In [None]:
from pathlib import Path
import numpy as np
import importlib
from matplotlib import pyplot as plt
import humanize
import sys
import os, psutil
import time
import shutil
import pandas as pd
import seaborn as sns
import warnings
from sklearn.metrics import ConfusionMatrixDisplay, confusion_matrix
import logging

import astrocast.reduction as red
import astrocast.clustering as clust
import astrocast.analysis as ana
import astrocast.autoencoders as AE
import astrocast.helper as helper
import astrocast.experiments as exp

for pack in [red, clust, helper, ana, AE, helper, exp]:
    importlib.reload(pack)


In [None]:
importlib.reload(exp)

z_range = (0, 10000)
timings = (None, list(range(0, z_range[1], 1000)), None)
dummy_parameters = []

dummy = dict(
        name="default",
        num_rows=1000,
        # z0, z1 boundaries
        z_range=z_range,
        timings=timings,
        timing_jitter=None,
        timing_offset=5,
        # Signal generators
        )

# Signal generators
def_gen = dict(noise_amplitude=0.001, trace_length=(50, 50), parameter_fluctuations=0.01)

# Identical
dummy1 = dummy.copy()
dummy1["name"] = "identical"
dummy1["generators"] = (def_gen, def_gen)
dummy_parameters.append(dummy1)

# big diff
dummy2 = dummy.copy()
dummy2["name"] = "big_diff"
gen2 = def_gen.copy()
gen2.update({"b": 2, "plateau_duration": 6})
dummy2["generators"] = (def_gen, gen2)
dummy_parameters.append(dummy2)

# small diff
dummy3 = dummy.copy()
dummy3["name"] = "small_diff"
gen1 = def_gen.copy()
gen1.update({"b": 1.5, "plateau_duration": 2, "signal_amplitude": 1})
gen2 = gen1.copy()
gen2.update({"b": 1.9, "signal_amplitude": 1.02, "leaky_k": 0.15})
dummy3["generators"] = (gen1, gen2)
dummy_parameters.append(dummy3)

# tripple
dummy4 = dummy.copy()
dummy4["name"] = "tripplet"
gen1 = def_gen.copy()
gen1.update({"b": 1, "plateau_duration": 1, "signal_amplitude": 1})
gen2 = gen1.copy()
gen2.update({"b": 0.8, "plateau_duration": 2})
gen3 = gen1.copy()
gen3.update({"b": 3, "plateau_duration": 5})
dummy4["generators"] = (gen1, gen2, gen3)
dummy_parameters.append(dummy4)

# variable length
dummy5 = dummy.copy()
dummy5["name"] = "variable_length"
gen1 = def_gen.copy()
gen1.update({"trace_length": 60, "ragged_allowed": True, "signal_amplitude": None, "abort_amplitude": None})
gen2 = gen1.copy()
gen2.update({"leaky_k": 0.2})
dummy5["generators"] = (gen1, gen2)
dummy_parameters.append(dummy5)

# create experiments
ex = exp.Experiments(dummy_parameters, replicates=3)

ex.plot_traces()

In [None]:
ex.create_embedding(dict(FExt=None, CNN=None, RNN=None))

In [None]:
ex.conditional_contrasts()

In [None]:
ex.coincidence_detection()

In [None]:
results = ex.get_results()
display(results.head(3))
display(results.tail(3))

In [None]:
importlib.reload(exp)

panel_column = "name"
panel_rows = ["conditional_contrasts", "coincidence_detection"]
# chr(65)
figsize = (3, 3)
num_samples = 4
alpha = .9
linestyle = "--"

# df = results[results.data_split == "test"]
df = results.copy()

if isinstance(panel_rows, str):
    panel_rows = [panel_rows]

