# Occlusion Sensitivity

This notebook conducts an experiment for the occlusion sensitivity of our networks.

-----

## Load Packages

In [None]:
# for auto-reloading external modules
# see http://stackoverflow.com/questions/1907993/autoreload-of-modules-in-ipython
%cd ..
%load_ext autoreload
%autoreload 2

In [None]:
# Load some packages
import os
from copy import deepcopy
import hydra
from collections import OrderedDict
import glob

import numpy as np
import pandas as pd
from cycler import cycler

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
import torchaudio
from torch.utils.data import DataLoader

import pprint
from tqdm.auto import tqdm
import wandb
import matplotlib
import matplotlib.pyplot as plt
import matplotlib.transforms as mtransforms
from matplotlib.patches import FancyBboxPatch
from mpl_toolkits.axes_grid1.inset_locator import inset_axes
from mpl_toolkits.axes_grid1.anchored_artists import AnchoredDirectionArrows
from matplotlib.colors import Normalize

# custom package
from datasets.caueeg_script import load_caueeg_config
from datasets.caueeg_script import load_caueeg_task_datasets
from datasets.caueeg_script import make_dataloader
from datasets.caueeg_script import compose_preprocess
from datasets.pipeline import EegDropChannels
from datasets.pipeline import EegToTensor
from datasets.pipeline import eeg_collate_fn
import models

In [None]:
print('PyTorch version:', torch.__version__)
device = torch.device('cuda:3' if torch.cuda.is_available() else 'cpu')

if torch.cuda.is_available(): print('cuda is available.')
else: print('cuda is unavailable.') 

-----

## Load a trained model

In [None]:
# model_list = [
#     '1vc80n1f',  # 1D-VGG-19 
#     'l8524nml',  # 1D-ResNet-18   // 2s1700lg, l8524nml
#     'gvqyvmrj',  # 1D-ResNet-50 
#     'v301o425',  # 1D-ResNeXt-50 
#     'lo88puq7',  # 2D-VGG-19
#     'xci5svkl',  # 2D-ResNet-18 
#     'syrx7bmk',  # 2D-ResNet-50 
#     '1sl7ipca',  # 2D-ResNeXt-50 
#     'gjkysllw',  # 2D-ViT-B-16 
# ]

model_list = [
    # 'nemy8ikm',  # 1D-VGG-19
    # '4439k9pg',  # 1D-ResNet-18
    # 'q1hhkmik',  # 1D-ResNet-50
    # 'tp7qn5hd',  # 1D-ResNeXt-50 
    # 'ruqd8r7g',  # 2D-VGG-19
    'dn10a6bv',  # 2D-ResNet-18
    'atbhqdgg',  # 2D-ResNet-50
    # '0svudowu',  # 2D-ResNeXt-50
    # '1cdws3t5',  # 2D-ViT-B-16
]

---

## Configurations

In [None]:
minibatch = 512
time_interval = 10  # 10

target_datasets = [
    'train',
    # 'val',
    # 'test',
]

save_fig = True

## Helper functions

In [None]:
def compose_transforms_modified(config, verbose=False):
    transform = []
    
    channel_reduction_list = config.get("channel_reduction_list", [])
    if config.get("EKG", None) not in ["O", "X", None]:
        raise ValueError(f"config['EKG'] should be one of ['O', 'X', None].")
    elif config.get("EKG", None) == "X":
        channel_reduction_list.append(config["signal_header"].index("EKG"))

    if config.get("photic", None) not in ["O", "X", None]:
        raise ValueError(f"config['photic'] should be one of ['O', 'X', None].")
    elif config.get("photic", None) == "X":
        channel_reduction_list.append(config["signal_header"].index("Photic"))    
    channel_reduction_set = set(channel_reduction_list)
    transform += [EegDropChannels(sorted([*channel_reduction_set]))]
    transform += [EegToTensor()]
    transform = transforms.Compose(transform)
    
    return transform

