## Tutorial For Histogram Matching

In [None]:
# Standard Library
import itertools
import os
import warnings
import functools

warnings.simplefilter("ignore")

# Third Party Library
import numpy as np
import scipy as sp
import matplotlib.pyplot as plt
import seaborn as sns

from matplotlib.cm import ScalarMappable, get_cmap
import matplotlib.colors as colors
# %config InlineBackend.figure_formats = {'png', 'retina'} # for notebook?
plt.rcParams["font.size"] = 14

import torch
import optuna
import ot

from scipy.special import comb
from sklearn.metrics import confusion_matrix, accuracy_score

In [None]:
def correct_rate_mv(gw):
    y_pred = np.argmax(gw, axis=1)
    y_true = np.arange(0, len(gw))
    cm = confusion_matrix(y_true, y_pred)
    acc = accuracy_score(y_true, y_pred)
    return acc, cm

def sc_plot(x, y, labels):
    plt.figure()
    plt.plot(x, y, '.')
    plt.xlabel(labels[0])
    plt.ylabel(labels[1])
    plt.show()


def im_plot(X, Y, title_list):
    fig, axes = plt.subplots(1, 2, figsize=(10, 5))
    for ax, dm, t in zip(axes.reshape(-1), [X, Y], title_list):
        a = ax.imshow(dm)
        ax.set_title(t)
        cbar = fig.colorbar(a, ax=ax, shrink=0.7)
    plt.show()

In [None]:
# generate disimilarity matrix data
n = 1000  # number of elements
sigma = 1

np.random.seed(seed=42)
x_sim = np.random.uniform(0, 1, size=comb(n, 2, exact=True))
np.random.seed(seed=0)
y_sim = 2 * x_sim + np.random.uniform(0, sigma, size=comb(n, 2, exact=True))

X = sp.spatial.distance.squareform(x_sim)  # disimilarity matrix 1
Y = sp.spatial.distance.squareform(y_sim)  # disimilarity matrix 2

# %%
# RSA correlation
x_flat = sp.spatial.distance.squareform(X)
y_flat = sp.spatial.distance.squareform(Y)

corr, _ = sp.stats.pearsonr(x_flat, y_flat)
print(f'pearson r = {corr}')

sc_plot(x_flat, y_flat, ['X', 'Y'])
im_plot(X, Y, ['X', 'Y'])

# histogram alignment
# Y_t = histogram_matching(X, Y)
# im_plot(X, Y_t, ['X', 'transformed Y'])
