# t-SNE visualization

This notebook visualizes the EEG embeddings computed by the model trained.

-----

## Load Packages

In [None]:
# for auto-reloading external modules
# see http://stackoverflow.com/questions/1907993/autoreload-of-modules-in-ipython
%cd ..
%load_ext autoreload
%autoreload 2

In [None]:
# Load some packages
import os
from copy import deepcopy
import hydra
from omegaconf import OmegaConf
from collections import OrderedDict

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from sklearn.preprocessing import MinMaxScaler
from sklearn.manifold import TSNE

import pprint
import matplotlib
import matplotlib.pyplot as plt
from matplotlib import offsetbox

# custom package
from datasets.caueeg_script import build_dataset_for_train
import models
from train.evaluate import check_accuracy
from train.evaluate import check_accuracy_extended
from train.evaluate import check_accuracy_extended_debug
from train.evaluate import check_accuracy_multicrop
from train.evaluate import check_accuracy_multicrop_extended
from train.visualize import draw_roc_curve
from train.visualize import draw_confusion, draw_confusion2
from train.visualize import draw_class_wise_metrics
from train.visualize import draw_error_table
from train.visualize import annotate_heatmap

In [None]:
print('PyTorch version:', torch.__version__)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

if torch.cuda.is_available(): print('cuda is available.')
else: print('cuda is unavailable.') 

In [None]:
# Other settings
%matplotlib inline
%config InlineBackend.figure_format = 'retina' # cleaner text

plt.style.use('default') 
# ['Solarize_Light2', '_classic_test_patch', 'bmh', 'classic', 'dark_background', 'fast', 
#  'fivethirtyeight', 'ggplot', 'grayscale', 'seaborn', 'seaborn-bright', 'seaborn-colorblind', 
#  'seaborn-dark', 'seaborn-dark-palette', 'seaborn-darkgrid', 'seaborn-deep', 'seaborn-muted', 
#  'seaborn-notebook', 'seaborn-paper', 'seaborn-pastel', 'seaborn-poster', 'seaborn-talk', 
#  'seaborn-ticks', 'seaborn-white', 'seaborn-whitegrid', 'tableau-colorblind10']

plt.rcParams['image.interpolation'] = 'bicubic'
plt.rcParams["font.family"] = 'Helvetica' # 'NanumGothic' # for Hangul in Windows

-----

## Load the configuration used during the train phase

In [None]:
# model_name = 'lo88puq7'
model_name = '1nu3jagp'  # no mixup version
model_path = os.path.join(r'E:\CAUEEG\checkpoint', model_name, 'checkpoint.pt')

ckpt = torch.load(model_path, map_location=device)
print(ckpt.keys())

In [None]:
model_state = ckpt['model_state']
config = ckpt['config']
optimizer = ckpt['optimizer_state']
scheduler = ckpt['scheduler_state']

In [None]:
pprint.pprint(config, width=250)

-----

## Load the target model

In [None]:
# model = config['generator'](**config).to(device)
model = hydra.utils.instantiate(config).to(device)

if config.get('ddp', False):
    model_state_ddp = deepcopy(model_state)
    model_state = OrderedDict()
    for k, v in model_state_ddp.items():
        name = k[7:] # remove 'module.' of DataParallel/DistributedDataParallel
        model_state[name] = v
        
model.load_state_dict(model_state)

-----

## Evaluate the model and analyze the performance by the crop timing

### Configurations

In [None]:
config = ckpt['config']

config.pop('cwd', 0)
config['ddp'] = False
config['crop_timing_analysis'] = True
config['eval'] = True
config['crop_multiple'] = 32
config['device'] = device

target_from_last = 2

### Build Dataset

In [None]:
if '220419' in config['dataset_path']:
    config['dataset_path'] = './local/dataset/caueeg-dataset/'
    
train_loader, val_loader, test_loader, multicrop_test_loader = build_dataset_for_train(config, verbose=True)

