In [None]:
from pathlib import Path
import numpy as np
import importlib
from matplotlib import pyplot as plt
import humanize
import sys
import os, psutil
import time
import shutil
import pandas as pd
import seaborn as sns
import warnings
from sklearn.metrics import ConfusionMatrixDisplay, confusion_matrix
import logging

import astrocast.reduction as red
import astrocast.clustering as clust
import astrocast.analysis as ana
import astrocast.autoencoders as AE
import astrocast.helper as helper
import astrocast.experiments as exp

for pack in [red, clust, helper, ana, AE, helper, exp]:
    importlib.reload(pack)


In [None]:
# !pip install matplotlib==3.7.3

In [None]:
importlib.reload(exp)

z_range = (0, 10000)
timings = (None, list(range(0, z_range[1], 1000)), None)
dummy_parameters = []

dummy = dict(
        name="default",
        num_rows=1000,
        # z0, z1 boundaries
        z_range=z_range,
        timings=timings,
        timing_jitter=None,
        timing_offset=10,
        # Signal generators
        )

# Signal generators
def_gen = dict(noise_amplitude=0.001, trace_length=(50, 50), parameter_fluctuations=0.01)

# Identical
dummy1 = dummy.copy()
dummy1["name"] = "identical"
dummy1["generators"] = (def_gen, def_gen)
dummy_parameters.append(dummy1)

# big diff
dummy2 = dummy.copy()
dummy2["name"] = "big_diff"
gen2 = def_gen.copy()
gen2.update({"b": 2, "plateau_duration": 6})
dummy2["generators"] = (def_gen, gen2)
dummy_parameters.append(dummy2)

# small diff
dummy3 = dummy.copy()
dummy3["name"] = "small_diff"
gen1 = def_gen.copy()
gen1.update({"b": 1.5, "plateau_duration": 2, "signal_amplitude": 1})
gen2 = gen1.copy()
gen2.update({"b": 1.9, "signal_amplitude": 1.02, "leaky_k": 0.15})
dummy3["generators"] = (gen1, gen2)
dummy_parameters.append(dummy3)

# tripple
dummy4 = dummy.copy()
dummy4["name"] = "tripplet"
gen1 = def_gen.copy()
gen1.update({"b": 1, "plateau_duration": 1, "signal_amplitude": 1})
gen2 = gen1.copy()
gen2.update({"b": 0.8, "plateau_duration": 2})
gen3 = gen1.copy()
gen3.update({"b": 3, "plateau_duration": 5})
dummy4["generators"] = (gen1, gen2, gen3)
dummy_parameters.append(dummy4)

# variable length
dummy5 = dummy.copy()
dummy5["name"] = "variable_length"
gen1 = def_gen.copy()
gen1.update({"trace_length": 60, "ragged_allowed": True, "signal_amplitude": None, "abort_amplitude": None})
gen2 = gen1.copy()
gen2.update({"leaky_k": 0.2})
dummy5["generators"] = (gen1, gen2)
dummy_parameters.append(dummy5)

# create experiments
ex = exp.Experiments(dummy_parameters, replicates=3)

ex.plot_traces()

In [None]:
ex.create_embedding(dict(FExt=None, CNN=None, RNN=None))

In [None]:
ex.conditional_contrasts()

In [None]:
ex.coincidence_detection()

In [None]:
results = ex.get_results()
display(results.head(3))
display(results.tail(3))

In [None]:
importlib.reload(exp)

panel_column = "name"
panel_rows = ["conditional_contrasts", "coincidence_detection"]
# chr(65)
figsize = (3, 3)
num_samples = 4
alpha = .9
linestyle = "--"

# df = results[results.data_split == "test"]
df = results.copy()

if isinstance(panel_rows, str):
    panel_rows = [panel_rows]

