<a href="https://colab.research.google.com/github/ethangearey/nc-lora/blob/main/Experiment_3.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Experiment 3: LoRA and Frozen Backbone Probes
- Take a pretrained classifier (e.g., trained ResNet).
- Fine-tune only the final layer using LoRA, varying the rank (e.g., 1–16).
- Evaluate:
 - Whether NC geometry persists or evolves under low-rank adaptation.
 - If LoRA directions align with NC class mean directions (cosine similarity, projection overlap).

### TODO
1. baseline: full fine-tuning (necessary?) (think: MPU)
2. validate data collection


In [None]:
import torch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Mount Drive for checkpoints
from google.colab import drive
drive.mount('/content/drive')

# Reuse existing imports and configurations from Experiment 1+2
import gc
import numpy as np
import pandas as pd
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
import torch.nn.functional as F
import torchvision.models as models

from tqdm import tqdm
from collections import OrderedDict
from scipy.sparse.linalg import svds
from sklearn.utils.extmath import randomized_svd
from torchvision import datasets, transforms
from IPython import embed
from sklearn.decomposition import IncrementalPCA
from peft import LoraConfig, get_peft_model

debug = True

# Random seed
seed                = 42
torch.manual_seed(seed)
np.random.seed(seed)

# CIFAR dataset parameters
im_size             = 32
padded_im_size      = 32
input_ch            = 3
C                   = 10

# Optimization Criterion
loss_name = 'CrossEntropyLoss'

lr                  = 0.01 # verify with ablation
batch_size          = 128
momentum            = 0.9
weight_decay        = 5e-4 # too high?

epochs              = 20
ranks               = [1,2,4,8,16]


Mounted at /content/drive


In [None]:
def analysis(graphs, model, criterion_summed, device, num_classes, loader, epoch):
    model.eval()

    N             = [0 for _ in range(C)]
    mean          = [torch.zeros(model.fc.in_features, device=device) for _ in range(C)]
    Sw            = 0

    loss          = 0
    net_correct   = 0
    NCC_match_net = 0

    for computation in ['Mean','Cov']:
        pbar = tqdm(total=len(loader), position=0, leave=True)
        for batch_idx, (data, target) in enumerate(loader, start=1):

            data, target = data.to(device), target.to(device)

            output = model(data)
            h = features.value.data.view(data.shape[0],-1) # B CHW

            # during calculation of class means, calculate loss
            if computation == 'Mean':
                if str(criterion_summed) == 'CrossEntropyLoss()':
                  loss += criterion_summed(output, target).item()

            for c in range(C):
                # features belonging to class c
                idxs = (target == c).nonzero(as_tuple=True)[0]

                if len(idxs) == 0: # If no class-c in this batch
                  continue

                h_c = h[idxs,:] # B CHW

                if computation == 'Mean':
                    # update class means
                    mean[c] += torch.sum(h_c, dim=0) # CHW
                    N[c] += h_c.shape[0]

                elif computation == 'Cov':
                    # update within-class cov

                    z = h_c - mean[c].unsqueeze(0) # B CHW
                    cov = torch.matmul(z.unsqueeze(-1), # B CHW 1
                                       z.unsqueeze(1))  # B 1 CHW
                    Sw += torch.sum(cov, dim=0)

                    # during calculation of within-class covariance, calculate:
                    # 1) network's accuracy
                    net_pred = torch.argmax(output[idxs,:], dim=1)
                    net_correct += sum(net_pred==target[idxs]).item()

                    # 2) agreement between prediction and nearest class center
                    NCC_scores = torch.stack([torch.norm(h_c[i,:] - M.T,dim=1) \
                                              for i in range(h_c.shape[0])])
                    NCC_pred = torch.argmin(NCC_scores, dim=1)
                    NCC_match_net += sum(NCC_pred==net_pred).item()

            pbar.update(1)
            pbar.set_description(
                'Analysis {}\t'
                'Epoch: {} [{}/{} ({:.0f}%)]'.format(
                    computation,
                    epoch,
                    batch_idx,
                    len(loader),
                    100. * batch_idx/ len(loader)))

            if debug and batch_idx > 20:
                break
        pbar.close()

        if computation == 'Mean':
            for c in range(C):
                mean[c] /= N[c]
                M = torch.stack(mean).T
                graphs.mean = mean
            loss /= sum(N)

            # Feature rank analysis



        elif computation == 'Cov':
            Sw /= sum(N)

    graphs.loss.append(loss)
    graphs.accuracy.append(net_correct/sum(N))
    graphs.NCC_mismatch.append(1-NCC_match_net/sum(N))

    # loss with weight decay
    reg_loss = loss
    for param in model.parameters():
        reg_loss += 0.5 * weight_decay * torch.sum(param**2).item()
    graphs.reg_loss.append(reg_loss)

    # global mean
    muG = torch.mean(M, dim=1, keepdim=True) # CHW 1

    # between-class covariance
    M_ = M - muG
    Sb = torch.matmul(M_, M_.T) / C

    # avg norm
    if hasattr(classifier, 'original'):
        W = classifier.original.weight  # LoRA case #?
    else:
        W = classifier.weight # rank=0 baseline
    M_norms = torch.norm(M_,  dim=0)
    W_norms = torch.norm(W.T, dim=0)

    graphs.norm_M_CoV.append((torch.std(M_norms)/torch.mean(M_norms)).item())
    graphs.norm_W_CoV.append((torch.std(W_norms)/torch.mean(W_norms)).item())

    # tr{Sw Sb^-1}
    Sw = Sw.cpu().double()
    Sw += 1e-8 * torch.eye(Sw.shape[0], device=Sw.device) # add jitter for numerical sability
    Sb = Sb.cpu().double()  # Extra precision for small eigenvalues; modified orig.
    eigvec, eigval, _ = torch.svd_lowrank(Sb, q=C-1)
    inv_Sb = eigvec @ torch.diag(1/eigval) @ eigvec.T
    graphs.Sw_invSb.append(torch.trace(Sw @ inv_Sb).item())

    # ||W^T - M_||
    normalized_M = M_ / torch.norm(M_,'fro')
    normalized_W = W.T / torch.norm(W.T,'fro')
    graphs.W_M_dist.append((torch.norm(normalized_W - normalized_M)**2).item())

    # mutual coherence
    def coherence(V):
        G = V.T @ V
        G += torch.ones((C,C),device=device) / (C-1)
        G -= torch.diag(torch.diag(G))
        return torch.norm(G,1).item() / (C*(C-1))

    graphs.cos_M.append(coherence(M_/M_norms))
    graphs.cos_W.append(coherence(W.T/W_norms))


