In [None]:
from google.colab import drive
drive.mount("/content/drive")

In [2]:
project_dir = "/content/drive/My Drive/OMSCS/CS7643/project/"

In [None]:
import os
print(os.listdir(project_dir))

In [None]:
import torch
print(torch.cuda.is_available())

Main logic

In [5]:
from matplotlib import pyplot as plt
import copy
import json
import time
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transforms, models
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay

NUM_CLASSES = 4

# ImageNet normalization values
IMAGENET_NORMALIZE = transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])

# project_dir = './'

def get_model(model_name, dilation=False):
    if model_name == 'resnet50':
        model = models.resnet50(pretrained=True)
        for param in model.parameters():
            param.requires_grad = False

        if dilation:
            for name, module in model.layer4.named_modules():
                if isinstance(module, nn.Conv2d):
                    if 'conv2' in name:
                        module.dilation = (2, 2)
                        module.padding = (2, 2)
                    for param in module.parameters():
                        param.requires_grad = True

        num_features = model.fc.in_features
        model.fc = nn.Linear(num_features, NUM_CLASSES)
        input_size = 224
        resize_size = 256
    elif model_name == 'densenet121':
        model = models.densenet121(pretrained=True)
        for param in model.parameters():
            param.requires_grad = False

        if dilation:
            for name, module in model.features.denseblock4.named_modules():
                if isinstance(module, nn.Conv2d):
                    if 'conv2' in name:
                        module.dilation = (2, 2)
                        module.padding = (2, 2)
                    for param in module.parameters():
                        param.requires_grad = True

        num_features = model.classifier.in_features
        model.classifier = nn.Linear(num_features, NUM_CLASSES)
        input_size = 224
        resize_size = 256
    elif model_name == 'mobilenet_v2':
        model = models.mobilenet_v2(pretrained=True)
        for param in model.parameters():
            param.requires_grad = False

        if dilation:
            for module in model.features[-2].modules():
                for param in module.parameters():
                    param.requires_grad = True

            for module in model.features[-1].modules():
                for param in module.parameters():
                    param.requires_grad = True

        model.features[17].conv[1][0].padding = (2, 2)
        model.features[17].conv[1][0].dilation = (2, 2)

        num_features = model.classifier[-1].in_features
        model.classifier[-1] = nn.Linear(num_features, NUM_CLASSES)

        input_size = 224
        resize_size = 256
    else:
        raise ValueError("Invalid model name. Choose from 'resnet50', 'densenet121', 'efficientnet_b5'")
    return model, input_size, resize_size