## Test accuracy

In [None]:
_ = check_accuracy_extended_debug(model, test_loader, 
                                  config['preprocess_test'], config, repeat=50)
test_acc = _[0]
test_score = _[1]
test_target = _[2]
test_confusion = _[3]
test_error_table = _[4]
test_crop_timing = _[5]

print(test_acc)

In [None]:
draw_roc_curve(test_score, test_target, config['class_label_to_name'], use_wandb=False)
draw_confusion(test_confusion, config['class_label_to_name'], use_wandb=False)
draw_class_wise_metrics(test_confusion, config['class_label_to_name'], use_wandb=False)
draw_error_table(test_error_table, use_wandb=False)

## t-SNE embedding

In [None]:
@torch.no_grad()
def compute_embedding(model, sample_batched, preprocess, crop_multiple, target_from_last):
    # evaluation mode
    model.eval()
    
    # preprocessing (this includes to-device operation)
    preprocess(sample_batched)

    # apply model on whole batch directly on device
    x = sample_batched['signal']
    age = sample_batched['age']
    e = model.compute_feature_embedding(x, age, target_from_last=target_from_last)
    y = sample_batched['class_label']
    
    if crop_multiple > 1:
        # multi-crop averaging
        if e.size(0) % crop_multiple != 0:
            raise ValueError(f"compute_embedding(): Real minibatch size={e.size(0)} is not multiple of "
                             f"crop_multiple={crop_multiple}.")

        real_minibatch = e.size(0) // crop_multiple
        e_ = torch.zeros((real_minibatch, e.size(1)))
        y_ = torch.zeros((real_minibatch,), dtype=torch.int32)

        for m in range(real_minibatch):
            e_[m] = e[crop_multiple*m:crop_multiple*(m + 1)].mean(dim=0, keepdims=True)
            y_[m] = y[crop_multiple*m]
                
        e = e_
        y = y_
    
    return e, y

In [None]:
def mixup_data(x, age, y, alpha=0, use_cuda=True):
    lam = np.random.beta(alpha, alpha) if alpha > 1e-12 else 1
    batch_size = x.size()[0]
    if use_cuda:
        index = torch.randperm(batch_size).cuda()
    else:
        index = torch.randperm(batch_size)

    mixed_x = lam * x + (1 - lam) * x[index, :]
    mixed_age = lam * age + (1 - lam) * age[index]
    y_a, y_b = y, y[index]
    return mixed_x, mixed_age, y_a, y_b, lam, index1

def mixup_data_lam(x, age, y, lam=0.5, use_cuda=True):
    batch_size = x.size()[0]
    if use_cuda:
        index = torch.randperm(batch_size).cuda()
    else:
        index = torch.randperm(batch_size)

    mixed_x = lam * x + (1 - lam) * x[index, :]
    mixed_age = lam * age + (1 - lam) * age[index]
    y_a, y_b = y, y[index]
    return mixed_x, mixed_age, y_a, y_b, lam, index


@torch.no_grad()
def compute_mixup_embedding(model, sample_batched, preprocess, crop_multiple, mixup_alpha, target_from_last):
    # evaluation mode
    model.eval()
    
    # preprocessing (this includes to-device operation)
    preprocess(sample_batched)

    # apply model on whole batch directly on device
    x = sample_batched['signal']
    age = sample_batched['age']
    y = sample_batched['class_label']

    # x, age, y1, y2, lam, mixup_index = mixup_data(x, age, y, mixup_alpha)
    x, age, y1, y2, lam, mixup_index = mixup_data_lam(x, age, y, lam=0.5)
    e = model.compute_feature_embedding(x, age, target_from_last=target_from_last)
    
    if crop_multiple > 1:
        # multi-crop averaging
        if e.size(0) % crop_multiple != 0:
            raise ValueError(f"compute_embedding(): Real minibatch size={e.size(0)} is not multiple of "
                             f"crop_multiple={crop_multiple}.")

        real_minibatch = e.size(0) // crop_multiple
        e_ = torch.zeros((real_minibatch, e.size(1)))

        for m in range(real_minibatch):
            e_[m] = e[crop_multiple*m:crop_multiple*(m + 1)].mean(dim=0, keepdims=True)
                
        e = e_
    
    return e

