In [None]:
from src.metrics.intrinsic_dimension import IntrinsicDimension
from src.metrics.clustering import LabelClustering
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
plot_config = {
    #'font.size': 12,           
    'axes.titlesize': 30,      
    'axes.labelsize': 29,
    'xtick.labelsize': 20,
    'ytick.labelsize': 20,
    'legend.fontsize': 23,
    'figure.figsize': (10,8),
    'lines.linewidth': 2.5,
    'lines.markersize': 10,
}


# Data visulization
The following snippet can be used to replicate some of the plot from the paper

## Intrinsic Dimension

In [None]:
def plotter(data, title,):
    # Set the style
    sns.set_style(
        "whitegrid",
        rc={"axes.edgecolor": ".15", "xtick.bottom": True, "ytick.left": True},
    )
    # Setup figure and axes for 2 plots in one row
    plt.figure(dpi = 200)
    layers = np.arange(0,data[0].shape[0])

    #Set ticks
    if layers.shape[0] < 50:
        tick_positions = np.arange(0, layers.shape[0], 4)  # Generates positions 0, 4, 8, ...
    else:
        tick_positions = np.arange(0, layers.shape[0], 8)  # Generates positions 0, 4, 8, ...

    tick_labels = tick_positions +1 # Get the corresponding labels from x

    
    names = ["0 shot pt", 
            "1 shot pt", 
            "2 shot pt",
            f"5 shot pt"]
    markerstyle = ['o', 'o', 'o', 'o']
    
    for int_dim, label, markerstyle in zip(data, names, markerstyle):
        sns.scatterplot(x=layers, y=int_dim, marker= markerstyle)
        sns.lineplot(x=layers, y=int_dim, label=label)


    plt.xlabel("Layer")
    plt.ylabel("ID")
    plt.title(title)
    plt.xticks(ticks=tick_positions, labels=tick_labels)
    tick_positions_y = np.arange(2.5, 22, 22/10).round(3)
    plt.yticks(tick_positions_y)
    plt.tick_params(axis='y')
    plt.legend()
    plt.tight_layout()
    plt.rcParams.update(plot_config)
    plt.show()

In [None]:
shot = [0,1,2,5]
data = []
for i in shot:
    intrinsic_dim = IntrinsicDimension(path=f'.../{shot}shot.csv')
    data.append(intrinsic_dim.main())
plotter(data, "Intrinsic Dimension")

## Clustering

### Subjects

In [None]:
shot = [0,1,2,5]
data = []
for i in shot:
    clustering = LabelClustering(path=f'.../{shot}shot.csv')
    data.append(clustering.main(label="subject"))
plotter(data, "Label Clustering")

### Letters

In [None]:
shot = [0,1,2,5]
data = []
for i in shot:
    clustering = LabelClustering(path=f'.../{shot}shot.csv')
    data.append(clustering.main(label="letters"))
plotter(data, "Label Clustering")