In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import tiktoken
import torch
from trl import AutoModelForCausalLMWithValueHead, PPOConfig, PPOTrainer
from torch.nn import functional as F

device = "cuda"

sample_rng = torch.Generator(device=device)

model = AutoModelForCausalLMWithValueHead.from_pretrained("gpt2")
model.to(device)

def add_key_prefix(prefix, d):
    return dict((prefix + k, v) if "v_head" not in k else (k, v) for k, v in d.items())

In [None]:
# def add_key_prefix(prefix, d):
#     return dict((prefix + k, v) if "v_head" not in k else (k, v) for k, v in d.items())

In [None]:
model.load_state_dict(add_key_prefix("pretrained_model.", torch.load("log/model_00000.pt")["student_model"]))
model.load_state_dict(add_key_prefix("pretrained_model.", torch.load("log/model_01000.pt")["student_model"]))
model.load_state_dict(add_key_prefix("pretrained_model.", torch.load("log/model_02000.pt")["student_model"]))
model.load_state_dict(add_key_prefix("pretrained_model.", torch.load("log/model_03000.pt")["student_model"]))
model.load_state_dict(add_key_prefix("pretrained_model.", torch.load("log/model_04000.pt")["student_model"]))
model.load_state_dict(add_key_prefix("pretrained_model.", torch.load("log/model_08000.pt")["student_model"]))

In [None]:
from transformers import GPT2LMHeadModel

model = GPT2LMHeadModel.from_pretrained("gpt2")
model.load_state_dict(torch.load("log/model_08000.pt")["expert_model"])
model.to(device)

In [None]:
enc = tiktoken.get_encoding('gpt2')
# tokens = enc.encode("You are all resolved. ")
tokens = enc.encode("1+1=")
tokens = torch.tensor(tokens, dtype=torch.long) # (8,)
tokens = tokens.unsqueeze(0).repeat(5, 1) # (5, 8)

xgen = tokens.to(device)
max_length = 60

while xgen.size(1) < max_length:
    # forward the model to get the logits
    with torch.no_grad():
        with torch.autocast(device_type=device, dtype=torch.bfloat16):
            logits = model(xgen)[0] # (B, T, vocab_size)
        # take the logits at the last position
        logits = logits[:, -1, :] # (B, vocab_size)
        # get the probabilities
        probs = F.softmax(logits, dim=-1)
        # do top-k sampling of 50 (huggingface pipeline default)
        # topk_probs here becomes (5, 50), topk_indices is (5, 50)
        topk_probs, topk_indices = torch.topk(probs, 50, dim=-1)
        # select a token from the top-k probabilities
        # note: multinomial does not demand the input to sum to 1
        ix = torch.multinomial(topk_probs, 1, generator=sample_rng) # (B, 1)
        # gather the corresponding indices
        xcol = torch.gather(topk_indices, -1, ix) # (B, 1)
        # append to the sequence
        xgen = torch.cat((xgen, xcol), dim=1)

output = xgen

for i in range(5):
    output_tokens = output[i, :60].tolist()
    decoded = enc.decode(output_tokens)
    print(">", decoded)