In [1]:
from comet_ml import Experiment
import transformers
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import copy
import einops
import os
from dotenv import load_dotenv
from torchtyping import TensorType

load_dotenv()  # take environment variables from .env

True

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device = "cuda:3"

In [3]:
class GPTWithValueHead(torch.nn.Module):
    
    def __init__(self):
        super().__init__()
        self.tokenizer = transformers.GPT2Tokenizer.from_pretrained('gpt2')
        self.gpt = transformers.GPT2LMHeadModel.from_pretrained('gpt2', pad_token_id=self.tokenizer.eos_token_id)
        self.generate = self.gpt.generate  # borrow existing generate function
        hidden_size = self.gpt.transformer.wte.weight.shape[-1]
        self.value_network = torch.nn.Sequential(
            torch.nn.Linear(hidden_size, 4 * hidden_size),
            torch.nn.ReLU(),
            torch.nn.Linear(4 * hidden_size, 1),
        )
    
    def forward(self, input_ids):
        response = self.gpt(input_ids, output_hidden_states=True)  # [batch_size, num_layers, hidden_dim]
        last_hidden_state = response.hidden_states[-1]  # [batch_size, seq_len, hidden_size]
        values = self.value_network(last_hidden_state).squeeze(-1)
        logits = response.logits  # [batch_size, seq_len, vocab_size]
        return logits, values


ref_model = GPTWithValueHead().to(device)
ref_model.eval()
tokenizer = transformers.GPT2Tokenizer.from_pretrained('gpt2');

In [4]:
def get_samples(model, input_ids, batch_size=50, gen_len=10):
    with torch.no_grad():
        samples = model.generate(input_ids, max_length=input_ids.shape[-1]+gen_len, min_length=input_ids.shape[-1]+gen_len, do_sample=True, temperature=0.6, top_k=len(tokenizer), top_p=1.0, num_return_sequences=batch_size)
        gen_samples = samples[:, input_ids.shape[-1]:]
        sample_ids = copy.deepcopy(samples)
        samples = tokenizer.batch_decode(samples)
        gen_samples = tokenizer.batch_decode(gen_samples)
    return sample_ids, samples, gen_samples

In [5]:
input_ids = tokenizer.encode('Testing', return_tensors='pt').to(device)
sample_ids, samples, gen_samples =  get_samples(model=ref_model, input_ids=input_ids, batch_size=5)

In [6]:
def reward_fn(gen_sample):
    if isinstance(gen_sample, list):
        return [reward_fn(item) for item in gen_sample]
    else:
        return gen_sample.count('.')

def reward_fn_test():
    A = 'This is a test.'
    assert reward_fn(A) == 1
    B = '......'
    assert reward_fn(B) ==6
    C = 'Whatever'
    assert reward_fn(C) == 0
    assert reward_fn([A, B, C]) == [1, 6, 0]

    print('Passed test.')
    return

reward_fn_test()

Passed test.


In [7]:
def get_logprobs(input_ids, logits):
    logprobs = F.log_softmax(logits, dim=-1)
    logprobs = torch.gather(logprobs, -1, input_ids[:,:,None])[:,:,0]
    return logprobs

def logprobs_test(logprobs_fn):
    input_ids = torch.randint(0, 100, (10, 10))
    logits = torch.randn(10, 10, 100)
    ref_logprobs = get_logprobs(input_ids, logits)
    logprobs = logprobs_fn(input_ids, logits)
    assert torch.allclose(logprobs, ref_logprobs)


In [8]:
def noas_log_probs_from_logits(sample_ids, logits, prefix_len):
    return get_logprobs(sample_ids[:,prefix_len:], logits[:,prefix_len-1:-1])

def log_probs_from_logits(logits: TensorType["batch_size", "seq_len", "vocab_size"], input_ids: TensorType["batch_size", "seq_len"], prefix_len=1) -> TensorType["batch_size", "seq_len"]:
    assert prefix_len > 0

    logits = logits[:, prefix_len-1:-1]  # [batch_size, seq_len-1, vocab_size] -- ignore x_{n+1}
    log_probs = torch.nn.functional.log_softmax(logits, dim=-1) 
    
    input_ids = input_ids[:, prefix_len:]  # [batch_size, seq_len-1, vocab_size]  -- ignore x_0
    input_ids = input_ids.unsqueeze(-1)  # [batch_size, seq_len-1, vocab_size, 1]
    seq_log_probs = torch.gather(input=log_probs, dim=-1, index=input_ids)
    seq_log_probs = seq_log_probs.squeeze(-1) # [batch_size, seq_len-1]
    
    return seq_log_probs

