# t-SNE visualization

This notebook visualizes the EEG embeddings computed by the model trained.

-----

## Load Packages

In [1]:
# for auto-reloading external modules
# see http://stackoverflow.com/questions/1907993/autoreload-of-modules-in-ipython
%cd ..
%load_ext autoreload
%autoreload 2

C:\Users\Minjae\Desktop\EEG_Project


In [2]:
# Load some packages
import os
from copy import deepcopy
import hydra
from omegaconf import OmegaConf
from collections import OrderedDict

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from sklearn.preprocessing import MinMaxScaler
from sklearn.manifold import TSNE

import pprint
import matplotlib
import matplotlib.pyplot as plt
from matplotlib import offsetbox

# custom package
from datasets.caueeg_script import build_dataset_for_train
import models
from train.evaluate import check_accuracy
from train.evaluate import check_accuracy_extended
from train.evaluate import check_accuracy_extended_debug
from train.evaluate import check_accuracy_multicrop
from train.evaluate import check_accuracy_multicrop_extended
from train.visualize import draw_roc_curve
from train.visualize import draw_confusion, draw_confusion2
from train.visualize import draw_class_wise_metrics
from train.visualize import draw_error_table
from train.visualize import annotate_heatmap

In [3]:
print('PyTorch version:', torch.__version__)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

if torch.cuda.is_available(): print('cuda is available.')
else: print('cuda is unavailable.') 

PyTorch version: 2.0.0+cu117
cuda is available.


In [4]:
# Other settings
%matplotlib inline
%config InlineBackend.figure_format = 'retina' # cleaner text

plt.style.use('default') 
# ['Solarize_Light2', '_classic_test_patch', 'bmh', 'classic', 'dark_background', 'fast', 
#  'fivethirtyeight', 'ggplot', 'grayscale', 'seaborn', 'seaborn-bright', 'seaborn-colorblind', 
#  'seaborn-dark', 'seaborn-dark-palette', 'seaborn-darkgrid', 'seaborn-deep', 'seaborn-muted', 
#  'seaborn-notebook', 'seaborn-paper', 'seaborn-pastel', 'seaborn-poster', 'seaborn-talk', 
#  'seaborn-ticks', 'seaborn-white', 'seaborn-whitegrid', 'tableau-colorblind10']

plt.rcParams['image.interpolation'] = 'bicubic'
plt.rcParams["font.family"] = 'Helvetica' # 'NanumGothic' # for Hangul in Windows

-----

## Load the configuration used during the train phase

In [5]:
# VGG
# model_name = 'lo88puq7'  # mixup o, awgn o
# model_name = '1nu3jagp'  # mixup x, awgn x
# model_name = '1mwdhqbz'  # mixup x, awgn x, dropout x

# ResNet
# model_name = 'l8524nml'  # mixup o, awgn o
model_name = 'ph0mix3b'    # mixup x, awgn o, dropout x
# model_name = '2apj72km'  # mixup x, awgn o
# model_name = '2k8xomy6'  # mixup x, awgn x


model_path = os.path.join(r'E:\CAUEEG\checkpoint', model_name, 'checkpoint.pt')
mix_repeat = 3

save_fig = True
output_folder = './local/output/mixup_tsne'

ckpt = torch.load(model_path, map_location=device)
print(ckpt.keys())

dict_keys(['model_state', 'config', 'optimizer_state', 'scheduler_state'])


In [6]:
model_state = ckpt['model_state']
config = ckpt['config']
optimizer = ckpt['optimizer_state']
scheduler = ckpt['scheduler_state']

In [7]:
pprint.pprint(config, width=250)

