In [1]:
import torch
import torch.nn as nn
import torch.utils.data as tud
from tqdm.notebook import tqdm
from pprint import pprint
import matplotlib.pyplot as plt
import numpy as np
import importlib
import language_modelling_seq2seq

In [2]:
# chess

with open("data/pg5614.txt", encoding = "utf-8") as file:
    text = file.read()
print(len(text))
print(text[:1000])

556949
﻿The Project Gutenberg EBook of Chess Strategy, by Edward Lasker

This eBook is for the use of anyone anywhere at no cost and with
almost no restrictions whatsoever.  You may copy it, give it away or
re-use it under the terms of the Project Gutenberg License included
with this eBook or online at www.gutenberg.org/license


Title: Chess Strategy

Author: Edward Lasker

Translator: J. Du Mont

Release Date: November 11, 2012 [EBook #5614]

Language: English


*** START OF THIS PROJECT GUTENBERG EBOOK CHESS STRATEGY ***




Produced by John Mamoun <mamounjo@umdnj.edu>, Charles
Franks, and the Online Distributed Proofreaders website.








INFORMATION ABOUT THIS E-TEXT EDITION



The following is an e-text of "Chess Strategy," second edition, (1915)
by Edward Lasker, translated by J. Du Mont.

This e-text contains the 167 chess and checkers board game
diagrams appearing in the original book, all in the form of
ASCII line drawings. The following is a key to the diagrams:

For chess

In [3]:
len(text)

556949

In [4]:
vocabulary = set(text)
len(vocabulary)

92

In [5]:
char2i = {c:i for i, c in enumerate(sorted(vocabulary), 3)}
char2i["<PAD>"] = 0
char2i["<START>"] = 1
char2i["<END>"] = 2
print(len(char2i))
i2char = {i:c for i, c in enumerate(sorted(vocabulary), 3)}
i2char[0] = "<PAD>"
i2char[1] = "<START>"
i2char[2] = "<END>"
print(len(i2char))

95
95


In [6]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

device

device(type='cuda')

In [7]:
length = 20
lines = []
for i in range(len(text))[:-length]:
    lines.append(text[i:length + i])
print(len(text))
print(len(lines))
print(lines[:5])
encoded = torch.tensor([[char2i[c] for c in l] for l in tqdm(lines)]).to(device).long()
print(encoded.shape)
source_1 = encoded[:, :length // 2]
print(source_1.shape)
target_1 = torch.cat((torch.ones(encoded.shape[0], 1).to(device).long(), encoded[:, length // 2:]), axis = 1)
print(target_1.shape)

556949
556929
['\ufeffThe Project Gutenbe', 'The Project Gutenber', 'he Project Gutenberg', 'e Project Gutenberg ', ' Project Gutenberg E']


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=556929.0), HTML(value='')))


torch.Size([556929, 20])
torch.Size([556929, 10])
torch.Size([556929, 11])


In [8]:
importlib.reload(language_modelling_seq2seq)    
net = language_modelling_seq2seq.Seq2SeqRNN(char2i, i2char, 
                                            encoder_hidden_units = 100, encoder_layers = 2,
                                            decoder_hidden_units = 100, decoder_layers = 2)
net.to(device)
net.fit(source_1, target_1, epochs = 5, batch_size = 150, lr = 0.0001, verbose = 3, 
        save_path = "checkpoints/seq2seq_rnn.pt")

Net parameters: 408,795


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=5.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=3713.0), HTML(value='')))


Epoch:    1, Loss: 3.0290
Y
['ct Gutenbe', 't Gutenber', ' Gutenberg', 'Gutenberg ', 'utenberg E']

forward
['          ', '          ', '          ', '          ', '          ']

greedy_search
['          ', '          ', '          ', '          ', '          ']
tensor([-18.1500, -18.0000, -17.7500, -16.5600, -15.6700], device='cuda:0')

sample
[' ttcdtith ', 've eogi.eC', ' hgdiaht3s', 'N^eokre i9', '. v\natxynt']