class Graphs:
  def __init__(self):
    self.accuracy     = []
    self.loss         = []
    self.reg_loss     = []

    # NC1
    self.Sw_invSb     = []

    # NC2
    self.norm_M_CoV   = []
    self.norm_W_CoV   = []
    self.cos_M        = []
    self.cos_W        = []

    # NC3
    self.W_M_dist     = []

    # NC4
    self.NCC_mismatch = []

    self.mean         = []
    self.feature_rank = [] # not used


In [None]:
def prepare_lora_model(pretrained_model, rank=4):
    config = LoraConfig(
        r=rank,
        lora_alpha=16,  # Scaling factor (alpha = 16 is common)
        target_modules=["fc"],  # Apply LoRA to final layer
        lora_dropout=0.05,
        bias="none",
    )
    model = get_peft_model(pretrained_model, config)
    print("Trainable parameters:", sum(p.numel() for p in model.parameters() if p.requires_grad))
    return model.to(device)


def compute_alignment_metrics(model, class_means, rank):
    """Calculate alignment between LoRA directions and NC class means."""
    lora_A = model.base_model.model.fc.lora_A['default'].weight  # [rank, in_dim]
    lora_B = model.base_model.model.fc.lora_B['default'].weight  # [out_dim, rank]
    W_lora = lora_B @ lora_A  # Combined LoRA direction [out_dim, in_dim]

    # Project class means onto LoRA subspace
    M = class_means.T.cpu().numpy()  # [in_dim, C]
    U, _, _ = randomized_svd(W_lora.detach().cpu().numpy(), n_components=rank)

    # Cosine similarity between LoRA directions and class means
    cos_sims = []
    for c in range(M.shape[1]):
        v = M[:, c]
        for i in range(U.shape[1]):
            u = U[:, i]
            cos_sim = np.dot(u, v) / (np.linalg.norm(u)*np.linalg.norm(v)+1e-8)
            cos_sims.append(cos_sim)

    # Subspace projection score
    proj = U.T @ M
    proj_score = np.linalg.norm(proj)**2 / np.linalg.norm(M)**2

    return np.mean(cos_sims), proj_score