def log_probs_from_logits_test():
    input_ids = torch.randint(0, 100, (10, 10))
    logits = torch.randn(10, 10, 100)
    for prefix_len in range(1, 10):
        actual = log_probs_from_logits(logits=logits, input_ids=input_ids, prefix_len=prefix_len)
        expected = noas_log_probs_from_logits(logits=logits, sample_ids=input_ids, prefix_len=prefix_len)
        assert torch.allclose(actual, expected)


log_probs_from_logits_test()

In [9]:
def kl_divergence(p_logits, q_logits):
    p_log_probs = torch.nn.functional.log_softmax(p_logits, dim=-1)
    q_log_probs = torch.nn.functional.log_softmax(q_logits, dim=-1)
    kl_div = torch.nn.functional.kl_div(input=q_log_probs, target=p_log_probs, reduction="none", log_target=True)
    return torch.sum(kl_div, dim=-1)

def kl_divergence_test():
    p_logits = torch.tensor([[1, 2, 3]], dtype=torch.float32)
    q_logits = torch.tensor([[4, 5, 6]], dtype=torch.float32)
    p = torch.nn.functional.softmax(p_logits, dim=-1)
    q = torch.nn.functional.softmax(q_logits, dim=-1)
    actual = kl_divergence(p_logits, q_logits)
    expected = torch.tensor([[p[0][i] * torch.log(p[0][i] / q[0][i]) for i in range(3)]])
    assert torch.allclose(actual, expected), f"{actual}, {expected}"


kl_divergence_test()

In [10]:
def get_lr_scheduler(warmup_steps, total_steps, final_scale):
    def lr_scheduler(step):
        if step < warmup_steps:
            return step / warmup_steps
        else:
            return 1-(1-final_scale)*(step-warmup_steps)/(total_steps-warmup_steps)
    return lr_scheduler

In [11]:
def whiten(t, eps=1e-5):
    t = t - t.mean()
    t = t/(t.std()+eps)
    return t

In [12]:
def shift_tensor_left(t: TensorType["batch_size", "seq_len"]) -> TensorType["batch_size", "seq_len"]:
    """
    Shifts tensor left one and fills with zeros: [v_0, v_1, ..., v_n] -> [v_1, ..., v_n, 0].
    
    Note: you probably want to detach the result of this function.
    """
    shifted = torch.zeros_like(t)
    shifted[:, :-1] = t[:, 1:]
    return shifted

def shift_tensor_left_test():
    t = torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8]])
    shifted_t = shift_tensor_left(t)
    assert torch.allclose(shifted_t, torch.tensor([[2, 3, 4, 0], [6, 7, 8, 0]]))


shift_tensor_left_test()

In [13]:
def rewards_to_go_from_rewards_per_timestep(rewards_per_timestep: TensorType["batch_size", "seq_len"]) -> TensorType["batch_size", "seq_len"]:
    flipped = torch.fliplr(rewards_per_timestep)
    flipped_cumsum = torch.cumsum(flipped, dim=-1)
    rewards_to_go = torch.fliplr(flipped_cumsum)
    return rewards_to_go

def rewards_to_go_from_rewards_per_timestep_test():
    rewards_per_timestep = torch.tensor([
        [0.5, 1.0, 0.5],
        [1.5, 1.0, 1.0],
    ])
    actual = rewards_to_go_from_rewards_per_timestep(rewards_per_timestep)
    expected = torch.tensor([
        [2.0, 1.5, 0.5],
        [3.5, 2.0, 1.0],
    ])
    assert torch.allclose(actual, expected)


rewards_to_go_from_rewards_per_timestep_test()

In [14]:
def rewards_per_timestep_from_rewards(rewards: TensorType["batch_size"], seq_len: int) -> TensorType["batch_size", "seq_len"]:
    batch_size = rewards.shape[0]
    rewards_per_timestep = torch.zeros(batch_size, seq_len).to(device)
    rewards_per_timestep[:, -1] = rewards
    return rewards_per_timestep

def rewards_per_timestep_from_rewards_test():
    rewards = torch.tensor([1.0, 2.0, 3.0]).to(device)
    seq_len = 3
    actual = rewards_per_timestep_from_rewards(rewards, seq_len)
    expected = torch.tensor([
        [0.0, 0.0, 1.0],
        [0.0, 0.0, 2.0],
        [0.0, 0.0, 3.0],
    ]).to(device)
    assert torch.allclose(actual, expected)