In [None]:
result = [{'name': 'Train Dataset', 'loader': train_loader}, 
          {'name': 'Validation Dataset', 'loader': val_loader}, 
          {'name': 'Test Dataset', 'loader': test_loader}]

for r in range(len(result)):
    name = result[r]['name']
    loader = result[r]['loader']

    for i, sample_batched in enumerate(loader):
        if i == 0:
            crop_multiple = config['crop_multiple']
            minibatch_size = loader.batch_size

        # estimate
        e, y = compute_embedding(model, sample_batched, config['preprocess_test'], crop_multiple, target_from_last=target_from_last)

        if i == 0:
            embedding = e.detach().cpu().numpy()
            target = y.detach().cpu().numpy()
        else:
            embedding = np.concatenate([embedding, e.detach().cpu().numpy()], axis=0)
            target = np.concatenate([target, y.detach().cpu().numpy()], axis=0)     
                    
    result[r]['embedding'] = embedding
    result[r]['target'] = target

In [None]:
mixup_result = [{'name': 'Train Dataset', 'loader': train_loader}, 
                {'name': 'Validation Dataset', 'loader': val_loader}, 
                {'name': 'Test Dataset', 'loader': test_loader}]

for r in range(len(mixup_result)):
    name = mixup_result[r]['name']
    loader = mixup_result[r]['loader']

    for i, sample_batched in enumerate(loader):
        if i == 0:
            crop_multiple = config['crop_multiple']
            minibatch_size = loader.batch_size

        # estimate
        e = compute_mixup_embedding(model, sample_batched, config['preprocess_test'], crop_multiple, 
                                    mixup_alpha=config['mixup'], target_from_last=target_from_last)

        if i == 0:
            embedding = e.detach().cpu().numpy()
        else:
            embedding = np.concatenate([embedding, e.detach().cpu().numpy()], axis=0)  
                    
    mixup_result[r]['embedding'] = embedding

In [None]:
tsne_transform = TSNE(n_components=2, init="pca", learning_rate="auto", perplexity=30.0,
                      n_iter=5000, n_iter_without_progress=500, n_jobs=2, random_state=0,)

for r in range(len(result)):
    output = tsne_transform.fit_transform( np.concatenate([result[r]['embedding'], mixup_result[r]['embedding']]))
    result[r]['tsne_embedding'] = output[:result[r]['embedding'].shape[0]]
    mixup_result[r]['tsne_embedding'] = output[result[r]['embedding'].shape[0]:]

In [None]:
plt.style.use('default') 
plt.style.use('fivethirtyeight') # default, ggplot, fivethirtyeight, bmh, dark_background, classic
plt.rcParams.update({'font.size': 11})
plt.rcParams.update({'font.family': 'Arial'})
# plt.rcParams["savefig.dpi"] = 1200
color_map = ['tab:green', 'tab:orange', 'tab:red']

for r in range(len(result)):
    _, ax = plt.subplots()
    for class_name, class_label in config['class_name_to_label'].items():
        ax.scatter(
            result[r]['tsne_embedding'][result[r]['target'] == class_label][:, 0],
            result[r]['tsne_embedding'][result[r]['target'] == class_label][:, 1],
            label=class_name,
            color=color_map[class_label],
            alpha=0.8,
            edgecolors='k',
            zorder=2)
    ax.scatter(
        mixup_result[r]['tsne_embedding'][:, 0],
        mixup_result[r]['tsne_embedding'][:, 1],
        label='mixup',
        color='gray',
        alpha=0.5,
        edgecolors='k',
        zorder=2)
    ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left', borderaxespad=0.)
    ax.set_title(f"t-SNE embedding of {result[r]['name']}")

