## Rotating Logos - Reference Set Different Distribution

In [None]:
### Import relevant libraries
import matplotlib.pyplot as plt
from matplotlib import font_manager
import numpy as np
import cv2
import scipy as sci
import tqdm
import os
from scipy import ndimage
import helper_functions.plotting_funcs as plot_funcs
from helper_functions.logo_funcs import imresize_pad, lin2circ_angles, get_validation_indices
import json
import pickle
import seaborn as sns


sim_params = {
        'delete_kernels': False,
        'generate_data': False,
        'evd_solver': 'arpack',  # 'arpack' / 'randomized' / 'svd'
        'ad_methods': ['lead', 'forward_only', 'ncca', 'kcca_impute', 'nystrom', 'adm_plus', 'backward_only'],
        'embed_dim': 2,
        't': 0,
        'scales': [2, 8, 10, 20],
        'angle_bias_factor_max': 0.8,
        'angles_for_bias': 'mnm',
        'im_resize_factor': 2,
        'Nr': 50,  # number of samples in the reference set,
        'N' : 1000, # number of total samples
        'valid_size': 0.2,
    }

fig_str = f"figures/logos_rs_factor_{sim_params['im_resize_factor']}_{sim_params['evd_solver']}_N_{sim_params['N']}_Nr_{sim_params['Nr']}_diff_dist".replace('.', 'p')
# figures_path = f"{fig_str}_{sim_params['N']}"
figures_path = f"{fig_str}"
os.makedirs(figures_path, exist_ok=True)
Nr = sim_params['Nr'] # number of samples in the reference set
N_d = sim_params['N']


# save params to json
with open(f"{figures_path}/sim_params.json", 'w') as fp:
    json.dump(sim_params, fp, indent=4)

In [None]:
font_name = "Times New Roman"  # Change to any other installed serif font if needed

# Set font properties using the font name
# 2 plots at a row 
font_properties_title = font_manager.FontProperties(family=font_name, size=28)
font_properties_ticks = font_manager.FontProperties(family=font_name, size=22)
figsize = (8, 7) 
# 3 plots in a row 
# font_properties_title = font_manager.FontProperties(family=font_name, size=38)
# font_properties_ticks = font_manager.FontProperties(family=font_name, size=32)
# figsize = (8, 7)  # 2 plot size

## Load Images

In [None]:
## Load Logo images 
# KFC logo
img = cv2.imread('KFC_logo')
img_kfc = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

# M&M logo
img = cv2.imread('mnm_logo.jpg')
img_mnm = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img_mnm_g = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)

# Starbucks logo
img = cv2.imread('starbucks_logo.jpg')
img_str = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img_str_g = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)

# resize images so they can be concatenated 
height = 400 # height needs to be uniform for concatanation

# KFC resize
kfc_scale = img_kfc.shape[0]/height
kfc_width = int(img_kfc.shape[1]/kfc_scale)
img_kfc_rs = cv2.resize(img_kfc,(kfc_width,height), interpolation = cv2.INTER_AREA)

# M&M resize
mnm_scale = img_mnm.shape[0]/height
mnm_width = int(img_mnm.shape[1]/mnm_scale)
img_mnm_rs = cv2.resize(img_mnm,(mnm_width,height), interpolation = cv2.INTER_AREA)

# Starbucks resize
str_scale = img_str.shape[0]/height
str_width = int(img_str.shape[1]/str_scale)
img_str_rs = cv2.resize(img_str,(str_width,height), interpolation = cv2.INTER_AREA)

## Create Example of Rotation for Visualization

In [None]:
## concatenate Image rotations to simulate two sensors 
angle_KFC = 30
angle_MNM = 74
angle_STR = 110

# resize M&M
mnm_resized = imresize_pad(img_mnm_rs,sim_params['im_resize_factor'])

# rotate images 
kfc_rot = ndimage.rotate(img_kfc_rs, angle_KFC, reshape=False, mode='nearest')
mnm_rot = ndimage.rotate(mnm_resized, angle_MNM, reshape=False, mode='nearest')
str_rot = ndimage.rotate(img_str_rs, angle_STR, reshape=False, mode='nearest')