with warnings.catch_warnings():
    warnings.simplefilter('ignore')
    
    panels = df[panel_column].unique()
    N = 1 + len(panel_rows)
    M = len(panels)
    fig, axx = plt.subplots(N, M, figsize=(M * figsize[0], N * figsize[1]))
    
    # plot traces
    img_y = 0
    for img_x, name in enumerate(panels):
        
        ax = axx[img_y, img_x]
        
        plot = None
        for eObj in ex.experiments:
            if eObj.name == name:
                plot = eObj.plot
                break
        
        if plot is not None:
            _ = plot.plot_traces(title=f"{panel_column}: {name}", num_samples=num_samples, by="group", alpha=alpha,
                                 linestyle=linestyle, ax=ax)
    
    # plot rows
    for panel_row in panel_rows:
        
        img_y += 1
        
        for img_x, panel in enumerate(panels):
            
            ax = axx[img_y, img_x]
            
            selected = df[(df.evaluation_type == panel_row) & (df[panel_column] == panel)]
            # _ = sns.barplot(data=selected, x="embedding", y="score", hue="type", ax=ax)
            # _ = sns.swarmplot(data=selected, x="embedding", y="score", hue="data_split", style="cluster_type", ax=ax)
            
            show_legend = True if ax == axx[img_y, -1] else False
            
            exp.Experiments.plot_heatmap(df=selected, evaluation_type=panel_row, index='embedding',
                                         columns='data_split',
                                         group_by="cluster_type", show_legend=show_legend, ax=ax)
    
    # remove y-labels
    axx_ = axx[1:, :].flatten().tolist()
    for ax in axx_:
        ax.set_ylabel(None)
        ax.set_xlabel(None)
    
    for i, panel in enumerate(panel_rows):
        axx[i + 1, 0].set_ylabel(panel.replace("_", " "))
    
    # clean xticklabels
    axx_ = axx[1:, :].flatten().tolist()
    for ax in axx_:
        yticklabels = ax.get_yticklabels()
        yticklabels = [y.get_text().split("_")[0] for y in yticklabels]
        ax.set_yticklabels(yticklabels, rotation='vertical')
    
    # axx_ = axx[1:, 1:].flatten().tolist()
    # for ax in axx_:
    #     ax.set_yticklabels([])
    
    # legends
    # axx_ = axx[1:, :-1].flatten().tolist()
    # for ax in axx_:
    #     ax.get_legend().remove()
    
    # axx_ = axx[:, -1].flatten().tolist()
    # for ax in axx_:
    #     sns.move_legend(ax, loc="upper left", bbox_to_anchor=(1.04, 1))

In [None]:
t = yticklabels[0]
t.get_text()

In [None]:
raise ValueError

# Generate example datasets

In [None]:
importlib.reload(helper)

# default settings signal generator
def_gen = dict(noise_amplitude=0.001, trace_length=(50, 50), parameter_fluctuations=0.01)
ident = (def_gen, def_gen)

# big diff
def_diff = def_gen.copy()
def_diff.update(dict(b=2, plateau_duration=6))
diff = (def_gen, def_diff)

# small diff
small_diff_1 = def_gen.copy()
small_diff_1.update(dict(b=1.5, plateau_duration=2, signal_amplitude=1))

small_diff_2 = def_gen.copy()
small_diff_2.update(dict(signal_amplitude=1))
small_diff = (small_diff_1, small_diff_2)

# tripple 
trip_1 = def_gen.copy()
trip_1.update(dict(b=1, plateau_duration=1, signal_amplitude=1))

trip_2 = def_gen.copy()
trip_2.update(dict(b=0.8, plateau_duration=2, signal_amplitude=1))

trip_3 = def_gen.copy()
trip_3.update(dict(b=3, plateau_duration=5, signal_amplitude=1))
triplet = [trip_1, trip_2, trip_3]

# Variable Length
def_var_1 = def_gen.copy()
def_var_1.update(dict(trace_length=60, ragged_allowed=True, signal_amplitude=None, abort_amplitude=None))

def_var_2 = def_var_1.copy()
def_var_2.update(dict(leaky_k=0.2))
var_length = (def_var_1, def_var_2)


In [None]:
importlib.reload(ana)

z_range = (0, 10000)
timings = (None, list(range(0, z_range[1], 1000)), None)

def_dummy = dict(num_rows=1000, z_range=z_range, timings=timings, timing_jitter=None, timing_offset=5)
gen_params = {"ident": ident, "diff": diff, "small_diff": small_diff, "triplet": triplet, "var_length": var_length}

