# Brief Tutorial on Extracting Optuna Study
This notebook is intended to introduce the minimal functions needed to perform the GWOT optimization.    
For most users, using `align_representations.py`, as demonstrated in our main tutorial `tutorial.ipynb`, will be sufficient.   
However, this tutorial is intended for some users who want to understand how Optuna is used in this toolbox and customize the optimization process by using Optuna on their own.    

This notebook briefly demonstrates how to :
1. Use `opt.run_study` to create an Optuna study.   
2. Extract the `best_trial` from the study.    

Please make sure that you have worked through the main tutorial (`tutorial.ipynb`) before diving into this one, as this tutorial focuses only on specific objectives and assumes familiarity with the main concepts.

In [None]:
# Standard Library
import os
import pickle as pkl
import sys

sys.path.append(os.path.join(os.getcwd(), '../'))

# Third Party Library
import matplotlib.pyplot as plt
import numpy as np
import torch
from sqlalchemy import URL

# First Party Library
from src.gw_alignment import GW_Alignment
from src.utils.gw_optimizer import load_optimizer
# os.chdir(os.path.dirname(__file__))

## Step:1 load data
Here we use the `color` data for demonstaration.   
`color`: human similarity judgements of 93 colors for 5 paricipants groups

In [None]:
data_path = '../data/color/num_groups_5_seed_0_fill_val_3.5.pickle'
with open(data_path, "rb") as f:
    data = pkl.load(f)
sim_mat_list = data["group_ave_mat"]

C1 = sim_mat_list[0]
C2 = sim_mat_list[1]

# show dissimilarity matrices
fig, axes = plt.subplots(1, 2, figsize=(10, 5))
im1 = axes[0].imshow(C1, cmap='rocket_r')
cbar1 = fig.colorbar(im1, ax=axes[0], shrink=0.8)
im2 = axes[1].imshow(C2, cmap='rocket_r')
cbar2 = fig.colorbar(im2, ax=axes[1], shrink=0.8)

axes[0].set_title('Dissimilarity matrix #1')
axes[1].set_title('Dissmimilarity matrix #2')
plt.show()

## Step:2 set the parameter used for computing and saving the results

In [None]:
# Set the range of the epsilon
# set the minimum value and maximum value for "tpe" sampler
# for "grid" or "random" sampler, you can also set the step size    
eps_list = [1e-2, 1e-1]  # [1e-2, 1e-1, 1e-3]
device = "cpu"
to_types = "numpy"

# whether epsilon is sampled at log scale or not
eps_log = True

# Set the params for the trial of optimize and max iteration for gw alignment computation
# set the number of trials, i.e., the number of epsilon values tested in optimization: default : 20
num_trial = 20

# the maximum number of iteration for GW optimization: default: 1000
max_iter = 200

# choose sampler
sampler_name = 'tpe'

# choose pruner
pruner_name = 'hyperband'
pruner_params = {'n_startup_trials': 1, 'n_warmup_steps': 2, 'min_resource': 2, 'reduction_factor' : 3}

# initialization of transportation plan
# 'uniform': uniform matrix, 'diag': diagonal matrix', random': random matrix, 'permutation': permutation matrix
init_mat_plan = 'random'

# the number of random initial matrices for 'random' or 'permutation' options：default: 100
n_iter = 1

## Set the parameters for GW alignment computation 
# please choose the method of sinkhorn implemented by POT (URL : https://pythonot.github.io/gen_modules/ot.bregman.html#id87). For using GPU, "sinkhorn_log" is recommended.
sinkhorn_method='sinkhorn'

# user can define the dtypes both for numpy and torch, "float(=float32)" or "double(=float64)". For using GPU with "sinkhorn", double is storongly recommended.
data_type = "double"

### Set the filename and folder name for saving optuna results  
filename is also treated as optuna study_name

In [None]:
filename = 'test'
save_path = '../results/tutorial_minimal/' + filename + '/' + sampler_name

