In [1]:
%load_ext autoreload
%autoreload 2

In [3]:
import os
import json
import math
import random
import numpy as np
from tqdm import tqdm
from scipy.stats import t as tdist

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

model_name_or_path = "gpt2-xl"
dataset_path = "/kaggle/input/boolq-dev/dev.jsonl"
context_len = 256
permutations_per_shard = 20
max_examples = 5000

num_shards = 50
context_len = 256
stride = 1024
device = "cuda"

In [4]:
def load_dataset(dataset_path):
    # For loading a JSON-serialized list of examples.
    if dataset_path.endswith(".json"):
        print("loading from json...")
        with open(dataset_path, "r") as f:
            data = f.read()
            examples = json.loads(data)
            return examples

    # For loading a dataset where each example is on its own line.
    with open(dataset_path, "r") as f:
        lines = f.readlines()
    return lines

In [12]:
examples = load_dataset(dataset_path)
examples = examples[:max_examples]
num_examples = len(examples)
print(f"Loaded {num_examples} examples from {dataset_path}")

Loaded 3270 examples from /kaggle/input/boolq-dev/dev.jsonl


In [13]:
t = AutoTokenizer.from_pretrained(model_name_or_path)
tokenized_examples = [t.encode(ex) for ex in examples]

Token indices sequence length is longer than the specified maximum sequence length for this model (1072 > 1024). Running this sequence through the model will result in indexing errors


In [14]:
def compute_logprob_of_token_sequence(tokens, model, context_len=2048, stride=1024, device=0):
    inputs = tokens[:-1]
    targets = tokens[1:]

    logp = torch.zeros((1, 1), dtype=torch.float32).to(device)

    # compute the smallest multiple k of s so that t <= ks + c.
    for j in range(math.ceil(max(0, len(inputs) - context_len) / stride)):
        start = stride * j
        end = min(stride * j + context_len, len(inputs))
        rel_offs = max(0, context_len - stride) if j > 0 else 0

        w_inp = inputs[start:end]
        w_inp = torch.tensor(w_inp).to(device)
        w_trg = targets[start:end]
        w_trg = torch.tensor(w_trg).to(device)

        model.eval()
        with torch.no_grad():
            out = model(torch.unsqueeze(w_inp, 0))
            logps = torch.nn.functional.log_softmax(out.logits[0], dim=-1)
            logps = logps.gather(-1, w_trg.unsqueeze(-1)).squeeze(-1)
            logp += logps[rel_offs:].sum()

        del w_inp
        del w_trg
        torch.cuda.empty_cache()

    return logp.item()

In [15]:
import gc

gc.collect()
torch.cuda.empty_cache()

In [16]:
shard_idx = enumerate([num_examples // num_shards] * num_shards)
shard_counts = [(x + 1 if i < num_examples % num_shards else x) for i, x in shard_idx]
shard_bounds = [0] + np.cumsum(np.asarray(shard_counts)).tolist()

m = AutoModelForCausalLM.from_pretrained(model_name_or_path)
m.to(device)

GPT2LMHeadModel(
  (transformer): GPT2Model(
    (wte): Embedding(50257, 1600)
    (wpe): Embedding(1024, 1600)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0-47): 48 x GPT2Block(
        (ln_1): LayerNorm((1600,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2Attention(
          (c_attn): Conv1D()
          (c_proj): Conv1D()
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((1600,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D()
          (c_proj): Conv1D()
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm((1600,), eps=1e-05, elementwise_affine=True)
  )
  (lm_head): Linear(in_features=1600, out_features=50257, bias=False)
)

In [17]:
flatten = lambda l : [x for s in l for x in s]
shuffle = lambda l : random.sample(l, k=len(l))

In [18]:
canon, shuffled = [], []

for start, end in tqdm(list(zip(shard_bounds, shard_bounds[1:]))):
    cur_tokens = flatten(tokenized_examples[start:end])
    canon.append(compute_logprob_of_token_sequence(cur_tokens, m, context_len, stride, device))
    shuffled.append([])
    for _ in range(permutations_per_shard):
        shuffled[-1].append(compute_logprob_of_token_sequence(shuffle(cur_tokens), m, context_len, stride, device))

100%|██████████| 50/50 [27:45<00:00, 33.32s/it]


In [19]:
def t_test(canon, shuffled):
    diffs = canon - shuffled.mean(axis=1)
    z = np.mean(diffs) / np.std(diffs) * np.sqrt(len(diffs))
    pval = 1 - tdist.cdf(z, df=len(diffs)-1)
    return pval

In [20]:
canon = np.asarray(canon)
shuffled = np.asarray(shuffled)
t_test(canon, shuffled)

0.0

In [None]:
canon, shuffled = [], []

for start, end in tqdm(list(zip(shard_bounds, shard_bounds[1:]))):
    cur_tokens = flatten(tokenized_examples[start:end])
    canon.append(compute_logprob_of_token_sequence(cur_tokens, m, context_len, stride, device))
    shuffled.append([])
    for _ in range(permutations_per_shard):
        shuffled[-1].append(compute_logprob_of_token_sequence(shuffle(cur_tokens), m, context_len, stride, device))