In [None]:
model_ckpt = 'CL_expression_after_valaro_z64_bs512'

In [None]:
import os
import sys

sys.path += [os.path.dirname('../scripts/')]

In [None]:
import numpy as np
import pandas as pd
import pickle as pkl
import torch

from GWTune.src.align_representations import Representation, AlignRepresentations, OptimizationConfig, VisualizationConfig

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device

In [None]:
token_valaro = pd.read_csv(f'./data/{model_ckpt}token_and_valaro.csv', header=None)
token_emotion = pd.read_csv(f'./data/CL_expression_after_valaro_z64_bs512token_and_emotion.csv', header=None)
valaro = token_valaro.iloc[:, -2:].values
token = token_valaro.iloc[:, :-2].values
emotion = token_emotion.iloc[:, -1].values

In [None]:
from utils import exclude_id
from GWTune.src.utils.utils_functions import get_category_data, sort_matrix_with_categories


category_mat = pd.get_dummies(emotion)
id2label, _ = exclude_id([8,9,10])
category_mat.columns = id2label.values()

object_labels, category_idx_list, num_category_list, category_name_list = get_category_data(category_mat = category_mat)

In [None]:
from utils import exclude_id
from GWTune.src.utils.utils_functions import get_category_data, sort_matrix_with_categories


category_mat = pd.get_dummies(emotion)
id2label, _ = exclude_id([8,9,10])
category_mat.columns = id2label.values()

object_labels, category_idx_list, num_category_list, category_name_list = get_category_data(category_mat = category_mat)

representations = []
for name, emb in zip(['CLS-token', 'val-aro'],
                     [token, valaro]):
    representation = Representation(
        name=name,
        embedding=emb,
        metric='seuclidean',
        get_embedding=False,
        object_labels=object_labels,
        category_name_list=category_name_list,
        category_idx_list=category_idx_list,
        num_category_list=num_category_list, 
        func_for_sort_sim_mat=sort_matrix_with_categories,
    )
    representations.append(representation)

In [None]:
eps_list_tutorial = [1e-4, 1e-2]
to_types = 'torch'

# whether epsilon is sampled at log scale or not
eps_log = True
num_trial = 10

init_mat_plan = "random"

In [None]:
if device == 'cuda':
    sinkhorn_method = 'sinkhorn_log' # please choose the method of sinkhorn implemented by POT (URL : https://pythonot.github.io/gen_modules/ot.bregman.html#id87). For using GPU, "sinkhorn_log" is recommended.
    data_type= 'float'
    multi_gpu = [1, 2, 3]

elif device == 'cpu':
    sinkhorn_method = 'sinkhorn'
    data_type = 'double'
    multi_gpu = False

In [None]:
config = OptimizationConfig(    
    eps_list = eps_list_tutorial,
    eps_log = eps_log, 
    num_trial = num_trial,
    sinkhorn_method=sinkhorn_method,
    
    to_types = to_types,
    device = device,
    data_type = data_type, 
    
    n_jobs = 4,
    multi_gpu = multi_gpu, 
    db_params={"drivername": "sqlite"},
    
    init_mat_plan = init_mat_plan,
    
    n_iter = 1,
    max_iter = 1000,
    
    sampler_name = 'tpe',
    pruner_name = 'hyperband',
    pruner_params = {'n_startup_trials': 1, 
                     'n_warmup_steps': 2, 
                     'min_resource': 2, 
                     'reduction_factor' : 3
                    },
)

In [None]:
align_representation = AlignRepresentations(
    config=config,
    representations_list=representations,   
   
    # histogram matching : this will adjust the histogram of target to that of source.
    histogram_matching=False,

    # metric : The metric for computing the distance between the embeddings. Please set the metric tha can be used in "scipy.spatical.distance.cdist()".
    metric="seuclidean", 

    # main_results_dir : folder or file name when saving the result
    main_results_dir =  "./GWOT_results/" + '/seuclidean/' + model_ckpt,
   
    # data_name : Please rewrite this name if users want to use their own data.
    data_name = model_ckpt,
)

