In [1]:
import einops

from functools import partial
from itertools import product
import numpy as np
from pathlib import Path
from plotnine import (
    ggplot,
    geom_point, 
    geom_histogram, 
    geom_line,
    geom_ribbon,
    qplot, 
    coord_fixed, 
    aes, 
    facet_wrap, 
    labs,
    scale_x_log10,
    scale_y_log10
)
import polars as pl
import torch
from scipy.sparse import coo_array, csr_array
from scipy import linalg
import scipy.sparse.linalg as sparselinalg
from tokengrams import MemmapIndex, InMemoryIndex
from tqdm.notebook import tqdm
from transformers import AutoTokenizer
from transformer_lens import HookedTransformer, HookedTransformerConfig
import zstandard as zstd


from ngram_markov.utils import create_ngrams, nanogpt_to_hooked_transformer_config, convert_nanogpt_weights
from ngram_markov.model import GPT, GPTConfig
from torch.nn.functional import softmax


import collections
from collections import defaultdict
from itertools import islice
import numpy as np
import rustworkx as rx

from typing import List, Tuple


import cola



In [2]:
torch.cuda.is_available()

True

In [3]:
tokenizer = AutoTokenizer.from_pretrained("tokenizer/tinystories512")
ts_bin_path = 'data/tinystories/train.bin'
index_path = "data/tinystories/ngrams/suffix_tree.idx"

index = MemmapIndex(ts_bin_path, index_path)

In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from ngram_markov.model import Block, LayerNorm

class GPTEmbeddingModel(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config

        self.transformer = nn.ModuleDict(dict(
            drop = nn.Dropout(config.dropout),
            h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
            ln_f = LayerNorm(config.n_embed, bias=config.bias),
        ))
        self.lm_head = nn.Linear(config.n_embed, config.vocab_size, bias=False)

        # init all weights
        self.apply(self._init_weights)
        # apply special scaled init to the residual projections, per GPT-2 paper
        for pn, p in self.named_parameters():
            if pn.endswith('c_proj.weight'):
                torch.nn.init.normal_(p, mean=0.0, std=0.02/math.sqrt(2 * config.n_layer))

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)

    def forward(self, x):
        b, t, e = x.size()
        assert e == self.config.n_embed, f"Input embedding dimension {e} doesn't match model dimension {self.config.n_embed}"
        assert t <= self.config.block_size, f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}"

        # forward the GPT model itself
        x = self.transformer.drop(x)
        for block in self.transformer.h:
            x = block(x)
        x = self.transformer.ln_f(x)
        logits = self.lm_head(x)

    

        return logits

def create_embedding_encoder(gpt_model):
    class EmbeddingEncoder(nn.Module):
        def __init__(self, wte, wpe):
            super().__init__()
            self.wte = wte
            self.wpe = wpe

        def forward(self, idx):
            device = idx.device
            b, t = idx.size()
            pos = torch.arange(0, t, dtype=torch.long, device=device)
            tok_emb = self.wte(idx)
            pos_emb = self.wpe(pos)
            return tok_emb + pos_emb

    return EmbeddingEncoder(gpt_model.transformer.wte, gpt_model.transformer.wpe)

def split_gpt_model(gpt_model):
    embedding_encoder = create_embedding_encoder(gpt_model)
    
    config = gpt_model.config
    embedding_model = GPTEmbeddingModel(config)
    
    # Copy weights from gpt_model to embedding_model
    embedding_model_dict = embedding_model.state_dict()
    gpt_model_dict = gpt_model.state_dict()
    
    for name, param in gpt_model_dict.items():
        if name in embedding_model_dict:
            embedding_model_dict[name].copy_(param)
    
    embedding_model.load_state_dict(embedding_model_dict)
    
    return embedding_encoder, embedding_model

# Usage example:
# gpt_model = GPT(config)  # Your original GPT model
# embedding_encoder, embedding_model = split_gpt_model(gpt_model)
#
# # To use:
# indices = torch.randint(0, config.vocab_size, (batch_size, seq_length))
# embeddings = embedding_encoder(indices)
# logits, loss = embedding_model(embeddings, targets)

In [5]:
#tokenizer = AutoTokenizer.from_pretrained("EleutherAI/pythia-1.4b")
#MAX_TOKEN = 50276 

#data_path = '/media/External01/pile-resharded/document-00000-of-00020.bin'
#index_path = '/media/External01/pile-shard-suffix-arrays/suffix_array0.idx'

#index = MemmapIndex(data_path, index_path)


In [6]:
epoch = 200
model_path = Path('/media/External01/ngram-checkpoints/4layer_tinystories')

