# Setup

In [None]:
import os

In [None]:
run_in_colab = True

In [None]:
if run_in_colab:
  !pip install transformers
  #!pip install wandb
  !pip install git+https://github.com/google-research/bleurt.git
  !pip install datasets

  from google.colab import drive
  drive.mount('/content/drive')
  
  !git clone https://github.com/nofarmordehai/Learn-Chess-Commentary.git 'chess'
  CODE_DIR = 'chess'
  os.chdir(f'./{CODE_DIR}')

In [None]:
from transformers import AdamW, get_linear_schedule_with_warmup
from torch.utils.data import DataLoader, random_split
from tqdm import tqdm
import torch
import time
from Models.GPT2 import GPT2
from Models.BERT import BERT
from Dataset.MovesDataset import MovesDataset
from Configs.train_config import config

In [None]:
if run_in_colab:
  BASE_PATH = '/content/drive/MyDrive/NLP/'
else:
  BASE_PATH = '/home/joberant/nlp_fall_2021/nofarm/chess/'

In [None]:
games_data_path = BASE_PATH + 'Data/NEW_attack/games_data'

# Model



In [None]:
# path_base = BASE_PATH+ u'Models-test/0_1616907487.3351743_0.bin'
# path_legal = BASE_PATH+ u'Models-test/0_1616699942.8052278_0-legal.bin'

path_attack = BASE_PATH+ u'Models-Final/gpt2.bin'
path_bert = BASE_PATH+ u'Models-Final/bert.bin'

In [None]:
# gpt2 = GPT2() # fen, move, comment
# gpt2_legal = GPT2() # fen, move, desc_move, legal_moves, comment

gpt2_attack = GPT2() # fen, move, desc_move, attacks and attack by, comment
bert = BERT() # fen, move, desc_move, attacks and attack by, comment

In [None]:
# size mismatch for transformer 
# gpt2_legal.load_model(path_legal)
# gpt2.load_model(path_base)

gpt2_attack.load_model(path_attack)
bert.load_model(path_bert)

In [None]:
# gpt2.model = gpt2.model.eval()
# gpt2_legal.model = gpt2_legal.model.eval()

gpt2_attack.model = gpt2_attack.model.eval().cuda()
bert.model = bert.model.eval().cuda()

In [None]:
model = 'gpt2'

In [None]:
if model == 'gpt2':
  tested_model = gpt2_attack
  max = 768
  eof = '<|endoftext|>'
elif model == 'bert-base':
  tested_model = bert
  max = 512
  eof = 'endoftext'

In [None]:
pad_token_id = tested_model.tokenizer('[PAD]')['input_ids'][0]

# Dataset

In [None]:
# last pickle is our test-set
dataset = MovesDataset([f'{games_data_path}{i}.p' for i in [config['NUMER_OF_DATA_DIRS']] ], tested_model.tokenizer, max_length=max)

In [None]:
dataloader = DataLoader(dataset, batch_size=config['batch_size'], shuffle=True)

# Humen test

In [None]:
def get_results():
  proccessed_data, attn_masks, labels = next(iter(dataloader))

  # inputs = []
  # targets = []
  input_encodings = []
  for i in range(config['batch_size']):
    textual_data = tested_model.tokenizer.decode(token_ids = proccessed_data[i], skip_special_tokens=False).split('<comment>')

    target_text = textual_data[1].split(eof)[0]
    # targets.append(target_text)
    input_text = textual_data[0] 
    # inputs.append(input_text)

    comment_idx = list(proccessed_data[i]).index(dataset.comment_encoding) + 1
    input_encoding = proccessed_data[i][:comment_idx].unsqueeze(0).cuda()
    
    input_encodings.append(input_encoding)

  results = []
  for i in range(config['batch_size']):
    with torch.no_grad():
        outputs = tested_model.model.generate(input_encodings[i], num_beams=2, no_repeat_ngram_size=2, max_length=max+1, pad_token_id=pad_token_id)
        output_text = tested_model.tokenizer.decode(outputs[0], skip_special_tokens=True)
        results.append(output_text)

  return results

In [None]:
results = get_results()
print(results)

# Evaluation

every t iterations calculate evaluation metrics for the current model and save the results

In [None]:
from Evaluation.Metrics import perplexity, bleurt, bleu

In [None]:
from Utils import get_targets_and_outputs

In [None]:
test_perplexity = perplexity(tested_model.model, dataloader)

In [None]:
print(test_perplexity)

In [None]:
target_texts, output_texts = get_targets_and_outputs(tested_model, dataset, dataset.comment_encoding, pad_token_id, max_length=max, eof=eof)

In [None]:
test_bleurt = bleurt(target_texts, output_texts)

In [None]:
print(sum(test_bleurt)/len(test_bleurt))

In [None]:
test_bleu = bleu(target_texts, output_texts)

In [None]:
print(sum(test_bleu)/len(test_bleu))