# Autoregressive Generation Wavelet Perturbation Test

In [12]:
import torch
from transformers import GPT2Tokenizer, GPT2LMHeadModel
import pywt
import numpy as np

### Load Model and Tokenizer

In [25]:
model_name = "gpt2" 
tokenizer = GPT2Tokenizer.from_pretrained(model_name)
model = GPT2LMHeadModel.from_pretrained(model_name)
model.eval()

GPT2LMHeadModel(
  (transformer): GPT2Model(
    (wte): Embedding(50257, 768)
    (wpe): Embedding(1024, 768)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0-11): 12 x GPT2Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2Attention(
          (c_attn): Conv1D(nf=2304, nx=768)
          (c_proj): Conv1D(nf=768, nx=768)
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D(nf=3072, nx=768)
          (c_proj): Conv1D(nf=768, nx=3072)
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  )
  (lm_head): Linear(in_features=768, out_features=50257, bias=False)
)

### Tokenize Input Prompt

In [26]:
prompt = "Italy is known for what food?"
inputs = tokenizer.encode_plus(prompt, return_tensors="pt")
inputs

{'input_ids': tensor([[45001,   318,  1900,   329,   644,  2057,    30]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1]])}

### Extract Token Embeddings

In [27]:
with torch.no_grad():
    input_ids = inputs["input_ids"]
    embeddings = model.transformer.wte(input_ids)  # shape: [1, seq_len, embed_dim]

embeddings[0].shape  # (seq_len, embed_dim)

torch.Size([7, 768])

### Wavelet Perturbation Function

In [28]:
def wavelet_perturb(embeddings_tensor, wavelet='haar'):
    """
    Applies discrete wavelet transform per token embedding vector,
    zeros out high-frequency coefficients, and reconstructs embeddings.
    """
    emb_np = embeddings_tensor.squeeze(0).cpu().numpy()  # [seq_len, embed_dim]
    perturbed_np = np.zeros_like(emb_np)

    for i in range(emb_np.shape[0]):  # for each token vector
        coeffs = pywt.dwt(emb_np[i, :], wavelet)
        cA, cD = coeffs
        # Zero out high frequency coefficients
        cD[:] = 0
        # Reconstruct embedding
        perturbed_np[i, :] = pywt.idwt(cA, cD, wavelet)

    perturbed_tensor = torch.tensor(perturbed_np, dtype=torch.float32).unsqueeze(0)
    return perturbed_tensor

### Apply Wavelet Perturbation to Embeddings

In [29]:
perturbed_embeddings = wavelet_perturb(embeddings)
perturbed_embeddings.shape

torch.Size([1, 7, 768])

### Test Model Inputs with Original and Perturbed Embeddings

In [30]:
def sample_next_token(logits, temperature=1.0, top_k=50):
    # Temperature scaling
    logits = logits / temperature
    # Top-k filtering
    top_k_values, top_k_indices = torch.topk(logits, top_k)
    probs = torch.softmax(top_k_values, dim=-1)
    next_token = top_k_indices[0, torch.multinomial(probs, num_samples=1)]
    return next_token.unsqueeze(0)  # [1, 1]

def generate_from_embeddings(start_embeddings, start_ids, max_new_tokens=20):
    generated_embeds = start_embeddings
    generated_ids = start_ids

    for _ in range(max_new_tokens):
        with torch.no_grad():
            outputs = model(inputs_embeds=generated_embeds)
            next_token_logits = outputs.logits[:, -1, :]  # [1, vocab_size]
            next_token_id = sample_next_token(next_token_logits)  # [1, 1]

        # Ensure next_token_id is 2D
        next_token_id = next_token_id.view(1, 1)

        # Append new token ID
        generated_ids = torch.cat([generated_ids, next_token_id], dim=1)

        # Append new token embedding
        next_token_emb = model.transformer.wte(next_token_id)  # [1, 1, embed_dim]
        generated_embeds = torch.cat([generated_embeds, next_token_emb], dim=1)

    return tokenizer.decode(generated_ids[0], skip_special_tokens=True)

### Decode Output Logits

Prompt: "Italy is known for what food?"

In [32]:
original_text = generate_from_embeddings(embeddings, input_ids, max_new_tokens=50)
perturbed_text = generate_from_embeddings(perturbed_embeddings, input_ids, max_new_tokens=50)

print("Original generation:")
print(original_text)
print("\nPerturbed generation:")
print(perturbed_text)

Original generation:
Italy is known for what food? But what are they? If there was a restaurant near the coast to where I live on the beach, it would be on my way home for a vacation.

There's an idea that a restaurant is the answer to being in a different world

Perturbed generation:
Italy is known for what food?

The following are my thoughts on this issue. I would like to remind the people of my country, that I, and the government my country, have been a part of this and I would like to be an observer in this, and one


### Quantify Changes in Output Logits