In [None]:
total_out_embedding = tsne_transform.fit_transform(np.concatenate([r['embedding'] for r in result]))

In [None]:
plt.style.use('default') 
plt.style.use('fivethirtyeight') # default, ggplot, fivethirtyeight, bmh, dark_background, classic
plt.rcParams.update({'font.size': 11})
plt.rcParams.update({'font.family': 'Arial'})
# plt.rcParams["savefig.dpi"] = 1200
color_map = ['tab:green', 'tab:orange', 'tab:red']
start_from = 0

for r in range(len(result)):
    _, ax = plt.subplots()
    n_size = result[r]['tsne_embedding'].shape[0]
    
    start_from_temp = 0
    for rr in range(len(result)):
        n_size_temp = result[rr]['tsne_embedding'].shape[0]
        for class_name, class_label in config['class_name_to_label'].items():
            ax.scatter(
                total_out_embedding[start_from_temp:start_from_temp + n_size_temp][result[rr]['target'] == class_label][:, 0],
                total_out_embedding[start_from_temp:start_from_temp + n_size_temp][result[rr]['target'] == class_label][:, 1],
                color=color_map[class_label],
                alpha=0.1,
                zorder=2)
        start_from_temp += n_size_temp

    for class_name, class_label in config['class_name_to_label'].items():
        ax.scatter(
            total_out_embedding[start_from:start_from + n_size][result[r]['target'] == class_label][:, 0],
            total_out_embedding[start_from:start_from + n_size][result[r]['target'] == class_label][:, 1],
            label=class_name,
            color=color_map[class_label],
            alpha=0.8,
            edgecolors='k',
            zorder=2)
    start_from += n_size
    
    ax.set_title(f"t-SNE embedding of {result[r]['name']}")
    ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left', borderaxespad=0.)

In [None]:
plt.style.use('default') 
plt.style.use('fivethirtyeight') # default, ggplot, fivethirtyeight, bmh, dark_background, classic
plt.rcParams.update({'font.size': 11})
plt.rcParams.update({'font.family': 'Arial'})
# plt.rcParams["savefig.dpi"] = 1200
color_map = ['tab:green', 'tab:orange', 'tab:red']

for r in range(len(result)):
    _, ax = plt.subplots()
    for k, v in result[r]['target_symptom'].items():
        ax.scatter(
            result[r]['tsne_embedding'][[*set(v)]][:, 0],
            result[r]['tsne_embedding'][[*set(v)]][:, 1],
            label=k,
            alpha=0.8,
            edgecolors='k',
            zorder=2)
    ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left', borderaxespad=0.)
    ax.set_title(f"t-SNE embedding {[*result[r]['target_symptom'].keys()]} in {result[r]['name']}")

In [None]:
plt.style.use('default') 
plt.style.use('fivethirtyeight') # default, ggplot, fivethirtyeight, bmh, dark_background, classic
plt.rcParams.update({'font.size': 11})
plt.rcParams.update({'font.family': 'Arial'})
# plt.rcParams["savefig.dpi"] = 1200
color_map = ['tab:green', 'tab:orange', 'tab:red']
start_from = 0

