In [None]:
import plotly.express as px
import pandas as pd
import numpy as np
import torch
import seaborn as sns
from rindti.losses import SoftNearestNeighborLoss

In [None]:
class Plotter:
    def __init__(self, num=100):
        self.num = num
        self.a = torch.randn((num, 2))
        self.b = torch.randn((num, 2))
        self.c = torch.randn((num, 2))
        self.d = torch.randn((num, 2))

    def get_data(self, mult, temp):
        a = self.a + torch.tensor([1, 1]) * mult
        b = self.b + torch.tensor([-1, -1]) * mult
        c = self.c + torch.tensor([1, -1]) * mult
        d = self.d + torch.tensor([-1, 1]) * mult

        data = torch.cat((a, b, c, d))
        fam_idx = torch.tensor([0] * self.num + [1] * self.num + [2] * self.num + [3] * self.num).view(-1, 1)
        loss = SoftNearestNeighborLoss(temp)
        losses = loss.forward(data, fam_idx).tolist()
        data = pd.DataFrame(data.numpy(), columns=["x", "y"])
        data["mult"] = mult
        data["temp"] = temp
        data['loss'] = losses
        data['symbol'] = fam_idx.view(-1).tolist()
        return data
    
    def get_loss(self, mult, temp):
        data = self.get_data(mult, temp)
        return data['loss'].mean()

In [None]:
p = Plotter()
t = []
for mult in np.linspace(0,3,100):
    for temp in np.power(10, np.linspace(-2, 2, 100)):
        t.append((mult, temp, p.get_loss(mult, temp)))
t = pd.DataFrame(t, columns=['mult', "temp", "loss"])

In [None]:
fig = px.scatter_3d(data_frame=t, x="mult", y="temp", z="loss", log_y=True, color="loss", width=1000, height=800)
fig.show()

In [None]:
p = Plotter()
datas = []
for mult in np.linspace(0,3,100):
#     for temp in np.power(10, np.linspace(-2, 2, 100)):
    datas.append(p.get_data(mult, 1))
data = pd.concat(datas)
fig = px.scatter(
    data,
    x="x",
    y="y",
    symbol="symbol",
    animation_frame="mult",
    height=800,
    width=1000,
    color="loss",
    opacity=0.5,
    range_x = [-10,10],
    range_y=[-10, 10]
)
fig.update_traces(marker=dict(size=10, line=dict(width=0.5, color="black")))
fig.show()