In [1]:
import comet_ml
import torch
import json
import linecache
import collections
import transformers
import os

In [2]:
DEVICE = "cuda:3"

In [3]:
with open("comparisons/batch10.json", "r") as file:
    for i, json_str in enumerate(file):
        obj = json.loads(json_str)
        print(obj.keys())
        print(obj["info"])
        print(obj["summaries"])
        print(obj["choice"])
        break

dict_keys(['info', 'split', 'summaries', 'choice', 'worker', 'batch', 'extra'])
{'id': 't3_52mb8y', 'post': "This is my first post so please be kind :)\n\nI know that lots of people often feel confused when they come out of a long-term relationship. They think they have forgotten how to be single, or how to flirt/date.\n\nI am one of these people.\n\nThe problem is, my relationship started when I had just turned 16. I have never been single - as an adult. That might sound silly. But the only time I have ever flirted or dated was as an over-confident, hormone-riddled teenager.\n\nNow I have a pretty demanding job, responsibilities blah blah... And I just don't know how to this!\n\nI'm no way in a rush to get into a new relationship, but that doesn't mean I want to be completely alone in the mean time.\n\nIf anyone has experienced anything similar, or just generally has some advice, it would be greatly appreciated!", 'title': "I [23F] have just come out of 8 year relationship. Feel like 

In [4]:
class ComparisonDataset(torch.utils.data.Dataset):
    def __init__(self, path_to_dataset_dir):
        self.path_to_dataset_dir = path_to_dataset_dir
        self.file_names = [f"{path_to_dataset_dir}/batch{i}.json" for i in range (3, 11)]
        self.file_lengths = None

    def __len__(self):
        if self.file_lengths is None:
            self.file_lengths = collections.OrderedDict()
            for file_name in self.file_names:
                with open(file_name) as f:
                    self.file_lengths[file_name] = sum(1 for line in f)
        
        return sum(self.file_lengths.values())
    
    def __getitem__(self, i):
        i = i % len(self)

        if not (0 <= i < len(self)):
            raise IndexError(f"Tried to retrieve sample at index {i}, but only indicies between 0 and {len(self)-1} modulo {len(self)} are valid.")
        
        cum_length = 0
        
        for file_name in self.file_names:
            cum_length += self.file_lengths[file_name]
            if i < cum_length:
                file_idx = i - cum_length + self.file_lengths[file_name]
                line = linecache.getline(file_name, lineno=file_idx+1)
                payload = json.loads(line)
                choice = payload["choice"]
                summary_good = payload["summaries"][choice]["text"]
                summary_bad = payload["summaries"][1 - choice]["text"]
                post = payload["info"]["post"]
                post_good = f"{post} TLDR:{summary_good}"
                post_bad = f"{post} TLDR:{summary_bad}"
                return post_good, post_bad


dataset = ComparisonDataset(path_to_dataset_dir="./comparisons")

In [10]:
len(dataset)

103901

In [5]:
dataset[-1]

('A bit of background:\n\nAll of my exes that I have had, (with the [current] exception of my most recent one [less than two months since breakup]) are married, and all of the ones that had interest, have kids.\n\nEvery single one of them.\n\nMy most recent ex broke up with me under the guise of we needed to take a break so we were better together, and kept saying things like we needed space so we could become better for ourselves, not each other.\nThen today I found out that not only is she with someone else in all possible ways, less than two months after our breakup, but she realized when she met him while we were still together that he was the one she was going to spend her life with. TLDR: exes who I have had breakups with all married, every single one of them, and every single one of them is with someone else less than two months after our breakup.',
 'A bit of background:\n\nAll of my exes that I have had, (with the [current] exception of my most recent one [less than two months

In [6]:
class GPTWithRewardHead(torch.nn.Module):
    
    def __init__(self, mask_token_id=-100):
        super().__init__()
        self.gpt = transformers.GPT2LMHeadModel.from_pretrained('gpt2', pad_token_id=mask_token_id)
        self.generate = self.gpt.generate  # borrow existing generate function
        hidden_size = self.gpt.transformer.wte.weight.shape[-1]
        self.reward_network = torch.nn.Sequential(
            torch.nn.Linear(hidden_size, 4 * hidden_size),
            torch.nn.ReLU(),
            torch.nn.Linear(4 * hidden_size, 1),
        )
    
    def forward(self, input_ids, **kwargs):
        response = self.gpt(input_ids, output_hidden_states=True, **kwargs)  # [batch_size, num_layers, hidden_dim]
        last_hidden_state = response.hidden_states[-1]  # [batch_size, seq_len, hidden_size]
        rewards = self.reward_network(last_hidden_state).squeeze(-1)
        last_reward = rewards[:, -1]
        logits = response.logits  # [batch_size, seq_len, vocab_size]
        return logits, last_reward


tokenizer = transformers.GPT2Tokenizer.from_pretrained('gpt2')
tokenizer.pad_token = -100
model = GPTWithRewardHead().to(DEVICE)

In [7]:
def collate_fn(batches):
    summary_good, summary_bad = zip(*batches)
    tokens_good = tokenizer(summary_good, **tokenizer_config)
    tokens_bad = tokenizer(summary_bad, **tokenizer_config)
    return tokens_good, tokens_bad

tokenizer_config = {
    "max_length": 512,
    "padding": "longest",
    "truncation": True,
    "return_tensors": "pt",
}

data_loader_config = {
    "batch_size": 4,
    "shuffle": True,
    "collate_fn": collate_fn,
}

num_train = int(0.95 * len(dataset))
num_test = len(dataset) - num_train

data_train, data_test = torch.utils.data.random_split(dataset, (num_train, num_test))
train_data_loader = torch.utils.data.DataLoader(data_train, **data_loader_config)
test_data_loader = torch.utils.data.DataLoader(data_test, **data_loader_config)

In [8]:
for x in train_data_loader:
    print(x)
    break

({'input_ids': tensor([[   40,   357,  1731,  ..., 50256, 50256, 50256],
        [   40,  3888,   287,  ..., 50256, 50256, 50256],
        [   40,  1053,   587,  ..., 50256, 50256, 50256],
        [    7,  5305,  1096,  ...,   477,  1865,    13]]), 'attention_mask': tensor([[1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 1, 1, 1]])}, {'input_ids': tensor([[   40,   357,  1731,  ..., 50256, 50256, 50256],
        [   40,  3888,   287,  ..., 50256, 50256, 50256],
        [   40,  1053,   587,  ..., 50256, 50256, 50256],
        [    7,  5305,  1096,  ...,  5876,  4547,     8]]), 'attention_mask': tensor([[1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 1, 1, 1]])})


In [9]:
def train(model, train_data_loader, epochs=30, lr=1e-3, comet_experiment=None):
    model.train()
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=50)
    
    for _ in range(epochs):
        for inputs_good, inputs_bad in train_data_loader:
            optimizer.zero_grad()
            
            input_good_ids = inputs_good["input_ids"].to(DEVICE)
            attention_good_mask = inputs_good["attention_mask"].to(DEVICE)
            _, rewards_good = model(input_good_ids, attention_mask=attention_good_mask)
            
            input_bad_ids = inputs_bad["input_ids"].to(DEVICE)
            attention_bad_mask = inputs_bad["attention_mask"].to(DEVICE)
            _, rewards_bad = model(input_bad_ids, attention_mask=attention_bad_mask)
            
            loss = torch.log(torch.sigmoid(rewards_good - rewards_bad)).mean()
            
            loss.backward()
            optimizer.step()
            scheduler.step(loss)
            
            if comet_experiment is not None:
                comet_experiment.log_metric('train loss', float(loss))
                experiment.log_metric('lr', optimizer.param_groups[0]['lr'])
    
    if comet_experiment is not None:
        comet_experiment.end()


experiment = comet_ml.Experiment(
    api_key=os.getenv("COMET_API_KEY"),
    project_name="learning-to-summarise-using-human-feedback",
    workspace="danesherbs",
    log_env_cpu=False,
    log_env_gpu=False,
)

train(model, train_data_loader, comet_experiment=experiment, lr=3e-5)

COMET INFO: Experiment is live on comet.ml https://www.comet.ml/danesherbs/learning-to-summarise-using-human-feedback/1868132920eb48b5acea254d777693b8



KeyboardInterrupt: 