## Pipeline Quiz Generator (Separate Quiz and Distractor Approach)

Description: Quiz Generator with separate pipeline for quiz generation and then distractor generator

### Step 1 : SciQ Loading

Load dataset

In [140]:
from datasets import load_dataset

sciq_dataset = load_dataset("allenai/sciq")
sciq_dataset

DatasetDict({
    train: Dataset({
        features: ['question', 'distractor3', 'distractor1', 'distractor2', 'correct_answer', 'support'],
        num_rows: 11679
    })
    validation: Dataset({
        features: ['question', 'distractor3', 'distractor1', 'distractor2', 'correct_answer', 'support'],
        num_rows: 1000
    })
    test: Dataset({
        features: ['question', 'distractor3', 'distractor1', 'distractor2', 'correct_answer', 'support'],
        num_rows: 1000
    })
})

Sample Data:

In [141]:
sciq_dataset["train"][27]

{'question': 'A small scale version of what type of map displays individual rock units?',
 'distractor3': 'polar map',
 'distractor1': 'seismic map',
 'distractor2': 'geographic map',
 'correct_answer': 'geologic map',
 'support': 'Geologic maps display rock units and geologic features. A small scale map displays individual rock units while a large scale map shows geologic provinces.'}

Drop every data with empty support. 

In [142]:
filtered_sciq = sciq_dataset.filter(lambda example: example["support"] != '')
filtered_sciq

DatasetDict({
    train: Dataset({
        features: ['question', 'distractor3', 'distractor1', 'distractor2', 'correct_answer', 'support'],
        num_rows: 10481
    })
    validation: Dataset({
        features: ['question', 'distractor3', 'distractor1', 'distractor2', 'correct_answer', 'support'],
        num_rows: 887
    })
    test: Dataset({
        features: ['question', 'distractor3', 'distractor1', 'distractor2', 'correct_answer', 'support'],
        num_rows: 884
    })
})

Check for support with longer than 512 tokens/words (Maximum token of T5).

In [143]:
import pandas as pd
df_train = pd.DataFrame(sciq_dataset["train"])
df_train.head()

Unnamed: 0,question,distractor3,distractor1,distractor2,correct_answer,support
0,What type of organism is commonly used in prep...,viruses,protozoa,gymnosperms,mesophilic organisms,"Mesophiles grow best in moderate temperature, ..."
1,What phenomenon makes global winds blow northe...,tropical effect,muon effect,centrifugal effect,coriolis effect,Without Coriolis Effect the global winds would...
2,Changes from a less-ordered state to a more-or...,endothermic,unbalanced,reactive,exothermic,Summary Changes of state are examples of phase...
3,What is the least dangerous radioactive decay?,zeta decay,beta decay,gamma decay,alpha decay,All radioactive decay is dangerous to living t...
4,Kilauea in hawaii is the world’s most continuo...,magma,greenhouse gases,carbon and smog,smoke and ash,Example 3.5 Calculating Projectile Motion: Hot...


In [144]:
print(df_train['support'].str.len().max())

3559


In [145]:
print(df_train['question'].str.len().max())

399


In [15]:
test = filtered_sciq.filter(lambda example: len(example["support"]) > 3000) 
test

DatasetDict({
    train: Dataset({
        features: ['question', 'distractor3', 'distractor1', 'distractor2', 'correct_answer', 'support'],
        num_rows: 10
    })
    validation: Dataset({
        features: ['question', 'distractor3', 'distractor1', 'distractor2', 'correct_answer', 'support'],
        num_rows: 0
    })
    test: Dataset({
        features: ['question', 'distractor3', 'distractor1', 'distractor2', 'correct_answer', 'support'],
        num_rows: 0
    })
})

In [28]:
test_data = test['train'][7]
test_data['support']

'membrane gradients was known, Mitchell proposed that energy captured through the absorption of light (by phototrophs) or the breakdown of molecules into more stable molecules (by various types of chemotrophs) relied on the same basic (homologous) mechanism, namely the generation of H+ gradients across membranes (the plasma membrane in prokaryotes or the internal membranes of mitochondria or chloroplasts (intracellular organelles, derived from bacteria – see below) in eukaryotes. What makes us think that these processes might have a similar evolutionary root, that they are homologous? Basically, it is the observation that in both light- and chemical-based processes captured energy is transferred through the movement of electrons through a membrane-embedded “electron transport chain”. An electron transport chain involves a series of membrane and associated proteins and a series of reduction-oxidation or redox reactions (see below) during which electrons move from a high energy donor to 

