# AlexNet fine-tuning & SVM classification

In this assignment, we investigate the use of transfer learning for image classification by leveraging a pre-trained convolutional neural network, specifically AlexNet.

In [None]:
from pathlib import Path

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import models, transforms, datasets

from sklearn.svm import LinearSVC
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import accuracy_score


In [11]:
def get_dataloaders(data_dir: Path, batch_size: int) -> tuple:
    """Create PyTorch data loaders for train, validation and test sets."""
    
    # -----------------------------
    # TRANSFORMS
    # -----------------------------

    # Define preprocessing + augmentation for training images.
    transform_train = transforms.Compose([
        transforms.Resize((224, 224)),                # Resize all images to 224×224
        transforms.Grayscale(num_output_channels=3),  # Convert 1-channel grayscale -> 3-channel (AlexNet expects RGB)
        transforms.ToTensor(),                        # Convert PIL image -> PyTorch tensor (range [0,1])
        transforms.Normalize(                         # Normalize using ImageNet mean/std
            mean=[0.485, 0.456, 0.406], 
            std=[0.229, 0.224, 0.225]
        ),
    ])

    # Define preprocessing for validation/test (no augmentation).
    transform_test = transforms.Compose([
        transforms.Resize((224, 224)),                # Resize same as training
        transforms.Grayscale(num_output_channels=3),  # Convert grayscale -> 3 channels
        transforms.ToTensor(),                        # Convert to tensor
        transforms.Normalize(                         # Normalize with same stats
            mean=[0.485, 0.456, 0.406], 
            std=[0.229, 0.224, 0.225]
        ),
    ])

    # -----------------------------
    # DATASETS
    # -----------------------------

    # Load images from train folder and apply training transforms
    train_dataset = datasets.ImageFolder(str(data_dir / 'train'), transform=transform_train)

    # Load validation dataset with test transforms (no augmentation)
    val_dataset = datasets.ImageFolder(str(data_dir / 'val'), transform=transform_test)

    # Load test dataset
    test_dataset = datasets.ImageFolder(str(data_dir / 'test'), transform=transform_test)

    # -----------------------------
    # DATALOADERS
    # -----------------------------

    # Create dataloader for training
    train_loader = DataLoader(
        train_dataset, 
        batch_size=batch_size, 
        shuffle=True,          # Shuffle training batches
        num_workers=4          # Number of background workers loading data
    )

    # Validation dataloader (no shuffle)
    val_loader = DataLoader(
        val_dataset, 
        batch_size=batch_size, 
        shuffle=False, 
        num_workers=4
    )

    # Test dataloader (no shuffle)
    test_loader = DataLoader(
        test_dataset, 
        batch_size=batch_size, 
        shuffle=False, 
        num_workers=4
    )

    # Return loaders as a tuple
    return train_loader, val_loader, test_loader


