In [8]:
from transformers import GPT2Tokenizer, AutoModelForCausalLM
from torch.nn.functional import log_softmax
import torch
import numpy as np
import json
from matplotlib import pyplot as plt
from tqdm import tqdm
import random

In [2]:
model = AutoModelForCausalLM.from_pretrained('gpt2')
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')

In [4]:
def encode(string):
    ids, breakpoints = [], []
    spl = string.split()
    for tok in spl:
        encoded = tokenizer.encode(tok)
        ids.extend(encoded)
        breakpoints.append(len(ids) + 1)
    ids = [tokenizer.bos_token_id] + ids + [tokenizer.eos_token_id]
    return ids, breakpoints
    
encode("Outperforms ALIGN in supervised entity linking")

([50256,
  7975,
  525,
  23914,
  1847,
  16284,
  259,
  16668,
  16149,
  26858,
  75,
  8040,
  50256],
 [3, 5, 6, 8, 9, 11])

In [5]:
taskC_train = []
with open ("data/subtaskA_train_monolingual.jsonl", "r", encoding="utf-8") as f:
    for line in f:
        if line:
            parsed = json.loads(line)
            taskC_train.append((parsed["text"], parsed["label"], parsed["id"]))
        

In [6]:
def perplexity(logits, labels):
    sm = log_softmax(logits, dim=-1)
    probs = []
    for i in range(labels.shape[1] - 1):
        next_idx = labels[0, i + 1].item()
        scores = sm[0, i, :]
        prob_next_idx = scores[next_idx].item()
        probs.append(prob_next_idx)

    probs = np.array(probs)
    probs = np.sum(probs)
    l = labels.shape[1] - 1
    ppl = - (1 / l) * np.sum(probs)
    return np.exp(ppl)

In [9]:
result = []
window_size = 1024
slide_amount = 256
random.shuffle(taskC_train)
firstk = taskC_train[:100]

for i, (text, label, id) in tqdm(enumerate(firstk), total=len(firstk)):
    input_ids, breakpoints = encode(text)
    perplexities = []
    if len(input_ids) > window_size:
        window_start = 0
        window_end = window_size
        window = input_ids[window_start:window_end]
        out = model(torch.tensor(window).unsqueeze(0))
        for bp in breakpoints:
            if (bp + 1) > window_end:
                window_start += slide_amount
                window_end += slide_amount
                window = input_ids[window_start:window_end]
            bp = bp - window_start
            partiad_ids = window[0:bp + 1]
            partiad_ids = torch.tensor(partiad_ids).unsqueeze(0)
            partial_logits = out.logits[:, 0:bp + 1, :]
            try:
                ppl = perplexity(partial_logits, partiad_ids)
                perplexities.append(ppl)
            except:
                print(partiad_ids.shape, partial_logits.shape, bp, window_start, window_end, len(input_ids))
    else:
        input_ids = torch.tensor(input_ids).unsqueeze(0)
        out = model(input_ids)
        for bp in breakpoints:
            partiad_ids = input_ids[:, 0:bp + 1]
            partial_logits = out.logits[:, 0:bp + 1, :]
            ppl = perplexity(partial_logits, partiad_ids)
            perplexities.append(ppl)
    result.append(perplexities)
        
    

100%|██████████| 100/100 [20:57<00:00, 12.57s/it]


In [None]:
with open("data/taskA_perplexities.json", "w", encoding="utf-8") as f:
    json.dump(result, f, indent=4)