In [None]:
import pickle
import os
import mlflow
from pathlib import Path
import torch
import numpy as np
from typing import List, Dict, Tuple
from torch.utils.data import Dataset
from abc import abstractmethod
from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint

In [None]:
mlflow.pytorch.autolog()

# Parameters

In [None]:
game_state_global_dim = 2
n_players = 2
player_dim = 8
action_general_dim = 31
zone_vector_dim = 34

max_n_zone_vectors = 120
max_n_action_source_cards = 10
max_n_action_target_cards = 10

embedding_dim = 64
transformer_n_layers = 5
transformer_n_heads = 16
transformer_dim_feedforward = 128
dropout = 0.0

n_epochs = 1000
batch_size = 10
early_stopping_patience = 50

In [None]:
mlflow.log_param("embedding_dim", embedding_dim)
mlflow.log_param("transformer_n_layers", transformer_n_layers)
mlflow.log_param("transformer_n_heads", transformer_n_heads)
mlflow.log_param("transformer_dim_feedforward", transformer_dim_feedforward)
mlflow.log_param("dropout", dropout)

# Read one pickle

In [None]:
game_logs_folder_path = "../data/game_logs/"

In [None]:
dataset_from_pickles = []

In [None]:
for file_name in os.listdir(game_logs_folder_path):
    game_log_file_path = os.path.join(game_logs_folder_path, file_name)
    print(f"Read game logs from '{game_log_file_path}'")
    with open(game_log_file_path, "rb") as f:
        data_dict = pickle.load(f)

        for item_dict in data_dict["dataset"]:
            n_possible_actions = len(item_dict["possible_actions"])
            if n_possible_actions >= 2:
                dataset_from_pickles.append(item_dict)

In [None]:
len(dataset_from_pickles)

# Create a preprocessed dataset

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
def pad_tensor(
    vec: torch.Tensor,
    pad: int,
    dim: int,
    device,
    return_pad_size: bool = False
):
    """
    args:
        vec - tensor to pad
        pad - the size to pad to
        dim - dimension to pad

    return:
        a new tensor padded to 'pad' in dimension 'dim'
    """
    pad_size = 0
    if pad > vec.size(dim):
        pad_shape = list(vec.shape)
        pad_size = pad - vec.size(dim)
        pad_shape[dim] = pad_size
        padded_tensor = torch.cat([vec.to(device), torch.zeros(*pad_shape).to(device)], dim=dim)
    else:
        padded_tensor = torch.from_numpy(vec.cpu().numpy().take(torch.arange(pad), axis=dim)).to(device)
    if return_pad_size:
        return padded_tensor, pad_size
    return padded_tensor