In [None]:
sim_mat_format = "sorted"

visualize_config = VisualizationConfig(
    figsize=(12, 12),
    title_size = 15,
    cmap = 'rocket_r',
    ot_object_tick=False,
    fig_ext='svg',
    ot_category_tick=True,
    draw_category_line=True,
    colorbar_range=None
)

visualize_hist = VisualizationConfig(figsize=(8, 6), color='C0')

sim_mat = align_representation.show_sim_mat(
    sim_mat_format = sim_mat_format, 
    visualization_config = visualize_config,
    visualization_config_hist = visualize_hist,
    show_distribution=False,
    
)

In [None]:
align_representation.RSA_get_corr(metric = "pearson")

In [None]:
visualize_config = VisualizationConfig(
    show_figure=True,
    figsize=(8, 6), 
    title_size = 15, 
    ot_object_tick=True,
    plot_eps_log=eps_log,
    cmap='viridis'
)

ot_mat = align_representation.gw_alignment(
    compute_OT = False,
    delete_results = False,
    return_data = True,
    return_figure = True,
    OT_format = sim_mat_format,
    visualization_config = visualize_config,
)

In [None]:
align_representation.show_optimization_log(fig_dir=None, visualization_config=visualize_config)

In [None]:
ot_df: pd.DataFrame = align_representation.calc_accuracy(top_k_list = [1, 5, 10], eval_type = "ot_plan", return_dataframe=True)
align_representation.plot_accuracy(eval_type = "ot_plan", scatter = True)

ot_df.to_csv("./GWOT_results/" + f'{model_ckpt}/' + 'op_plan_accuracy_.csv')

In [None]:
category_df = align_representation.calc_accuracy(top_k_list = [1, 5, 10], eval_type = "category", category_mat=category_mat, return_dataframe=True)
align_representation.plot_accuracy(eval_type = "category", scatter = True)

category_df.to_csv("./GWOT_results/" + f'{model_ckpt}/' + 'category_accuracy_.csv')

In [None]:
def sorted2raw_indices(object_labels):
    output = {}
    for sorted_idx, raw_idx in enumerate(object_labels):
        output[sorted_idx] = raw_idx
    return output


def find_matching(ot):
    _matching = {}
    for k, v in enumerate(np.argmax(ot, axis=1)):
        _matching[k] = v
    idx_conv = sorted2raw_indices(object_labels)
    matching = {}
    for k, v in _matching.items():
        matching[idx_conv[k]] = idx_conv[v]
    return matching


def eval_valaro(ot_mat, align_representation):
    l = []
    for pairwise, ot in zip(align_representation.pairwise_list, ot_mat):
        pair_name = pairwise.pair_name
        matching = find_matching(ot)
        valaro_dist = 0
        for k, v in matching.items():
            valaro_dist += np.linalg.norm(valaro[k] - valaro[v])
        valaro_dist /= ot.shape[0]
        print(f'{pair_name}: {valaro_dist}')

In [None]:
eval_valaro(ot_mat, align_representation)

In [None]:
# visualization_embedding = VisualizationConfig(
#     cmap="cool",
#     colorbar_label="frame",
#     colorbar_range=[0, 30],
#     color_labels=None,
#     color_hue="cool", # If "color_labels=None", you have the option to choose the color hue as either "cool", "warm", or "None".
#     figsize=(10, 10), 
#     xlabel="PC1", 
#     ylabel="PC2",
#     marker_size=50,
#     legend_size=11
# )

# align_representation.visualize_embedding(
#     dim=2, # the dimensionality of the space the points are embedded in. You can choose either 2 or 3.
#     pivot=0, # the number of one of the representations or the "barycenter".
#     visualization_config=visualization_embedding
# )