In [None]:
!gdown -O  t5_que_gen.zip --id 1vhsDOW9wUUO83IQasTPlkxb82yxmMH-V
!unzip t5_que_gen.zip

Downloading...
From: https://drive.google.com/uc?id=1vhsDOW9wUUO83IQasTPlkxb82yxmMH-V
To: /content/t5_que_gen.zip
1.65GB [00:23, 71.0MB/s]
Archive:  t5_que_gen.zip
   creating: t5_que_gen_model/
   creating: t5_que_gen_model/t5_base_tok_que_gen/
  inflating: t5_que_gen_model/t5_base_tok_que_gen/spiece.model  
 extracting: t5_que_gen_model/t5_base_tok_que_gen/added_tokens.json  
 extracting: t5_que_gen_model/t5_base_tok_que_gen/tokenizer_config.json  
  inflating: t5_que_gen_model/t5_base_tok_que_gen/special_tokens_map.json  
   creating: t5_que_gen_model/t5_base_que_gen/
  inflating: t5_que_gen_model/t5_base_que_gen/config.json  
  inflating: t5_que_gen_model/t5_base_que_gen/pytorch_model.bin  
 extracting: t5_que_gen_model/logs.zip  
   creating: t5_ans_gen_model/
   creating: t5_ans_gen_model/t5_base_tok_ans_gen/
  inflating: t5_ans_gen_model/t5_base_tok_ans_gen/spiece.model  
  inflating: t5_ans_gen_model/t5_base_tok_ans_gen/added_tokens.json  
 extracting: t5_ans_gen_model/t5_base_

In [None]:
!pip install transformers

