In [None]:
# ------------------------------
# 1️⃣ Imports and device setup
# ------------------------------
import random
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import torchvision.models as models
import numpy as np
import cv2
import matplotlib.pyplot as plt

SEED = 123
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print("Using device:", DEVICE)

In [None]:
# ------------------------------
# 2️⃣ Parameters
# ------------------------------
N_EDGE_EPOCHS = 4        # epochs to feed edge maps
BATCH_SIZE = 64
EPOCHS =10
LEARNING_RATE = 1e-3


In [None]:
# ------------------------------
# 3️⃣ Edge transform
# ------------------------------
class EdgeTransform:
    
    def __call__(self, img):
        img_np = np.array(img)
        gray = cv2.cvtColor(img_np, cv2.COLOR_RGB2GRAY)
        grad_x = cv2.Sobel(gray, cv2.CV_64F, 1, 0, ksize=3)
        grad_y = cv2.Sobel(gray, cv2.CV_64F, 0, 1, ksize=3)
        edges = np.sqrt(grad_x**2 + grad_y**2)
        edges = np.clip(edges / edges.max(), 0, 1)
        edges = np.stack([edges]*3, axis=0)
        return torch.tensor(edges, dtype=torch.float32)

edge_transform = EdgeTransform()

In [None]:
# ------------------------------
# 4️⃣ Data loading with subset
# ------------------------------
from torch.utils.data import Subset, DataLoader

def load_data(batch_size=BATCH_SIZE, subset_percent=None):
    transform_full = transforms.Compose([transforms.ToTensor()])

    trainset_full = torchvision.datasets.CIFAR10(root='./data', train=True,
                                                 download=True, transform=transform_full)
    testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                           download=True, transform=transform_full)

    if subset_percent is not None:
        subset_size = int(len(trainset_full) * subset_percent)
        indices = torch.randperm(len(trainset_full))[:subset_size]
        trainset = Subset(trainset_full, indices)
    else:
        trainset = trainset_full

    trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True)
    testloader = DataLoader(testset, batch_size=batch_size, shuffle=False)
    return trainloader, testloader

In [None]:
# Example usage: 10% of training data
trainloader, testloader = load_data(subset_percent = 1)

In [None]:
# ------------------------------
# 5️⃣ Medium CNN for CIFAR-100
# ------------------------------
def resnet18_cifar(output_classes = 100):
    # Load ResNet18 without pretrained weights
    model = models.resnet18(weights=None)

    # Modify the first convolution layer for 32x32 inputs
    model.conv1 = nn.Conv2d(
        3, 64, kernel_size=3, stride=1, padding=1, bias=False
    )
    model.maxpool = nn.Identity()  # remove the maxpool layer

    # Adjust the final fully connected layer for CIFAR-100
    num_ftrs = model.fc.in_features
    model.fc = nn.Linear(num_ftrs, output_classes)

    return model

In [None]:
# ------------------------------
# 6️⃣ Training function
# ------------------------------
def train_one_epoch(model, loader, optimizer, criterion, epoch_num, n_edge_epochs=N_EDGE_EPOCHS):
    model.train()
    running_loss = 0.0
    correct, total = 0, 0
    for images, labels in loader:
        images, labels = images.to(DEVICE), labels.to(DEVICE)

        if epoch_num < n_edge_epochs:
            # Edge transform
            images = torch.stack([edge_transform(img.cpu().permute(1,2,0).numpy()*255) for img in images])
            images = images.to(DEVICE)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * images.size(0)
        _, predicted = outputs.max(1)
        correct += predicted.eq(labels).sum().item()
        total += labels.size(0)

    return running_loss / total, correct / total

In [None]:
# ------------------------------
# 7️⃣ Validation function
# ------------------------------
def validate(model, loader, criterion):
    model.eval()
    val_running_loss = 0.0
    val_correct, val_total = 0, 0
    with torch.no_grad():
        for images, labels in loader:
            images, labels = images.to(DEVICE), labels.to(DEVICE)
            outputs = model(images)
            loss = criterion(outputs, labels)
            val_running_loss += loss.item() * images.size(0)
            _, predicted = outputs.max(1)
            val_correct += predicted.eq(labels).sum().item()
            val_total += labels.size(0)
    return val_running_loss / val_total, val_correct / val_total

In [None]:
def freeze_high_layers(model):
    # Freeze layer3, layer4, and fc 
    for name, param in model.named_parameters():
        if "layer4" in name or "layer3" in name:
            param.requires_grad = False
        else:
            param.requires_grad = True


def unfreeze_all_layers(model):
    for param in model.parameters():
        param.requires_grad = True

