In [15]:
from comet_ml import Experiment
import os
import torch
import transformers
import einops
from torchtyping import TensorType
from dotenv import load_dotenv

load_dotenv()  # take environment variables from .env

True

In [16]:
DEVICE = "cuda:2"

In [17]:
tokenizer = transformers.GPT2Tokenizer.from_pretrained('gpt2')
model = transformers.GPT2LMHeadModel.from_pretrained('gpt2', pad_token_id=tokenizer.eos_token_id).to(DEVICE)
ref_model = transformers.GPT2LMHeadModel.from_pretrained('gpt2', pad_token_id=tokenizer.eos_token_id).to(DEVICE)

In [18]:
def count_periods(s: str) -> int:
    return sum(1 if c == "." else 0 for c in s)

In [5]:
def lr_schedule(epoch: int) -> float:
    if epoch <= 40:
        return epoch / 40
    
    return 1.0


def train(model, reference_model, num_steps, num_tokens_to_generate, batch_size, lr, kl_loss_coefficient, print_every=1, device=DEVICE):
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_schedule)
    vocab_size = len(tokenizer)
    
    for i in range(num_steps):
        optimizer.zero_grad()
        
        input_ids = torch.full(size=(batch_size, 1), fill_value=tokenizer.eos_token_id).to(device)
        response_ids = model.generate(
            input_ids,
            max_length=num_tokens_to_generate,
            min_length=num_tokens_to_generate,
            do_sample=True,
            temperature=0.6,
            top_p=1.0,
            top_k=vocab_size,
        )
        reponse_texts = tokenizer.batch_decode(response_ids)
        rewards = torch.tensor([count_periods(response) for response in reponse_texts], dtype=torch.float32).to(device)
        
        rewards_mean = torch.mean(rewards, dim=-1, keepdim=True)
        rewards_std = torch.std(rewards, dim=-1, keepdim=True)
        rewards_normed = (rewards - rewards_mean) / (rewards_std + 1e-6)

        log_probs = log_prob_from_input_ids(model, response_ids)
        reference_log_probs = log_prob_from_input_ids(reference_model, response_ids)

        probs = torch.exp(reference_log_probs)
        reference_probs = torch.exp(reference_log_probs)

        kl_loss = torch.nn.functional.kl_div(reference_probs, probs)
        loss = -(log_probs * rewards_normed).mean() + kl_loss_coefficient * kl_loss.mean()
        
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()
        scheduler.step()
        
        if (i+1) % print_every == 0:
            print(f"Step {(i+1):2} | loss {loss:.5f} | avg reward {rewards.mean():.5f} | lr {scheduler.get_last_lr()[0]}")


def log_prob_from_input_ids(model: torch.nn.Module, input_ids: TensorType["batch_size", "seq_len"]) -> TensorType["batch_size"]:
    """
    Calculates p(x_1, x_2, ..., x_n | x_0) = p(x_1 | x_0) * p(x_2 | x_1, x_0) * ... * p(x_{n} | x_{n-1}, ..., x_0)
    """

    logits = model(input_ids).logits[:, :-1]  # [batch_size, seq_len-1, vocab_size] -- ignore x_{n+1}
    log_probs = torch.nn.functional.log_softmax(logits, dim=-1) 
    
    input_ids = input_ids[:, 1:]  # [batch_size, seq_len-1, vocab_size]  -- ignore x_0
    input_ids = input_ids.unsqueeze(-1)  # [batch_size, seq_len] -> [batch_size, seq_len, 1]
    seq_log_probs = torch.gather(input=log_probs, dim=-1, index=input_ids)
    seq_log_probs = seq_log_probs.squeeze(-1) # [batch_size, seq_len, 1] -> [batch_size, seq_len]

    return torch.sum(seq_log_probs, dim=-1)  # [batch_size,]


# train(
#     model=model,
#     reference_model=ref_model,
#     num_steps=100,
#     num_tokens_to_generate=20,
#     batch_size=20,
#     lr=3e-5,
#     kl_loss_coefficient=0.2,
#     print_every=10,
# )

In [6]:
input_ids = tokenizer(["This is"], return_tensors="pt").input_ids.to(DEVICE)
response_ids = model.generate(
    input_ids,
    max_length=20,
    min_length=20,
    do_sample=True,
    temperature=0.6,
    top_p=1.0,
    top_k=len(tokenizer),
)
tokenizer.batch_decode(response_ids)

['This is the first time the company has made a statement that it is not aware of a problem with']

In [19]:
class GPTWithValueHead(torch.nn.Module):
    
    def __init__(self):
        super().__init__()
        self.tokenizer = transformers.GPT2Tokenizer.from_pretrained('gpt2')
        self.gpt = transformers.GPT2LMHeadModel.from_pretrained('gpt2', pad_token_id=self.tokenizer.eos_token_id)
        hidden_size = self.gpt.transformer.wte.weight.shape[-1]
        self.value_network = torch.nn.Sequential(
            torch.nn.Linear(hidden_size, 4 * hidden_size),
            torch.nn.ReLU(),
            torch.nn.Linear(4 * hidden_size, 1),
        )
    
    def forward(self, input_ids):
        response = self.gpt(input_ids, output_hidden_states=True)  # [batch_size, num_layers, hidden_dim]
        last_hidden_state = response.hidden_states[-1]  # [batch_size, seq_len, hidden_size]
        values = self.value_network(last_hidden_state).squeeze(-1)
        logits = response.logits  # [batch_size, seq_len, vocab_size]
        return logits, values