experiments = {}
e_id = 0
for _ in range(1):
    for k, gen_param in gen_params.items():
        
        experiments[e_id] = {}
        
        ############################################
        # Create generator for identical populations
        generators = [helper.SignalGenerator(**param, ) for param in gen_param]
        dg = helper.DummyGenerator(generators=generators, **def_dummy)
        
        eObj = dg.get_events()
        eObj.name = k
        
        # save events
        experiments[e_id]["eObj"] = eObj
        experiments[e_id]["population_type"] = eObj.name
        
        #########################
        # Plot example population
        param = dict(num_samples=4, by="group", alpha=.9, linestyle="--", )
        
        plot = ana.Plotting(eObj)
        _ = plot.plot_traces(figsize=(4, 2), title=f"Exp {e_id} ({k})", **param)
        
        experiments[e_id]["plot"] = plot
        
        ##############
        # increment id
        e_id += 1

# Embedding

### Feature Extraction

In [None]:
for i, exp in experiments.items():
    eObj = exp["eObj"]
    
    experiments[i]["embeddings"]["FExt"] = features


### CNN Autoencoder

In [None]:
for i, exp in experiments.items():
    eObj = exp["eObj"]
    
    experiments[i]["embeddings"]["CNN"] = embedding

### RNN Autoencoder

In [None]:
for i, exp in experiments.items():
    eObj = exp["eObj"]
    
    experiments[i]["embeddings"]["RNN"] = embedding

# Conditional Constrasts

## Classifier (Predict condition)

In [None]:
importlib.reload(clust)

results = {k: [] for k in ['eid', 'dataset', 'embedding', 'data split', 'cm', 'accuracy', 'precision', 'recall', 'f1']}

for i, exp in experiments.items():
    eObj = exp["eObj"]
    
    for emb_name, embedding in exp['embeddings'].items():
        
        discr = clust.Discriminator(eObj)
        
        clf = discr.train_classifier(embedding=embedding, category_vector=eObj.events.group.tolist())
        res = discr.evaluate(show_plot=False, title=f"condition: {eObj.name} [{emb_name}]", figsize=(8, 4))
        
        for k, v in res.items():
            
            results['eid'].append(i)
            results['dataset'].append(exp['population_type'])
            results['embedding'].append(emb_name)
            results['data split'].append(k)
            results['cm'].append(v['cm'])
            results['accuracy'].append(v['accuracy'])
            results['precision'].append(v['precision'])
            results['recall'].append(v['recall'])
            results['f1'].append(v['f1'])

df = pd.DataFrame(results)
df

In [None]:
plot_type = "barplot"

with warnings.catch_warnings():
    warnings.simplefilter("ignore")
    
    datasets = df.dataset.unique()
    n_ds = len(datasets)
    fig, axx = plt.subplots(2, n_ds, figsize=(int(3 * n_ds), 8))
    
    for n in range(n_ds):
        
        ds = datasets[n]
        
        # traces
        plot = None
        for i, exp in experiments.items():
            if exp["population_type"] == ds:
                plot = exp['plot']
                break
        
        if plot is not None:
            _ = plot.plot_traces(num_samples=4, by="group", ax=axx[0, n], alpha=.9, linestyle="--")
        
        # precision
        data = df[df.dataset == ds]
        
        if plot_type == "pointplot":
            sns.pointplot(data=data, ax=axx[1, n], x="data split", y="accuracy", hue="embedding", dodge=True)
        
        elif plot_type == "barplot":
            sns.barplot(data=data, ax=axx[1, n], x="embedding", y="accuracy", hue="data split")
        
        elif plot_type == "violinplot":
            sns.violinplot(data=data, ax=axx[1, n], x="embedding", y="accuracy", hue="data split", split=True)
        else:
            raise ValueError(f"unknown plot type")
        
        # axis label
        axx[0, n].set_title(ds)
    
    # set random line
    for ax in axx[1, :]:
        ax.set_ylim(0, 1.1)
        ax.axhline(0.5, linestyle="--", color="gray")
    
    # legends
    for ax in axx[0, :-1]:
        ax.get_legend().remove()
    
    for ax in axx[:, -1]:
        sns.move_legend(ax, loc="upper left", bbox_to_anchor=(1.04, 1))

fig.savefig("conditional_classifier.png", dpi=480)

## Hierarchical Clustering

In [None]:
importlib.reload(clust)

