# Evaluate

This notebook evaluates the network trained previous notebooks and analyzes the results.

-----

## Load Packages

In [1]:
# for auto-reloading external modules
# see http://stackoverflow.com/questions/1907993/autoreload-of-modules-in-ipython
%cd ..
%load_ext autoreload
%autoreload 2

C:\Users\Minjae\Desktop\EEG_Project


In [2]:
# Load some packages
import os
from copy import deepcopy
import hydra
from omegaconf import OmegaConf
from collections import OrderedDict
import wandb

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import pprint
import matplotlib.pyplot as plt

# custom package
from datasets.caueeg_script import build_dataset_for_train
import models
from models import *
from train.evaluate import check_accuracy
from train.evaluate import check_accuracy_extended
from train.evaluate import check_accuracy_extended_debug
from train.evaluate import check_accuracy_multicrop
from train.evaluate import check_accuracy_multicrop_extended
from train.visualize import draw_roc_curve
from train.visualize import draw_confusion
from train.visualize import draw_error_table

In [3]:
print('PyTorch version:', torch.__version__)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

if torch.cuda.is_available(): print('cuda is available.')
else: print('cuda is unavailable.') 

PyTorch version: 1.11.0+cu113
cuda is available.


-----

## Load the configuration used during the train phase

In [4]:
api = wandb.Api()
runs = api.runs('ipis-mjkim/caueeg-task2')

for run in runs:
    if run.name != 'e91v8u7y':
        continue
    print(run)    
    # re-evaluate
    best_last = run.summary['(Best, Last) Test Accuracy'][0].lower()
    model_path = os.path.join('local/checkpoint_temp', run.name, f'{best_last}_checkpoint.pt')
    ckpt = torch.load(model_path, map_location=device)
    
    model_state = ckpt['model_state']
    config = ckpt['config']
    optimizer = ckpt['optimizer_state']
    scheduler = ckpt['scheduler_state']
    
    pprint.pprint(config)
    
    if '_target_' in config:
        model = hydra.utils.instantiate(config).to(device)
    else:
        if 'block' in config:
            if models.resnet_1d.BottleneckBlock1D == config['block']:
                config['block'] = 'bottleneck'
            elif models.resnet_2d.Bottleneck2D == config['block']:
                config['block'] = 'bottleneck'
            else:
                config['block'] = 'basic'

        model = config['generator'](**config).to(device)

    if config.get('ddp', False):
        model_state_ddp = deepcopy(model_state)
        model_state = OrderedDict()
        for k, v in model_state_ddp.items():
            name = k[7:] # remove 'module.' of DataParallel/DistributedDataParallel
            model_state[name] = v
        model.load_state_dict(model_state)
    else:
        model.load_state_dict(model_state)
        
    config = ckpt['config']
    config.pop('cwd', None)
    config['dataset_path'] = 'local/dataset/02_Curated_Data_220419/'
    config['ddp'] = False
    config['eval'] = True
    config['device'] = device

    train_loader, val_loader, test_loader, multicrop_test_loader = build_dataset_for_train(config, verbose=False)
    # config['test_crop_multiple'] = 32
    

    repeat = round(200 / config['crop_multiple'])
    test_acc = check_accuracy(model, test_loader, 
                              config['preprocess_test'], config, repeat=repeat)
    # multi_test_acc = check_accuracy_multicrop(model, multicrop_test_loader, 
    #                                           config['preprocess_test'], config, repeat=repeat)

    print('test_acc:', test_acc)
    print(run.summary['Test Accuracy'])
    # run.summary['Multi-Crop Test Accuracy'] = multi_test_acc
    run.config['task'] = 'task2'
    run.config.update()
    run.summary.update()
    run.update()
    # print(run.summary['Multi-Crop Test Accuracy'])
    print('----')

<Run ipis-mjkim/caueeg-task2/coge5ofz (finished)>
test_acc: 62.190677966101696
61.54661016949152
----
<Run ipis-mjkim/caueeg-task2/nymygq8h (finished)>
test_acc: 60.279661016949156
60.282485875706215
----
<Run ipis-mjkim/caueeg-task2/lm6j0kiz (finished)>
test_acc: 61.86440677966102
62.11864406779661
----
<Run ipis-mjkim/caueeg-task2/arn6s5v2 (finished)>
test_acc: 59.26271186440678
59.583333333333336
----
<Run ipis-mjkim/caueeg-task2/yir9g4lz (finished)>
test_acc: 58.95762711864407
59.17372881355932
----
<Run ipis-mjkim/caueeg-task2/1lkk6af2 (finished)>
test_acc: 58.58474576271186
58.78531073446328
----
<Run ipis-mjkim/caueeg-task2/6rcq11k1 (finished)>
test_acc: 61.20762711864407
61.15112994350282
----
<Run ipis-mjkim/caueeg-task2/261dq5me (finished)>
test_acc: 59.95338983050848
59.851694915254235
----
<Run ipis-mjkim/caueeg-task2/bocotv1e (finished)>
test_acc: 59.51271186440678
59.39265536723164
----
<Run ipis-mjkim/caueeg-task2/ytb1raqf (finished)>
test_acc: 60.82627118644068
60.96045