with warnings.catch_warnings():
    warnings.simplefilter('ignore')
    
    panels = df[panel_column].unique()
    N = 1 + len(panel_rows)
    M = len(panels)
    fig, axx = plt.subplots(N, M, figsize=(M * figsize[0], N * figsize[1]))
    
    # plot traces
    img_y = 0
    for img_x, name in enumerate(panels):
        
        ax = axx[img_y, img_x]
        
        plot = None
        for eObj in ex.experiments:
            if eObj.name == name:
                plot = eObj.plot
                break
        
        if plot is not None:
            _ = plot.plot_traces(title=f"{panel_column}: {name}", num_samples=num_samples, by="group", alpha=alpha,
                                 linestyle=linestyle, ax=ax)
    
    # plot rows
    for panel_row in panel_rows:
        
        img_y += 1
        
        for img_x, panel in enumerate(panels):
            
            ax = axx[img_y, img_x]
            
            selected = df[(df.evaluation_type == panel_row) & (df[panel_column] == panel)]
            # _ = sns.barplot(data=selected, x="embedding", y="score", hue="type", ax=ax)
            # _ = sns.swarmplot(data=selected, x="embedding", y="score", hue="data_split", style="cluster_type", ax=ax)
            
            show_legend = True if ax == axx[img_y, -1] else False
            
            exp.Experiments.plot_heatmap(df=selected, evaluation_type=panel_row, index='embedding',
                                         columns='data_split',
                                         group_by="cluster_type", show_legend=show_legend, ax=ax)
    
    # remove y-labels
    axx_ = axx[1:, :].flatten().tolist()
    for ax in axx_:
        ax.set_ylabel(None)
        ax.set_xlabel(None)
    
    for i, panel in enumerate(panel_rows):
        axx[i + 1, 0].set_ylabel(panel.replace("_", " "))
    
    # clean xticklabels
    axx_ = axx[1:, :].flatten().tolist()
    for ax in axx_:
        yticklabels = ax.get_yticklabels()
        yticklabels = [y.get_text().split("_")[0] for y in yticklabels]
        ax.set_yticklabels(yticklabels, rotation='vertical')
    
    # axx_ = axx[1:, 1:].flatten().tolist()
    # for ax in axx_:
    #     ax.set_yticklabels([])
    
    # legends
    # axx_ = axx[1:, :-1].flatten().tolist()
    # for ax in axx_:
    #     ax.get_legend().remove()
    
    # axx_ = axx[:, -1].flatten().tolist()
    # for ax in axx_:
    #     sns.move_legend(ax, loc="upper left", bbox_to_anchor=(1.04, 1))

plt.tight_layout()

fig_name = 6
save_path = Path.cwd().parent.joinpath(f"{fig_name}.png")
print(save_path)
fig.savefig(save_path, dpi=(260))

legend = """
Performance of Different Algorithms on Analyzing Various Synthetic Datasets. A) Showcase of synthetic calcium events designed to represent various levels of analytical difficulty, where color coding corresponds to events generated under different parameter sets that simulate diverse conditions or event types. All events include a random noise amplitude of 0.001 and parameter fluctuations of 0.01, subtly varying each event's parameters. B) Conditional contrasts analysis assesses algorithmic efficiency in distinguishing events from differing conditions (groups 1-3). Events are characterized using different methods: FExt for Feature Extraction, CNN for Convolutional Autoencoder, and RNN for Recurrent Autoencoder. The hierarchical clustering leverages distance metrics between events (Pearson correlationor dynamic time warping), depicted by the absence of training dependency in grey. CNN's inability to process variable-length events results in its omission in the final panel. C) Coincidence detection analysis gauges algorithm performance in predicting the occurrence of stimulus events. This encompasses two groups: one with events exclusively occurring during a stimulus and another with randomly occurring events. The embedding classifier and prediction methods are consistent with panel B, where the classifier identifies stimulus occurrence, and regression determines the exact timing of the stimulus in coinciding events. For all scores the worst score out of 3 replicates is displayed.
"""
legend_path = Path.cwd().parent.joinpath(f"{fig_name}.txt")
with open(legend_path, 'w') as f:
    f.write(legend)

In [None]:
legend = """

"""

# groups (1-3)?

# Troubleshoot regression

In [None]:
big = [e for e in ex.experiments if e.name == "big_diff"]
display(len(big))

dataset = big[0].copy()
display(dataset)

tim0, tim1, tim2 = timings

In [None]:
ev = dataset.filter(dict(z0=(940, 1060)), inplace=False)
ev