# Tutorial For GW Methods
This notebook will demonstrate the main function of this toolbox.

Most of the users' case, we can recommend to use these function by using `align_representations.py`, as `tutorial.ipynb` demonstrates.

So, please check the `tutorial.ipynb` before checking this notebook.

But, this tutorial may be helpful when employing this toolbox to our unexpected usage case which `align_representations.py` is unable to follow.

In [None]:
# Standard Library
import os
import pickle as pkl
import sys

sys.path.append(os.path.join(os.getcwd(), '../'))

# Third Party Library
import matplotlib.pyplot as plt
import numpy as np
import ot
import pandas as pd
# import pymysql
import scipy.io
import seaborn as sns
import torch

# First Party Library
from src.gw_alignment import GW_Alignment
from src.utils.gw_optimizer import load_optimizer
from src.utils.init_matrix import InitMatrix
# os.chdir(os.path.dirname(__file__))


## Step:1 load data
#### you can choose the following data
1. 'DNN': representations of 2000 imagenet images in AlexNet and VGG
1. 'color': human similarity judgements of 93 colors for 5 paricipants groups
1. 'face': human similarity judgements of 16 faces, attended vs unattended condition in the same participant

The three data above is sample data to demonstrate the computation of this toolbox.
For the people who want to use their own data, please rewrite the `C1` and `C2` in the block below.

In [None]:
data_select = 'color'

if data_select == 'DNN':
    path1 = '../data/model1.pt'
    path2 = '../data/model2.pt'
    C1 = torch.load(path1)
    C2 = torch.load(path2)
elif data_select == 'color':
    data_path = '../data/num_groups_5_seed_0_fill_val_3.5.pickle'
    with open(data_path, "rb") as f:
        data = pkl.load(f)
    sim_mat_list = data["group_ave_mat"]
    C1 = sim_mat_list[1]
    C2 = sim_mat_list[2]
elif data_select == 'face':
    data_path = '../data/faces_GROUP_interp.mat'
    mat_dic = scipy.io.loadmat(data_path)
    C1 = mat_dic["group_mean_ATTENDED"]
    C2 = mat_dic["group_mean_UNATTENDED"]

fig, axes = plt.subplots(1, 2, figsize=(10, 5))
im1 = axes[0].imshow(C1, cmap='viridis')
cbar1 = fig.colorbar(im1, ax=axes[0])
im2 = axes[1].imshow(C2, cmap='viridis')
cbar2 = fig.colorbar(im2, ax=axes[1])

axes[0].set_title('Dissimilarity matrix #1')
axes[1].set_title('Dissmimilarity matrix #2')
plt.show()

## Step:2 set the parameter used for computing and saving the results
### Set the filename and folder name for saving optuna results  
filename is also treated as optuna study_name

In [None]:
filename = 'test'
save_path = '../results/gw_alignment/' + filename

### set the device ('cuda' or 'cpu') and variable type ('torch' or 'numpy')

In [None]:
device = 'cpu'
to_types = 'numpy'

### Set the database URL to store the optimization results.  

The URL notation should follow the SQLAlchemy documentation:   
https://docs.sqlalchemy.org/en/20/core/engines.html  

To use remote databases, you need to start the database server beforehand. For detailed instruction, please refer to the Optuna official tutorial:  
https://optuna.readthedocs.io/en/stable/tutorial/10_key_features/004_distributed.html  

When using SQLite, the database file is automatically created, so you only need to set the URL.

In [None]:
# Specify the RDB to use for distributed calculations
storage = "sqlite:///" + save_path +  '/' + filename + '.db'
# storage = 'mysql+pymysql://root:olabGPU61@localhost/GridTest'

### Set the range of epsilon
  
set only the minimum value and maximum value for 'tpe' sampler
   
for 'grid' or 'random' sampler, you can also set the step size

In [None]:
eps_list = [1e-2, 1e-1]
# eps_list = [1e-2, 1e-1, 1e-3]

eps_log = True # use log scale if True

#### Set the params for the trial of optimize and max iteration for gw alignment computation

In [None]:
# set the number of trials, i.e., the number of epsilon values tested in optimization: default : 20
num_trial = 8

# the maximum number of iteration for GW optimization: default: 1000
max_iter = 200

### choose sampler
1. 'random': randomly select epsilon between the range of epsilon
1. 'grid': grid search between the range of epsilon
1. 'tpe': Bayesian sampling
  
### ※ For TPE-Sampler and Grid Sampler, we recommend that `n_jobs` should be 1 because of the limitation of this algorithm.  
We also provide parallel computation by multi-thread with `n_jobs > 1` by using the default function implemented by Optuna, but n_jobs = 1 may be safer especially for grid and TPE sampler for optuna's technical problem.  

In [None]:
sampler_name = 'random'

# the number of jobs
n_jobs = 1

### choose pruner
1. 'median': Pruning if the score is below the past median at a certain point in time  
    n_startup_trials: Do not activate the pruner until this number of trials has finished  
    n_warmup_steps: Do not activate the pruner for each trial below this step  
      
