## VIRGO: Scaling to full datasets with stochastic variational deep kernel learning Gaussian process


In [None]:
%%capture
!pip install gpytorch
!pip install pyfof

In [None]:
import numpy as np
import matplotlib.pyplot as plt

import torch
from sklearn.preprocessing import StandardScaler
import gpytorch
from torch.utils.data import TensorDataset, DataLoader
import os

from google.colab import drive
drive.mount('/content/drive')
os.chdir("/content/drive/MyDrive/virgo/")

torch.manual_seed(2022)
np.random.seed(2022)

%matplotlib inline
%load_ext autoreload
%autoreload 2

# https://arxiv.org/pdf/1511.02222.pdf

In [None]:
from virgo.data.cluster import VirgoCluster
from virgo.data.cleaner import AutoDensityCleaner
from virgo.models.kernel import VirgoKernel
from virgo.models.mixture import VirgoMixture
from virgo.models.dklmodel import DKLModel
from virgo.models.dkltrainer import DKLTrainer

In [None]:
all_clusters = np.loadtxt('../data/virgo_data/vc_fitted_790_cluster.txt')[:, [0, 1, 2, 3, 4, 5, 6]]
all_labs = np.loadtxt('../data/virgo_data/vc_fitted_790_cluster_labels.txt')
# all_clusters = np.loadtxt('../data/virgo_data/vc_box_fitted_set0_cluster.txt')[:, [0, 1, 2, 3, 4, 5, 6]]
# all_labs = np.loadtxt('../data/virgo_data/vc_box_fitted_set0_cluster_labels.txt')
print(all_clusters.shape, all_labs.shape)

In [None]:
# should be shuffled already
all_data = np.array([*all_clusters.T, all_labs]).T
np.random.shuffle(all_data)
all_clusters = all_data[:, :-1]
all_labs = all_data[:, -1].T
all_labs_cp = all_labs

In [None]:
# Use only spatial points and shock normal (i.e. no Mach number and no HSML lenght)
use_dim = [0, 1, 2, 3, 4, 5]
n_dim = len(use_dim)
n_classes = np.unique(all_labs_cp[all_labs_cp!=-1.]).shape[0]
all_clusters = all_clusters[:, use_dim]

scaler = StandardScaler()
scaler.fit(all_clusters)
all_clusters = scaler.transform(all_clusters)

print(n_dim, n_classes)
print(all_clusters.min(), all_clusters.max(), all_clusters.mean())
print(all_clusters.shape, all_labs.shape)
print(all_clusters[:5], all_labs[:5])

In [None]:
train_x_np = all_clusters[all_labs_cp!=-1.]
train_y_np = all_labs_cp[all_labs_cp!=-1.]

n_cut = int(train_x_np.shape[0] * 0.9)
train_x = torch.tensor(train_x_np[:n_cut], dtype=torch.float32)
train_y = torch.tensor(train_y_np[:n_cut], dtype=torch.float32)
val_x = torch.tensor(train_x_np[n_cut:], dtype=torch.float32)
val_y = torch.tensor(train_y_np[n_cut:], dtype=torch.float32)

train_dataset = TensorDataset(train_x, train_y)
val_dataset = TensorDataset(val_x, val_y)
train_loader = DataLoader(train_dataset, batch_size=1024, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=1024, shuffle=True)

print(train_x.shape, train_y.shape, val_x.shape, val_y.shape)

In [None]:
model = DKLModel()
likelihood = gpytorch.likelihoods.SoftmaxLikelihood(
    num_features=model.num_feat,
    num_classes=n_classes,
)

# summary(model, input_size=torch.rand(1024, n_dim).shape, device="cpu")
# for p in model.named_parameters():
#     print(p)

if torch.cuda.is_available():
    model = model.cuda()
    likelihood = likelihood.cuda()

In [None]:
trainer = DKLTrainer(
    model=model,
    likelihood=likelihood,
    train_loader=train_loader,
    val_loader=val_loader,
)

trainer.train()

In [None]:
# Set into eval mode
model.eval()
likelihood.eval()

# Initialize plots
fig, axs = plt.subplots(n_dim, 1, figsize=(4, 3 * n_dim))
with torch.no_grad():
    for x_batch, y_batch in trainer.val_loader:
        if torch.cuda.is_available():
                x_batch, y_batch = x_batch.cuda(), y_batch.cuda()
        
        x_batch = x_batch[::10]
        y_batch = y_batch[::10]
        # mean = torch.round(model(x_batch).mean).cpu()
        # mean = (model(x_batch).mean).cpu()
        output = likelihood(model(x_batch))  # This gives us 16 samples from the predictive distribution
        mean = output.probs.mean(0).argmax(-1).cpu()
        
        for xdim in range(n_dim):
            ax = axs[xdim]
            
            ax.plot(x_batch[:, xdim].cpu().detach().numpy(), mean.detach().numpy(), '*b')
            ax.plot(x_batch[:, xdim].cpu().detach().numpy(), y_batch.cpu().detach().numpy(), 'xr', alpha=0.99)
            ax.legend([ 'Mean', 'Observed Data'])
            ax.set_title(f'Dim {xdim}')
        break