In [None]:
def mean_abs_activation(model, loader, device="cuda"):
    """
    Returns and prints mean absolute activation per layer for a single batch.
    Works safely for ResNet-18, including avgpool and fc.
    """
    model.eval()
    x, _ = next(iter(loader))
    x = x.to(device)

    activations = {}

    with torch.no_grad():
        # conv1
        out = model.conv1(x)
        activations["conv1"] = out.abs().mean().item()

        # layers 1-4
        out = model.layer1(out)
        activations["layer1"] = out.abs().mean().item()

        out = model.layer2(out)
        activations["layer2"] = out.abs().mean().item()

        out = model.layer3(out)
        activations["layer3"] = out.abs().mean().item()

        out = model.layer4(out)
        activations["layer4"] = out.abs().mean().item()

        # avgpool
        out = model.avgpool(out)
        activations["avgpool"] = out.abs().mean().item()

        # flatten + fc
        out = torch.flatten(out, 1)
        out = model.fc(out)
        activations["fc"] = out.abs().mean().item()

    print(activations)
    return


In [None]:
# ------------------------------
# 8️⃣ Training loop (piecewise)
# ------------------------------
train_losses, val_losses = [], []
train_accs, val_accs = [], []

model1 = resnet18_cifar(10).to(DEVICE)
optimizer1 = optim.Adam(model1.parameters(), lr=LEARNING_RATE)
criterion1 = nn.CrossEntropyLoss()
min_loss, val_acc_at_best_epoch, train_acc_at_best_epoch = 100,100,100

print("1.) Baseline.")

for epoch in range(EPOCHS):
    train_loss, train_acc = train_one_epoch(model1, trainloader, optimizer1, criterion1, epoch, 0)
    val_loss, val_acc = validate(model1, testloader, criterion1)

    train_losses.append(train_loss)
    val_losses.append(val_loss)
    train_accs.append(train_acc)
    val_accs.append(val_acc)

    print(f"Epoch {epoch+1}/{EPOCHS} | "
          f"Train Loss: {train_loss:.4f}, Acc: {train_acc:.4f} | "
          f"Val Loss: {val_loss:.4f}, Acc: {val_acc:.4f}")
    mean_abs_activation(model1, trainloader)

    if(min_loss > val_loss):
        min_loss = val_loss
        val_acc_at_best_epoch = val_acc
        train_acc_at_best_epoch = train_acc

print(f"Best run ---> val_acc : {val_acc_at_best_epoch:.4f} |"
      f"train_acc : {train_acc_at_best_epoch:.4f} |"
      f"val_loss : {min_loss:.4f}")


print("2.) Edge training + freeze layer 3 and 4 for edge training.")

model2 = resnet18_cifar(10).to(DEVICE)
optimizer2 = torch.optim.Adam(model2.parameters(), lr=LEARNING_RATE)
criterion2 = nn.CrossEntropyLoss()
min_loss, val_acc_at_best_epoch, train_acc_at_best_epoch = 100,100,100

freeze_high_layers(model2)
for epoch in range(EPOCHS):
    if epoch == N_EDGE_EPOCHS:  
        unfreeze_all_layers(model2)
        optimizer2 = torch.optim.Adam(model2.parameters(), lr=LEARNING_RATE)
    train_loss, train_acc = train_one_epoch(model2, trainloader, optimizer2, criterion2, epoch)
    val_loss, val_acc = validate(model2, testloader, criterion2)

                
    train_losses.append(train_loss)
    val_losses.append(val_loss)
    train_accs.append(train_acc)
    val_accs.append(val_acc)

    print(f"Epoch {epoch+1}/{EPOCHS} | "
        f"Train Loss: {train_loss:.4f}, Acc: {train_acc:.4f} | "
        f"Val Loss: {val_loss:.4f}, Acc: {val_acc:.4f}")
    mean_abs_activation(model2, trainloader)

    if(min_loss > val_loss):
        min_loss = val_loss
        val_acc_at_best_epoch = val_acc
        train_acc_at_best_epoch = train_acc

print(f"Best run ---> val_acc : {val_acc_at_best_epoch:.4f} |"
      f"train_acc : {train_acc_at_best_epoch:.4f} |"
      f"val_loss : {min_loss:.4f}")


In [None]:
# ------------------------------
# 9️⃣ Plot function (can run separately)
# ------------------------------
def plot_curves(train_losses, val_losses, train_accs, val_accs):
    plt.figure(figsize=(12,5))
    plt.subplot(1,2,1)
    plt.plot(train_losses, label="Train Loss")
    plt.plot(val_losses, label="Val Loss")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.title("Loss Curves")
    plt.legend()

    plt.subplot(1,2,2)
    plt.plot(train_accs, label="Train Acc")
    plt.plot(val_accs, label="Val Acc")
    plt.xlabel("Epoch")
    plt.ylabel("Accuracy")
    plt.title("Accuracy Curves")
    plt.legend()
    plt.show()