# Concatenate 
sensor_1 = np.concatenate((kfc_rot, mnm_rot),axis=1)
sensor_2 = np.concatenate((mnm_rot, str_rot),axis=1)
total_scene = np.concatenate((sensor_1, str_rot),axis=1)

# plot parameters
plt.rcParams['figure.figsize'] = (8, 6)
plt.rcParams['font.size'] = 18
# plt.rcParams['text.usetex'] = True
plt.rcParams['font.family'] = 'serif'
plt.rcParams['axes.labelsize'] = 18
plt.rcParams['axes.titlesize'] = 22
plt.rcParams['xtick.labelsize'] = 18
plt.rcParams['ytick.labelsize'] = 18
plt.rcParams['legend.fontsize'] = 18
plt.rcParams['lines.linewidth'] = 2
plt.rcParams['axes.grid'] = True
plt.rcParams['grid.alpha'] = 0.5
plt.rcParams['grid.linestyle'] = '--'


# plot images
fig, ax = plot_funcs.subplots_imshow(1, 1, figsize=(8, 4))
ax.imshow(sensor_1, cmap=plt.cm.gray)
ax.set_title("Sensor 1", font_properties=font_properties_title)
plt.savefig(f"{figures_path}/sensor1_scene.pdf", dpi = 300, format = 'pdf', bbox_inches='tight')
fig, ax = plot_funcs.subplots_imshow(1, 1, figsize=(8, 4))
ax.imshow(sensor_2, cmap=plt.cm.gray)
ax.set_title("Sensor 2", font_properties=font_properties_title)
plt.savefig(f"{figures_path}/sensor2_scene.pdf", dpi = 300, format = 'pdf', bbox_inches='tight')
fig, ax = plot_funcs.subplots_imshow(1, 1, figsize=(16, 8))
ax.imshow(total_scene, cmap=plt.cm.gray)
ax.set_title("Whole Scene", font_properties=font_properties_title)
plt.savefig(f"{figures_path}/whole_scene.pdf", dpi = 300, format = 'pdf', bbox_inches='tight')

## Create Dataset of Synchronized Pairs Rotating images
We create a dataset of 2 sensors measuring 2 pairs of images with a common image and a different image.\
We aim to create an embedding of the common variable - the rotation angle of the common logo.

In [None]:

# define rotation angular velocities
w_kfc = 2.93 * 4 # [degrees/timestamp]
w_str = 1.27 * 4# [degrees/timestamp]
w_mnm = 2.11 * 4# [degrees/timestamp]

# calculate angles 
angles_kfc_d = lin2circ_angles(w_kfc*np.linspace(0, N_d-1, N_d))
angles_mnm_d = lin2circ_angles(w_mnm*np.linspace(0, N_d-1, N_d)) # generate angle vectors for dataset 
angles_str_d = lin2circ_angles(w_str*np.linspace(0, N_d-1, N_d)) # generate angle vectors for dataset 

In [None]:
## Create Dataset - Deterministic Rotation
if sim_params['generate_data'] or not os.path.isfile(f"{figures_path}/s1_low_csr.npy"):
    d1 = sensor_1.size # number of pixels/dimension of data points
    d2 = sensor_2.size # number of pixels/dimension of data points
    s1_points_d = np.zeros((N_d, d1), dtype='uint8') # initalize data points 
    s2_points_d = np.zeros((N_d, d2), dtype='uint8') # initalize data points 
    
    for i in tqdm.tqdm(range(N_d)):
        # rotate images
        kfc_rot = ndimage.rotate(img_kfc_rs, angles_kfc_d[i], reshape=False, mode='nearest')
        mnm_rot = ndimage.rotate(mnm_resized, angles_mnm_d[i], reshape=False, mode='nearest')
        str_rot = ndimage.rotate(img_str_rs, angles_str_d[i], reshape=False, mode='nearest')
        # concatenate images
        sensor_1 = np.concatenate((kfc_rot, mnm_rot), axis=1)
        sensor_2 = np.concatenate((mnm_rot, str_rot), axis=1)
        # flatten images to create a vector
        s1_points_d[i,:] = sensor_1.reshape((1, d1))
        s2_points_d[i,:] = sensor_2.reshape((1, d2))
        
    %time np.save(f"{figures_path}/s1_low_csr", s1_points_d)
    %time np.save(f"{figures_path}/s2_low_csr", s2_points_d)
    
    print('Successfully Generated Dataset')
