In [43]:
from transformers import GPT2Tokenizer, GPT2LMHeadModel
import torch
from torch.nn import CrossEntropyLoss
import cola
from cola import Auto

# Initialize the tokenizer and model
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
model = GPT2LMHeadModel.from_pretrained('gpt2')

# Tokenize the input text
text = "Hello, how are you?"
inputs = tokenizer(text, return_tensors='pt')
input_ids = inputs.input_ids
attention_mask = inputs.attention_mask

print(f"input_ids shape: {input_ids.shape}")
print(f"attention_mask shape: {attention_mask.shape}")

# Target labels for language modeling (shifted input_ids)
labels = input_ids.clone()
labels[:, :-1] = input_ids[:, 1:].clone()
labels[:, -1] = -100  # Mask the last token for the loss calculation

print(f"labels shape: {labels.shape}")

# Flatten the parameters
def flatten_params(params):
    shapes = [p.shape for p in params]
    flat_params = torch.cat([p.flatten() for p in params])
#     print(f"Shapes: {shapes}")
    print(f"Flat parameters length: {flat_params.numel()}")
    return flat_params, shapes

def unflatten_params(flat_params, shapes):
    params = []
    i = 0
    for shape in shapes:
        size = torch.prod(torch.tensor(shape)).item()
        param = flat_params[i:i + size]
        param = param.view(shape)
        params.append(param)
        i += size
    return params

flat_p, shape = flatten_params(list(model.parameters()))
flat_p = flat_p.detach().requires_grad_(True)

print(f"flat_p shape: {flat_p.shape}")

# Define the stateless model call
def stateless_model(flat_params, args):
    params = unflatten_params(flat_params, shape)
    param_dict = {name: param for name, param in zip(dict(model.named_parameters()).keys(), params)}
    return torch.func.functional_call(model, param_dict, args)

# Define flat_fn
def flat_fn(p):
    print(p.shape, "inside flat_fn p")
    after = p.view(p.size(-1))  # Properly reshape the logits
    print(f"logits shape (after reshape): {after.shape}")
    return after

# Criterion and GN function
criterion = CrossEntropyLoss(ignore_index=-100)

def GN(flat_fn, criterion, flat_p, labels):
    """Gauss-Newton approximation using J^T * J"""
    J = cola.ops.Jacobian(flat_fn, flat_p)
    print(f"Jacobian shape: {J.shape}")
    reshaped_logits = flat_fn(flat_p)
    print(f"reshaped_logits shape: {reshaped_logits.shape}")
    reshaped_labels = labels.view(-1)
    print(f"reshaped_labels shape: {reshaped_labels.shape}")

    def loss_fn(p):
        logits = flat_fn(p)
        return criterion(logits, reshaped_labels)
    
    # Compute J^T J
    G = J.T @ J
    print(f"Gauss-Newton matrix shape: {G.shape}")
    return G

input_ids shape: torch.Size([1, 6])
attention_mask shape: torch.Size([1, 6])
labels shape: torch.Size([1, 6])
Flat parameters length: 124439808
flat_p shape: torch.Size([124439808])


In [44]:
# Compute the Gauss-Newton approximation
G = GN(flat_fn, criterion, flat_p, labels)

torch.Size([124439808]) inside flat_fn p
logits shape (after reshape): torch.Size([124439808])
Jacobian shape: (124439808, 124439808)
torch.Size([124439808]) inside flat_fn p
logits shape (after reshape): torch.Size([124439808])
reshaped_logits shape: torch.Size([124439808])
reshaped_labels shape: torch.Size([6])
Gauss-Newton matrix shape: (124439808, 124439808)


In [48]:
l = torch.randn(G.shape[1], 1, device='cpu')
print(l.shape, G.shape)
try:
    result = G @ l
except AssertionError as e:
    print(f"Assertion Error: {e}")

torch.Size([124439808, 1]) (124439808, 124439808)
torch.Size([124439808]) inside flat_fn p
logits shape (after reshape): torch.Size([124439808])
torch.Size([124439808]) inside flat_fn p
logits shape (after reshape): torch.Size([124439808])


In [49]:
def power_iteration(G, num_simulations=100):
    """Power iteration method for finding the largest eigenvalue."""
    b_k = torch.randn(G.shape[1], 1, device=G.device)

    for _ in range(num_simulations):
        # Calculate the matrix-by-vector product Ab
        try:
            b_k1 = G @ b_k
            print("success")
        except AssertionError as e:
            print(f"Assertion Error in num_simulation: {e}")
            return None
        # Re-normalize the vector
        b_k1_norm = torch.norm(b_k1)
        b_k = b_k1 / b_k1_norm

    # Rayleigh quotient
    try:
        max_eigval = torch.matmul(b_k.T, G @ b_k) / torch.matmul(b_k.T, b_k)
        return max_eigval.item()
    except AssertionError as e:
        print(f"Assertion Error in Rayleigh quotient: {e}")
        return None

# Compute the largest eigenvalue using power iteration
max_eigval = power_iteration(G)
if max_eigval is not None:
    print(f"Largest eigenvalue: {max_eigval}")
else:
    print("Failed to compute the largest eigenvalue due to dimension mismatch.")


torch.Size([124439808]) inside flat_fn p
logits shape (after reshape): torch.Size([124439808])
torch.Size([124439808]) inside flat_fn p
logits shape (after reshape): torch.Size([124439808])
success
torch.Size([124439808]) inside flat_fn p
logits shape (after reshape): torch.Size([124439808])
torch.Size([124439808]) inside flat_fn p
logits shape (after reshape): torch.Size([124439808])
success
torch.Size([124439808]) inside flat_fn p
logits shape (after reshape): torch.Size([124439808])
torch.Size([124439808]) inside flat_fn p
logits shape (after reshape): torch.Size([124439808])
success
torch.Size([124439808]) inside flat_fn p
logits shape (after reshape): torch.Size([124439808])
torch.Size([124439808]) inside flat_fn p
logits shape (after reshape): torch.Size([124439808])
success
torch.Size([124439808]) inside flat_fn p
logits shape (after reshape): torch.Size([124439808])
torch.Size([124439808]) inside flat_fn p
logits shape (after reshape): torch.Size([124439808])
success
torch.Size

In [60]:
def flat_loss(p):
    logits = flat_fn(p)
    logits = logits.view(-1, logits.size(-1))  # Reshape logits to match the labels
    reshaped_labels = labels.view(-1)
    print(logits.shape, reshaped_labels.shape)
    return criterion(logits, reshaped_labels)


In [55]:
p = flat_p.clone().detach().requires_grad_(True)

gnh_losses = []
for epoch in range(1):
    with torch.no_grad():  # Don't pay extra memory for recording the computation graph
        g = torch.autograd.grad(flat_loss(p), p)[0]
        G_inv = cola.inv(GN(flat_fn, criterion, p, labels), alg=Auto(tol=1e-3, max_iters=20))
        p -= G_inv @ g
        loss = flat_loss(p)
        gnh_losses.append(loss.item())
        if (epoch + 1) % 10 == 0:
            print(f'Epoch [{epoch + 1}/{epochs}], Loss: {loss.item()}')

# Update the model parameters with the final p
updated_params = unflatten_params(p, shape)
with torch.no_grad():
    for param, updated_param in zip(model.parameters(), updated_params):
        param.copy_(updated_param)


torch.Size([124439808]) inside flat_fn p
logits shape (after reshape): torch.Size([124439808])


ValueError: Expected input batch_size (1) to match target batch_size (6).