## Xplore

In [None]:
import matplotlib.pyplot as plt
import torch
import torchvision.transforms as T
import numpy as np
from torchvision.datasets import STL10
import pyrootutils  
from sklearn.decomposition import PCA
import pandas as pd
import seaborn as sns

root = pyrootutils.setup_root(
    "/home/jmordacq/Documents/IRBA/dev/sim-RIPS/", # path to the root directory
    project_root_env_var=True,
    pythonpath=True)

from toyxp.data_utils import Augment

#### utils

In [None]:
def imshow(img):
    """
    shows an imagenet-normalized image on the screen
    """
    mean = torch.tensor([0.485, 0.456, 0.406], dtype=torch.float32)
    std = torch.tensor([0.229, 0.224, 0.225], dtype=torch.float32)
    unnormalize = T.Normalize((-mean / std).tolist(), (1.0 / std).tolist())
    npimg = unnormalize(img).numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.show()


### Datasets

In [None]:
dataset = STL10(root='/home/jmordacq/Documents/IRBA/misc/datasets', split='train', transform=Augment().test_transform, download=True)
images = [dataset[i][0].numpy().flatten() for i in range(len(dataset))]
labels = [dataset[i][1] for i in range(len(dataset))]

images_tensor = np.vstack(images)
images_tensor = torch.tensor(images_tensor, dtype=torch.float32)

test = STL10(root='/home/jmordacq/Documents/IRBA/misc/datasets', split='test', transform=Augment().test_transform, download=True)
images_test = [test[i][0].numpy().flatten() for i in range(len(test))]
labels_test = [test[i][1] for i in range(len(test))]

In [None]:
dataset_transform = STL10(root='/home/jmordacq/Documents/IRBA/misc/datasets', split='train', transform=Augment(), download=True)
images_view1 = [dataset[i][0][0].numpy().flatten() for i in range(len(dataset))]
images_view2 = [dataset[i][0][1].numpy().flatten() for i in range(len(dataset))]
labels = [dataset[i][1] for i in range(len(dataset))]

images_tensor_view1 = np.vstack(images_view1)
images_tensor_view1 = torch.tensor(images_tensor_view1, dtype=torch.float32)


images_tensor_view2 = np.vstack(images_view2)
images_tensor_view2 = torch.tensor(images_view2, dtype=torch.float32)

#### pca

In [None]:
pca = PCA(n_components=3)
images_pca = pca.fit_transform(images)

In [None]:
df_pca = pd.DataFrame(images_pca, columns=["x", "y", "z"])
df_pca["labels"] = labels
sns.pairplot(df_pca, hue="labels", palette="tab10", vars=["x", "y", "z"])

In [None]:
fig = plt.figure()
ax = plt.axes(projection='3d')
ax.view_init(0, 10, 0)
ax.scatter3D(images_pca[:,0], 
             images_pca[:,1], 
             images_pca[:,2], 
             c=labels, 
             cmap="tab10")


### Persistent, structure in the dataset ?

#### dataset as is

In [None]:
import torch_topological.nn

In [None]:
rips = torch_topological.nn.VietorisRipsComplex(dim=1, keep_infinite_features=True)
pi = rips(images_tensor)

In [None]:
def plot_diagram(pi, title="Persistent Diagram"):
    max_x_1, max_y_1 = 0, 0
    figure = plt.figure()
    ax = figure.add_subplot(111)
    colors = plt.cm.viridis(np.linspace(0, 1, len(pi)))
    for dim in range(len(pi)):
        diag = pi[dim].diagram.detach().cpu().numpy()
        if len(diag) > 0:
            inf_idx = np.where(np.isinf(diag[:,1]))
            birth_inf = diag[inf_idx,0]
            diag = np.delete(diag, inf_idx, axis=0)
            max_x_1 = max(max_x_1, np.max(diag[:,0])) + 1
            max_y_1 = max(max_y_1, np.max(diag[:,1])) + max(max_y_1, np.max(diag[:,1]))*0.1
            ax.scatter(diag[:, 0], 
                        diag[:, 1], 
                        c=colors[dim],
                        marker="x",
                        label=f"$H_{dim}$ ")
            ax.scatter(birth_inf, 
                    np.repeat(max_y_1, birth_inf.shape[1]),
                    marker="o",
                    s=30,
                    c=colors[dim],
                    label=f"$H_{dim}$ - inf")
    maxi = max(max_x_1, max_y_1)
    ax.plot([0, maxi], [0, maxi], 'k--')
    ax.set_xlim([0, maxi])
    ax.set_ylim([0, maxi])
    ax.set_xlabel("Birth")
    ax.set_ylabel("Death")
    ax.set_title(title)
    ax.legend(loc="lower right")
    plt.show()


In [None]:
plot_diagram(pi)

In [None]:
for i in range(10):
    class_idx = np.where(np.array(labels) == i)
    plot_diagram(rips(images_tensor[class_idx]))
    print(f"Label {i}: {labels.count(i)}")

### considering the transforms

In [None]:
rips = torch_topological.nn.VietorisRipsComplex(dim=1, keep_infinite_features=True)
pi_1 = rips(images_tensor_view1[:500])
pi_2 = rips(images_tensor_view2[:500])

In [None]:
plot_diagram(pi_1, title="Persistent diagram - View 1")
plot_diagram(pi_2, title="Persistent diagram - View 2")

In [None]:
wasserstein = torch_topological.nn.WassersteinDistance()
wasserstein(pi_1, pi_2)

### Considering the transforms per class

In [None]:
for i in range(10):
    class_idx = np.where(np.array(labels) == i)
    pi_1 = rips(images_tensor_view1[class_idx])
    pi_2 = rips(images_tensor_view2[class_idx])
    plot_diagram(pi_1, title="Persistent diagram - View 1 - class {}".format(i))
    plot_diagram(pi_2, title="Persistent diagram - View 2 - class {}".format(i))

    print(wasserstein(pi_1, pi_2))
    
    print(f"Label {i}: {labels.count(i)}")