In [None]:
import logging
from pathlib import Path
import numpy as np
import matplotlib.pyplot as plt

In [None]:
from pyobsbox import Checker

In [None]:
logging.getLogger("pyobsbox").setLevel(logging.WARNING)
logging.getLogger("pyobsbox.plotter").setLevel(logging.DEBUG)

# List models

In [None]:
models_folder = Path("../../models")

In [None]:
[model.name for model in list(models_folder.glob('model_*'))]

In [None]:
model_path = models_folder / "{model_name}"

In [None]:
checker = Checker(model_path)

In [None]:
checker.plot_loss()

In [None]:
train_data, train_meta = checker.load_train_dataset()
train_prediction = checker.predict(train_data)

In [None]:
train_meta.head()

In [None]:
train_meta.shape

In [None]:
train_data.shape

In [None]:
train_prediction.shape

In [None]:
validation_data, validation_meta = checker.load_validation_dataset()
validation_prediction = checker.predict(validation_data)

In [None]:
validation_meta.head()

In [None]:
validation_meta.shape

In [None]:
validation_data.shape

In [None]:
validation_prediction.shape

In [None]:
def plot_train_index(index: int):
    """Helper function to plot a sample from the train dataset.
    """
    true_signal = train_data[index]
    predicted_signal = train_prediction[index]
    metadata_row = train_meta.iloc[index]
    print(f"index: {index}")
    print(f"Error: {train_error[index]}")
    print(metadata_row)
    checker.plot_sample(metadata_row, true_signal, predicted_signal)
    plt.show()
    encoded = checker.encode(true_signal)
    plt.imshow(encoded)
    plt.show()

In [None]:
train_error = checker.MSE(train_data, train_prediction)
checker.plot_error(train_error, bins=100)

In [None]:
validation_error = checker.MSE(validation_data, validation_prediction)
checker.plot_error(validation_error, bins=100)

In [None]:
# Histogram peak
hist, bins = np.histogram(train_error, bins=100)
hist_peak = bins[np.argmax(hist)]

In [None]:
# plotting a random sample
plot_train_index(0)

In [None]:
sorted_error_indices = np.argsort(train_error)

# From left to right

In [None]:
for i in range(10):
    index = sorted_error_indices[i]
    plot_train_index(index)

# From right to left

In [None]:
for i in range(1, 51):
    index = sorted_error_indices[-i]
    plot_train_index(index)

# Around the peak

In [None]:
error_around_peak = abs(train_error - hist_peak)
sorted_error_indices_around_peak = np.argsort(error_around_peak)

In [None]:
for i in range(10):
    index = sorted_error_indices_around_peak[i]
    plot_train_index(index)

# Around a point

In [None]:
# point = 

In [None]:
# error_around_peak = abs(train_error - point)
# sorted_error_indices_around_point = np.argsort(error_around_peak)

In [None]:
# for i in range(10):
#     index = sorted_error_indices_around_point[i]
#     plot_train_index(index)

# Clustering

In [None]:
indices = sorted_error_indices[-np.arange(1, 1024+1)]
encoding_true = train_data[indices]
encoding_pred = train_prediction[indices]
encoding_meta = train_meta.iloc[indices]
try:
    encoded = checker.encode(encoding_true)
except:
    encoded = np.array(checker.encoder(encoding_true)).squeeze()
encoded_2d = encoded.reshape(encoded.shape[0], -1)

In [None]:
encoded_2d.shape

In [None]:
print(encoded_2d[0])

In [None]:
encoded_2d.min()

In [None]:
encoded_2d.max()

In [None]:
def plot_clusters(cluster_indices):
    fig, axes = plt.subplots(encoding_true.shape[2], 1, figsize=(12, 3))
    if not isinstance(axes, np.ndarray):
        axes = np.array([axes])
    for i in range(encoding_true.shape[2]):
        axes[i].plot(encoding_true[cluster_indices][:,:, i].T,
                     linewidth=0,
                     marker=',',
                     alpha=0.5)
    return fig, axes

