In [9]:
pip install torchvision==0.16.2

Collecting torchvision==0.16.2
  Downloading torchvision-0.16.2-cp310-cp310-manylinux1_x86_64.whl.metadata (6.6 kB)
Downloading torchvision-0.16.2-cp310-cp310-manylinux1_x86_64.whl (6.8 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m6.8/6.8 MB[0m [31m53.2 MB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m
[?25hInstalling collected packages: torchvision
  Attempting uninstall: torchvision
    Found existing installation: torchvision 0.17.1
    Uninstalling torchvision-0.17.1:
      Successfully uninstalled torchvision-0.17.1
Successfully installed torchvision-0.16.2
[0mNote: you may need to restart the kernel to use updated packages.


In [6]:
import torch
from transformers import GPT2Tokenizer, GPT2LMHeadModel
from transformers import AutoTokenizer
from torch.utils.data import DataLoader
from transformers import GPT2Config
from torch.optim import SGD
import torch
# from torch.utils.tensorboard import SummaryWriter
import time
import gpytorch
import gc
import os
import torch.nn as nn
from datasets import load_dataset
from matplotlib import pyplot as plt

In [7]:
# Load the tokenizer and model
model_name = "distilgpt2"  # Using a smaller model for demonstration
tokenizer = GPT2Tokenizer.from_pretrained(model_name)
model = GPT2LMHeadModel.from_pretrained(model_name)



In [8]:
ds = load_dataset("wikipedia", "20220301.simple")

subsample_size = int(0.001 * len(ds['train']))
subsample = ds['train'].shuffle(seed=42).select(range(subsample_size))

if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

def tokenize_function(examples):
    return tokenizer(examples['text'], truncation=True, padding='max_length', max_length=128)

tokenized_docs = subsample.map(tokenize_function, batched=True)

from torch.utils.data import DataLoader

def select_model_inputs(batch):
    return {
        "input_ids": batch["input_ids"],
        "attention_mask": batch["attention_mask"]
    }

model_inputs = tokenized_docs.map(select_model_inputs, batched=True)

# Manually collate a batch
def manual_collate_fn(batch):
    input_ids = [item['input_ids'] for item in batch]
    attention_mask = [item['attention_mask'] for item in batch]
    return {
        'input_ids': torch.tensor(input_ids, dtype=torch.long),
        'attention_mask': torch.tensor(attention_mask, dtype=torch.long)
    }

# Create DataLoader
data_loader = DataLoader(model_inputs, batch_size=1, collate_fn=manual_collate_fn)
device = torch.device("mps") if torch.backends.mps.is_available() else torch.device("cpu")


In [13]:
class CurvVecProduct:
    def __init__(self, inputs, model, criterion, labels, init_vec=None):
        self.inputs = inputs
        self.model = model
        self.criterion = criterion
        self.labels = labels
        self.init_vec = init_vec
        self.iters = 0

    def __call__(self, vector):
        if self.iters == 0 and self.init_vec is not None:
            vector = self.init_vec
        output = hess_vec(vector, self.inputs, self.model, self.criterion, self.labels, cuda=device.type == 'cuda')
        self.iters += 1
        return output.unsqueeze(1)

# Define the Hessian-vector product function
def hess_vec(vector, inputs, model, criterion, labels, cuda=True):
    param_list = list(model.parameters())
    vector_list = []

    offset = 0
    for param in param_list:
        vector_list.append(vector[offset:offset + param.numel()].detach().view_as(param).to(param.device))
        offset += param.numel()

    model.eval()
    model.zero_grad()
    outputs = model(inputs, labels=inputs)
    loss = outputs.loss
    # loss = criterion(outputs, labels)
    #loss = loss.mean()

    grad_list = torch.autograd.grad(loss, param_list, create_graph=True)
    dL_dvec = torch.zeros(1, device='cuda' if cuda else 'cpu')
    for v, g in zip(vector_list, grad_list):
        dL_dvec += torch.sum(v * g)
    dL_dvec.backward()

    return torch.cat([param.grad.view(-1) for param in param_list]).view(-1)


In [None]:
# Set the model to training mode
momentum_buffers = {}

model.train()
losses = []
# Check if MPS is available
learning_rate = 1e-2
lanczos_iters = 20
delta = 1e-3
momentum = 0
weight_decay = 0
smoothing = 0.7
regularity = 1
criterion = nn.CrossEntropyLoss()
model.to(device)

# Number of training epochs
epochs = 1

# Training loop
for batch_idx, batch in enumerate(data_loader):
    # Prepare inputs and labels
    inputs = batch["input_ids"].to(device)

    # Forward pass
    outputs = model(inputs, labels=inputs)
    loss = outputs.loss
    gradients = torch.autograd.grad(loss, model.parameters(), create_graph=True)
    grad_vector = torch.cat([grad.view(-1) for grad in gradients])

    if (batch_idx+1) % regularity == 0 or batch_idx == 0:

        # Curvature vector product and Lanczos tridiagonalization
        print("before curvvec")
        productor = CurvVecProduct(inputs, model, criterion, inputs, init_vec=grad_vector)
        P = sum(p.numel() for p in model.parameters())
        Q, T = gpytorch.utils.lanczos.lanczos_tridiag(
            productor,
            max_iter=lanczos_iters,
            dtype=torch.float32,
            device=device,
            matrix_shape=(P, P)
        )
        print("after curvvec")
        # Compute eigenvalues and eigenvectors
        c_eigvals, c_eigvects = torch.linalg.eigh(T)
        c_gammas = c_eigvects[0, :] ** 2
        c_V = c_eigvects.t() @ Q.t()

        if batch_idx == 0:
            eigvals, V = c_eigvals[:], c_V[:]
        elif not torch.isnan(c_eigvals).any():
            eigvals = smoothing*eigvals + (1-smoothing)*c_eigvals[:]
            V = smoothing*V+(1-smoothing)*c_V[:]
        else:
            print("failure to calculate curvature")

    adjusted_grad_vector = grad_vector.clone()
    new_grad = torch.zeros_like(grad_vector).to(device)
    
    for i, eigval in enumerate(eigvals):
        intermediate_vec = V[i].to(device)
        dot_product = torch.dot(grad_vector, intermediate_vec)
        adjustment = (1 / (abs(eigval) + delta)) * dot_product * intermediate_vec
        new_grad += adjustment
    
    split_sizes = [p.numel() for p in model.parameters()]
    split_gradients = torch.split(new_grad, split_sizes)
    adjusted_gradients = [g.view(p.size()) for g, p in zip(split_gradients, model.parameters())]

    # Update parameters with momentum and weight decay
    with torch.no_grad():
        for param, adj_grad in zip(model.parameters(), adjusted_gradients):
            # Calculate weight decay term if applicable
            weight_decay_term = weight_decay * param.data if weight_decay != 0 else 0
            adjusted_grad_with_weight_decay = adj_grad + weight_decay_term

            # Update momentum buffer
            if param in momentum_buffers:
                momentum_buffers[param] = momentum_buffers[param] * momentum + adjusted_grad_with_weight_decay
            else:
                momentum_buffers[param] = adjusted_grad_with_weight_decay

            # Apply the update to parameters
            param.data -= learning_rate * momentum_buffers[param]

            # Optionally set param.grad for potential further gradient manipulations
            param.grad = momentum_buffers[param]


    losses.append(loss.item())
    print(f"Epoch: {1}, Loss: {loss.item()}")


before curvvec
after curvvec
Epoch: 1, Loss: 3.002678394317627
before curvvec
after curvvec
Epoch: 1, Loss: 8.563395500183105
before curvvec
after curvvec
Epoch: 1, Loss: 2.373828887939453
before curvvec
after curvvec
Epoch: 1, Loss: 2.3428092002868652
before curvvec
after curvvec
Epoch: 1, Loss: 3.6841518878936768
before curvvec
after curvvec
Epoch: 1, Loss: 8.902929306030273
before curvvec
after curvvec
Epoch: 1, Loss: 3.975749969482422
before curvvec
after curvvec
Epoch: 1, Loss: 3.765594720840454
before curvvec
after curvvec
Epoch: 1, Loss: 5.644918918609619
before curvvec