In [None]:
class DeepLearningDataset(Dataset):
    def __init__(
        self,
        player_dataset: List[Dict],
        zone_vector_dim: int,
        max_n_zone_vectors: int,
        max_n_action_source_cards: int,
        max_n_action_target_cards: int,
        device
    ):
        super().__init__()
        self.player_dataset = player_dataset
        self.zone_vector_dim = zone_vector_dim
        self.max_n_zone_vectors = max_n_zone_vectors
        self.max_n_action_source_cards = max_n_action_source_cards
        self.max_n_action_target_cards = max_n_action_target_cards
        self.device = device

    def __len__(self) -> int:
        return len(self.player_dataset)

    def __getitem__(self, idx: int) -> Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor], Dict[str, torch.Tensor], torch.Tensor]:
        item_dict = self.player_dataset[idx]
        action_history = item_dict["action_history"]
        current_game_state = item_dict["current_game_state"]
        possible_actions = item_dict["possible_actions"]
        chosen_action_index = item_dict["chosen_action_index"]

        action_history_vectors = self.__get_action_history_vectors(action_history)
        current_game_state_vectors = self.__get_current_game_state_vectors(current_game_state)
        possible_actions_vectors = self.__get_possible_actions_vectors(possible_actions)
        target_action = self.__get_target_action(
            n_possible_actions=len(possible_actions),
            chosen_action_index=chosen_action_index
        )

        return (
            action_history_vectors,
            current_game_state_vectors,
            possible_actions_vectors,
            target_action
        )

    def __action_list_to_tensors(self, action_list: List[Dict[str, np.ndarray]]) -> Dict[str, torch.Tensor]:
        """
        Inputs:

        action_list:
        [{
          general: (action_dim,)
          source_card_vectors: (n_action_source_cards, zone_vector_dim)
          target_card_vectors: (n_action_target_cards, zone_vector_dim)
        }] * n_actions

        Return:

        action_list_vectors
        - general: (n_actions, action_dim)
        - source_card_vectors: (n_actions, max_n_action_source_cards, zone_vector_dim)
        - target_card_vectors: (n_actions, max_n_action_target_cards, zone_vector_dim)
        """
        general_vectors = []
        source_card_vectors = []
        target_card_vectors = []

        for action_dict in action_list:
            general = torch.from_numpy(action_dict["general"]).to(self.device)
            source_cards = torch.from_numpy(action_dict["source_card_vectors"])
            if len(source_cards) == 0:
                source_cards = torch.zeros(size=(self.max_n_action_source_cards, self.zone_vector_dim)).to(self.device)
            else:
                source_cards = pad_tensor(
                    source_cards,
                    pad=self.max_n_action_source_cards,
                    dim=0,
                    device=self.device
                ).to(self.device)

            target_cards = torch.from_numpy(action_dict["target_card_vectors"])
            if len(target_cards) == 0:
                target_cards = torch.zeros(size=(self.max_n_action_target_cards, self.zone_vector_dim)).to(self.device)
            else:
                target_cards = pad_tensor(
                    target_cards,
                    pad=self.max_n_action_target_cards,
                    dim=0,
                    device=self.device
                ).to(self.device)

            general_vectors.append(general[None])
            source_card_vectors.append(source_cards[None])
            target_card_vectors.append(target_cards[None])

        return {
            "general": torch.cat(general_vectors, dim=0).float().to(self.device),
            "source_card_vectors": torch.cat(source_card_vectors, dim=0).float().to(self.device),
            "target_card_vectors": torch.cat(target_card_vectors, dim=0).float().to(self.device),
        }

    def __get_action_history_vectors(self, action_history: List[Dict[str, np.ndarray]]) -> Dict[str, torch.Tensor]:
        """
        Inputs:

        action_history:
        [{
          general: (action_dim,)
          source_card_vectors: (n_action_source_cards, zone_vector_dim)
          target_card_vectors: (n_action_target_cards, zone_vector_dim)
        }] * history_size

        Return:

        action_history_vectors
        - general: (history_size, action_dim)
        - source_card_vectors: (history_size, max_n_action_source_cards, zone_vector_dim)
        - target_card_vectors: (history_size, max_n_action_target_cards, zone_vector_dim)
        """
        return self.__action_list_to_tensors(action_list=action_history)

    def __get_current_game_state_vectors(self, current_game_state: Dict[str, torch.Tensor]):
        """
        Inputs:

        current_game_state:
        {
            global: (global_dim,)
            players: (n_players, player_dim)
            zones: (n_zone_vectors, zone_vector_dim)
        }

        Return:

        current_game_state_vectors:
        - global: (global_dim,)
        - players: (n_players, player_dim)
        - zones: (max_n_zone_vectors, zone_vector_dim)
        """
        return {
            "global": torch.from_numpy(current_game_state["global"]).float().to(self.device),
            "players": torch.from_numpy(current_game_state["players"]).float().to(self.device),
            "zones": pad_tensor(
                torch.from_numpy(current_game_state["zones"]),
                pad=self.max_n_zone_vectors,
                dim=0,
                device=self.device
            ).float().to(self.device)
        }

    def __get_possible_actions_vectors(self, possible_actions: List[Dict[str, np.ndarray]]) -> Dict[str, torch.Tensor]:
        """
        Inputs:

        possible_actions:
        [{
          general: (action_dim,)
          source_card_vectors: (n_action_source_cards, zone_vector_dim)
          target_card_vectors: (n_action_target_cards, zone_vector_dim)
        }] * n_possible_actions

        Return:

        possible_actions_vectors:
        - general: (n_possible_actions, action_dim)
        - source_card_vectors: (n_possible_actions, max_n_action_source_cards, zone_vector_dim)
        - target_card_vectors: (n_possible_actions, max_n_action_target_cards, zone_vector_dim)
        """
        return self.__action_list_to_tensors(action_list=possible_actions)

    def __get_target_action(self, n_possible_actions: int, chosen_action_index: int) -> torch.Tensor:
        """
        Return:

        target_action: (n_possible_actions,)
        """
        target_action = torch.zeros(n_possible_actions).float().to(self.device)
        target_action[chosen_action_index] = 1
        return target_action

In [None]:
deep_learning_dataset = DeepLearningDataset(
    player_dataset=dataset_from_pickles,
    zone_vector_dim=zone_vector_dim,
    max_n_zone_vectors=max_n_zone_vectors,
    max_n_action_source_cards=max_n_action_source_cards,
    max_n_action_target_cards=max_n_action_target_cards,
    device=device
)

In [None]:
def pad_possible_actions_collate_fn(
    samples: List[Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor], Dict[str, torch.Tensor], torch.Tensor]],
    device
) -> Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor], Dict[str, torch.Tensor], torch.Tensor]:
    batch_action_history_vectors = {}
    batch_current_game_state_vectors = {}
    batch_possible_actions_vectors = {}
    batch_target_action = []
    batch_n_possible_actions = []

    max_n_possible_actions = max([sample[2]["general"].shape[0] for sample in samples])

    for sample in samples:
        for key, tensor in sample[0].items():
            if key not in batch_action_history_vectors:
                batch_action_history_vectors[key] = []
            batch_action_history_vectors[key].append(tensor[None])
        for key, tensor in sample[1].items():
            if key not in batch_current_game_state_vectors:
                batch_current_game_state_vectors[key] = []
            batch_current_game_state_vectors[key].append(tensor[None])
        for key, tensor in sample[2].items():
            if key not in batch_possible_actions_vectors:
                batch_possible_actions_vectors[key] = []
            tensor = pad_tensor(tensor, pad=max_n_possible_actions, dim=0, device=device)
            batch_possible_actions_vectors[key].append(tensor[None])

        padded_target_action = pad_tensor(
            sample[3],
            pad=max_n_possible_actions,
            dim=0,
            device=device
        )
        batch_target_action.append(padded_target_action[None])
        batch_n_possible_actions.append(int(sample[3].shape[0]))

    for key, tensors in batch_action_history_vectors.items():
        batch_action_history_vectors[key] = torch.cat(tensors, dim=0).to(device)
    for key, tensors in batch_current_game_state_vectors.items():
        batch_current_game_state_vectors[key] = torch.cat(tensors, dim=0).to(device)
    for key, tensors in batch_possible_actions_vectors.items():
        batch_possible_actions_vectors[key] = torch.cat(tensors, dim=0).to(device)
    batch_target_action = torch.cat(batch_target_action, dim=0).to(device)
    batch_n_possible_actions = torch.from_numpy(np.array(batch_n_possible_actions)).to(device)
    
    return (
        batch_action_history_vectors,
        batch_current_game_state_vectors,
        batch_possible_actions_vectors,
        batch_target_action,
        batch_n_possible_actions
    )

