In [1]:
from src.utils import EncoderDecoder, get_climb_score
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from src.models.diffusion import SimpleDiffusionModel
from src.models.predict import KilterModel
import numpy as np
import pandas as pd

In [None]:
model = SimpleDiffusionModel.load_from_checkpoint(
    "logs/lightning_logs/2263tvzt/checkpoints/epoch=24-step=84525.ckpt",
    config={"dim": 64, "timesteps": 1000, "lr": 1e-4, "objective" : "eps"},
).cuda()
# classifier = KilterModel.load_from_checkpoint(
#     "logs/lightning_logs/3pitto95/checkpoints/epoch=4-step=744.ckpt",
#     config={"embedding_dim": 246, "dim": 1024, "depth": 4, "heads": 8, "mlp_dim": 412, "dropout": 0.1},
# )

In [2]:
encdec = EncoderDecoder()

In [None]:
samples = model.diffusion.sample()

In [None]:
from tqdm import tqdm

scores = []
for i, sampled_climb in tqdm(enumerate(list(samples))):
    scores.append(get_climb_score(sampled_climb))
scores = pd.DataFrame(scores)

In [None]:
n = 2
fig = make_subplots(rows=n, cols=n, horizontal_spacing=0.02, vertical_spacing=0.02)
for i in range(n):
    for j in range(n):
        climb, angle = encdec(samples[i*n + j])
        fig.add_trace(go.Image(z=encdec.plot_climb(climb), name=f"Angle - {angle}"), row=i+1, col=j+1)

for i in range(n):
    for j in range(n):
        fig.update_xaxes(showticklabels=False, row=i + 1, col=j + 1)
        fig.update_yaxes(showticklabels=False, row=i + 1, col=j + 1)
fig.update_layout(
    width=1000,
    height=1000,
    margin=dict(t=5,b=5, l=5, r=5),
)
fig.show()