# Virgo scaling to full datasets with SV-DKL
This notebook is setup to run on Google Colab Pro and needs the labeled subset result from the virgo denoising and labeling notebook.

In [None]:
%%capture
!pip install gpytorch
!pip install pyfof

In [None]:
import numpy as np
import matplotlib.pyplot as plt

import torch
from sklearn.preprocessing import StandardScaler
import gpytorch
from torch.utils.data import TensorDataset, DataLoader
import os

from google.colab import drive
drive.mount('/content/drive')
os.chdir("/content/drive/MyDrive/virgo/")

torch.manual_seed(2022)
np.random.seed(2022)

%matplotlib inline
%load_ext autoreload
%autoreload 2

In [None]:
from virgo.data.cluster import VirgoCluster
from virgo.data.cleaner import AutoDensityCleaner
from virgo.models.kernel import VirgoKernel
from virgo.models.mixture import VirgoMixture
from virgo.models.dklmodel import DKLModel
from virgo.models.dkltrainer import DKLTrainer

### Import labaled subset of data set and pre-processing

Please adjust the import path as you've set it up.

In [None]:
snap_id_dkl = 750
all_clusters = np.loadtxt(f'./vc_fitted_{snap_id_dkl}_cluster.txt')[:, [0, 1, 2, 3, 4, 5, 6]]
all_labs = np.loadtxt(f'./vc_fitted_{snap_id_dkl}_cluster_labels.txt')
print(all_clusters.shape, all_labs.shape)

In [None]:
# should be shuffled already
all_data = np.array([*all_clusters.T, all_labs]).T
np.random.shuffle(all_data)
all_clusters = all_data[:, :-1]
all_labs = all_data[:, -1].T
all_labs_cp = all_labs

In [None]:
# Use only spatial points and shock normal (i.e. no Mach number and no HSML lenght)
use_dim = [0, 1, 2, 3, 4, 5]
n_dim = len(use_dim)
n_classes = np.unique(all_labs_cp[all_labs_cp!=-1.]).shape[0]
all_clusters = all_clusters[:, use_dim]

scaler = StandardScaler()
scaler.fit(all_clusters)
all_clusters = scaler.transform(all_clusters)

print(n_dim, n_classes)
print(all_clusters.shape, all_labs.shape)

In [None]:
train_x_np = all_clusters[all_labs_cp!=-1.]
train_y_np = all_labs_cp[all_labs_cp!=-1.]

n_cut = int(train_x_np.shape[0] * 0.8)
train_x = torch.tensor(train_x_np[:n_cut], dtype=torch.float32)
train_y = torch.tensor(train_y_np[:n_cut], dtype=torch.float32)
val_x = torch.tensor(train_x_np[n_cut:], dtype=torch.float32)
val_y = torch.tensor(train_y_np[n_cut:], dtype=torch.float32)
n_cut = int(val_x.shape[0] * 0.5)
test_x = val_x[n_cut:]
test_y = val_y[n_cut:]
val_x = val_x[:n_cut]
val_y = val_y[:n_cut]

train_dataset = TensorDataset(train_x, train_y)
val_dataset = TensorDataset(val_x, val_y)
test_dataset = TensorDataset(test_x, test_y)
train_loader = DataLoader(train_dataset, batch_size=1024, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=1024, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=1024, shuffle=True)

print(train_x.shape, train_y.shape, val_x.shape, val_y.shape, test_x.shape, test_y.shape)

### Create model and trainer

In [None]:
model = DKLModel()
likelihood = gpytorch.likelihoods.SoftmaxLikelihood(
    num_features=model.num_feat,
    num_classes=n_classes,
)

if torch.cuda.is_available():
    model = model.cuda()
    likelihood = likelihood.cuda()

### Train and test

In [None]:
model = DKLModel()
likelihood = gpytorch.likelihoods.SoftmaxLikelihood(
    num_features=model.num_feat,
    num_classes=n_classes,
)

if torch.cuda.is_available():
    model = model.cuda()
    likelihood = likelihood.cuda()

