# Virgo Demo 3 - Advanced pipeline

In [None]:
from virgo.data.cluster import VirgoCluster
from virgo.data.cleaner import LowDensityCleaner
from virgo.models.kernel import VirgoSimpleKernel, VirgoKernel
from virgo.models.mixture import VirgoMixture, VirgoClustering

%load_ext autoreload
%autoreload 2

%matplotlib notebook

In [None]:
filebase = "/home/max/Software/virgo/data/VIRGO/snap_800"
virgo_cluster = VirgoCluster(
    file_name=filebase, io_mode=1, cut_mach_dim=-2, n_max_data=200000, 
)
virgo_cluster.data = virgo_cluster.data[:, :-1]
virgo_cluster.scale_data()
virgo_cluster.print_datastats()
virgo_cluster.plot_raw_hists(
    bins=100, plot_range=[[2000., 8000.], [-6000., 1000.], [-3000., 6000.]]
)

In [None]:
# Vanilla FoF
virgo_cluster.run_fof(linking_length=5000., min_group_size=1000, use_scaled_data=False)

labels, counts = virgo_cluster.get_labels(return_counts=True)
print(labels, counts)
virgo_cluster.plot_cluster(n_step=1, store_gif=False, gif_title="fof_vanilla_50", maker_size=1.)

In [None]:
# virgo_kernel = VirgoKernel(virgo_cluster, spatial_dim=[0, 1, 2, 3, 4, 5], k_nystroem=4000, pca_comp=5)
virgo_kernel = VirgoKernel(virgo_cluster, k_nystroem=100, pca_comp=5)
virgo_kernel()
virgo_cluster.print_datastats()

In [None]:
# rbf 800
virgo_cluster.run_fof(linking_length=0.018, min_group_size=100, use_scaled_data=True)

# # rbf 850
# virgo_cluster.run_fof(linking_length=0.019, min_group_size=300, use_scaled_data=True)

# # rbf 900
# virgo_cluster.run_fof(linking_length=0.019, min_group_size=300, use_scaled_data=True)

# raw
# virgo_cluster.run_fof(linking_length=30., min_group_size=3000, use_scaled_data=False)

labels, counts = virgo_cluster.get_labels(return_counts=True)
print(len(labels))
print(labels, counts)
virgo_cluster.plot_cluster(n_step=1, plot_kernel_space=True, store_gif=False, gif_title="fof_kernel_kspace")
virgo_cluster.plot_cluster(n_step=1, store_gif=False, gif_title="fof_kernel")

In [None]:
# 800 -16
# 850 -14
# 900 -14
d_cleaner = LowDensityCleaner(virgo_cluster, 1e-10)
d_cleaner.clean()
print(virgo_cluster.get_labels(return_counts=True))
virgo_cluster.plot_cluster(n_step=1)

In [None]:
virgo_cluster.plot_cluster(n_step=1, store_gif=False, gif_title="fof_kernel", cluster_label=[0, 1, 2, 3, 5])

In [None]:
virgo_mixture = VirgoMixture(virgo_cluster, n_comp=2)
elbo = virgo_mixture.fit()

print(f"ELBO: {elbo}")
print(f"Mixture weights {virgo_mixture.model.weights_}")

virgo_mixture.predict(remove_uncertain_labels=False)
labels_removed = virgo_cluster.get_labels(return_counts=True)
print(labels_removed)

virgo_cluster.cluster_labels[virgo_cluster.cluster_labels == 1] = 2
# virgo_cluster.cluster_labels[virgo_cluster.cluster_labels == 1][0:100] = 4

virgo_cluster.plot_cluster(cmap_vmax=4, n_step=5, plot_kernel_space=True, store_gif=False, gif_title="nystroem_separation_kernelspace")
virgo_cluster.plot_cluster(cmap_vmax=4, n_step=5, store_gif=False, gif_title="nystroem_separation_kernelspace")

In [None]:
d_cleaner = LowDensityCleaner(virgo_cluster, 1e-10)
d_cleaner.clean()
print(virgo_cluster.get_labels(return_counts=True))
virgo_cluster.plot_cluster(n_step=5)

In [None]:
vc_2 = VirgoCluster(file_name=None)
vc_2.data = virgo_cluster.cluster[virgo_cluster.cluster_labels >=0]
vc_2.scale_data()
vc_2.print_datastats()

