## Load packages and model

In [3]:
import torch
import torch.nn.functional as F
from transformers import GPT2Tokenizer, GPT2LMHeadModel
import pywt
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd

In [2]:
model_name = "gpt2" 
tokenizer = GPT2Tokenizer.from_pretrained(model_name)
model = GPT2LMHeadModel.from_pretrained(model_name)
model.eval()

GPT2LMHeadModel(
  (transformer): GPT2Model(
    (wte): Embedding(50257, 768)
    (wpe): Embedding(1024, 768)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0-11): 12 x GPT2Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2Attention(
          (c_attn): Conv1D(nf=2304, nx=768)
          (c_proj): Conv1D(nf=768, nx=768)
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D(nf=3072, nx=768)
          (c_proj): Conv1D(nf=768, nx=3072)
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  )
  (lm_head): Linear(in_features=768, out_features=50257, bias=False)
)

## Load and tokenize prompts

In [8]:
df = pd.read_csv("./data/prompts.csv", quotechar='"')
df.head(-1)

Unnamed: 0,id,prompt
0,1,Who discovered penicillin?
1,2,When was the Eiffel Tower built?
2,3,What is the capital of Australia?
3,4,How tall is Mount Everest?
4,5,Who painted the Mona Lisa?
...,...,...
94,95,"Finish: 'In the morning, I always…'"
95,96,Give an example of a metaphor.
96,97,List two programming languages.
97,98,Write a haiku about winter.


In [10]:
tokenized_prompts = []

for prompt in df['prompt']:
    tokenized_prompts.append(tokenizer.encode_plus(prompt, return_tensors="pt"))

tokenized_prompts[0]

{'input_ids': tensor([[ 8241,  5071,  3112,   291, 32672,    30]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1]])}

## Define Pertrubing methods

1. Zeroing High-Frequency
2. Zeroing Low-Frequency

In [11]:
def wavelet_perturb(embeddings_tensor, wavelet='haar'):
    """
    Applies discrete wavelet transform per token embedding vector,
    zeros out high-frequency coefficients, and reconstructs embeddings.
    """
    emb_np = embeddings_tensor.squeeze(0).cpu().numpy()  # [seq_len, embed_dim]
    perturbed_np = np.zeros_like(emb_np)

    for i in range(emb_np.shape[0]):  # for each token vector
        coeffs = pywt.dwt(emb_np[i, :], wavelet)
        cA, cD = coeffs
        # Zero out high frequency coefficients
        cD[:] = 0
        # Reconstruct embedding
        perturbed_np[i, :] = pywt.idwt(cA, cD, wavelet)

    perturbed_tensor = torch.tensor(perturbed_np, dtype=torch.float32).unsqueeze(0)
    return perturbed_tensor

def wavelet_perturb2(embeddings_tensor, wavelet='haar'):
    """
    Applies discrete wavelet transform per token embedding vector,
    zeros out low-frequency coefficients, and reconstructs embeddings.
    """
    emb_np = embeddings_tensor.squeeze(0).cpu().numpy()  # [seq_len, embed_dim]
    perturbed_np = np.zeros_like(emb_np)

    for i in range(emb_np.shape[0]):  # for each token vector
        coeffs = pywt.dwt(emb_np[i, :], wavelet)
        cA, cD = coeffs
        # Zero out high frequency coefficients
        cA[:] = 0
        # Reconstruct embedding
        perturbed_np[i, :] = pywt.idwt(cA, cD, wavelet)

    perturbed_tensor = torch.tensor(perturbed_np, dtype=torch.float32).unsqueeze(0)
    return perturbed_tensor

## Embed tokenized inputs

In [20]:
embeddings = []

with torch.no_grad():
    for prompt in tokenized_prompts:
        input_ids = prompt['input_ids']
        embedding = model.transformer.wte(input_ids)
        embeddings.append(embedding.cpu().numpy())

embeddings[1].shape

(1, 9, 768)

## Define pertrubed and original inputs

In [28]:
# Zero out high frequency components
zero_high_freq = []

for emb in embeddings:
    emb_tensor = torch.tensor(emb, dtype=torch.float32)
    perturbed_emb = wavelet_perturb(emb_tensor)
    zero_high_freq.append(perturbed_emb)

# Zero out low frequency components
zero_low_freq = []

for emb in embeddings:
    emb_tensor = torch.tensor(emb, dtype=torch.float32)
    perturbed_emb = wavelet_perturb2(emb_tensor)
    zero_low_freq.append(perturbed_emb)

# Original embeddings
original_embeddings = []

for emb in embeddings:
    original_embeddings.append(torch.tensor(emb, dtype=torch.float32))

print(zero_high_freq[1].shape)
print(zero_low_freq[1].shape)
print(original_embeddings[1].shape)

torch.Size([1, 9, 768])
torch.Size([1, 9, 768])
torch.Size([1, 9, 768])


## Auto-regressive top-k logits

- (x = 5) x auto-regressive tokens to predict
- (k = 10) top-k logits for interpretability

In [34]:
def autoregressive_topk(model, tokenizer, embeddings_list, x=5, k=10):
    """
    embeddings_list: list of torch tensors, each (1, seq_len, hidden_dim)
    Returns: list of length len(embeddings_list), each element is a list of x steps,
             each step is a list of k top-k token strings
    """
    all_topk_predictions = []

    for emb in embeddings_list:
        topk_predictions = []
        generated_embeds = emb.clone()  # (1, seq_len, hidden_dim)

        with torch.no_grad():
            for _ in range(x):
                outputs = model(inputs_embeds=generated_embeds)  # (1, seq_len, vocab_size)
                next_token_logits = outputs.logits[:, -1, :]     # (1, vocab_size)

                # Top-k
                topk_probs, topk_ids = torch.topk(next_token_logits, k=k, dim=-1)
                topk_tokens = [tokenizer.decode(ids).strip() for ids in topk_ids[0]]
                topk_predictions.append(topk_tokens)

                # Append top-1 token embedding to continue autoregressively
                next_token_id = topk_ids[:, 0].unsqueeze(-1)  # (1,1)
                next_token_embed = model.transformer.wte(next_token_id)  # (1,1,hidden_dim)
                generated_embeds = torch.cat([generated_embeds, next_token_embed], dim=1)

        all_topk_predictions.append(topk_predictions)

    return all_topk_predictions

## Extract top-k logits

## Softmax for interpreatible probabilities

## Define probability distributions