In [None]:
def train_lora(model, criterion, device, train_loader, optimizer, epoch, rank):
    model.train()
    # model.fc.original.requires_grad_(False)  # redundant as peft freezes weights automatically

    pbar = tqdm(total=len(train_loader), desc=f'LoRA Rank {rank} Epoch {epoch}')
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()

        # Update progress bar
        pbar.set_postfix({'Loss': loss.item()})
        pbar.update(1)

        if debug and batch_idx > 20:
            break
    pbar.close()


In [None]:
# ====================== EXECUTION ======================
# checkpoint from Experiment 1+2
checkpoint = torch.load('/content/drive/MyDrive/checkpoint.pth', weights_only=False)

# dataset, optimizer setup from Experiment 1+2
transform = transforms.Compose([transforms.ToTensor(),
                                transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261))
                                ])

train_loader = torch.utils.data.DataLoader(
    datasets.CIFAR10('../data', train=True, download=True, transform=transform),
    batch_size=batch_size, shuffle=True, drop_last=True)

analysis_loader = torch.utils.data.DataLoader(
    datasets.CIFAR10('../data', train=True, download=True, transform=transform),
    batch_size=batch_size, shuffle=True, drop_last=True)

criterion = nn.CrossEntropyLoss()
criterion_summed = nn.CrossEntropyLoss(reduction='sum')

## Run Experiment 3
original_class_means = None  # To store initial NC geometry
for rank in ranks:

    results = []

    # Load fresh pretrained model for each run
    pretrained_model = models.resnet18(pretrained=False, num_classes=C)
    pretrained_model.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
    pretrained_model.maxpool = nn.Identity()
    # Register feature hook
    class features:
        pass
    def hook(self, input, output):
        features.value = input[0].clone()
    pretrained_model.fc.register_forward_hook(hook)
    pretrained_model.load_state_dict(checkpoint['model_state_dict'])
    pretrained_model.to(device)

    # Apply LoRA
    model = prepare_lora_model(pretrained_model, rank=rank)
    # optimizer = optim.Adam(model.fc.parameters(), lr=1e-3)
    optimizer = optim.SGD(pretrained_model.parameters(),
                          lr=lr,
                          momentum=momentum,
                          weight_decay=weight_decay) # SGD, not Adam, for consistency with Exp1+2

    # Initial NC state, from Exp1+2
    graphs = Graphs()
    analysis(graphs, model, criterion_summed, device, C, analysis_loader, epoch=0)
    original_class_means = torch.stack(graphs.mean).T if original_class_means is None else original_class_means

    # Experiment 3: Fine-tune pretrained model for 20 epochs under LoRA
    for epoch in range(1, epochs + 1):
        train_lora(model, criterion, device, train_loader, optimizer, epoch, rank)
        analysis(graphs, model, criterion_summed, device, C, analysis_loader, epoch)

        # Compute alignment metrics
        current_class_means = torch.stack(graphs.mean).T
        cos_sim, proj_score = compute_alignment_metrics(model, current_class_means, rank)

        results.append({
            'rank': rank,
            'epoch': epoch,
            'Sw_invSb': graphs.Sw_invSb[-1],
            'W_M_dist': graphs.W_M_dist,
            'NCC_mismatch': graphs.NCC_mismatch,
            'norm_M_CoV': graphs.norm_M_CoV,
            'norm_W_CoV': graphs.norm_W_CoV,
            'cos_M': graphs.cos_M,
            'cos_W': graphs.cos_W,
            # 'feature_rank': graphs.feature_rank[-1],
            'cos_sim': cos_sim,
            'proj_score': proj_score
            # 'test_acc': graphs.accuracy[-1]  # Assuming test loader available
        })

    # Save results per rank
    pd.DataFrame(results).to_csv(f'lora_rank_{rank}_results.csv', index=False)


