### LSTM Chess engine
A totally nonsensical chess engine based off learning on PGN data.
- No checks for legal moves

In [48]:
# Imports
import numpy as np

import torch
from torch import nn, optim
from torch.utils.data import DataLoader

import pandas as pd
from collections import Counter
from pathlib import Path
import re

In [81]:
# Parameters
sequence_length = 4
batch_size = 200
epochs = 6

In [82]:
class Model(nn.Module):
    def __init__(self, dataset):
        super(Model, self).__init__()
        self.lstm_size = 128
        self.embedding_dim = 128
        self.num_layers = 3

        n_vocab = len(dataset.uniq_words)
        self.embedding = nn.Embedding(
            num_embeddings=n_vocab,
            embedding_dim=self.embedding_dim,
        )
        self.lstm = nn.LSTM(
            input_size=self.lstm_size,
            hidden_size=self.lstm_size,
            num_layers=self.num_layers,
            dropout=0.2,
        )
        self.fully_connected = nn.Linear(self.lstm_size, n_vocab)

    def forward(self, x, prev_state):
        embed = self.embedding(x)
        output, state = self.lstm(embed, prev_state)
        logits = self.fully_connected(output)
        return logits, state

    def init_state(self, sequence_length):
        return (torch.zeros(self.num_layers, sequence_length, self.lstm_size),
                torch.zeros(self.num_layers, sequence_length, self.lstm_size))

In [83]:
# PGN parsing
txt = Path("./datasets/pgn_small.txt").read_text()
games = txt.split("\n")
clean_moves = []
words = []
c = 0
for game in games:
    if c >= 2000:
        break
        
    if game == "":
        continue
    moves = re.split("\d+\.", game)
    
    # Don't include short games
    if len(moves) < 10:
        continue
        
    c = c + 1
    clean_moves = [m.strip() for m in moves if m != ""]
    for move in clean_moves:
        m = move.split(" ")
        words.append(m[0])
        words.append(m[1])

In [84]:
class Dataset(torch.utils.data.Dataset):
    def __init__(self, words):
        self.sequence_length = sequence_length
        self.batch_size = batch_size
        self.words = words
        self.uniq_words = self.get_uniq_words()
        print(f"unique words {len(self.uniq_words)}")

        self.index_to_word = {index: word for index, word in enumerate(self.uniq_words)}
        self.word_to_index = {word: index for index, word in enumerate(self.uniq_words)}

        self.words_indexes = [self.word_to_index[w] for w in self.words]


    def get_uniq_words(self):
        word_counts = Counter(self.words)
        return sorted(word_counts, key=word_counts.get, reverse=True)

    def __len__(self):
        return len(self.words_indexes) - self.sequence_length

    def __getitem__(self, index):
        return (
            torch.tensor(self.words_indexes[index:index+self.sequence_length]),
            torch.tensor(self.words_indexes[index+1:index+self.sequence_length+1]),
        )

In [85]:
def train(dataset, model, epochs):
    model.train()

    dataloader = DataLoader(dataset, batch_size=batch_size)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.002)
    
    print(f"Total batchs {len(dataloader)}")

    for epoch in range(epochs):
        state_h, state_c = model.init_state(sequence_length)

        for batch, (x, y) in enumerate(dataloader):
            optimizer.zero_grad()

            y_pred, (state_h, state_c) = model(x, (state_h, state_c))
            loss = criterion(y_pred.transpose(1, 2), y)

            state_h = state_h.detach()
            state_c = state_c.detach()

            loss.backward()
            optimizer.step()
            
            if (batch + 1) % 20 == 0:
                print({ 'epoch': epoch, 'batch': batch, 'loss': loss.item() })
    print("All done!!!")

In [86]:
def predict(dataset, model, text, next_words=1):
    model.eval()

    words = text.split(' ')
    state_h, state_c = model.init_state(len(words))

    for i in range(0, next_words):
        x = torch.tensor([[dataset.word_to_index[w] for w in words[i:]]])
        y_pred, (state_h, state_c) = model(x, (state_h, state_c))

        last_word_logits = y_pred[0][-1]
        p = torch.nn.functional.softmax(last_word_logits, dim=0).detach().numpy()
        word_index = np.random.choice(len(last_word_logits), p=p)
        words.append(dataset.index_to_word[word_index])

    return words

In [87]:
dataset = Dataset(words)
model = Model(dataset)
train(dataset, model, epochs)

unique words 2367
Total batchs 676
{'epoch': 0, 'batch': 19, 'loss': 6.350759983062744}
{'epoch': 0, 'batch': 39, 'loss': 6.603416919708252}
{'epoch': 0, 'batch': 59, 'loss': 6.2618865966796875}
{'epoch': 0, 'batch': 79, 'loss': 6.656838893890381}
{'epoch': 0, 'batch': 99, 'loss': 6.025484561920166}
{'epoch': 0, 'batch': 119, 'loss': 6.6966938972473145}
{'epoch': 0, 'batch': 139, 'loss': 6.281108379364014}
{'epoch': 0, 'batch': 159, 'loss': 5.565322399139404}
{'epoch': 0, 'batch': 179, 'loss': 6.430304050445557}
{'epoch': 0, 'batch': 199, 'loss': 6.0007734298706055}
{'epoch': 0, 'batch': 219, 'loss': 6.396431922912598}
{'epoch': 0, 'batch': 239, 'loss': 5.893108367919922}
{'epoch': 0, 'batch': 259, 'loss': 6.198031425476074}
{'epoch': 0, 'batch': 279, 'loss': 5.9380693435668945}
{'epoch': 0, 'batch': 299, 'loss': 6.046158313751221}
{'epoch': 0, 'batch': 319, 'loss': 6.363229274749756}
{'epoch': 0, 'batch': 339, 'loss': 6.018549919128418}
{'epoch': 0, 'batch': 359, 'loss': 6.08706235885

{'epoch': 4, 'batch': 419, 'loss': 5.189685821533203}
{'epoch': 4, 'batch': 439, 'loss': 5.065160751342773}
{'epoch': 4, 'batch': 459, 'loss': 4.8421406745910645}
{'epoch': 4, 'batch': 479, 'loss': 4.982901573181152}
{'epoch': 4, 'batch': 499, 'loss': 5.232603549957275}
{'epoch': 4, 'batch': 519, 'loss': 5.111843585968018}
{'epoch': 4, 'batch': 539, 'loss': 5.08915901184082}
{'epoch': 4, 'batch': 559, 'loss': 5.013657569885254}
{'epoch': 4, 'batch': 579, 'loss': 5.232427597045898}
{'epoch': 4, 'batch': 599, 'loss': 4.856278896331787}
{'epoch': 4, 'batch': 619, 'loss': 5.163305759429932}
{'epoch': 4, 'batch': 639, 'loss': 5.10181188583374}
{'epoch': 4, 'batch': 659, 'loss': 5.398383617401123}
{'epoch': 5, 'batch': 19, 'loss': 4.8831562995910645}
{'epoch': 5, 'batch': 39, 'loss': 5.129215717315674}
{'epoch': 5, 'batch': 59, 'loss': 4.617417812347412}
{'epoch': 5, 'batch': 79, 'loss': 5.222854137420654}
{'epoch': 5, 'batch': 99, 'loss': 4.365396022796631}
{'epoch': 5, 'batch': 119, 'loss'

In [80]:
print(predict(dataset, model, text='e4'))

['e4', 'Nc3']
