# Virgo Demo 2 - Testing

In [None]:
from virgo.cluster import VirgoCluster
from virgo.kernel import VirgoKernel
from virgo.mixture import VirgoMixture
from virgo.cleaner import GaussianMixtureCleaner, LowDensityCleaner

%load_ext autoreload
%autoreload 2
%matplotlib notebook
# https://towardsdatascience.com/transform-data-to-high-dimensional-kernel-space-87d62b670e0f

In [None]:
file_name = "/home/max/Software/virgo/data/data.txt"
virgo_cluster = VirgoCluster(file_name=file_name)
virgo_cluster.scale_data()
virgo_kernel = VirgoKernel(virgo_cluster)
virgo_kernel()
virgo_cluster.print_datastats()

In [None]:
if True:
    virgo_mixture = VirgoMixture(virgo_cluster, n_comp=12)
elif True:
    virgo_mixture = VirgoMixture(virgo_cluster, n_comp=12, fit_dim_ind=[0, 1, 2, -2, -1])
else:
    virgo_mixture = VirgoMixture(
        virgo_cluster,
        n_comp=15,
        mixture_type="bayesian_gaussian",
        fit_dim_ind=[0, 1, 2, -2, -1],
    )
    
elbo = virgo_mixture.fit()
print(elbo)
print(virgo_mixture.model.weights_)

virgo_mixture.predict()
labels_all = virgo_cluster.get_labels(return_counts=True)
print(labels_all)

virgo_mixture.predict(remove_uncertain_labels=True)
labels_removed = virgo_cluster.get_labels(return_counts=True)
print(labels_removed)
print("Diff per class: ", labels_all[1] - labels_removed[1][1:])
print("Rel loss per class: ", (labels_all[1] - labels_removed[1][1:]) / labels_all[1])

virgo_cluster.plot_cluster(n_step=50)

In [None]:
virgo_cluster.plot_cluster(n_step=50, remove_uncertain=False)

In [None]:
virgo_cluster.plot_cluster(n_step=10, remove_uncertain=False, cluster_label=[-1])

## Visualize different clusters

The last run evaluation will be visualized. 

In [None]:
for i in virgo_cluster.get_labels()[1:]:
    virgo_cluster.plot_cluster(n_step=10, cluster_label=[i])

In [None]:
focus = 3
for i in virgo_cluster.get_labels()[1:]:
    if i == focus:
        continue
    virgo_cluster.plot_cluster(n_step=10, cluster_label=[focus, i])

## Testing different fit configurations

ELBO studies for different models and used input dimensions.

In [None]:
elbos = []
for i in range(2, 25, 3):
    virgo_mixture = VirgoMixture(virgo_cluster, n_comp=i, mixture_type="gaussian", fit_dim_ind=[0, 1, 2, -2, -1])
    elbo = virgo_mixture.fit()
    elbos.append(elbo)
    print(i, elbo)

print(elbos)
plt.plot(elbos)
plt.show()   

### GaussianMixture cleaner

In [None]:
virgo_cluster.plot_cluster(n_step=10, cluster_label=[6])

In [None]:
gm_cleaner = GaussianMixtureCleaner(virgo_cluster)
print(gm_cleaner.unique_labels)
gm_cleaner.clean()
print(gm_cleaner.unique_labels)

In [None]:
virgo_cluster.plot_cluster(n_step=10, cluster_label=[6])

In [None]:
for i in virgo_cluster.get_labels()[1:]:
    virgo_cluster.plot_cluster(n_step=10, cluster_label=[i])

### Low density cleaner

In [None]:
virgo_cluster.plot_cluster(n_step=50)
virgo_cluster.get_labels(return_counts=True)

In [None]:
d_cleaner = LowDensityCleaner(virgo_cluster, 1e-10)
d_cleaner.clean()
virgo_cluster.plot_cluster(n_step=50)
virgo_cluster.get_labels(return_counts=True)

In [None]:
virgo_cluster.plot_cluster(n_step=50)

In [None]:
clusters = []
labels = []
elbo_rat = []

for val in values.copy():
    plot_dat = virgo_cluster.cluster[virgo_cluster.cluster_labels == val]
    plot_y = virgo_cluster.cluster_labels[virgo_cluster.cluster_labels == val]
    
    model = GaussianMixture(n_components=1)
    model.fit(plot_dat)
    m1 = model.lower_bound_
    
    model = GaussianMixture(n_components=2)
    model.fit(plot_dat)
    m2 = model.lower_bound_

    print(m1 / m2, m1, m2)
    elbo_rat.append(m1 / m2)
    
    if (m1 / m2) < 1.075:
        clusters.append(plot_dat)
        labels.append(plot_y)
        
        fig = plt.figure(figsize=(8, 8))
        ax = fig.add_subplot(projection='3d')
        ax.scatter(plot_dat.T[0], plot_dat.T[1], plot_dat.T[2], c=plot_y, marker=".", cmap="plasma")
        plt.show()

    else:
        
        plot_pred_y = model.predict(plot_dat)
        new_label = len(values)
        
        clusters.append(plot_dat[plot_pred_y==0])
        labels.append(plot_y[plot_pred_y==0])
        
        clusters.append(plot_dat[plot_pred_y==1])
        plot_y[plot_pred_y==1] = new_label
        labels.append(plot_y[plot_pred_y==1])
        values.append(new_label)
        
        fig = plt.figure(figsize=(8, 8))
        ax = fig.add_subplot(projection='3d')
        ax.scatter(plot_dat.T[0], plot_dat.T[1], plot_dat.T[2], c=plot_pred_y, marker=".", cmap="plasma")
        plt.show()

clusters = np.array(clusters)
labels = np.array(labels)

In [None]:
for ind, clust in enumerate(clusters):
    print(labels[ind].shape[0])
    
    if ind == 0:
        all_clusters = np.array(clust)
        all_labs = np.array(labels[ind])
    else:
        all_clusters = np.concatenate([all_clusters, clust])
        all_labs = np.concatenate([all_labs, labels[ind]])


virgo_cluster.cluster = all_clusters
virgo_cluster.cluster_labels = all_labs

In [None]:
virgo_cluster.plot_cluster(n_step=50)

In [None]:
virgo_cluster.plot_cluster(n_step=4, cluster_ind=[2, 1])