{'EKG': 'O',
 '_target_': 'models.resnet_1d.ResNet1D',
 'activation': 'gelu',
 'age_mean': tensor([71.2768], device='cuda:0'),
 'age_std': tensor([9.7251], device='cuda:0'),
 'awgn': 0,
 'awgn_age': 0,
 'base_channels': 64,
 'base_lr': 0.00033918432381593736,
 'block': 'basic',
 'class_label_to_name': ['Normal', 'MCI', 'Dementia'],
 'class_name_to_label': {'Dementia': 2, 'MCI': 1, 'Normal': 0},
 'conv_layers': [2, 2, 2, 2],
 'criterion': 'multi-bce',
 'crop_multiple': 4,
 'crop_timing_analysis': False,
 'cwd': 'C:\\Users\\Minjae\\Desktop\\EEG_Project',
 'dataset_name': 'CAUEEG dataset',
 'dataset_path': 'local/dataset/02_Curated_Data_220419/',
 'ddp': False,
 'device': device(type='cuda'),
 'draw_result': True,
 'dropout': 0,
 'fc_stages': 3,
 'file_format': 'memmap',
 'in_channels': 21,
 'input_norm': 'dataset',
 'iterations': 195312,
 'latency': 2000,
 'load_event': False,
 'lr_scheduler_type': 'cosine_decay_with_warmup_half',
 'mgn': 0,
 'minibatch': 512,
 'mixup': 0,
 'model': '1D-

-----

## Load the target model

In [8]:
# model = config['generator'](**config).to(device)
model = hydra.utils.instantiate(config).to(device)

if config.get('ddp', False):
    model_state_ddp = deepcopy(model_state)
    model_state = OrderedDict()
    for k, v in model_state_ddp.items():
        name = k[7:] # remove 'module.' of DataParallel/DistributedDataParallel
        model_state[name] = v
        
model.load_state_dict(model_state)

<All keys matched successfully>

-----

## Evaluate the model and analyze the performance by the crop timing

### Configurations

In [9]:
config = ckpt['config']

config.pop('cwd', 0)
config['ddp'] = False
config['crop_timing_analysis'] = True
config['eval'] = True
config['crop_multiple'] = 32
config['device'] = device

target_from_last = 2

### Build Dataset

In [10]:
if '220419' in config['dataset_path']:
    config['dataset_path'] = './local/dataset/caueeg-dataset/'
    
train_loader, val_loader, test_loader, multicrop_test_loader = build_dataset_for_train(config, verbose=True)

transform: Compose(
    EegRandomCrop(crop_length=2000, length_limit=10000000, multiple=32, latency=2000, segment_simulation=False, return_timing=True, reject_events=False)
    EegDropChannels(drop_index=[])
    EegToTensor()
)

----------------------------------------------------------------------------------------------------

transform_multicrop: Compose(
    EegRandomCrop(crop_length=2000, length_limit=10000000, multiple=8, latency=2000, segment_simulation=False, return_timing=True, reject_events=False)
    EegDropChannels(drop_index=[])
    EegToTensor()
)

----------------------------------------------------------------------------------------------------


task config:
{'class_label_to_name': ['Normal', 'MCI', 'Dementia'],
 'class_name_to_label': {'Dementia': 2, 'MCI': 1, 'Normal': 0},
 'task_description': 'Classification of [Normal], [MCI], and [Dementia] '
                     'symptoms.',
 'task_name': 'CAUEEG-Dementia benchmark'}

 -------------------------------------------

## Test accuracy

In [11]:
# _ = check_accuracy_extended_debug(model, test_loader, 
#                                   config['preprocess_test'], config, repeat=50)
# test_acc = _[0]
# test_score = _[1]
# test_target = _[2]
# test_confusion = _[3]
# test_error_table = _[4]
# test_crop_timing = _[5]

# print(test_acc)

In [12]:
# draw_roc_curve(test_score, test_target, config['class_label_to_name'], use_wandb=False)
# draw_confusion(test_confusion, config['class_label_to_name'], use_wandb=False)
# draw_class_wise_metrics(test_confusion, config['class_label_to_name'], use_wandb=False)
# draw_error_table(test_error_table, use_wandb=False)

## t-SNE embedding

In [13]:
@torch.no_grad()
def compute_embedding(model, sample_batched, preprocess, crop_multiple, target_from_last):
    # evaluation mode
    model.eval()
    
    # preprocessing (this includes to-device operation)
    preprocess(sample_batched)

    # apply model on whole batch directly on device
    x = sample_batched['signal']
    age = sample_batched['age']
    e = model.compute_feature_embedding(x, age, target_from_last=target_from_last)
    y = sample_batched['class_label']
    
    if crop_multiple > 1:
        # multi-crop averaging
        if e.size(0) % crop_multiple != 0:
            raise ValueError(f"compute_embedding(): Real minibatch size={e.size(0)} is not multiple of "
                             f"crop_multiple={crop_multiple}.")

        real_minibatch = e.size(0) // crop_multiple
        e_ = torch.zeros((real_minibatch, e.size(1)))
        y_ = torch.zeros((real_minibatch,), dtype=torch.int32)

        for m in range(real_minibatch):
            e_[m] = e[crop_multiple*m:crop_multiple*(m + 1)].mean(dim=0, keepdims=True)
            y_[m] = y[crop_multiple*m]
                
        e = e_
        y = y_
    
    return e, y

In [14]:
def mixup_data(x, age, y, alpha=0, use_cuda=True):
    lam = np.random.beta(alpha, alpha) if alpha > 1e-12 else 1
    batch_size = x.size()[0]
    if use_cuda:
        index = torch.randperm(batch_size).cuda()
    else:
        index = torch.randperm(batch_size)

    mixed_x = lam * x + (1 - lam) * x[index, :]
    mixed_age = lam * age + (1 - lam) * age[index]
    y_a, y_b = y, y[index]
    return mixed_x, mixed_age, y_a, y_b, lam, index1

def mixup_data_lam(x, age, y, lam=0.5, use_cuda=True):
    batch_size = x.size()[0]
    if use_cuda:
        index = torch.randperm(batch_size).cuda()
    else:
        index = torch.randperm(batch_size)

    mixed_x = lam * x + (1 - lam) * x[index, :]
    mixed_age = lam * age + (1 - lam) * age[index]
    y_a, y_b = y, y[index]
    return mixed_x, mixed_age, y_a, y_b, lam, index


@torch.no_grad()
def compute_mixup_embedding(model, sample_batched, preprocess, crop_multiple, mixup_alpha, target_from_last):
    # evaluation mode
    model.eval()
    
    # preprocessing (this includes to-device operation)
    preprocess(sample_batched)

    # apply model on whole batch directly on device
    x = sample_batched['signal']
    age = sample_batched['age']
    y = sample_batched['class_label']

    # x, age, y1, y2, lam, mixup_index = mixup_data(x, age, y, mixup_alpha)
    x, age, y1, y2, lam, mixup_index = mixup_data_lam(x, age, y, lam=0.5)
    e = model.compute_feature_embedding(x, age, target_from_last=target_from_last)
    y = torch.concatenate((y1.unsqueeze(dim=-1), y2.unsqueeze(dim=-1)), axis=-1)
    
    if crop_multiple > 1:
        # multi-crop averaging
        if e.size(0) % crop_multiple != 0:
            raise ValueError(f"compute_embedding(): Real minibatch size={e.size(0)} is not multiple of "
                             f"crop_multiple={crop_multiple}.")

        real_minibatch = e.size(0) // crop_multiple
        e_ = torch.zeros((real_minibatch, e.size(1)))
        y_ = torch.zeros((real_minibatch, y.size(1)), dtype=torch.int32)

        for m in range(real_minibatch):
            e_[m] = e[crop_multiple*m:crop_multiple*(m + 1)].mean(dim=0, keepdims=True)
            y_[m] = y[crop_multiple*m]
                
        e = e_
        y = y_
    
    return e, y

In [15]:
result = [{'name': 'Train Dataset', 'loader': train_loader}]

for r in range(len(result)):
    name = result[r]['name']
    loader = result[r]['loader']

    for i, sample_batched in enumerate(loader):
        if i == 0:
            crop_multiple = config['crop_multiple']
            minibatch_size = loader.batch_size

        # estimate
        e, y = compute_embedding(model, sample_batched, config['preprocess_test'], crop_multiple, target_from_last=target_from_last)

        if i == 0:
            embedding = e.detach().cpu().numpy()
            target = y.detach().cpu().numpy()
        else:
            embedding = np.concatenate([embedding, e.detach().cpu().numpy()], axis=0)
            target = np.concatenate([target, y.detach().cpu().numpy()], axis=0)     
                    
    result[r]['embedding'] = embedding
    result[r]['target'] = target

In [16]:
mixup_result = [{'name': 'Train Dataset', 'loader': train_loader}]

for r in range(len(mixup_result)):
    name = mixup_result[r]['name']
    loader = mixup_result[r]['loader']

    for m in range(mix_repeat):
        for i, sample_batched in enumerate(loader):
            if i == 0:
                crop_multiple = config['crop_multiple']
                minibatch_size = loader.batch_size
    
            # estimate
            e, y = compute_mixup_embedding(model, sample_batched, config['preprocess_test'], crop_multiple, 
                                           mixup_alpha=config['mixup'], target_from_last=target_from_last)
    
            if m == 0 and i == 0:
                embedding = e.detach().cpu().numpy()
                target = y.detach().cpu().numpy()
            else:
                embedding = np.concatenate([embedding, e.detach().cpu().numpy()], axis=0)  
                target = np.concatenate([target, y.detach().cpu().numpy()], axis=0)     
                    
    mixup_result[r]['embedding'] = embedding
    mixup_result[r]['target'] = target

## Draw 2D

In [17]:
# tsne_transform = TSNE(n_components=2, init="pca", learning_rate="auto", perplexity=50,
#                       n_iter=10000, n_iter_without_progress=1000, n_jobs=2, random_state=0,)

# for r in range(len(result)):
#     output = tsne_transform.fit_transform( np.concatenate([result[r]['embedding'], mixup_result[r]['embedding']]))
#     result[r]['tsne_embedding'] = output[:result[r]['embedding'].shape[0]]
#     mixup_result[r]['tsne_embedding'] = output[result[r]['embedding'].shape[0]:]

In [18]:
# plt.style.use('default') 
# plt.style.use('fivethirtyeight') # default, ggplot, fivethirtyeight, bmh, dark_background, classic
# plt.rcParams.update({'font.size': 16})
# plt.rcParams.update({'font.family': 'Roboto Slab'})
# # plt.rcParams["savefig.dpi"] = 1200
# color_map = ['tab:green', 'tab:orange', 'tab:red']

# for r in range(len(result)):
#     _, ax = plt.subplots()
#     for class_name, class_label in config['class_name_to_label'].items():
#         ax.scatter(
#             result[r]['tsne_embedding'][result[r]['target'] == class_label][:, 0],
#             result[r]['tsne_embedding'][result[r]['target'] == class_label][:, 1],
#             label=class_name,
#             color=color_map[class_label],
#             alpha=0.8,
#             edgecolors='k',
#             zorder=2)
#     ax.set_xticklabels([])
#     ax.set_yticklabels([])
#     ax.legend(bbox_to_anchor=(1.04, 1), loc='center left', borderaxespad=0.)
#     ax.set_title(f"t-SNE embedding of {result[r]['name']}")

In [19]:
# plt.style.use('default') 
# plt.style.use('fivethirtyeight') # default, ggplot, fivethirtyeight, bmh, dark_background, classic
# plt.rcParams.update({'font.size': 16})
# plt.rcParams.update({'font.family': 'Roboto Slab'})
# # plt.rcParams["savefig.dpi"] = 1200
# color_map = ['tab:green', 'tab:orange', 'tab:red']
# color_map2 = [['tab:green', 'tab:brown', 'tab:blue'], 
#               ['gray', 'tab:orange', 'tab:pink'], 
#               ['gray', 'gray', 'tab:red']]

# for r in range(len(result)):
#     _, ax = plt.subplots()
#     for class_name, class_label in config['class_name_to_label'].items():
#         ax.scatter(
#             result[r]['tsne_embedding'][result[r]['target'] == class_label][:, 0],
#             result[r]['tsne_embedding'][result[r]['target'] == class_label][:, 1],
#             label=class_name,
#             color=color_map[class_label],
#             alpha=0.2,
#             zorder=2)

#     for class_name, class_label in config['class_name_to_label'].items():        
#         for class_name2, class_label2 in config['class_name_to_label'].items():
#             if class_label2 <= class_label:
#                 continue
#             mixup_idx = np.all(mixup_result[r]['target'] == [class_label, class_label2], axis=-1)
#             mixup_idx = mixup_idx | np.all(mixup_result[r]['target'] == [class_label2, class_label], axis=-1)
#             ax.scatter(
#                 mixup_result[r]['tsne_embedding'][mixup_idx][:, 0],
#                 mixup_result[r]['tsne_embedding'][mixup_idx][:, 1],
#                 label=f"mixup of {class_name} and {class_name2}",
#                 color=color_map2[class_label][class_label2],
#                 alpha=0.8,
#                 edgecolors='k',
#                 zorder=2)

#     ax.legend(bbox_to_anchor=(1.04, 1), loc='center left', borderaxespad=0.)
#     ax.set_xticklabels([])
#     ax.set_yticklabels([])
#     ax.set_title(f"t-SNE embedding of {result[r]['name']}")

In [20]:
# plt.style.use('default') 
# plt.style.use('fivethirtyeight') # default, ggplot, fivethirtyeight, bmh, dark_background, classic
# plt.rcParams.update({'font.size': 16})
# plt.rcParams.update({'font.family': 'Roboto Slab'})
# plt.rcParams["savefig.dpi"] = 1200
# color_map = ['tab:green', 'tab:orange', 'tab:red']
# color_map2 = [['tab:green', 'tab:brown', 'tab:blue'], 
#               ['gray', 'tab:orange', 'tab:pink'], 
#               ['gray', 'gray', 'tab:red']]


# for n_iter in [7000]:
#     for perplexity in [25, 50, 100, 200, 300]:
#         tsne_transform = TSNE(n_components=2, init="pca", learning_rate="auto", perplexity=perplexity,
#                               n_iter=n_iter, n_iter_without_progress=1000, n_jobs=2, random_state=0,)
        
#         for r in range(len(result)):
#             output = tsne_transform.fit_transform(np.concatenate([result[r]['embedding'], mixup_result[r]['embedding']]))
#             result[r]['tsne_embedding'] = output[:result[r]['embedding'].shape[0]]
#             mixup_result[r]['tsne_embedding'] = output[result[r]['embedding'].shape[0]:]
            
#         for r in range(len(result)):
#             fig, ax = plt.subplots()
#             for class_name, class_label in config['class_name_to_label'].items():
#                 ax.scatter(
#                     result[r]['tsne_embedding'][result[r]['target'] == class_label][:, 0],
#                     result[r]['tsne_embedding'][result[r]['target'] == class_label][:, 1],
#                     label=class_name,
#                     color=color_map[class_label],
#                     alpha=0.8,
#                     edgecolors='k',
#                     zorder=2)
#             ax.set_xticklabels([])
#             ax.set_yticklabels([])
#             ax.legend(bbox_to_anchor=(1.04, 0.5), loc='center left', borderaxespad=0.)
#             if save_fig:
#                 for ext in ['pdf', 'jpg', 'svg']:
#                     os.makedirs(os.path.join(output_folder, model_name, ext), exist_ok=True)
#                     fig.savefig(os.path.join(output_folder, model_name, ext, f"dim2_per{perplexity:03}_iter{n_iter:05}_ori.{ext}"), 
#                                 transparent=True, bbox_inches='tight')
#             else:
#                 plt.show()
#             fig.clear()
#             plt.close(fig)
            
#         for r in range(len(result)):
#             fig, ax = plt.subplots()
#             for class_name, class_label in config['class_name_to_label'].items():
#                 ax.scatter(
#                     result[r]['tsne_embedding'][result[r]['target'] == class_label][:, 0],
#                     result[r]['tsne_embedding'][result[r]['target'] == class_label][:, 1],
#                     label=class_name,
#                     color=color_map[class_label],
#                     alpha=0.2,
#                     zorder=2)        
                
#             for class_name, class_label in config['class_name_to_label'].items():        
#                 for class_name2, class_label2 in config['class_name_to_label'].items():
#                     if class_label2 <= class_label:
#                         continue
#                     mixup_idx = np.all(mixup_result[r]['target'] == [class_label, class_label2], axis=-1)
#                     mixup_idx = mixup_idx | np.all(mixup_result[r]['target'] == [class_label2, class_label], axis=-1)
#                     ax.scatter(
#                         mixup_result[r]['tsne_embedding'][mixup_idx][:, 0],
#                         mixup_result[r]['tsne_embedding'][mixup_idx][:, 1],
#                         label=f"mixup of {class_name} and {class_name2}",
#                         color=color_map2[class_label][class_label2],
#                         alpha=0.8,
#                         edgecolors='k',
#                         zorder=2)
        
#             ax.set_xticklabels([])
#             ax.set_yticklabels([])
#             ax.legend(bbox_to_anchor=(1.04, 0.5), loc='center left', borderaxespad=0.)
#             if save_fig:
#                 for ext in ['pdf', 'jpg', 'svg']:
#                     os.makedirs(os.path.join(output_folder, model_name, ext), exist_ok=True)
#                     fig.savefig(os.path.join(output_folder, model_name, ext, f"dim2_per{perplexity:03}_iter{n_iter:05}.{ext}"), 
#                                 transparent=True, bbox_inches='tight')
#             else:
#                 plt.show()
#             fig.clear()
#             plt.close(fig)

In [21]:
# plt.style.use('default') 
# plt.style.use('fivethirtyeight') # default, ggplot, fivethirtyeight, bmh, dark_background, classic
# plt.rcParams.update({'font.size': 16})
# plt.rcParams.update({'font.family': 'Roboto Slab'})
# plt.rcParams["savefig.dpi"] = 1200
# color_map = ['tab:green', 'tab:orange', 'tab:red']
# color_map2 = [['tab:green', 'tab:brown', 'tab:blue'], 
#               ['gray', 'tab:orange', 'tab:pink'], 
#               ['gray', 'gray', 'tab:red']]


# for n_iter in [7000]:
#     for perplexity in [25, 50, 100, 200, 300]:
#         tsne_transform = TSNE(n_components=2, init="pca", learning_rate="auto", perplexity=perplexity,
#                               n_iter=n_iter, n_iter_without_progress=1000, n_jobs=2, random_state=0,)
        
#         for r in range(len(result)):
#             N = result[r]['embedding'].shape[0]
#             output = tsne_transform.fit_transform( np.concatenate([result[r]['embedding'], 
#                                                                    mixup_result[r]['embedding'][:N]]))
#             result[r]['tsne_embedding'] = output[:result[r]['embedding'].shape[0]]
#             mixup_result[r]['tsne_embedding'] = output[result[r]['embedding'].shape[0]:]

#         for r in range(len(result)):
#             fig, ax = plt.subplots()
#             for class_name, class_label in config['class_name_to_label'].items():
#                 ax.scatter(
#                     result[r]['tsne_embedding'][result[r]['target'] == class_label][:, 0],
#                     result[r]['tsne_embedding'][result[r]['target'] == class_label][:, 1],
#                     label=class_name,
#                     color=color_map[class_label],
#                     alpha=0.8,
#                     edgecolors='k',
#                     zorder=2)
#             ax.set_xticklabels([])
#             ax.set_yticklabels([])
#             ax.legend(bbox_to_anchor=(1.04, 0.5), loc='center left', borderaxespad=0.)
#             if save_fig:
#                 for ext in ['pdf', 'jpg', 'svg']:
#                     os.makedirs(os.path.join(output_folder, model_name, ext), exist_ok=True)
#                     fig.savefig(os.path.join(output_folder, model_name, ext, f"dim2_per{perplexity:03}_iter{n_iter:05}_ori.{ext}"), 
#                                 transparent=True, bbox_inches='tight')
#             else:
#                 plt.show()
#             fig.clear()
#             plt.close(fig)
            
#         for r in range(len(result)):
#             fig, ax = plt.subplots()
#             for class_name, class_label in config['class_name_to_label'].items():
#                 ax.scatter(
#                     result[r]['tsne_embedding'][result[r]['target'] == class_label][:, 0],
#                     result[r]['tsne_embedding'][result[r]['target'] == class_label][:, 1],
#                     label=class_name,
#                     color=color_map[class_label],
#                     alpha=0.2,
#                     zorder=2)
        
#             N = result[r]['embedding'].shape[0]
#             for class_name, class_label in config['class_name_to_label'].items():        
#                 for class_name2, class_label2 in config['class_name_to_label'].items():
#                     if class_label2 < class_label:
#                         continue
#                     mixup_idx = np.all(mixup_result[r]['target'] == [class_label, class_label2], axis=-1)
#                     mixup_idx = mixup_idx | np.all(mixup_result[r]['target'] == [class_label2, class_label], axis=-1)
#                     ax.scatter(
#                         mixup_result[r]['tsne_embedding'][mixup_idx[:N]][:, 0],
#                         mixup_result[r]['tsne_embedding'][mixup_idx[:N]][:, 1],
#                         label=f"mixup of {class_name} and {class_name2}",
#                         color=color_map2[class_label][class_label2],
#                         alpha=0.8,
#                         edgecolors='k',
#                         zorder=2)
        
#             ax.set_xticklabels([])
#             ax.set_yticklabels([])
#             ax.legend(bbox_to_anchor=(1.04, 0.5), loc='center left', borderaxespad=0.)
#             if save_fig:
#                 for ext in ['pdf', 'jpg', 'svg']:
#                     os.makedirs(os.path.join(output_folder, model_name, ext), exist_ok=True)
#                     fig.savefig(os.path.join(output_folder, model_name, ext, f"dim2_per{perplexity:03}_iter{n_iter:05}_1Epoch.{ext}"), 
#                                 transparent=True, bbox_inches='tight')
#             else:
#                 plt.show()
#             fig.clear()
#             plt.close(fig)

## Draw 3D

In [22]:
plt.style.use('default') 
plt.style.use('fivethirtyeight') # default, ggplot, fivethirtyeight, bmh, dark_background, classic
plt.rcParams.update({'font.size': 16})
plt.rcParams.update({'font.family': 'Roboto Slab'})
plt.rcParams["savefig.dpi"] = 1200
color_map = ['tab:green', 'tab:orange', 'tab:red']
color_map2 = [['tab:green', 'tab:brown', 'tab:blue'], 
              ['gray', 'tab:orange', 'tab:pink'], 
              ['gray', 'gray', 'tab:red']]


for n_iter in [10000]:
    for perplexity in [100]: # [50, 70, 100, 150, 200]:
        tsne_transform = TSNE(n_components=3, init="pca", learning_rate="auto", perplexity=perplexity,
                              n_iter=n_iter, n_iter_without_progress=2000, n_jobs=4, random_state=0,)
        
        for r in range(len(result)):
            output = tsne_transform.fit_transform( np.concatenate([result[r]['embedding'], mixup_result[r]['embedding']]))
            result[r]['tsne_embedding'] = output[:result[r]['embedding'].shape[0]]
            mixup_result[r]['tsne_embedding'] = output[result[r]['embedding'].shape[0]:]

        for r in range(len(result)):
            fig = plt.figure(num=1, clear=True, figsize=(12.0, 12.0))
            ax = fig.add_subplot(1, 1, 1, projection='3d')
            for class_name, class_label in config['class_name_to_label'].items():
                ax.scatter(
                    xs=result[r]['tsne_embedding'][result[r]['target'] == class_label][:, 0],
                    ys=result[r]['tsne_embedding'][result[r]['target'] == class_label][:, 1],
                    zs=result[r]['tsne_embedding'][result[r]['target'] == class_label][:, 2],
                    label=class_name,
                    color=color_map[class_label],
                    alpha=0.8,
                    s=40,
                    # zorder=2
                )
            ax.set_xticklabels([])
            ax.set_yticklabels([])
            ax.set_zticklabels([])
            ax.legend(bbox_to_anchor=(1.04, 0.5), loc='center left', borderaxespad=0.)
            if save_fig:
                for ext in ['pdf', 'jpg', 'svg']:
                    os.makedirs(os.path.join(output_folder, model_name, ext), exist_ok=True)
                    fig.savefig(os.path.join(output_folder, model_name, ext, f"dim3_per{perplexity:03}_iter{n_iter:05}_ori.{ext}"), 
                                transparent=True, bbox_inches='tight')
            else:
                plt.show()
            fig.clear()
            plt.close(fig)
            
        for r in range(len(result)):
            fig = plt.figure(num=1, clear=True, figsize=(12.0, 12.0))
            ax = fig.add_subplot(1, 1, 1, projection='3d')
            for class_name, class_label in config['class_name_to_label'].items():
                ax.scatter(
                    xs=result[r]['tsne_embedding'][result[r]['target'] == class_label][:, 0],
                    ys=result[r]['tsne_embedding'][result[r]['target'] == class_label][:, 1],
                    zs=result[r]['tsne_embedding'][result[r]['target'] == class_label][:, 2],
                    label=class_name,
                    color=color_map[class_label],
                    alpha=0.2,
                    s=40,
                    # zorder=2
                )
        
            for class_name, class_label in config['class_name_to_label'].items():        
                for class_name2, class_label2 in config['class_name_to_label'].items():
                    if class_label2 <= class_label:
                        continue
                    mixup_idx = np.all(mixup_result[r]['target'] == [class_label, class_label2], axis=-1)
                    mixup_idx = mixup_idx | np.all(mixup_result[r]['target'] == [class_label2, class_label], axis=-1)
                    ax.scatter(
                        xs=mixup_result[r]['tsne_embedding'][mixup_idx][:, 0],
                        ys=mixup_result[r]['tsne_embedding'][mixup_idx][:, 1],
                        zs=mixup_result[r]['tsne_embedding'][mixup_idx][:, 2],
                        label=f"mixup of {class_name} and {class_name2}",
                        color=color_map2[class_label][class_label2],
                        alpha=0.8,
                        s=40,
                        edgecolors='k',
                        # zorder=2
                    )
        
            ax.set_xticklabels([])
            ax.set_yticklabels([])
            ax.set_zticklabels([])
            ax.legend(bbox_to_anchor=(1.04, 0.5), loc='center left', borderaxespad=0.)
            if save_fig:
                for ext in ['pdf', 'jpg', 'svg']:
                    os.makedirs(os.path.join(output_folder, model_name, ext), exist_ok=True)
                    fig.savefig(os.path.join(output_folder, model_name, ext, f"dim3_per{perplexity:03}_iter{n_iter:05}.{ext}"), 
                                transparent=True, bbox_inches='tight')
            else:
                plt.show()
            fig.clear()
            plt.close(fig)


The PCA initialization in TSNE will change to have the standard deviation of PC1 equal to 1e-4 in 1.2. This will ensure better convergence.



In [23]:
plt.style.use('default') 
plt.style.use('fivethirtyeight') # default, ggplot, fivethirtyeight, bmh, dark_background, classic
plt.rcParams.update({'font.size': 16})
plt.rcParams.update({'font.family': 'Roboto Slab'})
plt.rcParams["savefig.dpi"] = 1200
color_map = ['tab:green', 'tab:orange', 'tab:red']
color_map2 = [['tab:green', 'tab:brown', 'tab:blue'], 
              ['gray', 'tab:orange', 'tab:pink'], 
              ['gray', 'gray', 'tab:red']]


for n_iter in [10000]:
    for perplexity in [100]: # [50, 70, 100, 150, 200]:
        tsne_transform = TSNE(n_components=3, init="pca", learning_rate="auto", perplexity=perplexity,
                              n_iter=n_iter, n_iter_without_progress=2000, n_jobs=4, random_state=0,)
        
        for r in range(len(result)):
            N = result[r]['embedding'].shape[0]
            output = tsne_transform.fit_transform( np.concatenate([result[r]['embedding'], 
                                                                   mixup_result[r]['embedding'][:N]]))
            result[r]['tsne_embedding'] = output[:result[r]['embedding'].shape[0]]
            mixup_result[r]['tsne_embedding'] = output[result[r]['embedding'].shape[0]:]

        for r in range(len(result)):
            fig = plt.figure(num=1, clear=True, figsize=(12.0, 12.0))
            ax = fig.add_subplot(1, 1, 1, projection='3d')
            for class_name, class_label in config['class_name_to_label'].items():
                ax.scatter(
                    xs=result[r]['tsne_embedding'][result[r]['target'] == class_label][:, 0],
                    ys=result[r]['tsne_embedding'][result[r]['target'] == class_label][:, 1],
                    zs=result[r]['tsne_embedding'][result[r]['target'] == class_label][:, 2],
                    label=class_name,
                    color=color_map[class_label],
                    alpha=0.8,
                    s=40,
                    # zorder=2
                )
            ax.set_xticklabels([])
            ax.set_yticklabels([])
            ax.set_zticklabels([])
            ax.legend(bbox_to_anchor=(1.04, 0.5), loc='center left', borderaxespad=0.)
            if save_fig:
                for ext in ['pdf', 'jpg', 'svg']:
                    os.makedirs(os.path.join(output_folder, model_name, ext), exist_ok=True)
                    fig.savefig(os.path.join(output_folder, model_name, ext, f"dim3_per{perplexity:03}_iter{n_iter:05}_ori.{ext}"), 
                                transparent=True, bbox_inches='tight')
            else:
                plt.show()
            fig.clear()
            plt.close(fig)
            
        for r in range(len(result)):
            fig = plt.figure(num=1, clear=True, figsize=(12.0, 12.0))
            ax = fig.add_subplot(1, 1, 1, projection='3d')
            for class_name, class_label in config['class_name_to_label'].items():
                ax.scatter(
                    xs=result[r]['tsne_embedding'][result[r]['target'] == class_label][:, 0],
                    ys=result[r]['tsne_embedding'][result[r]['target'] == class_label][:, 1],
                    zs=result[r]['tsne_embedding'][result[r]['target'] == class_label][:, 2],
                    label=class_name,
                    color=color_map[class_label],
                    alpha=0.2,
                    s=40,
                    # zorder=2
                )
        
            N = result[r]['embedding'].shape[0]
            for class_name, class_label in config['class_name_to_label'].items():        
                for class_name2, class_label2 in config['class_name_to_label'].items():
                    if class_label2 < class_label:
                        continue
                    mixup_idx = np.all(mixup_result[r]['target'] == [class_label, class_label2], axis=-1)
                    mixup_idx = mixup_idx | np.all(mixup_result[r]['target'] == [class_label2, class_label], axis=-1)
                    ax.scatter(
                        xs=mixup_result[r]['tsne_embedding'][mixup_idx[:N]][:, 0],
                        ys=mixup_result[r]['tsne_embedding'][mixup_idx[:N]][:, 1],
                        zs=mixup_result[r]['tsne_embedding'][mixup_idx[:N]][:, 2],
                        label=f"mixup of {class_name} and {class_name2}",
                        color=color_map2[class_label][class_label2],
                        alpha=0.8,
                        s=40,
                        edgecolors='k',
                        # zorder=2
                    )
        
            ax.set_xticklabels([])
            ax.set_yticklabels([])
            ax.set_zticklabels([])
            ax.legend(bbox_to_anchor=(1.04, 0.5), loc='center left', borderaxespad=0.)
            if save_fig:
                for ext in ['pdf', 'jpg', 'svg']:
                    os.makedirs(os.path.join(output_folder, model_name, ext), exist_ok=True)
                    fig.savefig(os.path.join(output_folder, model_name, ext, f"dim3_per{perplexity:03}_iter{n_iter:05}_1Epoch.{ext}"), 
                                transparent=True, bbox_inches='tight')
            else:
                plt.show()
            fig.clear()
            plt.close(fig)


The PCA initialization in TSNE will change to have the standard deviation of PC1 equal to 1e-4 in 1.2. This will ensure better convergence.