else:
    print('Data Already Generated')

## Load Data

In [None]:
s1_points_d = np.load(f"{figures_path}/s1_low_csr.npy")
s2_points_d = np.load(f"{figures_path}/s2_low_csr.npy")

## Split the Data into Reference Set and Total Set
We take samples of the rotation process by subsampling of the rotation times.

In [None]:
def sample_reference_set(s1, s2, angles, Nr, bias_factor=1):
    # define probability weights 
    weights = (angles / np.max(angles)) ** bias_factor
    
    # normalize weights to get probabilities
    normalized_weights = weights / np.sum(weights)
    
    # randomly sample Nr angles with probability weight
    ref_idx = np.random.choice(len(angles), size=Nr, replace=False, p=normalized_weights)
    
    # select views
    total_idx = np.round(np.linspace(0, N_d - 1, N_d)).astype(int)
    single_idx = [i for i in total_idx if i not in ref_idx] # the rest of the indices 
    reorder_idx = np.concatenate((ref_idx, single_idx))
    reorder_idx = np.argsort(reorder_idx)
    
    # split reference set
    s1_ref = s1[ref_idx, :]
    s2_ref = s2[ref_idx, :]
    
    # create single sensor set - only with samples from sensor 1
    s1_single = s1[single_idx, :]
    s1_aligned = np.concatenate((s1_ref, s1_single), axis=0)
    s2_single = s2[single_idx, :] # save sensor 2 images for reference to completed images
    s2_aligned = np.concatenate((s2_ref, s2_single), axis=0)
    
    return s1_ref, s2_ref, s1_aligned, s2_aligned, reorder_idx, ref_idx

In [None]:
## select in which angle to create bias
if sim_params['angles_for_bias'] == 'mnn':
    angles = angles_mnm_d
elif sim_params['angles_for_bias'] == 'str':
    angles = angles_str_d
else:
    angles = angles_kfc_d


### Run all methods