In [None]:
def build_dataset_for_train2(config, verbose=False):
    dataset_path = config["dataset_path"]
    if "cwd" in config:
        dataset_path = os.path.join(config["cwd"], dataset_path)

    config_dataset = load_caueeg_config(dataset_path)
    config.update(**config_dataset)

    if "run_mode" not in config.keys():
        print("\n" + "=" * 80 + "\n")
        print('WARNING: run_mode is not specified.\n \t==> run_mode is set to "train" automatically.')
        print("\n" + "=" * 80 + "\n")
        config["run_mode"] = "train"

    transform = compose_transforms_modified(config, verbose=verbose)
    config["transform"] = transform
    load_event = config["load_event"] or config.get("reject_events", False)

    _ = load_caueeg_task_datasets(dataset_path=dataset_path, task=config["task"], 
                                  load_event=load_event, file_format=config["file_format"], 
                                  transform=transform, verbose=verbose)
    config_task, train_dataset, val_dataset, test_dataset = _
    config.update(**config_task)

    _ = make_dataloader(config, train_dataset, 
                        val_dataset, test_dataset, test_dataset, verbose=False)
    train_loader, val_loader, test_loader, multicrop_test_loader = _

    preprocess_train, preprocess_test = compose_preprocess(config, train_loader, verbose=verbose)
    config["preprocess_train"] = preprocess_train
    config["preprocess_test"] = preprocess_test
    config["in_channels"] = preprocess_train(next(iter(train_loader)))["signal"].shape[1]
    config["out_dims"] = len(config["class_label_to_name"])

    if verbose:
        for i_batch, sample_batched in enumerate(train_loader):
            # preprocessing includes to-device operation
            preprocess_train(sample_batched)

            print(
                i_batch,
                sample_batched["signal"].shape,
                sample_batched["age"].shape,
                sample_batched["class_label"].shape,
            )

            if i_batch > 3:
                break
        print("\n" + "-" * 100 + "\n")

    return (
        train_loader,
        val_loader,
        test_loader,
        multicrop_test_loader,
    )


In [None]:
def estimate_score(model, sample_batched, config):
    # compute output embedding
    x = sample_batched['signal']
    age = sample_batched['age']
    output = model.compute_feature_embedding(x, age)

    # map depending on the loss function
    # if config['criterion'] == 'cross-entropy':
    #     score = F.softmax(output, dim=1)
    # elif config['criterion'] == 'multi-bce':
    #     score = torch.sigmoid(output)
    # elif config['criterion'] == 'svm':
    #     score = output
    # else:
    #     raise ValueError(f"estimate_score(): cannot parse config['criterion']={config['criterion']}.")
    return output

## Run experiments

