In [1]:
import torch
import transformers
import einops
from torchtyping import TensorType

In [2]:
DEVICE = "cuda:1"

In [3]:
tokenizer = transformers.GPT2Tokenizer.from_pretrained('gpt2')
model = transformers.GPT2LMHeadModel.from_pretrained('gpt2', pad_token_id=tokenizer.eos_token_id).to(DEVICE)
ref_model = transformers.GPT2LMHeadModel.from_pretrained('gpt2').to(DEVICE)

In [4]:
def count_periods(s: str) -> int:
    return sum(1 if c == "." else 0 for c in s)

In [5]:
def lr_schedule(epoch: int) -> float:
    if epoch <= 40:
        return epoch / 40
    
    return 1.0


def train(model, reference_model, num_steps, num_tokens_to_generate, batch_size, lr, kl_loss_coefficient, print_every=1, device=DEVICE):
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_schedule)
    vocab_size = len(tokenizer)

    for i in range(num_steps):
        optimizer.zero_grad()
        
        input_ids = torch.full(size=(batch_size, 1), fill_value=tokenizer.eos_token_id).to(device)
        response_ids = model.generate(
            input_ids,
            max_length=num_tokens_to_generate,
            min_length=num_tokens_to_generate,
            do_sample=True,
            temperature=0.6,
            top_p=1.0,
            top_k=vocab_size,
        )
        reponse_texts = tokenizer.batch_decode(response_ids)
        rewards = torch.tensor([count_periods(response) for response in reponse_texts], dtype=torch.float32).to(device)
        
        rewards_mean = torch.mean(rewards, dim=-1, keepdim=True)
        rewards_std = torch.std(rewards, dim=-1, keepdim=True)
        rewards_normed = (rewards - rewards_mean) / (rewards_std + 1e-6)

        log_probs = log_prob_of_sequence(model, response_ids)
        reference_log_probs = log_prob_of_sequence(reference_model, response_ids)

        probs = torch.exp(reference_log_probs)
        reference_probs = torch.exp(reference_log_probs)

        kl_loss = torch.nn.functional.kl_div(reference_probs, probs)
        loss = -(log_probs * rewards_normed).mean() + kl_loss_coefficient * kl_loss.mean()
        
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()
        scheduler.step()
        
        if (i+1) % print_every == 0:
            print(f"Step {(i+1):2} | loss {loss:.5f} | avg reward {rewards.mean():.5f} | lr {scheduler.get_last_lr()[0]}")


def log_prob_of_sequence(model: torch.nn.Module, input_ids: TensorType["batch_size", "seq_len"]) -> TensorType["batch_size"]:
    """
    Calculates p(x_1, x_2, ..., x_n | x_0) = p(x_1 | x_0) * p(x_2 | x_1, x_0) * ... * p(x_{n} | x_{n-1}, ..., x_0)
    """

    logits = model(input_ids).logits[:, :-1]  # [batch_size, seq_len-1, vocab_size] -- ignore x_{n+1}
    log_probs = torch.nn.functional.log_softmax(logits, dim=-1) 
    
    input_ids = input_ids[:, 1:]  # [batch_size, seq_len-1, vocab_size]  -- ignore x_0
    input_ids = input_ids.unsqueeze(-1)  # [batch_size, seq_len] -> [batch_size, seq_len, 1]
    seq_log_probs = torch.gather(input=log_probs, dim=-1, index=input_ids)
    seq_log_probs = seq_log_probs.squeeze(-1) # [batch_size, seq_len, 1] -> [batch_size, seq_len]

    return torch.sum(seq_log_probs, dim=-1)  # [batch_size,]


train(
    model=model,
    reference_model=ref_model,
    num_steps=100,
    num_tokens_to_generate=20,
    batch_size=20,
    lr=3e-5,
    kl_loss_coefficient=0.2,
    print_every=10,
)



Step 10 | loss -9.76044 | avg reward 1.95000 | lr 7.5e-06
Step 20 | loss -17.20097 | avg reward 3.70000 | lr 1.5e-05
Step 30 | loss -0.00038 | avg reward 19.00000 | lr 2.25e-05
Step 40 | loss -0.00038 | avg reward 19.00000 | lr 3e-05
Step 50 | loss -0.00038 | avg reward 19.00000 | lr 3e-05
Step 60 | loss -0.00038 | avg reward 19.00000 | lr 3e-05
Step 70 | loss -0.00038 | avg reward 19.00000 | lr 3e-05
Step 80 | loss -0.00038 | avg reward 19.00000 | lr 3e-05
Step 90 | loss -0.00038 | avg reward 19.00000 | lr 3e-05
Step 100 | loss -0.00038 | avg reward 19.00000 | lr 3e-05