In [None]:
from helper_functions.AD_funcs import Create_Transition_Mat, Create_Asym_Tran_Kernel, embed_wrapper
embed_dict = dict()
bias_factors = np.linspace(0, sim_params['angle_bias_factor_max'], 6)
for bias_factor in bias_factors:
    s1_ref, s2_ref, s1_aligned, s2_aligned, reorder_idx, ref_idx = sample_reference_set(s1_points_d, s2_points_d, angles, Nr, bias_factor)
    for scale in tqdm.tqdm(sim_params['scales']):
        A1, _, _ = Create_Asym_Tran_Kernel(s1_aligned, s1_ref, mode='median', scale=scale)
        A2, _, _ = Create_Asym_Tran_Kernel(s2_aligned, s2_ref, mode='median', scale=scale)
        K1, _ = Create_Transition_Mat(s1_aligned, scale=scale)
        K2, _ = Create_Transition_Mat(s2_aligned, scale=scale)
        K1_ref, _ = Create_Transition_Mat(s1_ref, scale=scale)
        K2_ref, _ = Create_Transition_Mat(s2_ref, scale=scale)
    
        for method in sim_params['ad_methods']:
            dict_key = f'{method}_scale_{scale}_factor_{bias_factor}'
            if method in {"forward_only", "forward_only_slow", "alternating_roseland", "ffbb", "fbfb", "ncca", "kcca", "nystrom", 'adm_plus', 'backward_only'}:
                embed = embed_wrapper(s1_ref, s1_aligned, s2_ref, s2_aligned, method=method,
                              embed_dim=sim_params['embed_dim'], t=sim_params['t'],
                              K1=A1, K2=K2_ref, solver=sim_params['evd_solver'],
                              delete_kernels=sim_params['delete_kernels'])
                embed_dict[dict_key] = embed[reorder_idx, :]
            elif method in {"ad", 'dm', 'kcca_full'}:
                embed = embed_wrapper(s1_ref, s1_aligned, s2_ref, s2_aligned, method=method,
                                  embed_dim=sim_params['embed_dim'], t=sim_params['t'],
                                  K1=K1, K2=K2, solver=sim_params['evd_solver'],
                                  delete_kernels=sim_params['delete_kernels'])
                embed_dict[dict_key] = embed[reorder_idx, :]
            elif method in {"kcca_impute"}:
                embed = embed_wrapper(s1_ref, s1_aligned, s2_ref, s2_aligned, method=method,
                                  embed_dim=sim_params['embed_dim'], t=sim_params['t'],
                                  K1=K1, K2=K2_ref, solver=sim_params['evd_solver'],
                                  delete_kernels=sim_params['delete_kernels'])
                embed_dict[dict_key] = embed[reorder_idx, :]
            elif method == "lead":
                embed = embed_wrapper(s1_ref, s1_aligned, s2_ref, s2_aligned, method=method,
                                  embed_dim=sim_params['embed_dim'], t=sim_params['t'],
                                  K1=A1, K2=A2, solver=sim_params['evd_solver'],
                                  delete_kernels=sim_params['delete_kernels'])
                embed_dict[dict_key] = embed[reorder_idx, :]
            
    # save embeddings to file
with open(f"{figures_path}/embedding_dictionary.pkl", 'wb') as fp:
    pickle.dump(embed_dict, fp)
    print('dictionary saved successfully to file')

In [None]:
# Read dictionary pkl file
with open(f"{figures_path}/embedding_dictionary.pkl", 'rb') as fp:
    embed_dict = pickle.load(fp)
    print('Dictionary Loaded Successfully')

