In [1]:
import nltk
from datasets import load_dataset

In [2]:
squad = load_dataset("squad", split="train[:40000]")

In [3]:
nltk.download('punkt')

[nltk_data] Downloading package punkt to
[nltk_data]     C:\Users\dante\AppData\Roaming\nltk_data...
[nltk_data]   Unzipping tokenizers\punkt.zip.


True

In [16]:
from nltk.tokenize import sent_tokenize

In [5]:
squad = squad.train_test_split(test_size=0.1)

In [6]:
squad

DatasetDict({
    train: Dataset({
        features: ['id', 'title', 'context', 'question', 'answers'],
        num_rows: 36000
    })
    test: Dataset({
        features: ['id', 'title', 'context', 'question', 'answers'],
        num_rows: 4000
    })
})

In [7]:
squad["train"][0]

{'id': '56df72ab5ca0a614008f9a7b',
 'title': 'Oklahoma_City',
 'context': 'The Oklahoma City Zoo and Botanical Garden is home to numerous natural habitats, WPA era architecture and landscaping, and hosts major touring concerts during the summer at its amphitheater. Oklahoma City also has two amusement parks, Frontier City theme park and White Water Bay water park. Frontier City is an \'Old West\'-themed amusement park. The park also features a recreation of a western gunfight at the \'OK Corral\' and many shops that line the "Western" town\'s main street. Frontier City also hosts a national concert circuit at its amphitheater during the summer. Oklahoma City also has a combination racetrack and casino open year-round, Remington Park, which hosts both Quarter horse (March – June) and Thoroughbred (August – December) seasons.',
 'question': 'Which amusement park is western themed? ',
 'answers': {'text': ['Frontier City'], 'answer_start': [235]}}

In [44]:
def filter_samples(example):
    """
    Returns bool for valid samples. Each sample must:
    1. Have an answer
    2. Have a question
    3. The answer must be in the context    
    """
    # Check if there are no answers
    if example["answers"]["text"] == []: return False
    answer = example["answers"]["text"][0]

    valid = example["question"] != "" and answer in example["context"] 
    return valid

In [46]:
squad.filter(filter_samples)

Filter:   0%|          | 0/36000 [00:00<?, ? examples/s]

Filter:   0%|          | 0/4000 [00:00<?, ? examples/s]

DatasetDict({
    train: Dataset({
        features: ['id', 'title', 'context', 'question', 'answers'],
        num_rows: 36000
    })
    test: Dataset({
        features: ['id', 'title', 'context', 'question', 'answers'],
        num_rows: 4000
    })
})

In [72]:
def extract_answer_sentence(context, answer):
    """
    Extracts the sentence containing the answer from the context.
    
    Args:
    - example: A dictionary representing a SQuaD example.
    
    Returns:
    - Modified example with the 'context' field replaced by the sentence containing the answer.
    """
    # Split the context into sentences
    sentences = sent_tokenize(context)
    # Return the sentences that contain the answer's name
    return " ".join(filter(lambda s: answer in s, sentences))

In [64]:
example = squad["train"][0]

print(example["context"])
print(example["question"])
print(example["answers"]["text"][0])

print(extract_answer_sentence(example["context"], example["answers"]["text"][0]))

The Oklahoma City Zoo and Botanical Garden is home to numerous natural habitats, WPA era architecture and landscaping, and hosts major touring concerts during the summer at its amphitheater. Oklahoma City also has two amusement parks, Frontier City theme park and White Water Bay water park. Frontier City is an 'Old West'-themed amusement park. The park also features a recreation of a western gunfight at the 'OK Corral' and many shops that line the "Western" town's main street. Frontier City also hosts a national concert circuit at its amphitheater during the summer. Oklahoma City also has a combination racetrack and casino open year-round, Remington Park, which hosts both Quarter horse (March – June) and Thoroughbred (August – December) seasons.
Which amusement park is western themed? 
Frontier City
Oklahoma City also has two amusement parks, Frontier City theme park and White Water Bay water park. Frontier City is an 'Old West'-themed amusement park. Frontier City also hosts a nationa

In [77]:
inputs = list(map(
    lambda c, a: extract_answer_sentence(c, a["text"][0]),
    squad["train"][10:14]["context"],
    squad["train"][10:14]["answers"]
))
inputs

['This feast is called in older prayer books the Purification of the Blessed Virgin Mary on February 2.',
 'Such institutional support may include government recognition or designation; presentation as being the "correct" form of a language in schools; published grammars, dictionaries, and textbooks that set forth a correct spoken and written form; and an extensive formal literature that employs that dialect (prose, poetry, non-fiction, etc.).',
 'Quoted at constant 2002 prices, GDP fell from £12 million in 1999-2000 to £11 million in 2005-06.',
 'Iran was Sunni at the time.']

