In [None]:
import seaborn as sns
import matplotlib.pyplot as plt
import pandas as pd
import os
import io
from PIL import Image
import math

import numpy as np
from matplotlib.gridspec import GridSpec

In [None]:
datasetnames = ["NL Pop200", "Fish Data", "NL Grouping", "NL Merging","NL 3Groups"]
dimrednames = ["PCA", "SNE", "SAM", "UMAP"]
visstratnames = ["ML", "MR", "GR", "nGR"]

# Testing Pearsons

In [None]:
import os
import matplotlib.pyplot as plt
import seaborn as sns

# Read 
def read_pearsons(dataset):
    dir = 'pearsons/'
    files = os.listdir(dir)

    pearsons = {}

    for file in files:

        # Read file name
        # get the filename components
        filename = file.split('.')[0]
        datasetname, dimRedStrat, imageStrat, epsilon = filename.split('_')

        if datasetname != dataset:
            continue
        

        with open(dir + file, 'r') as f:
            lines = f.readlines()

            # Skip the first line
            lines = lines[1:]

            # add datasetname and empty list to pearsons
            if imageStrat not in pearsons:
                pearsons[imageStrat] = []

            for line in lines:
                pearsons[imageStrat].append(float(line.strip()))

    return pearsons

new = read_pearsons("mergeFocus")

def make_histogram(pearsons):
    import matplotlib.pyplot as plt

    for dataset in pearsons:
        plt.hist(pearsons[dataset], bins=20)
        plt.title(dataset)
        plt.show()


def violinplot(pearsons):
    import seaborn as sns
    import pandas as pd

    data = pd.DataFrame(pearsons)
    sns.violinplot(data=data)
    plt.show()

# make_histogram(new)

violinplot(new)

In [None]:
import math

def create_pearsons_histogram(pearsons, visstrat="Motion Lines"):
    expcolor = (200,200,200)

    # Get the right metric instance
    metric = pearsons[visstrat]

    img = Image.new('RGB', (len(metric), 100), color = 'black')

    for i in range(len(metric)):

        framevalue = metric[i]
        # check if frame value is a float
        if math.isnan(framevalue):
            continue

        # Calculate the y value
        y = int(framevalue*100)

        # Draw a line
        for j in range(y):
            img.putpixel((i, 99-j), expcolor)

    return img

def plot_image_pearsons(img, title=""):
    # Wrap image in a figure
    fig, ax = plt.subplots()

    # Size of image
    fig.set_size_inches(20, 10)

    # labels axes
    ax.set_xlabel('Frame')
    ax.set_ylabel('Pearsons Correlation Coefficient')

    # Set title
    ax.set_title(title)

    ax.imshow(img, extent=[0, img.width, -50, 50])
    ax.set_yticks([-50, -25, 0, 25, 50])
    ax.set_yticklabels([0, 0.25, 0.5, 0.75, 1])

img_ml = create_pearsons_histogram(new)
img_mr = create_pearsons_histogram(new, "MotionRugs")
img_gr = create_pearsons_histogram(new, "Ordered Rugs")
img_ngr = create_pearsons_histogram(new, "Fuzzy Rugs")

plot_image_pearsons(img_ml, "Motion Lines")
plot_image_pearsons(img_mr, "MotionRugs")
plot_image_pearsons(img_gr, "Ordered Rugs")
plot_image_pearsons(img_ngr, "Fuzzy Rugs")

# Initialization

In [None]:
class DatasetStats:
    def __init__(self, name, dimRedStrat, imageStrat, epsilon):
        self.name = name
        self.dimRedStrat = dimRedStrat
        self.imageStrat = imageStrat
        self.epsilon = epsilon

        self.metrics = {}

    def addMetric(self, metric):
        self.metrics[metric.name] = metric

    def getMetric(self, metricName):
        return self.metrics[metricName]
    
    def getAverage(self, metricName):
        metric = self.getMetric(metricName)
        return sum(metric.metricvalues) / len(metric.metricvalues)
    
    def getStdDev(self, metricName):
        metric = self.getMetric(metricName)
        mean = self.getAverage(metricName)
        return (sum((x - mean) ** 2 for x in metric.metricvalues) / len(metric.metricvalues)) ** 0.5
    
    def getMin(self, metricName):
        metric = self.getMetric(metricName)
        return min(metric.metricvalues)
    
    def getMax(self, metricName):
        metric = self.getMetric(metricName)
        return max(metric.metricvalues)

class Metric:
    def __init__(self, name):
        self.name = name
        self.metricvalues = []

    def addValue(self, value):
        self.metricvalues.append(value)

In [None]:
drt_dict = {
    "Stable sammon mapping": "SAM",
    "UMAPStrategy": "UMAP",
    "Stable TSNEStableStrategy": "SNE",
    "t-SNE (simple)": "SNE",
    "PrincipalComponentStrategy": "PCA"
    }

datasetnames_dict = {
    "200pop1": "NL Pop200",
    "fishdatamerge": "Fish Data",
    "grouping": "NL Grouping",
    "mergeFocus": "NL Merging",
    "tryagain": "NL 3Groups"
}

visstrats_dict = {
    "Motion Lines": "ML",
    "MotionRugs": "MR",
    "Ordered Rugs": "GR",
    "Fuzzy Rugs": "nGR"
}

In [None]:
class MetricInstance:
    def __init__(self, dimred, visstrat):
        self.metrics = {}
        self.dimred = dimred
        self.visstrat = visstrat

    def addMetric(self, metric):
        self.metrics[metric.name] = metric

    def getMetric(self, metricName):
        return self.metrics[metricName]
    
    def getAverage(self, metricName):
        metric = self.getMetric(metricName)
        return sum(metric.metricvalues) / len(metric.metricvalues)
    
    def getStdDev(self, metricName):
        metric = self.getMetric(metricName)
        mean = self.getAverage(metricName)
        return (sum((x - mean) ** 2 for x in metric.metricvalues) / len(metric.metricvalues)) ** 0.5
    
    def getMin(self, metricName):
        metric = self.getMetric(metricName)
        return min(metric.metricvalues)
    
    def getMax(self, metricName):
        metric = self.getMetric(metricName)
        return max(metric.metricvalues)