In [12]:
def fine_tune_last_layer(data_dir: Path, batch_size: int = 32, epochs: int = 10, learning_rate: float = 0.001) -> None:
    """Fine-tune only the final fully connected layer of a pre-trained AlexNet."""
    
    # Select GPU if available, otherwise CPU
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")

    # Create train/validation/test dataloaders
    train_loader, val_loader, test_loader = get_dataloaders(data_dir, batch_size)

    # Load AlexNet pre-trained on ImageNet
    model = models.alexnet(weights=models.AlexNet_Weights.DEFAULT)
    model.to(device)  # Move model to chosen device

    # Freeze all layers so they will NOT update during training
    for param in model.parameters():
        param.requires_grad = False

    # Extract number of input features to the final linear layer
    num_features = model.classifier[6].in_features

    # Replace the last classification layer with a new one for 15 classes
    model.classifier[6] = nn.Linear(num_features, 15)
    model.classifier[6].to(device)

    # Only train the parameters of the new final layer
    parameters_to_update = model.classifier[6].parameters()

    # Loss function for multi-class classification
    criterion = nn.CrossEntropyLoss()

    # SGD optimizer with momentum, applied ONLY to the last layer
    optimizer = optim.SGD(parameters_to_update, lr=learning_rate, momentum=0.9)

    best_val_acc = 0.0  # Track best validation accuracy

    # -------------------------
    # TRAINING LOOP
    # -------------------------
    for epoch in range(1, epochs + 1):

        model.train()  # Set model to training mode
        running_loss = 0.0  # Track total loss
        correct = 0         # Track correct predictions
        total = 0           # Track total samples

        # Iterate over training batches
        for inputs, labels in train_loader:
            inputs, labels = inputs.to(device), labels.to(device)  # Move data to device

            optimizer.zero_grad()       # Reset gradients
            outputs = model(inputs)     # Forward pass
            loss = criterion(outputs, labels)  # Compute loss
            loss.backward()             # Backpropagation
            optimizer.step()            # Update weights (only last layer)

            running_loss += loss.item() * inputs.size(0)  # Accumulate batch loss
            _, preds = outputs.max(1)                     # Predicted class indices
            correct += preds.eq(labels).sum().item()      # Count correct predictions
            total += labels.size(0)                       # Add batch size

        # Compute epoch-level training stats
        train_loss = running_loss / total
        train_acc = correct / total

        # -------------------------
        # VALIDATION LOOP
        # -------------------------
        model.eval()   # Switch to eval mode
        val_correct = 0
        val_total = 0

        with torch.no_grad():  # No gradient calculations during validation
            for inputs, labels in val_loader:
                inputs, labels = inputs.to(device), labels.to(device)
                outputs = model(inputs)
                _, preds = outputs.max(1)
                val_correct += preds.eq(labels).sum().item()
                val_total += labels.size(0)

        # Compute validation accuracy
        val_acc = val_correct / val_total

        print(
            f"Epoch {epoch}/{epochs} - "
            f"train_loss: {train_loss:.4f}, train_acc: {train_acc:.4f}, val_acc: {val_acc:.4f}"
        )

        # Save model if validation accuracy improved
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            torch.save(model.state_dict(), 'finetuned_alexnet_best.pth')

    print(f"Best validation accuracy: {best_val_acc * 100:.2f}%")

    # -------------------------
    # TESTING LOOP
    # -------------------------

    # Load the best-performing checkpoint
    model.load_state_dict(torch.load('finetuned_alexnet_best.pth', map_location=device))
    model.eval()

    test_correct = 0
    test_total = 0

    with torch.no_grad():  # Disable gradient computation
        for inputs, labels in test_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            _, preds = outputs.max(1)
            test_correct += preds.eq(labels).sum().item()
            test_total += labels.size(0)

    # Compute final test accuracy
    test_acc = test_correct / test_total
    print(f"Test accuracy (fine-tuning last layer): {test_acc * 100:.2f}%")


In [13]:
from sklearn.base import BaseEstimator, ClassifierMixin, clone
from itertools import combinations

