In [None]:
import importlib

import pandas as pd
from tqdm import tqdm
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import warnings

import astrocast.helper as helper
import astrocast.reduction as red
import astrocast.clustering as clust
import astrocast.analysis as ana
import astrocast.autoencoders as AE
import astrocast.clustering as clust

importlib.reload(helper)
importlib.reload(red)
importlib.reload(clust)
importlib.reload(ana)
importlib.reload(AE)
importlib.reload(clust)


# Check SignalGenerator

In [None]:
x = np.arange(-10, 20, 1)
y = helper.SignalGenerator._richards_curve(x, m_0=8)
ysum = np.cumsum(y)

fig, (ax0, ax1) = plt.subplots(2, 1)

ax0.plot(x, y)
ax1.plot(x, ysum)

In [None]:
importlib.reload(helper)

sg = helper.SignalGenerator(trace_length=100,
                            plateau_duration=2,
                            parameter_fluctuations=0,
                            allow_negative_values=False,
                            signal_amplitude=None,
                            m_0=8,
                            a=0, k=1,
                            offset=(3, 3),
                            noise_amplitude=0.05,
                            leaky_k=0.2)
signal = sg.generate_signal()

fig, ax = plt.subplots(1, 1, figsize=(7, 2))
_ = ax.plot(signal)

# Test Dummy Generator

## Conditional Constrasts

In [None]:
importlib.reload(helper)
importlib.reload(ana)

default = dict(
        trace_length=(50, 50),
        allow_negative_values=False,
        noise_amplitude=0.01,
        offset=(1, 1),
        parameter_fluctuations=0.05,
        )

sg1 = helper.SignalGenerator(plateau_duration=0,
                             signal_amplitude=None,
                             a=0, k=1, m_0=8,
                             leaky_k=0.2,
                             **default)

sg2 = helper.SignalGenerator(plateau_duration=0,
                             signal_amplitude=None,
                             a=0, k=2, m_0=8,
                             leaky_k=0.3,
                             **default)

pop_size = 200
dg = helper.DummyGenerator(generators=[sg1, sg2], num_rows=pop_size)
eObj = dg.get_events()
# display(eObj)

plot = ana.Plotting(events=eObj)
_ = plot.plot_traces(num_samples=len(eObj), by="group")

ids = eObj.events["group"].tolist()

## Timings

In [None]:
importlib.reload(helper)
importlib.reload(ana)

max_z = 10000

default = dict(
        trace_length=None,
        allow_negative_values=False,
        noise_amplitude=0.01,
        offset=(1, 1),
        parameter_fluctuations=0,
        )

sg1 = helper.SignalGenerator(plateau_duration=0,
                             signal_amplitude=None,
                             a=0, k=1, m_0=8,
                             leaky_k=0.2,
                             **default)

t1 = None
t2 = list(range(0, max_z, 1000))

pop_size = 250
dg = helper.DummyGenerator(generators=[sg1, sg1], num_rows=pop_size,
                           timings=[t1, t2], timing_jitter=(5, 50), z_range=(0, max_z))
eObj = dg.get_events()
# display(eObj)

# plot = ana.Plotting(events=eObj)
# _ = plot.plot_traces(num_samples=len(eObj), by="group")

# ids = eObj.events["group"].tolist()

with warnings.catch_warnings():
    warnings.simplefilter('ignore')
    
    fig, ax = plt.subplots(1, 1, figsize=(10, 3))
    
    sns.scatterplot(data=eObj.events, x="z0", y="dz", hue="group",
                    ax=ax)
    
    for v in t2:
        ax.axvline(v, color="gray", linestyle="--")

# Test Unsupervised Clustering

## Feature extraction

In [None]:
evaluate(features, ids)

In [None]:
# extract features
fe = red.FeatureExtraction(eObj)
features = fe.all_features(dropna=True)
# display(features)

# get labels
hdb = clust.HdbScan()
lbls = hdb.fit(features)
print(np.unique(lbls))

# Compute the metrics
accuracy = accuracy_score(ids, lbls)
precision = precision_score(ids, lbls, average='macro')  # Use 'binary' for binary classification
recall = recall_score(ids, lbls, average='macro')  # Use 'binary' for binary classification
f1 = f1_score(ids, lbls, average='macro')  # Use 'binary' for binary classification

print(f"Accuracy: {accuracy}")
print(f"Precision: {precision}")
print(f"Recall: {recall}")
print(f"F1 Score: {f1}")

# get umap
umap = red.UMAP()
embedding = umap.train(features)
umap.plot(data=embedding, labels=lbls, plot_type="umap", size=10)

## CNN

In [None]:
importlib.reload(AE)

data = dummy_Object.events.trace.tolist()
data = np.array(data)
print(data.shape)

cnn = AE.CNN_Autoencoder(target_length=data.shape[1], latent_size=384, add_noise=0.01, use_cuda=True)
X_train, X_val, X_test = cnn.split_dataset(data)

