In [1]:
import os

import wandb
import lightning
import torch 
from model.commentary_models import ActualBoardTransformerMultipleHeadsModel
import sentencepiece
from typing import *
from data.ActualBoardCommentaryDataset import ActualBoardCommentaryDataset
from omegaconf import OmegaConf
import random 
import tqdm
import chess
from PIL import Image
from cairosvg import svg2png
from io import BytesIO
from serve_utils.SamplingStrategies import TopKSamplingStrategy, MultinomialSamplingStrategy
import multiprocessing


In [2]:
model = torch.jit.load("./artifacts/model.pt")

In [3]:
class ActualBoardPredictor:
    def __init__(
            self,
            context_length: int,
            sp: sentencepiece.SentencePieceProcessor
    ):
        self.__sp = sp
        self.__context_length = context_length
        self.__topk = TopKSamplingStrategy(1.0)
        self.__multinomial = MultinomialSamplingStrategy(1.0)

    def tokens_to_string(self, tokens: torch.Tensor) -> str:
        return self.__sp.decode(tokens.view(-1).tolist())

    def predict(self, model, X_board, X_strength, X_reps, X_state, text: str, max_new_tokens: int, target_type: Optional[int] = None, do_sample: bool = False) -> str:
        X_text = torch.tensor([self.__sp.bos_id()] + self.__sp.encode(text))
        X_text = X_text.unsqueeze(0)
        X_board = X_board.unsqueeze(0)
        X_strength = X_strength.unsqueeze(0)
        X_reps = X_reps.unsqueeze(0)
        X_state = X_state.unsqueeze(0)
        if target_type is not None:
            target_type = torch.tensor(target_type)
        
        result_tokens = X_board.clone()
        
        with torch.no_grad():
            for i in range(max_new_tokens):
                X_text = X_text if X_text.size(1) < self.__context_length else X_text[:, -self.__context_length:]
                logits, _ = model(X_board, X_strength, X_reps, X_state, X_text,
                                  (torch.zeros(1, X_text.size(1)) == 1).to(X_board.device), target_type=target_type)
                
                if do_sample:
                    sampler = self.__multinomial
                else:
                    sampler = self.__topk
                text_next = sampler.execute(logits[:, -1, :])
                X_text = torch.cat([X_text, text_next], dim=1)
                result_tokens = torch.cat([X_text, text_next], dim=1)
                if text_next == self.__sp.eos_id():
                    break
            return self.tokens_to_string(result_tokens)


In [4]:
sp = sentencepiece.SentencePieceProcessor("./artifacts/sp2000.model")
predictor = ActualBoardPredictor(512, sp)

In [5]:
conf = OmegaConf.create({
    "processed_path": "./processed_data",
    "split": "test",
    "count_past_boards": 2,
    "target_types": [0, 1, 2, 3, 4],
    "context_length": 512,
    "stride_big_sequences": 256,
})

engine_conf = OmegaConf.create({
    "mate_value": 10000
})
ds = ActualBoardCommentaryDataset(conf, engine_conf, sp)

In [7]:
TARGET_TYPES_TO_IDS = {
    'MoveDesc': 0,
    'MoveQuality': 1,
    'Comparative': 2,
    "Strategy": 3,
    "Context": 4,
    "General": None
}

FILES = {
    key: {
        "ground_truth": open(f"./artifacts/commentary_ground_truth_{key}.txt", "w"), 
        "generated_topk": open(f"./artifacts/commentary_generated_topk_{key}.txt", "w"),
        "generated_sample": open(f"./artifacts/commentary_generated_sample_{key}.txt", "w")
    }
    for key in TARGET_TYPES_TO_IDS.keys()
}


model.eval()

def process_entry(target_type_entry):
    (name, target_type) = target_type_entry
     
    for (X_board, X_strength, X_reps, X_state, y_tokens, types) in tqdm.tqdm(ds, desc=name):
        if target_type is not None and types[target_type].item() <= 1e-4:
            continue
        actual_text = predictor.tokens_to_string(y_tokens)
        FILES[name]['ground_truth'].write(f"{actual_text}\n")
        FILES[name]['generated_topk'].write(f"{predictor.predict(model, X_board, X_strength, X_reps, X_state, '', 1024, target_type=target_type, do_sample=False)}\n")
        FILES[name]['generated_sample'].write(f"{predictor.predict(model, X_board, X_strength, X_reps, X_state, '', 1024, target_type=target_type, do_sample=True)}\n")
    
for entry in TARGET_TYPES_TO_IDS.items():
    process_entry(entry)
            

for files in FILES.values():
    files['ground_truth'].close()
    files['generated_topk'].close()
    files['generated_sample'].close()

MoveDesc: 100%|██████████| 33373/33373 [16:39<00:00, 33.40it/s] 
MoveQuality: 100%|██████████| 33373/33373 [00:17<00:00, 1894.09it/s]
Comparative: 100%|██████████| 33373/33373 [01:36<00:00, 344.69it/s] 
Strategy: 100%|██████████| 33373/33373 [8:38:05<00:00,  1.07it/s]    
Context: 100%|██████████| 33373/33373 [1:44:41<00:00,  5.31it/s]   
General: 100%|██████████| 33373/33373 [8:58:41<00:00,  1.03it/s]    