# Implement deep learning model

In [None]:
class BaseDeepLearningScorer(LightningModule):
    def __init__(self):
        super().__init__()
        self.loss = torch.nn.CrossEntropyLoss()

    @abstractmethod
    def forward(self, batch_action_history_vectors, batch_current_game_state_vectors, batch_possible_actions_vectors):
        """
        Inputs:

        batch_action_history_vectors:
        {
            general: (batch_size, history_size, action_general_dim)
            source_card_vectors: (batch_size, history_size, max_n_action_source_cards, zone_vector_dim)
            target_card_vectors: (batch_size, history_size, max_n_action_target_cards, zone_vector_dim)
        }

        batch_current_game_state_vectors:
        {
            global: (batch_size, game_state_global_dim)
            players: (batch_size, n_players, player_dim)
            zones: (batch_size, max_n_zone_vectors, zone_vector_dim)
        }

        batch_possible_actions_vectors:
        {
            general: (batch_size, max_n_possible_actions_in_batch, action_general_dim)
            source_card_vectors: (batch_size, max_n_possible_actions_in_batch, max_n_action_source_cards, zone_vector_dim)
            target_card_vectors: (batch_size, max_n_possible_actions_in_batch, max_n_action_target_cards, zone_vector_dim)
        }

        Returns:
        - batch_predicted_target_action: (batch_size, max_n_possible_actions_in_batch)
        """
        raise NotImplementedError

    def __step(self, batch, batch_idx, base_metric_name):
        batch_action_history_vectors, batch_current_game_state_vectors, batch_possible_actions_vectors, batch_target_action, batch_n_possible_actions = batch
        batch_predicted_target_action = self.forward(
            batch_action_history_vectors, batch_current_game_state_vectors, batch_possible_actions_vectors
        )

        batch_loss = self.loss(batch_predicted_target_action, batch_target_action)

        """
        batch_loss = 0.0
        for predicted_target_action, n_possible_actions, target_action in zip(batch_predicted_target_action, batch_n_possible_actions, batch_target_action):
            batch_loss += self.loss(predicted_target_action[:n_possible_actions], target_action[:n_possible_actions]) / n_possible_actions
        batch_loss /= len(batch_predicted_target_action)
        """

        self.log(f"{base_metric_name}_loss", batch_loss, on_epoch=True, prog_bar=True)
        return batch_loss

    def training_step(self, batch, batch_idx):
        return self.__step(batch=batch, batch_idx=batch_idx, base_metric_name="training")

    def validation_step(self, batch, batch_idx):
        return self.__step(batch=batch, batch_idx=batch_idx, base_metric_name="validation")

    def predict_step(self, batch, batch_idx, dataloader_idx=0):
        batch_game_state_vectors, batch_action_vectors = batch
        batch_predicted_scores = self.forward(
            batch_game_state_vectors=batch_game_state_vectors, batch_action_vectors=batch_action_vectors
        )
        return batch_predicted_scores

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters())

    def get_n_parameters(self):
        return sum(p.numel() for p in self.parameters())

In [None]:
import torch.nn as nn
from torch.nn import Parameter
import torch.nn.functional as F


class Encoder(nn.Module):
    """
    Encoder class for Pointer-Net
    """

    def __init__(self, embedding_dim,
                 hidden_dim,
                 n_layers,
                 dropout,
                 bidir):
        """
        Initiate Encoder

        :param Tensor embedding_dim: Number of embbeding channels
        :param int hidden_dim: Number of hidden units for the LSTM
        :param int n_layers: Number of layers for LSTMs
        :param float dropout: Float between 0-1
        :param bool bidir: Bidirectional
        """

        super(Encoder, self).__init__()
        self.hidden_dim = hidden_dim//2 if bidir else hidden_dim
        self.n_layers = n_layers*2 if bidir else n_layers
        self.bidir = bidir
        self.lstm = nn.LSTM(embedding_dim,
                            self.hidden_dim,
                            n_layers,
                            dropout=dropout,
                            bidirectional=bidir)

        # Used for propagating .cuda() command
        self.h0 = Parameter(torch.zeros(1), requires_grad=False)
        self.c0 = Parameter(torch.zeros(1), requires_grad=False)

    def forward(self, embedded_inputs,
                hidden):
        """
        Encoder - Forward-pass

        :param Tensor embedded_inputs: Embedded inputs of Pointer-Net
        :param Tensor hidden: Initiated hidden units for the LSTMs (h, c)
        :return: LSTMs outputs and hidden units (h, c)
        """

        embedded_inputs = embedded_inputs.permute(1, 0, 2)

        outputs, hidden = self.lstm(embedded_inputs, hidden)

        return outputs.permute(1, 0, 2), hidden

    def init_hidden(self, embedded_inputs):
        """
        Initiate hidden units

        :param Tensor embedded_inputs: The embedded input of Pointer-NEt
        :return: Initiated hidden units for the LSTMs (h, c)
        """

        batch_size = embedded_inputs.size(0)

        # Reshaping (Expanding)
        h0 = self.h0.unsqueeze(0).unsqueeze(0).repeat(self.n_layers,
                                                      batch_size,
                                                      self.hidden_dim)
        c0 = self.h0.unsqueeze(0).unsqueeze(0).repeat(self.n_layers,
                                                      batch_size,
                                                      self.hidden_dim)

        return h0, c0