In [None]:
vc_2.run_fof(linking_length=10000., min_group_size=200, use_scaled_data=False)


labels, counts = vc_2.get_labels(return_counts=True)
print(len(labels))
print(labels, counts)
print(counts[1:].sum())
vc_2.plot_cluster(n_step=20, store_gif=False, gif_title="fof_kernel")

In [None]:
vc_3 = VirgoCluster(file_name=None)
vc_3.data = vc_2.cluster[vc_2.cluster_labels >=0]
vc_3.scale_data()
vc_3.print_datastats()

In [None]:
vk_3 = VirgoKernel(vc_3, k_nystroem=1000, pca_comp=5)
vk_3()
vc_3.print_datastats()

virgo_mixture = VirgoMixture(vc_3, n_comp=2)
elbo = virgo_mixture.fit()

print(f"ELBO: {elbo}")
print(f"Mixture weights {virgo_mixture.model.weights_}")

virgo_mixture.predict(remove_uncertain_labels=False)
labels_removed = vc_3.get_labels(return_counts=True)
print(labels_removed)

vc_3.plot_cluster(n_step=25, plot_kernel_space=True, store_gif=False, gif_title="gmm_kernel_kspace")
vc_3.plot_cluster(n_step=25, store_gif=False, gif_title="gmm_kernel")


# vk_3 = VirgoSimpleKernel(vc_3)
# vk_3()
# vc_3.print_datastats()

# virgo_mixture = VirgoMixture(vc_3, n_comp=4)
# elbo = virgo_mixture.fit()

# print(f"ELBO: {elbo}")
# print(f"Mixture weights {virgo_mixture.model.weights_}")

# virgo_mixture.predict(remove_uncertain_labels=True)
# labels_removed = vc_3.get_labels(return_counts=True)
# print(labels_removed)

# vc_3.plot_cluster(n_step=5, plot_kernel_space=True, store_gif=False, gif_title="gmm_kernel_kspace")
# vc_3.plot_cluster(n_step=5, store_gif=False, gif_title="gmm_kernel")

In [None]:
d_cleaner = LowDensityCleaner(vc_3, 1e-8)
d_cleaner.clean()
print(virgo_cluster.get_labels(return_counts=True))
vc_3.plot_cluster(n_step=25)

In [None]:
vc_4 = VirgoCluster(file_name=None)
vc_4.data = vc_3.cluster[vc_3.cluster_labels >=0]
vc_4.scale_data()
vc_4.print_datastats()

In [None]:
vk_4 = VirgoSimpleKernel(vc_4)
vk_4()
vc_4.print_datastats()

virgo_mixture = VirgoMixture(vc_4, n_comp=6)
elbo = virgo_mixture.fit()

print(f"ELBO: {elbo}")
print(f"Mixture weights {virgo_mixture.model.weights_}")

virgo_mixture.predict(remove_uncertain_labels=True)
labels_removed = vc_4.get_labels(return_counts=True)
print(labels_removed)

vc_4.plot_cluster(n_step=5, plot_kernel_space=True, store_gif=False, gif_title="gmm_kernel_kspace")
vc_4.plot_cluster(n_step=5, store_gif=False, gif_title="gmm_kernel")

In [None]:
vk_4 = VirgoKernel(vc_4, k_nystroem=4000, pca_comp=5)
vk_4()
vc_4.print_datastats()

virgo_mixture = VirgoMixture(vc_4, n_comp=8)
elbo = virgo_mixture.fit()

print(f"ELBO: {elbo}")
print(f"Mixture weights {virgo_mixture.model.weights_}")

virgo_mixture.predict(remove_uncertain_labels=True)
labels_removed = vc_4.get_labels(return_counts=True)
print(labels_removed)

vc_4.plot_cluster(n_step=5, plot_kernel_space=True, store_gif=False, gif_title="gmm_kernel_kspace")
vc_4.plot_cluster(n_step=5, store_gif=False, gif_title="gmm_kernel")

In [None]:
vc_2.run_fof(linking_length=80., min_group_size=25, use_scaled_data=False)

