In [6]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader


In [8]:
# Device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Transform: resize to 224x224 and convert 1 channel -> 3 channels
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.Grayscale(num_output_channels=3),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

In [10]:



# Load MNIST dataset
train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
test_dataset = datasets.MNIST(root='./data', train=False, transform=transform, download=True)

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

# Load ResNet18 (not pretrained)
model = models.resnet18(pretrained=False)

# Modify the final FC layer
model.fc = nn.Linear(model.fc.in_features, 10)

# Send to device
model = model.to(device)

# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Training loop
def train(num_epochs=5):
    model.train()
    for epoch in range(num_epochs):
        running_loss = 0.0
        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)

            outputs = model(images)
            loss = criterion(outputs, labels)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            running_loss += loss.item()

        print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(train_loader):.4f}")

# Evaluation
def evaluate():
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    print(f"Accuracy: {100 * correct / total:.2f}%")

if __name__ == '__main__':
    train(num_epochs=5)
    evaluate()


100%|██████████████████████████████████████| 9.91M/9.91M [00:01<00:00, 7.24MB/s]
100%|███████████████████████████████████████| 28.9k/28.9k [00:00<00:00, 259kB/s]
100%|██████████████████████████████████████| 1.65M/1.65M [00:00<00:00, 2.68MB/s]
100%|██████████████████████████████████████| 4.54k/4.54k [00:00<00:00, 6.94MB/s]


KeyboardInterrupt: 

In [None]:
import torch
import torch.nn as nn
from torchvision import models, transforms
import numpy as np

# Load and modify model
model = models.resnet18(pretrained=False)
model.fc = nn.Linear(model.fc.in_features, 10)

# Dict to hold activations
activations = {}

# Hook function
def get_activation(name):
    def hook(model, input, output):
        activations[name] = output.detach()
        print(f"{name:10s} | shape: {tuple(output.shape)}")
        # Uncomment to print some values (can be large!)
        # print(f"Sample values from {name}:\n", output[0, :2, :4, :4] if output.ndim == 4 else output[0, :10])
    return hook

# Register hooks
for name, layer in model.named_children():
    layer.register_forward_hook(get_activation(name))

# Dummy input (mimicking MNIST resized to 3×224×224)
x = torch.randn(1, 3, 224, 224)
with torch.no_grad():
    output = model(x)

print("\nFinal output shape:", output.shape)


In [13]:
import torch
from torchvision import models

model = models.resnet18(pretrained=False)
model.fc = torch.nn.Linear(model.fc.in_features, 10)  # Replace for MNIST

print(model)


ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  

In [15]:
# ROF

In [19]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import numpy as np
from torchvision import datasets, transforms, models


In [57]:
# Use MPS if available
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
print(f"Using device: {device}")

Using device: mps


In [59]:
model = models.resnet18(pretrained=False)
model.fc = torch.nn.Linear(model.fc.in_features, 10)
model.to(device)
model.eval()


ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  

In [47]:
# Load MNIST and convert to 3 channels
transform = transforms.Compose([
    transforms.Resize((112, 112)),
    transforms.Grayscale(num_output_channels=3),
    transforms.ToTensor(),
])


In [75]:
# Use a smaller subset for quick testing
full_dataset = datasets.MNIST(root='./data', train=False, transform=transform, download=True)
subset = full_dataset
#subset = torch.utils.data.Subset(full_dataset, list(range(1000)))  # First 1000 samples
test_loader = DataLoader(subset, batch_size=128, shuffle=False, num_workers=2)


In [85]:
full_dataset

Dataset MNIST
    Number of datapoints: 10000
    Root location: ./data
    Split: Test
    StandardTransform
Transform: Compose(
               Resize(size=(112, 112), interpolation=bilinear, max_size=None, antialias=True)
               Grayscale(num_output_channels=3)
               ToTensor()
           )

In [77]:
# Hook to capture fc layer outputs
fc_outputs = []

def fc_hook(module, input, output):
    fc_outputs.append(output.detach().cpu())  # Detach + move to CPU for L1 norm

hook_handle = model.fc.register_forward_hook(fc_hook)


In [87]:
def test_baseline_accuracy(model, dataloader, device):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in dataloader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            preds = outputs.argmax(dim=1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)
    acc = correct / total
    print(f"Baseline accuracy (no masking): {acc*100:.2f}%")
    return acc

# Run baseline test
baseline_acc = test_baseline_accuracy(model, test_loader, device)


Baseline accuracy (no masking): 9.82%


In [79]:
# 1. Run model to collect fc activations
with torch.no_grad():
    for images, _ in test_loader:
        images = images.to(device)
        _ = model(images)

