## Imports

In [9]:
import os

import torch
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
from torchvision import transforms
from tqdm.notebook import trange

# from custom files
from dataset import CompCarsImageFolder, WrapperDataset, match_class_to_name, match_classes, TestImagesFromTextFile, split_sv_data
from resnet import ResNet, resnet_cfg
from resnet import test
from utils import fix_all_seeds, compute_mean_std_from_dataset

## Configuration

In [10]:
## Configuration
###### Set root to the image folder of CompCars dataset ######

### NOTE: ADAPT TO YOUR FOLDER STRUCTURE

## EDO'S PATHS
root_data = '/Volumes/EDO/NNDL/CompCars dataset/data/image/'
root_sv_data = '/Volumes/EDO/NNDL/CompCars dataset/data/sv_data/image/'
sv_data_make_model_names = "/Volumes/EDO/NNDL/CompCars dataset/data/sv_data/sv_make_model_name.txt"
data_make_names = '/Volumes/EDO/NNDL/CompCars dataset/data/sv_data/make_names.txt'
data_model_names = '/Volumes/EDO/NNDL/CompCars dataset/data/sv_data/model_names.txt'
file = '/Volumes/EDO/NNDL/CompCars dataset/data/sv_data/surveillance.txt'

## MICHAEL'S PATHS
# root_data = '../cars_data/data/image'
# root_sv_data = '../cars_data/sv_data/image'

# TODO: add custom test data files (download from WA group)
# sv_data_make_model_names = "../cars_data/sv_data/sv_make_model_name.txt"
# data_make_names = '../cars_data/sv_data/make_names.txt'
# data_model_names = '../cars_data/sv_data/model_names.txt'
# file = '../cars_data/sv_data/surveillance.txt'

#############################################################

### Hyperparam configuration
resnet_type = 'resnet18'                # 'resnet18', 'resnet34', 'resnet50'    

params = {                              ## Test Params
    'epoch_num': 50,                    # number of epochs
    'batch_size': 128,                  # for test dataloader
    'hierarchy': 1,                     # Choose 0 for manufacturer classification, 1 for model classification    
    'resnet': resnet_cfg[resnet_type],  # ResNet configuration
    'seed': 28,                         # for reproducibility
    'supcon': True                      # which algorithm to test
}
fix_all_seeds(seed=params['seed'])

### TODO: Set MODEL_PATH to model you want to test
if params['hierarchy'] == 0:
    if params['supcon']:
        # BUG: fails to load from this model path
        MODEL_PATH = './trained_models/lin_moco_resnet18_weights_car_makers_full_dataset_256_nomlp.pth'
    else:
        MODEL_PATH = './trained_models/resnet18_weights_car_makers_full_dataset_128.pth'
else:
    if params['supcon']:
        # BUG: fails to load from this model path
        MODEL_PATH = './trained_models/lin_moco_resnet18_weights_car_models_full_dataset_256_nomlp.pth'
    else:
        MODEL_PATH = './trained_models/resnet18_weights_car_models_full_dataset_128.pth'
    

### Device
if torch.cuda.is_available():
    params["device"] = torch.device("cuda")   # option for NVIDIA GPUs
elif torch.backends.mps.is_available():
    params["device"] = torch.device("mps")    # option for Mac M-series chips (GPUs)
else:
    params["device"] = torch.device("cpu")    # default option if none of the above devices are available

print("Device: {}".format(params["device"]))

Device: mps


## Load Model & Test data

In [11]:
## Load full dataset
total_set = CompCarsImageFolder(root_data, hierarchy=params['hierarchy'])
num_classes = len(total_set.classes)

## Load Model
saved_model = torch.load(MODEL_PATH, map_location=params['device'])
model = ResNet(saved_model['resnet']['block'], saved_model['resnet']['layers'], 
                num_classes).to(params['device'])
model.load_state_dict(saved_model['model_state_dict'])

if params['hierarchy'] == 0:
    class_names = data_make_names
else:
    class_names = data_model_names
    
## Load Test Set
matches_classes= match_class_to_name(class_names, total_set.class_to_idx, params['hierarchy'])              # Find actual names of car makers and models
sv_data = split_sv_data(sv_data_make_model_names)                                                           # Load and separate surveillance data    
test_class_to_idx = match_classes(matches_classes, sv_data, params['hierarchy'])                            # Find the dictionaries of car makers/models present in surveillance data

test_set = TestImagesFromTextFile(root_sv_data, 
                                    sv_data_txt=sv_data_make_model_names, 
                                    txt_file=file, 
                                    hierarchy=params['hierarchy'],
                                    matches=test_class_to_idx, 
                                    train_class_to_idx=total_set.class_to_idx)

## Normalization

# mean, std = [0.483, 0.471, 0.463], [0.297, 0.296, 0.302]                  # default for training and validation data (webbased images)
test_mean, test_std = [0.2943, 0.3006, 0.3072], [0.2455, 0.2456, 0.2529]    # default for test data (surveillance data) NOTE: worse results

# test_mean, test_std = compute_mean_std_from_dataset(test_set)
# print(f"Test dataset mean: {test_mean}")
# print(f"Training dataset std: {test_std}")

data_transforms = {
    ## Same that is used for validation
    'test': transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),     # Evaluate using 224x224 central part of image
        transforms.ToTensor(),
        transforms.Normalize(test_mean, test_std)
    ])
}


## Prepare test loader
wrapped_testset = WrapperDataset(test_set, transform=data_transforms['test'])
test_loader = DataLoader(wrapped_testset, batch_size=params['batch_size'], shuffle=False, num_workers=os.cpu_count())

## Run test

In [12]:
pbar_inside = trange(0, len(test_loader), desc="Test", bar_format="{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_inv_fmt}]")

test_results=test(test_loader, model, torch.nn.CrossEntropyLoss(), params["device"], pbar=pbar_inside)

Test:   0%|          | 0/348 [00:00<?, ?s/it]


-- TEST --
Test results:
 - Loss: 6.646 +- 0.388
 - Top-1-Accuracy: 0.17
 - Top-5-Accuracy: 0.27
 - Time: 135.95s