1. 'hyperband': Use multiple SuccessiveHalvingPrunerd that gradually longer pruning decision periods and that gradually stricter criteria  
    min_resource: Do not activate the pruner for each trial below this step  
    reduction_factor: How often to check for pruning. Smaller values result in more frequent pruning checks. Between 2 to 6.  
      
1. 'nop': no pruning

In [None]:
pruner_name = 'hyperband'
pruner_params = {'n_startup_trials': 1, 'n_warmup_steps': 2, 'min_resource': 2, 'reduction_factor' : 3}

In [None]:
# distribution in the source space, and target space
p = ot.unif(len(C1))
q = ot.unif(len(C2))

### Set the parameters for initialization of transportation plan
1. initialization of transportation plan
2. 'uniform': uniform matrix, 'diag': diagonal matrix
3. 'random': random matrix, 'permutation': permutation matrix

In [None]:
# initialization of transportation plan
# 'uniform': uniform matrix, 'diag': diagonal matrix', random': random matrix, 'permutation': permutation matrix
init_mat_plan = 'random'

# the number of random initial matrices for 'random' or 'permutation' options：default: 100
n_iter = 1

## Set the parameters for GW alignment computation 

In [None]:
# please choose the method of sinkhorn implemented by POT (URL : https://pythonot.github.io/gen_modules/ot.bregman.html#id87). For using GPU, "sinkhorn_log" is recommended.
sinkhorn_method='sinkhorn_log'

# user can define the dtypes both for numpy and torch, "float(=float32)" or "double(=float64)". For using GPU with "sinkhorn", double is storongly recommended.
data_type = "double"

## Step:3 Perform GW Alignment

In [None]:
# generate instance solves gw_alignment　
test_gw = GW_Alignment(
    C1, 
    C2, 
    p, 
    q, 
    save_path, 
    max_iter = max_iter, 
    n_iter = n_iter, 
    to_types = to_types,
    data_type = data_type,
    sinkhorn_method = sinkhorn_method,    
)

In [None]:
# generate instance optimize gw_alignment　
opt = load_optimizer(
    save_path,
    n_jobs = n_jobs,
    num_trial = num_trial,
    to_types = to_types,
    method = 'optuna',
    sampler_name = sampler_name,
    pruner_name = pruner_name,
    pruner_params = pruner_params,
    n_iter = n_iter,
    filename = filename,
    storage = storage,
)

### define the space used for Grid Sampler. 

In [None]:
# optimization
# 1. choose the initial matrix for GW alignment computation.
# init_mat_builder.implemented_init_plans(init_plans_list) will automatically remove the plans which is not implemented in this toolbox.
init_plan = test_gw.main_compute.init_mat_builder.implemented_init_plans(init_mat_plan)

# used only in grid search sampler below the two lines
if sampler_name == "grid":
    # used only in grid search sampler below the two lines
    eps_space = opt.define_eps_space(eps_list, eps_log, num_trial)
    search_space = {"eps": eps_space}
else:
    search_space = None

### Compute GW Alignment

In [None]:
# 2. run optimzation
study = opt.run_study(
    test_gw, 
    device, 
    init_mat_plan = init_mat_plan, 
    eps_list = eps_list, 
    eps_log = eps_log, 
    search_space = search_space
)

## Step:4 View the result

In [None]:
### View Results
display(study.trials_dataframe().sort_values('params_eps'))

In [None]:
df_trial = study.trials_dataframe()
best_trial = study.best_trial
print(best_trial)

In [None]:
# optimized epsilon, GWD, and transportation plan
eps_opt = best_trial.params['eps']
GWD_opt = best_trial.values[0]

if to_types == 'numpy':
    OT = np.load(save_path+f'/{init_mat_plan}/gw_{best_trial.number}.npy')
elif to_types == 'torch':
    OT = torch.load(save_path+f'/{init_mat_plan}/gw_{best_trial.number}.pt')
    OT = OT.to('cpu').numpy()

plt.imshow(OT)
plt.title(f'OT eps:{eps_opt:.3f} GWD:{GWD_opt:.3f}')
plt.show()

df_trial = study.trials_dataframe()

# evaluate accuracy of unsupervised alignment
max_indices = np.argmax(OT, axis=1)
accuracy = np.mean(max_indices == np.arange(OT.shape[0])) * 100
print(f'accuracy={accuracy}%')

# figure plotting epsilon as x-axis and GWD as y-axis
sns.scatterplot(data = df_trial, x = 'params_eps', y = 'value', s = 50)
plt.xlabel('$\epsilon$')
plt.ylabel('GWD')
plt.show()

#　figure plotting GWD as x-axis and accuracy as y-axis
sns.scatterplot(data = df_trial, x = 'value', y = 'user_attrs_best_acc', s = 50)
plt.xlabel('GWD')
plt.ylabel('accuracy')
plt.show()