In [None]:
from virgo.cluster import VirgoCluster
from virgo.kernel import VirgoKernel, VirgoSimpleKernel
from virgo.mixture import VirgoMixture, VirgoClustering
from virgo.cleaner import AutoDensityCleaner

%load_ext autoreload
%autoreload 2

# %matplotlib notebook
%matplotlib inline

In [None]:
snap_id = 810
filebase = f"/home/max/Software/virgo/data/250x_hd/snap_{snap_id}"

# snap_id = 38
# filebase = f"/home/max/Software/virgo/data/250x_mhd/250x_mhd_snap_0{snap_id}"

virgo_cluster = VirgoCluster(
    file_name=filebase, io_mode=1, cut_mach_dim=-2, n_max_data=800000, 
)
# virgo_cluster.data = virgo_cluster.data[:, :-1]

virgo_cluster.scale_data()
virgo_cluster.print_datastats()
virgo_cluster.plot_raw_hists(
    bins=100, plot_range=[[2000., 8000.], [-6000., 1000.], [-3000., 6000.]]
)

In [None]:
virgo_kernel = VirgoKernel(virgo_cluster, k_nystroem=100, pca_comp=5)
virgo_kernel()
virgo_cluster.print_datastats()

In [None]:
virgo_mixture = VirgoMixture(virgo_cluster, n_comp=2)
elbo = virgo_mixture.fit()

print(f"ELBO: {elbo}")
print(f"Mixture weights {virgo_mixture.model.weights_}")

virgo_mixture.predict(remove_uncertain_labels=False)
labels_removed = virgo_cluster.get_labels(return_counts=True)
print(labels_removed)

# virgo_cluster.cluster_labels[virgo_cluster.cluster_labels == 1] = 2
# virgo_cluster.cluster_labels[virgo_cluster.cluster_labels == 1][0:100] = 4

# virgo_cluster.plot_cluster(cmap_vmax=4, n_step=25, plot_kernel_space=True, store_gif=False, gif_title="nystroem_separation_kernelspace")
virgo_cluster.plot_cluster(cmap_vmax=4, n_step=25, store_gif=False, gif_title="nystroem_separation_kernelspace")

In [None]:
d_cleaner = AutoDensityCleaner(virgo_cluster)
d_cleaner.clean()
print(virgo_cluster.get_labels(return_counts=True))
virgo_cluster.plot_cluster(n_step=10)

In [None]:
vc_2 = VirgoCluster(file_name=None)
vc_2.data = virgo_cluster.cluster[virgo_cluster.cluster_labels >=0][::10]
vc_2.scale_data()
vc_2.print_datastats()

In [None]:
virgo_kernel = VirgoKernel(
    vc_2, k_nystroem=500, pca_comp=6, spatial_dim=[0, 1, 2, 3, 4, 5]
)

virgo_kernel(virgo_kernel.custom_kernel)
vc_2.print_datastats()

In [None]:
# full data 
# ll = {
#     "750": 0.101,
#     "760": 0.14,
#     "770": 0.24,
#     "780": 0.190,
#     "790": 0.125,
#     "800": 0.115,
#     "810": 0.130,
#     "820": 0.13,
# }
#  n_max_data=800000 
ll = {
    "750": 0.11,
    "760": 0.14,
    "770": 0.2,
    "780": 0.190,
    "790": 0.15,
    "800": 0.135,
    "810": 0.11,
    "820": 0.13,
}

print(snap_id, ll[f"{snap_id}"])

vc_2.run_fof(
#     linking_length=ll[f"{snap_id}"],
#     linking_length=0.21,
#     linking_length=ll_est,
    min_group_size=700,
    use_scaled_data=True,
)

labels, counts = vc_2.get_labels(return_counts=True)
print(labels, counts)
vc_2.plot_cluster(n_step=1, plot_kernel_space=True)
vc_2.plot_cluster(n_step=1, maker_size=3.0)
# vc_2.plot_cluster(n_step=1, remove_uncertain=False)

In [None]:
# labels, counts = vc_2.get_labels(return_counts=True)
# vc_2.cluster_labels[vc_2.cluster_labels <0] = labels.shape[0]
# vc_2.plot_cluster(n_step=1)

In [None]:
# vc_2.export_cluster(f"vc_methodD_{snap_id}", remove_uncertain=False, remove_evno=True)

In [None]:
# (snap_id, ll[f"{snap_id}"], vc_2.scaled_data.std(), vc_2.scaled_data.var(), get_avg_nn_dist(vc_2.scaled_data))