In [None]:
import comet_ml
import torch
import json
import linecache
import collections
import transformers
import os
import einops

In [None]:
DEVICE = "cuda:2"

In [None]:
with open("comparisons/batch10.json", "r") as file:
    for i, json_str in enumerate(file):
        obj = json.loads(json_str)
        print(obj.keys())
        print(obj["info"])
        print(obj["summaries"])
        print(obj["choice"])
        break

In [None]:
class ComparisonDataset(torch.utils.data.Dataset):
    def __init__(self, path_to_dataset_dir):
        self.path_to_dataset_dir = path_to_dataset_dir
        self.file_names = [f"{path_to_dataset_dir}/batch{i}.json" for i in range (3, 10)]
        self.file_lengths = None

    def __len__(self):
        if self.file_lengths is None:
            self.file_lengths = collections.OrderedDict()
            for file_name in self.file_names:
                with open(file_name) as f:
                    self.file_lengths[file_name] = sum(1 for line in f)
        
        return sum(self.file_lengths.values())
    
    def __getitem__(self, i):
        i = i % len(self)

        if not (0 <= i < len(self)):
            raise IndexError(f"Tried to retrieve sample at index {i}, but only indicies between 0 and {len(self)-1} modulo {len(self)} are valid.")
        
        cum_length = 0
        
        for file_name in self.file_names:
            cum_length += self.file_lengths[file_name]
            if i < cum_length:
                file_idx = i - cum_length + self.file_lengths[file_name]
                line = linecache.getline(file_name, lineno=file_idx+1)
                payload = json.loads(line)
                choice = payload["choice"]
                summary_good = payload["summaries"][choice]["text"]
                summary_bad = payload["summaries"][1 - choice]["text"]
                post = payload["info"]["post"]
                post_good = f"{post} TLDR:{summary_good}"
                post_bad = f"{post} TLDR:{summary_bad}"
                return post_good, post_bad


dataset = ComparisonDataset(path_to_dataset_dir="./comparisons")

In [None]:
len(dataset)

In [None]:
dataset[-1]

In [None]:
def last_non_masked_indices(mask):
    """
    Adapted from https://github.com/openai/summarize-from-feedback/blob/master/summarize_from_feedback/reward_model.py
    """
    bools = mask == 0
    row_len = bools.size(-1)
    zero_or_index = row_len * (~bools).type(torch.long) + torch.arange(
        row_len, dtype=torch.long, device=bools.device
    )
    indices = torch.min(zero_or_index, dim=-1).values - 1
    return torch.max(indices, torch.zeros([1], dtype=indices.dtype, device=mask.device))


last_non_masked_indices(
    mask=torch.tensor([
        [1, 1, 1],
        [1, 1, 0],
        [1, 0, 0],
        [0, 0, 0],
    ]),
)

In [None]:
def last_non_masked_tokens(tokens, mask):
    last_indices = last_non_masked_indices(mask)
    last_tokens = torch.gather(tokens, dim=-1, index=last_indices[:, None])
    return last_tokens.squeeze(-1)

def last_non_masked_tokens_test():
    tokens = torch.tensor([
        [1, 2, 3],
        [4, 5, 6],
        [7, 8, 9],
    ])
    mask = torch.tensor([
        [1, 1, 1],
        [1, 1, 0],
        [0, 0, 0],
    ])
    
    actual = last_non_masked_tokens(tokens, mask)
    expected = torch.tensor([3, 5, 7])
    actual_shape = actual.shape
    expected_shape = tokens.shape[:-1]

    assert torch.allclose(actual, expected)
    assert actual_shape == expected_shape


last_non_masked_tokens_test()

In [None]:
class GPTWithRewardHead(torch.nn.Module):
    def __init__(self, mask_token_id=-100):
        super().__init__()
        self.mask_token_id = mask_token_id
        self.gpt = transformers.GPT2LMHeadModel.from_pretrained("gpt2", pad_token_id=mask_token_id)
        self.gpt.load_state_dict(torch.load("models/baseline.pt"))
        self.generate = self.gpt.generate  # borrow existing generate function
        hidden_size = self.gpt.transformer.wte.weight.shape[-1]
        self.reward_network = torch.nn.Sequential(
            torch.nn.Linear(hidden_size, 4 * hidden_size),
            torch.nn.ReLU(),
            torch.nn.Linear(4 * hidden_size, 1),
        )
    
    def forward(self, input_ids, **kwargs):
        response = self.gpt(input_ids, output_hidden_states=True, **kwargs)  # [batch_size, num_layers, hidden_dim]
        last_hidden_state = response.hidden_states[-1]  # [batch_size, seq_len, hidden_size]
        rewards = self.reward_network(last_hidden_state).squeeze(-1)  # [batch_size, seq_len]
        last_rewards = last_non_masked_tokens(rewards, kwargs["attention_mask"])
        return last_rewards


