In [5]:
import argparse
from blocks import *
import torch
import torchvision as tv
import torchvision.transforms.v2 as v2
from utils import validateModelIO, getNormalizedTransforms
from torch.profiler import profile, record_function, ProfilerActivity
from torch.utils.data import DataLoader, Dataset, random_split, Subset, TensorDataset, SubsetRandomSampler

from trainableModel import TrainingParameters, TrainableModel

from dataLoading import CIFAR10Dataset

# Import model and transform variables
from models import *
from transforms import *
from utils import getModel, determinsticSplitFullDataset

import numpy as np

import os

In [22]:
def evaluateModel(model:nn.Sequential, dataloader:DataLoader, freezeModel=True, transform=None) -> tuple[float, float]:

    device = "cuda" if torch.cuda.is_available() else "cpu"
    assert device == 'cuda', f'Device is not cuda: {device}'

    lossFunction = nn.CrossEntropyLoss()
    
    totalLoss = 0
    
    N = 0
    correct = 0
    
    model.to(device)
    # model.eval()
    
    with torch.no_grad():
        for features, labels in dataloader:
            # x, y = transform(features).to(device).detach(), labels.to(device).detach()
            x, y = (features).to(device).detach(), labels.to(device).detach()

            forwardPass = model.forward(x)
            
            # This adds the current accuracy to correct which is averaged over all iterations of the epoch.
            correct += (forwardPass.argmax(dim=1) == y).float().mean().item()

            loss = lossFunction(forwardPass, y)
            totalLoss += loss.item()
            
            N += 1
                
    return totalLoss / N, correct / N

dataLoaderParameters = {
                'num_workers': 2,
                'pin_memory': True,
                'prefetch_factor': 4
            }

