# Evaluate Ensemble

This notebook combines the classification results of some models via logit-ensembling way.

-----

## Load Packages

In [1]:
# for auto-reloading external modules
# see http://stackoverflow.com/questions/1907993/autoreload-of-modules-in-ipython
%cd ..
%load_ext autoreload
%autoreload 2

/home/imkbsz/workspace/eeg_analysis


In [2]:
# Load some packages
import os
import sys
import pickle
from copy import deepcopy
import hydra
from omegaconf import OmegaConf
from collections import OrderedDict

import numpy as np
from sklearn.metrics import classification_report
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import pandas as pd

import pprint
import wandb
import matplotlib
import matplotlib.pyplot as plt

# custom package
from datasets.caueeg_script import load_caueeg_config
from datasets.caueeg_script import make_dataloader
from datasets.caueeg_script import compose_transforms, compose_preprocess
from datasets.caueeg_script import load_caueeg_task_datasets
from datasets.caueeg_script import load_caueeg_task_split
import models
from train.evaluate import check_accuracy
from train.evaluate import check_accuracy_extended
from train.evaluate import check_accuracy_extended_debug
from train.evaluate import check_accuracy_multicrop
from train.evaluate import check_accuracy_multicrop_extended
from train.evaluate import calculate_confusion_matrix
from train.evaluate import calculate_confusion_matrix2
from train.evaluate import calculate_class_wise_metrics
from train.visualize import draw_roc_curve
from train.visualize import draw_confusion, draw_confusion2
from train.visualize import draw_class_wise_metrics
from train.visualize import draw_error_table
from train.visualize import annotate_heatmap

In [3]:
print('PyTorch version:', torch.__version__)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

if torch.cuda.is_available(): print('cuda is available.')
else: print('cuda is unavailable.') 

PyTorch version: 1.11.0
cuda is available.


-----

## List up the models to check accuracy

In [4]:
model_names = [
    'lo88puq7',  # 2D-VGG-19
    'l8524nml',  # 1D-ResNet-18   // 2s1700lg, l8524nml
    'v301o425',  # 1D-ResNeXt-50 
    '1sl7ipca',  # 2D-ResNeXt-50 
    'gvqyvmrj',  # 1D-ResNet-50 
    'gjkysllw',  # 2D-ViT-B-16 
    'xci5svkl',  # 2D-ResNet-18 
    '1vc80n1f',  # 1D-VGG-19 
    'syrx7bmk',  # 2D-ResNet-50 
]

# model_names = [
#     'tp7qn5hd',  # 1D-ResNeXt-50 
#     'q1hhkmik',  # 1D-ResNet-50
#     '0svudowu',  # 2D-ResNeXt-50
#     'ruqd8r7g',  # 2D-VGG-19
#     'dn10a6bv',  # 2D-ResNet-18
#     'atbhqdgg',  # 2D-ResNet-50
#     '4439k9pg',  # 1D-ResNet-18
#     'nemy8ikm',  # 1D-VGG-19
#     '1cdws3t5',  # 2D-ViT-B-16
# ]

model_pool = []

for model_name in model_names:
    path = os.path.join(r'local/checkpoint', model_name, 'checkpoint.pt')
    try:
        ckpt = torch.load(path, map_location=device)
        print(ckpt['config']['model'])
        model_pool.append({'name': model_name, 'path': path})
    except Exception as e:
        print(e)
        print(f'- checkpoint cannot be opened: {path}')
        
pprint.pprint([model_dict['name'] for model_dict in model_pool])

2D-VGG-19
1D-ResNet-18
1D-ResNeXt-50
2D-ResNeXt-50
1D-ResNet-50
2D-ViT-B-16
2D-ResNet-18
1D-VGG-19
2D-ResNet-50
['lo88puq7',
 'l8524nml',
 'v301o425',
 '1sl7ipca',
 'gvqyvmrj',
 'gjkysllw',
 'xci5svkl',
 '1vc80n1f',
 'syrx7bmk']


---

## Configurations

In [5]:
task = 'ensemble-dementia-class-score'
no_patient_overlap = False
eval_ensemble = True

base_repeat = 8 # 800
crop_multiple = 8
test_crop_multiple = 8

verbose = False
save_fig = False

eval_train = True
eval_val = True
eval_test = True

In [6]:
if no_patient_overlap:
    task += '-no-overlap'

