In [None]:
from src.models.gpt import GPTModel
from src.utils import Tokenizer
import torch
import pandas as pd
from src.utils import Plotter
import wandb

In [None]:
# Download the model weights and init the model

# run = wandb.init(name="test")
# artifact = run.use_artifact("ilsenatorov/kilter-gpt/model-17j5ahs6:v0", type="model")
# artifact_dir = artifact.download()

model = GPTModel.load_from_checkpoint("artifacts/model-17j5ahs6:v0/model.ckpt")

In [None]:
# init the plotter and tokenizer
df = pd.read_csv("data/raw/climbs.csv")
plotter = Plotter()
tokenizer = Tokenizer.from_df(df, angle=True, grade=True)

In [None]:
# Create a prompt and plot the holds from the prompt
hold_prompt = "p1135r12p1395r14"
prompts = [
    (hold_prompt, 30, "5a"),
    (hold_prompt, 30, "6a"),
    (hold_prompt, 30, "7a"),
    (hold_prompt, 30, "8a"),
]
plotter.plot_climb(hold_prompt, True)
# tokenize, remove EOS token, pad left
tokenized_prompts = torch.stack(
    [
        tokenizer.encode(
            *x,
            eos=False,
            pad=model.config.context_len,
        )
        for x in prompts
    ]
).to("cuda")

In [None]:
import matplotlib.pyplot as plt

# higher temperature -> more randomness. 0.2 is a good value for this model
TEMP = 0.1
generated = model.generate(tokenized_prompts, 50, temperature=TEMP)
fig, axs = plt.subplots(1, 4, figsize=(15, 4))  # 2 rows, 2 columns

for i, z in enumerate(tokenizer.decode_batch(generated)):
    frames, angle, grade = tokenizer.clean(z)

    axs[i].imshow(plotter.plot_climb(frames))
    axs[i].set_title(f"{grade} @ {angle[1:]}°")
    axs[i].axis("off")  # Remove axis ticks and labels

fig.suptitle(f"Temp: {TEMP}")
plt.tight_layout()  # Adjust spacing for better layout
plt.show()

In [None]:
# if the model is good, save it as torchscript
script = model.to_torchscript(file_path="good_model.pt", method="script")