In [6]:
%load_ext autoreload
%autoreload 2

In [12]:
from comet_ml import Experiment

import collections
import einops
import matplotlib.pyplot as plt
import math
import numpy as np
import torch as t
from torch import nn
import transformers
from IPython.core.display import HTML, display
from trl.gpt2 import GPT2HeadWithValueModel, respond_to_batch
from trl.ppo import PPOTrainer
from typing import Tuple, List, Dict
from tqdm import tqdm
from functools import partial

from minigpt_utils import get_minigpt, MiniGPT
from days.utils import *

In [13]:
device = "cuda:1"

In [14]:
tokenizer = transformers.GPT2Tokenizer.from_pretrained("gpt2")
tokenizer._add_tokens(["[BEGIN]", "[END]"])
tokenizer.pad_token = "[END]"
tokenizer.eos_token = "[END]"
ref_model = GPT2HeadWithValueModel.from_pretrained("gpt2").to(device)

Some weights of GPT2HeadWithValueModel were not initialized from the model checkpoint at gpt2 and are newly initialized: ['h.0.attn.masked_bias', 'h.1.attn.masked_bias', 'h.2.attn.masked_bias', 'h.3.attn.masked_bias', 'h.4.attn.masked_bias', 'h.5.attn.masked_bias', 'h.6.attn.masked_bias', 'h.7.attn.masked_bias', 'h.8.attn.masked_bias', 'h.9.attn.masked_bias', 'h.10.attn.masked_bias', 'h.11.attn.masked_bias', 'lm_head.weight', 'v_head.summary.bias', 'v_head.summary.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [15]:
BEGIN = t.tensor(50257, dtype=t.long, device=device).unsqueeze(0)

In [16]:
def train(
    gen_model,
    ref_model,
    ppo_config,
    reward_fn,
    reward_fn_args,
    gen_len:int = 20,
    num_batches:int = 2,
    logging:bool = True,
    comet_tag:str = "experiment"
):
    ppo_trainer = PPOTrainer(gen_model, ref_model, **ppo_config)

    # encode a query
    # query_txt = "This morning I went to the "
    query_tensor = (t.tensor([tokenizer.bos_token_id], dtype=t.long, device=device)
                    .unsqueeze(0)
                    .repeat(ppo_config['batch_size'], 1))
    
    if logging:
        experiment = Experiment(
            api_key="72XQSdnwnBcob4Q8NpbJHewll",
            project_name="jenny-dan",
            workspace="danielb",
            auto_output_logging=False,
        )
        experiment.add_tag(comet_tag)
        experiment.log_parameters(ppo_trainer.ppo_params)
        experiment.log_parameters(reward_fn_args)
    try:
        batch_info = []
        for batch in tqdm(range(num_batches)):
            # get model response
            response_tensor  = respond_to_batch(gen_model, query_tensor, txt_len=gen_len)
            response_txt = tokenizer.batch_decode(response_tensor)

            # reward fn may return auxillary dict of metrics that we are not scoring on
            reward, reward_metrics = reward_fn(response_tensor, response_txt, **reward_fn_args)

            # train model with ppo
            train_stats = ppo_trainer.step(query_tensor, response_tensor, reward)
            batch_info.append(train_stats)

            if logging:
                experiment.log_metric("ppo_reward", train_stats['ppo/returns/mean'][0])
                experiment.log_metric("reward", t.mean(reward))
                experiment.log_metric("policy loss", train_stats['ppo/loss/policy'][0])
                experiment.log_metric("value head loss", train_stats['ppo/loss/value'][0])
                # experiment.log_metric("unnormalized reward", rewards_per_batch.mean())
                # experiment.log_metric("entropy", entropy)
                experiment.log_metric("kl", train_stats['objective/kl'])
                experiment.log_metric("kl_coef", train_stats['objective/kl_coef'])
                # experiment.log_metric("grad norm", grad_norm)
                
                for metric in reward_metrics.items():
                    experiment.log_metric(metric[0], metric[1])
                
                if batch % 8 == 0:
                    experiment.log_text(response_txt[0], metadata={"reward": reward[0]})
                    
                if batch % 64 == 64-1:
                    fname = f"evalmodel_{experiment.get_name()}_b{batch}"
                    t.save(gen_model, fname)
                    if batch % 128 == 128-1:
                        experiment.log_model(fname, fname)

        return batch_info
    finally:
        if logging:
            experiment.end()

In [17]:
def silly_reward(response_tensor: t.Tensor, response_txt: List) -> Tuple[float, Dict]:
        reward = t.tensor([s.count(".") for s in response_txt], device=device).float()
        return reward, {}

In [18]:
def head_sum(response_tensor: t.Tensor, response_txt: List, layer: int, head: int, eval_model):
    response_with_begin = t.cat((BEGIN.repeat(response_tensor.shape[0], 1), response_tensor), dim=-1)
    weighted_attns = eval_model.weighted_attention(response_with_begin)
    vwas = weighted_attns[layer,:,head,:,:]
    return t.sum(vwas, dim=[-1,-2]), {"max_attn": t.max(vwas).item()}

def head_max(response_tensor: t.Tensor, response_txt: List, layer: int, head: int, eval_model):
    response_with_begin = t.cat((BEGIN.repeat(response_tensor.shape[0], 1), response_tensor), dim=-1)
    weighted_attns = eval_model.weighted_attention(response_with_begin)
    vwas = weighted_attns[layer,:,head,:,:]
    return t.max(t.max(vwas, dim=-1).values, dim=-1).values, {"sum_attn": t.sum(vwas).item()}

In [None]:
eval_model = get_minigpt("model.pt").to(device)
ppo_config = {
    "batch_size": 128,
    "forward_batch_size": 8,
    "adap_kl_ctrl": True,
    "init_kl_coef": 0.2,
}
reward_fn_args = {
    "layer": 1,
    "head": 4,
    "eval_model": eval_model
}
gen_model = GPT2HeadWithValueModel.from_pretrained("gpt2").to(device)
batch_info=train(gen_model, ref_model, ppo_config, head_max, reward_fn_args, gen_len=48, num_batches=256, logging=True)

Some weights of GPT2HeadWithValueModel were not initialized from the model checkpoint at gpt2 and are newly initialized: ['h.0.attn.masked_bias', 'h.1.attn.masked_bias', 'h.2.attn.masked_bias', 'h.3.attn.masked_bias', 'h.4.attn.masked_bias', 'h.5.attn.masked_bias', 'h.6.attn.masked_bias', 'h.7.attn.masked_bias', 'h.8.attn.masked_bias', 'h.9.attn.masked_bias', 'h.10.attn.masked_bias', 'h.11.attn.masked_bias', 'v_head.summary.bias', 'v_head.summary.weight', 'lm_head.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
COMET INFO: Experiment is live on comet.ml https://www.comet.ml/danielb/jenny-dan/eeab7776d2a645a8a169d45318e413bc

  2%|▏         | 5/256 [02:47<2:20:43, 33.64s/it]

In [None]:
eval_model = get_minigpt("model.pt").to(device)
ppo_config = {
    "batch_size": 128,
    "forward_batch_size": 8,
    "adap_kl_ctrl": True,
    "init_kl_coef": 0.2,
}
# I might actually prefer the explicit one // I like the for loop just to make sure we haven't missed any lol
for layer, head in [(x,y) for x in range(2) for y in range(8)]:
    reward_fn_args = {
        "layer": layer,
        "head": head,
        "eval_model": eval_model
    }
    gen_model = GPT2HeadWithValueModel.from_pretrained("gpt2").to(device)
    batch_info=train(gen_model, ref_model, ppo_config, head_max, reward_fn_args, gen_len=48, num_batches=256, logging=True, comet_tag=f"{layer=},{head=}")

Some weights of GPT2HeadWithValueModel were not initialized from the model checkpoint at gpt2 and are newly initialized: ['h.0.attn.masked_bias', 'h.1.attn.masked_bias', 'h.2.attn.masked_bias', 'h.3.attn.masked_bias', 'h.4.attn.masked_bias', 'h.5.attn.masked_bias', 'h.6.attn.masked_bias', 'h.7.attn.masked_bias', 'h.8.attn.masked_bias', 'h.9.attn.masked_bias', 'h.10.attn.masked_bias', 'h.11.attn.masked_bias', 'lm_head.weight', 'v_head.summary.bias', 'v_head.summary.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
COMET INFO: Experiment is live on comet.ml https://www.comet.ml/danielb/jenny-dan/2ecdfc1f6d78407bad9a8d7485095366

  0%|          | 1/256 [00:33<2:22:02, 33.42s/it]COMET INFO: invalid metadata, expecting JSON-encodable object
  2%|▏         | 5/256 [02:46<2:19:31, 33.35s/it]