link = clust.Linkage()
for i, exp in experiments.items():
    eObj = exp["eObj"]
    
    for correlation_type in ['pearson', 'dtw']:
        
        num_groups = len(eObj.events.group.unique())
        
        barycenters, cluster_lookup_table = link.get_barycenters(eObj.events,
                                                                 cutoff=num_groups, criterion='maxclust',
                                                                 distance_type=correlation_type
                                                                 )
        print(experiments[i]["population_type"], num_groups,
              len(np.unique(list(cluster_lookup_table.values()))))
        
        if "distance" not in experiments[i]:
            experiments[i]["distance"] = {}
        
        experiments[i]["distance"][correlation_type] = dict(barycenters=barycenters,
                                                            cluster_lookup_table=cluster_lookup_table)

In [None]:
importlib.reload(clust)

results = {k: [] for k in
           ['eid', 'dataset', 'distance', 'adjusted_mutual_info_score', 'adjusted_rand_score', 'homogeneity_score',
            'rand_score']}

for i, exp in experiments.items():
    eObj = exp["eObj"]
    
    for corr_type, v in exp['distance'].items():
        
        true_labels = eObj.events.group.tolist()
        predicted_labels = [v['cluster_lookup_table'][n] - 1 for n in range(len(true_labels))]
        
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            scores = clust.Discriminator.compute_scores(true_labels, predicted_labels, scoring="clustering")
            
            cm = confusion_matrix(predicted_labels, true_labels, normalize=None)
            experiments[i]['distance'][corr_type]['cm'] = cm
            
            results['eid'].append(i)
            results['dataset'].append(exp['population_type'])
            results['distance'].append(corr_type)
            results['adjusted_mutual_info_score'].append(scores['adjusted_mutual_info_score'])
            results['adjusted_rand_score'].append(scores['adjusted_rand_score'])
            results['homogeneity_score'].append(scores['homogeneity_score'])
            results['rand_score'].append(scores['rand_score'])

df = pd.DataFrame(results)
df

In [None]:
with warnings.catch_warnings():
    warnings.simplefilter("ignore")
    
    datasets = df.dataset.unique()
    n_ds = len(datasets)
    fig, axx = plt.subplots(2, n_ds, figsize=(int(3 * n_ds), 8))
    
    for n in range(n_ds):
        
        ds = datasets[n]
        
        # traces
        plot = None
        for i, exp in experiments.items():
            if exp["population_type"] == ds:
                plot = exp['plot']
                break
        
        if plot is not None:
            _ = plot.plot_traces(num_samples=4, by="group", ax=axx[0, n], alpha=.9, linestyle="--")
        
        # precision
        data = df[df.dataset == ds]
        sns.barplot(data=data, ax=axx[1, n], x="distance", y="rand_score")
        
        # axis label
        axx[0, n].set_title(ds)
    
    # set random line
    for ax in axx[1, :]:
        ax.set_ylim(0, 1.1)
        ax.axhline(0.5, linestyle="--", color="gray")
    
    # legends
    for ax in axx[0, :-1]:
        ax.get_legend().remove()
    
    sns.move_legend(axx[0, -1], loc="upper left", bbox_to_anchor=(1.04, 1))

fig.savefig("conditional_hierarchical.png", dpi=480)

# Coincidence detection

In [None]:
eObj = experiments[0]["eObj"]

with warnings.catch_warnings():
    warnings.simplefilter('ignore')
    
    fig, ax = plt.subplots(1, 1, figsize=(10, 3))
    
    Y = eObj.events.dz + np.random.normal(0, 5, size=len(eObj.events))
    sns.scatterplot(data=eObj.events, x="z0", y=Y, hue="group",
                    ax=ax)
    
    for v in timings[1]:
        ax.axvline(v, color="gray", linestyle="--")

## Classifer (Predict Incidence Occurrence)

In [None]:
importlib.reload(clust)

results = {k: [] for k in ['eid', 'dataset', 'embedding', 'data split', 'cm', 'accuracy', 'precision', 'recall', 'f1']}