class DAGSVM(BaseEstimator, ClassifierMixin):
    """
    Multiclass SVM using a Directed Acyclic Graph over One-vs-One binary classifiers.
    """

    def __init__(self, base_estimator=None):
        # base_estimator: an unfitted sklearn SVM 
        self.base_estimator = base_estimator


    def fit(self, X, y):
        X = np.asarray(X)      # Convert X to NumPy array
        y = np.asarray(y)      # Convert y to NumPy array

        self.classes_ = np.unique(y)   # All unique class labels
        self.pair_clfs_ = {}           # Dict to store all pairwise classifiers

        # Train one binary classifier for each pair of classes
        for ci, cj in combinations(self.classes_, 2):  # Generate all (ci, cj) pairs
            mask = np.logical_or(y == ci, y == cj)     # Keep only samples of class ci or cj
            X_ij = X[mask]                             # Filter features
            y_ij = y[mask]                             # Filter labels

            if self.base_estimator is None:
                # Default classifier: LinearSVC
                clf = LinearSVC(C=1.0, dual=False, max_iter=10000)
            else:
                clf = clone(self.base_estimator)       # Clone user-provided estimator

            clf.fit(X_ij, y_ij)                        # Train on the two-class subset
            self.pair_clfs_[(ci, cj)] = clf            # Store classifier for this pair

        return self


    def _predict_single(self, x):
        # DAG prediction: start with all classes and eliminate losers
        remaining = list(self.classes_)           # The set of candidate classes
        i = 0                                     # Pointer to first class
        j = len(remaining) - 1                    # Pointer to last class

        # Continue eliminating until only one class remains
        while len(remaining) > 1:
            ci, cj = remaining[i], remaining[j]  # Compare the "left" and "right" classes

            # Retrieve classifier for (ci, cj) regardless of order
            key = (ci, cj) if (ci, cj) in self.pair_clfs_ else (cj, ci)
            clf = self.pair_clfs_[key]

            pred = clf.predict(x.reshape(1, -1))[0]  # Predict which of ci/cj wins

            # Winner stays in DAG, loser gets removed
            if pred == ci:
                # ci wins -> remove cj
                remaining.pop(j)
                j = len(remaining) - 1               # Update end pointer
            else:
                # cj wins -> remove ci
                remaining.pop(i)
                j = len(remaining) - 1               # Update end pointer

        return remaining[0]  # Only one class left -> final prediction


    def predict(self, X):
        X = np.asarray(X)                          # Ensure NumPy array
        # Predict sample-by-sample using DAG logic
        return np.array([self._predict_single(x) for x in X])


In [None]:
def svm_feature_extraction(data_dir: Path, batch_size: int = 32) -> None:
    """Use AlexNet as a feature extractor and train a multi-class linear SVM."""

    # Select GPU if available, otherwise CPU
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")

    # ----------------------------------------------------
    # TRANSFORMS & DATASETS — same as fine-tuning but NO augmentation
    # ----------------------------------------------------
    transform = transforms.Compose([
        transforms.Resize((224, 224)),               # Resize input for AlexNet
        transforms.Grayscale(num_output_channels=3), # Convert grayscale -> RGB-like 3 channels
        transforms.ToTensor(),                       # Convert to tensor
        transforms.Normalize(                        # Normalise using ImageNet statistics
            mean=[0.485, 0.456, 0.406],
            std=[0.229, 0.224, 0.225]
        ),
    ])

    # ImageFolder automatically assigns class labels based on directory names
    train_dataset = datasets.ImageFolder(str(data_dir / 'train'), transform=transform)
    val_dataset   = datasets.ImageFolder(str(data_dir / 'val'),   transform=transform)
    test_dataset  = datasets.ImageFolder(str(data_dir / 'test'),  transform=transform)

    # Data loaders: shuffle=False because order matters for feature extraction
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=False, num_workers=4)
    val_loader   = DataLoader(val_dataset,   batch_size=batch_size, shuffle=False, num_workers=4)
    test_loader  = DataLoader(test_dataset,  batch_size=batch_size, shuffle=False, num_workers=4)

    # ----------------------------------------------------
    # LOAD ALEXNET AND BUILD FEATURE EXTRACTOR (remove classifier)
    # ----------------------------------------------------
    alexnet = models.alexnet(weights=models.AlexNet_Weights.DEFAULT)  # Pretrained on ImageNet
    alexnet.to(device)
    alexnet.eval()  # Disable dropout, batch norm updates

    # Create sequential model: convolutional layers + average pooling + flatten
    feature_extractor = nn.Sequential(
        alexnet.features,   # Convolutional backbone
        alexnet.avgpool,    # Adaptive average pooling
        nn.Flatten(),       # Flatten output into a 1D vector
    ).to(device)

    # ----------------------------------------------------
    # FEATURE EXTRACTION FUNCTION
    # ----------------------------------------------------
    def extract_features(loader: DataLoader) -> tuple:
        features_list = []  # store extracted features
        labels_list = []    # store labels matching features

        with torch.no_grad():  # no gradient computation
            for inputs, labels in loader:
                inputs = inputs.to(device)
                feats = feature_extractor(inputs)    # forward pass through CNN backbone
                features_list.append(feats.cpu().numpy())  # save features to CPU array
                labels_list.append(labels.numpy())         # save labels

        # Concatenate batches into full dataset arrays
        features = np.concatenate(features_list, axis=0)
        labels   = np.concatenate(labels_list,   axis=0)
        return features, labels

    # ----------------------------------------------------
    # EXTRACT FEATURES FOR TRAIN + VAL (training set)
    # ----------------------------------------------------
    print("Extracting features for training and validation...")
    train_features, train_labels = extract_features(train_loader)
    val_features,   val_labels   = extract_features(val_loader)

    # Merge training and validation sets for SVM training
    X_train = np.vstack((train_features, val_features))
    y_train = np.concatenate((train_labels, val_labels))

    # ----------------------------------------------------
    # EXTRACT FEATURES FOR TEST SET
    # ----------------------------------------------------
    print("Extracting features for test set...")
    X_test, y_test = extract_features(test_loader)

    # ----------------------------------------------------
    # SCALE FEATURES (VERY IMPORTANT FOR SVMs)
    # ----------------------------------------------------
    scaler = StandardScaler()
    X_train_scaled = scaler.fit_transform(X_train)   # Fit on train, transform train
    X_test_scaled  = scaler.transform(X_test)        # Transform test only

    # ----------------------------------------------------
    # TRAIN DAG-SVM (ONE-VS-ONE CLASSIFIERS + DAG INFERENCE)
    # ----------------------------------------------------
    print("Training multi-class DAGSVM")

    base_svm = LinearSVC(C=1.0, dual=False, max_iter=10)  # Base binary classifier
    svm = DAGSVM(base_estimator=base_svm)                  # Wrap into DAG-SVM structure
    svm.fit(X_train_scaled, y_train)                       # Train all pairwise SVMs

    # ----------------------------------------------------
    # EVALUATE ON TEST SET
    # ----------------------------------------------------
    y_pred = svm.predict(X_test_scaled)          # Predict using DAG traversal
    acc = accuracy_score(y_test, y_pred)         # Compute test accuracy

    print(f"Test accuracy (SVM on features): {acc * 100:.2f}%")