In [None]:
def load_files():
    dir = 'metrics/'
    files = os.listdir(dir)

    datasets = {}

    for file in files:
        with open(dir + file, 'r') as f:

            # get the filename components
            filename = file.split('.')[0]
            datasetname, dimRedStrat, imageStrat, epsilon = filename.split('_')
    
            # get the right names
            datasetname = datasetnames_dict[datasetname]
            visstrat = visstrats_dict[imageStrat]
            dimred = drt_dict[dimRedStrat]

            # create a new metric instance
            metricinstance = MetricInstance(dimred, visstrat)

            # read the file
            lines = f.readlines()
            metricname = ""
            for line in lines:
                if ":" in line:
                    metricname = line.split(":")[0]
                    metric = Metric(metricname)
                    metricinstance.addMetric(metric)
                else:
                    metric.addValue(float(line))

            # add the metric instance to the dataset
            if datasetname not in datasets:
                datasets[datasetname] = []
            datasets[datasetname].append(metricinstance)              

    return datasets

datasets = load_files()

In [None]:
# Create function that reads and saves all filenames taking a directory path as an argument.
def read_files(directory, datasetname, excludeUMAP=False):
    
    files = os.listdir(directory)
    fileslist = []
    # Map filenames to their contents.
    for file in files:
        with open(directory + file, 'r') as f:
            content = f.read()

            # Make deep copy of filename without .txt extension.
            filecopy = file[:-4]

            # Split filename into its components.
            name, drstrat, imagestrat, eps = filecopy.split('_')

            if datasetname != name:
                continue   
            if excludeUMAP and drstrat == "UMAP":
                continue

            eps = float(eps)
            dst = DatasetStats(datasetnames_dict[name], drt_dict[drstrat], visstrats_dict[imagestrat], eps)

            # Parse f to identify the metrics present in the file.
            lines = content.splitlines()
            metric = None
            for line in lines:
                if ":" in line:
                    if metric != None:
                        dst.addMetric(metric)
                    metric = Metric(line.strip())
                else:
                    metric.addValue(float(line.strip()))
            dst.addMetric(metric)

            fileslist.append(dst)

    return fileslist

# Barcharts, Violin Plots and Min-Max

In [None]:
import numpy as np

def create_violinplot(datasetname, metricname, dimred):
    data = []
    visstratnames = []
    for metricinstance in datasets[datasetname]:
        if metricinstance.dimred == dimred:
            data.append(metricinstance.getMetric(metricname).metricvalues)
            visstratnames.append(metricinstance.visstrat)
    data = np.array(data)
    data = data.T

    fig = plt.figure()

    # Set image size
    fig.set_size_inches(20, 10)

    df = pd.DataFrame(data, columns=visstratnames)
    fig.add_axes([0,0,1,1])
    sns.violinplot(data=df, scale='width', inner='point')
    plt.title(f"{datasetname} - {metricname} - {dimred}")
    plt.show()

# create_violinplot("NL Grouping", "Spatial Quality Enc", "UMAP")

def create_violinplots(metric, dimred):
    for dataset in datasets:
        create_violinplot(dataset, metric, dimred)

create_violinplot("NL 3Groups", "Silhouette Score", "PCA")

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns

def create_violinplot(datasetname, metricname, dimred):
    # Desired order of visualization strategies
    desired_order = ["nGR", "GR"]
    
    data_dict = {}
    
    # Collecting data and visstratnames
    for metricinstance in datasets[datasetname]:
        if metricinstance.dimred == dimred:
            visstrat = metricinstance.visstrat
            metric_values = metricinstance.getMetric(metricname).metricvalues
            data_dict[visstrat] = metric_values

    # Sorting data according to the desired order
    data = []
    visstratnames = []
    for visstrat in desired_order:
        if visstrat in data_dict:
            data.append(data_dict[visstrat])
            visstratnames.append(visstrat)

    data = np.array(data).T  # Transpose to match the expected format

    fig = plt.figure()

    # Set image size
    fig.set_size_inches(10, 5)

    df = pd.DataFrame(data, columns=visstratnames)
    fig.add_axes([0,0,1,1])
    sns.violinplot(data=df, scale='width', inner='point')
    plt.title(f"{datasetname} - {metricname} - {dimred}")
    plt.show()

# Example usage:
# create_violinplot("NL Grouping", "Spatial Quality Enc", "UMAP")

def create_violinplots(metric, dimred):
    for dataset in datasets:
        create_violinplot(dataset, metric, dimred)

# Example usage:
create_violinplot("NL Merging", "Silhouette Score", "PCA")


In [None]:
data_dict = {}

for mi in datasets["NL Pop200"]:
    if mi.visstrat == "nGR" and mi.dimred == "PCA":
        stab_PCA = mi.getMetric("Stability Dist").metricvalues
        data_dict["PCA"] = stab_PCA

    if mi.visstrat == "nGR" and mi.dimred == "SAM":
        stab_SAM = mi.getMetric("Stability Dist").metricvalues
        data_dict["SAM"] = stab_SAM

# Create a DataFrame
df = pd.DataFrame(data_dict)

# Create a violin plot
sns.violinplot(data=df, scale='width', inner='point')
plt.title("NL Pop200 - Stability Dist - PCA vs SAM")
# Set figure size
plt.gcf().set_size_inches(10, 5)
plt.show()

In [None]:
def printMax(dataset, metricname, dimred):
    for metricinstance in datasets[dataset]:
        if metricinstance.dimred == dimred:
            print(str(metricinstance.visstrat) + " " + str(metricinstance.getMax(metricname)))

printMax("Fish Data", "Silhouette Score", "PCA")

In [None]:
metric_dict = {
    "Silhouette Score": "Silh.",
    "Stability Dist": "Stab.",
    "Spatial Quality Enc": "SS",
    "Spatial Quality Dist": "KS"
}
    
report_metrics = {
    "Silhouette Score": "Silhouette Score",
    "Stability Dist": "Stability",
    "Spatial Quality Enc": "Spatial Similarity (SS)",
    "Spatial Quality Dist": "Keys Similarity (KS)"
}

