In [None]:
import plotly.express as px
import pandas as pd
import numpy as np
import torch
import seaborn as sns
from rindti.losses import GeneralisedLiftedStructureLoss, SoftNearestNeighborLoss
from sklearn.manifold import TSNE

import plotly.graph_objs as go

import numpy as np
from ipywidgets import interact
from sklearn.decomposition import PCA
from umap import UMAP

In [None]:
def compress(dist, scale):
    centre = dist.mean(dim=0)
    displacement = (centre - dist).norm(dim=1, keepdim=True)
    return dist / (1 + displacement * scale)

class Plotter:
    def __init__(self, num=100, dims=2, loss=GeneralisedLiftedStructureLoss):
        self.num = num
        self.dims = dims
        self.loss = loss
        self.dist = torch.normal(0,1,size=(num, dims))

    def get_data(self, distance, scale, *args, **kwargs):
        a = compress(self.dist, scale) + torch.ones(self.dims) * distance
        b = compress(self.dist, scale) - torch.ones(self.dims) * distance
        data = torch.cat((a, b))
        fam_idx = torch.tensor([0] * self.num + [3] * self.num).view(-1, 1)
        loss = self.loss(*args, **kwargs)
        losses = loss.forward(data, fam_idx).tolist()
        if data.shape[1] > 2:
            data = PCA(n_components=2).fit_transform(data)
        else:
            data = data.numpy()
        data = pd.DataFrame(data, columns=["x", "y"])
        data["distance"] = distance
        data["scale"] = scale
        data['loss'] = losses
        data['symbol'] = fam_idx.view(-1).tolist()
        return data
    
    def get_loss(self, *args, **kwargs):
        data = self.get_data(*args, **kwargs)
        return data['loss'].mean()

In [None]:
fig = go.FigureWidget()
scatt = fig.add_trace(go.Scatter())
fig.update_layout(
                  width=1000,
                  height=600,
                 )
fig.update_traces(marker=dict(size=12), mode="markers")
p = Plotter(num=100, dims=2, loss=SoftNearestNeighborLoss)
@interact(distance=(0, 1, 0.1), scale=(0, 1.0, 0.1), temperature=(-3, 3, 1))
def update(distance=0, scale=0, temperature=0, optim_temperature=False):
    with fig.batch_update():
        data = p.get_data(distance, scale, temperature=10**temperature, optim_temperature=optim_temperature)
        scatt["data"][0].x= data['x']
        scatt["data"][0].y= data['y']
        scatt["data"][0].marker.color = data['loss'] + 1
        scatt["data"][0].marker.symbol = data['symbol']
        scatt["data"][0].hovertext = data['loss']
        fig['layout'].title = "Loss = {}".format(data['loss'].mean())
fig

In [None]:
p = Plotter(dims=2, num=25, loss=SoftNearestNeighborLoss)
t = []
for distance in np.linspace(0,1,25):
    for scale in np.linspace(0,1,25):
        t.append((distance, scale, p.get_loss(distance, scale)))
t = pd.DataFrame(t, columns=['distance', "scale", "loss"])

fig = px.scatter_3d(data_frame=t, x="distance", y="scale", z="loss", color="loss", width=1000, height=1000)
fig.update_coloraxes(showscale=False)
fig.update_layout(title="How distance between 2 clusters and their std affect SNN Loss")
# fig.write_html("scale_dist_loss_2d.html")
fig.show()

In [None]:
n = 256
t = torch.normal(0,1, size=(n, 32))
fam_idx = torch.tensor([0] * (n//2) + [1] * (n//2))

loss = SoftNearestNeighborLoss(optim_temperature=False)

%%profile
loss(t, fam_idx)