beam_search
[['          ', '          ', '          ', '          ', '          '],
 ['   e      ', '   e      ', '   e      ', '    e     ', '   e      '],
 ['    e     ', '    e     ', '    e     ', '   e      ', '    e     '],
 ['  e       ', '  e       ', '  e       ', '     e    ', '  e       '],
 ['     e    ', '     e    ', '     e    ', '  e       ', '     e    ']]
tensor([[-18.1500, -18.0000, -17.7500, -16.5600, -15.6700],
        [-18.7400, -18.6100, -18.4200, -17.4900, -16.7600],
        [-18.7500, -18.6300, -18.4300, -17.4900, -16.7700],
        [-18.7500, -

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=3713.0), HTML(value='')))


Epoch:    2, Loss: 2.2000
Y
['ct Gutenbe', 't Gutenber', ' Gutenberg', 'Gutenberg ', 'utenberg E']

forward
[' e thne  e', 'eethne se ', 'etone  e  ', ' tne  e   ', ' ne  e   t']

greedy_search
[' the the t', 'e the the ', 'e the the ', ' the the t', ' the the t']
tensor([-11.8800, -11.9000, -11.4100, -11.6500, -11.4400], device='cuda:0')

sample
['t pinnmsva', 'h pont in ', ' B1\nhatfeo', 'he Qp. toe', 'in tiwfbrs']

beam_search
[[' the the t', 'e the the ', 'e the the ', '          ', '          '],
 ['e the the ', 'n the the ', ' the the t', '         -', '         -'],
 [' the thes ', 'te the the', 'in the the', ' the the t', ' the the t'],
 [' the the a', 'e the thes', ' the the a', '         1', '        --'],
 [' the the o', 'n the thes', ' the the o', '         h', '         1']]
tensor([[-11.8800, -11.9000, -11.4100,  -8.6700,  -9.5000],
        [-12.1400, -12.1500, -12.2400, -10.9600, -11.4200],
        [-12.7600, -12.9400, -13.0900, -11.6500, -11.4400],
        [-12.9600, -

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=3713.0), HTML(value='')))


Epoch:    3, Loss: 1.6847
Y
['ct Gutenbe', 't Gutenber', ' Gutenberg', 'Gutenberg ', 'utenberg E']

forward
['sh afterte', 'n aoth te ', 'eaoth te  ', 'tonh te  i', 'tn   e   t']

greedy_search
['s the the ', 'n the the ', 'e the the ', 'the the th', 'the the th']
tensor([-10.3700, -11.0300, -10.1900, -10.1400, -10.7500], device='cuda:0')

sample
['g\nplarices', '2 "fatutlu', 'aico bseve', 'isdrconvou', '    9Caslo']

beam_search
[['s the the ', 's the the ', 'e the the ', 'the the th', 'the the th'],
 ['s the thes', 'n the the ', 'e the thes', 'the the to', 'the the pa'],
 ['s the ther', 'on the the', 'e the ther', 'the the pa', 'the the to'],
 ['s of the t', 'n the thes', 'e the thit', 'the thes t', 'the there '],
 ['s of the p', 's the thes', 'e the thec', 'the the an', 'the the po']]
tensor([[-10.3700, -11.0100, -10.1900, -10.1400, -10.7500],
        [-11.1300, -11.0300, -11.1800, -10.9200, -11.5800],
        [-11.3400, -11.2300, -11.2700, -11.0300, -11.7200],
        [-11.4900, -

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=3713.0), HTML(value='')))


Epoch:    4, Loss: 1.4636
Y
['ct Gutenbe', 't Gutenber', ' Gutenberg', 'Gutenberg ', 'utenberg E']

forward
['nk autente', 'n aothnten', ' aothnten ', 'tot n enta', 'hn n en  t']

greedy_search
['ng the pas', 'n the past', ' and the p', 'the paster', 'he packing']
tensor([-12.8200, -12.4300, -10.0700, -11.8400, -13.1300], device='cuda:0')

sample
['m boces a ', 'den, puch ', 'idisanf an', 'frectentis', 'e. 0H-Q4 f']