### Set the database URL to store the optimization results.  

The URL notation should follow the SQLAlchemy documentation:   
https://docs.sqlalchemy.org/en/20/core/engines.html  

To use remote databases, you need to start the database server beforehand. For detailed instruction, please refer to the Optuna official tutorial:  
https://optuna.readthedocs.io/en/stable/tutorial/10_key_features/004_distributed.html  

When using SQLite, the database file is automatically created, so you only need to set the URL.

In [None]:
# Specify the RDB to use for distributed calculations
db_params={"drivername": "sqlite"} # SQLite
# db_params={"drivername": "mysql+pymysql", "username": "root", "password": "****", "host": "localhost"} # MySQL

if db_params["drivername"] == "sqlite":
    storage = "sqlite:///" + save_path +  '/' + filename + '.db'
else:
    # Generate the URL for the database. Syntax differs for SQLite and others.
    storage = URL.create(database=filename, **db_params).render_as_string(hide_password=False)

## Step:3 Perform GW Alignment

In [None]:
# generate instance solves gw_alignment　
test_gw = GW_Alignment(
    C1, 
    C2, 
    save_path, 
    max_iter = max_iter, 
    n_iter = n_iter, 
    to_types = to_types,
    data_type = data_type,
    sinkhorn_method = sinkhorn_method,    
)

In [None]:
# generate instance optimize gw_alignment　
opt = load_optimizer(
    save_path=save_path,
    filename=filename,
    storage=storage,
    init_mat_plan=init_mat_plan,
    n_iter = n_iter,
    num_trial = num_trial,
    n_jobs = 1,    
    method = 'optuna',
    sampler_name = sampler_name,
    pruner_name = pruner_name,
    pruner_params = pruner_params,
)

### Compute GW Alignment

In [None]:
### Running the Optimization using `opt.run_study`
# 2. run optimzation
study = opt.run_study(
    test_gw,
    device,
    seed=42,
    init_mat_plan=init_mat_plan,
    eps_list=eps_list,
    eps_log=eps_log,
    search_space=None,
)

## Step:4 View the result

In [None]:
### View Results
display(study.trials_dataframe().sort_values('params_eps'))

In [None]:
### Extracting the Best Trial from the Study
df_trial = study.trials_dataframe()
best_trial = study.best_trial
print(best_trial)

# extracting optimized epsilon, GWD from best_trial
eps_opt = best_trial.params['eps']
GWD_opt = best_trial.values[0]

# load the opitimized transportation plan from the saved file
if to_types == 'numpy':
    OT = np.load(save_path+f'/gw_{best_trial.number}.npy')
elif to_types == 'torch':
    OT = torch.load(save_path+f'/gw_{best_trial.number}.pt')
    OT = OT.to('cpu').numpy()

# plot the optimal transportation plan
plt.imshow(OT)
plt.title(f'OT eps:{eps_opt:.3f} GWD:{GWD_opt:.3f}')
plt.show()


In [None]:
# figure plotting epsilon as x-axis and GWD as y-axis
df_trial = study.trials_dataframe()

plt.scatter(df_trial['params_eps'], df_trial['value'], s = 50, c=df_trial['user_attrs_best_acc'] * 100, cmap='viridis')
plt.xlabel('$\epsilon$')
plt.ylabel('GWD')
plt.colorbar(label='Matching Rate (%)')
plt.show()

In [None]:
# evaluate accuracy of unsupervised alignment
max_indices = np.argmax(OT, axis=1)
accuracy = np.mean(max_indices == np.arange(OT.shape[0])) * 100
print(f'accuracy={accuracy}%')


#　figure plotting GWD as x-axis and accuracy as y-axis
plt.scatter(df_trial['user_attrs_best_acc'] * 100, df_trial['value'], s = 50, c= df_trial['params_eps'])
plt.xlabel('Matching Rate (%)')
plt.ylabel('GWD')
plt.colorbar(label='epsilon')
plt.show()