In [20]:
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator
from typing import *
import torch
import itertools
from torch.nn.utils.rnn import pad_sequence
import math
from timeit import default_timer as timer
import torchsummary 
from nltk.translate.bleu_score import sentence_bleu

In [2]:
class ChessDataset(torch.utils.data.Dataset):
    X: List[Tuple[str, str]]
    def __init__(self, base_path: str, split: str, category: str, transform: Callable[[Tuple[str, str]], Tuple[str, str]] =None):
        self.__data = []
        
        with open(f"{base_path}/{split}.che-eng.{category}.che", "r") as fin, open(f"{base_path}/{split}.che-eng.{category}.en", "r") as fout:
            for line_in, line_out in zip(fin, fout):
                tokens_line_in = line_in.strip()
                tokens_line_out = line_out.strip()
                if transform is not None:
                   tokens_line_in, tokens_line_out = transform((tokens_line_in, tokens_line_out))
                self.__data.append((tokens_line_in, tokens_line_out))
    
    def __len__(self) -> int:
        return len(self.__data)
    
    def __getitem__(self, idx) -> Tuple[List[str], List[str]]:
        return self.__data[idx]
    

In [32]:
# Possible idea - add all one-move chess moves

In [3]:
# Vocab&dataloader logic
BASE_PATH = "../dataset"
SPLITS = ["train", "test", "valid"]
CATEGORIES = ["0attack", "0score", "0simple", "1attack", "1score", "1simple", "2.comparitiveattack", "2.comparitivescore", "2.comparitivesimple"]

target_token_transform = get_tokenizer('spacy')

UNK_IDX, PAD_IDX, BOS_IDX, EOS_IDX = 0, 1, 2, 3
special_symbols = ['<unk>', '<pad>', '<bos>', '<eos>']

train_datasets = [ChessDataset(BASE_PATH, "train", category) for category in CATEGORIES]


train_iter = itertools.chain.from_iterable(train_datasets)
source_vocab = build_vocab_from_iterator(
    map(lambda x: x[0].split(" "), train_iter),
    min_freq=1,
    specials=special_symbols,
    special_first=True
)

train_iter = itertools.chain.from_iterable(train_datasets)
target_vocab = build_vocab_from_iterator(
    map(lambda x: target_token_transform(x[1]), train_iter),
    min_freq=1,
    specials=special_symbols,
    special_first=True
)

source_vocab.set_default_index(UNK_IDX)
target_vocab.set_default_index(UNK_IDX)

def sequential_transforms(*transforms):
    def func(txt_input):
        for transform in transforms:
            txt_input = transform(txt_input)
        return txt_input
    return func

def tensor_transform(token_ids: List[int]):
    return torch.cat((torch.tensor([BOS_IDX]),
                      torch.tensor(token_ids),
                      torch.tensor([EOS_IDX])))

source_transform = sequential_transforms(
    lambda x: x.split(" "),
    source_vocab,
    tensor_transform
)

target_transform = sequential_transforms(
    target_token_transform,
    target_vocab,
    tensor_transform
)

def collate_fn(batch):
    src_batch, tgt_batch = [], []
    for src_sample, tgt_sample in batch:
        src_sample = src_sample.rstrip("\n")
        tgt_sample = tgt_sample.rstrip("\n")
        src_batch.append(source_transform(src_sample))
        tgt_batch.append(target_transform(tgt_sample))

    src_batch = pad_sequence(src_batch, padding_value=PAD_IDX)
    tgt_batch = pad_sequence(tgt_batch, padding_value=PAD_IDX)
    return src_batch, tgt_batch

def get_dataloader_for(split, category, batch_size):
    return torch.utils.data.DataLoader(ChessDataset(BASE_PATH, split, category), batch_size=batch_size, collate_fn=collate_fn)




