# Evaluate

This notebook evaluates the network trained previous notebooks and analyzes the results.

-----

## Load Packages

In [1]:
# for auto-reloading external modules
# see http://stackoverflow.com/questions/1907993/autoreload-of-modules-in-ipython
%cd ..
%load_ext autoreload
%autoreload 2

C:\Users\Minjae\Desktop\EEG_Project


In [2]:
# Load some packages
import os
from copy import deepcopy
import hydra
from omegaconf import OmegaConf
from collections import OrderedDict

import numpy as np
from sklearn.metrics import classification_report
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import pandas as pd

import pprint
import wandb
import matplotlib
import matplotlib.pyplot as plt

# custom package
from datasets.caueeg_script import build_dataset_for_train
import models
from train.evaluate import check_accuracy
from train.evaluate import check_accuracy_and_throughput
from train.evaluate import check_accuracy_extended
from train.evaluate import check_accuracy_extended_debug
from train.evaluate import check_accuracy_multicrop
from train.evaluate import check_accuracy_multicrop_extended
from train.visualize import draw_roc_curve
from train.visualize import draw_confusion
from train.visualize import draw_class_wise_metrics
from train.visualize import draw_error_table
from train.visualize import annotate_heatmap

In [3]:
print('PyTorch version:', torch.__version__)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

if torch.cuda.is_available(): print('cuda is available.')
else: print('cuda is unavailable.') 

PyTorch version: 1.11.0+cu113
cuda is available.


-----

## List up the models to check accuracy

In [10]:
model_names = ['34oksw2q', ]
model_pool = []

for model_name in model_names:
    path = os.path.join(r'E:\CAUEEG\checkpoint', model_name, 'checkpoint.pt')
    try:
        ckpt = torch.load(path, map_location=device)
        model_pool.append({'name': model_name, 'path': path})
    except Exception as e:
        print(e)
        print(f'- {run.name}\'s checkpoint cannot be opened: {path}')
        
pprint.pprint([model_dict['name'] for model_dict in model_pool])

['34oksw2q']


In [11]:
# model_pool = []

# api = wandb.Api()
# runs = api.runs('ipis-mjkim/caueeg-task2-ablation')

# for run in runs:
#     path = os.path.join(r'E:\CAUEEG\checkpoint', run.name, 'checkpoint.pt')
#     try:
#         ckpt = torch.load(path, map_location=device)
#         model_pool.append({'name': run.name, 'path': path})
#     except Exception as e:
#         print(e)
#         print(f'- {run.name}\'s checkpoint cannot be opened: {path}')
        
# pprint.pprint([model_dict['name'] for model_dict in model_pool])

---

## Configurations

In [12]:
base_repeat = 200
verbose = False

-----

## Load and check accuracy

In [13]:
for model_dict in model_pool:
    # load and parse the checkpoint
    ckpt = torch.load(model_dict['path'], map_location=device)
    model_state = ckpt['model_state']
    config = ckpt['config']
    optimizer = ckpt['optimizer_state']
    scheduler = ckpt['scheduler_state']
    
    # initiate the model
    model = hydra.utils.instantiate(config).to(device)
    
    if config.get('ddp', False):
        model_state_ddp = deepcopy(model_state)
        model_state = OrderedDict()
        for k, v in model_state_ddp.items():
            name = k[7:]  # remove 'module.' of DataParallel/DistributedDataParallel
            model_state[name] = v
    
    model.load_state_dict(model_state)
    
    # reconfigure
    config.pop('cwd', 0)
    config['ddp'] = False
    config['crop_timing_analysis'] = True
    config['eval'] = True
    config['device'] = device
    repeat = round(base_repeat / config['crop_multiple'])
    
    # build dataset
    _ = build_dataset_for_train(config, verbose=verbose)
    train_loader = _[0]
    val_loader = _[1]
    test_loader = _[2]
    multicrop_test_loader = _[3]
    
    # train accuracy
    train_acc = check_accuracy(model, train_loader, 
                               config['preprocess_test'], config, repeat=repeat)
    model_dict['Train Accuracy'] = train_acc
    
    # val accuracy
    val_acc = check_accuracy(model, val_loader, 
                             config['preprocess_test'], config, repeat=repeat)
    model_dict['Validation Accuracy'] = val_acc
    
    # Test accuracy
    _ = check_accuracy_and_throughput(model, test_loader, 
                                      config['preprocess_test'], config, repeat=repeat)
    model_dict['Test Accuracy'] = _[0]
    model_dict['Throughput'] = _[1]

    # Multi-crop test accuracy
    multi_test_acc = check_accuracy_multicrop(model, multicrop_test_loader, 
                                              config['preprocess_test'], config, repeat=repeat)
    model_dict['Multi-Crop Test Accuracy'] = multi_test_acc

In [16]:
pprint.pprint(model_pool)

[{'Multi-Crop Test Accuracy': 95.60681520314547,
  'Test Accuracy': 87.55291612057667,
  'Throughput': 5.5176137532272255e-05,
  'Train Accuracy': 99.71800986842105,
  'Validation Accuracy': 87.31880733944953,
  'awgn': 0,
  'awgn_age': 0,
  'mgn': 0,
  'mixup': 0.0,
  'name': '34oksw2q',
  'path': 'E:\\CAUEEG\\checkpoint\\34oksw2q\\checkpoint.pt'}]


In [19]:
for model_dict in model_pool:
    # load and parse the checkpoint
    ckpt = torch.load(model_dict['path'], map_location=device)
    model_state = ckpt['model_state']
    config = ckpt['config']
    
    model_dict['model'] = config['model']
    model_dict['awgn'] = config['awgn']
    model_dict['awgn_age'] = config['awgn_age']
    model_dict['mgn'] = config['mgn']
    model_dict['mixup'] = config['mixup']
    # model_dict['dropout'] = config['dropout']
    model_dict['fc_stages'] = config['fc_stages']
    model_dict['use_age'] = config['use_age']
    model_dict['photic'] = config['photic']
    model_dict['EKG'] = config['EKG']
    model_dict['seed'] = config['seed']

In [20]:
pd.DataFrame(model_pool).to_csv('temp.csv')
pd.DataFrame(model_pool)

Unnamed: 0,name,path,Train Accuracy,Validation Accuracy,Test Accuracy,Throughput,Multi-Crop Test Accuracy,awgn,awgn_age,mgn,mixup,fc_stages,use_age,photic,EKG,seed,model
0,34oksw2q,E:\CAUEEG\checkpoint\34oksw2q\checkpoint.pt,99.71801,87.318807,87.552916,5.5e-05,95.606815,0,0,0,0.0,1,no,X,X,0,Ieracitano-CNN


In [21]:
1 / 0.000055

18181.81818181818