labels, counts = vc_2.get_labels(return_counts=True)
print(labels, counts)
print(counts[1:].sum())
# vc_2.plot_cluster(n_step=1, plot_kernel_space=True, store_gif=False, gif_title="fof_kernel_kspace")
vc_2.plot_cluster(n_step=1, store_gif=False, gif_title="fof_kernel")

In [None]:
vc_3 = VirgoCluster(file_name=None)
vc_3.data = vc_2.cluster[vc_2.cluster_labels >=0]
vc_3.scale_data(use_dim=[0, 1, 2])
vc_3.print_datastats()

In [None]:
vk_3 = VirgoKernel(vc_3, k_nystroem=4000, pca_comp=5)
vk_3()
vc_3.print_datastats()

In [None]:
# rbf 800
vc_3.run_fof(linking_length=50., min_group_size=200, use_scaled_data=False)

labels, counts = vc_3.get_labels(return_counts=True)
print(len(labels))
print(labels, counts)
vc_3.plot_cluster(n_step=1, store_gif=False, gif_title="fof_kernel")

In [None]:
# rbf 800
vc_3.run_fof(linking_length=0.036, min_group_size=200, use_scaled_data=True)

labels, counts = vc_3.get_labels(return_counts=True)
print(len(labels))
print(labels, counts)
vc_3.plot_cluster(n_step=1, plot_kernel_space=True, store_gif=False, gif_title="fof_kernel_kspace")
vc_3.plot_cluster(n_step=1, store_gif=False, gif_title="fof_kernel")

In [None]:
virgo_mixture = VirgoMixture(vc_3, n_comp=6)
elbo = virgo_mixture.fit()

print(f"ELBO: {elbo}")
print(f"Mixture weights {virgo_mixture.model.weights_}")

virgo_mixture.predict(remove_uncertain_labels=True)
labels_removed = vc_3.get_labels(return_counts=True)
print(labels_removed)

vc_3.plot_cluster(n_step=1, plot_kernel_space=True, store_gif=False, gif_title="gmm_kernel_kspace")
vc_3.plot_cluster(n_step=1, store_gif=False, gif_title="gmm_kernel")

In [None]:
virgo_clustering = VirgoClustering(vc_3, min_samples=100)
virgo_clustering.predict()
vc_3.remove_small_groups(remove_thresh=200)
vc_3.sort_labels()
labels_removed = vc_3.get_labels(return_counts=True)
print(labels_removed)

vc_3.plot_cluster(n_step=1, plot_kernel_space=True)
vc_3.plot_cluster(n_step=1)

In [None]:
# virgo_clustering = VirgoClustering(vc_3, n_clusters=8, clustering_type="agglo")
# virgo_clustering.predict()
# labels_removed = vc_3.get_labels(return_counts=True)
# print(labels_removed)

# vc_3.plot_cluster(n_step=1, plot_kernel_space=True)
# vc_3.plot_cluster(n_step=1)

In [None]:
# virgo_clustering = VirgoClustering(virgo_cluster, n_clusters=10, clustering_type="spectral")
# virgo_clustering.predict()
# labels_removed = virgo_cluster.get_labels(return_counts=True)
# print(labels_removed)

# virgo_cluster.plot_cluster(n_step=5, plot_kernel_space=True)
# virgo_cluster.plot_cluster(n_step=5)

In [None]:
virgo_clustering = VirgoClustering(vc_3, min_samples=10, clustering_type="dbscan")
virgo_clustering.predict()
vc_3.remove_small_groups(remove_thresh=200)
vc_3.sort_labels()
labels_removed = vc_3.get_labels(return_counts=True)
print(labels_removed)

vc_3.plot_cluster(n_step=1, plot_kernel_space=True)
vc_3.plot_cluster(n_step=1)

In [None]:
# import matplotlib.pyplot as plt

In [None]:
# elbos = []
# bics = []
# for i in range(2, 45, 3):
#     virgo_mixture = VirgoMixture(virgo_cluster, n_comp=i, mixture_type="gaussian")
#     elbo = virgo_mixture.fit()
#     elbos.append(elbo)
#     bic = virgo_mixture.model.bic(virgo_cluster.scaled_data)
#     bics.append(bic)
#     print(i, elbo, bic)

# print(elbos)
# print(bics)

In [None]:
# plt.plot(elbos)
# plt.show()   

In [None]:
# plt.plot(bics)
# plt.show()  