In [1]:
from comet_ml import Experiment
import transformers
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import copy
import einops
import os
from dotenv import load_dotenv

load_dotenv()  # take environment variables from .env

True

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device = "cuda:3"

In [3]:
class GPTWithValueHead(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = transformers.GPT2LMHeadModel.from_pretrained('gpt2')
        d_model = self.model.transformer.wte.weight.shape[-1]
        self.value_head = nn.Sequential(
            nn.Linear(d_model, 4*d_model),
            nn.ReLU(),
            nn.Linear(4*d_model, 1)
        )
        self.generate = self.model.generate
    def forward(self, input_ids):
        outputs = self.model(input_ids, output_hidden_states = True)
        logits = outputs.logits
        values = self.value_head(outputs.hidden_states[-1]).squeeze(-1)
        return logits, values

ref_model = GPTWithValueHead().to(device)
ref_model.eval()
tokenizer = transformers.GPT2Tokenizer.from_pretrained('gpt2');

In [4]:
def get_samples(model, input_ids, batch_size=50, gen_len=10):
    with torch.no_grad():
        samples = model.generate(input_ids, max_length=input_ids.shape[-1]+gen_len, min_length=input_ids.shape[-1]+gen_len, do_sample=True, temperature=0.6, top_k=len(tokenizer), top_p=1.0, num_return_sequences=batch_size)
        gen_samples = samples[:, input_ids.shape[-1]:]
        sample_ids = copy.deepcopy(samples)
        samples = tokenizer.batch_decode(samples)
        gen_samples = tokenizer.batch_decode(gen_samples)
    return sample_ids, samples, gen_samples

In [5]:
input_ids = tokenizer.encode('Testing', return_tensors='pt').to(device)
sample_ids, samples, gen_samples =  get_samples(model=ref_model, input_ids=input_ids, batch_size=5)

Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


In [6]:
def reward_fn(gen_sample):
    if isinstance(gen_sample, list):
        return [reward_fn(item) for item in gen_sample]
    else:
        return gen_sample.count('.')

def reward_fn_test():
    A = 'This is a test.'
    assert reward_fn(A) == 1
    B = '......'
    assert reward_fn(B) ==6
    C = 'Whatever'
    assert reward_fn(C) == 0
    assert reward_fn([A, B, C]) == [1, 6, 0]

    print('Passed test.')
    return

reward_fn_test()

Passed test.


In [7]:
def get_logprobs(input_ids, logits):
    logprobs = F.log_softmax(logits, dim=-1)
    logprobs = torch.gather(logprobs, -1, input_ids[:,:,None])[:,:,0]
    return logprobs

def logprobs_test(logprobs_fn):
    input_ids = torch.randint(0, 100, (10, 10))
    logits = torch.randn(10, 10, 100)
    ref_logprobs = get_logprobs(input_ids, logits)
    logprobs = logprobs_fn(input_ids, logits)
    assert torch.allclose(logprobs, ref_logprobs)


In [8]:
def get_logprobs(input_ids, logits):
    logprobs = F.log_softmax(logits, dim=-1)
    logprobs = torch.gather(logprobs, -1, input_ids[:,:,None])[:,:,0]
    return logprobs

def logprobs_test(logprobs_fn):
    input_ids = torch.randint(0, 100, (10, 10))
    logits = torch.randn(10, 10, 100)
    ref_logprobs = get_logprobs(input_ids, logits)
    logprobs = logprobs_fn(input_ids, logits)
    assert torch.allclose(logprobs, ref_logprobs)


In [9]:
def get_kl(logits, ref_logits, eps=1e-4):
    ref_probs = torch.nn.functional.softmax(ref_logits, dim=-1)
    return (ref_probs * (torch.log(ref_probs)-F.log_softmax(logits, dim=-1))).sum(-1)

def kl_divergence(p_logits, q_logits):
    p = torch.nn.functional.softmax(p_logits, dim=-1)
    q = torch.nn.functional.softmax(q_logits, dim=-1)
    log_p = torch.log(p)
    log_q = torch.log(q)
    return torch.sum(p * log_p - p * log_q, dim=-1)

def kl_divergence_test():
    for _ in range(10):
        p = torch.rand(2, 3, 5)
        q = torch.rand(2, 3, 5)
        actual = kl_divergence(p, q)
        expected = get_kl(q, p)
        assert torch.allclose(actual, expected), f"{actual}, {expected}"


kl_divergence_test()

In [10]:
def get_entropy(logits):
    probs = F.softmax(logits, dim=-1)
    log_probs = torch.log(probs)
    entropy = -(probs * log_probs).sum(dim=-1)
    return entropy

In [11]:
def get_lr_scheduler(warmup_steps, total_steps, final_scale):
    def lr_scheduler(step):
        if step < warmup_steps:
            return step / warmup_steps
        else:
            return 1-(1-final_scale)*(step-warmup_steps)/(total_steps-warmup_steps)
    return lr_scheduler

In [12]:
def whiten(t, eps=1e-5):
    t = t - t.mean()
    t = t/(t.std()+eps)
    return t

In [13]:
def get_minibatches(sample_ids, old_logprobs, ref_logits, old_values, rewards):
    sample_ids = einops.rearrange(sample_ids,'(m b) t -> m b t', b=minibatch_size)
    old_logprobs = einops.rearrange(old_logprobs, '(m b) t -> m b t', b=minibatch_size)
    ref_logits = einops.rearrange(ref_logits, '(m b) t d -> m b t d', b=minibatch_size)
    old_values = einops.rearrange(old_values, '(m b) t -> m b t', b=minibatch_size)
    rewards = einops.rearrange(rewards, '(m b) -> m b', b=minibatch_size)
    for i in range(batch_size//minibatch_size):
        yield {'sample_ids': sample_ids[i], 'old_logprobs': old_logprobs[i], 'ref_logits': ref_logits[i], 'old_values': old_values[i], 'rewards': rewards[i]}

In [14]:
n_minibatches_per_epoch = 4
minibatch_size=20
n_epochs = 40
ent_coef = 0.0  # .001
kl_coef = .2
vf_coef = .3
n_steps = 300
warmup_steps = 10
lr = 3e-5
gen_len=30

In [15]:
torch.manual_seed(42)

experiment = Experiment(
    api_key=os.getenv("COMET_API_KEY"),
    project_name="learning-to-summarise-using-human-feedback",
    workspace="danesherbs",
    log_env_cpu=False,
    log_env_gpu=False,
)

model = GPTWithValueHead().to(device)
prefix = 'This is'
input_ids = tokenizer(prefix, return_tensors='pt').input_ids.to(device)

batch_size = minibatch_size*n_minibatches_per_epoch
prefix_len=input_ids.shape[-1]

optimizer = optim.Adam(model.parameters(), lr=lr)
scheduler = get_lr_scheduler(5, n_steps, 0.1)
lr_scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=scheduler)

def get_loss(sample_ids, old_logprobs, ref_logits, old_values, rewards, prefix_len, clip_range=.2):
    logits, est_values = model(sample_ids)
    logprobs = get_logprobs(sample_ids[:,prefix_len:], logits[:,prefix_len-1:-1])
    
    entropy = get_entropy(logits[:,prefix_len-1:-1])
    ent_loss = -entropy.mean()

    kl = kl_divergence(q_logits=logits, p_logits=ref_logits)[:,prefix_len-1:-1]
    kl_loss = kl.mean()
    
    def get_advantages(values, prefix_len):
        one_step_q_est = torch.cat((est_values[:,prefix_len:-1].detach(), rewards[:,None]), dim=-1)
        # s_0a_0r_0s_1a_1r_1s_2a_2r_2s_3a_3r_3s_4
        #          v_1 ---- v_2 ---- v_3 ---- 0
        #        + 0   ---- 0   ---- 0   ---- r_3
    
        zero_step_value_est = est_values[:,prefix_len-1:-1]
        # s_0a_0r_0s_1a_1r_1s_2a_2r_2s_3a_3r_3s_4
        # v_0 ---- v_1 ---- v_2 ---- v_3
    
        advantages = one_step_q_est - zero_step_value_est
        
        return advantages

    advantages = get_advantages(est_values, prefix_len)

    vf_loss = (advantages**2).mean()

    ratio = torch.exp(logprobs - old_logprobs)
    pg_losses1 = -advantages.detach() * ratio
    pg_losses2 = -advantages.detach() * torch.clamp(ratio, 1.0 - clip_range, 1.0 + clip_range)
    pg_loss = torch.max(pg_losses1, pg_losses2).mean()
    pg_clipfrac = torch.mean(torch.gt(pg_losses2, pg_losses1).double())

    loss = pg_loss + vf_coef * vf_loss + kl_coef * kl_loss + ent_coef * ent_loss
    
    experiment.log_metric('pg_clipfrac', pg_clipfrac.item())
    experiment.log_metric('vf_loss', vf_loss.item())
    experiment.log_metric('mean kl', kl_loss.item())
    experiment.log_metric('total_loss', loss.item())
    experiment.log_metric('lr', lr_scheduler.get_last_lr()[0])
    experiment.log_metric('pg_loss', pg_loss.item())
    experiment.log_metric('mean entropy', -ent_loss.item())

    return loss
    

for batch_idx in range(n_steps):
    sample_ids, samples, gen_samples = get_samples(model, input_ids=input_ids, batch_size=batch_size, gen_len=gen_len)
    sample_ids = sample_ids.to(device)
    experiment.log_text(gen_samples[0])
    old_logits, old_values = model(sample_ids)
    old_logits, old_values = old_logits.detach(), old_values.detach()
    old_logprobs = get_logprobs(sample_ids[:,input_ids.shape[-1]:], old_logits[:,input_ids.shape[-1]-1:-1]).detach()
    ref_logits, _ = ref_model(sample_ids)
    ref_logits = ref_logits.detach()

    rewards = torch.tensor(reward_fn(samples), dtype=torch.float32).to(device)
    experiment.log_metric('mean_reward', rewards.mean())
    rewards = whiten(rewards)  
    
    for epoch in range(1):
        for minibatch in get_minibatches(sample_ids, old_logprobs, ref_logits, old_values, rewards):
            loss = get_loss(**minibatch, prefix_len=2, clip_range=.2)
            loss.backward()
            nn.utils.clip_grad_norm_(model.parameters(), 1.0, norm_type=2.0, error_if_nonfinite=True)
            optimizer.step()
            optimizer.zero_grad()

    lr_scheduler.step()
    
experiment.end()

COMET INFO: Experiment is live on comet.ml https://www.comet.ml/danesherbs/learning-to-summarise-using-human-feedback/66ac5f6e0a9543888efca88b306372bd

Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:50256 for open-

In [16]:
cls_model = transformers.AutoModelForSequenceClassification.from_pretrained('cardiffnlp/twitter-roberta-base-sentiment').to(device)
cls_tokenizer = transformers.AutoTokenizer.from_pretrained('cardiffnlp/twitter-roberta-base-sentiment')

def reward_fn(gen_sample):
    if isinstance(gen_sample, list):
        return [reward_fn(item) for item in gen_sample]
    else:
        logits = cls_model(cls_tokenizer(gen_sample, return_tensors='pt')['input_ids'].to(device)).logits[0]
        logprobs = F.log_softmax(logits, dim=0)
        assert logprobs.shape == (3,)
        return float(logprobs[-1])