In [1]:
from comet_ml import Experiment
import transformers
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import copy
import einops
import os
from dotenv import load_dotenv
from torchtyping import TensorType
import collections
import json
import linecache

load_dotenv()  # take environment variables from .env

True

In [2]:
device = "cuda:2"

In [3]:
def last_non_masked_indices(mask):
    """
    Adapted from https://github.com/openai/summarize-from-feedback/blob/master/summarize_from_feedback/reward_model.py
    """
    bools = mask == 0
    row_len = bools.size(-1)
    zero_or_index = row_len * (~bools).type(torch.long) + torch.arange(
        row_len, dtype=torch.long, device=bools.device
    )
    indices = torch.min(zero_or_index, dim=-1).values - 1
    return torch.max(indices, torch.zeros([1], dtype=indices.dtype, device=mask.device))

def last_non_masked_tokens(tokens, mask):
    last_indices = last_non_masked_indices(mask)
    last_tokens = torch.gather(tokens, dim=-1, index=last_indices[:, None])
    return last_tokens.squeeze(-1)


class GPTWithRewardHead(torch.nn.Module):
    def __init__(self, mask_token_id=-100):
        super().__init__()
        self.gpt = transformers.GPT2LMHeadModel.from_pretrained("gpt2", pad_token_id=mask_token_id).to(device)
        self.generate = self.gpt.generate  # borrow existing generate function
        hidden_size = self.gpt.transformer.wte.weight.shape[-1]
        self.reward_network = torch.nn.Sequential(
            torch.nn.Linear(hidden_size, 4 * hidden_size),
            torch.nn.ReLU(),
            torch.nn.Linear(4 * hidden_size, 1),
        )
    
    def forward(self, input_ids, **kwargs):
        response = self.gpt(input_ids, output_hidden_states=True, **kwargs)  # [batch_size, num_layers, hidden_dim]
        last_hidden_state = response.hidden_states[-1]  # [batch_size, seq_len, hidden_size]
        rewards = self.reward_network(last_hidden_state).squeeze(-1)  # [batch_size, seq_len]
        last_rewards = last_non_masked_tokens(rewards, kwargs["attention_mask"])
        return last_rewards


tokenizer = transformers.GPT2Tokenizer.from_pretrained("gpt2")
tokenizer.pad_token = -100

reward_model = GPTWithRewardHead().to(device)
reward_model.load_state_dict(torch.load("models/reward.pt"))
reward_model.eval()
None

In [4]:
class GPTWithValueHead(torch.nn.Module):
    
    def __init__(self, path_to_baseline_model="models/baseline.pt"):
        super().__init__()
        self.tokenizer = transformers.GPT2Tokenizer.from_pretrained('gpt2')
        self.gpt = transformers.GPT2LMHeadModel.from_pretrained('gpt2', pad_token_id=self.tokenizer.eos_token_id)
        self.gpt.load_state_dict(torch.load(path_to_baseline_model))
        self.generate = self.gpt.generate  # borrow existing generate function
        hidden_size = self.gpt.transformer.wte.weight.shape[-1]
        self.value_network = torch.nn.Sequential(
            torch.nn.Linear(hidden_size, 4 * hidden_size),
            torch.nn.ReLU(),
            torch.nn.Linear(4 * hidden_size, 1),
        )
    
    def forward(self, input_ids):
        # response = self.gpt(input_ids, output_hidden_states=True)  # [batch_size, num_layers, hidden_dim]
        response = self.gpt(input_ids)  # [batch_size, num_layers, hidden_dim]
        # last_hidden_state = response.hidden_states[-1]  # [batch_size, seq_len, hidden_size]
        # values = self.value_network(last_hidden_state).squeeze(-1)
        logits = response.logits  # [batch_size, seq_len, vocab_size]
        # return logits, values
        return logits


# ref_model = GPTWithValueHead().to(device)
ref_model = transformers.GPT2LMHeadModel.from_pretrained('gpt2', pad_token_id=tokenizer.eos_token_id).to(device)
ref_model.load_state_dict(torch.load("models/baseline.pt"))
ref_model.eval()
tokenizer = transformers.GPT2Tokenizer.from_pretrained('gpt2');

In [18]:
def get_samples(model, input_ids, batch_size=50, gen_len=10, temperature=0.6):
    with torch.no_grad():
        samples = model.generate(
            input_ids,
            max_length=input_ids.shape[-1] + gen_len,
            min_length=input_ids.shape[-1] + gen_len,
            do_sample=True,
            temperature=temperature,
            top_k=len(tokenizer),
            top_p=1.0,
            num_return_sequences=batch_size
        )
        gen_samples = samples[:, input_ids.shape[-1]:]
        sample_ids = copy.deepcopy(samples)
        samples = tokenizer.batch_decode(samples)
        gen_samples = tokenizer.batch_decode(gen_samples)
    return sample_ids, samples, gen_samples

