Fine-Tune Flan-T5 Model usign Writing Prompt Dataset

In [None]:
pip install transformers torch

In [None]:
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import T5Tokenizer, T5ForConditionalGeneration, AdamW
from tqdm.auto import tqdm

In [None]:
# function to combine prompts and stories
def combinetext(prompt, story):
    prompts = open(prompt, 'r', encoding='utf8').readlines()
    stories = open(story, 'r', encoding='utf8').readlines()
    assert len(prompts) == len(stories)
    combine = []
    for i in range(len(prompts)):
        combine.append(prompts[i].rstrip() + ' <sep> ' + " ".join(stories[i].split()[:300]))
    return combine

# Prprocessing the data (punctuations, etc)
def cleanpunctuation(s):
    for p in '!,.:;?':
        s = s.replace(' ' + p, p)
    s = s.replace(' ' + 'n\'t', 'n\'t')
    s = s.replace(' ' + '\'s', '\'s')
    s = s.replace(' ' + '\'re', '\'re')
    s = s.replace(' ' + '\'ve', '\'ve')
    s = s.replace(' ' + '\'ll', '\'ll')
    s = s.replace(' ' + '\'am', '\'am')
    s = s.replace(' ' + '\'m', '\'m')
    s = s.replace(' ' + '\' m', '\'m')
    s = s.replace(' ' + '\'m', '\'m')
    s = s.replace(' ' + '\' ve', '\'ve')
    s = s.replace(' ' + '\' s', '\'s')
    s = s.replace('<newline>', '\n')
    return s

In [None]:
!pip install sentencepiece

In [None]:
tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-base")

train_texts = combinetext('valid.wp_source', 'valid.wp_target')
train_texts = list(map(cleanpunctuation, train_texts))
train_dataset = StoryDataset(tokenizer, train_texts)

valid_texts = combinetext('test.wp_source', 'test.wp_target')
valid_texts = list(map(cleanpunctuation, valid_texts))
valid_dataset = StoryDataset(tokenizer, valid_texts)


In [None]:
# tokenize and encode the input (conactenated text prompt and story usign <SEP>)
class StoryDataset(Dataset):
    def __init__(self, tokenizer, texts, max_length=512):
        self.tokenizer = tokenizer
        self.inputs = []
        self.targets = []
        for text in texts:
            prompt, story = text.split('<sep>')
            tokenized_input = tokenizer(prompt, max_length=max_length, truncation=True, padding='max_length', return_tensors="pt")
            tokenized_target = tokenizer(story, max_length=max_length, truncation=True, padding='max_length', return_tensors="pt")
            self.inputs.append(tokenized_input)
            self.targets.append(tokenized_target)

    def __len__(self):
        return len(self.inputs)

    def __getitem__(self, index):
        input_ids = self.inputs[index]['input_ids'].squeeze(0)
        attention_mask = self.inputs[index]['attention_mask'].squeeze(0)
        target_ids = self.targets[index]['input_ids'].squeeze(0)
        return input_ids, attention_mask, target_ids

In [None]:
# load Flan-T5 base modl from Hugging Face
model = T5ForConditionalGeneration.from_pretrained("google/flan-t5-base")

In [None]:
# Use GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
valid_loader = DataLoader(valid_dataset, batch_size=8)

optimizer = AdamW(model.parameters(), lr=5e-5)


In [None]:
# Fine-tune train the model
num_epochs = 3
for epoch in range(num_epochs):
    model.train()
    for input_ids, attention_mask, target_ids in tqdm(train_loader):
        input_ids = input_ids.to(device)
        attention_mask = attention_mask.to(device)
        target_ids = target_ids.to(device)

        # Forward pass
        outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=target_ids)
        loss = outputs.loss

        # Backward pass
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

In [None]:
model.save_pretrained('./fine_tuned_model')

In [None]:
from torch.nn.functional import cross_entropy

In [None]:
# Model evaluation and compute loss
def evaluate(model, val_loader, device):
    model.eval()
    total_loss = 0
    total = 0
    with torch.no_grad():
        for inputs, targets in tqdm(val_loader):
            inputs = {k: v.to(device) for k, v in inputs.items()}
            targets = {k: v.to(device) for k, v in targets.items()}

            outputs = model(**inputs, labels=targets["input_ids"])
            loss = outputs.loss
            total_loss += loss.item()
            total += 1

    return total_loss / total

In [None]:
val_loss = evaluate(model, valid_loader, device)
print(f"Validation Loss: {val_loss}")