In [None]:
for model_name in tqdm(model_list, desc='Model', leave=False):
    # load from disk
    try:
        path = os.path.join(r'./local/checkpoint', model_name, 'checkpoint.pt')
        ckpt = torch.load(path, map_location=device)
    except Exception as e:
        raise e
    
    model_state = ckpt['model_state']
    config = ckpt['config']

    # initiate the model
    if '_target_' in config:
        model = hydra.utils.instantiate(config).to(device)
    elif type(config['generator']) is str:
        config['generator'] = getattr(models, config['generator'].split('.')[-1])
        if 'block' in config:
            config['block'] = getattr(models, config['block'].split('.')[-1])
        model = config['generator'](**config).to(device)
    else:
        if 'block' in config:
            if config['block'] == models.resnet_1d.BottleneckBlock1D:
                config['block'] = 'bottleneck'
            elif config['block'] == models.resnet_2d.Bottleneck2D:
                config['block'] = 'bottleneck'
            elif config['block'] == models.resnet_1d.BasicBlock1D:
                config['block'] = 'basic'
            elif config['block'] == models.resnet_2d.BasicBlock2D:
                config['block'] = 'basic'
    
        model = config['generator'](**config).to(device)
    
    if config.get('ddp', False):
        model_state_ddp = deepcopy(model_state)
        model_state = OrderedDict()
        for k, v in model_state_ddp.items():
            name = k[7:]  # remove 'module.' of DataParallel/DistributedDataParallel
            model_state[name] = v
    
    model.load_state_dict(model_state)
    model.requires_grad_(False)
    model.eval()

    task = config['task']
    config.pop('cwd', 0)
    config['ddp'] = False
    config['minibatch'] = 1
    config['crop_multiple'] = 1
    config['test_crop_multiple'] = 1
    config['crop_timing_analysis'] = False
    config['eval'] = True
    config['device'] = device
    
    if '220419' in config['dataset_path']:
        config['dataset_path'] = './local/dataset/caueeg-dataset/'
    config['run_mode'] = 'eval'
    
    _ = build_dataset_for_train2(config, verbose=False)
    train_loader = _[0]
    val_loader = _[1]
    test_loader = _[2]

    for target_dataset in tqdm(target_datasets, desc="Dataset", leave=False):
        if target_dataset == 'train':
            loader = train_loader
        elif target_dataset == 'val':
            loader = val_loader
        elif target_dataset == 'test':
            loader = test_loader
        else:
            raise ValueError('')
                
        for sample in tqdm(loader, total=len(loader), desc='Batch', leave=False):
            timing_results = []
            sample_origin = deepcopy(sample)
            st = config.get("latency", 0)
            ed = sample_origin['signal'].shape[2] - config['seq_length'] + 1
            sample_batched = []
            
            for t in range(st, ed, time_interval):
                sample = deepcopy(sample_origin)

                # custom cropping
                sample['signal'] = sample['signal'][:, :, t:t + config['seq_length']]
                config['preprocess_test'](sample)
                sample['timing'] = t

                # gather minibatch
                sample_batched.append(sample)

                # gather data until it becomes minibatch size
                if len(sample_batched) == minibatch:
                    sample_batched = eeg_collate_fn(sample_batched)
                    N, _, *R = sample_batched['signal'].shape
                    sample_batched['signal'] = sample_batched['signal'].reshape(N, *R)
                    sample_batched['age'] = sample_batched['age'].reshape(N)
                    sample_batched['class_label'] = sample_batched['class_label'].reshape(N)
                    s = estimate_score(model, sample_batched, config)

                    for i in range(N):
                        result = {
                            'st': sample_batched['timing'][i],
                            'ed': sample_batched['timing'][i] + config['seq_length'],
                            'logit': s[i].cpu(),
                        }
                        timing_results.append(result)
                    sample_batched = []

            # the rest data
            if len(sample_batched) > 0:
                sample_batched = eeg_collate_fn(sample_batched)
                N, _, *R = sample_batched['signal'].shape
                sample_batched['signal'] = sample_batched['signal'].reshape(N, *R)
                sample_batched['age'] = sample_batched['age'].reshape(N)
                sample_batched['class_label'] = sample_batched['class_label'].reshape(N)
                s = estimate_score(model, sample_batched, config)

                for i in range(N):
                    result = {
                        'st': sample_batched['timing'][i],
                        'ed': sample_batched['timing'][i] + config['seq_length'],
                        'logit': s[i].cpu(),
                    }
                    timing_results.append(result)
                sample_batched = []
            
            result_dict = {
                'name': model_name,
                'model': config['model'],
                'seq_length': config['seq_length'],
                'serial': sample['serial'][0],
                'class_label': sample['class_label'][0],
                'timing_results': timing_results,
            }
            
            path = f'local/output/timing_analysis/{model_name}/{target_dataset}/'
            os.makedirs(path, exist_ok=True)
            torch.save(result_dict, os.path.join(path, sample["serial"][0] + '.pt'))

## Post analysis