# Function to write code for overleaf table
def makeTable(dataset, metric, dsmentioned = False, visstratmentioned = False):
    dimreds = ["PCA", "SAM", "UMAP", "SNE"]
    visstrats = ["MR", "ML", "nGR", "GR"]
    if dataset == ("NL Pop200" or "NL Merging" or "Fish Data"):
        dimreds = ["PCA", "SAM", "UMAP"]

    for visstrat in visstrats:
        visstratmentioned = False

        for dimred in dimreds:
            # Get the average, min and max values
            for metricinstance in datasets[dataset]:
                if metricinstance.dimred == dimred and metricinstance.visstrat == visstrat:
                    avg = metricinstance.getAverage(metric)
                    minval = metricinstance.getMin(metric)
                    maxval = metricinstance.getMax(metric)
            print(f"\\hline")
            if not dsmentioned and not visstratmentioned:
                print(f" {dataset} & {visstrat} & {dimred} & {avg:.2f} & {minval:.2f} & {maxval:.2f} \\\\")
                dsmentioned = True
                visstratmentioned = True
            elif not visstratmentioned:
                print(f" & {visstrat} & {dimred} & {avg:.2f} & {minval:.2f} & {maxval:.2f} \\\\")
                visstratmentioned = True
            else:
                print(f" & & {dimred} & {avg:.2f} & {minval:.2f} & {maxval:.2f} \\\\")
# Example usage:
# makeTable("NL Pop200", "Silhouette Score")

for dataset in datasetnames:

    ds_mentioned = False

    for metric in metric_dict:

        print(f"\\begin{{table}}[H]")
        print(f"\\centering")
        print(f"\\begin{{tabular}}{{|c|c|c|c|c|c|}}")
        print(f"\\hline")

        print(f" Dataset & Vis. Strat. & Dim. Red. Techn. & {metric_dict[metric]} avg & {metric_dict[metric]} min & {metric_dict[metric]} max \\\\")
        print(f"\\hline")
        makeTable(dataset, metric)

        print(f"\\hline")
        print(f"\\end{{tabular}}")
        print(f"\\caption{{Displaying average, minimum and maximum values for the {report_metrics[metric]} metric on the {dataset} dataset.}}")
        print(f"\\label{{tab:{dataset}_{report_metrics[metric]}}}")
        print(f"\\end{{table}}")

In [None]:
# Create function that reads and saves all filenames taking a directory path as an argument.
def read_files(directory, datasetname, excludeUMAP=False):
    
    files = os.listdir(directory)
    fileslist = []
    # Map filenames to their contents.
    for file in files:
        with open(directory + file, 'r') as f:
            content = f.read()

            # Make deep copy of filename without .txt extension.
            filecopy = file[:-4]

            # Split filename into its components.
            name, drstrat, imagestrat, eps = filecopy.split('_')

            if datasetname != name:
                continue   
            if excludeUMAP and drstrat == "UMAP":
                continue

            eps = float(eps)
            dst = DatasetStats(datasetnames_dict[name], drt_dict[drstrat], visstrats_dict[imagestrat], eps)

            # Parse f to identify the metrics present in the file.
            lines = content.splitlines()
            metric = None
            for line in lines:
                if ":" in line:
                    if metric != None:
                        dst.addMetric(metric)
                    metric = Metric(line.strip())
                else:
                    metric.addValue(float(line.strip()))
            dst.addMetric(metric)

            fileslist.append(dst)

    return fileslist

# Scatterplot Code

In [None]:
def create_dataset(excludeUMAP=False):

    df_list = []
    dsnames = ['200pop1', 'fishdatamerge','grouping', 'mergeFocus', 'tryagain']
    for dsname in dsnames:
        files = read_files('metrics/', dsname, True)
        
        stabilityAvg = []
        spatial_qualityAvg = []
        spatial_qualityAvg_enc = []
        silhouetteAvg = []
        filename = []
        strategies = []
        dimredstrats = []
        for file in files:

                if excludeUMAP and file.dimRedStrat == "UMAP":
                    continue    

                filename.append(file.name)
                strategies.append(file.imageStrat)
                silhouetteAvg.append(file.getAverage("Silhouette Score:"))
                stabilityAvg.append(file.getAverage("Stability Dist:"))
                spatial_qualityAvg.append(file.getAverage("Spatial Quality Dist:"))
                spatial_qualityAvg_enc.append(file.getAverage("Spatial Quality Enc:"))
                dimredstrats.append(file.dimRedStrat)

        df = pd.DataFrame({
                'Dataset': filename,
                'Stability': stabilityAvg,
                'Spatial Quality (KS)': spatial_qualityAvg,
                'Spatial Quality (SS)': spatial_qualityAvg_enc,
                'Image Strategy': strategies,
                'Dim. Red. Technique': dimredstrats,
                'Silhouette': silhouetteAvg
                })

        df_list.append(df)

    return df_list

In [None]:
# Define a mapping for image strategy to marker shapes using only filled markers
marker_shapes = {
    'MR': 's',   # Square
    'ML': 'D',   # Diamond
    'nGR': 'o',  # Circle
    'GR': '^'    # Triangle up
}

scatterdir = ""

technique_colors = {
        "UMAP": "#c6c752",  
        "PCA":  "#a5dad2",
        "SNE": "#a2a18f",
        "SAM": "#f69431"
    }

# Create a seaborn palette from this dictionary
palette = sns.color_palette([technique_colors.get(tech, "gray") for tech in technique_colors])


def scatterplotstab_ks_spat(df, save=False):
    # Compute scatterplot with seaborn, include borders around the points, for each dim red strategy add a label to the     point.
    y = 'Spatial Quality (KS)'

        # Extract the unique dim red techniques in the current dataframe
    unique_techniques = df['Image Strategy'].unique()
    
    # Create a markers dictionary only for the present unique techniques
    markers = {marker: marker_shapes[marker] for marker in unique_techniques}

    # Compute scatterplot with seaborn, include borders around the points, for each dim red strategy add a label to the point.
    sns.scatterplot(
        data=df,
        x='Stability',
        y=y,
        style='Image Strategy',
        hue='Dim. Red. Technique',
        edgecolor='black',
        s=100,
        palette=technique_colors,
        markers=markers
    )

    # sns.scatterplot(data=df, x='Stability', y=y, style='Image Strategy', hue='Dim. Red. Technique', edgecolor='black',    s=100, palette=technique_colors)

    dsname = df['Dataset'][0]

    # Include title and labels
    plt.title('Stability vs Spatial Quality (' + dsname + ')')
    plt.xlabel('Stability (KSte)')
    plt.ylabel('Spatial Quality (KS)')

    # plt.ylim(60,80)

    plt.legend(bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.)

    plt.grid()
    if save:
        plt.savefig(scatterdir + 'scatter_stability_spatial_quality' + dsname + '.png', bbox_inches='tight')

    plt.show()