100%|██████████| 170M/170M [00:04<00:00, 38.5MB/s]


Trainable parameters: 522


LoRA Rank 1 Epoch 1:   6%|▌         | 22/390 [01:10<19:41,  3.21s/it, Loss=1.28]
LoRA Rank 1 Epoch 2:   6%|▌         | 22/390 [01:10<19:39,  3.21s/it, Loss=1.18]
LoRA Rank 1 Epoch 3:   6%|▌         | 22/390 [01:13<20:24,  3.33s/it, Loss=1.19]
LoRA Rank 1 Epoch 4:   6%|▌         | 22/390 [01:10<19:46,  3.22s/it, Loss=1.35]
LoRA Rank 1 Epoch 5:   6%|▌         | 22/390 [01:11<19:49,  3.23s/it, Loss=1.33]
LoRA Rank 1 Epoch 6:   6%|▌         | 22/390 [01:11<19:54,  3.24s/it, Loss=1.35]
LoRA Rank 1 Epoch 7:   6%|▌         | 22/390 [01:11<19:58,  3.26s/it, Loss=1.33]
LoRA Rank 1 Epoch 8:   6%|▌         | 22/390 [01:10<19:43,  3.22s/it, Loss=1.41]
LoRA Rank 1 Epoch 9:   6%|▌         | 22/390 [01:10<19:44,  3.22s/it, Loss=1.39]
LoRA Rank 1 Epoch 10:   6%|▌         | 22/390 [01:11<19:54,  3.25s/it, Loss=1.48]
LoRA Rank 1 Epoch 11:   6%|▌         | 22/390 [01:10<19:31,  3.18s/it, Loss=1.3]
LoRA Rank 1 Epoch 12:   6%|▌         | 22/390 [01:14<20:39,  3.37s/it, Loss=1.27]
LoRA Rank 1 Epoch 13:   6%

Trainable parameters: 1044


LoRA Rank 2 Epoch 1:   6%|▌         | 22/390 [01:09<19:30,  3.18s/it, Loss=1.32]
LoRA Rank 2 Epoch 2:   6%|▌         | 22/390 [01:11<19:56,  3.25s/it, Loss=1.54]
LoRA Rank 2 Epoch 3:   6%|▌         | 22/390 [01:11<19:54,  3.25s/it, Loss=1.59]
LoRA Rank 2 Epoch 4:   6%|▌         | 22/390 [01:10<19:31,  3.18s/it, Loss=1.44]
LoRA Rank 2 Epoch 5:   6%|▌         | 22/390 [01:10<19:36,  3.20s/it, Loss=1.32]
LoRA Rank 2 Epoch 6:   6%|▌         | 22/390 [01:10<19:39,  3.20s/it, Loss=1.28]
LoRA Rank 2 Epoch 7:   6%|▌         | 22/390 [01:10<19:37,  3.20s/it, Loss=1.47]
LoRA Rank 2 Epoch 8:   6%|▌         | 22/390 [01:09<19:30,  3.18s/it, Loss=1.41]
LoRA Rank 2 Epoch 9:   6%|▌         | 22/390 [01:11<19:51,  3.24s/it, Loss=1.42]
LoRA Rank 2 Epoch 10:   6%|▌         | 22/390 [01:11<19:49,  3.23s/it, Loss=1.29]
LoRA Rank 2 Epoch 11:   6%|▌         | 22/390 [01:10<19:31,  3.18s/it, Loss=1.32]
LoRA Rank 2 Epoch 12:   6%|▌         | 22/390 [01:11<20:03,  3.27s/it, Loss=1.29]
LoRA Rank 2 Epoch 13:   6

Trainable parameters: 2088