NAMPATHS = [
    # ('baseline130kN_vanilla', r'C:\Users\Nicholas\Documents\Desktop\SCHOOL\GRADUATE Offline\CS 444\Final Project\CS-444-Final-Project\models\baseline130kN_vanilla-FULL_Epoch200_Batch256_LR0.05_Momentum0.9'),
    # ('baseline130kN_Easy', r'C:\Users\Nicholas\Documents\Desktop\SCHOOL\GRADUATE Offline\CS 444\Final Project\CS-444-Final-Project\models\baseline130kN_easyaugmentation-FULL_Epoch200_Batch256_LR0.05_Momentum0.9'),
    # ('baseline130kN_Hard2', r'C:\Users\Nicholas\Documents\Desktop\SCHOOL\GRADUATE Offline\CS 444\Final Project\CS-444-Final-Project\models\baseline130kN_hardAugmentation2-FULL_Epoch200_Batch256_LR0.05_Momentum0.9'),
    # ('baseline130kN_Hard3', r'C:\Users\Nicholas\Documents\Desktop\SCHOOL\GRADUATE Offline\CS 444\Final Project\CS-444-Final-Project\models\baseline130kN_hardAugmentation3-FULL_Epoch200_Batch256_LR0.05_Momentum0.9'),

    # ('baseline430kN_vanilla', r"C:\Users\Nicholas\Documents\Desktop\SCHOOL\GRADUATE Offline\CS 444\Final Project\CS-444-Final-Project\models\baseline430kN_vanilla-FULL_Epoch200_Batch256_LR0.05_Momentum0.9"),
    # ('baseline430kN_Easy', r"C:\Users\Nicholas\Documents\Desktop\SCHOOL\GRADUATE Offline\CS 444\Final Project\CS-444-Final-Project\models\baseline430kN_easyaugmentation-FULL_Epoch200_Batch256_LR0.05_Momentum0.9"),
    # ('baseline430kN_Hard2', r"C:\Users\Nicholas\Documents\Desktop\SCHOOL\GRADUATE Offline\CS 444\Final Project\CS-444-Final-Project\models\baseline430kN_hardAugmentation2-FULL_Epoch200_Batch256_LR0.05_Momentum0.9"),
    # ('baseline430kN_Hard3', r"C:\Users\Nicholas\Documents\Desktop\SCHOOL\GRADUATE Offline\CS 444\Final Project\CS-444-Final-Project\models\baseline430kN_hardAugmentation3-FULL_Epoch200_Batch256_LR0.05_Momentum0.9"),

    # ('baseline108MN_vanilla', r"C:\Users\Nicholas\Documents\Desktop\SCHOOL\GRADUATE Offline\CS 444\Final Project\CS-444-Final-Project\models\baseline108MN_vanilla-FULL_Epoch200_Batch256_LR0.05_Momentum0.9"),
    # ('baseline108MN_Easy', r"C:\Users\Nicholas\Documents\Desktop\SCHOOL\GRADUATE Offline\CS 444\Final Project\CS-444-Final-Project\models\baseline108MN_easyaugmentation-FULL_Epoch200_Batch256_LR0.05_Momentum0.9"),
    # ('baseline108MN_Hard2', r"C:\Users\Nicholas\Documents\Desktop\SCHOOL\GRADUATE Offline\CS 444\Final Project\CS-444-Final-Project\models\baseline108MN_hardAugmentation2-FULL_Epoch200_Batch256_LR0.05_Momentum0.9"),
    # ('baseline108MN_Hard3', r"C:\Users\Nicholas\Documents\Desktop\SCHOOL\GRADUATE Offline\CS 444\Final Project\CS-444-Final-Project\models\baseline108MN_hardAugmentation3-FULL_Epoch200_Batch256_LR0.05_Momentum0.9"),
    
    # # RES
    # ('residualNetv1_vanilla', r"C:\Users\Nicholas\Documents\Desktop\SCHOOL\GRADUATE Offline\CS 444\Final Project\CS-444-Final-Project\models\residualNetv1_vanilla-FULL_Epoch200_Batch256_LR0.05_Momentum0.9"),
    # ('residualNetv1_Easy', r"C:\Users\Nicholas\Documents\Desktop\SCHOOL\GRADUATE Offline\CS 444\Final Project\CS-444-Final-Project\models\residualNetv1_easyaugmentation-FULL_Epoch200_Batch256_LR0.05_Momentum0.9"),
    # ('residualNetv1_Hard2', r"C:\Users\Nicholas\Documents\Desktop\SCHOOL\GRADUATE Offline\CS 444\Final Project\CS-444-Final-Project\models\residualNetv1_hardAugmentation2-FULL_Epoch200_Batch256_LR0.05_Momentum0.9"),
    # ('residualNetv1_Hard3', r"C:\Users\Nicholas\Documents\Desktop\SCHOOL\GRADUATE Offline\CS 444\Final Project\CS-444-Final-Project\models\residualNetv1_hardAugmentation3-FULL_Epoch200_Batch256_LR0.05_Momentum0.9"),
    
    # # DBN/BN
    # ('bottleneckResidualv2_Easy', r"C:\Users\Nicholas\Documents\Desktop\SCHOOL\GRADUATE Offline\CS 444\Final Project\CS-444-Final-Project\models\bottleneckResidualv2_easyaugmentation-FULL_Epoch200_Batch256_LR0.05_Momentum0.9"),
    # ('bottleneckResidualv2_Hard2', r"C:\Users\Nicholas\Documents\Desktop\SCHOOL\GRADUATE Offline\CS 444\Final Project\CS-444-Final-Project\models\bottleneckResidualv2_hardAugmentation2-FULL_Epoch200_Batch256_LR0.05_Momentum0.9"),
    # ('bottleneckResidualv2_Hard3', r"C:\Users\Nicholas\Documents\Desktop\SCHOOL\GRADUATE Offline\CS 444\Final Project\CS-444-Final-Project\models\bottleneckResidualv2_hardAugmentation3-FULL_Epoch200_Batch256_LR0.05_Momentum0.9"),
    
    # ('doubleBottleneckResidualv1_Easy', r"C:\Users\Nicholas\Documents\Desktop\SCHOOL\GRADUATE Offline\CS 444\Final Project\CS-444-Final-Project\models\doubleBottleneckResidualv1_easyaugmentation-FULL_Epoch200_Batch256_LR0.05_Momentum0.9"),
    # ('doubleBottleneckResidualv1_Hard2', r"C:\Users\Nicholas\Documents\Desktop\SCHOOL\GRADUATE Offline\CS 444\Final Project\CS-444-Final-Project\models\doubleBottleneckResidualv1_hardAugmentation2-FULL_Epoch200_Batch256_LR0.05_Momentum0.9"),
    # ('doubleBottleneckResidualv1_Hard3', r"C:\Users\Nicholas\Documents\Desktop\SCHOOL\GRADUATE Offline\CS 444\Final Project\CS-444-Final-Project\models\doubleBottleneckResidualv1_hardAugmentation3-FULL_Epoch200_Batch256_LR0.05_Momentum0.9"),
    
    # ('bottleneckResidualv1_Easy', r"C:\Users\Nicholas\Documents\Desktop\SCHOOL\GRADUATE Offline\CS 444\Final Project\CS-444-Final-Project\models\bottleneckResidualv1_easyaugmentation-FULL_Epoch200_Batch256_LR0.05_Momentum0.9"),
    # ('bottleneckResidualv1_Hard2', r"C:\Users\Nicholas\Documents\Desktop\SCHOOL\GRADUATE Offline\CS 444\Final Project\CS-444-Final-Project\models\bottleneckResidualv1_hardAugmentation2-FULL_Epoch200_Batch256_LR0.05_Momentum0.9"),
    # ('bottleneckResidualv1_Hard3', r"C:\Users\Nicholas\Documents\Desktop\SCHOOL\GRADUATE Offline\CS 444\Final Project\CS-444-Final-Project\models\bottleneckResidualv1_hardAugmentation3-FULL_Epoch200_Batch256_LR0.05_Momentum0.9"),
    
    # # HWY
    # ('highwayResidualv1_Easy', r"C:\Users\Nicholas\Documents\Desktop\SCHOOL\GRADUATE Offline\CS 444\Final Project\CS-444-Final-Project\models\highwayResidualv1_easyaugmentation-FULL_Epoch200_Batch256_LR0.05_Momentum0.9"),
    # ('highwayResidualv1_Hard2', r"C:\Users\Nicholas\Documents\Desktop\SCHOOL\GRADUATE Offline\CS 444\Final Project\CS-444-Final-Project\models\highwayResidualv1_hardAugmentation2-FULL_Epoch200_Batch256_LR0.05_Momentum0.9"),
    # ('highwayResidualv1_Hard3', r"C:\Users\Nicholas\Documents\Desktop\SCHOOL\GRADUATE Offline\CS 444\Final Project\CS-444-Final-Project\models\highwayResidualv1_hardAugmentation3-FULL_Epoch200_Batch256_LR0.05_Momentum0.9"),
    
    # ('highwayResidualv2_Easy', r"C:\Users\Nicholas\Documents\Desktop\SCHOOL\GRADUATE Offline\CS 444\Final Project\CS-444-Final-Project\models\highwayResidualv2_easyaugmentation-FULL_Epoch200_Batch256_LR0.05_Momentum0.9"),
    # ('highwayResidualv2_Hard2', r"C:\Users\Nicholas\Documents\Desktop\SCHOOL\GRADUATE Offline\CS 444\Final Project\CS-444-Final-Project\models\highwayResidualv2_hardAugmentation2-FULL_Epoch200_Batch256_LR0.05_Momentum0.9"),
    # ('highwayResidualv2_Hard3', r"C:\Users\Nicholas\Documents\Desktop\SCHOOL\GRADUATE Offline\CS 444\Final Project\CS-444-Final-Project\models\highwayResidualv2_hardAugmentation3-FULL_Epoch200_Batch256_LR0.05_Momentum0.9"),
    
    # # BRANCH
    # ('branchResidualv1_Easy', r"C:\Users\Nicholas\Documents\Desktop\SCHOOL\GRADUATE Offline\CS 444\Final Project\CS-444-Final-Project\models\branchResidualv1_easyaugmentation-FULL_Epoch200_Batch256_LR0.05_Momentum0.9"),
    # ('branchResidualv1_Hard2', r"C:\Users\Nicholas\Documents\Desktop\SCHOOL\GRADUATE Offline\CS 444\Final Project\CS-444-Final-Project\models\branchResidualv1_hardAugmentation2-FULL_Epoch200_Batch256_LR0.05_Momentum0.9"),
    # ('branchResidualv1_Hard3', r"C:\Users\Nicholas\Documents\Desktop\SCHOOL\GRADUATE Offline\CS 444\Final Project\CS-444-Final-Project\models\branchResidualv1_hardAugmentation3-FULL_Epoch200_Batch256_LR0.05_Momentum0.9"),
    
    # ('branchResidualv2_Easy', r"C:\Users\Nicholas\Documents\Desktop\SCHOOL\GRADUATE Offline\CS 444\Final Project\CS-444-Final-Project\models\branchResidualv2_easyaugmentation-FULL_Epoch200_Batch256_LR0.05_Momentum0.9"),
    # ('branchResidualv2_Hard2', r"C:\Users\Nicholas\Documents\Desktop\SCHOOL\GRADUATE Offline\CS 444\Final Project\CS-444-Final-Project\models\branchResidualv2_hardAugmentation2-FULL_Epoch200_Batch256_LR0.05_Momentum0.9"),
    # ('branchResidualv2_Hard3', r"C:\Users\Nicholas\Documents\Desktop\SCHOOL\GRADUATE Offline\CS 444\Final Project\CS-444-Final-Project\models\branchResidualv2_hardAugmentation3-FULL_Epoch200_Batch256_LR0.05_Momentum0.9"),
    
    # NORMS
    ('branchResidualNormv1_Easy', r"C:\Users\Nicholas\Documents\Desktop\SCHOOL\GRADUATE Offline\CS 444\Final Project\CS-444-Final-Project\models\branchResidualNormv1_easyaugmentation-FULL_Epoch200_Batch256_LR0.05_Momentum0.9"),
    ('branchResidualNormv1_Hard2', r"C:\Users\Nicholas\Documents\Desktop\SCHOOL\GRADUATE Offline\CS 444\Final Project\CS-444-Final-Project\models\branchResidualNormv1_hardAugmentation2-FULL_Epoch200_Batch256_LR0.05_Momentum0.9"),
    ('branchResidualNormv1_Hard3', r"C:\Users\Nicholas\Documents\Desktop\SCHOOL\GRADUATE Offline\CS 444\Final Project\CS-444-Final-Project\models\branchResidualNormv1_hardAugmentation3-FULL_Epoch200_Batch256_LR0.05_Momentum0.9"),
    
    ('branchResidualNormv2_Easy', r"C:\Users\Nicholas\Documents\Desktop\SCHOOL\GRADUATE Offline\CS 444\Final Project\CS-444-Final-Project\models\branchResidualNormv2_easyaugmentation-FULL_Epoch200_Batch256_LR0.05_Momentum0.9"),
    ('branchResidualNormv2_Hard2', r"C:\Users\Nicholas\Documents\Desktop\SCHOOL\GRADUATE Offline\CS 444\Final Project\CS-444-Final-Project\models\branchResidualNormv2_hardAugmentation2-FULL_Epoch200_Batch256_LR0.05_Momentum0.9"),
    ('branchResidualNormv2_Hard3', r"C:\Users\Nicholas\Documents\Desktop\SCHOOL\GRADUATE Offline\CS 444\Final Project\CS-444-Final-Project\models\branchResidualNormv2_hardAugmentation3-FULL_Epoch200_Batch256_LR0.05_Momentum0.9")
    
    
    ]


