# t-SNE visualization

This notebook visualizes the trained EEG embeddings.

-----

## Load Packages

In [None]:
# for auto-reloading external modules
# see http://stackoverflow.com/questions/1907993/autoreload-of-modules-in-ipython
%cd ..
%load_ext autoreload
%autoreload 2

In [None]:
# Load some packages
import os
from copy import deepcopy
import hydra
from omegaconf import OmegaConf
from collections import OrderedDict

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from sklearn.preprocessing import MinMaxScaler
from sklearn.manifold import TSNE

import pprint
import matplotlib
import matplotlib.pyplot as plt
from matplotlib import offsetbox

# custom package
from datasets.caueeg_script import build_dataset_for_train
import models
from train.evaluate import check_accuracy
from train.evaluate import check_accuracy_extended
from train.evaluate import check_accuracy_extended_debug
from train.evaluate import check_accuracy_multicrop
from train.evaluate import check_accuracy_multicrop_extended
from train.visualize import draw_roc_curve
from train.visualize import draw_confusion
from train.visualize import draw_class_wise_metrics
from train.visualize import draw_error_table
from train.visualize import annotate_heatmap

In [None]:
print('PyTorch version:', torch.__version__)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

if torch.cuda.is_available(): print('cuda is available.')
else: print('cuda is unavailable.') 

In [None]:
# Other settings
%matplotlib inline
%config InlineBackend.figure_format = 'retina' # cleaner text

plt.style.use('default') 
# ['Solarize_Light2', '_classic_test_patch', 'bmh', 'classic', 'dark_background', 'fast', 
#  'fivethirtyeight', 'ggplot', 'grayscale', 'seaborn', 'seaborn-bright', 'seaborn-colorblind', 
#  'seaborn-dark', 'seaborn-dark-palette', 'seaborn-darkgrid', 'seaborn-deep', 'seaborn-muted', 
#  'seaborn-notebook', 'seaborn-paper', 'seaborn-pastel', 'seaborn-poster', 'seaborn-talk', 
#  'seaborn-ticks', 'seaborn-white', 'seaborn-whitegrid', 'tableau-colorblind10']

plt.rcParams['image.interpolation'] = 'bicubic'
plt.rcParams["font.family"] = 'Helvetica' # 'NanumGothic' # for Hangul in Windows

-----

## Load the configuration used during the train phase

In [None]:
model_name = 'amzr0uzt'
model_path = os.path.join(r'E:\CAUEEG\checkpoint', model_name, 'checkpoint.pt')

ckpt = torch.load(model_path, map_location=device)
print(ckpt.keys())

In [None]:
model_state = ckpt['model_state']
config = ckpt['config']
optimizer = ckpt['optimizer_state']
scheduler = ckpt['scheduler_state']

In [None]:
pprint.pprint(config, width=250)

-----

## Load the target model

In [None]:
# model = config['generator'](**config).to(device)
model = hydra.utils.instantiate(config).to(device)

if config.get('ddp', False):
    model_state_ddp = deepcopy(model_state)
    model_state = OrderedDict()
    for k, v in model_state_ddp.items():
        name = k[7:] # remove 'module.' of DataParallel/DistributedDataParallel
        model_state[name] = v
        
model.load_state_dict(model_state)

-----

## Evaluate the model and analyze the performance by the crop timing

### Configurations

In [None]:
config = ckpt['config']

config.pop('cwd', 0)
config['ddp'] = False
config['crop_timing_analysis'] = True
config['eval'] = True
config['crop_multiple'] = 64
config['device'] = device

target_from_last = 2

### Build Dataset

In [None]:
train_loader, val_loader, test_loader, multicrop_test_loader = build_dataset_for_train(config, verbose=True)

