Fine-Tune BART Model usign Writing Prompt Dataset

In [None]:
!pip install transformers -U
!pip install accelerate -U

In [None]:
# function to combine prompts and stories
def combinetext(prompt, story):
    prompts = open(prompt, 'r', encoding='utf8').readlines()
    stories = open(story, 'r', encoding='utf8').readlines()
    assert len(prompts) == len(stories)
    combine = []
    for i in range(len(prompts)):
        combine.append(prompts[i].rstrip() + ' <sep> ' + " ".join(stories[i].split()[:300]))
    return combine

# Prprocessing the data (punctuations, etc)
def cleanpunctuation(s):
    for p in '!,.:;?':
        s = s.replace(' ' + p, p)
    s = s.replace(' ' + 'n\'t', 'n\'t')
    s = s.replace(' ' + '\'s', '\'s')
    s = s.replace(' ' + '\'re', '\'re')
    s = s.replace(' ' + '\'ve', '\'ve')
    s = s.replace(' ' + '\'ll', '\'ll')
    s = s.replace(' ' + '\'am', '\'am')
    s = s.replace(' ' + '\'m', '\'m')
    s = s.replace(' ' + '\' m', '\'m')
    s = s.replace(' ' + '\'m', '\'m')
    s = s.replace(' ' + '\' ve', '\'ve')
    s = s.replace(' ' + '\' s', '\'s')
    s = s.replace('<newline>', '\n')
    return s

# Combine and clean text for train and valid datasets
train_text = combinetext('valid.wp_source', 'valid.wp_target')
train_text = list(map(cleanpunctuation, train_text))

valid_text = combinetext('test.wp_source', 'test.wp_target')
valid_text = list(map(cleanpunctuation, valid_text))


In [None]:
# tokenize and load BART-base model from Hugging Face
from transformers import BartTokenizer, BartForConditionalGeneration

model_name = "facebook/bart-base"
tokenizer = BartTokenizer.from_pretrained(model_name)
model = BartForConditionalGeneration.from_pretrained(model_name)


In [None]:
# tokenize and encode the input (conactenated text prompt and story usign <SEP>)
def tokenize_and_encode(examples):
    inputs, labels = [], []
    for example in examples:
        if ' <sep> ' in example:
            split_text = example.split(' <sep> ')
            inputs.append(split_text[0])
            labels.append(split_text[1])
        else:
            print("Separator not found in example:", example)
            # Handle the case where separator is not found
            continue  # Skipping this example

    return tokenizer(inputs, padding="max_length", truncation=True, max_length=1024), tokenizer(labels, padding="max_length", truncation=True, max_length=1024)

tokenized_train = tokenize_and_encode(train_text)
tokenized_valid = tokenize_and_encode(valid_text)

In [None]:
if torch.cuda.is_available():
    model.to("cuda")

In [None]:
from transformers import Trainer, TrainingArguments
import torch

# Check GPU availability and print the GPU name
if torch.cuda.is_available():
    print("GPU is available. Device name:", torch.cuda.get_device_name(0))
else:
    print("GPU is not available.")

# Training arguments
training_args = TrainingArguments(
    output_dir="./results",
    evaluation_strategy="epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=1,  # Reduced batch size
    per_device_eval_batch_size=1,   # Reduced batch size
    gradient_accumulation_steps=4,  # Using gradient accumulation
    num_train_epochs=3,
    weight_decay=0.01,
    fp16=True,  # Enable mixed precision
)

# Move model to GPU if available
if torch.cuda.is_available():
    model.to("cuda")

# Initialize Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_train,
    eval_dataset=tokenized_valid,
)

# Start training
trainer.train()


## I tried reducing the batch size, introducing gradient accumalation steps and also mixed precision training(fp16=True). But still the model seems too big to run on Colab GPU

In [None]:
model.save_pretrained("./fine_tuned_bart")
tokenizer.save_pretrained("./fine_tuned_bart")

Generate Stories for an example pair of three captions

In [None]:
def generate_story(prompt):
    inputs = tokenizer(prompt, return_tensors="pt", max_length=1024, truncation=True)
    outputs = model.generate(inputs["input_ids"], max_length=1024, num_beams=5, early_stopping=True)
    story = tokenizer.decode(outputs[0], skip_special_tokens=True)
    return story

# Example
prompt = "A dragon, a castle, and a mysterious old book"
print(generate_story(prompt))