# Combine outputs and compute L1-norm per unit
fc_outputs_tensor = torch.cat(fc_outputs, dim=0)  # Shape: [N, 10]
l1_norms = fc_outputs_tensor.abs().mean(dim=0)    # Shape: [10]

# Rank units by L1-norm
unit_ranking = torch.argsort(l1_norms, descending=True)

# Print result
print("\n📊 L1-norm per unit in FC layer:")
for i, score in enumerate(l1_norms):
    print(f"Unit {i}: {score:.6f}")

print("\n🏅 Ranked units (most active to least):", unit_ranking.tolist())

# Clean up
hook_handle.remove()


📊 L1-norm per unit in FC layer:
Unit 0: 0.235651
Unit 1: 0.443512
Unit 2: 0.099427
Unit 3: 0.087424
Unit 4: 0.752889
Unit 5: 0.134741
Unit 6: 0.208930
Unit 7: 0.024645
Unit 8: 0.065303
Unit 9: 0.393814

🏅 Ranked units (most active to least): [4, 1, 9, 0, 6, 5, 2, 3, 8, 7]


In [80]:
# 2. --- ROF Evaluation: Reactivate top N FC units ---
def rof_fc_eval(unit_ranking, top_k):
    def rof_fc_mask(module, input, output):
        B, D = output.shape
        mask = torch.zeros_like(output)
        for i in range(min(top_k, D)):
            idx = unit_ranking[i]
            mask[:, idx] = 1.0
        return output * mask

    # Register temporary hook
    hook = model.fc.register_forward_hook(rof_fc_mask)

    total_correct, total_loss, total_samples = 0, 0.0, 0
    criterion = nn.CrossEntropyLoss()

    model.eval()
    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)

            preds = outputs.argmax(dim=1)
            total_correct += (preds == labels).sum().item()
            total_loss += loss.item() * labels.size(0)
            total_samples += labels.size(0)

    hook.remove()
    acc = total_correct / total_samples
    avg_loss = total_loss / total_samples
    return avg_loss, acc

In [83]:
# 3. --- Run ROF for 1 to 10 active units ---
E_n = []
print("\n📈 ROF Curve (Final FC Layer on MPS):")
for n in range(1, 11):
    loss, acc = rof_fc_eval(unit_ranking, top_k=n)
    print(f"Top {n:2d} units activated | Loss: {loss:.4f} | Acc: {acc*100:.2f}%")
    E_n.append((n, loss, acc))


📈 ROF Curve (Final FC Layer on MPS):
Top  1 units activated | Loss: 2.3379 | Acc: 9.82%
Top  2 units activated | Loss: 2.3490 | Acc: 9.82%
Top  3 units activated | Loss: 2.3602 | Acc: 9.82%
Top  4 units activated | Loss: 2.3669 | Acc: 9.82%
Top  5 units activated | Loss: 2.3672 | Acc: 9.82%
Top  6 units activated | Loss: 2.3692 | Acc: 9.82%
Top  7 units activated | Loss: 2.3675 | Acc: 9.82%
Top  8 units activated | Loss: 2.3685 | Acc: 9.82%
Top  9 units activated | Loss: 2.3683 | Acc: 9.82%
Top 10 units activated | Loss: 2.3691 | Acc: 9.82%


In [71]:
def test_accuracy():
    correct, total = 0, 0
    model.eval()
    with torch.no_grad():
        for imgs, labels in test_loader:
            imgs, labels = imgs.to(device), labels.to(device)
            outputs = model(imgs)
            preds = outputs.argmax(dim=1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)
    return correct / total

print(f"Baseline test accuracy: {test_accuracy() * 100:.2f}%")


Baseline test accuracy: 11.00%


In [73]:
with torch.no_grad():
    for images, _ in test_loader:
        images = images.to(device)
        _ = model(images)

# Combine all outputs: shape [N, 10]
fc_outputs_tensor = torch.cat(fc_outputs, dim=0)

# Compute L1-norm across all samples (mean over dim 0)
l1_norms = fc_outputs_tensor.abs().mean(dim=0)

# Sort unit indices in descending order of importance
unit_ranking = torch.argsort(l1_norms, descending=True)

# Print results
print("L1-norm per FC unit:")
for i, norm in enumerate(l1_norms):
    print(f"Unit {i}: {norm:.6f}")

print("\nRanked units (most to least active):", unit_ranking.tolist())

# Remove hook
hook_handle.remove()

L1-norm per FC unit:
Unit 0: 0.228653
Unit 1: 0.425979
Unit 2: 0.093129
Unit 3: 0.081897
Unit 4: 0.719917
Unit 5: 0.128647
Unit 6: 0.196815
Unit 7: 0.023364
Unit 8: 0.061695
Unit 9: 0.382617

Ranked units (most to least active): [4, 1, 9, 0, 6, 5, 2, 3, 8, 7]