In [None]:
@torch.no_grad()
def compute_embedding(model, sample_batched, preprocess, crop_multiple, target_from_last):
    # evaluation mode
    model.eval()
    
    # preprocessing (this includes to-device operation)
    preprocess(sample_batched)

    # apply model on whole batch directly on device
    x = sample_batched['signal']
    age = sample_batched['age']
    e = model.compute_feature_embedding(x, age, target_from_last=target_from_last)
    y = sample_batched['class_label']
    
    if crop_multiple > 1:
        # multi-crop averaging
        if e.size(0) % crop_multiple != 0:
            raise ValueError(f"compute_embedding(): Real minibatch size={e.size(0)} is not multiple of "
                             f"crop_multiple={crop_multiple}.")

        real_minibatch = e.size(0) // crop_multiple
        e_ = torch.zeros((real_minibatch, e.size(1)))
        y_ = torch.zeros((real_minibatch,), dtype=torch.int32)

        for m in range(real_minibatch):
            e_[m] = e[crop_multiple*m:crop_multiple*(m + 1)].mean(dim=0, keepdims=True)
            y_[m] = y[crop_multiple*m]
                
        e = e_
        y = y_
    
    return e, y

In [None]:
result = [{'name': 'Train Dataset', 
           'loader': train_loader}, 
          {'name': 'Validation Dataset', 
           'loader': val_loader}, 
          {'name': 'Test Dataset', 
           'loader': test_loader}]