# Validate args
for modelName, path in NAMPATHS:
    model = getModel(modelName, printResult=False)
    assert os.path.exists(path=path)
print('Args validated!')

# Create dataset instances
fullDataset = CIFAR10Dataset(rootDirectory='cifar-10', csvFilename='trainLabels.csv', dataFolder='train', transform=None)

trainDataset, validationDataset, testDataset = determinsticSplitFullDataset(fullDataset, [0.8, 0.1, 0.1])

_, transform = getNormalizedTransforms(fullDataset=fullDataset, trainTransform=v2.Compose([v2.Identity()]), valTestTransform=v2.Compose([v2.Identity()]))
print(transform)

trainLoader = DataLoader(trainDataset, batch_size=256, shuffle=False, **dataLoaderParameters)
validationLoader = DataLoader(validationDataset, batch_size=256, shuffle=False, **dataLoaderParameters)
testLoader = DataLoader(testDataset, batch_size=256, shuffle=False, **dataLoaderParameters)

names = ['Train', 'Validation', 'Test']
loaders = [trainLoader, validationLoader, testLoader]

print()

for modelName, path in NAMPATHS:

    model = getModel(modelName, printResult=False)
    modelPath = os.path.normpath(path)
    loadedModel = torch.load(modelPath)
    model.load_state_dict(loadedModel)

    print(modelName)
    for name, loader in zip(names, loaders):
        
        _, accuracy = evaluateModel(model=model, dataloader=loader, transform=transform)
        print(f'{name} accuracy: {round(accuracy*100, 1)}')
    print()