beam_search
[['s of the p', 'n of the p', ' and the p', 'the Black ', 'his of the'],
 ['s of the B', 'n the the ', 'ing the pa', "the Black'", 'his the pa'],
 ['s of the K', 'n of the s', 'ing the Kt', 'of the Bla', 'his the th'],
 ['s of the c', 'n of the c', ' and the s', 'the Black\n', 'his the Kt'],
 ['s of the a', 'n of the B', 'ing the Bl', 'the Black,', 'his the pe']]
tensor([[-10.2700, -10.3700, -10.0700,  -7.7400,  -9.8700],
        [-10.8000, -10.4800, -10.5600,  -9.3900, -10.3900],
        [-10.8400, -10.9600, -10.6500,  -9.5300, -10.6900],
        [-10.9000, -

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=3713.0), HTML(value='')))


Epoch:    5, Loss: 1.3464
Y
['ct Gutenbe', 't Gutenber', ' Gutenberg', 'Gutenberg ', 'utenberg E']

forward
['nk outente', 'n outenten', ' autentert', 'autentent ', 'hnhn ert t']

greedy_search
['ng to play', 'n and the ', ' and the p', 'and the pa', 'his the pa']
tensor([-13.0700, -10.7800, -10.0100, -10.3500, -10.7200], device='cuda:0')

sample
['ged (P. 4)', 'n ...\n\n   ', ', of K5 Kt', 'won foatib', 'hin connta']

beam_search
[['ck of the ', 'k, and the', ', and the ', 'for the Kt', 'his there '],
 ['ck on the ', 'n of the p', ', and and ', 'for the pa', 'ould the p'],
 ['ck, and th', 'k, and to ', ', and ther', 'and the pa', 'his the pa'],
 ['ck and the', 'n of the K', ', and thes', 'of the Kt-', 'hing the p'],
 ['ck of the\n', 'n of the s', ' and the p', 'the pawn a', 'hing the K']]
tensor([[ -9.1200,  -8.8400,  -8.3000, -10.0000, -10.2200],
        [ -9.7200,  -9.8000,  -9.5500, -10.0500, -10.3600],
        [-10.0800, -10.2200,  -9.6000, -10.3500, -10.7200],
        [-10.0900,

In [9]:
importlib.reload(language_modelling_seq2seq)    
net = language_modelling_seq2seq.Seq2SeqRNN(char2i, i2char, 
                                            encoder_hidden_units = 100,
                                            encoder_layers = 2,
                                            decoder_hidden_units = 100,
                                            decoder_layers = 2)
net.load_state_dict(torch.load("checkpoints/seq2seq_rnn.pt"))
net.to(device)
test = net.beam_search(source_1[:10000], verbose = 1, batch_size = 50)

test[1][:20]

Net parameters: 408,795


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=19.0), HTML(value='')))




tensor([[-17.8335, -18.0513, -18.4524, -18.6212, -19.1344],
        [-17.8188, -18.4778, -19.9400, -19.9586, -19.9911],
        [-18.0875, -18.1003, -18.3117, -18.5102, -18.6515],
        [-18.8604, -19.6848, -19.7341, -19.8503, -19.8990],
        [-18.7424, -19.4441, -20.1285, -20.5640, -20.6350],
        [-17.5809, -17.8223, -17.8357, -17.8608, -17.8852],
        [-17.3872, -17.6513, -17.6753, -17.8806, -18.1069],
        [-16.1265, -16.2314, -16.6199, -18.0132, -18.0527],
        [-15.9027, -16.4403, -17.8933, -18.0407, -18.0589],
        [-16.4433, -16.6209, -17.2826, -17.5410, -17.7742],
        [-15.8254, -16.5886, -16.6080, -17.6020, -17.7822],
        [-16.7700, -17.0002, -17.0462, -17.3763, -17.4078],
        [-17.2418, -17.2750, -17.4296, -17.8269, -17.9135],
        [-15.5557, -16.1395, -16.5601, -16.5798, -16.5995],
        [-16.7301, -17.7979, -18.4939, -19.0436, -19.3501],
        [-13.5500, -14.4816, -14.6108, -14.7599, -15.2486],
        [-16.1356, -17.0112, -17.0463, -