LoRA Rank 4 Epoch 1:   6%|▌         | 22/390 [01:09<19:30,  3.18s/it, Loss=1.27]
LoRA Rank 4 Epoch 2:   6%|▌         | 22/390 [01:10<19:39,  3.21s/it, Loss=1.3]
LoRA Rank 4 Epoch 3:   6%|▌         | 22/390 [01:10<19:35,  3.19s/it, Loss=1.32]
LoRA Rank 4 Epoch 4:   6%|▌         | 22/390 [01:10<19:46,  3.22s/it, Loss=1.32]
LoRA Rank 4 Epoch 5:   6%|▌         | 22/390 [01:10<19:32,  3.19s/it, Loss=1.41]
LoRA Rank 4 Epoch 6:   6%|▌         | 22/390 [01:12<20:12,  3.30s/it, Loss=1.32]
LoRA Rank 4 Epoch 7:   6%|▌         | 22/390 [01:10<19:31,  3.18s/it, Loss=1.26]
LoRA Rank 4 Epoch 8:   6%|▌         | 22/390 [01:10<19:36,  3.20s/it, Loss=1.28]
LoRA Rank 4 Epoch 9:   6%|▌         | 22/390 [01:11<19:54,  3.25s/it, Loss=1.56]
LoRA Rank 4 Epoch 10:   6%|▌         | 22/390 [01:09<19:30,  3.18s/it, Loss=1.42]
LoRA Rank 4 Epoch 11:   6%|▌         | 22/390 [01:10<19:31,  3.18s/it, Loss=1.37]
LoRA Rank 4 Epoch 12:   6%|▌         | 22/390 [01:11<19:59,  3.26s/it, Loss=1.21]
LoRA Rank 4 Epoch 13:   6%

In [None]:
# Visualization
df = pd.DataFrame(results)
plt.figure(figsize=(12, 4))
for rank in df['rank'].unique():
    subset = df[df['rank'] == rank]
    plt.plot(subset['epoch'], subset['cos_sim'], label=f'Rank {rank}')
plt.xlabel('Epoch')
plt.ylabel('Mean Cosine Similarity')
plt.title('LoRA Direction Alignment with Class Means')
plt.legend()
plt.show()

plt.figure(figsize=(12, 4))
for rank in df['rank'].unique():
    subset = df[df['rank'] == rank]
    plt.plot(subset['epoch'], subset['NC1'], label=f'Rank {rank}')
plt.xlabel('Epoch')
plt.ylabel('NC1 (Tr{Sw Sb^-1})')
plt.title('Neural Collapse Persistence During LoRA Tuning')
plt.legend()
plt.show()

## Baseline: Full Fine-Tuning


From Sonnet; not debugged

In [None]:
# After your LoRA experiments, add:

# Full fine-tuning baseline
full_model = models.resnet18(pretrained=False, num_classes=C)
full_model.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
full_model.maxpool = nn.Identity()
full_model.load_state_dict(checkpoint['model_state_dict'])
full_model.to(device)

# Register hook
full_model.fc.register_forward_hook(hook)

# Save original weights for later comparison
original_fc_weights = full_model.fc.weight.clone().detach()

# For full fine-tuning, only train the classifier layer for fair comparison with LoRA
for param in full_model.parameters():
    param.requires_grad = False
full_model.fc.weight.requires_grad = True
full_model.fc.bias.requires_grad = True

# Use same optimizer as LoRA experiments
optimizer = optim.SGD(full_model.fc.parameters(), lr=0.01, momentum=momentum, weight_decay=weight_decay)

# Train and collect same metrics
full_results = []
graphs = Graphs()
for epoch in range(1, 21):
    train_full_ft(full_model, criterion, device, train_loader, optimizer, epoch)
    analysis(graphs, full_model, criterion_summed, device, C, analysis_loader, epoch)

    # Compute weight update and its rank
    current_weights = full_model.fc.weight.clone().detach()
    weight_delta = current_weights - original_fc_weights
    weight_delta_rank = compute_effective_rank(weight_delta)

    # Current class means
    current_class_means = torch.stack(graphs.mean).T

    # Store results with same metrics as LoRA
    full_results.append({
        'method': 'full_ft',
        'epoch': epoch,
        'NC1': graphs.Sw_invSb[-1],
        'NC2': graphs.norm_M_CoV[-1],
        'NC3': graphs.W_M_dist[-1],
        'NC4': graphs.NCC_mismatch[-1],
        'weight_update_rank': weight_delta_rank,
        'class_mean_change': torch.norm(current_class_means - original_class_means).item(),
    })

pd.DataFrame(full_results).to_csv('full_ft_results.csv', index=False)