In [None]:
import pandas as pd
import torch
from src.models.diffusion import TextDiffusionModel
import torch.nn.functional as F
from src.utils import Tokenizer

In [None]:
model = TextDiffusionModel.load_from_checkpoint("logs/lightning_logs/bxspp7rv/checkpoints/epoch=91-step=38916.ckpt")

In [None]:
samples = model.diffusion.sample()

In [None]:
def emb2indices(output, emb_layer):
    # output is size: [batch, sequence, emb_length], emb_layer is size: [num_tokens, emb_length]
    emb_weights = emb_layer.weight
    batch_size, embedding_dim, token_num = output.size()
    output = output.view(batch_size, token_num, embedding_dim)
    # get indices from embeddings:
    emb_size = output.size(0), output.size(1), -1, -1
    out_size = -1, -1, emb_weights.size(0), -1
    out_indices = torch.argmin(
        torch.abs(output.unsqueeze(2).expand(out_size) - emb_weights.unsqueeze(0).unsqueeze(0).expand(emb_size)).sum(dim=3),
        dim=2,
    )
    return out_indices

In [None]:
df = pd.read_csv("data/kiltertextdiffuse/raw/all_climbs.csv")
class Tokenizer:
    def __init__(self, df: pd.DataFrame, max_len: int = 64):
        self.df = df
        self.max_len = max_len
        self.token_map = self._get_token_map()
        self.decode_map = {v: k for k, v in self.token_map.items()}

    @staticmethod
    def split_tokens(frames: str) -> list[str]:
        res = []
        for pair in frames.split("p")[1:]:
            hold, color = pair.split("r")
            res += [f"p{hold}", f"r{color}"]
        return res

    def __call__(self, frames: str) -> torch.Tensor:
        split = self.split_tokens(frames)
        n = len(split)
        if n >= self.max_len:
            split = split[: self.max_len]
        else:
            split += ["[PAD]"] * (self.max_len - n)
        return torch.tensor([self.token_map[x] for x in split], dtype=torch.long)

    def decode(self, samples: list[list[int]]) -> list[str]:
        climbs = []
        for climb in samples:
            climb_str = ""
            for hold in climb:
                climb_str += self.decode_map[hold]
            climbs.append(climb_str)
        return climbs

    def _get_token_map(self) -> dict[str, int]:
        tokens = set()
        for name, row in self.df.iterrows():
            tokens.update(self.split_tokens(row["frames"]))
        token_map = {token: idx + 1 for idx, token in enumerate(tokens)}
        token_map["[PAD]"] = 0
        return token_map
T = Tokenizer(df)

In [None]:
from sklearn.manifold import TSNE

tsne = TSNE()
x = tsne.fit_transform(model.embedding.weight.data[:532].cpu())
df = pd.DataFrame(data=x, columns=["x", "y"])
df['token'] = T.token_map.keys()
import plotly.express as px
df['type'] = "hold"
df[df['token'].str.contains("r")] = "color"
px.scatter(df, x="x", y="y", hover_name="token", color="type", opacity=0.4)

In [None]:
t = emb2indices(samples, model.embedding)

In [None]:
res = [x.strip("[PAD]") for x in T.decode(t.tolist())]

In [None]:
res