In [None]:
%load_ext autoreload
%autoreload 2
import numpy as np
from copy import deepcopy
import sys
import datasets
import matplotlib.pyplot as plt
import numpy as np
sys.path.append('../experiments/')
import os
import scipy
import pandas as pd
import numpy as np
import transformers
import sys
from os.path import join
import datasets
from dict_hash import sha256
import numpy as np
from torch.autograd import grad
import torch
from tqdm import tqdm
from torch.autograd.functional import jacobian
from torch.func import jacfwd
from transformers import AutoTokenizer, AutoModelForCausalLM

# load model
checkpoint = 'gpt2'
device = 'cuda'
model = AutoModelForCausalLM.from_pretrained(checkpoint)
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
model.eval()
model = model.to(device)

# load dset
dset_train = datasets.load_dataset('rotten_tomatoes')['train']
dset_train = dset_train.select(np.random.choice(
    len(dset_train), size=100, replace=False))

In [None]:
class NaiveBayesHuggingFaceClassifier:
    def __init__(self, model, tokenizer):
        self.model = model
        self.tokenizer = tokenizer

    def __call__(self, x):
        x = self.tokenizer(x, return_tensors='pt', padding=True)
        x = {k: v.to(device) for k, v in x.items()}
        logits = self.model(**x).logits
        return logits.detach().cpu().numpy()[0]

### Groundtruth (eval all bigrams)

In [None]:
def get_unigram_logits(prefix, tokenizer, model):
    input_ids = tokenizer.encode(prefix, return_tensors="pt").to(model.device)
    return model(input_ids=input_ids).logits[:, -1, :].detach().cpu().numpy()


get_unigram_logits('The', tokenizer, model)

In [None]:
prefix = 'The most popular words that might appear in a positive movie review are'
batch_size = 2048
voc_size = model.transformer.wte.weight.shape[0]

prefix_embs = tokenizer.encode(prefix, return_tensors="pt")[0]
input_ids_bigrams = [
    torch.concatenate(
        (prefix_embs, torch.LongTensor([i])))
    for i in tqdm(range(voc_size))
    # for j in range(voc_size)
]
input_ids_bigrams = torch.stack(input_ids_bigrams).to(model.device)

In [None]:
bigram_logits = np.zeros((voc_size, voc_size))
unigram_logits = get_unigram_logits(prefix, tokenizer, model)
for i in tqdm(range(0, voc_size, batch_size)):
    input_ids = input_ids_bigrams[i:i + batch_size]
    bigram_logits[i:i + batch_size] = \
        model(input_ids).logits[:, -1, :].detach().cpu().numpy()

In [None]:
unigram1_probs = scipy.special.softmax(unigram_logits).flatten()
unigram2_probs = scipy.special.softmax(bigram_logits, axis=1)

# multiply each row of unigram2_probs by the corresponding value in unigram_probs
bigram_probs = unigram2_probs * unigram1_probs[:, None]
assert np.allclose(bigram_probs.sum(axis=1), unigram1_probs)
assert np.allclose(bigram_probs.sum(), 1)

In [None]:
# find top 100 bigrams (row, col) pairs - this took like 20 mins
bigram_probs_flat = bigram_probs.flatten()
bigram_probs_flat_sorted = np.sort(bigram_probs_flat)[::-1]
bigram_probs_flat_sorted_idx = np.argsort(bigram_probs_flat)[::-1]
bigram_probs_sorted = bigram_probs_flat_sorted.reshape(bigram_probs.shape)
bigram_probs_sorted_idx = np.unravel_index(
    bigram_probs_flat_sorted_idx, bigram_probs.shape)
bigram_probs_sorted_idx = np.stack(bigram_probs_sorted_idx, axis=1)
bigram_probs_sorted_idx = bigram_probs_sorted_idx[:100]

In [None]:
top_bigrams = tokenizer.batch_decode(bigram_probs_sorted_idx)

### Evaluate jacobian-based logits

In [None]:
prefix = 'The most popular words that might appear in a positive movie review are'
input_ids = tokenizer.encode(prefix, return_tensors="pt").to(model.device)
input_embs = model.transformer.wte(input_ids)
# output = model(inputs_embeds=input_embs)


def forward_embs(embs):
    output = model(inputs_embeds=embs)
    return output.logits

In [None]:
input_embs = input_embs.to('cuda')
model = model.to('cuda')

In [None]:
def forward_embs(embs):
    # torch.zeros(1, 6, 768).shape
    output = model(inputs_embeds=embs)
    return output.logits

In [None]:
# jac = jacobian(
# forward_embs, input_embs, strategy='forward-mode', vectorize=True)
jac = jacfwd(forward_embs)(input_embs)
jac = jac.squeeze()

In [None]:
jac.shape

In [None]:
jac_input = jac @ model.transformer.wte.weight.T  # output x input
jac_input.device

In [None]:
jac_input = jac_input.cpu().detach().numpy()

In [None]:
jac_input[:3, :3]

In [None]:
plt.imshow(jac_input)

# Take multi-dimensional grad

In [None]:
# Define the input tensor
x = torch.tensor(torch.normal(0, 1, size=(1, 5)), requires_grad=True)

# Compute the output of the function
param = torch.ones(5, 4)
param[0, :] = 2


def f(x):
    return x @ param


# Compute the gradient of y with respect to x
# gradient = torch.autograd.grad(y, x, torch.ones_like(y))[0]
# gradient = jacobian(f, x)[0]
gradient = jacfwd(f)(x)

print("Input x:", x)
print("Output y:", f(x))
print("Jacob dy/dx:", gradient)