In [None]:
total_results = []
for model_name in tqdm(model_list, desc='Model', leave=False):
    results = {'model': model_name}
    for dataset in tqdm(target_datasets, desc="Dataset", leave=False):
        results[dataset] = []
        for fname in tqdm(glob.glob(f'local/output/timing_analysis/{model_name}/{dataset}/*'), desc='Inner Loop', leave=False):
            result_dict = torch.load(fname)
            timing_results = result_dict['timing_results']
            C = timing_results[0]['logit'].shape[0]
            eed = timing_results[-1]['ed']
            
            result = {'serial': result_dict['serial'], 
                      'name': result_dict['name'],
                      'model': result_dict['model'],
                      'class_label': result_dict['class_label'],
                      'seq_length': result_dict['seq_length'],
                      'timing': torch.arange(eed), 
                      'logit': torch.zeros((eed, C)), 
                      'fraction': torch.zeros((eed))}
            
            for i, r in enumerate(timing_results):
                result['logit'][r['st']:r['ed']] += r['logit']
                result['fraction'][r['st']:r['ed']] += 1
            results[dataset].append(result) 
    total_results.append(results)

In [None]:
results_by_serial = {}

for results in total_results:
    for dataset, result in results.items():
        if dataset == 'model':
            continue
        for r in result:
            if r['serial'] not in results_by_serial.keys():
                results_by_serial[r['serial']] = []
                
            fraction = r['fraction']
            fraction[fraction < 0.5] = 1
            results_by_serial[r['serial']].append({
                'name': r['name'],
                'model': r['model'],
                'seq_length': r['seq_length'],
                'class_label': r['class_label'].item(),
                'timing': r['timing'].numpy(),
                'logit': (r['logit'] / fraction.unsqueeze(1)).numpy(),
            })

In [None]:
# ensemble_result_by_serial = {}

# for serial, results in results_by_serial.items():
#     ensemble_score = torch.zeros_like(torch.Tensor(results[0]['score']))
#     for r in results:
#         ensemble_score += r['score']
#     ensemble_score /= len(results)
#     ensemble_result_by_serial[serial] = ensemble_score

# path = f'local/output/timing_analysis/'
# os.makedirs(path, exist_ok=True)
# torch.save(ensemble_result_by_serial, os.path.join(path, 'ensemble_result_by_serial.pt'))

In [None]:
%config InlineBackend.figure_format = 'retina' # cleaner text
plt.style.use('classic') 
plt.style.use('default') 
plt.style.use('bmh') # default, ggplot, fivethirtyeight, bmh, dark_background, classic
plt.rcParams.update({'font.size': 14})
plt.rcParams.update({'font.family': 'Arial'})

base_path = f'./local/output/imgs/timing_analysis'
os.makedirs(base_path, exist_ok=True)
colors = ['tab:red', 'tab:orange', 'tab:green']

for serial, results in tqdm(results_by_serial.items(), desc='Serial', leave=False):
    N = max(len(results), 2)
    fig, axs = plt.subplots(N, 1, sharex=True, figsize=(15.0, 3.0 * N), constrained_layout=True)
    # fig.tight_layout(rect=[0, 0.03, 1, 1])
    # fig.subplots_adjust(hspace=0)    `
    
    for i, result in enumerate(results):
        score = torch.Tensor(result['logit']).softmax(axis=1).numpy()
        for c in range(result['logit'].shape[1]):
            axs[i].plot(result['timing'], score[:, c], colors[c], 
                        lw=2.0 if c == result['class_label'] else 1.0, 
                        ls='-' if c == result['class_label'] else '--')
        axs[i].set_xlim(0, result['timing'][-1])
        axs[i].set_ylim(0, 1.0)
        axs[i].set_ylabel('Prob')
        axs[i].set_title(f"Model: {result['model']}, Crop length: {result['seq_length']}")
        
    axs[-1].set_xlabel('Timing')
    fig.suptitle(serial, fontsize=25, fontweight='semibold')
    fig.savefig(os.path.join(base_path, f'{serial}.jpg'), transparent=True)
    fig.clear()
    plt.close(fig)