def scatterplotstabexp(df, save=False):

    # Compute scatterplot with seaborn, include borders around the points, for each dim red strategy add a label to the     point.
    sns.scatterplot(data=df, x='Stability', y='Silhouette',style='Image Strategy', hue='Dim. Red. Technique', edgecolor='black',    s=100, palette=technique_colors)

    dsname = df['Dataset'][0]

    # Include title and labels
    plt.title('Stability vs Expressiveness (' + dsname + ')')
    plt.xlabel('Stability (KSte)')
    plt.ylabel('Expressiveness (Silhouette)')

    plt.legend(bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.)

    plt.grid()

    if save:
        plt.savefig(scatterdir + 'scatter_stability_expressiveness' + dsname + '.png', bbox_inches='tight')
    plt.show()

def scatterplotspat_ks_exp(df, save=False):

    y = 'Spatial Quality (KS)'

    dsname = df['Dataset'][0]

    # Compute scatterplot with seaborn, include borders around the points, for each dim red strategy add a label to the     point.
    sns.scatterplot(data=df, x='Silhouette', y=y, style='Image Strategy', hue='Dim. Red. Technique', edgecolor='black',    s=100, palette=technique_colors)

    # Include title and labels
    plt.title('Expressiveness vs Spatial Quality (' + dsname + ')')
    plt.xlabel('Expressiveness (Silhouette)')
    plt.ylabel('Spatial Quality (KS)')

    plt.legend(bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.)
    # plt.ylim(60,80)
    plt.grid()

    if save:
        plt.savefig(scatterdir + 'scatter_expressiveness_spatial_quality' + dsname + '.png', bbox_inches='tight')
    plt.show()

# ==============================================================================
# Spatial Quality (SS) plots
# ==============================================================================
def scatterplot_ss_ks(df, save=False, xlim=False):

    
    x = 'Spatial Quality (SS)'
    y = 'Spatial Quality (KS)'

    dsname = df['Dataset'][0]

    # Compute scatterplot with seaborn, include borders around the points, for each dim red strategy add a label to the     point.
    sns.scatterplot(data=df, x=x, y=y, style='Image Strategy', hue='Dim. Red. Technique', edgecolor='black',    s=100, palette=technique_colors)

    # Include title and labels
    plt.title(x + ' vs ' + y + ' (' + dsname + ')')
    plt.xlabel(x)
    plt.ylabel(y)

    if xlim:
        plt.xlim(62,70)
        plt.ylim(5,9)

    plt.legend(bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.)
    plt.grid()

    if save:
        if xlim:
            plt.savefig(scatterdir + 'scatter_' + x + '_' + y + dsname + '_xlim.png', bbox_inches='tight')
        else:
            plt.savefig(scatterdir + 'scatter_' + x + '_' + y + dsname + '.png', bbox_inches='tight')

    plt.show()

def scatterplotstab_ss_spat(df, save=False):
    # Compute scatterplot with seaborn, include borders around the points, for each dim red strategy add a label to the     point.
    y = 'Spatial Quality (SS)'

    sns.scatterplot(data=df, x='Stability', y=y, style='Image Strategy', hue='Dim. Red. Technique', edgecolor='black',    s=100, palette=technique_colors)

    dsname = df['Dataset'][0]

    # Include title and labels
    plt.title('Stability vs Spatial Quality (SS) (' + dsname + ')')
    plt.xlabel('Stability (KSte)')
    plt.ylabel('Spatial Quality (SS)')

    plt.legend(bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.)

    plt.grid()
    if save:
        plt.savefig(scatterdir + 'scatter_stability_spatial_quality' + dsname + '_SS.png', bbox_inches='tight')

    plt.show()

def scatterplotspat_ss_exp(df, save=False):

    y = 'Spatial Quality (SS)'

    dsname = df['Dataset'][0]

    # Compute scatterplot with seaborn, include borders around the points, for each dim red strategy add a label to the     point.
    sns.scatterplot(data=df, x='Silhouette', y=y, style='Image Strategy', hue='Dim. Red. Technique', edgecolor='black',    s=100, palette=technique_colors)

    # Include title and labels
    plt.title('Expressiveness vs Spatial Quality (' + dsname + ')')
    plt.xlabel('Expressiveness (Silhouette)')
    plt.ylabel('Spatial Quality (SS)')

    plt.legend(bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.)
    # plt.ylim(60,80)
    plt.grid()

    if save:
        plt.savefig(scatterdir + 'scatter_expressiveness_spatial_quality' + dsname + '_SS.png', bbox_inches='tight')
    plt.show()

marker_shapes = {
    'MR': 's',  # Square
    'ML': 'D',  # Diagonal cross
    'nGR': 'o',  # Circle
    'GR': '^'   # Upright cross
}

In [None]:
# Create scatterplot ks stability for NL Pop200
df = create_dataset(True)[2]
scatterplotstab_ks_spat(df, save=True)

In [None]:
# ==============================================================================
# Spatial Quality (KS) plots for all datasets in one plot
# ==============================================================================
def scatterplot_ks_stab_all(df_list, save=False):
    fig, axs = plt.subplots(2, 3, figsize=(20, 10))
    
    for i, df in enumerate(df_list):
        x = 'Stability'
        y = 'Spatial Quality (KS)'
    
        dsname = df['Dataset'][0]

        # Determine which subplot to use
        if i < 3:
            ax = axs[0, i]
        else:
            ax = axs[1, i % 3]  # Correct the subplot index
    
        # Map the image strategy to the corresponding marker shape
        style_order = [marker_shapes[strategy] for strategy in df['Image Strategy'].unique()]

        # Compute scatterplot with seaborn, include borders around the points, for each dim red strategy add a label to the point.
        sns.scatterplot(
            data=df,
            x=x,
            y=y,
            style='Image Strategy',
            hue='Dim. Red. Technique',
            edgecolor='black',
            s=100,
            palette=technique_colors,
            ax=ax,
            markers=marker_shapes
        )

        # Remove legend from first four plots
        if i < 4:
            ax.get_legend().remove()
    
        # Include title and labels
        ax.set_title(dsname)
        ax.set_xlabel(x)
        ax.set_ylabel(y)
    
    if save:
        plt.savefig(scatterdir + 'scatter_' + x + '_' + y + '_all.png', bbox_inches='tight')
    
    plt.show()

# Example usage (assuming create_dataset is defined and technique_colors is set)
scatterplot_ks_stab_all(create_dataset(), True)

