In [1]:
import torch
from datasets import load_dataset
from torch.utils.data import DataLoader
from transformers import (
    GPT2LMHeadModel,
    GPT2Tokenizer,
    TrainingArguments,
    Trainer,
    DataCollatorForLanguageModeling,
    EarlyStoppingCallback,
)
from transformers import DataCollatorForLanguageModeling

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def generate_answer(question, context, model, tokenizer, max_length=256):
    prompt = f"Question: {question}\nContext: {context}\nAnswer:"
    inputs = tokenizer(prompt, return_tensors="pt").to(device)
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_length=len(inputs.input_ids[0]) + max_length,
            temperature=0.7,
            num_return_sequences=1,
            pad_token_id=tokenizer.eos_token_id,
        )
    answer = tokenizer.decode(outputs[0], skip_special_tokens=True)
    answer_start = answer.find("Answer:") + len("Answer:")
    return answer[answer_start:].strip()

In [3]:
dataset = load_dataset("pubmed_qa", "pqa_labeled")
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token
model = GPT2LMHeadModel.from_pretrained("gpt2").to(device)
question = "What causes COVID-19?"
context = "COVID-19 is caused by the SARS-CoV-2 virus."
answer = generate_answer(question, context, model, tokenizer)
print(answer)



The SARS-CoV-2 virus is a virus that is present in the human body. It is a virus that is present in the human body. It is a virus that is present in the human body. It is a virus that is present in the human body. It is a virus that is present in the human body. It is a virus that is present in the human body. It is a virus that is present in the human body. It is a virus that is present in the human body. It is a virus that is present in the human body. It is a virus that is present in the human body. It is a virus that is present in the human body. It is a virus that is present in the human body. It is a virus that is present in the human body. It is a virus that is present in the human body. It is a virus that is present in the human body. It is a virus that is present in the human body. It is a virus that is present in the human body. It is a virus that is present in the human body. It is a virus that is present in the human body. It is a virus that is present in the human body. It

In [4]:
def preprocess_function(examples):
    texts = []
    for question, context, answer in zip(
        examples["question"], examples["context"], examples["final_decision"]
    ):
        context_text = " ".join(context) if isinstance(context, list) else context
        text = f"Question: {question}\nContext: {context_text}\nAnswer: {answer}"
        texts.append(text)
    tokenized = tokenizer(texts, truncation=True, padding="max_length", max_length=512)
    tokenized["labels"] = tokenized["input_ids"].copy()
    return tokenized


tokenized_dataset = dataset.map(preprocess_function, batched=True)
if "validation" not in tokenized_dataset:
    split_dataset = tokenized_dataset["train"].train_test_split(test_size=0.1, seed=42)
    train_dataset = split_dataset["train"]
    eval_dataset = split_dataset["test"]
else:
    train_dataset = tokenized_dataset["train"]
    eval_dataset = tokenized_dataset["validation"]

Map: 100%|██████████| 1000/1000 [00:02<00:00, 409.66 examples/s]


In [5]:
def calculate_perplexity(model, eval_dataset, batch_size=4):
    model.eval()
    data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
    data_loader = DataLoader(
        eval_dataset, batch_size=batch_size, shuffle=False, collate_fn=data_collator
    )
    total_loss = 0.0
    total_tokens = 0
    with torch.no_grad():
        for batch in data_loader:
            inputs = {k: v.to(device) for k, v in batch.items() if k != "labels"}
            labels = batch["labels"].to(device)
            outputs = model(**inputs, labels=labels)
            loss = outputs.loss
            total_loss += loss.item() * labels.size(0) * labels.size(1)
            total_tokens += torch.sum(labels != -100).item()
    avg_loss = total_loss / total_tokens
    perplexity = torch.exp(torch.tensor(avg_loss))
    return perplexity.item()


eval_dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "labels"])
pre_finetune_perplexity = calculate_perplexity(model, eval_dataset)
print(f"Perplexity before finetune: {pre_finetune_perplexity:.2f}")

`loss_type=None` was set in the config but it is unrecognised.Using the default loss: `ForCausalLMLoss`.


Perplexity before finetune: 28.32


In [6]:
training_args = TrainingArguments(
    output_dir="./results",
    eval_strategy="steps",
    eval_steps=500,
    save_strategy="steps",
    save_steps=1000,
    learning_rate=2e-5,
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    weight_decay=0.01,
    num_train_epochs=3,
    warmup_steps=500,
    logging_dir="./logs",
    logging_steps=100,
    fp16=torch.cuda.is_available(),
    load_best_model_at_end=True,
    metric_for_best_model="eval_loss",
    greater_is_better=False,
)

data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    data_collator=data_collator,
    callbacks=[EarlyStoppingCallback(early_stopping_patience=3)],
)
trainer.train()
trainer.save_model("../results/fine_tuned_gpt2_pubmedqa")
tokenizer.save_pretrained("../results/fine_tuned_gpt2_pubmedqa")
results = trainer.evaluate()
print(f"Perplexity after finetune: {torch.exp(torch.tensor(results['eval_loss'])):.2f}")

Step,Training Loss,Validation Loss
500,2.2554,2.195654


Perplexity after finetune: 8.78


In [7]:
model_finetune = GPT2LMHeadModel.from_pretrained(
    "../results/fine_tuned_gpt2_pubmedqa"
).to(device)
tokenizer_finetune = GPT2Tokenizer.from_pretrained(
    "../results/fine_tuned_gpt2_pubmedqa"
)

answer = generate_answer(question, context, model_finetune, tokenizer_finetune)
print(answer)



The aim of this study was to determine the cause of COVID-19 in a population of patients with severe respiratory syndrome (SARS).
METHODS: A retrospective cohort study was conducted in the Netherlands. Patients with severe SARS were included in the study. Patients with severe SARS were excluded from the study because of the presence of SARS-CoV-2 virus.
RESULTS: The mean age of the patients was 43.5 years (range, 30-54 years). The mean age of the patients was 65.4 years (range, 45-69 years). The mean age of the patients was 65.4 years (range, 45-69 years). The mean age of the patients was 65.4 years (range, 45-69 years). The mean age of the patients was 65.4 years (range, 45-69 years). The mean age of the patients was 65.4 years (range, 45-69 years). The mean age of the patients was 65.4 years (range, 45-69 years). The mean age of the patients was 65.4 years (range, 45-69 years). The mean age of the patients was 65.4 years (range, 45-69 years). The mean age of the