In [7]:
def build_dataset_for_train(config, verbose=False):
    dataset_path = config["dataset_path"]
    if "cwd" in config:
        dataset_path = os.path.join(config["cwd"], dataset_path)

    config_dataset = load_caueeg_config(dataset_path)
    config.update(**config_dataset)

    if "run_mode" not in config.keys():
        print("\n" + "=" * 80 + "\n")
        print('WARNING: run_mode is not specified.\n \t==> run_mode is set to "train" automatically.')
        print("\n" + "=" * 80 + "\n")
        config["run_mode"] = "train"

    (
        transform,
        transform_multicrop,
    ) = compose_transforms(config, verbose=verbose)
    config["transform"] = transform
    config["transform_multicrop"] = transform_multicrop
    load_event = config["load_event"] or config.get("reject_events", False)

    (
        config_task,
        train_dataset,
        val_dataset,
        test_dataset,
    ) = load_caueeg_task_datasets(
        dataset_path=dataset_path,
        task=config["task"],
        load_event=load_event,
        file_format=config["file_format"],
        transform=transform,
        verbose=verbose,
    )
    config.update(**config_task)

    _, multicrop_train_dataset = load_caueeg_task_split(dataset_path=dataset_path, 
                                                       task=config["task"], split="train", 
                                                       load_event=load_event, file_format=config["file_format"], 
                                                       transform=transform_multicrop, verbose=verbose)
    _, multicrop_val_dataset = load_caueeg_task_split(dataset_path=dataset_path, 
                                                      task=config["task"], split="validation", 
                                                      load_event=load_event, file_format=config["file_format"], 
                                                      transform=transform_multicrop, verbose=verbose)
    _, multicrop_test_dataset = load_caueeg_task_split(dataset_path=dataset_path, 
                                                       task=config["task"], split="test", 
                                                       load_event=load_event, file_format=config["file_format"], 
                                                       transform=transform_multicrop, verbose=verbose)

    train_loader, val_loader, test_loader, multicrop_train_loader = make_dataloader(config, train_dataset, val_dataset, test_dataset, 
                                                                                    multicrop_train_dataset, verbose=False)
    train_loader, val_loader, test_loader, multicrop_val_loader = make_dataloader(config, train_dataset, val_dataset, test_dataset, 
                                                                                  multicrop_val_dataset, verbose=False)
    train_loader, val_loader, test_loader, multicrop_test_loader = make_dataloader(config, train_dataset, val_dataset, test_dataset, 
                                                                                   multicrop_test_dataset, verbose=False)

    (
        preprocess_train,
        preprocess_test,
    ) = compose_preprocess(config, train_loader, verbose=verbose)
    config["preprocess_train"] = preprocess_train
    config["preprocess_test"] = preprocess_test
    config["in_channels"] = preprocess_train(next(iter(train_loader)))["signal"].shape[1]
    config["out_dims"] = len(config["class_label_to_name"])

    if verbose:
        for i_batch, sample_batched in enumerate(train_loader):
            # preprocessing includes to-device operation
            preprocess_train(sample_batched)

            print(
                i_batch,
                sample_batched["signal"].shape,
                sample_batched["age"].shape,
                sample_batched["class_label"].shape,
            )

            if i_batch > 3:
                break
        print("\n" + "-" * 100 + "\n")

    return (
        multicrop_train_loader,
        multicrop_val_loader,
        multicrop_test_loader,
    )


-----

## Evaluate each model and accumulate the logits