In [None]:
def plot_method_embedding(embed, figures_path, angles, Nr, method, ref_idx, plot_flag=True, pointsize=20,
                          pointsize_ref=30, fontproperties=None, tick_fontproperties=None, figsize=(8, 7)):
    fig, ax = plot_funcs.subplots_plot(1, 1, figsize=figsize)
    colors = lin2circ_angles(angles)
    N_d = embed.shape[0]
    total_idx = np.round(np.linspace(0, N_d - 1, N_d)).astype(int)
    single_idx = [i for i in total_idx if i not in ref_idx]
    # define font
    if fontproperties is None:
        my_fontproperties = font_manager.FontProperties(family='Times New Roman', size=18)
    else:
        my_fontproperties = fontproperties
    if tick_fontproperties is None:
        my_tick_fontproperties = font_manager.FontProperties(family='Times New Roman', size=18)
    else:
        my_tick_fontproperties = tick_fontproperties
    ax.scatter(embed[ref_idx, 0], embed[ref_idx, 1], marker='x', c=colors[ref_idx],
               label='Reference', s=pointsize_ref)
    ax.scatter(embed[single_idx, 0], embed[single_idx, 1], marker='.', c=colors[single_idx],
               label='Out of Reference', s=pointsize)

    # plot a circle to show how well it fits into a circle
    # Calculate radii from the origin (assuming points are centered at (0, 0))
    radii = np.linalg.norm(embed, axis=1)
    median_radius = np.median(radii)
    quantile_15 = np.quantile(radii, 0.15)
    quantile_85 = np.quantile(radii, 0.85)

    # Plot the median radius circle
    circle = plt.Circle((0, 0), median_radius, color='black', linestyle='--', fill=False, label='Median Radius')
    ax.add_artist(circle)
    # circle = plt.Circle((0, 0), quantile_25, color='gray', linestyle='--', alpha=0.5, fill=False,
    #                     label='0.25 Quant Radius')
    # ax.add_artist(circle)
    # circle = plt.Circle((0, 0), quantile_75, color='gray', linestyle='--', alpha=0.5, fill=False,
    #                     label='0.75 Quant Radius')
    # ax.add_artist(circle)

    # Plot the shaded region for the 0.25 and 0.75 quantile radii
    theta = np.linspace(0, 2 * np.pi, 100)
    x_quantile_15 = quantile_15 * np.cos(theta)
    y_quantile_15 = quantile_15 * np.sin(theta)
    x_quantile_85 = quantile_85 * np.cos(theta)
    y_quantile_85 = quantile_85 * np.sin(theta)
    ax.fill(np.concatenate([x_quantile_85, x_quantile_15[::-1]]),
                np.concatenate([y_quantile_85, y_quantile_15[::-1]]),
                color='gray', alpha=0.3, label='0.15-0.85 Quantile Radius')

    # Set axis aspect ratio to be equal
    # ax.set_aspect('equal')
    max_extent = np.max(np.abs(embed)) * 1.1  # Add a small margin
    ax.set_xlim(-max_extent, max_extent)
    ax.set_ylim(-max_extent, max_extent)
    
    # ax.set_title("M&M Angle Colors - Spectral Completed Data, $N_R = {}$".format(Nr))
    ax.set_xlabel("First Diffusion Coordinate", font_properties=my_fontproperties)
    ax.set_ylabel("Second Diffusion Coordinate", font_properties=my_fontproperties)
    # ax.legend(loc='upper right')
    for label in ax.get_xticklabels() + ax.get_yticklabels():
        label.set_fontproperties(my_tick_fontproperties)
    if plot_flag:
        plt.show()
    
    # Display the normalized ring width in the second figure
    normalized_ring_width = (quantile_85 - quantile_15) / median_radius
    ax.text(1, 1, f'Width = {normalized_ring_width:.2f} Median(R)',
            transform=ax.transAxes, fontsize=22, fontproperties=my_fontproperties,
            verticalalignment='top', horizontalalignment='right')
    
    plt.savefig(f"{figures_path}/{method}_embedding.pdf", dpi=300, format='pdf', bbox_inches='tight')

    fig, ax = plot_funcs.subplots_plot(1, 1, figsize=figsize)
    ax.scatter(embed[single_idx, 0], embed[single_idx, 1], marker='.', c='r',
               label='Out of Reference', s=pointsize)
    ax.scatter(embed[ref_idx, 0], embed[ref_idx, 1], marker='.', c='b', label='Reference', s=pointsize)
    # ax.set_title("Spectral Completed - Reference Vs. Completed")
    ax.set_xlabel("First Diffusion Coordinate", font_properties=my_fontproperties)
    ax.set_ylabel("Second Diffusion Coordinate", font_properties=my_fontproperties)
    

    ax.legend(loc='upper right')
    for label in ax.get_xticklabels() + ax.get_yticklabels():
        label.set_fontproperties(my_fontproperties)
    if plot_flag:
        plt.show()
    plt.savefig(f"{figures_path}/{method}_embedding_comp_vs_ref.pdf", dpi=300, format='pdf', bbox_inches='tight')

In [None]:
# plot and save embeddings
# from helper_functions.logo_funcs import plot_method_embedding
for method in embed_dict.keys():
    plot_method_embedding(embed_dict[method], figures_path, angles_mnm_d, Nr, method, ref_idx, plot_flag=False, pointsize=20, pointsize_ref=30, fontproperties=font_properties_ticks, figsize=figsize)


## Evaluation

