In [1]:
import os

import wandb
import lightning
import torch 
from model.commentary_models import ActualBoardTransformerMultipleHeadsModel
import sentencepiece
from typing import *
from data.ActualBoardCommentaryDataset import ActualBoardCommentaryDataset
from omegaconf import OmegaConf
import random 
import tqdm
import chess
from PIL import Image
from cairosvg import svg2png
from io import BytesIO
from serve_utils.SamplingStrategies import TopKSamplingStrategy, MultinomialSamplingStrategy
import multiprocessing
import evaluate

In [20]:
TARGET_TYPES_TO_IDS = {
    'MoveDesc': 0,
    'MoveQuality': 1,
    'Comparative': 2,
    "Strategy": 3,
    "Context": 4,
    "General": None
}

FILES = {
    key: {
        "ground_truth": open(f"./artifacts/commentary_ground_truth_{key}.txt", "r"), 
        "generated_topk": open(f"./artifacts/commentary_generated_topk_{key}.txt", "r"),
        "generated_sample": open(f"./artifacts/commentary_generated_sample_{key}.txt", "r")
    }
    for key in TARGET_TYPES_TO_IDS.keys()
}

bleu_eval = evaluate.load("bleu")
meteor_eval = evaluate.load("meteor")

for key in TARGET_TYPES_TO_IDS.keys():
    with    open(f"./artifacts/commentary_ground_truth_{key}.txt", "r") as f_ground_truth, open(f"./artifacts/commentary_generated_topk_{key}.txt", "r") as f_topk, open(f"./artifacts/commentary_generated_sample_{key}.txt", "r") as f_sample:
        ground_truths = []
        topks = []
        samples = []
        
        for line in f_ground_truth:
            ground_truths.append(line)
        for line in f_sample:
            samples.append(line)
        for line in f_topk:
            topks.append(line)
            
        minlen = min(len(ground_truths), len(samples), len(topks))
        
        ground_truths = ground_truths[:minlen]
        samples = samples[:minlen]
        topks = topks[:minlen]
        
        for predictions, mode in zip([topks, samples], ['topk', 'sample']):
            meteor = (meteor_eval.compute(predictions=predictions, references=ground_truths)['meteor'])
            bleu2 = (bleu_eval.compute(predictions=predictions, references=ground_truths, max_order=2)['bleu'])
            bleu4 = (bleu_eval.compute(predictions=predictions, references=ground_truths, max_order=4)['bleu'])
            
            print(f"Category {key} in {mode} mode has: Meteor: {meteor * 100:2f}%, Bleu-2: {bleu2 * 100:2f}%, Bleu-4: {bleu4 * 100:2f}%")
        

[nltk_data] Downloading package wordnet to
[nltk_data]     /home/georgerapeanu/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!
[nltk_data] Downloading package punkt to
[nltk_data]     /home/georgerapeanu/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package omw-1.4 to
[nltk_data]     /home/georgerapeanu/nltk_data...
[nltk_data]   Package omw-1.4 is already up-to-date!


Category MoveDesc in topk mode has: Meteor: 6.451449%, Bleu-2: 0.332758%, Bleu-4: 0.090764%
Category MoveDesc in sample mode has: Meteor: 8.519064%, Bleu-2: 3.171120%, Bleu-4: 0.422949%
Category MoveQuality in topk mode has: Meteor: 35.820571%, Bleu-2: 9.907897%, Bleu-4: 4.170502%
Category MoveQuality in sample mode has: Meteor: 21.308597%, Bleu-2: 6.462269%, Bleu-4: 0.000000%
Category Comparative in topk mode has: Meteor: 18.552224%, Bleu-2: 11.780189%, Bleu-4: 5.478214%
Category Comparative in sample mode has: Meteor: 16.559216%, Bleu-2: 10.465320%, Bleu-4: 3.520017%
Category Strategy in topk mode has: Meteor: 9.347631%, Bleu-2: 3.781002%, Bleu-4: 0.669128%
Category Strategy in sample mode has: Meteor: 10.003106%, Bleu-2: 3.735307%, Bleu-4: 0.535825%
Category Context in topk mode has: Meteor: 7.488191%, Bleu-2: 1.418690%, Bleu-4: 0.230641%
Category Context in sample mode has: Meteor: 8.814915%, Bleu-2: 3.588633%, Bleu-4: 0.594800%
Category General in topk mode has: Meteor: 8.038120%,