In [68]:
import pandas as pd
import numpy as np
import plotly.express as px
from src.utils import EncoderDecoder, plot_climb, jaccard_similarity, get_climb_score
import torch
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from src.models.diffusion_unet import DiffusionUNet
from src.models.predict_vit import KilterModel
from typing import Union

In [102]:
climbs = pd.read_csv("data/raw/all_climbs.csv", index_col=0)
holds = pd.read_csv("data/raw/holds.csv", index_col=0)
encdec = EncoderDecoder(holds)
grades = pd.read_csv("data/raw/grades.csv", index_col=0).set_index("difficulty")["boulder_name"].to_dict()

In [73]:
model = DiffusionUNet.load_from_checkpoint(
    "logs/lightning_logs/3ilwseos/checkpoints/epoch=34-step=8855.ckpt",
    config={"dim": 64, "timesteps": 1000, "lr": 1e-4},
).cuda()
classifier = KilterModel.load_from_checkpoint(
    "logs/lightning_logs/uprwb2gy/checkpoints/epoch=36-step=1924.ckpt",
    config={"embedding_dim": 256, "dim": 1024, "depth": 4, "heads": 8, "mlp_dim": 512, "dropout": 0.1},
)

In [56]:
samples = model.diffusion.sample(100).cpu()

sampling loop time step: 100%|██████████| 1000/1000 [00:50<00:00, 19.95it/s]


In [57]:
saclimb = (samples > 0.5).long()

In [58]:
from tqdm import tqdm

jac_sim = np.zeros((len(saclimb), len(climbs)))
scores = []
for i, sampled_climb in tqdm(enumerate(list(saclimb))):
    scores.append(get_climb_score(sampled_climb))
    # for j, sim[i, j] = jaccard_similarity(real_climb, encdec(sampled_climb))

100it [00:00, 28735.98it/s]


In [59]:
stats = pd.DataFrame()

In [60]:
# stats["jac_max"] = jac_sim.max(1)
# stats["jac_mean"] = jac_sim.mean(1)
stats["score"] = scores

In [114]:
for deg in [10, 25, 40, 55, 70]:
    angles = torch.tensor([deg] * 100).cuda()
    preds = classifier.forward(saclimb.cuda().float(), angles.cuda()).cpu().detach()
    predlist = preds.round().long().view(-1).tolist()
    fig = px.histogram(
        [grades[x] for x in sorted(predlist)],
        width=500,
        height=300,
        title=f"Grade distribution for generated routes on angle {deg}",
    )
    fig.update_layout(showlegend=False)
    fig.write_image(f"angle{deg}.png")
    fig.show()

tensor([[23],
        [26],
        [26],
        [24],
        [20],
        [26],
        [23],
        [21],
        [28],
        [22],
        [23],
        [26],
        [19],
        [26],
        [23],
        [23],
        [23],
        [26],
        [20],
        [21],
        [26],
        [27],
        [25],
        [25],
        [26],
        [29],
        [24],
        [26],
        [23],
        [25],
        [25],
        [26],
        [20],
        [23],
        [28],
        [28],
        [25],
        [25],
        [27],
        [24],
        [25],
        [17],
        [22],
        [23],
        [24],
        [28],
        [28],
        [24],
        [28],
        [25],
        [20],
        [26],
        [22],
        [24],
        [26],
        [29],
        [26],
        [26],
        [26],
        [25],
        [27],
        [24],
        [29],
        [24],
        [27],
        [21],
        [25],
        [18],
        [26],
        [24],
        [23],
      

In [67]:
px.imshow(plot_climb(saclimb[42]))