# Evaluate Ensemble

This notebook combines the classification results of some models via logit-ensembling way.

-----

## Load Packages

In [1]:
# for auto-reloading external modules
# see http://stackoverflow.com/questions/1907993/autoreload-of-modules-in-ipython
%cd ..
%load_ext autoreload
%autoreload 2

C:\Users\Minjae\Desktop\EEG_Project


In [2]:
# Load some packages
import os
import sys
import pickle
from copy import deepcopy
import hydra
from omegaconf import OmegaConf
from collections import OrderedDict

import numpy as np
from sklearn.metrics import classification_report
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import pandas as pd

import pprint
import wandb
import matplotlib
import matplotlib.pyplot as plt

# custom package
from datasets.caueeg_script import build_dataset_for_train
import models
from train.evaluate import check_accuracy
from train.evaluate import check_accuracy_extended
from train.evaluate import check_accuracy_extended_debug
from train.evaluate import check_accuracy_multicrop
from train.evaluate import check_accuracy_multicrop_extended
from train.evaluate import calculate_confusion_matrix
from train.evaluate import calculate_confusion_matrix2
from train.evaluate import calculate_class_wise_metrics
from train.visualize import draw_roc_curve
from train.visualize import draw_confusion, draw_confusion2
from train.visualize import draw_class_wise_metrics
from train.visualize import draw_error_table
from train.visualize import annotate_heatmap
from train.visualize import draw_heatmap

In [3]:
print('PyTorch version:', torch.__version__)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

if torch.cuda.is_available(): print('cuda is available.')
else: print('cuda is unavailable.') 

PyTorch version: 2.0.0+cu117
cuda is available.


-----

## List up the models to check accuracy

In [21]:
model_names = [
    '1vc80n1f',  # 1D-VGG-19 
    'l8524nml',  # 1D-ResNet-18   // 2s1700lg, l8524nml
    'gvqyvmrj',  # 1D-ResNet-50 
    'v301o425',  # 1D-ResNeXt-50 
    'lo88puq7',  # 2D-VGG-19
    'xci5svkl',  # 2D-ResNet-18 
    'syrx7bmk',  # 2D-ResNet-50 
    '1sl7ipca',  # 2D-ResNeXt-50 
    'gjkysllw',  # 2D-ViT-B-16 
]

# model_names = [
#     'nemy8ikm',  # 1D-VGG-19
#     '4439k9pg',  # 1D-ResNet-18
#     'q1hhkmik',  # 1D-ResNet-50
#     'tp7qn5hd',  # 1D-ResNeXt-50 
#     'ruqd8r7g',  # 2D-VGG-19
#     'dn10a6bv',  # 2D-ResNet-18
#     'atbhqdgg',  # 2D-ResNet-50
#     '0svudowu',  # 2D-ResNeXt-50
#     '1cdws3t5',  # 2D-ViT-B-16
# ]

model_pool = []

for model_name in model_names:
    path = os.path.join(r'E:\CAUEEG\checkpoint', model_name, 'checkpoint.pt')
    try:
        ckpt = torch.load(path, map_location=device)
        print(ckpt['config']['model'])
        model_pool.append({'name': model_name, 'path': path})
    except Exception as e:
        print(e)
        print(f'- checkpoint cannot be opened: {path}')
        
pprint.pprint([model_dict['name'] for model_dict in model_pool])

1D-VGG-19
1D-ResNet-18
1D-ResNet-50
1D-ResNeXt-50
2D-VGG-19
2D-ResNet-18
2D-ResNet-50
2D-ResNeXt-50
2D-ViT-B-16
['1vc80n1f',
 'l8524nml',
 'gvqyvmrj',
 'v301o425',
 'lo88puq7',
 'xci5svkl',
 'syrx7bmk',
 '1sl7ipca',
 'gjkysllw']


---

## Configurations

In [22]:
task = 'cross-correlation'
no_patient_overlap = False
eval_ensemble = True

base_repeat = 800 # 800
crop_multiple = 8
test_crop_multiple = 8

verbose = False

eval_train = False
eval_val = True
eval_test = False