def train(parameters, suffix="0"):
    print(parameters)

    # use gpu if available
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    model, input_size, resize_size = get_model(parameters["model"], parameters["dilation"])
    model = model.to(device)

    print(model)

    # Training data
    train_transform = transforms.Compose([
        transforms.RandomResizedCrop(input_size),
        transforms.ToTensor(),
        IMAGENET_NORMALIZE,
    ])
    train_dataset = datasets.ImageFolder(root=parameters["train_dir"], transform=train_transform)
    train_loader = DataLoader(train_dataset, batch_size=parameters["batch_size"], shuffle=True, num_workers=parameters["num_workers"])

    # Test data
    test_transform = transforms.Compose([
        transforms.Resize(resize_size),
        transforms.CenterCrop(input_size),
        transforms.ToTensor(),
        IMAGENET_NORMALIZE,
    ])
    test_dataset = datasets.ImageFolder(root=parameters["test_dir"], transform=test_transform)
    test_loader = DataLoader(test_dataset, batch_size=parameters["batch_size"], shuffle=False, num_workers=parameters["num_workers"])

    # Loss fn
    criterion = nn.CrossEntropyLoss()

    # Optimizer
    optimizer = optim.SGD(model.parameters(), lr=parameters["learning_rate"], momentum=parameters["momentum"])
    print(optimizer)

    # Early Stopping parameters
    patience = 5  # Number of epochs to wait for improvement before stopping
    best_val_loss = None  # Best validation loss observed
    epochs_no_improve = 0  # Count of epochs with no improvement

    train_dataset_size = len(train_dataset)
    samples_used_in_epoch = 0
    train_losses = []
    validation_losses = []
    train_accuracies = []
    validation_accuracies = []
    train_time = []
    validation_time = []
    best_acc = 0
    best_model = None
    best_true_labels = None
    best_predicted_labels = None
    for epoch in range(parameters["epochs"]):
        # Training
        model.train()
        train_loss = 0.0
        correct_train = 0
        total_train = 0

        train_start = time.perf_counter()
        for i, (inputs, labels) in enumerate(train_loader):
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model.forward(inputs)
            loss = criterion.forward(outputs, labels)
            loss.backward()
            optimizer.step()
            train_loss += loss.item()

            _, predicted = torch.max(outputs.data, 1)
            total_train += labels.size(0)
            correct_train += (predicted == labels).sum().item()

            samples_used_in_epoch += inputs.size(0)
            print(f"Training {samples_used_in_epoch}/{train_dataset_size}, Loss: {loss.item()}, Running loss: {train_loss} ")
        train_end = time.perf_counter()

        train_losses.append(train_loss / len(train_loader))
        train_accuracy = correct_train / total_train
        train_accuracies.append(train_accuracy)
        train_time.append(train_end - train_start)

        # Validation
        model.eval()
        val_loss = 0.0
        correct = 0
        total = 0

        #Confusion matrix
        all_true_labels = []
        all_predicted_labels = []
        with torch.no_grad():
            validation_start = time.perf_counter()
            for i, (inputs, labels) in enumerate(test_loader):
                inputs, labels = inputs.to(device), labels.to(device)

                outputs = model.forward(inputs)
                loss = criterion.forward(outputs, labels)
                val_loss += loss.item()
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

                # Collect true and predicted labels
                all_true_labels.extend(labels.cpu().numpy())
                all_predicted_labels.extend(predicted.cpu().numpy())

            validation_end = time.perf_counter()

        validation_losses.append(val_loss / len(test_loader))
        validation_accuracy = correct / total
        validation_accuracies.append(validation_accuracy)
        validation_time.append(validation_end - validation_start)

        if validation_accuracy > best_acc:
            best_acc = validation_accuracy
            best_model = copy.deepcopy(model)
            best_true_labels = all_true_labels
            best_predicted_labels = all_predicted_labels

        samples_used_in_epoch = 0

        print(f'Epoch {epoch + 1}/{parameters["epochs"]}, Train Loss: {train_loss}, Avg Train Loss/Batch: {train_loss/len(train_loader)}, Train Acc: {train_accuracy}, Train time: {train_end - train_start}, Val Loss: {val_loss}, Avg Val Loss/Batch: {val_loss/len(test_loader)}, Accuracy: {correct / total}, Val Time: {validation_start - validation_end}')

        # Early Stopping
        if best_val_loss is None or val_loss < (best_val_loss - 1e-3):
            best_val_loss = val_loss
            epochs_no_improve = 0
        else:
            epochs_no_improve += 1
            if epochs_no_improve >= patience and epoch >= 14:
                print("Early Stopping...")
                break  # Break the training loop

    # Plot curves
    plot_curves(parameters["model"], train_losses, train_accuracies, validation_losses, validation_accuracies, best_true_labels, best_predicted_labels, suffix)

    # Save the model
    torch.save(best_model.state_dict(), f'{project_dir}models/fine_tuned_{parameters["model"]}_{suffix}.pth')

    results = {
        "train_losses": train_losses,
        "validation_losses": validation_losses,
        "train_accuracies": train_accuracies,
        "validation_accuracies": validation_accuracies,
        "train_time": train_time,
        "validation_time": validation_time,
        "best_acc": best_acc,
    }

    with open(f"{project_dir}plots/{parameters['model']}_{suffix}_data.json", "w+") as file:
        json.dump(results, file, indent=4)

    return results

