In [None]:
import argparse
import pickle
from transformers import GPT2Config, GPT2LMHeadModel, GPT2Model, GPT2Tokenizer
from torch.utils.data import IterableDataset, Dataset, DataLoader
import torch
import numpy as np

class UciGPT:
    """
    This is a, on UCI notation finetuned, GPT-2 Model with the help of gpt-2-simple.
    See file uic_gpt2.py
    """
    config = GPT2Config.from_pretrained("gpt2-medium")
    config.output_hidden_states = True
    model = GPT2Model.from_pretrained("uci_checkpoint/run1/model-18500.index", from_tf=True, config=config)
    tokenizer = GPT2Tokenizer("uci_checkpoint/run1/encoder.json", "uci_checkpoint/run1/vocab.bpe")
    tokenizer.pad_token = tokenizer.eos_token
    tokenizer.padding_side = "left"
    name = "UciGPT"
    notation = "uci"

# create or open probing_dataset

def train_probing_classifier(chess_model, dataset):
    for x, boards in DataLoader(dataset, batch_size=1):
        print(x)
        #pt_outputs = chess_model.model(**games)  # todo check  shape (batch_size, sequence_length, hidden_size)
        #probing_classifier(pt_outputs.last_hidden_state)
        #exit()


class ChessGamesDataset(IterableDataset):  # todo maybe iterabledataset if memory problems
    def __init__(self, filename, model_name, tokenizer):
        self.filename = filename
        self.model_name = model_name
        self.tokenizer = tokenizer
        #self.max_length = max_length
        #self.tokenizer.model_max_length = self.max_length

    def game_mapper(self, file_iter):
        game = file_iter.split(";")
        board = np.fromstring(game[-1][1:-1], dtype=int, sep=' ')
        if self.model_name == "UciGPT":
            print("uci")
            return self.tokenizer(game[1], return_tensors='pt', padding='max_length', return_length=True), board
        elif self.model_name == "PgnGPT":
            x = self.tokenizer(game[0], return_tensors='pt', padding='max_length', return_length=True)
        elif self.model_name == "SpecialGPT":
            pass
        else:  # == PretrainedGPT
            pass  # todo check what preprocessing is nesesary for this model        print(gameannotation)
        return x, np.fromstring(game[-1][1:-1], dtype=int, sep=' ')

    def __iter__(self):
        # Create an iterator
        file_itr = open(self.filename, encoding='cp1252')
        # Map each element using the line_mapper
        mapped_itr = map(self.game_mapper, file_itr)
        return mapped_itr
        
class ChessGamesDataset2(Dataset):
    def __init__(self, data, model_name, tokenizer):
        self.data = data
        self.model_name = model_name
        self.tokenizer = tokenizer
        #self.max_length = len(max(self.data, key=lambda x: len(x[0]))[0])
        #self.tokenizer.model_max_length = self.max_length

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        pgn, uci, board = self.data[index]
        if self.model_name == "UciGPT":
            return self.tokenizer(uci, return_tensors='pt', padding='max_length', return_length=True), board
        elif self.model_name == "PgnGPT":
            return self.tokenizer(pgn, return_tensors='pt', padding='max_length'), board
        elif self.model_name == "SpecialGPT":
            pass  # todo convert strings to human readable text with feldman_gpt2
            #   probably strip beginnning tags
        else:  # == PretrainedGPT
            pass  # todo check what preprocessing is nesesary for this model
        
        
        
        
for model in [UciGPT]:  # ["pretrained_gpt","pgn_gpt", "uci_gpt", "special_gpt"
    chess_models = [UciGPT]
    for model in chess_models:  # ["pretrained_gpt","pgn_gpt", "uci_gpt", "special_gpt"]
        # Creating the iterable dataset object
        data = pickle.load(open("data/probing_dataset.txt", "rb"))
        dataset = ChessGamesDataset2(data, model.name, model.tokenizer)

        
        #dataset = ChessGamesDataset("data/probing_dataset2.txt", model.name, model.tokenizer)
        
        
        train_probing_classifier(model, dataset)
        #train_probing_classifier(model, dataset, probing_classifier)
        
        # todo save