ckpt = torch.load(model_path / f'ckpt{epoch}.pt', map_location='cpu')

In [7]:
from ngram_markov.model import GPT, GPTConfig
from ngram_markov.utils import create_ngrams, nanogpt_to_hooked_transformer_config, convert_nanogpt_weights
from torch.nn.functional import softmax


#tl_weights = convert_nanogpt_weights(ckpt['model'], config)
#tl_model = HookedTransformer(config)
#tl_model.load_state_dict(tl_weights)

gpt_config = GPTConfig(**ckpt['model_args'])

gpt_model = GPT(gpt_config)
gpt_model.load_state_dict(ckpt['model'])





number of parameters: 12.85M


RuntimeError: Error(s) in loading state_dict for GPT:
	Missing key(s) in state_dict: "transformer.h.0.attn.bias", "transformer.h.1.attn.bias", "transformer.h.2.attn.bias", "transformer.h.3.attn.bias". 

In [8]:
encoder, embedding_model = split_gpt_model(gpt_model)

In [9]:
encoder

EmbeddingEncoder(
  (wte): Embedding(512, 512)
  (wpe): Embedding(1024, 512)
)

In [10]:
from ngram_markov.ngrams import create_ngrams, calculate_ngram_kl_divergence


data_dir = Path('data/tinystories')
n = 2
batch_size = 8
block_size = 1024
device_type = 'cuda'
device = 'cuda'

def get_batch(split):
    # We recreate np.memmap every batch to avoid a memory leak, as per
    # https://stackoverflow.com/questions/45132940/numpy-memmap-memory-usage-want-to-iterate-once/61472122#61472122
    if split == 'train':
        data = np.memmap(data_dir / 'train.bin', dtype=np.uint16, mode='r')
    else:
        data = np.memmap(data_dir / 'validation.bin', dtype=np.uint16, mode='r')
    ix = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack([torch.from_numpy((data[i:i+block_size]).astype(np.int64)) for i in ix])
    y = torch.stack([torch.from_numpy((data[i+1:i+1+block_size]).astype(np.int64)) for i in ix])
    if device_type == 'cuda':
        # pin arrays x,y, which allows us to move them to GPU asynchronously (non_blocking=True)
        x, y = x.pin_memory().to(device, non_blocking=True), y.pin_memory().to(device, non_blocking=True)
    else:
        x, y = x.to(device), y.to(device)
    return x, y


batch, _ = get_batch('train')
#ngrams = create_ngrams(batch.cpu(), n-1)

In [11]:
embedding_model.to('cuda')

