In [1]:
import torch

from src.data.dataset import ChessBoardDataset
from src.train.train_utils import train_test_split, reward_fn

In [2]:
dataset = ChessBoardDataset(root_dir='../sample_data',
                            transform=True,
                            return_moves=True,
                            return_outcome=True,
                            include_draws=False)

In [3]:
train_dataset, test_dataset = train_test_split(dataset=dataset,
                                               seed=0,
                                               train_size=0.8)

In [4]:
train_dataloader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=64,
                                               shuffle=True,
                                               collate_fn=lambda x: x)

test_dataloader = torch.utils.data.DataLoader(test_dataset,
                                              batch_size=64,
                                              shuffle=True,
                                              collate_fn=lambda x: x)

In [5]:
train_boards, train_moves, train_outcomes = next(iter(train_dataloader))

[32m2024-04-18 17:46:40.211[0m | [1mINFO    [0m | [36msrc.data.dataset[0m:[36m__getitems__[0m:[36m147[0m - [1mTransforming the boards to tensors...[0m
100%|██████████| 64/64 [00:00<00:00, 9891.86it/s]
[32m2024-04-18 17:46:40.229[0m | [1mINFO    [0m | [36msrc.data.dataset[0m:[36m__getitems__[0m:[36m149[0m - [1mTransforming the legal moves to tensors...[0m
100%|██████████| 64/64 [00:00<00:00, 30219.01it/s]
[32m2024-04-18 17:46:40.233[0m | [1mINFO    [0m | [36msrc.data.dataset[0m:[36m__getitems__[0m:[36m153[0m - [1mTransforming the outcomes to tensors...[0m
100%|██████████| 64/64 [00:00<00:00, 674460.94it/s]


In [6]:
test_boards, test_moves, test_outcomes = next(iter(train_dataloader))

[32m2024-04-18 17:46:41.819[0m | [1mINFO    [0m | [36msrc.data.dataset[0m:[36m__getitems__[0m:[36m147[0m - [1mTransforming the boards to tensors...[0m
100%|██████████| 64/64 [00:00<00:00, 10655.58it/s]
[32m2024-04-18 17:46:41.827[0m | [1mINFO    [0m | [36msrc.data.dataset[0m:[36m__getitems__[0m:[36m149[0m - [1mTransforming the legal moves to tensors...[0m
100%|██████████| 64/64 [00:00<00:00, 53644.18it/s]
[32m2024-04-18 17:46:41.829[0m | [1mINFO    [0m | [36msrc.data.dataset[0m:[36m__getitems__[0m:[36m153[0m - [1mTransforming the outcomes to tensors...[0m
100%|██████████| 64/64 [00:00<00:00, 958698.06it/s]


In [7]:
from loguru import logger

In [8]:
class NeuralNetwork(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.flatten = torch.nn.Flatten()
        self.linear_relu_stack = torch.nn.Sequential(
            torch.nn.Linear(12*8*8, 12*8),
            torch.nn.ReLU(),
            torch.nn.Linear(12*8, 12),
            torch.nn.ReLU(),
            torch.nn.Linear(12, 1),
            torch.nn.Tanh()
        )

    def forward(self, x):
        x = x.float()       
        x = self.flatten(x)
        score = self.linear_relu_stack(x)
        return score

In [9]:
model = NeuralNetwork()

In [10]:
model(train_boards)

tensor([[0.1707],
        [0.1681],
        [0.1729],
        [0.1636],
        [0.1679],
        [0.1622],
        [0.1625],
        [0.1652],
        [0.1802],
        [0.1695],
        [0.1621],
        [0.1833],
        [0.1648],
        [0.1560],
        [0.1629],
        [0.1582],
        [0.1726],
        [0.1819],
        [0.1671],
        [0.1734],
        [0.1734],
        [0.1719],
        [0.1608],
        [0.1638],
        [0.1642],
        [0.1561],
        [0.1756],
        [0.1601],
        [0.1699],
        [0.1769],
        [0.1752],
        [0.1723],
        [0.1826],
        [0.1847],
        [0.1782],
        [0.1668],
        [0.1666],
        [0.1806],
        [0.1662],
        [0.1637],
        [0.1627],
        [0.1720],
        [0.1646],
        [0.1591],
        [0.1658],
        [0.1798],
        [0.1549],
        [0.1717],
        [0.1617],
        [0.1700],
        [0.1780],
        [0.1795],
        [0.1681],
        [0.1618],
        [0.1600],
        [0

In [11]:
loss = torch.nn.MSELoss()

In [12]:
optimizer = torch.optim.Adagrad(model.parameters(), lr=0.01)

In [14]:
for batch in train_dataloader:
    boards, moves, outcomes = batch
    optimizer.zero_grad()
    pred = model(boards).reshape(-1)
    targets = reward_fn(outcome=outcomes, gamma=0.99)
    loss_value = loss(pred, targets)
    loss_value.backward()
    optimizer.step()
    logger.info(f'Loss: {loss_value.item()}')

[32m2024-04-09 10:04:36.940[0m | [1mINFO    [0m | [36msrc.data.dataset[0m:[36m__getitems__[0m:[36m147[0m - [1mTransforming the boards to tensors...[0m
100%|██████████| 64/64 [00:00<00:00, 10502.17it/s]
[32m2024-04-09 10:04:36.948[0m | [1mINFO    [0m | [36msrc.data.dataset[0m:[36m__getitems__[0m:[36m149[0m - [1mTransforming the legal moves to tensors...[0m
100%|██████████| 64/64 [00:00<00:00, 42581.77it/s]
[32m2024-04-09 10:04:36.952[0m | [1mINFO    [0m | [36msrc.data.dataset[0m:[36m__getitems__[0m:[36m153[0m - [1mTransforming the outcomes to tensors...[0m
100%|██████████| 64/64 [00:00<00:00, 1048576.00it/s]
[32m2024-04-09 10:04:36.956[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m9[0m - [1mLoss: 0.4149627983570099[0m


KeyboardInterrupt: 