class Attention(nn.Module):
    """
    Attention model for Pointer-Net
    """

    def __init__(self, input_dim,
                 hidden_dim):
        """
        Initiate Attention

        :param int input_dim: Input's diamention
        :param int hidden_dim: Number of hidden units in the attention
        """

        super(Attention, self).__init__()

        self.input_dim = input_dim
        self.hidden_dim = hidden_dim

        self.input_linear = nn.Linear(input_dim, hidden_dim)
        self.context_linear = nn.Conv1d(input_dim, hidden_dim, 1, 1)
        self.V = Parameter(torch.FloatTensor(hidden_dim), requires_grad=True)
        self._inf = Parameter(torch.FloatTensor([float('-inf')]), requires_grad=False)
        self.tanh = nn.Tanh()
        self.softmax = nn.Softmax()

        # Initialize vector V
        nn.init.uniform(self.V, -1, 1)

    def forward(self, input,
                context,
                mask):
        """
        Attention - Forward-pass

        :param Tensor input: Hidden state h
        :param Tensor context: Attention context
        :param ByteTensor mask: Selection mask
        :return: tuple of - (Attentioned hidden state, Alphas)
        """

        # (batch, hidden_dim, seq_len)
        inp = self.input_linear(input).unsqueeze(2).expand(-1, -1, context.size(1))

        # (batch, hidden_dim, seq_len)
        context = context.permute(0, 2, 1)
        ctx = self.context_linear(context)

        # (batch, 1, hidden_dim)
        V = self.V.unsqueeze(0).expand(context.size(0), -1).unsqueeze(1)

        # (batch, seq_len)
        att = torch.bmm(V, self.tanh(inp + ctx)).squeeze(1)
        if len(att[mask]) > 0:
            att[mask] = self.inf[mask]
        alpha = self.softmax(att)

        hidden_state = torch.bmm(ctx, alpha.unsqueeze(2)).squeeze(2)

        return hidden_state, alpha

    def init_inf(self, mask_size):
        self.inf = self._inf.unsqueeze(1).expand(*mask_size)


class Decoder(nn.Module):
    """
    Decoder model for Pointer-Net
    """

    def __init__(self, embedding_dim,
                 hidden_dim):
        """
        Initiate Decoder

        :param int embedding_dim: Number of embeddings in Pointer-Net
        :param int hidden_dim: Number of hidden units for the decoder's RNN
        """

        super(Decoder, self).__init__()
        self.embedding_dim = embedding_dim
        self.hidden_dim = hidden_dim

        self.input_to_hidden = nn.Linear(embedding_dim, 4 * hidden_dim)
        self.hidden_to_hidden = nn.Linear(hidden_dim, 4 * hidden_dim)
        self.hidden_out = nn.Linear(hidden_dim * 2, hidden_dim)
        self.att = Attention(hidden_dim, hidden_dim)

        # Used for propagating .cuda() command
        self.mask = Parameter(torch.ones(1), requires_grad=False)
        self.runner = Parameter(torch.zeros(1), requires_grad=False)

    def forward(self, embedded_inputs,
                decoder_input,
                hidden,
                context):
        """
        Decoder - Forward-pass

        :param Tensor embedded_inputs: Embedded inputs of Pointer-Net
        :param Tensor decoder_input: First decoder's input
        :param Tensor hidden: First decoder's hidden states
        :param Tensor context: Encoder's outputs
        :return: (Output probabilities, Pointers indices), last hidden state
        """

        batch_size = embedded_inputs.size(0)
        input_length = embedded_inputs.size(1)

        # (batch, seq_len)
        mask = self.mask.repeat(input_length).unsqueeze(0).repeat(batch_size, 1)
        self.att.init_inf(mask.size())

        # Generating arang(input_length), broadcasted across batch_size
        runner = self.runner.repeat(input_length)
        for i in range(input_length):
            runner.data[i] = i
        runner = runner.unsqueeze(0).expand(batch_size, -1).long()

        outputs = []
        pointers = []

        def step(x, hidden):
            """
            Recurrence step function

            :param Tensor x: Input at time t
            :param tuple(Tensor, Tensor) hidden: Hidden states at time t-1
            :return: Hidden states at time t (h, c), Attention probabilities (Alpha)
            """

            # Regular LSTM
            h, c = hidden

            gates = self.input_to_hidden(x) + self.hidden_to_hidden(h)
            input, forget, cell, out = gates.chunk(4, 1)

            input = F.sigmoid(input)
            forget = F.sigmoid(forget)
            cell = F.tanh(cell)
            out = F.sigmoid(out)

            c_t = (forget * c) + (input * cell)
            h_t = out * F.tanh(c_t)

            # Attention section
            hidden_t, output = self.att(h_t, context, torch.eq(mask, 0))
            hidden_t = F.tanh(self.hidden_out(torch.cat((hidden_t, h_t), 1)))

            return hidden_t, c_t, output

        # Recurrence loop
        for _ in range(input_length):
            h_t, c_t, outs = step(decoder_input, hidden)
            hidden = (h_t, c_t)

            # Masking selected inputs
            masked_outs = outs * mask

            # Get maximum probabilities and indices
            max_probs, indices = masked_outs.max(1)
            one_hot_pointers = (runner == indices.unsqueeze(1).expand(-1, outs.size()[1])).float()

            # Update mask to ignore seen indices
            mask  = mask * (1 - one_hot_pointers)

            # Get embedded inputs by max indices
            embedding_mask = one_hot_pointers.unsqueeze(2).expand(-1, -1, self.embedding_dim).byte()
            decoder_input = embedded_inputs[embedding_mask.data].view(batch_size, self.embedding_dim)

            outputs.append(outs.unsqueeze(0))
            pointers.append(indices.unsqueeze(1))

        outputs = torch.cat(outputs).permute(1, 0, 2)
        pointers = torch.cat(pointers, 1)

        return (outputs, pointers), hidden