Collecting transformers
[?25l  Downloading https://files.pythonhosted.org/packages/12/b5/ac41e3e95205ebf53439e4dd087c58e9fd371fd8e3724f2b9b4cdb8282e5/transformers-2.10.0-py3-none-any.whl (660kB)
[K     |▌                               | 10kB 16.5MB/s eta 0:00:01[K     |█                               | 20kB 4.5MB/s eta 0:00:01[K     |█▌                              | 30kB 6.2MB/s eta 0:00:01[K     |██                              | 40kB 7.9MB/s eta 0:00:01[K     |██▌                             | 51kB 5.0MB/s eta 0:00:01[K     |███                             | 61kB 5.9MB/s eta 0:00:01[K     |███▌                            | 71kB 6.6MB/s eta 0:00:01[K     |████                            | 81kB 7.4MB/s eta 0:00:01[K     |████▌                           | 92kB 6.0MB/s eta 0:00:01[K     |█████                           | 102kB 6.5MB/s eta 0:00:01[K     |█████▌                          | 112kB 6.5MB/s eta 0:00:01[K     |██████                          | 122kB 6.5

In [None]:
import argparse
import glob
import os
import json
import time
import logging
import random
from itertools import chain
from string import punctuation

import nltk
nltk.download('punkt')
from nltk.tokenize import sent_tokenize

import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader

from transformers import (
    AdamW,
    T5ForConditionalGeneration,
    T5Tokenizer,
    get_linear_schedule_with_warmup
)

[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Unzipping tokenizers/punkt.zip.


In [None]:
class QueGenerator():
  def __init__(self):
    self.que_model = T5ForConditionalGeneration.from_pretrained('./t5_que_gen_model/t5_base_que_gen/')
    self.ans_model = T5ForConditionalGeneration.from_pretrained('./t5_ans_gen_model/t5_base_ans_gen/')

    self.que_tokenizer = T5Tokenizer.from_pretrained('./t5_que_gen_model/t5_base_tok_que_gen/')
    self.ans_tokenizer = T5Tokenizer.from_pretrained('./t5_ans_gen_model/t5_base_tok_ans_gen/')
    
    self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
    
    self.que_model = self.que_model.to(self.device)
    self.ans_model = self.ans_model.to(self.device)
  
  def generate(self, text):
    answers = self._get_answers(text)
    questions = self._get_questions(text, answers)
    output = [{'answer': ans, 'question': que} for ans, que in zip(answers, questions)]
    return output
  
  def _get_answers(self, text):
    # split into sentences
    sents = sent_tokenize(text)

    examples = []
    for i in range(len(sents)):
      input_ = ""
      for j, sent in enumerate(sents):
        if i == j:
            sent = "[HL] %s [HL]" % sent
        input_ = "%s %s" % (input_, sent)
        input_ = input_.strip()
      input_ = input_ + " </s>"
      examples.append(input_)
    
    batch = self.ans_tokenizer.batch_encode_plus(examples, max_length=512, pad_to_max_length=True, return_tensors="pt")
    with torch.no_grad():
      outs = self.ans_model.generate(input_ids=batch['input_ids'].to(self.device), 
                                attention_mask=batch['attention_mask'].to(self.device), 
                                max_length=32,
                                # do_sample=False,
                                # num_beams = 4,
                                )
    dec = [self.ans_tokenizer.decode(ids, skip_special_tokens=False) for ids in outs]
    answers = [item.split('[SEP]') for item in dec]
    answers = chain(*answers)
    answers = [ans.strip() for ans in answers if ans != ' ']
    return answers
  
  def _get_questions(self, text, answers):
    examples = []
    for ans in answers:
      input_text = "%s [SEP] %s </s>" % (ans, text)
      examples.append(input_text)
    
    batch = self.que_tokenizer.batch_encode_plus(examples, max_length=512, pad_to_max_length=True, return_tensors="pt")
    with torch.no_grad():
      outs = self.que_model.generate(input_ids=batch['input_ids'].to(self.device), 
                                attention_mask=batch['attention_mask'].to(self.device), 
                                max_length=32,
                                num_beams = 4)
    dec = [self.que_tokenizer.decode(ids, skip_special_tokens=False) for ids in outs]
    return dec

In [None]:
que_generator = QueGenerator()

In [None]:
text = "Python is an interpreted, high-level, general-purpose programming language. Created by Guido van Rossum \
and first released in 1991, Python's design philosophy emphasizes code \
readability with its notable use of significant whitespace."

text2 = "Gravity (from Latin gravitas, meaning 'weight'), or gravitation, is a natural phenomenon by which all \
things with mass or energy—including planets, stars, galaxies, and even light—are brought toward (or gravitate toward) \
one another. On Earth, gravity gives weight to physical objects, and the Moon's gravity causes the ocean tides. \
The gravitational attraction of the original gaseous matter present in the Universe caused it to begin coalescing \
and forming stars and caused the stars to group together into galaxies, so gravity is responsible for many of \
the large-scale structures in the Universe. Gravity has an infinite range, although its effects become increasingly \
weaker as objects get further away"

In [None]:
que_generator.generate(text)

[{'answer': 'Python',
  'question': 'What is the name of the interpreted, high-level, general-purpose programming language?'},
 {'answer': 'Guido van Rossum', 'question': 'Who created Python?'},
 {'answer': '1991', 'question': 'When was Python released?'}]

In [None]:
que_generator.generate(text2)

[{'answer': 'weight', 'question': 'What does gravitas mean in English?'},
 {'answer': 'Earth',
  'question': 'On what planet does gravity give weight to physical objects?'},
 {'answer': 'galaxies', 'question': 'What do the stars form together into?'},
 {'answer': 'weaker',
  'question': "What do gravity's effects become as objects get further away?"}]

In [None]:
tetx = "A dentist, also known as a dental surgeon, is a surgeon who specializes in dentistry, the diagnosis, prevention, and treatment of diseases and conditions of the oral cavity. The dentist's supporting team aids in providing oral health services. The dental team includes dental assistants, dental hygienists, dental technicians, and sometimes dental therapists."

In [None]:
que_generator.generate(tetx)

[{'answer': 'dental surgeon',
  'question': 'What is another name for a dentist?'},
 {'answer': "The dentist's supporting team",
  'question': 'Who provides oral health services?'},
 {'answer': 'dental assistants, dental hygienists, dental technicians, and sometimes dental therapists',
  'question': 'What is a dental team comprised of?'}]