In [None]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import animation
from virgo.data.cluster import VirgoCluster

%load_ext autoreload
%autoreload 2

%matplotlib notebook

In [None]:
store_gif = True
remove_uncertain= True
n_step=1
cluster_label = None
gif_title = None
axs_label = ["x [c kpc / h]", "y [c kpc / h]", "z [c kpc / h]"]
maker_size = 6.0

fig = plt.figure(figsize=(8, 8))
ax = fig.add_subplot(projection="3d")

def get_plot_data(snap_id: int):
    file_name = "/home/max/Software/virgo/demo_notebooks/vc_methodD_"
    virgo_cluster = VirgoCluster(None)
    virgo_cluster.cluster = np.loadtxt(f"{file_name}{int(snap_id)}_cluster.txt")
    virgo_cluster.cluster_labels = np.loadtxt(f"{file_name}{int(snap_id)}_cluster_labels.txt")
    plot_data = virgo_cluster.cluster[::n_step]
    plot_label = virgo_cluster.cluster_labels[::n_step]

    if remove_uncertain:
        uncertain_mask = plot_label >= 0
        plot_data = plot_data[uncertain_mask]
        plot_label = plot_label[uncertain_mask]

    if cluster_label is not None:
        for target_ind, target_label in enumerate(cluster_label):
            curr_data = plot_data[plot_label == target_label]
            curr_label = plot_label[plot_label == target_label]
            if target_ind == 0:
                plot_data_filt = curr_data
                plot_label_filt = curr_label
            else:
                plot_data_filt = np.concatenate([plot_data_filt, curr_data])
                plot_label_filt = np.concatenate([plot_label_filt, curr_label])

        plot_data = plot_data_filt
        plot_label = plot_label_filt
        
    return plot_data, plot_label



def animate(i=-60):
    snap = 750
    if i > 30:
        snap = 760
    if i > 60:
        snap = 770
    if i > 90:
        snap = 780
    if i > 120:
        snap = 790
    if i > 150:
        snap = 800
    if i > 180:
        snap = 810
    if i > 210:
        snap = 820
    print(i, snap)
    plot_data, plot_label = get_plot_data(snap)
    
    ax.cla()
    ax.title.set_text(f"Snap {snap}")
    ax.set(xlabel=axs_label[0], ylabel=axs_label[1], zlabel=axs_label[2])
    ax.set_xlim((1774.3607604980468, 7015.55903930664))
    ax.set_ylim((-4975.025365447998, 378.096150970459))
    ax.set_zlim((-112.7749122619629, 4446.502069854736))
    
    ax.scatter(
        plot_data.T[0],
        plot_data.T[1],
        plot_data.T[2],
        c=plot_label,
        marker=".",
        cmap="plasma",
        s=maker_size,
    )
    # azimuth angle : 0 deg to 360 deg
    ax.view_init(elev=10, azim=(-60 + i * 0.33))
#     print(ax.get_xlim())
#     print(ax.get_ylim())
#     print(ax.get_zlim())
    return (fig,)

def init():
    pass


if store_gif:
    # Animate
    ani = animation.FuncAnimation(
        fig, animate, init_func=init, frames=240, interval=100, blit=True
    )
    if gif_title is None:
        file_name = "test"
    else:
        file_name = gif_title
    ani.save(file_name + ".gif", writer="imagemagick", fps=15)
else:
    init()
    animate()

plt.show()