for i, exp in experiments.items():
    eObj = exp["eObj"]
    
    for emb_name, embedding in exp['embeddings'].items():
        
        cDetect = clust.CoincidenceDetection(events=eObj, incidences=timings[1], embedding=embedding)
        
        clf, res = cDetect.predict_coincidence(binary_classification=True)
        
        for k, v in res.items():
            
            results['eid'].append(i)
            results['dataset'].append(exp['population_type'])
            results['embedding'].append(emb_name)
            results['data split'].append(k)
            results['cm'].append(v['cm'])
            results['accuracy'].append(v['accuracy'])
            results['precision'].append(v['precision'])
            results['recall'].append(v['recall'])
            results['f1'].append(v['f1'])

df = pd.DataFrame(results)
df

In [None]:
plot_type = "barplot"

with warnings.catch_warnings():
    warnings.simplefilter("ignore")
    
    datasets = df.dataset.unique()
    n_ds = len(datasets)
    fig, axx = plt.subplots(2, n_ds, figsize=(int(3 * n_ds), 8))
    
    for n in range(n_ds):
        
        ds = datasets[n]
        
        # traces
        plot = None
        for i, exp in experiments.items():
            if exp["population_type"] == ds:
                plot = exp['plot']
                break
        
        if plot is not None:
            _ = plot.plot_traces(num_samples=4, by="group", ax=axx[0, n], alpha=.9, linestyle="--")
        
        # precision
        data = df[df.dataset == ds]
        
        if plot_type == "pointplot":
            sns.pointplot(data=data, ax=axx[1, n], x="data split", y="accuracy", hue="embedding", dodge=True)
        
        elif plot_type == "barplot":
            sns.barplot(data=data, ax=axx[1, n], x="embedding", y="accuracy", hue="data split")
        
        elif plot_type == "violinplot":
            sns.violinplot(data=data, ax=axx[1, n], x="embedding", y="accuracy", hue="data split", split=True)
        else:
            raise ValueError(f"unknown plot type")
        
        # axis label
        axx[0, n].set_title(ds)
    
    # set random line
    for ax in axx[1, :]:
        ax.set_ylim(0, 1.1)
        ax.axhline(0.5, linestyle="--", color="gray")
    
    # legends
    for ax in axx[0, :-1]:
        ax.get_legend().remove()
    
    for ax in axx[:, -1]:
        sns.move_legend(ax, loc="upper left", bbox_to_anchor=(1.04, 1))

# fig.savefig("conditional_classifier.png", dpi=480)

## Regression (Predict timing)

In [None]:
importlib.reload(clust)

results = {k: [] for k in ['eid', 'dataset', 'embedding', 'score']}

for i, exp in experiments.items():
    eObj = exp["eObj"]
    
    for emb_name, embedding in exp['embeddings'].items():
        
        cDetect = clust.CoincidenceDetection(events=eObj, incidences=timings[1], embedding=embedding)
        
        clf, res = cDetect.predict_incidence_location()
        
        results['eid'].append(i)
        results['dataset'].append(exp['population_type'])
        results['embedding'].append(emb_name)
        results['score'].append(res['score'])

df = pd.DataFrame(results)
df

In [None]:
with warnings.catch_warnings():
    warnings.simplefilter("ignore")
    
    datasets = df.dataset.unique()
    n_ds = len(datasets)
    fig, axx = plt.subplots(2, n_ds, figsize=(int(3 * n_ds), 8))
    
    for n in range(n_ds):
        
        ds = datasets[n]
        
        # traces
        plot = None
        for i, exp in experiments.items():
            if exp["population_type"] == ds:
                plot = exp['plot']
                break
        
        if plot is not None:
            _ = plot.plot_traces(num_samples=4, by="group", ax=axx[0, n], alpha=.9, linestyle="--")
        
        # precision
        data = df[df.dataset == ds]
        sns.barplot(data=data, ax=axx[1, n], x="embedding", y="score")
        
        # axis label
        axx[0, n].set_title(ds)
    
    # set random line
    for ax in axx[1, :]:
        ax.set_ylim(-1.1, 1.1)
    #     ax.axhline(0.5, linestyle="--", color="gray")
    
    # share y-axis
    # for ax in axx[1, 1:]:
    #     ax.sharey(axx[1, 0])
    
    # legends
    for ax in axx[0, :-1]:
        ax.get_legend().remove()
    
    # for ax in axx[:, -1]:
    #     sns.move_legend(ax, loc="upper left", bbox_to_anchor=(1.04, 1))

# fig.savefig("conditional_classifier.png", dpi=480)