In [4]:
#model
class PositionalEncoding(torch.nn.Module):
    def __init__(self,
                 emb_size: int,
                 dropout: float,
                 maxlen: int = 5000):
        super(PositionalEncoding, self).__init__()
        den = torch.exp(- torch.arange(0, emb_size, 2)* math.log(10000) / emb_size)
        pos = torch.arange(0, maxlen).reshape(maxlen, 1)
        pos_embedding = torch.zeros((maxlen, emb_size))
        pos_embedding[:, 0::2] = torch.sin(pos * den)
        pos_embedding[:, 1::2] = torch.cos(pos * den)
        pos_embedding = pos_embedding.unsqueeze(-2)

        self.dropout = torch.nn.Dropout(dropout)
        self.register_buffer('pos_embedding', pos_embedding)

    def forward(self, token_embedding: torch.Tensor):
        return self.dropout(token_embedding + self.pos_embedding[:token_embedding.size(0), :])
    
class TokenEmbedding(torch.nn.Module):
    def __init__(self, vocab_size: int, emb_size):
        super(TokenEmbedding, self).__init__()
        self.embedding = torch.nn.Embedding(vocab_size, emb_size)
        self.emb_size = emb_size

    def forward(self, tokens: torch.Tensor):
        return self.embedding(tokens.long()) * math.sqrt(self.emb_size)
    
class Seq2SeqTransformer(torch.nn.Module):
    def __init__(self,
                 num_encoder_layers: int,
                 num_decoder_layers: int,
                 emb_size: int,
                 nhead: int,
                 src_vocab_size: int,
                 tgt_vocab_size: int,
                 dim_feedforward: int = 512,
                 dropout: float = 0.1):
        super(Seq2SeqTransformer, self).__init__()
        self.transformer = torch.nn.Transformer(d_model=emb_size,
                                       nhead=nhead,
                                       num_encoder_layers=num_encoder_layers,
                                       num_decoder_layers=num_decoder_layers,
                                       dim_feedforward=dim_feedforward,
                                       dropout=dropout)
        self.generator = torch.nn.Linear(emb_size, tgt_vocab_size)
        self.src_tok_emb = TokenEmbedding(src_vocab_size, emb_size)
        self.tgt_tok_emb = TokenEmbedding(tgt_vocab_size, emb_size)
        self.positional_encoding = PositionalEncoding(
            emb_size, dropout=dropout)

    def forward(self,
                src: torch.Tensor,
                trg: torch.Tensor,
                src_mask: torch.Tensor,
                tgt_mask: torch.Tensor,
                src_padding_mask: torch.Tensor,
                tgt_padding_mask: torch.Tensor,
                memory_key_padding_mask: torch.Tensor):
        src_emb = self.positional_encoding(self.src_tok_emb(src))
        tgt_emb = self.positional_encoding(self.tgt_tok_emb(trg))
        outs = self.transformer(src_emb, tgt_emb, src_mask, tgt_mask, None,
                                src_padding_mask, tgt_padding_mask, memory_key_padding_mask)
        return self.generator(outs)

    def encode(self, src: torch.Tensor, src_mask: torch.Tensor):
        return self.transformer.encoder(self.positional_encoding(
                            self.src_tok_emb(src)), src_mask)

    def decode(self, tgt: torch.Tensor, memory: torch.Tensor, tgt_mask: torch.Tensor):
        return self.transformer.decoder(self.positional_encoding(
                          self.tgt_tok_emb(tgt)), memory,
                          tgt_mask)
    


In [5]:

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
def generate_square_subsequent_mask(sz):
    mask = torch.tril(torch.ones((sz, sz), device=DEVICE))
    mask = mask.masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
    return mask


def create_mask(src, tgt):
    src_seq_len = src.shape[0]
    tgt_seq_len = tgt.shape[0]

    tgt_mask = generate_square_subsequent_mask(tgt_seq_len)
    src_mask = torch.zeros((src_seq_len, src_seq_len),device=DEVICE).type(torch.bool)

    src_padding_mask = (src == PAD_IDX).transpose(0, 1)
    tgt_padding_mask = (tgt == PAD_IDX).transpose(0, 1)
    return src_mask, tgt_mask, src_padding_mask, tgt_padding_mask

In [6]:
from torch.utils.data import DataLoader