We can see above that support is long but only a few sentences is relevant, we cannot do raw summarization, we have to extract text based on keywords which are answers (distractors and keywords from questions too!). If we left this, support and answer will be truncated. If we summarize it raw, we lose important info of what is asked.

Extractive Summarization based on answer and questions

In [23]:
import yake

def extract_question(question):
    kw_extractor = yake.KeywordExtractor(top=10, stopwords=None)
    keywords = kw_extractor.extract_keywords(question)
    return [keyword for keyword, score in keywords]

In [24]:
import spacy

nlp = spacy.load("en_core_web_sm")

def clean_text(text):
    lemmatizer = nlp.get_pipe("lemmatizer")
    doc = nlp(text)
    tokens = [tok for tok in doc]
    lemmas = [tok.lemma_ for tok in tokens]
    return ' '.join(lemmas)

In [25]:
def score_sentence(sentence, words):
    score = 0
    clean_sentences = clean_text(sentence.lower())
    for word in words:
        if clean_text(word.lower()) in clean_sentences:
            score += 1
    return score

In [88]:
def summarize_support(example, max_words=256):
    text = example["support"]
    words = extract_question(example["question"])
    words.extend([test_data["correct_answer"]])

    scored_sentences = (
        (i, sentence, score_sentence(sentence, words))
        for i, sentence in enumerate(text.split("."))
        if any(clean_text(w.lower()) in clean_text(sentence.lower()) for w in words)
    )
    ranked_sentences = sorted(scored_sentences, key=lambda x: x[2], reverse=True)

    sentence_in_summary = []
    sum_of_words = 0
    for order, sentence, _ in ranked_sentences:
        num_of_words = len(sentence.split())
        if sum_of_words + num_of_words < max_words:
            sentence_in_summary.append((order, sentence))
            sum_of_words += num_of_words 

    summary = sorted(sentence_in_summary, key=lambda x: x[1])
    return ".".join(sent for _, sent in summary)


summarize_support(test_data)

' ) The major pigment in this system, chlorophyll, is based on a complex molecule, a porphyrin (see above) and it is primarily these pigments that give plants their green color. At this point, we consider only one aspect of this photosynthetic system, known as the oxygenic or non-cyclic system (look to more advanced classes for more details. Chlorophyll is synthesized by a conserved biosynthetic pathway that is also used to synthesize heme, which is found in the hemoglobin of animals and in the cytochromes, within the electron transport chain present in both plants and animals (which. For simplicity’s sake we will describe the photosynthetic system of cyanobacterium; the system in eukaryotic algae and plants, while more complex, follows the same basic logic. In all of these organisms, their photosynthetic systems appear to be homologous, that is derived from a common ancestor, a topic we will return to later in this chapter. Oxygenic photosynthesis \u2028 Compared to the salt loving ar

In [93]:
def generate_context(example, max_token_size=256):
    answer_size = len(example['correct_answer'].split())
    support_size = len(example['support'].split())
    words_len = answer_size + support_size + 1
    context = example['support']

    if words_len > max_token_size:
        max_new_token_size = max_token_size - answer_size - 1
        context = summarize_support(example, max_words=max_new_token_size)
    
    return {
        "context": context
    }

preprocessed_sciq = filtered_sciq.map(generate_context, num_proc=4)

Map (num_proc=4):   0%|          | 0/10481 [00:00<?, ? examples/s]

Map (num_proc=4):   0%|          | 0/887 [00:00<?, ? examples/s]

Map (num_proc=4):   0%|          | 0/884 [00:00<?, ? examples/s]

In [126]:
test_x = preprocessed_sciq.filter(lambda example: len(example["question"].split()) > 64) 
test_x

DatasetDict({
    train: Dataset({
        features: ['question', 'distractor3', 'distractor1', 'distractor2', 'correct_answer', 'support', 'context'],
        num_rows: 3
    })
    validation: Dataset({
        features: ['question', 'distractor3', 'distractor1', 'distractor2', 'correct_answer', 'support', 'context'],
        num_rows: 0
    })
    test: Dataset({
        features: ['question', 'distractor3', 'distractor1', 'distractor2', 'correct_answer', 'support', 'context'],
        num_rows: 0
    })
})

In [127]:
test_x['train'][0]