In [8]:
for model_dict in model_pool:
    # load and parse the checkpoint
    ckpt = torch.load(model_dict['path'], map_location=device)
    model_state = ckpt['model_state']
    config = ckpt['config']
    
    model_dict['model'] = config['model']
    model_dict['num_params'] = config.get('num_params', '???')
    model_dict['model size (MiB)'] = sys.getsizeof(pickle.dumps(model_state)) / (1024 * 1024)

    if no_patient_overlap:
        config['task'] += '-no-overlap' 

    
    if '220419' in config['dataset_path']:
        config['dataset_path'] = './local/dataset/caueeg-dataset/'
    config['run_mode'] = 'eval'
    
    model_dict['seq_length'] = config['seq_length']
    model_dict['use_age'] = config['use_age']
    model_dict['photic'] = config['photic']
    model_dict['EKG'] = config['EKG']

    model_dict['awgn'] = config.get('awgn', 0)
    model_dict['awgn_age'] = config.get('awgn_age', 0)
    model_dict['mgn'] = config.get('mgn', 0)
    model_dict['mixup'] = config.get('mixup', 0)
    model_dict['dropout'] = config.get('dropout', 0)
    model_dict['weight_decay'] = config.get('weight_decay', '???')
    model_dict['fc_stages'] = config.get('fc_stages', 1)
    model_dict['activation'] = config.get('activation', 0)

    model_dict['minibatch'] = round(config['minibatch'])
    model_dict['total_samples'] = round(config.get('total_samples', config['iterations'] * config['minibatch']))
    model_dict['base_lr'] = config.get('base_lr', config.get('LR', '???'))
    model_dict['lr_scheduler_type'] = config.get('lr_scheduler_type', 'constant_with_decay')
    model_dict['warmup_steps'] = config.get('warmup_steps', '???')
    model_dict['seed'] = config.get('seed', '???')
    
    print('- checking for', model_dict['name'], config['model'], '...')
    
    # initiate the model
    if '_target_' in config:
        model = hydra.utils.instantiate(config).to(device)
    elif type(config['generator']) is str:
        config['generator'] = getattr(models, config['generator'].split('.')[-1])
        if 'block' in config:
            config['block'] = getattr(models, config['block'].split('.')[-1])
        model = config['generator'](**config).to(device)
    else:
        if 'block' in config:
            if config['block'] == models.resnet_1d.BottleneckBlock1D:
                config['block'] = 'bottleneck'
            elif config['block'] == models.resnet_2d.Bottleneck2D:
                config['block'] = 'bottleneck'
            elif config['block'] == models.resnet_1d.BasicBlock1D:
                config['block'] = 'basic'
            elif config['block'] == models.resnet_2d.BasicBlock2D:
                config['block'] = 'basic'
                
        model = config['generator'](**config).to(device)
    
    if config.get('ddp', False):
        model_state_ddp = deepcopy(model_state)
        model_state = OrderedDict()
        for k, v in model_state_ddp.items():
            name = k[7:]  # remove 'module.' of DataParallel/DistributedDataParallel
            model_state[name] = v
    
    model.load_state_dict(model_state)
    
    # reconfigure and update
    config.pop('cwd', 0)
    config['ddp'] = False
    config['crop_multiple'] = crop_multiple
    config['test_crop_multiple'] = test_crop_multiple
    config['crop_timing_analysis'] = False
    config['eval'] = True
    config['device'] = device
    
    repeat = round(base_repeat / crop_multiple)
    model_dict['repeat'] = repeat
    model_dict['crop_multiple'] = crop_multiple
    model_dict['test_crop_multiple'] = test_crop_multiple
    
    # build dataset
    _ = build_dataset_for_train(config, verbose=verbose)
    train_loader = _[0]
    val_loader = _[1]
    test_loader = _[2]
    
    # warm-up stage
    _ = check_accuracy_extended(model, test_loader, 
                                config['preprocess_test'], config, repeat=1)
    
    # Multi-crop train accuracy
    if eval_train:
        _ = check_accuracy_multicrop_extended(model, train_loader, 
                                              config['preprocess_test'], config, repeat=repeat)
        model_dict['Multi-Crop Train Throughput'] = _[4]
        model_dict['Multi-Crop Train Accuracy'] = _[0]
        model_dict['Multi-Crop Train Score'] = _[1]
        model_dict['Multi-Crop Train Target'] = _[2]
        
    # Multi-crop val accuracy
    if eval_val:
        _ = check_accuracy_multicrop_extended(model, val_loader, 
                                              config['preprocess_test'], config, repeat=repeat)
        model_dict['Multi-Crop Val Throughput'] = _[4]
        model_dict['Multi-Crop Val Accuracy'] = _[0]
        model_dict['Multi-Crop Val Score'] = _[1]
        model_dict['Multi-Crop Val Target'] = _[2]
        
    # Multi-crop test accuracy
    if eval_test:
        _ = check_accuracy_multicrop_extended(model, test_loader, 
                                              config['preprocess_test'], config, repeat=repeat)
        model_dict['Multi-Crop Test Throughput'] = _[4]
        model_dict['Multi-Crop Test Accuracy'] = _[0]
        model_dict['Multi-Crop Test Score'] = _[1]
        model_dict['Multi-Crop Test Target'] = _[2]
            
print('==== Finished ====')

- checking for lo88puq7 2D-VGG-19 ...
- checking for l8524nml 1D-ResNet-18 ...
- checking for v301o425 1D-ResNeXt-50 ...
- checking for 1sl7ipca 2D-ResNeXt-50 ...
- checking for gvqyvmrj 1D-ResNet-50 ...
- checking for gjkysllw 2D-ViT-B-16 ...
- checking for xci5svkl 2D-ResNet-18 ...
- checking for 1vc80n1f 1D-VGG-19 ...
- checking for syrx7bmk 2D-ResNet-50 ...
==== Finished ====


## Conduct ensemble