In [6]:
input_ids = tokenizer(["This is"], return_tensors="pt").input_ids.to(DEVICE)
response_ids = model.generate(
    input_ids,
    max_length=20,
    min_length=20,
    do_sample=True,
    temperature=0.6,
    top_p=1.0,
    top_k=len(tokenizer),
)
tokenizer.batch_decode(response_ids)

['This is where things get weird and crazy. At least offensively, it was hard."\n\nListen']

In [7]:
class LayerHook:
    def __init__(self):
        self.module = None
        self.input = None
        self.output = None
        
    def __call__(self, module, input, output):
        self.module = module
        self.input = input
        self.output = output


class GPTWithValueNetwork(torch.nn.Module):
    def __init__(self):
        self.hook = LayerHook()
        self.gpt = transformers.GPT2LMHeadModel.from_pretrained('gpt2', pad_token_id=tokenizer.eos_token_id)
        self.gpt.transformer.h[11].register_forward_hook(self.hook)
        self.linear = torch.nn.Linear(768, 1)
    
    def forward(self, input_ids):
        policy = self.gpt(input_ids).logits
        value = self.linear(self.hook.output)
        return policy, value



In [22]:
def lr_schedule(epoch: int) -> float:
    if epoch <= 40:
        return epoch / 40
    
    return 1.0


def train_temporal_difference(model, reference_model, value_network, num_steps, num_tokens_to_generate, batch_size, lr, kl_loss_coefficient, print_every=1, device=DEVICE):
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_schedule)
    vocab_size = len(tokenizer)

    for i in range(num_steps):
        optimizer.zero_grad()
        
        input_ids = torch.full(size=(batch_size, 1), fill_value=tokenizer.eos_token_id).to(device)
        response_ids = model.generate(
            input_ids,
            max_length=num_tokens_to_generate,
            min_length=num_tokens_to_generate,
            do_sample=True,
            temperature=0.6,
            top_p=1.0,
            top_k=vocab_size,
        )
        reponse_texts = tokenizer.batch_decode(response_ids)
        rewards = torch.tensor([count_periods(response) for response in reponse_texts], dtype=torch.float32).to(device)
        
        rewards_mean = torch.mean(rewards, dim=-1, keepdim=True)
        rewards_std = torch.std(rewards, dim=-1, keepdim=True)
        rewards_normed = (rewards - rewards_mean) / (rewards_std + 1e-6)

        with torch.no_grad():
            [hidden_states] = model(response_ids, output_hidden_states=True)
            last_hidden_state = hidden_states[:, -1]
        
        values = value_network(last_hidden_state)  # [batch_size, seq_len]

        log_probs = log_prob_of_sequence(model, response_ids)
        reference_log_probs = log_prob_of_sequence(reference_model, response_ids)

        probs = torch.exp(reference_log_probs)
        reference_probs = torch.exp(reference_log_probs)

        kl_loss = torch.nn.functional.kl_div(reference_probs, probs)
        loss = -(log_probs * rewards_normed).mean() + kl_loss_coefficient * kl_loss.mean()
        
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()
        scheduler.step()
        
        if (i+1) % print_every == 0:
            print(f"Step {(i+1):2} | loss {loss:.5f} | avg reward {rewards.mean():.5f} | lr {scheduler.get_last_lr()[0]}")


def log_prob_of_sequence(model: torch.nn.Module, input_ids: TensorType["batch_size", "seq_len"]) -> TensorType["batch_size"]:
    """
    Calculates p(x_1, x_2, ..., x_n | x_0) = p(x_1 | x_0) * p(x_2 | x_1, x_0) * ... * p(x_{n} | x_{n-1}, ..., x_0)
    """

    logits = model(input_ids).logits[:, :-1]  # [batch_size, seq_len-1, vocab_size] -- ignore x_{n+1}
    log_probs = torch.nn.functional.log_softmax(logits, dim=-1) 
    
    input_ids = input_ids[:, 1:]  # [batch_size, seq_len-1, vocab_size]  -- ignore x_0
    input_ids = input_ids.unsqueeze(-1)  # [batch_size, seq_len] -> [batch_size, seq_len, 1]
    seq_log_probs = torch.gather(input=log_probs, dim=-1, index=input_ids)
    seq_log_probs = seq_log_probs.squeeze(-1) # [batch_size, seq_len, 1] -> [batch_size, seq_len]

    return torch.sum(seq_log_probs, dim=-1)  # [batch_size,]


train_temporal_difference(
    model=model,
    reference_model=ref_model,
    num_steps=100,
    num_tokens_to_generate=20,
    batch_size=16,
    lr=3e-5,
    kl_loss_coefficient=0.2,
    print_every=10,
)

torch.Size([16, 20, 768])