In [6]:
input_ids = tokenizer.encode('Testing', return_tensors='pt').to(device)
sample_ids, samples, gen_samples =  get_samples(model=ref_model, input_ids=input_ids, batch_size=5)

In [7]:
def reward_fn(gen_sample):
    if isinstance(gen_sample, list):
        return [reward_fn(item) for item in gen_sample]
    else:
        return gen_sample.count('.')

def reward_fn_test():
    A = 'This is a test.'
    assert reward_fn(A) == 1
    B = '......'
    assert reward_fn(B) ==6
    C = 'Whatever'
    assert reward_fn(C) == 0
    assert reward_fn([A, B, C]) == [1, 6, 0]

    print('Passed test.')
    return

reward_fn_test()

Passed test.


In [8]:
def get_logprobs(input_ids, logits):
    logprobs = F.log_softmax(logits, dim=-1)
    logprobs = torch.gather(logprobs, -1, input_ids[:,:,None])[:,:,0]
    return logprobs

def logprobs_test(logprobs_fn):
    input_ids = torch.randint(0, 100, (10, 10))
    logits = torch.randn(10, 10, 100)
    ref_logprobs = get_logprobs(input_ids, logits)
    logprobs = logprobs_fn(input_ids, logits)
    assert torch.allclose(logprobs, ref_logprobs)

In [9]:
def noas_log_probs_from_logits(sample_ids, logits, prefix_len):
    return get_logprobs(sample_ids[:,prefix_len:], logits[:,prefix_len-1:-1])

def log_probs_from_logits(logits: TensorType["batch_size", "seq_len", "vocab_size"], input_ids: TensorType["batch_size", "seq_len"], prefix_len=1) -> TensorType["batch_size", "seq_len"]:
    assert prefix_len > 0

    logits = logits[:, prefix_len-1:-1]  # [batch_size, seq_len-1, vocab_size] -- ignore x_{n+1}
    log_probs = torch.nn.functional.log_softmax(logits, dim=-1) 
    
    input_ids = input_ids[:, prefix_len:]  # [batch_size, seq_len-1, vocab_size]  -- ignore x_0
    input_ids = input_ids.unsqueeze(-1)  # [batch_size, seq_len-1, vocab_size, 1]
    seq_log_probs = torch.gather(input=log_probs, dim=-1, index=input_ids)
    seq_log_probs = seq_log_probs.squeeze(-1) # [batch_size, seq_len-1]
    
    return seq_log_probs

def log_probs_from_logits_test():
    input_ids = torch.randint(0, 100, (10, 10))
    logits = torch.randn(10, 10, 100)
    for prefix_len in range(1, 10):
        actual = log_probs_from_logits(logits=logits, input_ids=input_ids, prefix_len=prefix_len)
        expected = noas_log_probs_from_logits(logits=logits, sample_ids=input_ids, prefix_len=prefix_len)
        assert torch.allclose(actual, expected)


log_probs_from_logits_test()

In [10]:
def kl_divergence(p_logits, q_logits):
    p_log_probs = torch.nn.functional.log_softmax(p_logits, dim=-1)
    q_log_probs = torch.nn.functional.log_softmax(q_logits, dim=-1)
    kl_div = torch.nn.functional.kl_div(input=q_log_probs, target=p_log_probs, reduction="none", log_target=True)
    return torch.sum(kl_div, dim=-1)

def kl_divergence_test():
    p_logits = torch.tensor([[1, 2, 3]], dtype=torch.float32)
    q_logits = torch.tensor([[4, 5, 6]], dtype=torch.float32)
    p = torch.nn.functional.softmax(p_logits, dim=-1)
    q = torch.nn.functional.softmax(q_logits, dim=-1)
    actual = kl_divergence(p_logits, q_logits)
    expected = torch.tensor([[p[0][i] * torch.log(p[0][i] / q[0][i]) for i in range(3)]])
    assert torch.allclose(actual, expected), f"{actual}, {expected}"


kl_divergence_test()

In [11]:
def get_lr_scheduler(warmup_steps, total_steps, final_scale):
    def lr_scheduler(step):
        if step < warmup_steps:
            return step / warmup_steps
        else:
            return 1-(1-final_scale)*(step-warmup_steps)/(total_steps-warmup_steps)
    
    return lr_scheduler

In [12]:
def whiten(t, eps=1e-5):
    t = t - t.mean()
    t = t / (t.std() + eps)
    return t