def train_epoch(model, optimizer, dataloader, loss_fn):
    model.train()
    losses = 0
    for src, tgt in dataloader:
        src = src.to(DEVICE)
        tgt = tgt.to(DEVICE)

        tgt_input = tgt[:-1, :]

        src_mask, tgt_mask, src_padding_mask, tgt_padding_mask = create_mask(src, tgt_input)

        logits = model(src, tgt_input, src_mask, tgt_mask,src_padding_mask, tgt_padding_mask, src_padding_mask)

        optimizer.zero_grad()

        tgt_out = tgt[1:, :]
        loss = loss_fn(logits.reshape(-1, logits.shape[-1]), tgt_out.reshape(-1))
        loss.backward()

        optimizer.step()
        losses += loss.item()

    return losses / len(list(dataloader))


def evaluate(model, dataloader, loss_fn):
    model.eval()
    losses = 0

    for src, tgt in dataloader:
        src = src.to(DEVICE)
        tgt = tgt.to(DEVICE)

        tgt_input = tgt[:-1, :]

        src_mask, tgt_mask, src_padding_mask, tgt_padding_mask = create_mask(src, tgt_input)

        logits = model(src, tgt_input, src_mask, tgt_mask,src_padding_mask, tgt_padding_mask, src_padding_mask)

        tgt_out = tgt[1:, :]
        loss = loss_fn(logits.reshape(-1, logits.shape[-1]), tgt_out.reshape(-1))
        losses += loss.item()

    return losses / len(list(dataloader))

In [52]:
def train_model_for_category(category: str, transformer = None):
    NUM_EPOCHS = 100
    torch.manual_seed(0)
    SRC_VOCAB_SIZE = len(source_vocab)
    TGT_VOCAB_SIZE = len(target_vocab)
    EMB_SIZE = 64
    NHEAD = 8
    FFN_HID_DIM = 128
    BATCH_SIZE = 64
    NUM_ENCODER_LAYERS = 6
    NUM_DECODER_LAYERS = 6
    
    if transformer is None:
        transformer = Seq2SeqTransformer(NUM_ENCODER_LAYERS, NUM_DECODER_LAYERS, EMB_SIZE,
                                         NHEAD, SRC_VOCAB_SIZE, TGT_VOCAB_SIZE, FFN_HID_DIM)
        
        for p in transformer.parameters():
            if p.dim() > 1:
                torch.nn.init.xavier_uniform_(p)
    
    val_dataloader = get_dataloader_for( "valid", category, BATCH_SIZE)
    dataloader = get_dataloader_for( "train", category, BATCH_SIZE)
    
    transformer = transformer.to(DEVICE)
    
    loss_fn = torch.nn.CrossEntropyLoss(ignore_index=PAD_IDX)
    
    optimizer = torch.optim.Adam(transformer.parameters(), lr=0.0001, betas=(0.9, 0.98), eps=1e-9)

    
    for epoch in range(1, NUM_EPOCHS+1):
        start_time = timer()
        train_loss = train_epoch(transformer, optimizer, dataloader, loss_fn)
        end_time = timer()
        val_loss = evaluate(transformer, val_dataloader, loss_fn)
        print((f"Epoch: {epoch}, Train loss: {train_loss:.3f}, Val loss: {val_loss:.3f}, "f"Epoch time = {(end_time - start_time):.3f}s"))
    return transformer

In [53]:
model = train_model_for_category(CATEGORIES[0])

Epoch: 1, Train loss: 8.331, Val loss: 7.178, Epoch time = 274.026s


KeyboardInterrupt: 

In [7]:
import pickle

In [51]:
pickle.dump(model, open("./save", "wb"))
model = train_model_for_category(CATEGORIES[0], model)

Epoch: 1, Train loss: 4.283, Val loss: 4.558, Epoch time = 283.741s
Epoch: 2, Train loss: 4.250, Val loss: 4.539, Epoch time = 275.639s
Epoch: 3, Train loss: 4.223, Val loss: 4.526, Epoch time = 281.756s


KeyboardInterrupt: 

In [13]:
import io
class CPU_Unpickler(pickle.Unpickler):
    def find_class(self, module, name):
        if module == 'torch.storage' and name == '_load_from_bytes':
            return lambda b: torch.load(io.BytesIO(b), map_location='cpu')
        else: return super().find_class(module, name)

model = CPU_Unpickler(open("save_drive.p", "rb")).load()