In [None]:
# ==============================================================================
# Spatial Quality (KS) plots for all datasets in one plot
# ==============================================================================
def scatterplot_ks_stab_all(df_list, save=False):
        fig, axs = plt.subplots(2, 3, figsize=(20,10))
    
        for i, df in enumerate(df_list):
            x = 'Stability'
            y = 'Spatial Quality (KS)'
    
            dsname = df['Dataset'][0]

            # Determine which subplot to use
            if i < 3:
                ax = axs[0, i]
            else:
                ax = axs[1, 1]
    
            # Compute scatterplot with seaborn, include borders around the points, for each dim red strategy add a label to the     point.
            sns.scatterplot(data=df, x=x, y=y, style='Image Strategy', hue='Dim. Red. Technique', edgecolor='black',    s=100, palette=technique_colors, ax=axs[math.floor(i/3), i%3])


            # Remove legend from first four plots
            if i < 4:
                axs[math.floor(i/3), i%3].get_legend().remove()
    
            # Include title and labels
            axs[math.floor(i/3), i%3].set_title(dsname)
            axs[math.floor(i/3), i%3].set_xlabel(x)
            axs[math.floor(i/3), i%3].set_ylabel(y)
    
        if save:
            plt.savefig(scatterdir + 'scatter_' + x + '_' + y + '_all.png', bbox_inches='tight')
    
        plt.show()

scatterplot_ks_stab_all(create_dataset(True), True)

In [None]:
# ==============================================================================
# Spatial Quality (KS) and Expressiveness plots for all datasets in one plot
# ==============================================================================
def scatterplot_ks_exp_all(df_list, save=False):
        fig, axs = plt.subplots(2, 3, figsize=(20,10))
    
        for i, df in enumerate(df_list):
            x = 'Spatial Quality (KS)'
            y = 'Silhouette'
    
            dsname = df['Dataset'][0]

            # Determine which subplot to use
            if i < 3:
                ax = axs[0, i]
            else:
                ax = axs[1, 1]
    
            # Compute scatterplot with seaborn, include borders around the points, for each dim red strategy add a label to the     point.
            sns.scatterplot(data=df, x=x, y=y, style='Image Strategy', hue='Dim. Red. Technique', edgecolor='black',    s=100, palette=technique_colors, ax=axs[math.floor(i/3), i%3])


            # Remove legend from first four plots
            if i < 4:
                axs[math.floor(i/3), i%3].get_legend().remove()
    
            # Include title and labels
            axs[math.floor(i/3), i%3].set_title(dsname)
            axs[math.floor(i/3), i%3].set_xlabel(x)
            axs[math.floor(i/3), i%3].set_ylabel(y)
    
        if save:
            plt.savefig(scatterdir + 'scatter_' + x + '_' + y + '_all.png', bbox_inches='tight')
    
        plt.show()

scatterplot_ks_exp_all(create_dataset(), True)

In [None]:
# ==============================================================================
# Spatial Quality (SS) and Expressiveness plots for all datasets in one plot
# ==============================================================================
def scatterplot_ss_exp_all(df_list, save=False):
        fig, axs = plt.subplots(2, 3, figsize=(20,10))
    
        for i, df in enumerate(df_list):
            x = 'Spatial Quality (SS)'
            y = 'Silhouette'
    
            dsname = df['Dataset'][0]

            # Determine which subplot to use
            if i < 3:
                ax = axs[0, i]
            else:
                ax = axs[1, 1]
    
            # Compute scatterplot with seaborn, include borders around the points, for each dim red strategy add a label to the     point.
            sns.scatterplot(data=df, x=x, y=y, style='Image Strategy', hue='Dim. Red. Technique', edgecolor='black',    s=100, palette=technique_colors, ax=axs[math.floor(i/3), i%3])


            # Remove legend from first four plots
            if i < 4:
                axs[math.floor(i/3), i%3].get_legend().remove()
    
            # Include title and labels
            axs[math.floor(i/3), i%3].set_title(dsname)
            axs[math.floor(i/3), i%3].set_xlabel(x)
            axs[math.floor(i/3), i%3].set_ylabel(y)
    
        if save:
            plt.savefig(scatterdir + 'scatter_' + x + '_' + y + '_all.png', bbox_inches='tight')
    
        plt.show()

scatterplot_ss_exp_all(create_dataset(), True)

In [None]:
# ==============================================================================
# Stability and Expressiveness plots for all datasets in one plot
# ==============================================================================
def scatterplot_stab_exp_all(df_list, save=False):
        fig, axs = plt.subplots(2, 3, figsize=(20,10))
    
        for i, df in enumerate(df_list):
            x = 'Stability'
            y = 'Silhouette'
    
            dsname = df['Dataset'][0]

            # Determine which subplot to use
            if i < 3:
                ax = axs[0, i]
            else:
                ax = axs[1, 1]
    
            # Compute scatterplot with seaborn, include borders around the points, for each dim red strategy add a label to the     point.
            sns.scatterplot(data=df, x=x, y=y, style='Image Strategy', hue='Dim. Red. Technique', edgecolor='black',    s=100, palette=technique_colors, ax=axs[math.floor(i/3), i%3])


            # Remove legend from first four plots
            if i < 4:
                axs[math.floor(i/3), i%3].get_legend().remove()
    
            # Include title and labels
            axs[math.floor(i/3), i%3].set_title(dsname)
            axs[math.floor(i/3), i%3].set_xlabel(x)
            axs[math.floor(i/3), i%3].set_ylabel(y)
    
        if save:
            plt.savefig(scatterdir + 'scatter_' + x + '_' + y + '_all.png', bbox_inches='tight')
    
        plt.show()

scatterplot_stab_exp_all(create_dataset(), True)

In [None]:
def create_scatterplots(df_list, save=False):
    for df in df_list:

        
        scatterplotstab_ks_spat(df, save=save)
        scatterplotspat_ks_exp(df, save=save)

        scatterplotstab_ss_spat(df, save=save)
        scatterplotspat_ss_exp(df, save=save)

        scatterplotstabexp(df, save=save)

df_list = create_dataset()  


create_scatterplots(df_list, save=True)

# Histogram Code

### Initialization

In [None]:
import math