In [9]:
if eval_ensemble:
    if 'Ensemble' == model_pool[-1]['name']:
        model_pool.remove(model_pool[-1])     

    # conduct ensembling
    if eval_train:
        ensem_multi_train_score = np.zeros_like(model_pool[0]['Multi-Crop Train Score'])
        ensem_multi_train_latency = 0
        
    if eval_val:
        ensem_multi_val_score = np.zeros_like(model_pool[0]['Multi-Crop Val Score'])
        ensem_multi_val_latency = 0

    if eval_test:
        ensem_multi_test_score = np.zeros_like(model_pool[0]['Multi-Crop Test Score'])
        ensem_multi_test_latency = 0

    ensem_params = 0
    ensem_model_size = 0

    for model_dict in model_pool:        
        ensem_params += model_dict['num_params']
        ensem_model_size += model_dict['model size (MiB)']

        if eval_train:
            ensem_multi_train_score += model_dict['Multi-Crop Train Score'] / len(model_pool)
            ensem_multi_train_latency += 1 / model_dict['Multi-Crop Train Throughput']

        if eval_val:
            ensem_multi_val_score += model_dict['Multi-Crop Val Score'] / len(model_pool)
            ensem_multi_val_latency += 1 / model_dict['Multi-Crop Val Throughput']

        
        if eval_test:
            ensem_multi_test_score += model_dict['Multi-Crop Test Score'] / len(model_pool)
            ensem_multi_test_latency += 1 / model_dict['Multi-Crop Test Throughput']
            
    if eval_train:
        # confusion matrix
        pred = ensem_multi_train_score.argmax(axis=-1)
        target = model_pool[0]['Multi-Crop Train Target']
        ensem_multi_train_acc = 100.0 * (pred.squeeze() == target).sum() / pred.shape[0]

    if eval_val:
        # confusion matrix
        pred = ensem_multi_val_score.argmax(axis=-1)
        target = model_pool[0]['Multi-Crop Val Target']
        ensem_multi_val_acc = 100.0 * (pred.squeeze() == target).sum() / pred.shape[0]

    if eval_test:
        # confusion matrix
        pred = ensem_multi_test_score.argmax(axis=-1)
        target = model_pool[0]['Multi-Crop Test Target']
        ensem_multi_test_acc = 100.0 * (pred.squeeze() == target).sum() / pred.shape[0]
        
    # summarize the ensemble results
    ensem_dict = {}

    ensem_dict['name'] = 'Ensemble'
    ensem_dict['num_params'] = ensem_params
    ensem_dict['model size (MiB)'] = ensem_model_size

    if eval_train:
        ensem_dict['Multi-Crop Train Throughput'] = 1 / ensem_multi_train_latency
        ensem_dict['Multi-Crop Train Accuracy'] = ensem_multi_train_acc
        ensem_dict['Multi-Crop Train Score'] = ensem_multi_train_score

    if eval_val:
        ensem_dict['Multi-Crop Val Throughput'] = 1 / ensem_multi_val_latency
        ensem_dict['Multi-Crop Val Accuracy'] = ensem_multi_val_acc
        ensem_dict['Multi-Crop Val Score'] = ensem_multi_val_score

    if eval_test:
        ensem_dict['Multi-Crop Test Throughput'] = 1 / ensem_multi_test_latency
        ensem_dict['Multi-Crop Test Accuracy'] = ensem_multi_test_acc
        ensem_dict['Multi-Crop Test Score'] = ensem_multi_test_score

    model_pool.append(ensem_dict)        

In [10]:
model_pool_frame = deepcopy(model_pool)

for model_dict in model_pool_frame:
    model_dict.pop('Train Score', None)
    model_dict.pop('Train Target', None)
    model_dict.pop('Multi-Crop Train Score', None)
    model_dict.pop('Multi-Crop Train Target', None)
    model_dict.pop('Val Score', None)
    model_dict.pop('Val Target', None)
    model_dict.pop('Multi-Crop Val Score', None)
    model_dict.pop('Multi-Crop Val Target', None)
    model_dict.pop('Test Score', None)
    model_dict.pop('Test Target', None)
    model_dict.pop('Multi-Crop Test Score', None)
    model_dict.pop('Multi-Crop Test Target', None)
    
pd.DataFrame(model_pool_frame).to_csv(f'local/output/{task}.csv')
pd.DataFrame(model_pool_frame)