Args validated!
tensor([0.4928, 0.4811, 0.4450])
tensor([0.2026, 0.1995, 0.2021])
Compose(
      Identity()
      Normalize(mean=[tensor(0.4928), tensor(0.4811), tensor(0.4450)], std=[tensor(0.2026), tensor(0.1995), tensor(0.2021)], inplace=False)
)

branchResidualNormv1_Easy
Train accuracy: 10.6
Validation accuracy: 11.1
Test accuracy: 10.8

branchResidualNormv1_Hard2


KeyboardInterrupt: 

Args validated!
tensor([0.4921, 0.4828, 0.4507])
tensor([0.2039, 0.2011, 0.2035])
Compose(
      Identity()
      Normalize(mean=[tensor(0.4921), tensor(0.4828), tensor(0.4507)], std=[tensor(0.2039), tensor(0.2011), tensor(0.2035)], inplace=False)
)

baseline130kN_vanilla
Train accuracy: 99.9
Validation accuracy: 80.8
Test accuracy: 80.4

baseline130kN_Easy
Train accuracy: 91.1
Validation accuracy: 82.8
Test accuracy: 82.6

baseline130kN_Hard2
Train accuracy: 87.1
Validation accuracy: 80.0
Test accuracy: 79.2

baseline130kN_Hard3
Train accuracy: 75.4
Validation accuracy: 72.6
Test accuracy: 72.9