GPTEmbeddingModel(
  (transformer): ModuleDict(
    (drop): Dropout(p=0.0, inplace=False)
    (h): ModuleList(
      (0-3): 4 x Block(
        (ln_1): LayerNorm()
        (attn): CausalSelfAttention(
          (c_attn): Linear(in_features=512, out_features=1536, bias=False)
          (c_proj): Linear(in_features=512, out_features=512, bias=False)
          (attn_dropout): Dropout(p=0.0, inplace=False)
          (resid_dropout): Dropout(p=0.0, inplace=False)
        )
        (ln_2): LayerNorm()
        (mlp): MLP(
          (c_fc): Linear(in_features=512, out_features=2048, bias=False)
          (gelu): GELU(approximate='none')
          (c_proj): Linear(in_features=2048, out_features=512, bias=False)
          (dropout): Dropout(p=0.0, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm()
  )
  (lm_head): Linear(in_features=512, out_features=512, bias=False)
)

In [12]:
import cola
from cola.ops import Hessian
from torch.nn.functional import softmax, log_softmax

def fisher_information(logits_fn, theta):
    probs = softmax(logits_fn(theta), dim=-1)

    def entropy(theta):
        log_probs = log_softmax(logits_fn(theta), dim=-1)
        return -torch.sum(probs * log_probs, dim=-1).mean()

    return cola.PSD(Hessian(entropy, theta))
    

In [18]:
encoder.to('cuda')

EmbeddingEncoder(
  (wte): Embedding(512, 512)
  (wpe): Embedding(1024, 512)
)

In [21]:
embedding_model(encoder(batch[:1].to('cuda')).to('cuda')).shape

torch.Size([1, 1024, 512])

In [25]:
tokens = encoder(batch[:1].to('cuda')).to('cuda')[:, :10, :]
tokens

tensor([[[ 0.0448, -0.0171, -0.0329,  ..., -0.0176, -0.0195,  0.0132],
         [ 0.0198,  0.0429, -0.0390,  ...,  0.0104,  0.0404,  0.0054],
         [ 0.0003, -0.0088, -0.0053,  ..., -0.0449,  0.0216, -0.0007],
         ...,
         [ 0.0292,  0.0157, -0.0411,  ...,  0.0057,  0.0302, -0.0180],
         [ 0.0107,  0.0262,  0.0185,  ...,  0.0300,  0.0748,  0.0060],
         [ 0.0091, -0.0076, -0.0149,  ..., -0.0279,  0.0307, -0.0038]]],
       device='cuda:0', grad_fn=<SliceBackward0>)

In [27]:
def entropy(theta):
    logits = embedding_model(theta)[:, -1, :].squeeze()
    log_probs = log_softmax(logits, dim=-1)
    probs = softmax(logits, dim=-1)
    return -1.0 * torch.sum(probs * log_probs, dim=-1).mean()


fim = torch.func.hessian(entropy)

fim(tokens)

NotImplementedError: Trying to use forward AD with _scaled_dot_product_efficient_attention that does not support it because it has not been implemented yet.
Please file an issue to PyTorch at https://github.com/pytorch/pytorch/issues/new?template=feature-request.yml so that we can prioritize its implementation.
Note that forward AD support for some operators require PyTorch to be built with TorchScript and for JIT to be enabled. If the environment var PYTORCH_JIT=0 is set or if the library is not built with TorchScript, some operators may no longer be used with forward AD.

In [43]:
import torch
import torch.nn.functional as F

def entropy(theta):
    logits = embedding_model(theta)
    log_probs = F.log_softmax(logits, dim=-1)
    probs = F.softmax(logits, dim=-1)
    return -1.0 * torch.sum(probs * log_probs, dim=-1).mean()

def fim_estimator_jvp_vjp(func, inputs, num_samples=100):
    input_size = inputs.numel()
    device = inputs.device

    def hvp(v):
        def grad_func(x):
            return torch.func.grad(func)(x)
        
        _, jvp_out = torch.func.jvp(grad_func, (inputs,), (v,))
        return jvp_out

    fim_diag = torch.zeros(input_size, device=device)
    
    for _ in range(num_samples):
        v = torch.randn(input_size, device=device)
        v = v / torch.norm(v)
        v = v.view_as(inputs)
        
        hvp_result = hvp(v)
        fim_diag += (hvp_result * v).view(-1)
    
    fim_diag /= num_samples
    return fim_diag.view_as(inputs)



def fim_estimator_jacrev(func, inputs, num_samples=100):
    input_size = inputs.numel()
    device = inputs.device

    def sample_and_contract(v):
        jacrev_out = torch.func.jacrev(func)(inputs)
        return torch.sum(jacrev_out * v)

    fim_diag = torch.zeros(input_size, device=device)
    
    for _ in range(num_samples):
        v = torch.randn(input_size, device=device)
        v = v / torch.norm(v)
        v = v.view_as(inputs)
        
        _, vjp_out = torch.func.vjp(sample_and_contract, v)
        fim_diag += (vjp_out[0] * v).view(-1)
    
    fim_diag /= num_samples
    return fim_diag.view_as(inputs)

# Usage
tokens = batch[:1]  # Your input tokens
embeddings = encoder(tokens)

# Using JVP and VJP
fim_diag = fim_estimator_jvp_vjp(entropy, embeddings)

# Using jacrev
#fim_diag_jacrev = fim_estimator_jacrev(entropy, embeddings)

print("Estimated diagonal of Fisher Information Matrix (JVP/VJP):", fim_diag_jvp_vjp)
#print("Estimated diagonal of Fisher Information Matrix (jacrev):", fim_diag_jacrev)

# Visualization
import matplotlib.pyplot as plt

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(20, 6))

im1 = ax1.imshow(fim_diag_jvp_vjp.detach().cpu().numpy().mean(axis=0), cmap='viridis')
ax1.set_title("Mean FIM diagonal (JVP/VJP)")
ax1.set_xlabel("Embedding dimension")
ax1.set_ylabel("Sequence position")
plt.colorbar(im1, ax=ax1)

im2 = ax2.imshow(fim_diag_jacrev.detach().cpu().numpy().mean(axis=0), cmap='viridis')
ax2.set_title("Mean FIM diagonal (jacrev)")
ax2.set_xlabel("Embedding dimension")
ax2.set_ylabel("Sequence position")
plt.colorbar(im2, ax=ax2)

plt.tight_layout()
plt.show()

OutOfMemoryError: CUDA out of memory. Tried to allocate 4096.00 GiB. GPU 