rewards_per_timestep_from_rewards_test()

In [15]:
def get_minibatches(sample_ids, old_log_probs, ref_logits, old_values, rewards):
    sample_ids = einops.rearrange(sample_ids,'(m b) t -> m b t', b=minibatch_size)
    old_log_probs = einops.rearrange(old_log_probs, '(m b) t -> m b t', b=minibatch_size)
    ref_logits = einops.rearrange(ref_logits, '(m b) t d -> m b t d', b=minibatch_size)
    old_values = einops.rearrange(old_values, '(m b) t -> m b t', b=minibatch_size)
    rewards = einops.rearrange(rewards, '(m b) -> m b', b=minibatch_size)
    for i in range(batch_size//minibatch_size):
        yield {'sample_ids': sample_ids[i], 'old_log_probs': old_log_probs[i], 'ref_logits': ref_logits[i], 'old_values': old_values[i], 'rewards': rewards[i]}

In [16]:
n_minibatches_per_epoch = 4
minibatch_size=20
n_epochs = 40
ent_coef = 0.0  # .001
kl_coef = 80  # 0.2
vf_coef = 0.3
n_steps = 300
warmup_steps = 10
lr = 3e-5
gen_len=30

In [17]:
torch.manual_seed(42)

experiment = Experiment(
    api_key=os.getenv("COMET_API_KEY"),
    project_name="learning-to-summarise-using-human-feedback",
    workspace="danesherbs",
    log_env_cpu=False,
    log_env_gpu=False,
)

model = GPTWithValueHead().to(device)
prefix = "This is"
input_ids = tokenizer(prefix, return_tensors='pt').input_ids.to(device)

batch_size = minibatch_size*n_minibatches_per_epoch
prefix_len=input_ids.shape[-1]

optimizer = optim.Adam(model.parameters(), lr=lr)
scheduler = get_lr_scheduler(5, n_steps, 0.1)
lr_scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=scheduler)

def get_loss(sample_ids, old_log_probs, ref_logits, old_values, rewards, prefix_len, clip_range=.2):
    logits, est_values = model(sample_ids)
    log_probs = log_probs_from_logits(logits=logits, input_ids=sample_ids, prefix_len=prefix_len)
    
    kl = kl_divergence(q_logits=logits, p_logits=ref_logits)[:,prefix_len-1:-1]
    kl_loss = kl.mean()
    
    def get_advantages(values, prefix_len):
        one_step_q_est = torch.cat((est_values[:,prefix_len:-1].detach(), rewards[:,None]), dim=-1)
        # s_0a_0r_0s_1a_1r_1s_2a_2r_2s_3a_3r_3s_4
        #          v_1 ---- v_2 ---- v_3 ---- 0
        #        + 0   ---- 0   ---- 0   ---- r_3
        
        zero_step_value_est = est_values[:,prefix_len-1:-1]
        # s_0a_0r_0s_1a_1r_1s_2a_2r_2s_3a_3r_3s_4
        # v_0 ---- v_1 ---- v_2 ---- v_3
    
        advantages = one_step_q_est - zero_step_value_est
        
        return advantages
    
    def my_get_advantages(values, prefix_len):
        seq_len = logits.shape[1]
        rewards_per_timestep = rewards_per_timestep_from_rewards(rewards, seq_len)
        rewards_to_go = rewards_to_go_from_rewards_per_timestep(rewards_per_timestep)
        return None
    
    def get_value_net_loss(est_values):
        shifted_values = shift_tensor_left(est_values).detach()
        seq_len = logits.shape[1]
        rewards_per_timestep = rewards_per_timestep_from_rewards(rewards, seq_len)
        value_net_loss = ((est_values - (rewards_per_timestep + shifted_values)) ** 2).mean()
        return value_net_loss
    
    advantages = get_advantages(est_values, prefix_len)
    value_net_loss = (advantages**2).mean()

    # value_net_loss = get_value_net_loss(est_values)

    # Noa's policy gradient loss
    ratio = torch.exp(log_probs - old_log_probs)
    pg_losses1 = -advantages.detach() * ratio
    pg_losses2 = -advantages.detach() * torch.clamp(ratio, 1.0 - clip_range, 1.0 + clip_range)
    pg_loss = torch.max(pg_losses1, pg_losses2).mean()
    pg_clipfrac = torch.mean(torch.gt(pg_losses2, pg_losses1).double())
    # end

    # My policy gradient loss
    # seq_len = logits.shape[1]
    # rewards_per_timestep = rewards_per_timestep_from_rewards(rewards, seq_len)
    # rewards_to_go = rewards_to_go_from_rewards_per_timestep(rewards_per_timestep)
    # rewards_to_go = rewards_to_go[:, prefix_len:]  # ignore rewards of prefix
    # assert log_probs.shape == rewards_to_go.shape
    # baseline = est_values[:, prefix_len:]
    # assert log_probs.shape == baseline.shape
    # pg_loss = -torch.sum(log_probs * (rewards_to_go - baseline), dim=-1).mean()
    
    pg_loss = -(torch.sum(log_probs, dim=-1) * rewards).mean()
    # end

    loss = pg_loss + vf_coef * value_net_loss + kl_coef * kl_loss
    
    experiment.log_metric('pg_clipfrac', pg_clipfrac.item())
    experiment.log_metric('value_net_loss', value_net_loss.item())
    experiment.log_metric('kl loss', kl_loss.item())
    experiment.log_metric('total_loss', loss.item())
    experiment.log_metric('lr', lr_scheduler.get_last_lr()[0])
    experiment.log_metric('pg_loss', pg_loss.item())

    return loss