In [13]:
def get_minibatches(sample_ids, ref_logits, rewards):
    sample_ids = einops.rearrange(sample_ids,"(m b) t -> m b t", b=minibatch_size)
    ref_logits = einops.rearrange(ref_logits, "(m b) t d -> m b t d", b=minibatch_size)
    rewards = einops.rearrange(rewards, "(m b) -> m b", b=minibatch_size)
    for i in range(batch_size // minibatch_size):
        yield {"sample_ids": sample_ids[i], "ref_logits": ref_logits[i], "rewards": rewards[i]}

In [14]:
class ComparisonDataset(torch.utils.data.Dataset):
    def __init__(self, path_to_dataset_dir="./comparisons"):
        self.path_to_dataset_dir = path_to_dataset_dir
        self.file_names = [f"{path_to_dataset_dir}/batch{i}.json" for i in range(11, 21)]
        self.file_lengths = None
        self.seen = set()

    def __len__(self):
        if self.file_lengths is None:
            self.file_lengths = collections.OrderedDict()
            for file_name in self.file_names:
                with open(file_name) as f:
                    self.file_lengths[file_name] = sum(1 for line in f)

        return sum(self.file_lengths.values())

    def __getitem__(self, i):
        i = i % len(self)

        if not (0 <= i < len(self)):
            raise IndexError(
                f"Tried to retrieve sample at index {i}, but only indicies between 0 and {len(self)-1} modulo {len(self)} are valid."
            )

        cum_length = 0

        for file_name in self.file_names:
            cum_length += self.file_lengths[file_name]
            if i < cum_length:
                file_idx = i - cum_length + self.file_lengths[file_name]
                line = linecache.getline(file_name, lineno=file_idx + 1)
                payload = json.loads(line)
                post = payload["info"]["post"]
                return f"{post} TLDR:"


dataset = ComparisonDataset()

In [15]:
n_minibatches_per_epoch = 4
minibatch_size = 5
kl_coef = 80  # 0.2
vf_coef = 0.3
n_steps = 1_000  # 300
warmup_steps = 30
lr = 3e-5
gen_len = 30

In [16]:
torch.manual_seed(42)

experiment = Experiment(
    api_key=os.getenv("COMET_API_KEY"),
    project_name="learning-to-summarise-using-human-feedback",
    workspace="danesherbs",
    log_env_cpu=False,
    log_env_gpu=False,
)

experiment.add_tag("human_feedback_model")

data_loader_config = {
    "batch_size": 1,
    "shuffle": True,
}

data_loader = iter(torch.utils.data.DataLoader(dataset, **data_loader_config))
# model = GPTWithValueHead().to(device)
model = transformers.GPT2LMHeadModel.from_pretrained('gpt2', pad_token_id=tokenizer.eos_token_id).to(device)
model.load_state_dict(torch.load("models/baseline.pt"))

batch_size = minibatch_size * n_minibatches_per_epoch
prefix_len = input_ids.shape[-1]

optimizer = optim.Adam(model.parameters(), lr=lr)
scheduler = get_lr_scheduler(5, n_steps, 0.1)
lr_scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=scheduler)

def get_loss(sample_ids, ref_logits, rewards, prefix_len):
    # logits, values = model(sample_ids)
    logits = model(sample_ids).logits
    log_probs = log_probs_from_logits(logits=logits, input_ids=sample_ids, prefix_len=prefix_len)
    
    # kl loss
    kl = kl_divergence(q_logits=logits, p_logits=ref_logits)[:,prefix_len-1:-1]
    kl_loss = kl.mean()
    
    # value loss
    # seq_len = logits.shape[1]
    # rewards_to_go = einops.repeat(rewards, "batch_size -> batch_size seq_len", seq_len=seq_len)
    # value_loss = torch.nn.functional.mse_loss(values, rewards_to_go)
    
    # policy loss
    policy_loss = -(torch.sum(log_probs, dim=-1) * rewards).mean()
    
    # total loss
    # loss = policy_loss + vf_coef * value_loss + kl_coef * kl_loss
    loss = policy_loss + kl_coef * kl_loss
    
    # experiment.log_metric('value_loss', value_loss.item())
    experiment.log_metric('kl loss', kl_loss.item())
    experiment.log_metric('total_loss', loss.item())
    experiment.log_metric('lr', lr_scheduler.get_last_lr()[0])
    experiment.log_metric('policy_loss', policy_loss.item())

    return loss


