In [None]:

EPOCHS = 10
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
writer = SummaryWriter(f"runs/piece_to_move_{timestamp}")

train_data, validation_data = torch.utils.data.random_split(train_data, [1-TEST_PERCENT, TEST_PERCENT])

training_loader = torch.utils.data.dataLoader(train_data, BATCH_SIZE, shuffle=True, pin_memory=True)
validation_loader = torch.utils.data.dataLoader(validation_data, BATCH_SIZE, shuffle=True, pin_memory=True)

def train_one_epoch(epoch_index: int, tb_writer, optimizer, training_loader, loss_fn): 
    running_loss = 0.
    last_loss = 0.

    for i, data in enumerate(training_loader):
        inputs = data[0]
        labels = data[1]

        optimizer.zero_grad()
        outputs = piece_to_move_net(inputs)

               
        loss = loss_fn(outputs, labels)
        loss.backward()

        optimizer.step()

        running_loss += loss.item()
        if i % 1000 == 999: 
            last_loss = running_loss / 1000
            print(f" batch {i + 1}, loss: {last_loss}")
            tb_x = epoch_index * len(training_loader) + i + 1
            tb_writer.add_scalar("Loss/train", last_loss, tb_x)
            running_loss = 0.

    return last_loss


def train_multiple_epochs(n_epochs, model, writer, validation_loader, loss_fn):
    for epoch in range(EPOCHS): 
        print(f"EPOCH {epoch + 1}: ")

        model.train(True)
        avg_loss = train_one_epoch(epoch, writer)

        model.train(False)
        running_vloss = 0.0
        for i, v_data in enumerate(validation_loader):
            vinputs = v_data[0]
            vlabels = v_data[1]

            voutputs = model(vinputs)
            vloss = loss_fn(voutputs, vlabels)
            running_vloss += vloss

        avg_vloss = running_vloss / (i + 1)
        print(f"LOSS train {avg_loss}, valid {avg_vloss}")

        writer.add_scalars("Training vs Validation Loss", {
            "Training": avg_loss, "Validation": avg_vloss}, 
            epoch + 1)
        writer.flush()
        
        if avg_vloss < best_vloss:
            best_vloss = avg_loss
            model_path = f"model_{timestamp}_{epoch}"
            torch.save(model.state_dict(), model_path)

        epoch += 1


In [9]:
import torch
from torch.masked import masked_tensor, as_masked_tensor
import aux_functions
import importlib

importlib.reload(aux_functions)
from aux_functions import *

# Create a 14x8x8 tensor filled with zeros
tensor = torch.zeros((14, 8, 8))

# Place the white pieces
tensor[6, 0, 0] = 1  # White rook
tensor[6, 0, 7] = 1  # White rook
tensor[6, 1, 1] = 1  # White knight
tensor[6, 1, 6] = 1  # White knight
tensor[6, 2, 2] = 1  # White bishop
tensor[6, 2, 5] = 1  # White bishop
tensor[6, 3, 3] = 1  # White queen
tensor[6, 4, 4] = 1  # White king
tensor[6, 5, 2] = 1  # White bishop
tensor[6, 5, 5] = 1  # White bishop
tensor[6, 6, 1] = 1  # White knight
tensor[6, 6, 6] = 1  # White knight
tensor[6, 7, 0] = 1  # White rook
tensor[6, 7, 7] = 1  # White rook
tensor[6, 1, 0] = 1  # White pawn
tensor[6, 1, 1] = 1  # White pawn
tensor[6, 1, 2] = 1  # White pawn
tensor[6, 1, 3] = 1  # White pawn
tensor[6, 1, 4] = 1  # White pawn
tensor[6, 1, 5] = 1  # White pawn
tensor[6, 1, 6] = 1  # White pawn
tensor[6, 1, 7] = 1  # White pawn

# Place the black pieces
tensor[0, 0, 0] = 1  # Black rook
tensor[0, 0, 7] = 1  # Black rook
tensor[0, 1, 1] = 1  # Black knight
tensor[0, 1, 6] = 1  # Black knight
tensor[0, 2, 2] = 1  # Black bishop
tensor[0, 2, 5] = 1  # Black bishop
tensor[0, 3, 3] = 1  # Black queen
tensor[0, 4, 4] = 1  # Black king
tensor[0, 5, 2] = 1  # Black bishop
tensor[0, 5, 5] = 1  # Black bishop
tensor[0, 6, 1] = 1  # Black knight
tensor[0, 6, 6] = 1  # Black knight
tensor[0, 7, 0] = 1  # Black rook
tensor[0, 7, 7] = 1  # Black rook
tensor[0, 6, 0] = 1  # Black pawn
tensor[0, 6, 1] = 1  # Black pawn
tensor[0, 6, 2] = 1  # Black pawn
tensor[0, 6, 3] = 1  # Black pawn
tensor[0, 6, 4] = 1  # Black pawn
tensor[0, 6, 5] = 1  # Black pawn
tensor[0, 6, 6] = 1  # Black pawn
tensor[0, 6, 7] = 1  # Black pawn

def generate_mask(tensor)-> list:
    """Generates a mask which contains the position of the pieces that can move"""

    # Initiate a list which will contain the position of the pieces that can move
    positions = []

    # If layer 12 has any 1 it will be whites turn
    if  torch.any(tensor[12] == 1):
        # White pieces are in layers 6 to 11
        for layer in range(6,12):
            # Apply a mask, if there is a piece it will be a 1
            mask = tensor[layer] == 1
            mt = masked_tensor(tensor, mask)

        # Sparse tensor to obtain index of the non zero elements
        sparse_coo_mt = mt.to_sparse_coo()

        # Access the index of the different elements
        for i in range(sparse_coo_mt.indices().size(1)): # Size 1 is for the number of columns
            # Obtain row and column
            row, col = sparse_coo_mt.indices()[:,i].to_list()
                
            # Obtain the position in a 64 board list
            position = row*8 + col
            # Add this position to the final list
            positions.append(position)

    # If layer 13 has any 1 it will be blacks turn
    if torch.any(tensor[13] == 1):
         # Black pieces are in range 0 to 6
        for layer in range(0,6):
            # Apply a mask, if there is a piece it will be a 1
            mask = tensor[layer] == 1
            mt = masked_tensor(tensor, mask)
        
        # Sparse tensor to obtain index of the non zero elements
        sparse_coo_mt = mt.to_sparse_coo()

        # Access the index of the different elements
        for i in range(sparse_coo_mt.indices().size(1)): # Size 1 is for the number of columns
            # Obtain row and column
            row, col = sparse_coo_mt.indices()[:,i].to_list()
                
            # Obtain the position in a 64 board list
            position = row*8 + col
            # Add this position to the final list
            positions.append(position)
    for pos in positions:
        print(pos)
    return positions

    

import torch
generate_mask(tensor)

[]