for batch_idx in range(n_steps):
    sample_ids, samples, gen_samples = get_samples(model, input_ids=input_ids, batch_size=batch_size, gen_len=gen_len)
    sample_ids = sample_ids.to(device)
    experiment.log_text(gen_samples[0])
    old_logits, old_values = model(sample_ids)
    old_logits, old_values = old_logits.detach(), old_values.detach()
    old_log_probs = get_logprobs(sample_ids[:,input_ids.shape[-1]:], old_logits[:,input_ids.shape[-1]-1:-1]).detach()
    ref_logits, _ = ref_model(sample_ids)
    ref_logits = ref_logits.detach()
    
    rewards = torch.tensor(reward_fn(samples), dtype=torch.float32).to(device)
    experiment.log_metric('mean_reward', rewards.mean())
    rewards = whiten(rewards)  
    
    for epoch in range(1):
        for minibatch in get_minibatches(sample_ids, old_log_probs, ref_logits, old_values, rewards):
            loss = get_loss(**minibatch, prefix_len=2, clip_range=.2)
            loss.backward()
            nn.utils.clip_grad_norm_(model.parameters(), 1.0, norm_type=2.0, error_if_nonfinite=True)
            optimizer.step()
            optimizer.zero_grad()

    lr_scheduler.step()
    
experiment.end()

COMET INFO: Experiment is live on comet.ml https://www.comet.ml/danesherbs/learning-to-summarise-using-human-feedback/c15d8ee3c5db482c9bc4e2c55ff62853

COMET INFO: ---------------------------
COMET INFO: Comet.ml Experiment Summary
COMET INFO: ---------------------------
COMET INFO:   Data:
COMET INFO:     display_summary_level : 1
COMET INFO:     url                   : https://www.comet.ml/danesherbs/learning-to-summarise-using-human-feedback/c15d8ee3c5db482c9bc4e2c55ff62853
COMET INFO:   Metrics [count] (min, max):
COMET INFO:     kl loss [1200]        : (2.3543589300345502e-09, 0.08422788232564926)
COMET INFO:     loss [120]            : (-46.14427185058594, 15.33053970336914)
COMET INFO:     lr [1200]             : (0.0, 3e-05)
COMET INFO:     mean_reward [300]     : (1.2625000476837158, 4.200000286102295)
COMET INFO:     pg_clipfrac [1200]    : (0.0, 0.043333333333333335)
COMET INFO:     pg_loss [1200]        : (-58.8792724609375, 24.934045791625977)
COMET INFO:     total_loss [1

In [18]:
cls_model = transformers.AutoModelForSequenceClassification.from_pretrained('cardiffnlp/twitter-roberta-base-sentiment').to(device)
cls_tokenizer = transformers.AutoTokenizer.from_pretrained('cardiffnlp/twitter-roberta-base-sentiment')

def reward_fn(gen_sample):
    if isinstance(gen_sample, list):
        return [reward_fn(item) for item in gen_sample]
    else:
        logits = cls_model(cls_tokenizer(gen_sample, return_tensors='pt')['input_ids'].to(device)).logits[0]
        logprobs = F.log_softmax(logits, dim=0)
        assert logprobs.shape == (3,)
        return float(logprobs[-1])