In [2]:
import os, torch
from pathlib import Path
from transformers import AutoModelForCausalLM, AutoTokenizer

torch.manual_seed(42)
cache_dir = (Path.cwd() / "models").resolve()
cache_dir.mkdir(parents=True, exist_ok=True)

device = (
    "cuda" if torch.cuda.is_available()
    # else ("mps" if torch.backends.mps.is_available() else "cpu")
    else "cpu"
)

os.environ["HF_HOME"] = str(cache_dir)
print(f'Device: {device}')

model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2-medium").to(device)
tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2-medium")
tokenizer.pad_token = tokenizer.eos_token
model.eval();

Device: cpu


In [3]:
gamma = model.lm_head.weight.detach()
W, d = gamma.shape
gamma_bar = torch.mean(gamma, dim=0)
centered_gamma = gamma - gamma_bar

### compute Cov(gamma) and tranform gamma to g ###
cov_gamma = centered_gamma.T @ centered_gamma / W
eigenvalues, eigenvectors = torch.linalg.eigh(cov_gamma)

inv_sqrt_cov_gamma = eigenvectors @ torch.diag(1/torch.sqrt(eigenvalues)) @ eigenvectors.T
sqrt_cov_gamma = eigenvectors @ torch.diag(torch.sqrt(eigenvalues)) @ eigenvectors.T

# gamma is our original head and inv_sqrt_cov_gamma puts us in a causal basis
g = gamma @ inv_sqrt_cov_gamma

# maybe i confused but A_inv = sqrt_cov_gamma and A = inv_sqrt_cov_gamma for 
# l(x).T @ g(y)
# where l(x) = lambda(x) @ A_inv and g(y) = gamma(y) @ A (referencing paper eq and presentation eq on youtube)
print(model.config.hidden_size)
print(g.size())

1024
torch.Size([50257, 1024])


### GPT2-Medium Issues
- **problem:** gamma @ inv_sqrt_cov_gamma max produces nans 
- **cause:** precision issue in float32
- **fix:** cast to float64 and then back to float32 after causal transform (gamma @ inv_sqrt_cov_gamma)

In [4]:
eigenval_min_max = f"Eigenval min: {eigenvalues.min()}\nEigenval max: {eigenvalues.max()}"
gamma_min_max = f"gamma min: {gamma.min()}\ngamma max: {gamma.max()}"
g_min_max = f"gamma @ inv_sqrt_cov_gamma min: {g.min()}\ngamma @ inv_sqrt_cov_gamma max: {g.max()}"
print(eigenval_min_max)
print(gamma_min_max)
print(g_min_max)
print(f"gamma dtype: {gamma.dtype}")

Eigenval min: 1.1103568198223002e-07
Eigenval max: 0.15361294150352478
gamma min: -1.3290700912475586
gamma max: 0.9381266236305237
gamma @ inv_sqrt_cov_gamma min: nan
gamma @ inv_sqrt_cov_gamma max: nan
gamma dtype: torch.float32


### GPT2-Medium Fix
- do gamma @ inv_sqrt_cov_gamma in float64

In [5]:
gamma = model.lm_head.weight.detach().double()
W, d = gamma.shape
gamma_bar = torch.mean(gamma, dim=0)
centered_gamma = gamma - gamma_bar

### compute Cov(gamma) and tranform gamma to g ###
cov_gamma = centered_gamma.T @ centered_gamma / W
eigenvalues, eigenvectors = torch.linalg.eigh(cov_gamma)

inv_sqrt_cov_gamma = eigenvectors @ torch.diag(1/torch.sqrt(eigenvalues)) @ eigenvectors.T
sqrt_cov_gamma = eigenvectors @ torch.diag(torch.sqrt(eigenvalues)) @ eigenvectors.T

# gamma is our original head and inv_sqrt_cov_gamma puts us in a causal basis
g = gamma @ inv_sqrt_cov_gamma

# maybe i confused but A_inv = sqrt_cov_gamma and A = inv_sqrt_cov_gamma for 
# l(x).T @ g(y)
# where l(x) = lambda(x) @ A_inv and g(y) = gamma(y) @ A (referencing paper eq and presentation eq on youtube)
print(model.config.hidden_size)
print(g.size())


#### cast back to float32 
g = g.float()
inv_sqrt_cov_gamma = inv_sqrt_cov_gamma.float()
sqrt_cov_gamma = sqrt_cov_gamma.float()

1024
torch.Size([50257, 1024])


In [6]:
eigenval_min_max = f"Eigenval min: {eigenvalues.min()}\nEigenval max: {eigenvalues.max()}"
gamma_min_max = f"gamma min: {gamma.min()}\ngamma max: {gamma.max()}"
g_min_max = f"gamma @ inv_sqrt_cov_gamma min: {g.min()}\ngamma @ inv_sqrt_cov_gamma max: {g.max()}"
print(eigenval_min_max)
print(gamma_min_max)
print(g_min_max)
print(f"gamma dtype: {gamma.dtype}")

Eigenval min: 1.0977974690670904e-07
Eigenval max: 0.1536130006206978
gamma min: -1.3290700912475586
gamma max: 0.9381266236305237
gamma @ inv_sqrt_cov_gamma min: -207.949951171875
gamma @ inv_sqrt_cov_gamma max: 645.2638549804688
gamma dtype: torch.float64
