In [None]:
import plotly.express as px
import pandas as pd
import numpy as np
import torch
import seaborn as sns
from rindti.losses import GeneralisedLiftedStructureLoss

In [None]:
class Plotter:
    def __init__(self, num=100):
        self.num = num
        self.a = torch.randn((num, 2))
        self.b = torch.randn((num, 2))
        self.c = torch.randn((num, 2))
        self.d = torch.randn((num, 2))

    def get_data(self, mult, comp=1):
        a = self.a * comp + torch.tensor([1, 1]) * mult
        b = self.b * comp + torch.tensor([-1, -1]) * mult
        c = self.c * comp + torch.tensor([1, -1]) * mult
        d = self.d * comp + torch.tensor([-1, 1]) * mult

        data = torch.cat((a, b, c, d))
        fam_idx = torch.tensor([0] * self.num + [1] * self.num + [2] * self.num + [3] * self.num).view(-1, 1)
        loss = GeneralisedLiftedStructureLoss()
        losses = (loss.forward(data, fam_idx) ** 2).tolist()
        data = pd.DataFrame(data.numpy(), columns=["x", "y"])
        data["mult"] = mult
        data['loss'] = losses
        data['symbol'] = fam_idx.view(-1).tolist()
        return data
    
    def get_loss(self, *args, **kwargs):
        data = self.get_data(*args, **kwargs)
        return data['loss'].mean()

In [None]:
p = Plotter()
t = []
for mult in np.linspace(0,3,100):
    t.append((mult, p.get_loss(mult, comp=0.1)))
t = pd.DataFrame(t, columns=['mult', "loss"])

In [None]:
fig = px.scatter(data_frame=t, x="mult", y="loss", log_y=True, color="loss", width=1000, height=800)
fig.show()

In [None]:
p = Plotter()
datas = []
t = np.linspace(1,3,100)
for mult in t:
    datas.append(p.get_data(mult, 3/mult))
data = pd.concat(datas)
fig = px.scatter(
    data,
    x="x",
    y="y",
    symbol="symbol",
    animation_frame="mult",
    height=800,
    width=1000,
    color="loss",
    opacity=0.5,
    range_x = [-10,10],
    range_y=[-10, 10]
)
for button in fig.layout.updatemenus[0].buttons:
    button['args'][1]['frame']['redraw'] = True

for step in fig.layout.sliders[0].steps:
    step["args"][1]["frame"]["redraw"] = True

for k in range(len(fig.frames)):
    fig.frames[k]['layout'].update(title_text='Loss = {}'.format(data[data['mult'] == t[k]]['loss'].mean()))
fig.update_traces(marker=dict(size=10, line=dict(width=0.5, color="black")))
fig.update_coloraxes(showscale=False)
fig.show()