In [1]:
import wandb
import lightning
import torch 
from model.commentary_models import ActualBoardTransformerMultipleHeadsModel
import sentencepiece
from typing import *
from data.ActualBoardCommentaryDataset import ActualBoardCommentaryDataset
from omegaconf import OmegaConf
import random 
import tqdm
import chess
from PIL import Image
from cairosvg import svg2png
from io import BytesIO


In [2]:
# artifact = wandb.Api().artifact('georgerapeanu/thesis/model-4l3kr92p:v42', type='model')
# artifact_dir = artifact.download()

In [3]:
artifact_dir = "/home/georgerapeanu/Desktop/thesis/artifacts/model-4l3kr92p:v42"

In [4]:
model = ActualBoardTransformerMultipleHeadsModel.load_from_checkpoint(artifact_dir + "/model.ckpt")

In [5]:
class ActualBoardPredictor:
    def __init__(
            self,
            context_length: int,
            sp: sentencepiece.SentencePieceProcessor
    ):
        self.__sp = sp
        self.__context_length = context_length

    def tokens_to_string(self, tokens: torch.Tensor) -> str:
        return self.__sp.decode(tokens.view(-1).tolist()).replace("<n>", "\n")

    def predict(self, model, X_board, X_strength, X_reps, X_state, text: str, max_new_tokens: int, target_type: Optional[int] = None) -> str:
        tokens = self.__sp.encode(text.strip().replace('\n', '<n>'))
        tokens = [self.__sp.bos_id()] + tokens
        X_board = X_board.to(model.device)
        X_strength = X_strength.to(model.device)
        X_reps = X_reps.to(model.device)
        X_state = X_state.to(model.device)
        tokens = torch.Tensor(tokens).unsqueeze(0).int().to(model.device)
        tokens = model.generate(X_board.unsqueeze(0), X_strength.unsqueeze(0), X_reps.unsqueeze(0), X_state.unsqueeze(0), tokens, max_new_tokens, target_type=target_type)
        return self.tokens_to_string(tokens)

In [6]:
sp = sentencepiece.SentencePieceProcessor("./artifacts/sp1000.model")
predictor = ActualBoardPredictor(512, sp)

In [7]:
conf = OmegaConf.create({
    "processed_path": "./processed_data",
    "split": "test",
    "count_past_boards": 2,
    "target_types": [0, 1, 2, 3, 4],
    "context_length": 512,
    "stride_big_sequences": 256,
})

engine_conf = OmegaConf.create({
    "mate_value": 10000
})
ds = ActualBoardCommentaryDataset(conf, engine_conf, sp)

In [8]:
choices = random.sample(range(len(ds)), 10)
to_predict = [ds[i] for i in choices]
to_predict_metadata = [ds.get_raw_data(i) for i in choices]


In [None]:
model.eval()

for ((X_board, X_strength, X_reps, X_state, y_tokens, _), (current_board, past_board, current_eval, past_eval)) in tqdm.tqdm(zip(to_predict, to_predict_metadata), desc="Prediction"):
            predicted_text = predictor.predict(model, X_board, X_strength, X_reps, X_state, '', 1024, target_type=4)
            actual_text = predictor.tokens_to_string(y_tokens)
            past = Image.open(BytesIO(
                    svg2png(chess.svg.board(None if past_board is None else chess.Board(past_board))))).convert(
                    'RGBA')
            curr = Image.open(BytesIO(
                    svg2png(chess.svg.board(None if past_board is None else chess.Board(current_board))))).convert(
                    'RGBA')
            print(f"predicted: {predicted_text}")
            print(f"actual: {actual_text}")
            display(past)
            display(curr)

In [2]:
model = ActualBoardTransformerMultipleHeadsModel(
    2,
    16,
    16,
    8,
    8,
    16,
    16,
    16,
    1,
    2,
    2,
    OmegaConf.create([[0, 1], [0, 2]]),
    "sgd",
    0.01
)

In [3]:
torch.jit.save(model.to_torchscript(), "./artifacts/model.pt")