fig.tight_layout()
plt.show()

In [None]:
# Plot full training data set
fig = plt.figure(figsize=(12, 12))
fig.suptitle("Full training data set")
ax = fig.add_subplot(projection='3d')

plot_data = all_clusters
plot_y = all_labs
print(plot_data.shape, plot_y.sum())
ax.scatter(plot_data.T[0], plot_data.T[1], plot_data.T[2], c=plot_y, marker=".", cmap="plasma")

plt.show()

In [None]:
fig = plt.figure(figsize=(12, 12))
fig.suptitle("Model predictions of full training data set")
ax = fig.add_subplot(projection='3d')

eval_data = torch.tensor(all_clusters, dtype=torch.float32)
if torch.cuda.is_available(): 
    eval_data = eval_data.cuda()

output = likelihood(model(eval_data))
mean = output.probs.mean(0).argmax(-1).cpu()
plot_y = mean.cpu().detach().numpy()
plot_data = eval_data.cpu().detach().numpy()

print(mean.min(), mean.max())
print(plot_data.shape, mean.sum())
ax.scatter(plot_data.T[0], plot_data.T[1], plot_data.T[2], c=plot_y, marker=".", cmap="plasma")
plt.show()

## Apply on raw data set

In [None]:
snap_id = 790
filebase = f"../data/virgo_data/250x_hd/snap_{snap_id}"

virgo_cluster = VirgoCluster(
    file_name=filebase, io_mode=1, cut_mach_dim=-2, n_max_data=800000, 
)

virgo_cluster.scale_data()
virgo_cluster.print_datastats()
virgo_cluster.plot_raw_hists(
    bins=100, plot_range=[[2000., 8000.], [-6000., 1000.], [-3000., 6000.]]
)

In [None]:
virgo_kernel = VirgoKernel(virgo_cluster, k_nystroem=100, pca_comp=5)
virgo_kernel()
virgo_cluster.print_datastats()

In [None]:
virgo_mixture = VirgoMixture(virgo_cluster, n_comp=2)
elbo = virgo_mixture.fit()

print(f"ELBO: {elbo}")
print(f"Mixture weights {virgo_mixture.model.weights_}")

virgo_mixture.predict(remove_uncertain_labels=False)
labels_removed = virgo_cluster.get_labels(return_counts=True)
print(labels_removed)

# virgo_cluster.cluster_labels[virgo_cluster.cluster_labels == 1] = 2
# virgo_cluster.cluster_labels[virgo_cluster.cluster_labels == 1][0:100] = 4
virgo_cluster.plot_cluster(cmap_vmax=4, n_step=25)

In [None]:
d_cleaner = AutoDensityCleaner(virgo_cluster)
d_cleaner.clean()
print(virgo_cluster.get_labels(return_counts=True))
virgo_cluster.plot_cluster(n_step=10)

In [None]:
print(virgo_cluster.data.shape, virgo_cluster.cluster.shape, virgo_cluster.cluster_labels.shape)
print(virgo_cluster.data[virgo_cluster.cluster_labels >= 0].shape)
eval_data = virgo_cluster.cluster[virgo_cluster.cluster_labels >= 0]
eval_data = scaler.transform(eval_data[:, [1, 2, 3, 4, 5, 6]])

In [None]:
fig = plt.figure(figsize=(8, 8))
ax = fig.add_subplot(projection='3d')
ax.scatter(eval_data.T[0][::10], eval_data.T[1][::10], eval_data.T[2][::10], marker=".", cmap="plasma")
plt.show()

In [None]:
eval_data = torch.tensor(eval_data, dtype=torch.float32)
if torch.cuda.is_available(): 
    eval_data = eval_data.cuda()

output = likelihood(model(eval_data))
mean = output.probs.mean(0).argmax(-1).cpu()
for i in [0., 45., 90., 135., 180., 225., 270.]:
    fig = plt.figure(figsize=(8, 8))
    ax = fig.add_subplot(projection='3d')
    plot_y = mean.cpu().detach().numpy()
    plot_data = eval_data.cpu().detach().numpy()

    print(mean.min(), mean.max())
    print(plot_data.shape, mean.sum())
    ax.scatter(plot_data.T[0][::5], plot_data.T[1][::5], plot_data.T[2][::5], c=plot_y[::5], marker=".", cmap="plasma")
    ax.azim = i
    ax.dist = 10
    ax.elev = 30
    plt.show()

In [None]:
# SVGP https://docs.gpytorch.ai/en/stable/examples/04_Variational_and_Approximate_GPs/SVGP_Regression_CUDA.html
# SVGP CLass https://docs.gpytorch.ai/en/stable/examples/04_Variational_and_Approximate_GPs/Non_Gaussian_Likelihoods.html
# DKL Multiclass https://docs.gpytorch.ai/en/stable/examples/06_PyTorch_NN_Integration_DKL/Deep_Kernel_Learning_DenseNet_CIFAR_Tutorial.html
# Exact Dirichlet https://docs.gpytorch.ai/en/stable/examples/01_Exact_GPs/GP_Regression_on_Classification_Labels.html?highlight=dirichlet

# https://github.com/cornellius-gp/gpytorch/issues/1396