In [23]:
@torch.no_grad()
def estimate_score_epoch(model, loader, preprocess, config, repeat=1):
    # evaluation mode
    model.eval()

    N = len(loader.dataset)
    C = len(config['class_label_to_name'])
    
    crop_multiple = config['crop_multiple']
    
    embeddings = torch.zeros((repeat, N, C), device=device)
    
    for k in range(repeat):
        for i, sample_batched in enumerate(loader):
            preprocess(sample_batched)
            x = sample_batched['signal']
            age = sample_batched['age']

            output = model.compute_feature_embedding(x, age, target_from_last=0)
            current_minibatch = x.shape[0]
            
            for m in range(current_minibatch // crop_multiple):
                ind1 = m*crop_multiple
                ind2 = (m + 1)*crop_multiple
                ind3 = (config['minibatch'] // crop_multiple)*i + m
                embeddings[k, ind3] = output[ind1:ind2].mean(dim=0, keepdim=True)
                
    return embeddings.mean(dim=0)

-----

## Evaluate each model and accumulate the logits

In [None]:
for model_dict in model_pool:
    # load and parse the checkpoint
    ckpt = torch.load(model_dict['path'], map_location=device)
    model_state = ckpt['model_state']
    config = ckpt['config']
    
    model_dict['task'] = config['task']
    model_dict['model'] = config['model']
    
    if '220419' in config['dataset_path']:
        config['dataset_path'] = './local/dataset/caueeg-dataset/'
    config['run_mode'] = 'eval'
    
    print('- checking for', model_dict['name'], config['model'], '...')
    
    # initiate the model
    if '_target_' in config:
        model = hydra.utils.instantiate(config).to(device)
    elif type(config['generator']) is str:
        config['generator'] = getattr(models, config['generator'].split('.')[-1])
        if 'block' in config:
            config['block'] = getattr(models, config['block'].split('.')[-1])
        model = config['generator'](**config).to(device)
    else:
        if 'block' in config:
            if config['block'] == models.resnet_1d.BottleneckBlock1D:
                config['block'] = 'bottleneck'
            elif config['block'] == models.resnet_2d.Bottleneck2D:
                config['block'] = 'bottleneck'
            elif config['block'] == models.resnet_1d.BasicBlock1D:
                config['block'] = 'basic'
            elif config['block'] == models.resnet_2d.BasicBlock2D:
                config['block'] = 'basic'
                
        model = config['generator'](**config).to(device)
    
    if config.get('ddp', False):
        model_state_ddp = deepcopy(model_state)
        model_state = OrderedDict()
        for k, v in model_state_ddp.items():
            name = k[7:]  # remove 'module.' of DataParallel/DistributedDataParallel
            model_state[name] = v
    
    model.load_state_dict(model_state)
    
    # reconfigure and update
    config.pop('cwd', 0)
    config['ddp'] = False
    config['crop_multiple'] = crop_multiple
    config['test_crop_multiple'] = test_crop_multiple  # do not use
    config['crop_timing_analysis'] = False
    config['eval'] = True
    config['device'] = device

    repeat = round(base_repeat / crop_multiple)
    
    # build dataset
    _ = build_dataset_for_train(config, verbose=verbose)
    train_loader = _[0]
    val_loader = _[1]
    test_loader = _[2]
    
    
    path = f"local/output/cross-correlation/{model_dict['name']}"
    os.makedirs(path, exist_ok=True)
    
    # train accuracy
    if eval_train:
        train_embeddings = estimate_score_epoch(model, train_loader, 
                                                config['preprocess_test'], config, repeat=repeat)
        torch.save(train_embeddings, os.path.join(path, 'train_embeddings.pt'))

    # val accuracy
    if eval_val:
        val_embeddings = estimate_score_epoch(model, val_loader, 
                                              config['preprocess_test'], config, repeat=repeat)
        torch.save(val_embeddings, os.path.join(path, 'val_embeddings.pt'))
    
    # Test accuracy
    if eval_test:
        test_embeddings = estimate_score_epoch(model, test_loader, 
                                               config['preprocess_test'], config, repeat=repeat)
        torch.save(test_embeddings, os.path.join(path, 'test_embeddings.pt'))


print('==== Finished ====')

## Calculate Cross-Correlation

In [24]:
for model_dict in model_pool:
    # load and parse the checkpoint
    ckpt = torch.load(model_dict['path'], map_location=device)
    model_state = ckpt['model_state']
    config = ckpt['config']
    
    model_dict['task'] = config['task']
    model_dict['model'] = config['model']
    
    if 'ViT' in model_dict['model']:
        model_dict['model'] = model_dict['model'].replace('2D-ViT', 'ViT')

    if '220419' in config['dataset_path']:
        config['dataset_path'] = './local/dataset/caueeg-dataset/'
    config['run_mode'] = 'eval'

    
    path = f"local/output/cross-correlation/{model_dict['name']}"
    if eval_train:
        model_dict['train_embeddings'] = torch.load(os.path.join(path, 'train_embeddings.pt')).cpu().numpy()
                
    if eval_val:
        model_dict['val_embeddings'] = torch.load(os.path.join(path, 'val_embeddings.pt')).cpu().numpy()
        
    if eval_test:
        model_dict['test_embeddings'] = torch.load(os.path.join(path, 'test_embeddings.pt')).cpu().numpy()
        
        
    if 'ViT' in model_dict['model']:
        model_dict['model'].replace('2D-ViT', 'ViT')

In [25]:
if eval_ensemble:
    model_dict = {'name': 'Ensemble', 'model': 'Ensemble'}
    if eval_train:
        model_dict['train_embeddings'] = np.mean([md['train_embeddings'] for md in model_pool], axis=0)
    if eval_val:
        model_dict['val_embeddings'] = np.mean([md['val_embeddings'] for md in model_pool], axis=0)
    if eval_test:
        model_dict['test_embeddings'] = np.mean([md['test_embeddings'] for md in model_pool], axis=0)
        
    model_pool.append(model_dict)

In [26]:
from scipy.stats import pearsonr

def pearson_correlation(model_dict, target_embeddings):
    result = np.zeros((len(model_dict), len(model_dict)))
    C = model_dict[0][target_embeddings].shape[1]
    
    for i, md_1 in enumerate(model_dict):
        for k, md_2 in enumerate(model_dict):
            summation = 0.0
            for c in range(C):
                summation += pearsonr(md_1[target_embeddings][:, c], 
                                      md_2[target_embeddings][:, c])[0]
            result[i, k] = summation / C
            
    return result

In [27]:
from sklearn.metrics.pairwise import cosine_similarity

def cosine_correlation(model_dict, target_embeddings):
    result = np.zeros((len(model_dict), len(model_dict)))
    C = model_dict[0][target_embeddings].shape[1]
    
    
    for i, md_1 in enumerate(model_dict):
        for k, md_2 in enumerate(model_dict):
            cos_sim = cosine_similarity(md_1[target_embeddings], 
                                        md_2[target_embeddings])
            result[i, k] = np.mean(cos_sim)
            
    return result

In [27]:
# Other settings
%matplotlib inline
%config InlineBackend.figure_format = 'retina' # cleaner text
import scienceplots

plt.style.use('default') 
# ['Solarize_Light2', '_classic_test_patch', 'bmh', 'classic', 'dark_background', 'fast', 
#  'fivethirtyeight', 'ggplot', 'grayscale', 'seaborn', 'seaborn-bright', 'seaborn-colorblind', 
#  'seaborn-dark', 'seaborn-dark-palette', 'seaborn-darkgrid', 'seaborn-deep', 'seaborn-muted', 
#  'seaborn-notebook', 'seaborn-paper', 'seaborn-pastel', 'seaborn-poster', 'seaborn-talk', 
#  'seaborn-ticks', 'seaborn-white', 'seaborn-whitegrid', 'tableau-colorblind10']

plt.rcParams['image.interpolation'] = 'bicubic'
plt.rcParams["font.family"] = 'Roboto Slab' # 'NanumGothic' # for Hangul in Windows
plt.style.use('classic') 
plt.style.use('default') 
plt.style.use('default') # default, ggplot, fivethirtyeight, bmh, dark_background, classic
plt.rcParams.update({'font.size': 14})
plt.rcParams.update({'font.family': 'Roboto Slab'})
plt.rcParams["savefig.dpi"] = 1200

def draw_correlation(correlation, model_names, title='', save_path=None):
    with plt.style.context(['science', 'default']):  # science, ieee, default, fivethirtyeight
        plt.rcParams["font.family"] = 'Roboto Slab' # 'NanumGothic' # for Hangul in Windows
        plt.rcParams.update({"font.size": 16})
        plt.rcParams["savefig.dpi"] = 1200

        H = len(model_names) + 0.5
        W = len(model_names) + 0.5
        
        fig = plt.figure(num=1, clear=True, figsize=(W, H), constrained_layout=True)
        ax = fig.add_subplot(1, 1, 1)

        data = correlation
        im = draw_heatmap(
            data,
            model_names,
            model_names,
            ax=ax,
            # imshow_kw={"alpha": 0.9, "cmap": "YlOrRd"},  # jet, YlOrRd, RdPu
            imshow_kw={"alpha": 0.9, "cmap": "coolwarm"},  # jet, YlOrRd, RdPu
            draw_cbar=False,
            cbar_label="",
            cbar_kw={},
        )
        annotate_heatmap(im, anno_format="{x:.2f}", text_colors=("black", "white"), threshold=0.9)

        ax.set_title(title)
        # ax.set_xlabel("Model")
        # ax.set_ylabel("Model")        

        # save
        if save_path:
            fig.savefig(save_path, transparent=True)

        if save_path is None:
            plt.show()

        # fig.clear()
        plt.close(fig)
        
        
def draw_correlation_mean(correlation, model_names, title='', save_path=None):
    with plt.style.context(['science', 'default']):  # science, ieee, default, fivethirtyeight
        plt.rcParams["font.family"] = 'Roboto Slab' # 'NanumGothic' # for Hangul in Windows
        plt.rcParams.update({"font.size": 16})
        plt.rcParams["savefig.dpi"] = 1200

        H = 2.5
        W = len(model_names) + 0.5
        
        fig = plt.figure(num=1, clear=True, figsize=(W, H), constrained_layout=True)
        ax = fig.add_subplot(1, 1, 1)

        data = correlation
        im = draw_heatmap(
            data,
            ['Average'],
            model_names,
            ax=ax,
            # imshow_kw={"alpha": 0.9, "cmap": "YlOrRd"},  # jet, YlOrRd, RdPu
            imshow_kw={"alpha": 0.9, "cmap": "coolwarm"},  # jet, YlOrRd, RdPu
            draw_cbar=False,
            cbar_label="",
            cbar_kw={},
        )
        annotate_heatmap(im, anno_format="{x:.2f}", text_colors=("black", "white"), threshold=0.9)

        ax.set_title(title)
        # ax.set_xlabel("Model")
        # ax.set_ylabel("Model")        

        # save
        if save_path:
            fig.savefig(save_path, transparent=True)

        if save_path is None:
            plt.show()

        # fig.clear()
        plt.close(fig)

In [28]:
pearson_result = pearson_correlation(model_pool, 'val_embeddings')

save_path = f"./local/output/{task}"

draw_correlation(pearson_result, [md['model'] for md in model_pool], 
                 save_path=os.path.join(save_path, f"{model_pool[0]['task']}-val.pdf"))
                 
draw_correlation_mean(pearson_result.mean(axis=0, keepdims=True), [md['model'] for md in model_pool], 
                      save_path=os.path.join(save_path, f"{model_pool[0]['task']}-val-mean.pdf"))

In [29]:
model_1d_aggregation = {'name': '1D', 'model': '1D'}

val_embeddings = np.zeros((4, *model_pool[0]['val_embeddings'].shape))
for i in range(4):
    print(model_pool[i]['model'])
    val_embeddings[i] = model_pool[i]['val_embeddings']
    
model_1d_aggregation['val_embeddings'] = val_embeddings.reshape(-1, model_pool[0]['val_embeddings'].shape[1])
print( model_1d_aggregation['val_embeddings'].shape )

1D-VGG-19
1D-ResNet-18
1D-ResNet-50
1D-ResNeXt-50
(476, 3)


In [30]:
model_2d_aggregation = {'name': '1D', 'model': '1D'}

val_embeddings = np.zeros((4, *model_pool[0]['val_embeddings'].shape))
for k in range(4):
    i = k + 4
    print(model_pool[i]['model'])
    val_embeddings[k] = model_pool[i]['val_embeddings']
    
model_2d_aggregation['val_embeddings'] = val_embeddings.reshape(-1, model_pool[0]['val_embeddings'].shape[1])
print( model_2d_aggregation['val_embeddings'].shape )

2D-VGG-19
2D-ResNet-18
2D-ResNet-50
2D-ResNeXt-50
(476, 3)


In [31]:
pearson_correlation([model_1d_aggregation, model_2d_aggregation], 'val_embeddings')

array([[1.        , 0.42392487],
       [0.42392487, 1.        ]])

In [None]:
# cosine_result = cosine_correlation(model_pool, 'val_embeddings')

# draw_correlation(cosine_result, [md['model'] for md in model_pool])
# draw_correlation_mean(cosine_result.mean(axis=0, keepdims=True), [md['model'] for md in model_pool])

##### 

##### 