Unnamed: 0,name,path,model,num_params,model size (MiB),seq_length,use_age,photic,EKG,awgn,...,seed,repeat,crop_multiple,test_crop_multiple,Multi-Crop Train Throughput,Multi-Crop Train Accuracy,Multi-Crop Val Throughput,Multi-Crop Val Accuracy,Multi-Crop Test Throughput,Multi-Crop Test Accuracy
0,lo88puq7,local/checkpoint/lo88puq7/checkpoint.pt,2D-VGG-19,20184131,77.073947,4000.0,conv,X,O,0.04794,...,0,1.0,8.0,8.0,279.923334,100.0,277.724131,67.226891,278.252892,67.79661
1,l8524nml,local/checkpoint/l8524nml/checkpoint.pt,1D-ResNet-18,11394051,43.551189,2000.0,conv,O,O,0.004873,...,0,1.0,8.0,8.0,1448.368414,99.473684,1433.228223,62.184874,1430.853026,69.491525
2,v301o425,local/checkpoint/v301o425/checkpoint.pt,1D-ResNeXt-50,25650051,98.228438,4000.0,fc,X,O,0.037536,...,0,1.0,8.0,8.0,542.062034,100.0,540.304496,66.386555,540.804528,68.644068
3,1sl7ipca,local/checkpoint/1sl7ipca/checkpoint.pt,2D-ResNeXt-50,25886467,99.137356,4000.0,fc,O,X,0.10395,...,???,1.0,8.0,8.0,155.030883,98.947368,154.879775,69.747899,155.612666,68.644068
4,gvqyvmrj,local/checkpoint/gvqyvmrj/checkpoint.pt,1D-ResNet-50,26178179,100.185495,1000.0,fc,O,X,0.012513,...,???,1.0,8.0,8.0,1286.20875,99.684211,1284.985731,63.02521,1281.997365,67.79661
5,gjkysllw,local/checkpoint/gjkysllw/checkpoint.pt,2D-ViT-B-16,90054147,343.588083,12000.0,conv,X,O,0.029602,...,0,1.0,8.0,8.0,44.386958,100.0,44.408312,60.504202,44.371241,66.101695
6,xci5svkl,local/checkpoint/xci5svkl/checkpoint.pt,2D-ResNet-18,11425155,43.664536,2000.0,fc,X,O,0.074615,...,???,1.0,8.0,8.0,424.566894,100.0,425.86946,64.705882,424.17685,66.101695
7,1vc80n1f,local/checkpoint/1vc80n1f/checkpoint.pt,1D-VGG-19,20205251,77.157321,2000.0,conv,X,O,0.00812,...,0,1.0,8.0,8.0,1259.372863,99.789474,1207.736137,61.344538,1212.794715,67.79661
8,syrx7bmk,local/checkpoint/syrx7bmk/checkpoint.pt,2D-ResNet-50,25729475,98.468181,4000.0,conv,X,O,0.002375,...,0,1.0,8.0,8.0,181.748415,99.894737,181.977453,62.184874,181.534706,67.79661
9,Ensemble,,,256706907,981.054547,,,,,,...,,,,,22.463787,100.0,22.434757,72.268908,22.433679,75.423729


##### 

In [11]:
serials = np.array([])
for sample_batched in train_loader:
    serials = np.concatenate((serials, sample_batched['serial'][::test_crop_multiple]))

for model_dict in model_pool:
    if model_dict['name'] != 'Ensemble':
        continue
    ensemble_scores = {}

    for i, serial in enumerate(serials):
        ensemble_scores[serial] = torch.tensor(model_dict['Multi-Crop Train Score'][i])

    # pprint.pprint(model_dict)
    torch.save(ensemble_scores, f'local/{task}.pt')

# pprint.pprint(ensemble_scores)

In [15]:
a = torch.load(f'local/{task}.pt', map_location=config["device"])

teacher_score = torch.zeros((max([int(k) for k in a.keys()]) + 1, *[*a.values()][0].shape))

for k, v in a.items():
    teacher_score[int(k)] = v

In [17]:
teacher_score[[2, 5, 12,31]]

tensor([[ 4.4546, -2.6231, -3.3622],
        [-3.4511,  3.4890, -3.4008],
        [ 3.1652, -2.4867, -2.4365],
        [-2.6721, -2.0320,  3.5146]])

In [22]:
teacher_score[[*list(map(int, ['00001','00005', '00002', '00005', '00100']))]]

tensor([[ 0.0000,  0.0000,  0.0000],
        [-3.4511,  3.4890, -3.4008],
        [ 4.4546, -2.6231, -3.3622],
        [-3.4511,  3.4890, -3.4008],
        [ 3.0729, -2.0876, -2.4810]])

In [23]:
teacher_score[[1,2, 5, 100]]

tensor([[ 0.0000,  0.0000,  0.0000],
        [ 4.4546, -2.6231, -3.3622],
        [-3.4511,  3.4890, -3.4008],
        [ 3.0729, -2.0876, -2.4810]])