In [None]:
import pandas as pd
from helper_functions.logo_funcs import embed_error
# calculate Error per Method
errors = []
# randomly select validation set
validation_idx = get_validation_indices(sim_params, seed=0)
# calculate error for each method
for key in embed_dict.keys():
    embed = np.real(embed_dict[key])
    error_mae, error_std = embed_error(embed, angles_mnm_d, plot_flag=False, metric='MAE')
    error_mae_val, _ = embed_error(embed[validation_idx, :], angles_mnm_d[validation_idx], plot_flag=False, metric='MAE')
    error_mse, _ = embed_error(embed, angles_mnm_d, plot_flag=False, metric='RMSE')
    error_mae_center, error_std_center = embed_error(embed, angles_mnm_d, plot_flag=False, metric='MAE', center_data=True)
    error_mse_center, _ = embed_error(embed, angles_mnm_d, plot_flag=False, metric='RMSE', center_data=True)
    parts = key.split('_scale_')
    method = parts[0]
    parts = parts[1].split('_factor_')
    scale = float(parts[0])
    bias_factor = float(parts[1])
    new_line = {'Method' : method,
                'scale': scale,
                'bias_factor': bias_factor,
                'RMSE' : error_mse,
                'MAE' : error_mae,
                'MAE_valid': error_mae_val,
                'STD' : error_std,
                'RMSE w centered data' : error_mse_center,
                'MAE w centered data' : error_mae_center,
                'STD center' : error_std_center,
                'ref mean radius' : np.mean(np.sqrt(embed[:Nr, 0] ** 2 + embed[:Nr, 1] ** 2)),
                'out of ref mean radius' : np.mean(np.sqrt(embed[Nr:, 0] ** 2 + embed[Nr:, 1] ** 2))
                }
    errors.append(new_line)

error_df = pd.DataFrame(errors)
error_df.to_csv(f"{figures_path}/results_{sim_params['evd_solver']}.csv")
error_df

In [None]:
best_error_df = error_df.loc[error_df.groupby(['Method', 'bias_factor'])['MAE_valid'].idxmin()]
best_error_df.to_csv(f"{figures_path}/results_best.csv")

### Post Processing 

In [None]:
import pandas as pd

# load results
summary_path = f"{figures_path}/summary"
results_df = pd.read_csv(f'{figures_path}/results_best.csv')

os.makedirs(summary_path, exist_ok=True)

In [None]:
# format
sns.set_style("whitegrid", {'grid.linestyle': '--'})  # Adjust grid style
plt.rcParams['figure.figsize'] = (8, 6)
plt.rcParams['font.size'] = 20
plt.rcParams['text.usetex'] = True
plt.rcParams['font.family'] = 'serif'
plt.rcParams['text.latex.preamble'] = r'\usepackage{newtxmath} \usepackage{newtxtext} \usepackage{newtxtext}'
plt.rcParams['font.serif'] = "Times New Roman"
plt.rcParams['axes.labelsize'] = 12
plt.rcParams['axes.titlesize'] = 20
plt.rcParams['xtick.labelsize'] = 20
plt.rcParams['ytick.labelsize'] = 20
plt.rcParams['legend.fontsize'] = 20
plt.rcParams['lines.linewidth'] = 3
plt.rcParams['axes.grid'] = True
plt.rcParams['grid.alpha'] = 0.5
plt.rcParams['grid.linestyle'] = '--'

font_name = "Times New Roman"  # Change to any other installed serif font if needed

# Set font properties using the font name
font_properties = font_manager.FontProperties(family=font_name, size=18)

In [None]:
reference_methods = ['lead']  # methods for performance reference 
competing_methods = ['nystrom', 'ncca']  # competing methods under the same setting
our_methods = ['adm_plus', 'forward_only', 'backward_only']

method_names = {
    'lead': 'LAD',
    'ad': 'ADM',
    'dm': 'DM',
    # 'nystrom': 'Nyström',
    'nystrom': 'Dov et al.',
    'ncca': 'NCCA',
    'kcca': 'KCCA (ChatGPT)',
    'kcca_impute': 'KCCA',
    'forward_only': 'forward only',
    'backward_only': 'backward only',
    'adm_plus': 'ADM+'
}

# Specify the color palette for different methods
palette_reference = {'ad': 'black', 'dm': 'grey', 'lead': 'black'}  # Reference methods: black and grey
palette_our_methods = {'forward_only': 'blue', 'adm_plus': 'dodgerblue', 'backward_only': 'cyan'}  # Our methods in shades of blue
palette_competing = {'nystrom': 'green', 'ncca': 'orange', 'kcca': 'violet', 'kcca_impute': 'purple'}  # Competing methods in other colors

