In [1]:
import os
if os.getcwd()[-len('irl-chess'):] != 'irl-chess':
    os.chdir('../')
import pytorch_lightning as pl
import torch
import torch.nn as nn
from torchmetrics import MeanSquaredError
from project import alpha_beta_search, evaluate_board, get_midgame_boards, load_chess_df, san_to_move
from tqdm import tqdm
import chess
from copy import copy

In [41]:
class SimpleRegressionModel(pl.LightningModule):
    def __init__(self, input_size, output_size, initial_weights=None, lr=1e-3):
        super(SimpleRegressionModel, self).__init__()
        self.lr = lr
        # Define the model architecture
        self.model = nn.Sequential(
            nn.Linear(input_size, output_size)
        )

        # Set initial weights if provided
        if initial_weights is not None:
            self.model[0].weight.data = initial_weights.clone()
        self.loss_function = nn.MSELoss()
        self.mse = MeanSquaredError()

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        x, y = batch
        predictions = self(x)
        loss = self.loss_function(predictions.squeeze(), y.squeeze())
        self.log('train_loss', loss.item(), on_epoch=True)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        y = y.squeeze()
        predictions = self(x).squeeze()
        loss = self.loss_function(predictions, y)
        self.log('val_loss', loss.item(), on_epoch=True)
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
        return optimizer

In [23]:
def get_trajectories(boards, R, heuristic=None):
    """
    Return a list of trajectories where each trajectory
    has its heuristic value calculated
    :param boards: 
    :param R: 
    :param heuristic: 
    :return: 
    """
    if heuristic is None:
        heuristic = lambda board: torch.ones_like(R)
    trajectories = torch.empty((len(boards), len(R)), dtype=R.dtype, device=R.device)
    for i, board in tqdm(enumerate(boards), desc='Generating trajectories', total=len(boards)):
        Q, board_, moves = alpha_beta_search(board=board, depth=depth, R=R.cpu().detach().numpy(), maximize=board.turn)
        trajectories[i] = heuristic(board_)
    return trajectories

def eval_boards(boards, moves, R, ):
    """
    
    :param boards: 
    :param R: 
    :return: 
    """
    trajectories = torch.empty((len(boards), len(R)), dtype=R.dtype, device=R.device)
    for i, (board, move) in tqdm(enumerate(zip(boards, moves)), total=len(boards), desc='Evaluating trajectories'):
        board.push(move)
        trajectories[i] = evaluate_board(board, R=R, tensor=True)
        board.pop()
    return trajectories

In [42]:
n_trajectories = 10
n_train = 10
depth = 3
n_epochs = 2
min_elo = 1100
max_elo = 1300
R = torch.rand(6) * 100
model = SimpleRegressionModel(input_size=len(R), output_size=1, initial_weights=R.reshape((1, -1)))

In [5]:
df = load_chess_df(n_files=3, overwrite=False)
boards, moves = get_midgame_boards(df=df, n_boards=n_trajectories, min_elo=min_elo, max_elo=max_elo, sunfish=False, move_translation=san_to_move)



-------------------  1/12  -------------------


C:\Users\toell\OneDrive\Documents\GitHub\irl-chess\data\processed\lichess_db_standard_rated_2013-01.csv already exists and was not changed
Time taken: 0.00 seconds for file
Time taken: 0.00 seconds in total


-------------------  2/12  -------------------


C:\Users\toell\OneDrive\Documents\GitHub\irl-chess\data\processed\lichess_db_standard_rated_2013-02.csv already exists and was not changed
Time taken: 0.00 seconds for file
Time taken: 0.00 seconds in total


-------------------  3/12  -------------------


C:\Users\toell\OneDrive\Documents\GitHub\irl-chess\data\processed\lichess_db_standard_rated_2013-03.csv already exists and was not changed
Time taken: 0.00 seconds for file
Time taken: 0.00 seconds in total


Contatenating DataFrames: 100%|██████████| 2/2 [00:01<00:00,  1.16it/s]
Searching for boards:   0%|          | 1095/402536 [00:00<00:10, 39050.45it/s]


In [21]:
for epoch in tqdm(range(n_epochs), desc='Epoch'):
    trajectories_new = get_trajectories(boards=boards, R=R, )
    trajectories_eval = eval_boards(boards=boards, moves=moves, R=R)
    # trajectories = torch.concat((trajectories_eval, trajectories_new))
    

Epoch:   0%|          | 0/2 [00:00<?, ?it/s]
Generating trajectories:   0%|          | 0/10 [00:00<?, ?it/s][A
Generating trajectories:  10%|█         | 1/10 [00:01<00:12,  1.42s/it][A
Generating trajectories:  20%|██        | 2/10 [00:01<00:06,  1.21it/s][A
Generating trajectories:  30%|███       | 3/10 [00:02<00:05,  1.23it/s][A
Generating trajectories:  50%|█████     | 5/10 [00:04<00:04,  1.01it/s][A
Generating trajectories:  70%|███████   | 7/10 [00:07<00:03,  1.07s/it][A
Epoch:   0%|          | 0/2 [00:07<?, ?it/s]


KeyboardInterrupt: 

In [44]:
def log_loss(discriminator, x_human, x_R):
    torch.sum(torch.log(discriminator(x_human)))
    torch.sum(torch.log(discriminator(x_R)))


tensor([[7.7463e+04],
        [6.6845e+04],
        [4.4926e+04],
        [3.9697e-01],
        [3.9697e-01],
        [3.9697e-01],
        [3.9697e-01],
        [3.9697e-01],
        [3.9697e-01],
        [3.9697e-01]], grad_fn=<AddmmBackward0>)

In [35]:
model(torch.ones((6, 6))), 

(tensor([[-0.9017],
         [-0.9017],
         [-0.9017],
         [-0.9017],
         [-0.9017],
         [-0.9017]], grad_fn=<AddmmBackward0>),)

In [29]:
R @ torch.ones(6)

tensor(254.1658)