def plot_curves(model_name, train_loss_history, train_acc_history, valid_loss_history, valid_acc_history, true_labels, predicted_labels, suffix):
    # Loss
    plt.figure(figsize=(8, 6))
    plt.plot(train_loss_history, label='Train', color='blue')
    plt.plot(valid_loss_history, label='Validation', color='orange')
    plt.title(f'Loss Curve - {model_name}')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.legend(loc='best')
    plt.xlim(left=0)
    plt.grid(True)
    plt.savefig(f'{project_dir}plots/{model_name}_{suffix}_loss_curve.png')
    plt.close()

    # Accuracy
    plt.figure(figsize=(8, 6))
    plt.plot(train_acc_history, label='Train', color='blue')
    plt.plot(valid_acc_history, label='Validation', color='orange')
    plt.title(f'Accuracy Curve - {model_name}')
    plt.xlabel('Epochs')
    plt.ylabel('Accuracy')
    plt.legend(loc='best')
    plt.xlim(left=0)
    plt.grid(True)
    plt.savefig(f'{project_dir}plots/{model_name}_{suffix}_accuracy_curve.png')
    plt.close()

    class_labels = ['0', '1', '2', '3']
    # Convert to numpy
    cm = confusion_matrix(true_labels, predicted_labels, normalize='true')
    disp = ConfusionMatrixDisplay(confusion_matrix=cm)
    disp.plot(include_values=True, values_format='.2f')
    plt.title(f'Confusion Matrix - {model_name}')
    plt.savefig(f'{project_dir}plots/{model_name}_{suffix}_confusion_matrix.png')
    plt.close()


Hyperparameter tuning

In [None]:
parameters = {
    # Select model: 'resnet50', 'densenet121', 'mobilenet_v2'
    "model": "mobilenet_v2",
    "train_dir": f"{project_dir}filtered/BASE",
    "test_dir": f"{project_dir}test_set",
    "epochs": 100,
    "num_workers": 1,
    # Hyperparameters
    "dilation": False,
    "batch_size": 64,
    "learning_rate": 0.001,
    "momentum": 0.99,
}


batch_sizes = [32, 64, 128]
learning_rates = [0.01, 0.001, 0.0001]
momentum = [0.9, 0.95, 0.99]

accs = []

for b in batch_sizes:
    for lr in learning_rates:
        for m in momentum:
            p = parameters.copy()
            p["batch_size"] = b
            p["learning_rate"] = lr
            p["momentum"] = m

            suffix = f"b{str(b)}-lr{str(lr)}-m{str(m)}"

            results = train(p, suffix)
            acc = results["best_acc"]
            accs.append(f"{suffix},{str(acc)}\n")

with open(f"{project_dir}/plots/results.csv", "w+") as csv:
    csv.writelines(accs)

Run on filtered image sets

In [None]:
parameters = {
    # Select model: 'resnet50', 'densenet121', 'mobilenet_v2'
    "model": "mobilenet_v2",
    "train_dir": f"{project_dir}filtered/BASE",
    "test_dir": f"{project_dir}test_set",
    "epochs": 100,
    "num_workers": 1,
    # Hyperparameters
    "dilation": True,
    "batch_size": 64,
    "learning_rate": 0.001,
    "momentum": 0.99,
}

train_types = ['BASE', 'BLACK_WHITE', 'SHARPEN', 'EDGE_ENHANCE', 'DETAIL', 'EMBOSS']
accs = []
for t in train_types:
    p = parameters.copy()
    p["train_dir"] = f"{project_dir}filtered/{t}"

    suffix = t

    results = train(p, suffix)
    acc = results["best_acc"]
    accs.append(f"{suffix},{str(acc)}\n")

with open(f"{project_dir}/plots/results.csv", "w+") as csv:
    csv.writelines(accs)