In [15]:
data_dir = Path('/Users/leonardoangellotti/Desktop/universita/Comp_Vision/CNN/data_augmented')

In [16]:
fine_tune_last_layer(data_dir) # 15 minutes

Using device: cpu
Epoch 1/10 - train_loss: 0.7389, train_acc: 0.7559, val_acc: 0.8933
Epoch 2/10 - train_loss: 0.3577, train_acc: 0.8837, val_acc: 0.9022
Epoch 3/10 - train_loss: 0.2708, train_acc: 0.9141, val_acc: 0.8889
Epoch 4/10 - train_loss: 0.2394, train_acc: 0.9233, val_acc: 0.9289
Epoch 5/10 - train_loss: 0.2080, train_acc: 0.9320, val_acc: 0.9022
Epoch 6/10 - train_loss: 0.1913, train_acc: 0.9378, val_acc: 0.9156
Epoch 7/10 - train_loss: 0.1723, train_acc: 0.9457, val_acc: 0.9200
Epoch 8/10 - train_loss: 0.1636, train_acc: 0.9490, val_acc: 0.9022
Epoch 9/10 - train_loss: 0.1532, train_acc: 0.9553, val_acc: 0.8978
Epoch 10/10 - train_loss: 0.1450, train_acc: 0.9533, val_acc: 0.8978
Best validation accuracy: 92.89%
Test accuracy (fine-tuning last layer): 87.04%


In [17]:
svm_feature_extraction(data_dir) # 3 minutes

Using device: cpu
Extracting features for training and validation...
Extracting features for test set...
Training multi-class DAGSVM (one-vs-one + DAG inference)... this may take a few minutes
Test accuracy (SVM on features): 85.56%