def plot_cluster_imgs(cluster_indices):
    fig, axes = plt.subplots(1, len(cluster_indices))
    if not isinstance(axes, np.ndarray):
        axes = np.array([axes])
    for i, img in enumerate(encoded[cluster_indices]):
        axes[i].imshow(img)
    return fig, axes

def plot_clusters_grid(cluster_indices):
    grid_size = int(np.sqrt(len(cluster_indices)))
    fig, axes = plt.subplots(grid_size, grid_size, figsize=(6, 6))
    if not isinstance(axes, np.ndarray):
        axes = np.array([axes])
    axes = axes.flatten()
    for i in range(len(axes)):
        axes[i].plot(encoding_true[cluster_indices[i]], linewidth=0, marker=',')
    return fig, axes

def plot_cluster_imgs_grid(cluster_indices):
    grid_size = int(np.sqrt(len(cluster_indices)))
    fig, axes = plt.subplots(grid_size, grid_size, figsize=(6, 6))
    if not isinstance(axes, np.ndarray):
        axes = np.array([axes])
    axes = axes.flatten()
    for i in range(len(axes)):
        axes[i].imshow(encoded[cluster_indices[i]])
    return fig, axes

def plot_avg_encoding(cluster_indices):
    plt.imshow(encoded[cluster_indices].mean(axis=0))

In [None]:
import sklearn
from sklearn.manifold import TSNE

In [None]:
for perplexity in [10, 30]:
    tsne = TSNE(perplexity=perplexity)
    encoded_2d_tsne = tsne.fit_transform(encoded_2d)
    plt.scatter(encoded_2d_tsne[:, 0], encoded_2d_tsne[:, 1], marker='.', s=2)
    plt.title(f"perp={perplexity}")
    plt.show()

In [None]:
tsne = TSNE(perplexity=30)
encoded_2d_tsne = tsne.fit_transform(encoded_2d)

In [None]:
from sklearn.decomposition import PCA

pca = PCA(n_components=8)
encoded_2d_pca = pca.fit_transform(encoded_2d)
plt.scatter(encoded_2d_pca[:, 0], encoded_2d_pca[:, 1], marker='.', s=4)

In [None]:
pca.explained_variance_ratio_.sum()

In [None]:
from sklearn.cluster import KMeans
from yellowbrick.cluster import KElbowVisualizer
import matplotlib as mpl
mpl.style.use("default")

In [None]:
kmeans = KMeans()
viz = KElbowVisualizer(kmeans, k=(2, 64))
viz.fit(encoded_2d_pca)
viz.show()

In [None]:
kmeans = KMeans(n_clusters=viz.elbow_value_)
clustering = kmeans.fit(encoded_2d_pca)
print('-------------')
print("n_clusters:", viz.elbow_value_)
for cluster in np.unique(clustering.labels_):
    cluster_indices = np.where(clustering.labels_ == cluster)[0]
    print("cluster", cluster)
    print(f"N samples: {len(cluster_indices)}")
    cluster_indices = np.random.choice(cluster_indices,
                                   min(8, len(cluster_indices)))
    
    fig, axes = plot_clusters(cluster_indices)
    plt.show()
    
    fig, axes = plot_cluster_imgs(cluster_indices)
    plt.show()

In [None]:
# from sklearn.cluster import DBSCAN
# dbscan = DBSCAN(eps=4)
# clustering = dbscan.fit(encoded_2d_tsne)
# sc = plt.scatter(encoded_2d_tsne[:, 0], encoded_2d_tsne[:, 1], marker='.', s=2, c=clustering.labels_, cmap='tab10')
# plt.legend(*sc.legend_elements())
# plt.show()
# for cluster in np.unique(clustering.labels_):
#     cluster_indices = np.where(clustering.labels_ == cluster)[0]
#     print("cluster", cluster)
#     fig, axes = plot_clusters(cluster_indices)
#     plt.show()
#     plot_avg_encoding(cluster_indices)
#     plt.show()