baseline430kN_vanilla
Train accuracy: 100.0
Validation accuracy: 81.5
Test accuracy: 82.1

baseline430kN_Easy
Train accuracy: 96.5
Validation accuracy: 84.4
Test accuracy: 82.9

baseline430kN_Hard2
Train accuracy: 92.8
Validation accuracy: 79.2
Test accuracy: 78.7

baseline430kN_Hard3
Train accuracy: 83.2
Validation accuracy: 78.5
Test accuracy: 78.9

baseline108MN_vanilla
Train accuracy: 99.9
Validation accuracy: 80.9
Test accuracy: 80.5

baseline108MN_Easy
Train accuracy: 98.6
Validation accuracy: 82.6
Test accuracy: 82.4

baseline108MN_Hard2
Train accuracy: 96.9
Validation accuracy: 81.7
Test accuracy: 81.5

baseline108MN_Hard3
Train accuracy: 73.9
Validation accuracy: 70.7
Test accuracy: 70.8

residualNetv1_vanilla
Train accuracy: 100.0
Validation accuracy: 87.9
Test accuracy: 88.4

residualNetv1_Easy
Train accuracy: 98.9
Validation accuracy: 89.5
Test accuracy: 89.5

residualNetv1_Hard2
Train accuracy: 98.3
Validation accuracy: 89.1
Test accuracy: 88.8

