In [187]:
import os
import sys

from tqdm.notebook import tqdm
import torch
import torch.nn.functional as F

import tiktoken
import itertools

import numpy as np
import pandas as pd
import pickle as pkl

sys.path.insert(1, '../../models/')
from model import GPT, GPTConfig

device = torch.device('cuda' if torch.cuda.is_available() else "cpu")

In [23]:
all_dfs = []

for filename in os.listdir('.'):
    
    if not filename.endswith('.jsonl'):
        continue

    cur_df = pd.read_json(filename, lines=True)
    all_dfs.append(cur_df)

In [24]:
df = pd.concat(all_dfs).reset_index(drop=True)

In [33]:
checkpoint = torch.load('../../models/out/ckpt-5-5-2.5-48-mean-0.pt', map_location=device)
model_args = checkpoint['model_args']
model_args['position_dir'] = '../../models/gpt2-positions-5-5'

gptconf = GPTConfig(**model_args)
model = GPT(gptconf)

state_dict = checkpoint['model']
# fix the keys of the state dictionary :(
# honestly no idea how checkpoints sometimes get this prefix, have to debug more
unwanted_prefix = '_orig_mod.'
for k,v in list(state_dict.items()):
    if k.startswith(unwanted_prefix):
        state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)

model.load_state_dict(state_dict, strict=False)
model.eval()

tokenizer = tiktoken.get_encoding('gpt2')
pad_token = tokenizer.encode('<|endoftext|>', allowed_special="all")[0]

number of parameters: 127.97M


In [119]:
prompt = 'hello i am'

start_ids = tokenizer.encode(prompt)
x = (torch.tensor(start_ids, dtype=torch.long, device=device)[None, ...])

In [125]:
x.shape

torch.Size([1, 3])

In [123]:
@torch.no_grad()
def generate(model, idx, max_new_tokens, temperature=1.0, top_k=None):
    """
    Take a conditioning sequence of indices idx (LongTensor of shape (b,t)) and complete
    the sequence max_new_tokens times, feeding the predictions back into the model each time.
    Most likely you'll want to make sure to be in model.eval() mode of operation for this.
    """
    for _ in range(max_new_tokens):
        # if the sequence context is growing too long we must crop it at block_size
        idx_cond = idx if idx.size(1) <= model.config.block_size else idx[:, -model.config.block_size:]
        # forward the model to get the logits for the index in the sequence
        logits, _, _, _, _ = model(idx_cond)
        # pluck the logits at the final step and scale by desired temperature
        logits = logits[:, -1, :] / temperature
        # optionally crop the logits to only the top k options
        if top_k is not None:
            v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
            logits[logits < v[:, [-1]]] = -float('Inf')
        # apply softmax to convert logits to (normalized) probabilities
        probs = F.softmax(logits, dim=-1)
        # sample from the distribution
        idx_next = torch.multinomial(probs, num_samples=1)
        # append sampled index to the running sequence and continue
        idx = torch.cat((idx, idx_next), dim=1)

    return idx


In [124]:
generate(model, x, 10)

tensor([[31373,  1312,   716,  1654,   284,   651,   340,  1107,   329,   502,
            11,  1521,   561]])

In [178]:
@torch.no_grad()
def surprisal(model, tokens):
    """
    compute surprisal of a batch of tokens
    """
    B, L = tokens.shape
    context = tokens[:, 0].unsqueeze(1)
    surp = torch.zeros(B)
    
    for i in range(L - 1):
        
        logits, _, _, _, _ = model(context)
        logits = logits[:, -1, :]
        probs = F.softmax(logits, dim=-1)
        
        next_token = tokens[:, i + 1]
        surp += -torch.log(probs[range(B), next_token])
        
        # append sampled index to the running sequence and continue
        context = torch.cat((context, tokens[:, i + 1].unsqueeze(1)), dim=1)

    return surp

In [182]:
def tokenize_batch(sents):
    tokens = tokenizer.encode_batch(sents, allowed_special = 'all')
    padded = list(zip(*itertools.zip_longest(*tokens, fillvalue=pad_token)))
    return torch.from_numpy(np.array(padded))

In [190]:
good_surps = []
bad_surps = []

batch_size = 48
for i in tqdm(range(0, len(df), batch_size)):
    batch_df = df.iloc[i:i + batch_size]
    
    good_surps += list(surprisal(model, tokenize_batch(batch_df['sentence_good'])))
    bad_surps += list(surprisal(model, tokenize_batch(batch_df['sentence_bad'])))

  0%|          | 0/1396 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [None]:
df['good_surps'] = good_surps
df['bad_surps'] = bad_surps
df['correct'] = df['good_surp'] < df['bad_surp']

In [None]:
df.to_csv('blimp_results.csv')

In [None]:
df.read_