In [None]:
from src.models.gpt import GPTModel
from dataclasses import dataclass
from src.utils import Tokenizer
import torch
import matplotlib.pyplot as plt
import pandas as pd
from src.utils import Plotter, pad_to

In [None]:
@dataclass
class Config:
    batch_size = 2048
    epochs = 1000
    vocab_size = 550
    lr = 6e-4
    wd = 1e-5
    n_embed = 256
    num_blocks = 4
    num_heads = 4
    head_size = n_embed // num_heads
    context_len = 64
    attn_drop_value = 0.2
    multihead_drop_value = 0.2
    ffn_drop_value = 0.2
    min_tokens = 5


config = Config()

model = GPTModel.load_from_checkpoint("checkpoints/main.ckpt", config=config)

In [None]:
df = pd.read_csv("data/raw/climbs.csv", index_col=0)
plotter = Plotter()
tokenizer = Tokenizer.from_df(df)

In [None]:
prompts = ["p1151r12p1153r12", "p1234r13p1136r12", "p1379r14p1157r12", "p1340r13p1360r13"]

encoded_prompts = pad_to(
    torch.stack([tokenizer.encode(prompt) for prompt in prompts])[:, :-1],
    config.context_len,
    tokenizer.encode_map[tokenizer.pad_token],
).to(model.device)
encoded_prompts

In [None]:
generated = model.generate(encoded_prompts, 50, temperature=0.7)

In [None]:
decoded = ["".join(tokenizer.decode(g)).replace("[PAD]", "").split("[EOS]")[0].split("[BOS]")[-1].strip() for g in generated]
decoded

In [None]:
df.iloc[2]['name']

In [None]:
plt.imshow(encdec.plot_climb(df.iloc[2]['frames']))

In [None]:
class EncoderDecoder:
    """Converts frames to tensors and back.
    If given tensor - returns string and angle.
    If given string and angle - returns (5,48,48) tensor.
    """

    def __init__(self):
        holds = pd.read_csv("data/raw/holds.csv", index_col=0)
        image_coords = pd.read_csv("figs/image_coords.csv", index_col=0)
        self.coord_to_id = self._create_coord_to_id(holds)
        self.id_to_coord = self._create_id_to_coord(holds)
        self.image_coords = self._create_image_coords(image_coords)

    def _create_coord_to_id(self, holds: pd.DataFrame):
        hold_lookup_matrix = np.zeros((48, 48), dtype=int)
        for i in range(48):
            for j in range(48):
                hold = holds[(holds["x"] == (i * 4 + 4)) & (holds["y"] == (j * 4 + 4))]
                if not hold.empty:
                    hold_lookup_matrix[i, j] = int(hold.index[0])
        return hold_lookup_matrix

    def _create_id_to_coord(self, holds):
        id_to_coord = holds[["x", "y"]]
        id_to_coord = (id_to_coord - 4) // 4
        return id_to_coord.transpose().to_dict(orient="list")

    def _create_image_coords(self, image_coords: pd.DataFrame):
        return {name: (row["x"], row["y"]) for name, row in image_coords.iterrows()}

    def str_to_tensor(self, frames: str, angle: float) -> torch.Tensor:
        angle_matrix = torch.ones((1, 48, 48), dtype=torch.float32) * (angle / 70)
        matrix = torch.zeros((4, 48, 48), dtype=torch.float32)
        for frame in frames.split("p")[1:]:
            hold_id, color = frame.split("r")
            hold_id, color = int(hold_id), int(color) - 12
            coords = self.id_to_coord[hold_id]
            matrix[color, coords[0], coords[1]] = 1
        return torch.cat((matrix, angle_matrix), dim=0)

    def tensor_to_str(self, matrix: torch.Tensor) -> str:
        angle = ((matrix[-1].mean() * 70 / 5).round() * 5).long().item()
        matrix = matrix[:-1, :, :].round().long()
        frames = []
        counter = [0, 0, 0, 0]
        for color, x, y in zip(*torch.where(matrix)):
            counter[color] += 1
            color, x, y = color.item(), x.item(), y.item()
            # too many start/end holds
            if counter[color] > 2 and color in [0, 2]:
                continue
            hold_id = self.coord_to_id[x, y]
            # wrong hold position
            if hold_id == 0:
                continue
            role = color + 12
            frames.append((hold_id, role))
        sorted_frames = sorted(frames, key=lambda x: x[0])
        return ("".join([f"p{hold_id}r{role}" for hold_id, role in sorted_frames]), angle)

    def plot_climb(self, frames: str):
        assert isinstance(frames, str), f"Input must be frames! Got {type(frames)}"
        board_path = "figs/full_board_commercial.png"
        image = cv2.imread(board_path)
        try:
            for hold in frames.split("p")[1:]:
                hold_id, hold_type = hold.split("r")
                if int(hold_id) not in self.image_coords:
                    continue
                radius = 30
                thickness = 2
                if hold_type == str(12):
                    color = (0, 255, 0)  # start
                if hold_type == str(13):  # hands
                    color = (0, 200, 255)
                if hold_type == str(14):  # end
                    color = (255, 0, 255)
                if hold_type == str(15):  # feet
                    color = (255, 165, 0)
                image = cv2.circle(image, self.image_coords[int(hold_id)], radius, color, thickness)
        except Exception as e:
            pass
        return image

    def __call__(self, *args):
        if len(args) == 1:
            return self.tensor_to_str(*args)
        elif len(args) == 2:
            return self.str_to_tensor(*args)
        else:
            raise ValueError(f"Only 2 input args allowed! You provided {len(args)}")



In [None]:
import pandas as pd
import torch
import numpy as np
import matplotlib.pyplot as plt
import cv2
encdec = EncoderDecoder()

In [None]:
plotter.plot_climb(df.iloc[150]['frames'], return_fig=True)