def create_silhouette_histogram(datasetname,dimred,visstrat):
    # Get the right dataset
    metricinstances = datasets[datasetname]
    expcolor = (115,195,108)

    # Get the right metric instance
    for metricinstance in metricinstances:
        if metricinstance.dimred == dimred and metricinstance.visstrat == visstrat:
            metric = metricinstance.getMetric("Silhouette Score")
            break

    img = Image.new('RGB', (len(metric.metricvalues), 100), color = 'black')

    halfline = 50

    for i in range(len(metric.metricvalues)):

        framevalue = metric.metricvalues[i]
        # check if frame value is a float
        if math.isnan(framevalue):
            continue

        scaledvalue = abs(int(framevalue * halfline))
        
        # Plot positive values on the upper half of the image
        if framevalue > 0:
            for j in range(scaledvalue):
                pixcoords = (i,halfline-j)
                img.putpixel(pixcoords, expcolor)

        # Plot negative values on the lower half of the image
        else:
            for j in range(scaledvalue):
                pixcoords = (i,halfline+j)
                img.putpixel(pixcoords, expcolor)

    return img

img = create_silhouette_histogram("NL 3Groups", "SAM", "ML")

def plot_image_pearsons(img):
    # Wrap image in a figure
    fig, ax = plt.subplots()

    # Size of image
    fig.set_size_inches(30, 10)

    # labels axes
    ax.set_xlabel('Frame')
    ax.set_ylabel('Silhouette Score')

    ax.imshow(img, extent=[0, img.width, -50, 50])
    ax.set_yticks([-50, -25, 0, 25, 50])
    ax.set_yticklabels([-1, -0.5, 0, 0.5, 1])

plot_image_pearsons(img)

In [None]:
def spatial_quality_hist(dsname, dimred, visstrat, maxvalue_ks=0, maxvalue_ss=0):
    metricinstances = datasets[dsname]

    spatcolor_ks = (234, 182, 118)
    spatcolor_ss = (158, 113, 58)

    # Get the right metric instance
    metricinstance = None
    for mi in metricinstances:
        if mi.dimred == dimred and mi.visstrat == visstrat:
            metricinstance = mi
            break

    spat_ks_metric = metricinstance.getMetric("Spatial Quality Dist")
    spat_ss_metric = metricinstance.getMetric("Spatial Quality Enc")

    if (maxvalue_ks == 0 and maxvalue_ss == 0):
        # find maximum value for each metric, rounded up to nearest integer
        maxvalue_ks = metricinstance.getMax("Spatial Quality Dist")
        maxvalue_ss = metricinstance.getMax("Spatial Quality Enc")
    
        maxvalue_ks = math.ceil(maxvalue_ks)
        maxvalue_ss = math.ceil(maxvalue_ss)

    img = Image.new('RGB', (len(spat_ks_metric.metricvalues), 100), color = 'black')
    halfline = 50

    for i in range(len(spat_ks_metric.metricvalues)):
        # Step 1: Spatial Quality Dist
        framevalue = spat_ks_metric.metricvalues[i]
        scaledvalue = int((framevalue/maxvalue_ks) * halfline)

        # cap the value at the height of the image and at 0
        if scaledvalue > img.height/2:
            scaledvalue = img.height/2
        if scaledvalue < 0:
            scaledvalue = 0 
            
        for j in range(img.height):

            # draw a vertical bar at pixel i,j indicating the value of the metric
            if j < scaledvalue:
                pixelcoords = (i, 50-j)
                img.putpixel(pixelcoords, spatcolor_ks)

        # Step 2: Spatial Quality Enc
        framevalue = spat_ss_metric.metricvalues[i]
        if math.isnan(framevalue):
            print(framevalue + " at frame " + i)
        scaledvalue = int((framevalue/maxvalue_ss) * halfline)

        # cap the value at the height of the image and at 0
        if scaledvalue > img.height/2:
            scaledvalue = img.height/2
        if scaledvalue < 0:
            scaledvalue = 0

        for j in range(img.height):
                
                # draw a vertical bar at pixel i,j indicating the value of the metric
                if j < scaledvalue:
                    pixelcoords = (i, 50+j)
                    img.putpixel(pixelcoords, spatcolor_ss)

    return img, maxvalue_ks, maxvalue_ss

def plot_img(img, maxvalue_ks, maxvalue_ss):
    fig, ax = plt.subplots()
    fig.set_size_inches(30, 10)
    ax.imshow(img, extent=[0, img.width, -50, 50])
    ax.set_yticks([-50, -25, 0, 25, 50])
    ax.set_yticklabels([round(float(maxvalue_ss),1), round(float(maxvalue_ss/2),1), 0, round(float(maxvalue_ks/2),1), round(float(maxvalue_ks), 1)])

    plt.show()

maxvalue_ks = 10
maxvalue_ss = 5

# img, maxvalue_ks, maxvalue_ss = spatial_quality_hist("NL Pop50", "UMAP", "ML")
img, maxvalue_ks, maxvalue_ss = spatial_quality_hist("Fish Data", "UMAP", "ML")
plot_img(img, maxvalue_ks=maxvalue_ks, maxvalue_ss=maxvalue_ss)

In [None]:
def stability_hist(dsname, dimred, visstrat, maxvalue=0):
    metricinstances = datasets[dsname]
    stabcolor = (100, 164, 203)

    # Get the right metric instance
    metricinstance = None
    for mi in metricinstances:
        if mi.dimred == dimred and mi.visstrat == visstrat:
            metricinstance = mi
            break
    stabmetric = metricinstance.getMetric("Stability Dist")

    # find maximum value for each metric, rounded up to nearest integer
    if maxvalue == 0:
        maxvalue = metricinstance.getMax("Stability Dist")
        maxvalue = math.ceil(maxvalue)

    img_height = 50
    img = Image.new('RGB', (len(stabmetric.metricvalues), img_height), color = 'black')

    for i in range(img.width):
        framevalue = stabmetric.metricvalues[i]
        scaledvalue = int(framevalue/maxvalue * img_height)

        # cap the value at the height of the image and at 0
        if scaledvalue > img.height-1:
            scaledvalue = img.height-1
        if scaledvalue < 0:
            scaledvalue = 0

        for j in range(img.height-1):
            if j < scaledvalue:
                pixelcoords = (i, img_height-1-j)
                img.putpixel(pixelcoords, stabcolor)

    return img, maxvalue

img, maxvalue = stability_hist("Fish Data", "PCA", "ML")
plot_img(img, maxvalue_ks=maxvalue, maxvalue_ss=0)


### Incorporating Code in Image

In [None]:
# Imports
import numpy as np
from matplotlib.gridspec import GridSpec