{'question': 'The ideal gas law is used like any other gas law, with attention paid to the unit and making sure that temperature is expressed in kelvin. however, the ideal gas law does not require a change in the conditions of a gas sample. the ideal gas law implies that if you know any three of the physical properties of a gas, you can calculate this?',
 'distractor3': 'unrelated',
 'distractor1': 'second',
 'distractor2': 'third',
 'correct_answer': 'fourth',
 'support': 'The ideal gas law is used like any other gas law, with attention paid to the unit and making sure that temperature is expressed in Kelvin. However, the ideal gas law does not require a change in the conditions of a gas sample. The ideal gas law implies that if you know any three of the physical properties of a gas, you can calculate the fourth property.',
 'context': 'The ideal gas law is used like any other gas law, with attention paid to the unit and making sure that temperature is expressed in Kelvin. However, th

In [109]:
preprocessed_sciq.save_to_disk("preprocessed_sciq-qg-256-new")

Saving the dataset (0/1 shards):   0%|          | 0/10481 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/887 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/884 [00:00<?, ? examples/s]

Now preprocessed has shorter context for long ones, due to map process running slow, uncomment below after getting zip file from me

In [110]:
# from datasets import load_from_disk
# preprocessed_sciq = load_from_disk("preprocessed_sciq-qg-256")

### Step 2 Question Generation

#### Tokenize for Question Generation 

Input : Context and Answer \
Output : Question

In [111]:
import torch
import copy
from transformers import T5Tokenizer

tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-small")
tokenizer.add_special_tokens({"sep_token": "<sep>"})

You are using the default legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


1

In [134]:
question = test_x['train'][0]['question']
tokenized_targets = tokenizer.encode_plus(question, max_length=128, padding='max_length', truncation=True, return_tensors="pt")
tokenized_targets

{'input_ids': tensor([[   37,  1523,  1807,   973,    19,   261,   114,   136,   119,  1807,
           973,     6,    28,  1388,  1866,    12,     8,  1745,    11,   492,
           417,    24,  2912,    19,  7103,    16,     3,  5768,  2494,     5,
           983,     6,     8,  1523,  1807,   973,   405,    59,  1457,     3,
             9,   483,    16,     8,  1124,    13,     3,     9,  1807,  3106,
             5,     8,  1523,  1807,   973, 18841,    24,     3,    99,    25,
           214,   136,   386,    13,     8,  1722,  2605,    13,     3,     9,
          1807,     6,    25,    54, 11837,    48,    58,     1,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,  

In [135]:
for x in tokenized_targets['input_ids'][:2]:
    print(tokenizer.decode(x))

The ideal gas law is used like any other gas law, with attention paid to the unit and making sure that temperature is expressed in kelvin. however, the ideal gas law does not require a change in the conditions of a gas sample. the ideal gas law implies that if you know any three of the physical properties of a gas, you can calculate this?</s><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad>


In [133]:
def preprocess_dataset(example):
    text = "{}<sep>{}".format(example['correct_answer'], example['context'])
    question = example['question']

    max_length = 256
    max_length_target=64
    
    tokenized_inputs = tokenizer.encode_plus(text, max_length=max_length, padding='max_length', truncation=True, return_tensors="pt")
    tokenized_targets = tokenizer.encode_plus(question, max_length=max_length_target, padding='max_length', truncation=True, return_tensors="pt")
    
    input_ids = tokenized_inputs['input_ids'].squeeze()
    input_attention = tokenized_inputs['attention_mask'].squeeze()

    target_ids = tokenized_targets['input_ids'].squeeze()
    target_attention = tokenized_targets['attention_mask'].squeeze()

    labels = copy.deepcopy(target_ids)
    labels[labels == 0] = -100
    
    outputs = {
        'input_ids':input_ids, 
        'attention_mask': input_attention, 
        'labels': labels
    }

    return outputs
    
tokenized_dataset = preprocessed_sciq.map(preprocess_dataset, remove_columns= ['question', 'distractor3', 'distractor1', 'distractor2', 'correct_answer', 'support', 'context'])
tokenized_dataset

Map:   0%|          | 0/10481 [00:00<?, ? examples/s]

Map:   0%|          | 0/887 [00:00<?, ? examples/s]

Map:   0%|          | 0/884 [00:00<?, ? examples/s]

DatasetDict({
    train: Dataset({
        features: ['input_ids', 'attention_mask', 'labels'],
        num_rows: 10481
    })
    validation: Dataset({
        features: ['input_ids', 'attention_mask', 'labels'],
        num_rows: 887
    })
    test: Dataset({
        features: ['input_ids', 'attention_mask', 'labels'],
        num_rows: 884
    })
})

In [136]:
import numpy as np
import evaluate

def compute_metrics(eval_pred):
    metric = evaluate.load("bleu")
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=-1)
    return metric.compute(predictions=predictions, references=labels)

2024-03-25 13:20:32.725572: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-03-25 13:20:32.725713: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-03-25 13:20:32.814307: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-03-25 13:20:33.046298: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [137]:
from transformers import T5ForConditionalGeneration, TrainingArguments, Trainer, default_data_collator

training_args = TrainingArguments(
    output_dir="pretrained_question_generation", 
    evaluation_strategy="no",
    per_device_train_batch_size=8,
    per_device_eval_batch_size=1,
    gradient_accumulation_steps=2,
    num_train_epochs=3,
    fp16=True
    # gradient_checkpointing=True
)

model = T5ForConditionalGeneration.from_pretrained("google/flan-t5-small", device_map="auto")

device = "cuda" if torch.cuda.is_available() else "cpu"
model = model.to(device)

data_collator = default_data_collator

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_dataset["train"],
    eval_dataset=tokenized_dataset["validation"],
    data_collator=data_collator,
    compute_metrics=compute_metrics
)

In [138]:
torch.cuda.empty_cache()

There is no validation yet (it's buggy in my comp for now)

In [139]:
trainer.train()

Epoch,Training Loss,Validation Loss


Checkpoint destination directory pretrained_question_generation/checkpoint-500 already exists and is non-empty. Saving will proceed but saved results may be invalid.


Epoch,Training Loss,Validation Loss


KeyboardInterrupt: 

In [29]:
trainer.save_model('model-qg')

In [35]:
tokenized_dataset['validation'][137]

{'input_ids': [11499,
  32100,
  37,
  2677,
  5013,
  2107,
  387,
  45,
  3,
  9,
  6957,
  42,
  13064,
  616,
  190,
  6079,
  9243,
  7293,
  7,
  6,
  114,
  273,
  16,
  7996,
  666,
  3,
  5,
  100,
  19,
  2953,
  57,
  579,
  2677,
  11,
  23295,
  24,
  169,
  8,
  387,
  12,
  1633,
  70,
  4096,
  5,
  634,
  52,
  1982,
  10441,
  19,
  10441,
  24,
  3033,
  7,
  8,
  2912,
  13,
  387,
  1,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,


In [114]:
model = T5ForConditionalGeneration.from_pretrained("model-qg")

In [115]:
text = "{}<sep>{}".format('minute amounts', "Within a nervous system, a neuron, neurone, or nerve cell is an electrically excitable cell that fires electric signals called action potentials across a neural network. Neurons communicate with other cells via synapses, which are specialized connections that commonly use minute amounts of chemical neurotransmitters to pass the electric signal from the presynaptic neuron to the target cell through the synaptic gap.")
tokenized_inputs = tokenizer.encode_plus(text, max_length=256, padding='max_length', truncation=True, return_tensors="pt")
decoder_input_ids = tokenized_inputs['input_ids']

In [116]:
output = model.generate(
    input_ids=tokenized_inputs['input_ids']
)
output

tensor([[   0,  363,   19,    8, 1657,   21,    3,    9, 6567,   29,    6, 6567,
           29,   15,    6,   42, 9077, 2358,   24, 1472]])

In [117]:
print(tokenizer.decode(output[0]))

<pad>What is the term for a neuron, neurone, or nerve cell that fire


In [36]:
preprocessed_sciq['validation'][137]

{'question': 'When the temperature of water is increased after being used in cooling, it is this form of pollution?',
 'distractor3': 'air',
 'distractor1': 'atmospheric',
 'distractor2': 'cosmic',
 'correct_answer': 'thermal',
 'support': "Thermal pollution is pollution that raises the temperature of water. This is caused by power plants and factories that use the water to cool their machines. The plants pump cold water from a lake or coastal area through giant cooling towers, like those in Figure below . As it flows through the towers, the cold water absorbs heat. This warmed water is returned to the lake or sea. Thermal pollution can kill fish and other water life. It's not just the warm temperature that kills them. Warm water can’t hold as much oxygen as cool water. If the water gets too warm, there may not be enough oxygen for living things.",
 'context': ' The plants pump cold water from a lake or coastal area through giant cooling towers, like those in Figure below . This is cau

In [53]:
# for x in tokenized_dataset['validation'][137]['input_ids']:
#     print(tokenizer.decode(x))

In [30]:
trainer.evaluate()

OutOfMemoryError: CUDA out of memory. Tried to allocate 4.32 GiB (GPU 0; 6.00 GiB total capacity; 5.98 GiB already allocated; 0 bytes free; 8.37 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

### Step 3 Distractor Generation

#### Tokenize for Distractor Generation 

Input : Answer, Question, Context \
Output : 3 Distractors

In [126]:
def preprocess_dataset_for_distractor(example):
    text = "{}<sep>{}<sep>{}".format(example['question'], example['correct_answer'], example['context'])
    distractor = "{}<sep>{}<sep>{}".format(example['distractor1'], example['distractor2'], example['distractor3'])

    max_length = 256
    
    tokenized_inputs = tokenizer.encode_plus(text, max_length=max_length, padding='max_length', truncation=True, return_tensors="pt")
    tokenized_targets = tokenizer.encode_plus(distractor, max_length=max_length, padding='max_length', truncation=True, return_tensors="pt")
    
    input_ids = tokenized_inputs['input_ids'].squeeze()
    input_attention = tokenized_inputs['attention_mask'].squeeze()

    target_ids = tokenized_targets['input_ids'].squeeze()
    target_attention = tokenized_targets['attention_mask'].squeeze()

    labels = copy.deepcopy(target_ids)
    labels[labels == 0] = -100
    
    outputs = {
        'input_ids':input_ids, 
        'attention_mask': input_attention, 
        'labels': labels
    }

    return outputs
    
tokenized_dataset_distractor = preprocessed_sciq.map(preprocess_dataset_for_distractor, remove_columns= ['question', 'distractor3', 'distractor1', 'distractor2', 'correct_answer', 'support', 'context'])
tokenized_dataset_distractor

Map:   0%|          | 0/10481 [00:00<?, ? examples/s]

Map:   0%|          | 0/887 [00:00<?, ? examples/s]

Map:   0%|          | 0/884 [00:00<?, ? examples/s]

DatasetDict({
    train: Dataset({
        features: ['input_ids', 'attention_mask', 'labels'],
        num_rows: 10481
    })
    validation: Dataset({
        features: ['input_ids', 'attention_mask', 'labels'],
        num_rows: 887
    })
    test: Dataset({
        features: ['input_ids', 'attention_mask', 'labels'],
        num_rows: 884
    })
})

In [127]:
training_args_dis = TrainingArguments(
    output_dir="pretrained_distractor_generation", 
    evaluation_strategy="no", 
    logging_strategy="epoch",
    auto_find_batch_size=True,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=1,
    gradient_accumulation_steps=2,
    num_train_epochs=3,
    fp16=True
    # gradient_checkpointing=True
)

model_dis = T5ForConditionalGeneration.from_pretrained("t5-small")

device = "cuda" if torch.cuda.is_available() else "cpu"
model_dis = model_dis.to(device)

trainer = Trainer(
    model=model_dis,
    args=training_args_dis,
    train_dataset=tokenized_dataset_distractor["train"],
    eval_dataset=tokenized_dataset_distractor["validation"],
    compute_metrics=compute_metrics
)

In [128]:
trainer.train()

Step,Training Loss
655,1.9005
1311,1.6972
1965,1.6619


TrainOutput(global_step=1965, training_loss=1.7532189910345102, metrics={'train_runtime': 6312.5134, 'train_samples_per_second': 4.981, 'train_steps_per_second': 0.311, 'total_flos': 2126625726529536.0, 'train_loss': 1.7532189910345102, 'epoch': 3.0})

In [129]:
trainer.save_model('model-dg')

In [133]:
model_dis = T5ForConditionalGeneration.from_pretrained("model-dg")

In [183]:
test_data = preprocessed_sciq['test'][700]
test_data

{'question': 'Each species has a particular way of making a living which is called its what?',
 'distractor3': 'life-cycle',
 'distractor1': 'habit',
 'distractor2': 'system',
 'correct_answer': 'niche',
 'support': 'Each species has a particular way of making a living. This is called its niche . You can see the niche of a lion in Figure below . A lion makes its living by hunting and eating other animals. Each species also has a certain place where it is best suited to live. This is called its habitat . The lion’s habitat is a grassland. Why is a lion better off in a grassland than in a forest?.',
 'context': ' A lion makes its living by hunting and eating other animals. Each species also has a certain place where it is best suited to live. This is called its habitat . This is called its niche .Each species has a particular way of making a living'}

In [184]:
text = "{}<sep>{}<sep>{}".format(test_data['question'], test_data['correct_answer'], test_data['context'])
tokenized_inputs = tokenizer.encode_plus(text, max_length=256, padding='max_length', truncation=True, return_tensors="pt")
decoder_input_ids = tokenized_inputs['input_ids']

In [185]:
output = model_dis.generate(
    input_ids=tokenized_inputs['input_ids']
)
output

tensor([[   0,    3,    7,  232,    2,    7,   15,  102, 3155,    7,   15,    9,
            2,    7,   15,  102, 3155,    7,   15,    9]])

In [186]:
print(tokenizer.decode(output[0]))

<pad> sand<unk> sep>sea<unk>sep>sea