class PointerNet(nn.Module):
    """
    Pointer-Net
    """

    def __init__(self, embedding_dim,
                 hidden_dim,
                 lstm_layers,
                 dropout,
                 bidir=False):
        """
        Initiate Pointer-Net

        :param int embedding_dim: Number of embbeding channels
        :param int hidden_dim: Encoders hidden units
        :param int lstm_layers: Number of layers for LSTMs
        :param float dropout: Float between 0-1
        :param bool bidir: Bidirectional
        """

        super(PointerNet, self).__init__()
        self.embedding_dim = embedding_dim
        self.bidir = bidir
        self.encoder = Encoder(embedding_dim,
                               hidden_dim,
                               lstm_layers,
                               dropout,
                               bidir)
        self.decoder = Decoder(embedding_dim, hidden_dim)
        self.decoder_input0 = Parameter(torch.FloatTensor(embedding_dim), requires_grad=False)

        # Initialize decoder_input0
        nn.init.uniform(self.decoder_input0, -1, 1)

    def forward(self, inputs):
        """
        PointerNet - Forward-pass

        :param Tensor inputs: Input sequence
        :return: Pointers probabilities and indices
        """

        batch_size = inputs.size(0)

        decoder_input0 = self.decoder_input0.unsqueeze(0).expand(batch_size, -1)

        embedded_inputs = inputs

        encoder_hidden0 = self.encoder.init_hidden(embedded_inputs)
        encoder_outputs, encoder_hidden = self.encoder(embedded_inputs,
                                                       encoder_hidden0)
        if self.bidir:
            decoder_hidden0 = (torch.cat(encoder_hidden[0][-2:], dim=-1),
                               torch.cat(encoder_hidden[1][-2:], dim=-1))
        else:
            decoder_hidden0 = (encoder_hidden[0][-1],
                               encoder_hidden[1][-1])
        (outputs, pointers), decoder_hidden = self.decoder(embedded_inputs,
                                                           decoder_input0,
                                                           decoder_hidden0,
                                                           encoder_outputs)

        return  outputs, pointers

In [None]:
class ActionProcessingBlock(LightningModule):
    def __init__(
        self,
        action_general_dim: int,
        max_n_action_source_cards: int,
        max_n_action_target_cards: int,
        zone_vector_dim: int,
        output_dim: int,
        transformer_n_layers: int = 1,
        transformer_n_heads: int = 1,
        transformer_dim_feedforward: int = 128,
        dropout: float = 0.0,
    ):
        super().__init__()
        self.action_general_dim = action_general_dim
        self.max_n_action_source_cards = max_n_action_source_cards
        self.max_n_action_target_cards = max_n_action_target_cards
        self.zone_vector_dim = zone_vector_dim
        assert output_dim > 3
        self.output_dim = output_dim
        self.transformer_n_layers = transformer_n_layers
        self.transformer_n_heads = transformer_n_heads
        self.transformer_dim_feedforward = transformer_dim_feedforward
        self.dropout = dropout

        # Modules
        self.general_mlp = torch.nn.Sequential(
            torch.nn.Linear(in_features=self.action_general_dim, out_features=self.output_dim - 3),
            torch.nn.ReLU()
        )
        self.card_mlp = torch.nn.Sequential(
            torch.nn.Linear(in_features=self.zone_vector_dim, out_features=self.output_dim - 3),
            torch.nn.ReLU()
        )
        self.transformer_encoder = torch.nn.TransformerEncoder(
            encoder_layer=torch.nn.TransformerEncoderLayer(
                d_model=self.output_dim,
                nhead=self.transformer_n_heads,
                dim_feedforward=self.transformer_dim_feedforward,
                dropout=self.dropout,
                activation="relu",
                batch_first=True
            ),
            num_layers=self.transformer_n_layers
        )

    def forward(self, action_vectors: Dict[str, torch.Tensor]):
        """
        Inputs:
        - action_vectors:
        {
            general: (batch_size, action_general_dim)
            source_card_vectors: (batch_size, max_n_action_source_cards, zone_vector_dim)
            target_card_vectors: (batch_size, max_n_action_target_cards, zone_vector_dim)
        }

        Outputs:
        - action_embedding: (batch_size, output_dim)
        """
        batch_size = action_vectors["general"].shape[0]

        action_general_embedding = self.__get_action_general_embedding(action_vectors["general"])
        action_source_card_embeddings = self.__get_action_source_card_embeddings(action_vectors["source_card_vectors"])
        action_target_card_embeddings = self.__get_action_target_card_embeddings(action_vectors["target_card_vectors"])

        action_embedding_for_prediction = torch.zeros(batch_size, 1, self.output_dim).to(action_general_embedding)

        action_embeddings_sequence = torch.cat(
            [
                action_embedding_for_prediction,
                action_general_embedding[:, None],
                action_source_card_embeddings,
                action_target_card_embeddings
            ],
            dim=1
        )

        action_embeddings_sequence_after_transformer = self.transformer_encoder(action_embeddings_sequence)

        return action_embeddings_sequence_after_transformer[:, 0, :]

    def __get_action_general_embedding(self, action_general_vector: torch.Tensor) -> torch.Tensor:
        batch_size = action_general_vector.shape[0]
        action_general_embedding = self.general_mlp(action_general_vector)
        action_general_embedding_type = torch.tensor([[1.0, 0.0, 0.0]]).repeat(batch_size, 1).to(action_general_embedding)
        return torch.cat([action_general_embedding, action_general_embedding_type], dim=1)

    def __get_action_source_card_embeddings(self, action_source_card_vectors: torch.Tensor) -> torch.Tensor:
        batch_size = action_source_card_vectors.shape[0]
        action_source_card_embeddings = self.card_mlp(action_source_card_vectors)
        action_source_card_embeddings_type = torch.tensor([[[0.0, 1.0, 0.0]]]).repeat(batch_size, self.max_n_action_source_cards, 1).to(action_source_card_embeddings)
        return torch.cat([action_source_card_embeddings, action_source_card_embeddings_type], dim=2)

    def __get_action_target_card_embeddings(self, action_target_card_vectors: torch.Tensor) -> torch.Tensor:
        batch_size = action_target_card_vectors.shape[0]
        action_target_card_embeddings = self.card_mlp(action_target_card_vectors)
        action_target_card_embeddings_type = torch.tensor([[[0.0, 0.0, 1.0]]]).repeat(batch_size, self.max_n_action_target_cards, 1).to(action_target_card_embeddings)
        return torch.cat([action_target_card_embeddings, action_target_card_embeddings_type], dim=2)