residualNetv1_Hard3
Train accuracy: 88.8
Validation accuracy: 84.2
Test accuracy: 84.4

bottleneckResidualv2_Easy
Train accuracy: 97.6
Validation accuracy: 88.3
Test accuracy: 87.8

bottleneckResidualv2_Hard2
Train accuracy: 83.0
Validation accuracy: 76.2
Test accuracy: 77.5

bottleneckResidualv2_Hard3
Train accuracy: 85.7
Validation accuracy: 82.2
Test accuracy: 82.2

doubleBottleneckResidualv1_Easy
Train accuracy: 95.0
Validation accuracy: 86.5
Test accuracy: 86.3

doubleBottleneckResidualv1_Hard2
Train accuracy: 82.9
Validation accuracy: 76.4
Test accuracy: 76.9

doubleBottleneckResidualv1_Hard3
Train accuracy: 70.4
Validation accuracy: 68.2
Test accuracy: 68.4

bottleneckResidualv1_Easy
Train accuracy: 98.0
Validation accuracy: 88.5
Test accuracy: 88.5

bottleneckResidualv1_Hard2
Train accuracy: 93.9
Validation accuracy: 85.0
Test accuracy: 85.1

bottleneckResidualv1_Hard3
Train accuracy: 80.2
Validation accuracy: 77.5
Test accuracy: 77.2

highwayResidualv1_Easy
Train accuracy: 98.3
Validation accuracy: 88.4
Test accuracy: 88.2

highwayResidualv1_Hard2
Train accuracy: 92.8
Validation accuracy: 82.4
Test accuracy: 84.0

highwayResidualv1_Hard3
Train accuracy: 84.0
Validation accuracy: 80.4
Test accuracy: 81.0

highwayResidualv2_Easy
Train accuracy: 99.2
Validation accuracy: 89.1
Test accuracy: 89.2

highwayResidualv2_Hard2
Train accuracy: 99.2
Validation accuracy: 89.2
Test accuracy: 89.1

highwayResidualv2_Hard3
Train accuracy: 89.9
Validation accuracy: 84.2
Test accuracy: 84.1

branchResidualv1_Easy
Train accuracy: 95.1
Validation accuracy: 87.1
Test accuracy: 87.1

branchResidualv1_Hard2
Train accuracy: 90.2
Validation accuracy: 84.3
Test accuracy: 85.3

branchResidualv1_Hard3
Train accuracy: 84.9
Validation accuracy: 82.1
Test accuracy: 82.3

branchResidualv2_Easy
Train accuracy: 92.9
Validation accuracy: 86.9
Test accuracy: 87.2

branchResidualv2_Hard2
Train accuracy: 91.3
Validation accuracy: 85.3
Test accuracy: 85.0

branchResidualv2_Hard3
Train accuracy: 81.9
Validation accuracy: 79.2
Test accuracy: 79.7

branchResidualNormv1_Easy
Train accuracy: 55.1
Validation accuracy: 51.2
Test accuracy: 52.3

branchResidualNormv1_Hard2
Train accuracy: 21.3
Validation accuracy: 21.4
Test accuracy: 20.9

branchResidualNormv1_Hard3
Train accuracy: 26.5
Validation accuracy: 27.3
Test accuracy: 26.4

branchResidualNormv2_Easy
Train accuracy: 30.4
Validation accuracy: 30.4
Test accuracy: 29.2

branchResidualNormv2_Hard2
Train accuracy: 16.5
Validation accuracy: 16.1
Test accuracy: 17.3

branchResidualNormv2_Hard3
Train accuracy: 16.7
Validation accuracy: 16.7
Test accuracy: 17.0