for r in range(len(result)):
    _, ax = plt.subplots()
    n_size = result[r]['tsne_embedding'].shape[0]
    
    start_from_temp = 0
    for rr in range(len(result)):
        n_size_temp = result[rr]['tsne_embedding'].shape[0]
        for class_name, class_label in config['class_name_to_label'].items():
            ax.scatter(
                total_out_embedding[start_from_temp:start_from_temp + n_size_temp][result[rr]['target'] == class_label][:, 0],
                total_out_embedding[start_from_temp:start_from_temp + n_size_temp][result[rr]['target'] == class_label][:, 1],
                color=color_map[class_label],
                label=class_name if rr == 0 else None,
                alpha=0.1,
                zorder=2)
        start_from_temp += n_size_temp

    for k, v in result[r]['target_symptom'].items():
        ax.scatter(
            total_out_embedding[start_from:start_from + n_size][[*set(v)]][:, 0],
            total_out_embedding[start_from:start_from + n_size][[*set(v)]][:, 1],
            label=k,
            alpha=0.8,
            edgecolors='k',
            zorder=2)
    start_from += n_size
    
    ax.set_title(f"t-SNE embedding of {result[r]['name']}")
    ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left', borderaxespad=0.)

### 3D

In [None]:
tsne_transform = TSNE(n_components=3, init="pca", learning_rate="auto", perplexity=200.0,
                      n_iter=50000, n_iter_without_progress=5000, n_jobs=4, random_state=0,)

for r in range(len(result)):
    output = tsne_transform.fit_transform( np.concatenate([result[r]['embedding'], mixup_result[r]['embedding']]))
    result[r]['tsne_embedding'] = output[:result[r]['embedding'].shape[0]]
    mixup_result[r]['tsne_embedding'] = output[result[r]['embedding'].shape[0]:]

In [None]:
plt.style.use('default') 
plt.style.use('fivethirtyeight') # default, ggplot, fivethirtyeight, bmh, dark_background, classic
plt.rcParams.update({'font.size': 11})
plt.rcParams.update({'font.family': 'Arial'})
# plt.rcParams["savefig.dpi"] = 1200
color_map = ['tab:green', 'tab:orange', 'tab:red']

for r in range(len(result)):
    fig = plt.figure(num=1, clear=True, figsize=(12.0, 12.0))
    ax = fig.add_subplot(1, 1, 1, projection='3d')
    for class_name, class_label in config['class_name_to_label'].items():
        ax.scatter(
            xs=result[r]['tsne_embedding'][result[r]['target'] == class_label][:, 0],
            ys=result[r]['tsne_embedding'][result[r]['target'] == class_label][:, 1],
            zs=result[r]['tsne_embedding'][result[r]['target'] == class_label][:, 2],
            label=class_name,
            color=color_map[class_label],
            alpha=0.8,
            s=40,
            edgecolors='k',
            # zorder=2
        )
    ax.scatter(
        xs=mixup_result[r]['tsne_embedding'][:, 0],
        ys=mixup_result[r]['tsne_embedding'][:, 1],
        zs=mixup_result[r]['tsne_embedding'][:, 2],
        label='mixup',
        color='gray',
        alpha=0.5,
        s=40,
        edgecolors='k',
        # zorder=2
    )
    ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left', borderaxespad=0.)
    ax.set_title(f"t-SNE embedding of {result[r]['name']}")
    plt.show()

In [None]:
total_out_embedding = tsne_transform.fit_transform(np.concatenate([r['embedding'] for r in result]))

In [None]:
plt.style.use('default') 
plt.style.use('fivethirtyeight') # default, ggplot, fivethirtyeight, bmh, dark_background, classic
plt.rcParams.update({'font.size': 11})
plt.rcParams.update({'font.family': 'Arial'})
# plt.rcParams["savefig.dpi"] = 1200
color_map = ['tab:green', 'tab:orange', 'tab:red']
start_from = 0

