# Fine-Tuning FLAN T5 for Question Answering

In [1]:
import os
import json
from transformers import GPT2Tokenizer, GPT2LMHeadModel
from transformers import Trainer, TrainingArguments
from transformers import TextDataset, DataCollatorForLanguageModeling

In [2]:
from datasets import load_dataset

dataset = load_dataset("keivalya/MedQuad-MedicalQnADataset")

In [3]:
dataset

DatasetDict({
    train: Dataset({
        features: ['qtype', 'Question', 'Answer'],
        num_rows: 16407
    })
})

In [10]:
def train_text_generate_question_answer(data, train_name):
    # Extract questions and answers
    questions = [entry['Question'] for entry in data]
    answers = [entry['Answer'] for entry in data]
    qtype = [entry['qtype'] for entry in data]

    # Store questions and answers in a text file
    with open(train_name, 'w') as text_file:
        for q, a, t in zip(questions, answers,qtype):
            text_file.write(f"[Q] {q}\n[A] {a}\n[T] {t}\n\n")

In [11]:
train_text_generate_question_answer(dataset['train'], 'train.txt')

In [12]:
def load_dataset(file_path, tokenizer, block_size = 128):
    dataset = TextDataset(
        tokenizer = tokenizer,
        file_path = file_path,
        block_size = block_size,
    )
    return dataset

def load_data_collator(tokenizer, mlm = False):
    data_collator = DataCollatorForLanguageModeling(
        tokenizer=tokenizer, 
        mlm=mlm,
    )
    return data_collator

def train(train_file_path,model_name,
          output_dir,
          overwrite_output_dir,
          per_device_train_batch_size,
          num_train_epochs,
          save_steps):
    tokenizer = GPT2Tokenizer.from_pretrained(model_name)
    train_dataset = load_dataset(train_file_path, tokenizer)
    data_collator = load_data_collator(tokenizer)

    tokenizer.save_pretrained(output_dir)
        
    model = GPT2LMHeadModel.from_pretrained(model_name)

    model.save_pretrained(output_dir)

    training_args = TrainingArguments(
          output_dir=output_dir,
          overwrite_output_dir=overwrite_output_dir,
          per_device_train_batch_size=per_device_train_batch_size,
          num_train_epochs=num_train_epochs,
      )

    trainer = Trainer(
            model=model,
            args=training_args,
            data_collator=data_collator,
            train_dataset=train_dataset,
    )
        
    trainer.train()
    trainer.save_model()

In [14]:
train_file_path = os.path.join(os.getcwd(),"train.txt" )
model_name = 'gpt2'
output_dir = 'output'
overwrite_output_dir = False
per_device_train_batch_size = 8
num_train_epochs = 3
save_steps = 50000

In [15]:
train(
    train_file_path=train_file_path,
    model_name=model_name,
    output_dir=output_dir,
    overwrite_output_dir=overwrite_output_dir,
    per_device_train_batch_size=per_device_train_batch_size,
    num_train_epochs=num_train_epochs,
    save_steps=save_steps
)

dataloader_config = DataLoaderConfiguration(dispatch_batches=None, split_batches=False)
Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mjohnmosesng[0m ([33mjohnmosesng-axiis[0m). Use [1m`wandb login --relogin`[0m to force relogin


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011169211577777737, max=1.0…

  0%|          | 0/15933 [00:00<?, ?it/s]

{'loss': 2.1298, 'learning_rate': 4.8430929517353924e-05, 'epoch': 0.09}
{'loss': 1.8796, 'learning_rate': 4.686185903470784e-05, 'epoch': 0.19}
{'loss': 1.7911, 'learning_rate': 4.529278855206176e-05, 'epoch': 0.28}
{'loss': 1.7144, 'learning_rate': 4.3723718069415684e-05, 'epoch': 0.38}
{'loss': 1.7018, 'learning_rate': 4.21546475867696e-05, 'epoch': 0.47}
{'loss': 1.669, 'learning_rate': 4.058557710412352e-05, 'epoch': 0.56}
{'loss': 1.6141, 'learning_rate': 3.9016506621477436e-05, 'epoch': 0.66}
{'loss': 1.6205, 'learning_rate': 3.744743613883136e-05, 'epoch': 0.75}
{'loss': 1.5766, 'learning_rate': 3.5878365656185273e-05, 'epoch': 0.85}
{'loss': 1.5777, 'learning_rate': 3.4309295173539195e-05, 'epoch': 0.94}
{'loss': 1.539, 'learning_rate': 3.274022469089312e-05, 'epoch': 1.04}
{'loss': 1.4684, 'learning_rate': 3.117115420824704e-05, 'epoch': 1.13}
{'loss': 1.4635, 'learning_rate': 2.9602083725600955e-05, 'epoch': 1.22}
{'loss': 1.4599, 'learning_rate': 2.8033013242954877e-05, 'ep

wandb: Network error (ConnectionError), entering retry loop.
[34m[1mwandb[0m: Network error resolved after 0:02:42.680286, resuming normal operation.
wandb: Network error (ConnectionError), entering retry loop.
[34m[1mwandb[0m: Network error resolved after 0:02:43.901390, resuming normal operation.