# linestyle
linewidth = 3.5
errorbar = 'sd'

# fonts 
legend_fontsize = 26
label_fontsize = 26
tick_fontsize = 32


In [None]:
# Filter the dataframe for the current train_percent value and selected methods 
plt.figure(figsize=(10, 6))

# Plot for reference methods with dashed lines
df_filtered = results_df[results_df['Method'].isin(reference_methods)]
ax = sns.lineplot(data=df_filtered, x='bias_factor', y=r'MAE', hue='Method',
                  marker='o', palette=palette_reference, linewidth=linewidth,
                  hue_order=reference_methods, estimator='mean', linestyle='dashed', errorbar=errorbar)

# Plot for competing methods
df_filtered = results_df[results_df['Method'].isin(competing_methods)] 
ax = sns.lineplot(data=df_filtered, x='bias_factor', y='MAE', hue='Method',
                  marker='o', palette=palette_competing, linewidth=linewidth,
                  hue_order=competing_methods, estimator='mean', errorbar=errorbar)

# Plot for our methods
df_filtered = results_df[results_df['Method'].isin(our_methods)]
ax = sns.lineplot(data=df_filtered, x='bias_factor', y='MAE', hue='Method',
                  marker='o', palette=palette_our_methods, linewidth=linewidth,
                  hue_order=our_methods, estimator='mean', errorbar=errorbar)

# Modify the legend with custom names
handles, labels = plt.gca().get_legend_handles_labels()
labels = [method_names[label] for label in labels]
plt.legend(handles, labels, title_fontsize=22, fontsize=22, loc='lower right', frameon=True, prop=font_properties)

# Enhancing the plot
# plt.title(f'Accuracy vs Dim for each Method (train_percent={train_percent})', fontsize=16, weight='bold', fontproperties=font_properties)
# plt.xlabel('Train Size[%]', fontsize=label_fontsize, fontproperties=font_properties)

plt.xlabel(r'$\alpha$', fontsize=label_fontsize)
plt.ylabel(r'MAE[deg]', fontsize=label_fontsize)
plt.xticks(fontsize=tick_fontsize)  # Increase x-tick fontsize
plt.yticks(fontsize=tick_fontsize)  # Increase y-tick fontsize
ax.tick_params(axis='x', labelsize=24)
ax.tick_params(axis='y', labelsize=24)
# ax.set_ylim([0.1, 1.0])

# ax.legend().set_visible(True)

# Adding gridlines for better readability
plt.grid(True, which='both', linestyle='--', linewidth=0.6)

# Remove the top and right spines for a cleaner look
sns.despine()
plt.savefig(f'{summary_path}/alpha_vs_mae.pdf', dpi=300, format='pdf', bbox_inches='tight')

# Biased sampling test

In [None]:
# define probability weights 
bias_factors = np.linspace(0, 5, 30)
angles = np.linspace(0, 360, 1000)
# angles = angles_mnm_d
Nr = sim_params['Nr']
avg_angles = np.zeros(bias_factors.shape)
seeds = [0, 32, 192, 240, 211]
for seed in seeds:
    for i, bias_factor in enumerate(bias_factors):
        weights = (angles / np.max(angles)) ** bias_factor
        
        # normalize weights to get probabilities
        normalized_weights = weights / np.sum(weights)
        
        # randomly sample Nr angles with probability weight
        np.random.seed(seed)
        ref_idx = np.random.choice(len(angles), size=Nr, replace=False, p=normalized_weights)
        
        sampled_angles = angles[ref_idx]
        avg_angles[i] += sampled_angles.mean()

avg_angles /= len(seeds)
# plot 
fig, ax = plt.subplots(1, 1)
ax.plot(bias_factors, avg_angles)
ax.set_xlabel(r'$\alpha$', fontsize=20)
ax.set_ylabel(r'$\theta_{avg}$', fontsize=20)