### LSTM Chess engine
A totally nonsensical chess engine based off learning on PGN data.
- No checks for legal moves

In [269]:
# Imports
import numpy as np

import torch
from torch import nn, optim
from torch.utils.data import DataLoader

import pandas as pd
from collections import Counter
from pathlib import Path
import re

In [270]:
# Parameters
sequence_length = 4
batch_size = 200
epochs = 30

In [271]:
class Model(nn.Module):
    def __init__(self, dataset):
        super(Model, self).__init__()
        self.lstm_size = 128
        self.embedding_dim = 128
        self.num_layers = 3

        n_vocab = len(dataset.uniq_words)
        self.embedding = nn.Embedding(
            num_embeddings=n_vocab,
            embedding_dim=self.embedding_dim,
        )
        self.lstm = nn.LSTM(
            input_size=self.lstm_size,
            hidden_size=self.lstm_size,
            num_layers=self.num_layers,
            dropout=0.2,
        )
        self.fully_connected = nn.Linear(self.lstm_size, n_vocab)

    def forward(self, x, prev_state):
        embed = self.embedding(x)
        output, state = self.lstm(embed, prev_state)
        logits = self.fully_connected(output)
        return logits, state

    def init_state(self, sequence_length):
        return (torch.zeros(self.num_layers, sequence_length, self.lstm_size),
                torch.zeros(self.num_layers, sequence_length, self.lstm_size))

In [272]:
# PGN parsing
txt = Path("./datasets/pgn_small.txt").read_text()
games = txt.split("\n")
clean_moves = []
words = []
sequence_list = []
c = 0
for game in games:
    if c >= 100:
        break
        
    if game == "":
        continue
    moves = re.split("\d+\.", game)
    
    # Don't include short games
    if len(moves) < 10:
        continue
        
    c = c + 1
    clean_moves = [m.strip() for m in moves if m != ""]
    seq = []
    for move in clean_moves:
        m = move.split(" ")
        words.append(m[0])
        words.append(m[1])
        seq.append(m[0])
        seq.append(m[1])
    sequence_list.append(seq)

In [273]:
class Dataset(torch.utils.data.Dataset):
    def __init__(self, words, sequence_list):
        self.sequence_length = sequence_length
        self.batch_size = batch_size
        self.words = words
        self.sequence_list = sequence_list
        self.uniq_words = self.get_uniq_words()

        self.index_to_word = {index: word for index, word in enumerate(self.uniq_words)}
        self.word_to_index = {word: index for index, word in enumerate(self.uniq_words)}

        self.words_indexes = [self.word_to_index[w] for w in self.words]

        # Preprocess sequences as in getitem
        # for each sequence generate windows
        self.data = []
        for seq in sequence_list:
            for idx in range(len(seq) - self.sequence_length):
                input = seq[idx : idx + self.sequence_length]
                output = seq[idx + 1 : idx + self.sequence_length + 1]                
                input2 = [self.word_to_index[w] for w in input]
                output2 = [self.word_to_index[w] for w in output]
                # print(f"{input2}, {output2}")
                self.data.append([input2, output2])
        
        print(f"Total sequences {len(sequence_list)}")
        print(f"Unique moves {len(self.uniq_words)}")
        print(f"Data length {len(self.data)}")


    def get_uniq_words(self):
        word_counts = Counter(self.words)
        return sorted(word_counts, key=word_counts.get, reverse=True)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        x = self.data[index]
    
        return (
            torch.tensor(x[0]),
            torch.tensor(x[1])
        )

In [274]:
def train(dataset, model, epochs):
    model.train()

    dataloader = DataLoader(dataset, batch_size=batch_size)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.01)
    
    print(f"Total batchs {len(dataloader)}")

    for epoch in range(epochs):
        state_h, state_c = model.init_state(sequence_length)

        for batch, (x, y) in enumerate(dataloader):
            optimizer.zero_grad()

            y_pred, (state_h, state_c) = model(x, (state_h, state_c))
            loss = criterion(y_pred.transpose(1, 2), y)

            state_h = state_h.detach()
            state_c = state_c.detach()

            loss.backward()
            optimizer.step()
            
            if (batch + 1) % 100 == 0:
                print({ 'epoch': epoch, 'batch': batch + 1, 'loss': loss.item() })
        print({ 'epoch': epoch, 'batch': batch + 1, 'loss': loss.item() })
    print("All done!!!")

In [275]:
def predict(dataset, model, text, next_words=1):
    model.eval()

    words = text.split(' ')
    state_h, state_c = model.init_state(len(words))

    for i in range(0, next_words):
        x = torch.tensor([[dataset.word_to_index[w] for w in words[i:]]])
        y_pred, (state_h, state_c) = model(x, (state_h, state_c))

        last_word_logits = y_pred[0][-1]
        p = torch.nn.functional.softmax(last_word_logits, dim=0).detach().numpy()
        word_index = np.random.choice(len(last_word_logits), p=p)
        words.append(dataset.index_to_word[word_index])

    return words

In [276]:
dataset = Dataset(words, sequence_list)
model = Model(dataset)
train(dataset, model, epochs)

Total sequences 100
Unique moves 1071
Data length 6104
Total batchs 31
{'epoch': 0, 'batch': 31, 'loss': 6.710749626159668}
{'epoch': 1, 'batch': 31, 'loss': 6.1540446281433105}
{'epoch': 2, 'batch': 31, 'loss': 5.93922233581543}
{'epoch': 3, 'batch': 31, 'loss': 5.905770301818848}
{'epoch': 4, 'batch': 31, 'loss': 6.232414722442627}
{'epoch': 5, 'batch': 31, 'loss': 5.9441070556640625}
{'epoch': 6, 'batch': 31, 'loss': 5.647998332977295}
{'epoch': 7, 'batch': 31, 'loss': 5.270970821380615}
{'epoch': 8, 'batch': 31, 'loss': 4.948829650878906}
{'epoch': 9, 'batch': 31, 'loss': 4.686614036560059}
{'epoch': 10, 'batch': 31, 'loss': 4.235520362854004}
{'epoch': 11, 'batch': 31, 'loss': 3.987637519836426}
{'epoch': 12, 'batch': 31, 'loss': 3.9004037380218506}
{'epoch': 13, 'batch': 31, 'loss': 3.7037806510925293}
{'epoch': 14, 'batch': 31, 'loss': 3.233628511428833}
{'epoch': 15, 'batch': 31, 'loss': 2.8027713298797607}
{'epoch': 16, 'batch': 31, 'loss': 2.4834933280944824}
{'epoch': 17, 'b

In [281]:
print(predict(dataset, model, text='e4 e5 Nf3 Nc6'))

['e4', 'e5', 'Nf3', 'Nc6', 'Nc3']