In [80]:
targets = list(map(
    lambda q, a: f"{q} {a["text"][0]}",
    squad["train"][10:14]["question"],
    squad["train"][10:14]["answers"]
))
targets

['On what date is the Presentation of Christ in the Temple celebrated by Anglicans? February 2',
 'Recognition from what body may help a dialect to become standardized? government',
 'What was the GDP of the island in 1999-2000? £12 million',
 'In the later Abbasid era, what branch of Islam did Iran adhere to? Sunni']

In [49]:
checkpoint = "google-t5/t5-small"

In [65]:
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
tokenizer = AutoTokenizer.from_pretrained(checkpoint)

In [74]:
def preprocess_function(examples):
    # Add the prompt to every context provided
    inputs = list(map(
        lambda c, a: extract_answer_sentence(c, a["text"][0]),
        examples["context"],
        examples["answers"]
    ))
    
    ["ask: " + context for context in examples["context"]]
    # Model every target as "question? answer"
    targets = list(map(
        lambda q, a: f"{q} {a["text"][0]}",
        examples["question"],
        examples["answers"]
    ))

    model_inputs = tokenizer(inputs, max_length=512, padding="max_length", truncation=True)
    labels = tokenizer(text_target=targets, max_length=128, padding="max_length", truncation=True)
    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

In [144]:
# Test what mapping over the samples would look like
outputs = list(map(
    lambda q, a: f"Q: {q} A: {a["text"][0]}",
    squad["train"][:5]["question"],
    squad["train"][:5]["answers"]
))

In [145]:
list(outputs)

["Q: Where is Volkswagen Group's AutoEuropa assembly plant located? A: Palmela",
 'Q: For what movie did Beyonce receive  her first Golden Globe nomination? A: Dreamgirls',
 "Q: Who provided information about the game's controls in December of 2005? A: NGC Magazine",
 'Q: Instead of being a single person, what does Whitehead view a person as? A: continuum of overlapping events',
 'Q: What was the percentage increase in the Broadway ticket revenue from 2012-3 to 2013-4? A: 11.4%']

In [146]:
tokenized_squad = squad.map(preprocess_function, batched=True)

Map:   0%|          | 0/2000 [00:00<?, ? examples/s]

In [151]:
tokenized_squad

DatasetDict({
    train: Dataset({
        features: ['id', 'title', 'context', 'question', 'answers', 'input_ids', 'attention_mask', 'labels'],
        num_rows: 8000
    })
    test: Dataset({
        features: ['id', 'title', 'context', 'question', 'answers', 'input_ids', 'attention_mask', 'labels'],
        num_rows: 2000
    })
})

In [147]:
import evaluate
rouge = evaluate.load("rouge")

In [148]:
import numpy as np

def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    result = rouge.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True)

    prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in predictions]
    result["gen_len"] = np.mean(prediction_lens)

    return {k: round(v, 4) for k, v in result.items()}

In [149]:
from transformers import DataCollatorForSeq2Seq

data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=checkpoint)

In [150]:
model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint)

In [152]:
from transformers import Seq2SeqTrainingArguments, Seq2SeqTrainer

training_args = Seq2SeqTrainingArguments(
    output_dir="./results",
    evaluation_strategy="epoch",
    save_strategy="epoch",
    num_train_epochs=4,
    learning_rate=2e-5,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=4,
    logging_steps=10,
    weight_decay=0.01,
    predict_with_generate=True,
    fp16=True,
)

trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_squad["train"],
    eval_dataset=tokenized_squad["test"],
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)

In [153]:
trainer.train()

Epoch,Training Loss,Validation Loss,Rouge1,Rouge2,Rougel,Rougelsum,Gen Len
1,0.3783,0.309609,0.3639,0.111,0.3249,0.325,18.465
2,0.3573,0.291394,0.3719,0.113,0.3306,0.3307,18.499
3,0.3305,0.284827,0.3743,0.1141,0.3338,0.3336,18.484
4,0.3278,0.283001,0.3769,0.1158,0.3361,0.336,18.4955




TrainOutput(global_step=2000, training_loss=0.6134686719179153, metrics={'train_runtime': 2557.6803, 'train_samples_per_second': 12.511, 'train_steps_per_second': 0.782, 'total_flos': 4330937647104000.0, 'train_loss': 0.6134686719179153, 'epoch': 4.0})

In [154]:
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer

mcq_model = AutoModelForSeq2SeqLM.from_pretrained("results/checkpoint-2000")

In [181]:
input_text = "context: Chemical engineering involves the production and manufacturing of products through chemical processes. This includes designing equipment, systems, and processes for refining raw materials and for mixing, compounding, and processing chemicals."
input_ids = tokenizer.encode(input_text, return_tensors="pt")

In [182]:
outputs = mcq_model.generate(input_ids)

In [183]:
tokenizer.decode(outputs[0])

'<pad> Q: What is the process of chemical engineering? A: manufacturing</s>'