In [1]:
from datasets import load_dataset
import torch
from transformers import BartTokenizer, BartForConditionalGeneration, TrainingArguments
from trl import DPOTrainer, DPOConfig

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Load the dataset
dataset = load_dataset("openai/summarize_from_feedback", "comparisons")

In [3]:
# Example of exploring the dataset
print(dataset["train"][0])

{'info': {'id': 't3_34xale', 'post': "My boyfriend and I are long distance. We have a trip planned this summer which involves me going over to him in the USA. This will be the second time I have actually been with him in person. I am flying from the UK with my mum to the east coast. The original plan was for me to fly over to my boyfriend in the west coast (my parents are holidaying on the east coast) but because my mum was freaking out so much about me going to meet my boyfriend i said we can all road trip there together. I even invited her on the trip with us. I have given her all of our dates so that she can travel around with us.\n\nThe plan was for me to stay on the 4th July and fly back on the 5th. Mum knew this. I told her I had booked a flight back already from the west coast to east coast (where she would pick me up and we would fly back to the UK together). She has gone mad at me because she can't believe I would book a flight when she told me she didn't want me flying on my 

In [4]:
def transform_example(example):
    # Extracting the required fields from the dataset
    post = example['info']['post']
    summaries = example['summaries']
    choice = example['choice']

    # Identify the chosen summary and the rejected one
    chosen_summary = summaries[choice]['text']
    rejected_summary = summaries[1 - choice]['text']

    return {
        'prompt': post,
        'chosen': chosen_summary,
        'rejected': rejected_summary
    }

In [5]:
train_dataset = dataset['train'].map(transform_example, remove_columns=dataset['train'].column_names)
eval_dataset = dataset['validation'].map(transform_example, remove_columns=dataset['validation'].column_names)

In [6]:
eval_dataset = eval_dataset.select(range(83000))

In [7]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model_name = "facebook/bart-large-cnn"
tokenizer = BartTokenizer.from_pretrained(model_name)
model = BartForConditionalGeneration.from_pretrained(model_name)
model.to(device)
ref_model = BartForConditionalGeneration.from_pretrained(model_name)



In [8]:
# Set the tokenizer padding and truncation settings
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = 'left'
tokenizer.truncation_side = 'left'

In [9]:
# Set up training arguments
training_args = DPOConfig(
    output_dir="./results",
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    learning_rate=5e-5,
    num_train_epochs=3,
    evaluation_strategy="epoch",
    logging_dir='./logs',
    logging_steps=10,
    save_steps=10,
    save_total_limit=2
)

In [10]:
dpo_trainer = DPOTrainer(
    model=model,
    ref_model=ref_model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    tokenizer=tokenizer,
    beta=0.1,  # Adjust this parameter as needed
)

Map: 100%|████████████████████████| 92858/92858 [02:21<00:00, 655.89 examples/s]
Map: 100%|████████████████████████| 83000/83000 [02:05<00:00, 663.86 examples/s]
dataloader_config = DataLoaderConfiguration(dispatch_batches=None, split_batches=False)


In [None]:
# Train the model
dpo_trainer.train()

# Save the fine-tuned model
model.save_pretrained("./fine-tuned-bart")
tokenizer.save_pretrained("./fine-tuned-bart")

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33melsiga[0m. Use [1m`wandb login --relogin`[0m to force relogin


Could not estimate the number of tokens of the input, floating-point operations will not be computed


Epoch,Training Loss,Validation Loss


Non-default generation parameters: {'max_length': 142, 'min_length': 56, 'early_stopping': True, 'num_beams': 4, 'length_penalty': 2.0, 'no_repeat_ngram_size': 3, 'forced_bos_token_id': 0, 'forced_eos_token_id': 2}
Non-default generation parameters: {'max_length': 142, 'min_length': 56, 'early_stopping': True, 'num_beams': 4, 'length_penalty': 2.0, 'no_repeat_ngram_size': 3, 'forced_bos_token_id': 0, 'forced_eos_token_id': 2}
Non-default generation parameters: {'max_length': 142, 'min_length': 56, 'early_stopping': True, 'num_beams': 4, 'length_penalty': 2.0, 'no_repeat_ngram_size': 3, 'forced_bos_token_id': 0, 'forced_eos_token_id': 2}
Non-default generation parameters: {'max_length': 142, 'min_length': 56, 'early_stopping': True, 'num_beams': 4, 'length_penalty': 2.0, 'no_repeat_ngram_size': 3, 'forced_bos_token_id': 0, 'forced_eos_token_id': 2}
Non-default generation parameters: {'max_length': 142, 'min_length': 56, 'early_stopping': True, 'num_beams': 4, 'length_penalty': 2.0, 'no_