tokenizer = transformers.GPT2Tokenizer.from_pretrained("gpt2")
tokenizer.pad_token = -100

model = GPTWithRewardHead().to(DEVICE)

In [None]:
def collate_fn(batches):
    summary_good, summary_bad = zip(*batches)
    summaries = summary_good + summary_bad
    return tokenizer(summaries, **tokenizer_config)

tokenizer_config = {
    "max_length": 350,
    "padding": "longest",
    "truncation": True,
    "return_tensors": "pt",
}

data_loader_config = {
    "batch_size": 8,
    "shuffle": True,
    "collate_fn": collate_fn,
    "drop_last": True,
}

num_train = int(0.97 * len(dataset))
num_test = len(dataset) - num_train

data_train, data_test = torch.utils.data.random_split(dataset, (num_train, num_test))
train_data_loader = torch.utils.data.DataLoader(data_train, **data_loader_config)
test_data_loader = torch.utils.data.DataLoader(data_test, **data_loader_config)

In [None]:
def average_number_of_tokens_in_dataset(dataset, n_samples=1_000):
    total_length = 0

    for i, (good, bad) in enumerate(dataset):
        if i == n_samples:
            break
        
        total_length += tokenizer(good, **tokenizer_config).input_ids.shape[-1]

    return total_length / (i+1)


average_number_of_tokens_in_dataset(dataset)

In [None]:
torch.split(torch.rand(16), split_size_or_sections=8, dim=0)

In [None]:
def train(model, train_data_loader, val_data_loader, comet_experiment, epochs=1, lr=1.5e-5):
    model.train()
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    batch_size = train_data_loader.batch_size
    
    for _ in range(epochs):
        for i, inputs in enumerate(train_data_loader):
            optimizer.zero_grad()
            
            input_ids = inputs["input_ids"].to(DEVICE)  # [2 * batch_size, seq_len]
            attention_mask = inputs["attention_mask"].to(DEVICE)
            rewards = model(input_ids, attention_mask=attention_mask)  # [2 * batch_size]
            rewards_good, rewards_bad = torch.split(rewards, split_size_or_sections=batch_size, dim=0)

            loss = -torch.log(torch.sigmoid(rewards_good - rewards_bad) + 1e-6).mean()
            
            loss.backward()
            optimizer.step()
            
            comet_experiment.log_metric('train loss', float(loss))
            comet_experiment.log_metric('avg (good - bad) rewards', float((rewards_good - rewards_bad).mean()))
            comet_experiment.log_metric('avg good summary rewards', float(rewards_good.mean()))
            comet_experiment.log_metric('avg bad summary rewards', float(rewards_bad.mean()))
            
            if i % 10 == 0:
                model.eval()
                n_correct = 0
                print("validating")
                for val_inputs in val_data_loader:
                    with torch.no_grad():
                        input_ids = val_inputs["input_ids"].to(DEVICE)  # [2 * batch_size, seq_len]
                        attention_mask = val_inputs["attention_mask"].to(DEVICE)
                        rewards = model(input_ids, attention_mask=attention_mask)  # [2 * batch_size]
                        # print("rewards.shape", rewards.shape)
                        # print("batch_size", batch_size)
                        rewards_good, rewards_bad = torch.split(rewards, split_size_or_sections=batch_size, dim=0)
                        logits = torch.cat([rewards_good, rewards_bad], dim=-1)
                        preds = torch.max(logits, dim=-1).indices
                        targets = torch.vstack([torch.ones(batch_size), torch.zeros(batch_size)]).to(DEVICE)
                        n_correct += torch.sum(preds == targets)
                model.train()
            
            comet_experiment.log_metric('val accuracy', float(n_correct / (batch_size * len(val_data_loader))))
    
    comet_experiment.end()


experiment = comet_ml.Experiment(
    api_key=os.getenv("COMET_API_KEY"),
    project_name="learning-to-summarise-using-human-feedback",
    workspace="danesherbs",
    log_env_cpu=False,
    log_env_gpu=False,
)

train(
    model=model,
    train_data_loader=train_data_loader,
    val_data_loader=test_data_loader,
    comet_experiment=experiment,
)