for r in range(len(result)):
    fig = plt.figure(num=1, clear=True, figsize=(12.0, 12.0))
    ax = fig.add_subplot(1, 1, 1, projection='3d')
    n_size = result[r]['tsne_embedding'].shape[0]
    
    start_from_temp = 0
    for rr in range(len(result)):
        n_size_temp = result[rr]['tsne_embedding'].shape[0]
        for class_name, class_label in config['class_name_to_label'].items():
            ax.scatter(
                total_out_embedding[start_from_temp:start_from_temp + n_size_temp][result[rr]['target'] == class_label][:, 0],
                total_out_embedding[start_from_temp:start_from_temp + n_size_temp][result[rr]['target'] == class_label][:, 1],
                total_out_embedding[start_from_temp:start_from_temp + n_size_temp][result[rr]['target'] == class_label][:, 2],
                color=color_map[class_label],
                alpha=0.1,
                s=40,
                zorder=2)
        start_from_temp += n_size_temp

    for class_name, class_label in config['class_name_to_label'].items():
        ax.scatter(
            total_out_embedding[start_from:start_from + n_size][result[r]['target'] == class_label][:, 0],
            total_out_embedding[start_from:start_from + n_size][result[r]['target'] == class_label][:, 1],
            total_out_embedding[start_from:start_from + n_size][result[r]['target'] == class_label][:, 2],
            label=class_name,
            color=color_map[class_label],
            alpha=0.8,
            s=40,
            edgecolors='k',
            zorder=2)
    start_from += n_size
    
    ax.set_title(f"t-SNE embedding of {result[r]['name']}")
    ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left', borderaxespad=0.)
    plt.show()

In [None]:
plt.style.use('default') 
plt.style.use('fivethirtyeight') # default, ggplot, fivethirtyeight, bmh, dark_background, classic
plt.rcParams.update({'font.size': 11})
plt.rcParams.update({'font.family': 'Arial'})
# plt.rcParams["savefig.dpi"] = 1200
color_map = ['tab:green', 'tab:orange', 'tab:red']

for r in range(len(result)):
    fig = plt.figure(num=1, clear=True, figsize=(12.0, 12.0))
    ax = fig.add_subplot(1, 1, 1, projection='3d')
    for k, v in result[r]['target_symptom'].items():
        ax.scatter(
            result[r]['tsne_embedding'][[*set(v)]][:, 0],
            result[r]['tsne_embedding'][[*set(v)]][:, 1],
            result[r]['tsne_embedding'][[*set(v)]][:, 2],
            label=k,
            s=40,
            alpha=0.8,
            edgecolors='k',
            zorder=2)
    ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left', borderaxespad=0.)
    ax.set_title(f"t-SNE embedding {[*result[r]['target_symptom'].keys()]} in {result[r]['name']}")
    plt.show()

In [None]:
plt.style.use('default') 
plt.style.use('fivethirtyeight') # default, ggplot, fivethirtyeight, bmh, dark_background, classic
plt.rcParams.update({'font.size': 11})
plt.rcParams.update({'font.family': 'Arial'})
# plt.rcParams["savefig.dpi"] = 1200
color_map = ['tab:green', 'tab:orange', 'tab:red']
start_from = 0

for r in range(len(result)):
    fig = plt.figure(num=1, clear=True, figsize=(12.0, 12.0))
    ax = fig.add_subplot(1, 1, 1, projection='3d')
    n_size = result[r]['tsne_embedding'].shape[0]
    
    start_from_temp = 0
    for rr in range(len(result)):
        n_size_temp = result[rr]['tsne_embedding'].shape[0]
        for class_name, class_label in config['class_name_to_label'].items():
            ax.scatter(
                total_out_embedding[start_from_temp:start_from_temp + n_size_temp][result[rr]['target'] == class_label][:, 0],
                total_out_embedding[start_from_temp:start_from_temp + n_size_temp][result[rr]['target'] == class_label][:, 1],
                total_out_embedding[start_from_temp:start_from_temp + n_size_temp][result[rr]['target'] == class_label][:, 2],
                color=color_map[class_label],
                label=class_name if rr == 0 else None,
                s=40,
                alpha=0.1,
                zorder=2)
        start_from_temp += n_size_temp

    for k, v in result[r]['target_symptom'].items():
        ax.scatter(
            total_out_embedding[start_from:start_from + n_size][[*set(v)]][:, 0],
            total_out_embedding[start_from:start_from + n_size][[*set(v)]][:, 1],
            total_out_embedding[start_from:start_from + n_size][[*set(v)]][:, 2],
            label=k,
            s=40,
            alpha=0.8,
            edgecolors='k',
            zorder=2)
    start_from += n_size
    
    ax.set_title(f"t-SNE embedding of {result[r]['name']}")
    ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left', borderaxespad=0.)
    plt.show()

