In [None]:
import plotly.express as px
import pandas as pd
import numpy as np
import torch
import seaborn as sns
from rindti.losses import GeneralisedLiftedStructureLoss, SoftNearestNeighborLoss
from sklearn.manifold import TSNE

In [None]:
class Plotter:
    def __init__(self, num=100, dims=2):
        self.num = num
        self.dims = dims
        self.gen = torch.Generator()

    def get_data(self, mean, std, temp):
        a = torch.normal(mean, std, size=(self.num, self.dims), generator=self.gen)
        b = torch.normal(-mean, std, size=(self.num, self.dims), generator=self.gen)
        data = torch.cat((a, b))
        fam_idx = torch.tensor([0] * self.num + [1] * self.num).view(-1, 1)
        loss = SoftNearestNeighborLoss(temp)
        losses = loss.forward(data, fam_idx).tolist()
        if data.shape[0] > 2:
            data = TSNE().fit_transform(data)
        data = pd.DataFrame(data, columns=["x", "y"])
        data["mean"] = mean
        data["std"] = std
        data['loss'] = losses
        data['symbol'] = fam_idx.view(-1).tolist()
        return data
    
    def get_loss(self, *args, **kwargs):
        data = self.get_data(*args, **kwargs)
        return data['loss'].mean()

In [None]:
p = Plotter(dims=128, num=25)
t = []
for mean in np.linspace(0,3,10):
    for std in np.linspace(0,3,10):
        t.append((mean, std, p.get_loss(mean, std, 1)))
t = pd.DataFrame(t, columns=['mean', "std", "loss"])

In [None]:
fig = px.scatter_3d(data_frame=t, x="mean", y="std", z="loss", log_y=True, color="loss", width=1000, height=800)
fig.show()

In [None]:
p = Plotter(dims=128)
datas = []
t = np.linspace(3,0,10)
for mean in t:
    datas.append(p.get_data(mean, 0.5, 1))
data = pd.concat(datas)
fig = px.scatter(
    data,
    x="x",
    y="y",
    symbol="symbol",
    animation_frame="mean",
    height=800,
    width=1000,
    color="loss",
    opacity=0.5,
#     range_x = [-10,10],
#     range_y=[-10, 10]
)
for button in fig.layout.updatemenus[0].buttons:
    button['args'][1]['frame']['redraw'] = True

for step in fig.layout.sliders[0].steps:
    step["args"][1]["frame"]["redraw"] = True

for k in range(len(fig.frames)):
    fig.frames[k]['layout'].update(title_text='Loss = {}'.format(data[data['mean'] == t[k]]['loss'].mean()))
fig.update_traces(marker=dict(size=10, line=dict(width=0.5, color="black")))
fig.update_coloraxes(showscale=False)
fig.show()