class GameStateProcessingBlock(LightningModule):
    def __init__(
        self,
        game_state_global_dim: int,
        n_players: int,
        player_dim: int,
        max_n_zone_vectors: int,
        zone_vector_dim: int,
        output_dim: int,
        transformer_n_layers: int = 1,
        transformer_n_heads: int = 1,
        transformer_dim_feedforward: int = 128,
        dropout: float = 0.0,
    ):
        super().__init__()
        self.game_state_global_dim = game_state_global_dim
        self.n_players = n_players
        self.player_dim = player_dim
        self.max_n_zone_vectors = max_n_zone_vectors
        self.zone_vector_dim = zone_vector_dim
        assert output_dim > 3
        self.output_dim = output_dim
        self.transformer_n_layers = transformer_n_layers
        self.transformer_n_heads = transformer_n_heads
        self.transformer_dim_feedforward = transformer_dim_feedforward
        self.dropout = dropout

        # Modules
        self.global_mlp = torch.nn.Sequential(
            torch.nn.Linear(in_features=self.game_state_global_dim, out_features=self.output_dim - 3),
            torch.nn.ReLU()
        )
        self.player_mlp = torch.nn.Sequential(
            torch.nn.Linear(in_features=self.player_dim, out_features=self.output_dim - 3),
            torch.nn.ReLU()
        )
        self.zone_mlp = torch.nn.Sequential(
            torch.nn.Linear(in_features=self.zone_vector_dim, out_features=self.output_dim - 3),
            torch.nn.ReLU()
        )
        self.transformer_encoder = torch.nn.TransformerEncoder(
            encoder_layer=torch.nn.TransformerEncoderLayer(
                d_model=self.output_dim,
                nhead=self.transformer_n_heads,
                dim_feedforward=self.transformer_dim_feedforward,
                dropout=self.dropout,
                activation="relu",
                batch_first=True
            ),
            num_layers=self.transformer_n_layers
        )

    def forward(self, game_state_vectors: Dict[str, torch.Tensor]):
        """
        Inputs:
        - game_state_vectors:
        {
            global: (batch_size, game_state_global_dim)
            players: (batch_size, n_players, player_dim)
            zones: (batch_size, max_n_zone_vectors, zone_vector_dim)
        }

        Outputs:
        - game_state_embedding: (batch_size, output_dim)
        """
        batch_size = game_state_vectors["global"].shape[0]

        global_embedding = self.__get_global_embedding(game_state_vectors["global"])
        player_embeddings = self.__get_player_embeddings(game_state_vectors["players"])
        zone_embeddings = self.__get_zone_embeddings(game_state_vectors["zones"])

        embedding_for_prediction = torch.zeros(batch_size, 1, self.output_dim).to(global_embedding)

        embeddings_sequence = torch.cat(
            [
                embedding_for_prediction,
                global_embedding[:, None],
                player_embeddings,
                zone_embeddings
            ],
            dim=1
        )

        embeddings_sequence_after_transformer = self.transformer_encoder(embeddings_sequence)

        return embeddings_sequence_after_transformer[:, 0, :]

    def __get_global_embedding(self, game_state_global_vector: torch.Tensor) -> torch.Tensor:
        batch_size = game_state_global_vector.shape[0]
        global_embedding = self.global_mlp(game_state_global_vector)
        global_embedding_type = torch.tensor([[1.0, 0.0, 0.0]]).repeat(batch_size, 1).to(global_embedding)
        return torch.cat([global_embedding, global_embedding_type], dim=1)

    def __get_player_embeddings(self, game_state_player_vectors: torch.Tensor) -> torch.Tensor:
        batch_size = game_state_player_vectors.shape[0]
        player_embeddings = self.player_mlp(game_state_player_vectors)
        player_embeddings_type = torch.tensor([[[0.0, 1.0, 0.0]]]).repeat(batch_size, self.n_players, 1).to(player_embeddings)
        return torch.cat([player_embeddings, player_embeddings_type], dim=2)

    def __get_zone_embeddings(self, game_state_zone_vectors: torch.Tensor) -> torch.Tensor:
        batch_size = game_state_zone_vectors.shape[0]
        zone_embeddings = self.zone_mlp(game_state_zone_vectors)
        zone_embeddings_type = torch.tensor([[[0.0, 0.0, 1.0]]]).repeat(batch_size, self.max_n_zone_vectors, 1).to(zone_embeddings)
        return torch.cat([zone_embeddings, zone_embeddings_type], dim=2)