model_with_value_head = GPTWithValueHead().to(DEVICE)

In [27]:
def lr_schedule(epoch: int) -> float:
    if epoch <= 40:
        return epoch / 40
    
    return 1.0


def train_with_value_network(model, reference_model, num_steps, num_tokens_to_generate, batch_size, lr, kl_loss_coefficient, print_every=1, device=DEVICE):
    experiment = Experiment(
        api_key=os.getenv("COMET_API_KEY"),
        project_name="learning-to-summarise-using-human-feedback",
        workspace="danesherbs",
        log_env_cpu=False,
        log_env_gpu=False,
    )

    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_schedule)
    
    vocab_size = len(tokenizer)
    seq_len = num_tokens_to_generate  # for readability

    for i in range(num_steps):
        optimizer.zero_grad()
        
        # Rewards
        input_ids = torch.full(size=(batch_size, 1), fill_value=tokenizer.eos_token_id).to(device)
        response_ids = reference_model.generate(
            input_ids,
            max_length=num_tokens_to_generate,
            min_length=num_tokens_to_generate,
            do_sample=True,
            temperature=0.6,
            top_p=1.0,
            top_k=vocab_size,
        )  # [batch_size, seq_len]
        reponse_texts = tokenizer.batch_decode(response_ids)
        rewards = torch.tensor([count_periods(response) for response in reponse_texts], dtype=torch.float32).to(device)  # [batch_size]
        
        assert response_ids.shape == (batch_size, seq_len)
        assert rewards.shape == (batch_size,)
        
        rewards_mean = torch.mean(rewards, dim=-1, keepdim=True)
        rewards_std = torch.std(rewards, dim=-1, keepdim=True)
        rewards_normed = (rewards - rewards_mean) / (rewards_std + 1e-6)

        # Query model
        logits, values = model(response_ids)
        assert values.shape == (batch_size, seq_len)
        
        # Value network
        # shifted_values = torch.cat([values[:, 1:], torch.zeros((batch_size, 1)).to(DEVICE)], dim=-1).detach()  # [v_0, v_1, ..., v_n] -> [v_1, ..., v_n, 0]
        shifted_values = shift_tensor_left(values).detach()
        rewards_per_timestep = torch.cat([torch.zeros(batch_size, seq_len-1).to(DEVICE), rewards.unsqueeze(-1)], dim=-1)  # [batch_size, seq_len]
        value_net_loss = ((values  - (rewards_per_timestep + shifted_values)) ** 2).mean()

        experiment.log_metric("value_net_loss", value_net_loss)
        experiment.log_metric("learning_rate", scheduler.get_last_lr()[0])
        
        if i % print_every == 0:
            # print(f"Value {list(values[0].cpu().detach().numpy())}\nDecoded {[tokenizer.decode(tok) for tok in response_ids[0]]}")
            print(list(zip(list(values[0].cpu().detach().numpy()), [tokenizer.decode(tok) for tok in response_ids[0]])))
        
        value_net_loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()
        scheduler.step()

        continue
        
        # Policy network
        reference_logits = reference_model(response_ids).logits
        log_probs = log_probs_from_logits(logits, response_ids)
        reference_log_probs = log_probs_from_logits(reference_logits, response_ids)

        probs = torch.exp(torch.sum(reference_log_probs, dim=-1))
        reference_probs = torch.exp(torch.sum(reference_log_probs, dim=-1))

        kl_loss = torch.nn.functional.kl_div(reference_probs, probs)
        policy_net_loss = -(log_probs * rewards_normed).mean() + kl_loss_coefficient * kl_loss.mean()
        
        policy_net_loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()
        scheduler.step()
        
        if (i+1) % print_every == 0:
            print(f"Step {(i+1):2} | value loss {value_net_loss:.5f} | policy loss {policy_net_loss:.5f} | avg reward {rewards.mean():.5f} | value lr {value_net_scheduler.get_last_lr()[0]} | policy lr {scheduler.get_last_lr()[0]}")
    
    experiment.end()




train_with_value_network(
    model=model_with_value_head,
    reference_model=ref_model,
    num_steps=250,
    num_tokens_to_generate=20,
    batch_size=40,
    lr=3e-5,
    kl_loss_coefficient=0.2,
    print_every=20,
)

COMET INFO: Experiment is live on comet.ml https://www.comet.ml/danesherbs/learning-to-summarise-using-human-feedback/53d782a1d4614750a5886a7494c6abfe