for batch_idx in range(n_steps):
    [prefix] = next(data_loader)
    input_ids = tokenizer(prefix, return_tensors='pt').input_ids.to(device)
    sample_ids, samples, gen_samples = get_samples(model, input_ids=input_ids, batch_size=batch_size, gen_len=gen_len)
    sample_ids = sample_ids.to(device)
    # ref_logits, _ = ref_model(sample_ids)
    with torch.no_grad():
        with torch.autocast(device_type="cuda"):
            ref_logits = ref_model(sample_ids).logits.detach()
    
    with torch.no_grad():
        with torch.autocast(device_type="cuda"):
            attention_mask = torch.ones_like(sample_ids).to(device)  # all samples have the same length and therefore no masking takes place
            rewards = reward_model(sample_ids, attention_mask=attention_mask).detach()
            rewards_normed = whiten(rewards)
            rewards_normed = rewards_normed
    
    experiment.log_metric('mean_reward', rewards.mean())
    experiment.log_text(prefix + gen_samples[0])
    
    for minibatch in get_minibatches(sample_ids, ref_logits, rewards_normed):
        with torch.autocast(device_type="cuda"):
            loss = get_loss(**minibatch, prefix_len=prefix_len)
        loss.backward()
        nn.utils.clip_grad_norm_(model.parameters(), 1.0, norm_type=2.0, error_if_nonfinite=True)
        optimizer.step()
        optimizer.zero_grad()
    
    if batch_idx > 0 and batch_idx % 1_000 == 0:
        torch.save(model.state_dict(), "models/human_feedback.pt")

    lr_scheduler.step()
    
experiment.end()

COMET INFO: Experiment is live on comet.ml https://www.comet.ml/danesherbs/learning-to-summarise-using-human-feedback/1359bbe84af44bdab036e0d0a8e40e99

COMET INFO: ---------------------------
COMET INFO: Comet.ml Experiment Summary
COMET INFO: ---------------------------
COMET INFO:   Data:
COMET INFO:     display_summary_level : 1
COMET INFO:     url                   : https://www.comet.ml/danesherbs/learning-to-summarise-using-human-feedback/1359bbe84af44bdab036e0d0a8e40e99
COMET INFO:   Metrics [count] (min, max):
COMET INFO:     kl loss [4000]     : (0.0, 12.157878875732422)
COMET INFO:     loss [400]         : (-2255.92724609375, 1920.4840087890625)
COMET INFO:     lr [4000]          : (0.0, 3e-05)
COMET INFO:     mean_reward [1000] : (-0.25146484375, 0.87548828125)
COMET INFO:     policy_loss [4000] : (-3939.677001953125, 2653.654541015625)
COMET INFO:     total_loss [4000]  : (-3551.366943359375, 3077.697021484375)
COMET INFO:   Uploads:
COMET INFO:     environment details     

In [17]:
cls_model = transformers.AutoModelForSequenceClassification.from_pretrained('cardiffnlp/twitter-roberta-base-sentiment').to(device)
cls_tokenizer = transformers.AutoTokenizer.from_pretrained('cardiffnlp/twitter-roberta-base-sentiment')

def reward_fn(gen_sample):
    if isinstance(gen_sample, list):
        return [reward_fn(item) for item in gen_sample]
    else:
        logits = cls_model(cls_tokenizer(gen_sample, return_tensors='pt')['input_ids'].to(device)).logits[0]
        logprobs = F.log_softmax(logits, dim=0)
        assert logprobs.shape == (3,)
        return float(logprobs[-1])

In [33]:
prompt = """
So, I was taking a man's order at the cafe I work at. He was pretty overweight, crazy hair, sunglasses with one of the arms broken off, and just seemed generally socially awkward. He orders a cupcake, which is pretty expensive, and I tell him the price is $4.26. He reacts: "Oh wow, $4.26?" And I reply that it will change his life.

The response that floored me: "Oh, well, I'm going to hold you to that, if it doesn't I'll put it on my blog. Yeah, I've already been taking pictures of you with my phone, oh my god what am I saying."

I laugh quietly and just finish the transaction in silence, pretending to not have heard him. But seriously you guys, this shit was hella awkward. And I felt super bad for this dude because it obviously just slipped out, as a joke that just turned out to be really creepy. SO WHAT THE FUCK AM I SUPPOSED TO DO?

TLDR:"""


test_input_ids = torch.tensor(tokenizer([prompt]).input_ids).to(device)

_, generated_text, _ = get_samples(
    model,
    test_input_ids,
    batch_size=1,
    gen_len=30,
    temperature=1e-3,
)

generated_text

['\nSo, I was taking a man\'s order at the cafe I work at. He was pretty overweight, crazy hair, sunglasses with one of the arms broken off, and just seemed generally socially awkward. He orders a cupcake, which is pretty expensive, and I tell him the price is $4.26. He reacts: "Oh wow, $4.26?" And I reply that it will change his life.\n\nThe response that floored me: "Oh, well, I\'m going to hold you to that, if it doesn\'t I\'ll put it on my blog. Yeah, I\'ve already been taking pictures of you with my phone, oh my god what am I saying."\n\nI laugh quietly and just finish the transaction in silence, pretending to not have heard him. But seriously you guys, this shit was hella awkward. And I felt super bad for this dude because it obviously just slipped out, as a joke that just turned out to be really creepy. SO WHAT THE FUCK AM I SUPPOSED TO DO?\n\nTLDR: I was in a car with a guy and he was in a car with a guy and he was in a car with a guy and he was in']