## Same data and different augmentation

In [None]:
multi_time_result = []
for t in range(4):
    time_result = [{'name': 'Train Dataset', 'loader': train_loader}, 
                   {'name': 'Validation Dataset', 'loader': val_loader}, 
                   {'name': 'Test Dataset', 'loader': test_loader}]

    for r in range(len(time_result)):
        target_symptom = {'mci_amnestic_ef': [], 'mci_amnestic_rf': []}
        name = time_result[r]['name']
        loader = time_result[r]['loader']

        for i, sample_batched in enumerate(loader):
            if i == 0:
                crop_multiple = config['crop_multiple']
                minibatch_size = loader.batch_size

            # estimate
            e, y = compute_embedding(model, sample_batched, config['preprocess_test'], crop_multiple, 
                                     target_from_last=target_from_last)

            if i == 0:
                embedding = e.detach().cpu().numpy()
                target = y.detach().cpu().numpy()
            else:
                embedding = np.concatenate([embedding, e.detach().cpu().numpy()], axis=0)
                target = np.concatenate([target, y.detach().cpu().numpy()], axis=0)

        time_result[r]['embedding'] = embedding
        time_result[r]['target'] = target
        time_result[r]['target_symptom'] = target_symptom
    
    multi_time_result.append(time_result)

In [None]:
tsne_transform = TSNE(n_components=2, init="pca", learning_rate="auto", perplexity=80.0,
                      n_iter=15000, n_iter_without_progress=1000, n_jobs=2, random_state=0,)

multi_time_total_out_embedding = []

for time_result in multi_time_result:
    for r in time_result:
        multi_time_total_out_embedding.append(r['embedding'])

multi_time_total_out_embedding = tsne_transform.fit_transform(np.concatenate(multi_time_total_out_embedding))
print(multi_time_total_out_embedding.shape)

In [None]:
plt.style.use('default') 
plt.style.use('fivethirtyeight') # default, ggplot, fivethirtyeight, bmh, dark_background, classic
plt.rcParams.update({'font.size': 11})
plt.rcParams.update({'font.family': 'Arial'})
# plt.rcParams["savefig.dpi"] = 1200

color_map = ['tab:green', 'tab:orange', 'tab:red']

_, ax = plt.subplots()

start_from = 0
for rr in range(len(time_result)):
    n_size = time_result[rr]['embedding'].shape[0]
    for class_name, class_label in config['class_name_to_label'].items():
        ax.scatter(
            multi_time_total_out_embedding[start_from:start_from + n_size][time_result[rr]['target'] == class_label][:, 0],
            multi_time_total_out_embedding[start_from:start_from + n_size][time_result[rr]['target'] == class_label][:, 1],
            color=color_map[class_label],
            alpha=0.05,
            edgecolors='k',
            zorder=2)
    start_from += n_size

start_from = 0
for time_result in multi_time_result:
    for rr in range(len(time_result)):
        n_size = time_result[rr]['embedding'].shape[0]
        for class_name, class_label in config['class_name_to_label'].items():
            ax.scatter(
                multi_time_total_out_embedding[start_from:start_from + n_size][time_result[rr]['target'] == class_label][3, 0],
                multi_time_total_out_embedding[start_from:start_from + n_size][time_result[rr]['target'] == class_label][3, 1],
                color=color_map[class_label],
                alpha=0.5,
                edgecolors='k',
                zorder=2)
        start_from += n_size
    
ax.set_title(f"t-SNE embedding by different augmentation")