In [None]:
import pandas as pd
from sklearn.model_selection import train_test_split

df = pd.read_json('datasets/qa_dataset.json')  # Adjust the path and method according to your dataset format

# Splitting dataset into training and validation sets
train_df, val_df = train_test_split(df, test_size=0.2) 


In [None]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained('google/flan-t5-base') 

# Function to tokenize the QA pairs
def tokenize_qa_pairs(df):
    return tokenizer(
        df['question'].tolist(),
        df['answer'].tolist(), 
        truncation=True, 
        padding='max_length', 
        max_length=128,  # Adjust based on model and memory constraints
        return_tensors='pt'
    )

train_encodings = tokenize_qa_pairs(train_df)
val_encodings = tokenize_qa_pairs(val_df)


In [None]:
import torch
del tokenizer

torch.cuda.empty_cache()

In [None]:
from torch.utils.data import Dataset

class QADataset(Dataset):
    def __init__(self, encodings, answers):
        self.encodings = encodings
        self.answers = answers

    def __getitem__(self, idx):
        item = {key: val[idx] for key, val in self.encodings.items()}
        # For generative QA, encode the answers and use them as labels
        answer_encoding = tokenizer(self.answers[idx], truncation=True, padding='max_length', max_length=128, return_tensors="pt")
        item['labels'] = answer_encoding['input_ids'][0]
        return item

    def __len__(self):
        return len(self.answers)

train_dataset = QADataset(train_encodings, train_df['answer'].tolist())
val_dataset = QADataset(val_encodings, val_df['answer'].tolist())


In [None]:
from transformers import AutoModelForSeq2SeqLM, TrainingArguments, Trainer

model = AutoModelForSeq2SeqLM.from_pretrained('google/flan-t5-base')

training_args = TrainingArguments(
    output_dir='./results',
    num_train_epochs=3,
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    warmup_steps=500,
    weight_decay=0.01,
    deepspeed='ds_config.json',  # Path to your DeepSpeed config
    logging_dir='./logs',
    evaluation_strategy='epoch',
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
)


In [None]:
trainer.train()