In [None]:
# Function to plot a new image with all the metrics for a given dataset, dimred and visstrat.
def plot_metrics_overview(
        img_visual_summary, 
        img_spatial, 
        img_stability,
        img_silhouette, 
        dsname, 
        dimred, 
        visstrat, 
        maxstab=15, 
        maxspat_ks=20,
        maxspat_ss=10,
        save = False):

    fig = plt.figure(figsize=(20,7))
    gs = GridSpec(4, 1, height_ratios=[img_visual_summary.height, img_spatial.height, img_stability.height, img_silhouette.height], hspace=0.1)

    ax_vissum = fig.add_subplot(gs[0,0])
    ax_spatial = fig.add_subplot(gs[1,0])
    ax_stability = fig.add_subplot(gs[2,0])
    ax_silhouette = fig.add_subplot(gs[3,0])

    # Plot the image
    ax_vissum.imshow(img_visual_summary, aspect="auto")
    ax_vissum.set_yticks([])
    ax_vissum.set_xticks([])
    ax_vissum.set_xticklabels([])

    # Plot the spatial quality metrics
    ax_spatial.set_ylabel('SS vs KS')
    metric_height = img_spatial.height
    ax_spatial.imshow(img_spatial, aspect="auto")
    ax_spatial.set_yticks(np.arange(5)*metric_height/4)
    ax_spatial.set_yticklabels([round(float(maxspat_ks),1), round(float(maxspat_ks/2,),1), 0, round(float(maxspat_ss/2),1), round(float(maxspat_ss),1)])
    ax_spatial.set_xticks([])

    # Plot the stability metric
    ax_stability.set_ylabel('Stability')
    metric_height = img_stability.height
    ax_stability.imshow(img_stability, aspect="auto")
    ax_stability.set_yticks(np.arange(3)*metric_height/2)
    ax_stability.set_yticklabels([round(float(maxstab),1), round(float(maxstab/2),1), 0])
    ax_stability.set_xticks([])

    # Plot the silhouette metric
    ax_silhouette.imshow(img_silhouette, aspect="auto")
    metric_height = img_silhouette.height
    ax_silhouette.set_yticks(np.arange(5)*metric_height/4)
    ax_silhouette.set_yticklabels([1, 0.5, 0, -0.5, -1])
    ax_silhouette.set_xlabel("Frame")
    ax_silhouette.set_ylabel("Silhouette\nScore")

    # Save the image
    plt.tight_layout()
    savepath = ""
    plt.savefig(savepath, bbox_inches='tight')
    plt.show()


In [None]:
# Make the Silhouette histogram for tryagain, pca ngr
img_silhouette = create_silhouette_histogram("NL 3Groups", "PCA", "nGR")

plot_image_pearsons(img_silhouette)

In [None]:
datasetfilenames = ["fishdatamerge", "grouping", "mergeFocus", "tryagain", "200pop1"] 
eps_dict = {
    "200pop1": 15.0,
    "fishdatamerge": 250.0,
    "grouping": 15.0,
    "mergeFocus": 15.0,
    "tryagain": 15.0
}

for dataset in datasetfilenames:
    for dimred1 in dimrednames:
        for visstrat in visstratnames:
                
                # Skip SNE for datasets that do not have it.
                if dataset == "200pop1" and dimred1 == "SNE":
                    continue
                if dataset == "fishdatamerge" and dimred1 == "SNE":
                    continue
                if dataset == "mergeFocus" and dimred1 == "SNE":
                    continue

                if dataset == "200pop1":
                    maxspat_ks = 10
                    maxspat_ss = 8
                    maxstab = 5
                elif dataset == "fishdatamerge":
                     maxspat_ks = 15
                     maxspat_ss = 250
                     maxstab = 5
                elif dataset == "grouping":
                    maxspat_ks = 8
                    maxspat_ss = 15
                    maxstab = 5
                elif dataset == "mergeFocus":
                    maxspat_ks = 19
                    maxspat_ss = 15
                    maxstab = 5
                elif dataset == "tryagain":
                    maxspat_ks = 6
                    maxspat_ss = 10
                    maxstab = 5

                epsilon = eps_dict[dataset]

                img_vis_path = ""
                img_vis = Image.open(img_vis_path)
                img_spatial, _, _ = spatial_quality_hist(datasetnames_dict[dataset], dimred1, visstrat, maxvalue_ks=maxspat_ks, maxvalue_ss=maxspat_ss)
                img_stability, _ = stability_hist(datasetnames_dict[dataset], dimred1, visstrat, maxstab)
                img_silhouette = create_silhouette_histogram(datasetnames_dict[dataset], dimred1, visstrat)

                plot_metrics_overview(
                    img_vis, 
                    img_spatial, 
                    img_stability, 
                    img_silhouette, 
                    datasetnames_dict[dataset], 
                    dimred1, 
                    visstrat, 
                    maxstab=maxstab, 
                    maxspat_ks=maxspat_ks, 
                    maxspat_ss=maxspat_ss)

# Data Visualization

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import pandas
from pathlib import Path
import matplotlib


def add_line_from_ax_to_ax(
    fig, ax0, x0, y0, ax1, x1, y1, color="k", linewidth=0.5, linestyle="-"
):
    """Adds a line from one axis to another.

    Parameters
    ----------
        fig : matplotlib.figure.Figure
            The figure to add the line to.
        ax0 : matplotlib.axes.Axes
            The axis to start the line from.
        x0 : float
            The x coordinate of the start of the line in data coordinates of ax0.
        y0 : float
            The y coordinate of the start of the line in data coordinates of ax0.
        ax1 : matplotlib.axes.Axes
            The axis to end the line at.
        x1 : float
            The x coordinate of the end of the line in data coordinates of ax1.
        y1 : float
            The y coordinate of the end of the line in data coordinates of ax1.
        color : str, default='k'
            The color of the line.
        linewidth : float, default=1.5
            The width of the line.

    Returns
    -------
    matplotlib.lines.Line2D
        The line that was added.
    """

    def from_ax_to_figure(ax, x, y):
        from_ax = ax.transData.transform((x, y))
        coords = fig.transFigure.inverted().transform(from_ax)
        return coords

    # Update all coordinates of the plot
    fig.canvas.draw()

    # Transform the coordinates to figure coordinates
    coord0 = from_ax_to_figure(ax0, x0, y0)
    coord1 = from_ax_to_figure(ax1, x1, y1)

    new_line = matplotlib.lines.Line2D(
        (coord0[0], coord1[0]),  # x coords
        (coord0[1], coord1[1]),  # y coords
        color=color,
        linewidth=linewidth,
        linestyle=linestyle,
        transform=fig.transFigure,
    )

    fig.lines.append(new_line)

    return new_line