class ClassificationBlock(LightningModule):
    def __init__(
        self,
        input_dim: int,
        transformer_n_layers: int = 1,
        transformer_n_heads: int = 1,
        transformer_dim_feedforward: int = 128,
        pointer_net_n_lstm_layers: int = 1,
        pointer_net_hidden_dim: int = 128,
        dropout: float = 0.0,
    ):
        super().__init__()
        self.input_dim = input_dim
        self.transformer_n_layers = transformer_n_layers
        self.transformer_n_heads = transformer_n_heads
        self.transformer_dim_feedforward = transformer_dim_feedforward
        self.pointer_net_n_lstm_layers = pointer_net_n_lstm_layers
        self.pointer_net_hidden_dim = pointer_net_hidden_dim
        self.dropout = dropout

        # Modules
        self.preprocessing_mlp = torch.nn.Sequential(
            torch.nn.Linear(in_features=self.input_dim + 3, out_features=self.input_dim),
            torch.nn.ReLU()
        )
        self.transformer_encoder = torch.nn.TransformerEncoder(
            encoder_layer=torch.nn.TransformerEncoderLayer(
                d_model=self.input_dim,
                nhead=self.transformer_n_heads,
                dim_feedforward=self.transformer_dim_feedforward,
                dropout=self.dropout,
                activation="relu",
                batch_first=True
            ),
            num_layers=self.transformer_n_layers
        )
        self.pointer_net = PointerNet(
            embedding_dim=self.input_dim,
            hidden_dim=self.pointer_net_hidden_dim,
            lstm_layers=self.pointer_net_n_lstm_layers,
            dropout=self.dropout,
            bidir=False
        )

    def forward(
        self,
        batch_action_history_embeddings: torch.Tensor,
        batch_current_game_state_embedding: torch.Tensor,
        batch_possible_actions_embeddings: torch.Tensor
    ) -> torch.Tensor:
        """
        Inputs:
        - batch_action_history_embeddings: (batch_size, history_size, input_dim)
        - batch_current_game_state_embedding: (batch_size, input_dim)
        - batch_possible_actions_embeddings: (batch_size, max_n_possible_actions_in_batch, input_dim)

        Outputs:
        - batch_predicted_target_action: (batch_size, max_n_possible_actions_in_batch)
        """
        max_n_possible_actions_in_batch = batch_possible_actions_embeddings.shape[1]

        batch_action_history_embeddings = self.__prepare_action_history_embeddings(
            batch_action_history_embeddings
        )
        batch_current_game_state_embedding = self.__prepare_current_game_state_embedding(
            batch_current_game_state_embedding
        )
        batch_possible_actions_embeddings = self.__prepare_possible_actions_embeddings(
            batch_possible_actions_embeddings
        )

        embeddings_sequence = torch.cat(
            [
                batch_action_history_embeddings,
                batch_current_game_state_embedding,
                batch_possible_actions_embeddings
            ],
            dim=1
        )

        embeddings_sequence_after_transformer = self.transformer_encoder(embeddings_sequence)

        possible_actions_embeddings = embeddings_sequence_after_transformer[:, -max_n_possible_actions_in_batch:]

        predicted_target_action_probabilities, predicted_target_action_pointers = self.pointer_net(possible_actions_embeddings)

        return predicted_target_action_probabilities[..., 0]

    def __prepare_action_history_embeddings(self, batch_action_history_embeddings: torch.Tensor) -> torch.Tensor:
        batch_size = batch_action_history_embeddings.shape[0]
        history_size = batch_action_history_embeddings.shape[1]
        embeddings_type = torch.tensor([[[1.0, 0.0, 0.0]]]).repeat(batch_size, history_size, 1).to(batch_action_history_embeddings)
        return self.preprocessing_mlp(torch.cat([batch_action_history_embeddings, embeddings_type], dim=2))

    def __prepare_current_game_state_embedding(self, batch_current_game_state_embedding: torch.Tensor) -> torch.Tensor:
        batch_size = batch_current_game_state_embedding.shape[0]
        embedding_type = torch.tensor([[0.0, 1.0, 0.0]]).repeat(batch_size, 1).to(batch_current_game_state_embedding)
        return self.preprocessing_mlp(torch.cat([batch_current_game_state_embedding, embedding_type], dim=1)[:, None])

    def __prepare_possible_actions_embeddings(self, batch_possible_actions_embeddings: torch.Tensor) -> torch.Tensor:
        batch_size = batch_possible_actions_embeddings.shape[0]
        max_n_possible_actions_in_batch = batch_possible_actions_embeddings.shape[1]
        embeddings_type = torch.tensor([[[0.0, 0.0, 1.0]]]).repeat(batch_size, max_n_possible_actions_in_batch, 1).to(batch_possible_actions_embeddings)
        return self.preprocessing_mlp(torch.cat([batch_possible_actions_embeddings, embeddings_type], dim=2))