trainer = DKLTrainer(
    model=model,
    likelihood=likelihood,
    train_loader=train_loader,
    val_loader=val_loader,
    test_loader=test_loader,
)

trainer.train()

test_acc = trainer.test()
print(test_acc)

### Classificaiton results

In [None]:
# Plot full training data set
fig = plt.figure(figsize=(12, 12))
fig.suptitle("Full training data set")
ax = fig.add_subplot(projection='3d')

plot_data = all_clusters
plot_y = all_labs
ax.scatter(plot_data.T[0], plot_data.T[1], plot_data.T[2], c=plot_y, marker=".", cmap="plasma")

plt.show()

In [None]:
# Plot predictions
fig = plt.figure(figsize=(12, 12))
fig.suptitle("Model predictions of full training data set")
ax = fig.add_subplot(projection='3d')

eval_data = torch.tensor(all_clusters, dtype=torch.float32)
if torch.cuda.is_available(): 
    eval_data = eval_data.cuda()

output = likelihood(model(eval_data))
mean = output.probs.mean(0).argmax(-1).cpu()
plot_y = mean.cpu().detach().numpy()
plot_data = eval_data.cpu().detach().numpy()

ax.scatter(plot_data.T[0], plot_data.T[1], plot_data.T[2], c=plot_y, marker=".", cmap="plasma")
plt.show()

## Apply on raw data set for full reconstruction of resolution

Please adjust the import path for your setup. This section is almost identical to the other demo notebook, as we are using the pre-cleaned data set as input.

In [None]:
snap_id = 750

cdir = os.getcwd()
filebase = cdir + f"/data/snap_{snap_id}"

virgo_cluster = VirgoCluster(
    file_name=filebase, io_mode=1, cut_mach_dim=-2, n_max_data=800000, 
)

virgo_cluster.scale_data()
virgo_cluster.print_datastats()
virgo_cluster.plot_raw_hists(bins=100)

In [None]:
virgo_kernel = VirgoKernel(virgo_cluster, k_nystroem=100, pca_comp=5)
virgo_kernel()
virgo_cluster.print_datastats()

In [None]:
virgo_mixture = VirgoMixture(virgo_cluster, n_comp=2)
elbo = virgo_mixture.fit()

virgo_mixture.predict(remove_uncertain_labels=False)
labels_removed = virgo_cluster.get_labels(return_counts=True)
print(labels_removed)

virgo_cluster.plot_cluster(cmap_vmax=4, n_step=25)

In [None]:
d_cleaner = AutoDensityCleaner(virgo_cluster)
d_cleaner.clean()
virgo_cluster.plot_cluster(n_step=10)

In [None]:
eval_data = virgo_cluster.cluster[virgo_cluster.cluster_labels >= 0]
eval_data = scaler.transform(eval_data[:, [1, 2, 3, 4, 5, 6]])

In [None]:
fig = plt.figure(figsize=(8, 8))
ax = fig.add_subplot(projection='3d')
ax.scatter(eval_data.T[0][::10], eval_data.T[1][::10], eval_data.T[2][::10], marker=".", cmap="plasma")
plt.show()

### Final predictions and full scalability

In [None]:
eval_data = torch.tensor(eval_data, dtype=torch.float32)
if torch.cuda.is_available(): 
    eval_data = eval_data.cuda()

output = likelihood(model(eval_data))
mean = output.probs.mean(0).argmax(-1).cpu()
plot_y = mean.cpu().detach().numpy()
plot_data = eval_data.cpu().detach().numpy()
for i in [0., 45., 90., 135., 180., 225., 270.]:
    fig = plt.figure(figsize=(8, 8))
    ax = fig.add_subplot(projection='3d')
    print(mean.min(), mean.max())
    print(plot_data.shape, mean.sum())
    plot_steps = 2
    ax.scatter(plot_data.T[0][::plot_steps], plot_data.T[1][::plot_steps], plot_data.T[2][::plot_steps], c=plot_y[::plot_steps], marker=".", s=0.75, cmap="plasma")
    ax.azim = i
    ax.dist = 10
    ax.elev = 30
    plt.show()