[(0.71381366, '<|endoftext|>'), (2.271839, '\n'), (1.5142207, 'The'), (1.8790493, ' boy'), (1.4063834, ' who'), (1.834041, ' was'), (1.7273602, ' murdered'), (1.5485362, ' in'), (1.6271002, ' the'), (0.27771175, ' aftermath'), (1.2530975, ' of'), (1.4790462, ' the'), (1.0655807, ' Sandy'), (0.99257016, ' Hook'), (0.11359513, ' Elementary'), (0.4894036, ' School'), (1.356084, ' shooting'), (2.1453974, ' was'), (1.1290555, ' a'), (4.679453, ' 17')]
[(1.0935266, '<|endoftext|>'), (1.0750427, 'The'), (1.3006316, ' game'), (1.2943171, ' is'), (1.2408774, ' based'), (0.96647996, ' on'), (1.0045441, ' the'), (0.85677874, ' Japanese'), (1.0331157, ' RPG'), (1.4218235, ' series'), (1.3857914, '.'), (1.4985912, ' In'), (1.2338622, ' addition'), (1.1306831, ' to'), (1.2686871, ' its'), (1.1506214, ' Japanese'), (0.55834204, ' name'), (0.92333907, ','), (1.1911438, ' the'), (1.0125614, ' game')]
[(0.05698468, '<|endoftext|>'), (0.23794253, '\n'), (0.045566875, 'F'), (0.2945436, 'ounded'), (0.38689

KeyboardInterrupt: 

In [35]:
def log_probs_from_logits(logits: TensorType["batch_size", "seq_len", "vocab_size"], input_ids: TensorType["batch_size", "seq_len"]) -> TensorType["batch_size", "seq_len"]:
    """
    Calculates p(x_1, x_2, ..., x_n | x_0) = p(x_1 | x_0) * p(x_2 | x_1, x_0) * ... * p(x_{n} | x_{n-1}, ..., x_0)
    """
    logits = logits[:, :-1]  # [batch_size, seq_len-1, vocab_size] -- ignore x_{n+1}
    log_probs = torch.nn.functional.log_softmax(logits, dim=-1) 
    
    input_ids = input_ids[:, 1:]  # [batch_size, seq_len-1, vocab_size]  -- ignore x_0
    input_ids = input_ids.unsqueeze(-1)  # [batch_size, seq_len, vocab_size, 1]
    seq_log_probs = torch.gather(input=log_probs, dim=-1, index=input_ids)
    seq_log_probs = seq_log_probs.squeeze(-1) # [batch_size, seq_len]

    return seq_log_probs

logits = torch.tensor([[[3, 6, 1], [2, 8, 0], [3, 6, 1]]], dtype=torch.float32)
input_ids = torch.tensor([[1, 0, 2]])
log_probs = log_probs_from_logits(logits, input_ids)
softmaxed = torch.nn.functional.log_softmax(logits, dim=-1)
assert torch.allclose(log_probs, torch.tensor([[softmaxed[0, 1, 0], softmaxed[0, 2, 2]]])), f"{log_probs} {softmaxed}"

AssertionError: tensor([[-3.0550, -8.0028]]) tensor([[[-3.0550e+00, -5.4985e-02, -5.0550e+00],
         [-6.0028e+00, -2.8102e-03, -8.0028e+00],
         [-3.0550e+00, -5.4985e-02, -5.0550e+00]]])

In [28]:
def shift_tensor_left(t: TensorType["batch_size", "seq_len"]) -> TensorType["batch_size", "seq_len"]:
    """
    Shifts tensor left one and fills with zeros: [v_0, v_1, ..., v_n] -> [v_1, ..., v_n, 0].
    
    Note: you probably want to detach the result of this function.
    """
    shifted = torch.zeros_like(t)
    shifted[:, :-1] = t[:, 1:]
    return shifted

t = torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8]])
shifted_t = shift_tensor_left(t)
assert torch.allclose(shifted_t, torch.tensor([[2, 3, 4, 0], [6, 7, 8, 0]]))

In [25]:
def rewards_to_go_from_rewards_and_values(rewards: TensorType["batch_size", "seq_len"], values: TensorType["batch_size", "seq_len"]) -> TensorType["batch_size", "seq_len"]:
    """
    Computes r_t + v(s_{t+1}) - v(s_t) for t = 0, ..., T-1
    """
    baseline = values
    values_left_shifted = shift_tensor_left(values).detach()
    # rewards_reversed = torch.flip(rewards, dim=-1)  # [r_T, ..., r_1, r_0]
    # rewards_reversed_cumsum = torch.cumsum(rewards_reversed, dim=-1)  # [r_T + ... + r_0, r_T + ... + r_1, ..., r_T]
    # rewards_to_go = torch.flip(rewards_reversed_cumsum)
    return rewards + values_left_shifted - baseline

rewards = torch.tensor([[0.5, 1, 0.5]])
values = torch.tensor([[0.1, 0, 0.3]])
rewards_to_go = rewards_to_go_from_rewards_and_values(rewards, values)
assert torch.allclose(rewards_to_go, torch.tensor([[0.5 + 0 - 0.1, 1 + 0.3 - 0, 0.5 + 0 - 0.3]]))