_ = cnn.train_autoencoder(X_train, X_val, X_test, epochs=50)

embedding = cnn.embed(data)

_ = cnn.plot_examples_pytorch(X_train, show_diff=True, trim_zeros=False)
_ = cnn.plot_examples_pytorch(X_test, show_diff=True, trim_zeros=False)

In [None]:
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay

importlib.reload(clust)

disc = clust.Discriminator(events=dummy_Object)

# disc.get_available_models()

clf = disc.train_classifier(embedding=embedding, category_vector=ids)
train, test = disc.evaluate(show_plot=True)

disp = ConfusionMatrixDisplay(confusion_matrix=train, display_labels=clf.classes_)
disp.plot()

In [None]:
importlib.reload(red)
import warnings


def compute_scores(true_labels, predicted_labels):
    """Compute performance metrics between true and predicted labels.
    
    Args:
      true_labels: Ground truth (correct) labels.
      predicted_labels: Predicted labels, as returned by a classifier.
    
    Returns:
      A dictionary with accuracy, precision, recall, and F1 score.
    """
    return {
        'accuracy':  accuracy_score(true_labels, predicted_labels),
        'precision': precision_score(true_labels, predicted_labels, average='macro'),
        'recall':    recall_score(true_labels, predicted_labels, average='macro'),
        'f1':        f1_score(true_labels, predicted_labels, average='macro')
        }


def best_score_plot(embedding, true_labels, lbls1, lbls2, axx=None):
    """Plots the embedding with the best scoring labels based on F1 score.
    
    Args:
      embedding: The embedding to be plotted.
      true_labels: Ground truth (correct) labels for the embedding.
      lbls1, lbls2: Two sets of labels to compare.
      ax: Matplotlib axis to plot on.
    """
    
    scores1 = compute_scores(true_labels, lbls1)
    scores2 = compute_scores(true_labels, lbls2)
    
    # Decide which labels have the best F1 score
    best_labels, best_scores = (lbls1, scores1) if scores1['f1'] > scores2['f1'] else (lbls2, scores2)
    
    if axx is None:
        fig, (ax0, ax1) = plt.subplots(2, 1)
    else:
        ax0, ax1 = axx
    
    # Plotting
    umap.plot(data=embedding, labels=best_labels, true_labels=true_labels, plot_type='matplotlib', size=10, alpha=0.75,
              ax=ax0)
    title = f"Accuracy: {best_scores['accuracy'] * 100:.1f}%, Precision: {best_scores['precision'] * 100:.1f}%, Recall: {best_scores['recall'] * 100:.1f}%, F1: {best_scores['f1'] * 100:.1f}%"
    ax.set_title(title)
    
    # Generating and plotting confusion matrix for the best model
    cm = confusion_matrix(true_labels, best_labels)
    disp = ConfusionMatrixDisplay(confusion_matrix=cm)
    disp.plot(ax=ax1)


def evaluate(embedding, true_labels, min_dist=0.11, n_neighbors=50, min_sample=2, min_cluster_size=50):
    
    umap = red.UMAP(min_dist=0.11, n_neighbors=50)
    um_embedding = umap.train(embedding)
    
    lbls1 = hdb.fit(embedding)
    lbls2 = hdb.fit(um_embedding)
    
    best_score_plot(embedding, true_labels, lbls1, lbls2)

## RNN

In [None]:
# pdl = AE.PaddedDataLoader(data=dummy_Object.events.trace)
# X_train, X_val, X_test = pdl.get_datasets(batch_size=128,
#                                           val_size=0.1,
#                                           test_size=0.05)

In [None]:
# tRAE = AE.TimeSeriesRnnAE(use_cuda=True, rnn_hidden_dim=64, num_layers=1,
#                           dropout=0.01, encoder_lr=0.001, decoder_lr=0.001)
# _ = tRAE.train_epochs(dataloader_train=X_train,
#                       dataloader_val=X_val,
#                       num_epochs=100,
#                       patience=1000,
#                       diminish_learning_rate=0.98,
#                       safe_after_epoch=None,
#                       show_mode='notebook'
#                       )

In [None]:
# fig, x_val, y_val, latent, losses = tRAE.plot_traces(dataloader=X_test, figsize=(20, 20))
# fig.savefig("tRAE_performance.png")

In [None]:
# X = pdl.get_dataloader(dummy_Object.events.trace, batch_size=16, shuffle=False)
# _, _, latent, _ = tRAE.embedd(X)
# latent = np.array(latent)

# hdb = clust.HdbScan()
# lbls = hdb.fit(latent)
# uniq = np.unique(lbls)
# print(len(uniq))
# print(uniq)

# # umap = red.UMAP()
# # ulatent = umap.train(latent)
# # print(ulatent.shape)
# # 
# # umap.plot(data=ulatent, labels=lbls, use_napari=False)