In [14]:
 def greedy_decode(model, src, src_mask, max_len, start_symbol):
    src = src.to(DEVICE)
    src_mask = src_mask.to(DEVICE)

    memory = model.encode(src, src_mask)
    ys = torch.ones(1, 1).fill_(start_symbol).type(torch.long).to(DEVICE)
    for i in range(max_len-1):
        memory = memory.to(DEVICE)
        tgt_mask = (generate_square_subsequent_mask(ys.size(0))
                    .type(torch.bool)).to(DEVICE)
        out = model.decode(ys, memory, tgt_mask)
        out = out.transpose(0, 1)
        prob = model.generator(out[:, -1])
        _, next_word = torch.max(prob, dim=1)
        next_word = next_word.item()

        ys = torch.cat([ys,
                        torch.ones(1, 1).type_as(src.data).fill_(next_word)], dim=0)
        if next_word == EOS_IDX:
            break
    return ys


# actual function to translate input sentence into target language
def translate(model: torch.nn.Module, src_sentence: str):
    model.eval()
    src = source_transform(src_sentence).view(-1, 1)
    num_tokens = src.shape[0]
    src_mask = (torch.zeros(num_tokens, num_tokens)).type(torch.bool)
    tgt_tokens = greedy_decode(
        model,  src, src_mask, max_len=num_tokens + 5, start_symbol=BOS_IDX).flatten()
    return " ".join(target_vocab.lookup_tokens(list(tgt_tokens.cpu().numpy()))).replace("<bos>", "").replace("<eos>", "")


In [15]:
ds = ChessDataset(BASE_PATH, "test", CATEGORIES[0])

In [16]:
dl = get_dataloader_for("test", CATEGORIES[0], 16)

In [17]:
for elem in dl:
    print(elem[0].shape)
    break

torch.Size([30, 16])


In [None]:
bleu_scores = {
    'bleu-2': [],
    'bleu-4': []
}

references = [
    target_transform(ds[i][1]) for i in range(len(ds))
]

for i in range(len(ds)):
    current_sample = ds[i][0]
    model_output = target_transform(translate(model, ds[i][0]))
    
    bleu_scores['bleu-2'].append(sentence_bleu(references, model_output, weights=(0, 1, 0, 0)))
    bleu_scores['bleu-4'].append(sentence_bleu(references, model_output, weights=(0, 0, 0, 1)))

In [None]:
print(sum(bleu_scores['bleu-2']) / len(bleu_scores['bleu-2']),
sum(bleu_scores['bleu-4']) / len(bleu_scores['bleu-4']))

In [38]:
ds[1][0]

'black black pawn h7 h5 <EOM> <EOMH> 17... h5 <EOR> white bishop <EOPA> <EOCA>'

In [58]:
ds[0][1]

'Stops the pawn and hopes for an exchange and maybe a pawn ...'

In [64]:
for i in range(0, 20): 
    print(ds[i][1], target_token_transform(ds[i][1]))

Stops the pawn and hopes for an exchange and maybe a pawn ... ['Stops', 'the', 'pawn', 'and', 'hopes', 'for', 'an', 'exchange', 'and', 'maybe', 'a', 'pawn', '...']
White can now no longer afford any passive move anymore , Black has a pawn phalanx marching straight toward the king , backed by a queen and a rook . ['White', 'can', 'now', 'no', 'longer', 'afford', 'any', 'passive', 'move', 'anymore', ',', 'Black', 'has', 'a', 'pawn', 'phalanx', 'marching', 'straight', 'toward', 'the', 'king', ',', 'backed', 'by', 'a', 'queen', 'and', 'a', 'rook', '.']
Will trade back my rook I lost for a minor piece earlier . ['Will', 'trade', 'back', 'my', 'rook', 'I', 'lost', 'for', 'a', 'minor', 'piece', 'earlier', '.']
need to remove his rook off the f file and temporarily put his b rook out of action . ['need', 'to', 'remove', 'his', 'rook', 'off', 'the', 'f', 'file', 'and', 'temporarily', 'put', 'his', 'b', 'rook', 'out', 'of', 'action', '.']
queen moves up . ['queen', 'moves', 'up', '.']
white to f