class DeepLearningScorerV1(BaseDeepLearningScorer):
    def __init__(
        self,
        game_state_global_dim: int,
        n_players: int,
        player_dim: int,
        max_n_zone_vectors: int,
        zone_vector_dim: int,
        action_general_dim: int,
        max_n_action_source_cards: int,
        max_n_action_target_cards: int,
        embedding_dim: int,
        transformer_n_layers: int,
        transformer_n_heads: int,
        transformer_dim_feedforward: int,
        dropout: float
    ):
        super().__init__()
        self.game_state_global_dim = game_state_global_dim
        self.n_players = n_players
        self.player_dim = player_dim
        self.max_n_zone_vectors = max_n_zone_vectors
        self.zone_vector_dim = zone_vector_dim
        self.action_general_dim = action_general_dim
        self.max_n_action_source_cards = max_n_action_source_cards
        self.max_n_action_target_cards = max_n_action_target_cards
        self.embedding_dim = embedding_dim
        self.transformer_n_layers = transformer_n_layers
        self.transformer_n_heads = transformer_n_heads
        self.transformer_dim_feedforward = transformer_dim_feedforward
        self.dropout = dropout

        # Modules
        self.action_processing_block = ActionProcessingBlock(
            action_general_dim=self.action_general_dim,
            max_n_action_source_cards=self.max_n_action_source_cards,
            max_n_action_target_cards=self.max_n_action_target_cards,
            zone_vector_dim=self.zone_vector_dim,
            output_dim=self.embedding_dim,
            transformer_n_layers=self.transformer_n_layers,
            transformer_n_heads=1,
            transformer_dim_feedforward=128,
            dropout=self.dropout,
        )
        self.game_state_processing_block = GameStateProcessingBlock(
            game_state_global_dim=self.game_state_global_dim,
            n_players=self.n_players,
            player_dim=self.player_dim,
            max_n_zone_vectors=self.max_n_zone_vectors,
            zone_vector_dim=self.zone_vector_dim,
            output_dim=self.embedding_dim,
            transformer_n_layers=self.transformer_n_layers,
            transformer_n_heads=self.transformer_n_heads,
            transformer_dim_feedforward=self.transformer_dim_feedforward,
            dropout=self.dropout,
        )
        self.classification_block = ClassificationBlock(
            input_dim=self.embedding_dim,
            transformer_n_layers=self.transformer_n_layers,
            transformer_n_heads=self.transformer_n_heads,
            transformer_dim_feedforward=self.transformer_dim_feedforward,
            dropout=self.dropout,
        )
        

    def forward(self, batch_action_history_vectors, batch_current_game_state_vectors, batch_possible_actions_vectors):
        batch_action_history_embeddings = self.__process_action_list(
            batch_action_list_vectors=batch_action_history_vectors,
            n_actions=batch_action_history_vectors["general"].shape[1]
        )

        batch_current_game_state_embedding = self.game_state_processing_block(batch_current_game_state_vectors)

        batch_possible_actions_embeddings = self.__process_action_list(
            batch_action_list_vectors=batch_possible_actions_vectors,
            n_actions=batch_possible_actions_vectors["general"].shape[1]
        )

        batch_predicted_target_action = self.classification_block(
            batch_action_history_embeddings,
            batch_current_game_state_embedding,
            batch_possible_actions_embeddings
        )

        return batch_predicted_target_action

    def __process_action_list(self, batch_action_list_vectors: torch.Tensor, n_actions: int) -> torch.Tensor:
        action_embeddings = []
        for i in range(n_actions):
            one_action_vectors = {key: tensor[:, i] for key, tensor in batch_action_list_vectors.items()}
            action_embedding = self.action_processing_block(one_action_vectors)
            action_embeddings.append(action_embedding[:, None])
        return torch.cat(action_embeddings, dim=1)

In [None]:
model = DeepLearningScorerV1(
    game_state_global_dim=game_state_global_dim,
    n_players=n_players,
    player_dim=player_dim,
    max_n_zone_vectors=max_n_zone_vectors,
    zone_vector_dim=zone_vector_dim,
    action_general_dim=action_general_dim,
    max_n_action_source_cards=max_n_action_source_cards,
    max_n_action_target_cards=max_n_action_target_cards,
    embedding_dim=embedding_dim,
    transformer_n_layers=transformer_n_layers,
    transformer_n_heads=transformer_n_heads,
    transformer_dim_feedforward=transformer_dim_feedforward,
    dropout=dropout
)

# Training

In [None]:
training_dataset, validation_dataset = torch.utils.data.random_split(
    deep_learning_dataset,
    lengths=[
        0.8,
        0.2,
    ],
)

In [None]:
training_data_loader = torch.utils.data.DataLoader(
    training_dataset,
    batch_size=batch_size,
    shuffle=True,
    drop_last=True,
    num_workers=0,
    collate_fn=lambda samples: pad_possible_actions_collate_fn(samples=samples, device=device)
)
validation_data_loader = torch.utils.data.DataLoader(
    validation_dataset,
    batch_size=batch_size,
    shuffle=False,
    drop_last=False,
    num_workers=0,
    collate_fn=lambda samples: pad_possible_actions_collate_fn(samples=samples, device=device)
)

In [None]:
# model_folder_path = "results/"

In [None]:
# Path(model_folder_path).mkdir(parents=True, exist_ok=True)
callbacks = [
    # ModelCheckpoint(
    #     dirpath=model_folder_path,
    #     filename="deep_learning_scorer",
    #     monitor="validation_loss",
    #     mode="min",
    #     save_top_k=1,
    #     verbose=False
    # ),
    EarlyStopping(
        monitor="validation_loss",
        mode="min",
        patience=early_stopping_patience,
        verbose=False
    ),
]
trainer = Trainer(
    max_epochs=n_epochs,
    devices="auto",
    deterministic=True,
    callbacks=callbacks,
)

In [None]:
trainer.fit(
    model=model,
    train_dataloaders=training_data_loader,
    val_dataloaders=validation_data_loader,
)