for r in range(len(result)):
    target_symptom = {'mci_amnestic_ef': [], 'mci_amnestic_rf': []}
    name = result[r]['name']
    loader = result[r]['loader']

    for i, sample_batched in enumerate(loader):
        if i == 0:
            crop_multiple = config['crop_multiple']
            minibatch_size = loader.batch_size

        # estimate
        e, y = compute_embedding(model, sample_batched, config['preprocess_test'], crop_multiple, target_from_last=target_from_last)

        if i == 0:
            embedding = e.detach().cpu().numpy()
            target = y.detach().cpu().numpy()
        else:
            embedding = np.concatenate([embedding, e.detach().cpu().numpy()], axis=0)
            target = np.concatenate([target, y.detach().cpu().numpy()], axis=0)

        for s in range(0, len(sample_batched['symptom']), crop_multiple):
            symp = sample_batched['symptom'][s]        
            for k in target_symptom.keys():
                if k in symp:
                    target_symptom[k].append((s // crop_multiple) + (i * minibatch_size))
                    
    result[r]['embedding'] = embedding
    result[r]['target'] = target
    result[r]['target_symptom'] = target_symptom

In [None]:
tsne_transform = TSNE(n_components=2, init="pca", learning_rate="auto", perplexity=50.0,
                      n_iter=5000, n_iter_without_progress=500, n_jobs=2, random_state=0,)

for r in range(len(result)):
    result[r]['tsne_embedding'] = tsne_transform.fit_transform(result[r]['embedding'])
    print(result[r]['name'], '-', result[r]['tsne_embedding'].shape)

In [None]:
plt.style.use('default') 
plt.style.use('fivethirtyeight') # default, ggplot, fivethirtyeight, bmh, dark_background, classic
plt.rcParams.update({'font.size': 11})
plt.rcParams.update({'font.family': 'Arial'})
# plt.rcParams["savefig.dpi"] = 1200
color_map = ['tab:green', 'tab:orange', 'tab:red']

for r in range(len(result)):
    _, ax = plt.subplots()
    for class_name, class_label in config['class_name_to_label'].items():
        ax.scatter(
            result[r]['tsne_embedding'][result[r]['target'] == class_label][:, 0],
            result[r]['tsne_embedding'][result[r]['target'] == class_label][:, 1],
            label=class_name,
            color=color_map[class_label],
            alpha=0.8,
            edgecolors='k',
            zorder=2)
    ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left', borderaxespad=0.)
    ax.set_title(f"t-SNE embedding of {result[r]['name']}")

In [None]:
total_out_embedding = tsne_transform.fit_transform(np.concatenate([r['embedding'] for r in result]))

In [None]:
plt.style.use('default') 
plt.style.use('fivethirtyeight') # default, ggplot, fivethirtyeight, bmh, dark_background, classic
plt.rcParams.update({'font.size': 11})
plt.rcParams.update({'font.family': 'Arial'})
# plt.rcParams["savefig.dpi"] = 1200
color_map = ['tab:green', 'tab:orange', 'tab:red']
start_from = 0

for r in range(len(result)):
    _, ax = plt.subplots()
    n_size = result[r]['tsne_embedding'].shape[0]
    
    start_from_temp = 0
    for rr in range(len(result)):
        n_size_temp = result[rr]['tsne_embedding'].shape[0]
        for class_name, class_label in config['class_name_to_label'].items():
            ax.scatter(
                total_out_embedding[start_from_temp:start_from_temp + n_size_temp][result[rr]['target'] == class_label][:, 0],
                total_out_embedding[start_from_temp:start_from_temp + n_size_temp][result[rr]['target'] == class_label][:, 1],
                color=color_map[class_label],
                alpha=0.1,
                zorder=2)
        start_from_temp += n_size_temp

    for class_name, class_label in config['class_name_to_label'].items():
        ax.scatter(
            total_out_embedding[start_from:start_from + n_size][result[r]['target'] == class_label][:, 0],
            total_out_embedding[start_from:start_from + n_size][result[r]['target'] == class_label][:, 1],
            label=class_name,
            color=color_map[class_label],
            alpha=0.8,
            edgecolors='k',
            zorder=2)
    start_from += n_size
    
    ax.set_title(f"t-SNE embedding of {result[r]['name']}")
    ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left', borderaxespad=0.)

In [None]:
plt.style.use('default') 
plt.style.use('fivethirtyeight') # default, ggplot, fivethirtyeight, bmh, dark_background, classic
plt.rcParams.update({'font.size': 11})
plt.rcParams.update({'font.family': 'Arial'})
# plt.rcParams["savefig.dpi"] = 1200
color_map = ['tab:green', 'tab:orange', 'tab:red']

for r in range(len(result)):
    _, ax = plt.subplots()
    for k, v in result[r]['target_symptom'].items():
        ax.scatter(
            result[r]['tsne_embedding'][[*set(v)]][:, 0],
            result[r]['tsne_embedding'][[*set(v)]][:, 1],
            label=k,
            alpha=0.8,
            edgecolors='k',
            zorder=2)
    ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left', borderaxespad=0.)
    ax.set_title(f"t-SNE embedding {[*result[r]['target_symptom'].keys()]} in {result[r]['name']}")

In [None]:
plt.style.use('default') 
plt.style.use('fivethirtyeight') # default, ggplot, fivethirtyeight, bmh, dark_background, classic
plt.rcParams.update({'font.size': 11})
plt.rcParams.update({'font.family': 'Arial'})
# plt.rcParams["savefig.dpi"] = 1200
color_map = ['tab:green', 'tab:orange', 'tab:red']
start_from = 0

for r in range(len(result)):
    _, ax = plt.subplots()
    n_size = result[r]['tsne_embedding'].shape[0]
    
    start_from_temp = 0
    for rr in range(len(result)):
        n_size_temp = result[rr]['tsne_embedding'].shape[0]
        for class_name, class_label in config['class_name_to_label'].items():
            ax.scatter(
                total_out_embedding[start_from_temp:start_from_temp + n_size_temp][result[rr]['target'] == class_label][:, 0],
                total_out_embedding[start_from_temp:start_from_temp + n_size_temp][result[rr]['target'] == class_label][:, 1],
                color=color_map[class_label],
                label=class_name if rr == 0 else None,
                alpha=0.1,
                zorder=2)
        start_from_temp += n_size_temp

    for k, v in result[r]['target_symptom'].items():
        ax.scatter(
            total_out_embedding[start_from:start_from + n_size][[*set(v)]][:, 0],
            total_out_embedding[start_from:start_from + n_size][[*set(v)]][:, 1],
            label=k,
            alpha=0.8,
            edgecolors='k',
            zorder=2)
    start_from += n_size
    
    ax.set_title(f"t-SNE embedding of {result[r]['name']}")
    ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left', borderaxespad=0.)

## Extra Symptoms

In [None]:
serial_set = set()
for loader in [train_loader, val_loader, test_loader]:    
    for sample_batched in loader:
        serial_set.update(sample_batched['serial'])

In [None]:
import json
from torch.utils.data import DataLoader
from datasets.caueeg_dataset import CauEegDataset
from datasets.pipeline import eeg_collate_fn

dataset_path = os.path.join(config['cwd'], config['dataset_path']) if 'cwd' in config.keys() else config['dataset_path']

with open(os.path.join(dataset_path, 'annotation.json'), 'r') as json_file:
    annotation_ = json.load(json_file)

annotation = deepcopy(annotation_)
annotation['data'] = [data for data in annotation['data'] if data['serial'] not in serial_set]

print(len(annotation_['data']))
print(len(serial_set))
print(len(annotation['data']))

In [None]:
extra_eeg_dataset = CauEegDataset(dataset_path, annotation['data'],
                                  load_event=config['load_event'],
                                  file_format=config['file_format'],
                                  transform=config['transform'])

print(len(extra_eeg_dataset))

extra_loader = DataLoader(extra_eeg_dataset,
                          batch_size=4,
                          shuffle=False,
                          drop_last=False,
                          num_workers=0,
                          pin_memory=True,
                          collate_fn=eeg_collate_fn)

In [None]:
extra_symptom = {}
extra_result = {}

for i, sample_batched in enumerate(extra_loader):
    if i == 0:
        crop_multiple = config['crop_multiple']
        minibatch_size = extra_loader.batch_size
    sample_batched['class_label'] = torch.zeros((len(sample_batched['symptom'])))
    
    # estimate
    e, _ = compute_embedding(model, sample_batched, config['preprocess_test'], crop_multiple, target_from_last=target_from_last)
        
    if i == 0:
        embedding = e.detach().cpu().numpy()
    else:
        embedding = np.concatenate([embedding, e.detach().cpu().numpy()], axis=0)
    
    for s in range(0, len(sample_batched['symptom']), crop_multiple):
        symp = ', '.join(sample_batched['symptom'][s])
        
        if symp in extra_symptom.keys():
            extra_symptom[symp].append((s // crop_multiple) + i * minibatch_size)
        else:
            extra_symptom[symp] = [(s // crop_multiple) + (i * minibatch_size)]
    
    extra_result['embedding'] = embedding

In [None]:
tsne_transform = TSNE(n_components=2, init="pca", learning_rate="auto", perplexity=70.0,
                      n_iter=5000, n_iter_without_progress=500, n_jobs=2, random_state=0,)

extra_result['tsne_embedding'] = tsne_transform.fit_transform(extra_result['embedding'])
print(extra_result['tsne_embedding'].shape)

In [None]:
_, ax = plt.subplots()

for k, v in extra_symptom.items():
    ax.scatter(
        extra_result['tsne_embedding'][[*set(v)]][:, 0],
        extra_result['tsne_embedding'][[*set(v)]][:, 1],
        label=k,
        alpha=0.5,
        zorder=2)
ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left', borderaxespad=0.)

In [None]:
extra_total_out_embedding = tsne_transform.fit_transform(np.concatenate([r['embedding'] for r in result] + [extra_result['embedding']], axis=0))

In [None]:
color_map1 = ['tab:green', 'tab:orange', 'tab:red']

for i, k, in enumerate(sorted(extra_symptom.keys())):
    v = extra_symptom[k]
    
    fig = plt.figure(num=1, clear=True, figsize=(6.0, 6.0))
    ax = fig.add_subplot(1, 1, 1)

    start_from_temp = 0
    for rr in range(len(result)):
        n_size_temp = result[rr]['tsne_embedding'].shape[0]
        for class_name, class_label in config['class_name_to_label'].items():
            ax.scatter(
                extra_total_out_embedding[start_from_temp:start_from_temp + n_size_temp][result[rr]['target'] == class_label][:, 0],
                extra_total_out_embedding[start_from_temp:start_from_temp + n_size_temp][result[rr]['target'] == class_label][:, 1],
                color=color_map[class_label],
                label=class_name if rr == 0 else None,
                alpha=0.1,
                zorder=2)
        start_from_temp += n_size_temp
        
    ax.scatter(
        extra_total_out_embedding[start_from_temp:][[*set(v)]][:, 0],
        extra_total_out_embedding[start_from_temp:][[*set(v)]][:, 1],
        label=k,
        color=plt.cm.tab10(i),
        edgecolors='k',
        alpha=0.8,
        s=50,
        zorder=2)
    ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left', borderaxespad=0.)
    plt.show()
    fig.clear()
    plt.close(fig)