In [1]:
import comet_ml
import torch
import json
import linecache
import collections
import transformers
import os
import datasets

In [2]:
DEVICE = "cuda:2"

In [3]:
def get_nth_line_of_file(path_to_file, n=1):
    with open(path_to_file, "r") as file:
        for i, json_str in enumerate(file):
            if i + 1 == n:
                obj = json.loads(json_str)
                print(obj.keys())
                print(obj["info"])
                print(obj["summaries"])
                print(obj["choice"])
                break


get_nth_line_of_file("comparisons/batch3.json", n=2)

dict_keys(['info', 'split', 'summaries', 'choice', 'worker', 'batch', 'extra'])
{'id': 't3_34xale', 'post': "My boyfriend and I are long distance. We have a trip planned this summer which involves me going over to him in the USA. This will be the second time I have actually been with him in person. I am flying from the UK with my mum to the east coast. The original plan was for me to fly over to my boyfriend in the west coast (my parents are holidaying on the east coast) but because my mum was freaking out so much about me going to meet my boyfriend i said we can all road trip there together. I even invited her on the trip with us. I have given her all of our dates so that she can travel around with us.\n\nThe plan was for me to stay on the 4th July and fly back on the 5th. Mum knew this. I told her I had booked a flight back already from the west coast to east coast (where she would pick me up and we would fly back to the UK together). She has gone mad at me because she can't believe 

In [4]:
dataset = datasets.ComparisonDataset()

In [5]:
def last_non_masked_indices(mask):
    """
    Adapted from https://github.com/openai/summarize-from-feedback/blob/master/summarize_from_feedback/reward_model.py
    """
    bools = mask == 0
    row_len = bools.size(-1)
    zero_or_index = row_len * (~bools).type(torch.long) + torch.arange(
        row_len, dtype=torch.long, device=bools.device
    )
    indices = torch.min(zero_or_index, dim=-1).values - 1
    return torch.max(indices, torch.zeros([1], dtype=indices.dtype, device=mask.device))

def last_non_masked_indices_test():
    actual = last_non_masked_indices(
        mask=torch.tensor([
            [1, 1, 1],
            [1, 1, 0],
            [1, 0, 0],
            [0, 0, 0],
        ]),
    )
    expected = torch.tensor([2, 1, 0, 0])
    assert torch.allclose(actual, expected)


last_non_masked_indices_test()

In [6]:
def last_non_masked_tokens(tokens, mask):
    last_indices = last_non_masked_indices(mask)
    last_tokens = torch.gather(tokens, dim=-1, index=last_indices[:, None])
    return last_tokens.squeeze(-1)

def last_non_masked_tokens_test():
    tokens = torch.tensor([
        [1, 2, 3],
        [4, 5, 6],
        [7, 8, 9],
    ])
    mask = torch.tensor([
        [1, 1, 1],
        [1, 1, 0],
        [0, 0, 0],
    ])
    
    actual = last_non_masked_tokens(tokens, mask)
    expected = torch.tensor([3, 5, 7])
    actual_shape = actual.shape
    expected_shape = tokens.shape[:-1]

    assert torch.allclose(actual, expected)
    assert actual_shape == expected_shape


last_non_masked_tokens_test()

In [7]:
class GPTWithRewardHead(torch.nn.Module):
    def __init__(self, mask_token_id=-100):
        super().__init__()
        self.gpt = transformers.GPT2LMHeadModel.from_pretrained("gpt2", pad_token_id=mask_token_id).to(DEVICE)
        self.gpt.load_state_dict(torch.load("models/baseline.pt"))
        self.generate = self.gpt.generate  # borrow existing generate function
        hidden_size = self.gpt.transformer.wte.weight.shape[-1]
        self.reward_network = torch.nn.Sequential(
            torch.nn.Linear(hidden_size, 4 * hidden_size),
            torch.nn.ReLU(),
            torch.nn.Linear(4 * hidden_size, 1),
        )
    
    def forward(self, input_ids, **kwargs):
        response = self.gpt(input_ids, output_hidden_states=True, **kwargs)  # [batch_size, num_layers, hidden_dim]
        last_hidden_state = response.hidden_states[-1]  # [batch_size, seq_len, hidden_size]
        rewards = self.reward_network(last_hidden_state).squeeze(-1)  # [batch_size, seq_len]
        last_rewards = last_non_masked_tokens(rewards, kwargs["attention_mask"])
        return last_rewards


tokenizer = transformers.GPT2Tokenizer.from_pretrained("gpt2")
tokenizer.pad_token = -100

model = GPTWithRewardHead().to(DEVICE)

In [8]:
def collate_fn(batches):
    summary_good, summary_bad = zip(*batches)
    summaries = summary_good + summary_bad
    return tokenizer(summaries, **tokenizer_config)

tokenizer_config = {
    "max_length": 350,
    "padding": "longest",
    "truncation": True,
    "return_tensors": "pt",
}

data_loader_config = {
    "batch_size": 8,
    "shuffle": True,
    "collate_fn": collate_fn,
    "drop_last": True,
}

num_train = int(0.97 * len(dataset))
num_test = len(dataset) - num_train

data_train, data_test = torch.utils.data.random_split(dataset, (num_train, num_test))
train_data_loader = torch.utils.data.DataLoader(data_train, **data_loader_config)
test_data_loader = torch.utils.data.DataLoader(data_test, **data_loader_config)

In [9]:
def average_number_of_tokens_in_dataset(dataset, n_samples=1_000):
    total_length = 0

    for i, (good, bad) in enumerate(dataset):
        if i == n_samples:
            break
        
        total_length += tokenizer(good, **tokenizer_config).input_ids.shape[-1]

    return total_length / (i+1)


# average_number_of_tokens_in_dataset(dataset)

In [10]:
def train(model, train_data_loader, val_data_loader, comet_experiment, epochs=1, lr=1.5e-5):
    model.train()
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    batch_size = train_data_loader.batch_size
    experiment.add_tag("reward_model")
    
    for _ in range(epochs):
        for step, inputs in enumerate(train_data_loader):
            optimizer.zero_grad()
            
            input_ids = inputs["input_ids"].to(DEVICE)  # [2 * batch_size, seq_len]
            attention_mask = inputs["attention_mask"].to(DEVICE)
            rewards = model(input_ids, attention_mask=attention_mask)  # [2 * batch_size]
            rewards_good, rewards_bad = torch.split(rewards, split_size_or_sections=batch_size, dim=0)

            loss = -torch.log(torch.sigmoid(rewards_good - rewards_bad) + 1e-6).mean()
            
            loss.backward()
            optimizer.step()
            
            comet_experiment.log_metric('train loss', float(loss))
            comet_experiment.log_metric('avg (good - bad) rewards', float((rewards_good - rewards_bad).mean()))
            comet_experiment.log_metric('avg good summary rewards', float(rewards_good.mean()))
            comet_experiment.log_metric('avg bad summary rewards', float(rewards_bad.mean()))
            
            if step % 100 == 0:
                torch.save(model.state_dict(), "models/reward.pt")
            
            if step % 10 == 0:
                model.eval()
                n_correct = 0
                for val_inputs in val_data_loader:
                    with torch.no_grad():
                        input_ids = val_inputs["input_ids"].to(DEVICE)  # [2 * batch_size, seq_len]
                        attention_mask = val_inputs["attention_mask"].to(DEVICE)
                        rewards = model(input_ids, attention_mask=attention_mask)  # [2 * batch_size]
                        rewards_good, rewards_bad = torch.split(rewards, split_size_or_sections=batch_size, dim=0)
                        rewards_good, rewards_bad = rewards_good.reshape(batch_size, 1), rewards_bad.reshape(batch_size, 1)
                        logits = torch.cat([rewards_good, rewards_bad], dim=-1)
                        preds = torch.max(logits, dim=-1).indices
                        targets = torch.zeros(batch_size).to(DEVICE)
                        n_correct += torch.sum(preds == targets)
                model.train()
                
                comet_experiment.log_metric('val accuracy', float(n_correct / (batch_size * len(val_data_loader))), step=step)
    
    comet_experiment.end()


experiment = comet_ml.Experiment(
    api_key=os.getenv("COMET_API_KEY"),
    project_name="learning-to-summarise-using-human-feedback",
    workspace="danesherbs",
    log_env_cpu=False,
    log_env_gpu=False,
)

train(
    model=model,
    train_data_loader=train_data_loader,
    val_data_loader=test_data_loader,
    comet_experiment=experiment,
)

COMET INFO: Experiment is live on comet.ml https://www.comet.ml/danesherbs/learning-to-summarise-using-human-feedback/0829b2992a1544bda3eba1d036d7a186