def load_positions(data_path, frame):
    """Loads the x-, and y-positions of all turtles in a given frame from the .csv file
    at the specified path.

    Parameters
    ----------
    data_path : str or pathlib.Path
        The path to the .csv file containing the data.
    frame : int
        The frame to load the positions for.

    Returns
    -------
    numpy.ndarray
        The positions of the turtles in the specified frame of shape (n_turtles, 2).
    """

    data = pandas.read_csv(data_path)

    # Filter the data for the specified frame
    data_frame = data[data["frame"] == frame]

    # Check if the data frame is empty
    if data_frame.empty:
        raise ValueError(f"No data for frame {frame} found in {data_path}.")

    # Extract the x- and y-positions
    x = data_frame["x"].values
    y = data_frame["y"].values

    return np.stack((x, y), axis=1)


def get_max_limits(
    data_path: Path, frame_indices: list[int], margin_factor: float = 1.1
):
    """Loads the positions for the specified frames and returns the maximum x- and
    y-values.

    Parameters
    ----------
    data_path : str or pathlib.Path
        The path to the .csv file containing the data.
    frame_indices : list of int
        The frames to load the positions for.
    margin_factor : float, default=1.1
        The factor to extend the limits by. A factor of 1.1 will extend the limits by
        10%, adding a margin of 5% to each side.

    Returns
    -------
    numpy.ndarray, numpy.ndarray
        The maximum x- and y-values of the specified frames.
    """

    xlim = [np.inf, -np.inf]
    ylim = [np.inf, -np.inf]

    for frame in frame_indices:
        # Load the positions of the turtles in the current frame
        positions = load_positions(data_path, frame)

        # Update the limits
        xlim[0] = np.min([np.min(positions[:, 0]), xlim[0]])
        xlim[1] = np.max([np.max(positions[:, 0]), xlim[1]])
        ylim[0] = np.min([np.min(positions[:, 1]), ylim[0]])
        ylim[1] = np.max([np.max(positions[:, 1]), ylim[1]])

    # Turn into numpy arrays
    xlim = np.array(xlim)
    ylim = np.array(ylim)

    # Apply margin to keep the turtles from the border
    x_margin = (xlim[1] - xlim[0]) * (margin_factor - 1) / 2
    y_margin = (ylim[1] - ylim[0]) * (margin_factor - 1) / 2

    xlim[0] -= x_margin
    xlim[1] += x_margin
    ylim[0] -= y_margin
    ylim[1] += y_margin

    return xlim, ylim


def plot_data(
    data_path: Path,
    img_visualization: np.array,
    frame_indices: list[int],
    turtle_colormap: np.array,
    line_color: str = "black",
    line_width: float = 0.5,
    turtle_color: str = "k",
    turtle_size: float = 3,
    figsize: tuple = (12, 6),
):
    """Creates a fancy plot of the turtle positions above the visualization image.

    Parameters
    ----------
    data_path : str or pathlib.Path
        The path to the .csv file containing the data.
    img_visualization : numpy.ndarray
        The visualization image to plot. i.e. grouprug, motion lines, etc.
    frame_indices : list of int
        The frames to plot.
    turtle_colormap : numpy.ndarray
        The colormap to use for the turtle positions.
    line_color : str, default='r'
        The color of the lines connecting the top and bottom row.
    line_width : float, default=0.5
        The width of the lines connecting the top and bottom row.
    turtle_color : str, default='k'
        The color of the dots representing the turtles.
    turtle_size : float, default=3
        The size of the dots representing the turtles.
    figsize : tuple of float, default=(12, 6)
        The size of the figure.
    """

    # Create gridspec
    fig = plt.figure(figsize=figsize, dpi=300)
    gs = fig.add_gridspec(2, len(frame_indices), height_ratios=[2, 3])

    # Add subplot for bottom row
    ax_visualization = fig.add_subplot(gs[1, :])
    ax_visualization.imshow(img_visualization, aspect="equal")

    ax_visualization.set_xlabel("frame")
    ax_visualization.set_yticks([])

    turtle_xlim, turtle_ylim = get_max_limits(
        data_path, frame_indices, margin_factor=1.1
    )

    for n, frame in enumerate(frame_indices):

        # Draw vertical line at the current frame
        ax_visualization.axvline(
            frame, color=line_color, linestyle="--", linewidth=line_width
        )

        # Load the positions of the turtles in the current frame
        positions = load_positions(data_path, frame_indices[n])

        # Add subplot for top row
        ax = fig.add_subplot(gs[0, n])
        ax.set_aspect("equal")
        ax.set_title(f"Frame {frame_indices[n]}")
        ax.scatter(positions[:, 0], positions[:, 1], c=turtle_color, s=turtle_size)
        ax.set_xticks([])
        ax.set_yticks([])
        ax.set_xlim(turtle_xlim)
        ax.set_ylim(turtle_ylim)
        if turtle_colormap is not None:
            ax.imshow(turtle_colormap, extent=[*turtle_xlim, *turtle_ylim], interpolation='nearest')
        else: # If no colormap provided, white background
            ax.set_facecolor("white")

        xlim = ax.get_xlim()
        ylim = ax.get_ylim()

        # Add lines from the top row to the bottom row
        add_line_from_ax_to_ax(
            fig=fig,
            ax0=ax,
            x0=xlim[0],
            y0=ylim[0],
            ax1=ax_visualization,
            x1=frame,
            y1=0,
            color="black",
            linewidth=line_width,
        )
        add_line_from_ax_to_ax(
            fig=fig,
            ax0=ax,
            x0=xlim[1],
            y0=ylim[0],
            ax1=ax_visualization,
            x1=frame,
            y1=0,
            color="black",
            linewidth=line_width,
        )

    plt.show()


# ======================================================================================
# Create the plot
# ======================================================================================
datapath = "" # Path to csv file
visualizationpath = "" # Path to visualization image
path_data = Path(datapath)
path_visualization = Path(visualizationpath)

img_visualization = plt.imread(path_visualization)
width_img = img_visualization.shape[1]
frame_indices = [850,1300]

turtle_colormap = None

plot_data(
    path_data,
    img_visualization,
    frame_indices,
    turtle_colormap=turtle_colormap,
    figsize=(12, 5),
)
