In [139]:
import pickle
import torch
import numpy as np
from typing import List, Dict, Tuple
from torch.utils.data import Dataset
from abc import abstractmethod
from pytorch_lightning import LightningModule

# Read one pickle

In [2]:
game_log_file_path = "../data/game_logs/game_9f9f5c4b-82a6-4051-8bb7-96b3431d569e.pickle"

In [3]:
with open(game_log_file_path, "rb") as f:
    data_dict = pickle.load(f)

In [4]:
data_dict.keys()

dict_keys(['game_id', 'dataset', 'winner_player_index'])

In [5]:
game_id = data_dict["game_id"]
dataset = data_dict["dataset"]
winner_player_index = data_dict["winner_player_index"]

In [6]:
game_id

'9f9f5c4b-82a6-4051-8bb7-96b3431d569e'

In [7]:
winner_player_index

0

In [8]:
item_dict = dataset[44]

In [9]:
item_dict.keys()

dict_keys(['action_history', 'current_game_state', 'possible_actions', 'chosen_action_index'])

In [10]:
action_list = item_dict["action_history"]
game_state_vectors = item_dict["current_game_state"]
possible_actions = item_dict["possible_actions"]
chosen_action_index = item_dict["chosen_action_index"]
source_player_index = possible_actions[chosen_action_index]["source_player_index"]

In [231]:
game_state_global_dim = game_state_vectors["global"].shape[0]

In [232]:
game_state_global_dim

2

In [236]:
n_players = game_state_vectors["players"].shape[0]

In [237]:
n_players

2

In [238]:
player_dim = game_state_vectors["players"].shape[1]

In [239]:
player_dim

8

In [140]:
action_general_dim = action_list[0]["general"].shape[0]

In [141]:
action_general_dim

31

In [46]:
zone_vector_dim = game_state_vectors["zones"].shape[1]

In [47]:
zone_vector_dim

34

In [11]:
action_list[0]

{'source_player_index': 0,
 'general': array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 1., 0.,
        1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]),
 'source_card_vectors': array([], dtype=float32),
 'target_card_vectors': array([[1., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 1., 0., 1., 0., 1., 0., 0., 0., 1., 0.,
         0., 0.]], dtype=float32)}

In [12]:
game_state_vectors

{'global': array([9., 1.], dtype=float32),
 'players': array([[19.,  0.,  0.,  0.,  0.,  0.,  1.,  1.],
        [20.,  0.,  0.,  0.,  0.,  0.,  1.,  0.]], dtype=float32),
 'zones': array([[1., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 1., 7., 1., 0.,
         0., 0., 0., 3., 0., 0., 0., 0., 0., 0., 1., 0., 1., 0., 0., 0.,
         0., 0.],
        [1., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 1., 0., 0., 0.,
         0., 0.],
        [0., 1., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 2., 1., 0., 1.,
         0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 1., 1., 0., 0., 0.,
         0., 0.],
        [0., 1., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 2., 5., 0., 1.,
         0., 0., 0., 3., 0., 0., 0., 0., 0., 0., 0., 1., 1., 0., 0., 0.,
         0., 0.],
        [0., 1., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 1., 4., 0., 2.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 0., 0., 0.,
         0., 0.],


In [13]:
possible_actions[2]

{'source_player_index': 0,
 'general': array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 1., 0.,
        1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]),
 'source_card_vectors': array([], dtype=float32),
 'target_card_vectors': array([[1., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 1., 0., 1., 0., 1., 0., 0., 0., 1., 0.,
         0., 0.]], dtype=float32)}

In [14]:
chosen_action_index

2

In [15]:
source_player_index

0

# Create a preprocessed dataset

In [34]:
def pad_tensor(vec, pad, dim):
    """
    args:
        vec - tensor to pad
        pad - the size to pad to
        dim - dimension to pad

    return:
        a new tensor padded to 'pad' in dimension 'dim'
    """
    if pad > vec.size(dim):
        pad_size = list(vec.shape)
        pad_size[dim] = pad - vec.size(dim)
        return torch.cat([vec, torch.zeros(*pad_size)], dim=dim)
    return torch.from_numpy(vec.numpy().take(torch.arange(pad), axis=dim))

In [35]:
x = torch.from_numpy(np.array([
    [1, 2, 3],
    [4, 5, 6]
]))

In [153]:
class DeepLearningDataset(Dataset):
    def __init__(
        self,
        player_dataset: List[Dict],
        zone_vector_dim: int,
        max_n_zone_vectors: int,
        max_n_action_source_cards: int,
        max_n_action_target_cards: int
    ):
        super().__init__()
        self.player_dataset = player_dataset
        self.zone_vector_dim = zone_vector_dim
        self.max_n_zone_vectors = max_n_zone_vectors
        self.max_n_action_source_cards = max_n_action_source_cards
        self.max_n_action_target_cards = max_n_action_target_cards

    def __len__(self) -> int:
        return len(self.player_dataset)

    def __getitem__(self, idx: int) -> Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor], Dict[str, torch.Tensor], torch.Tensor]:
        item_dict = self.player_dataset[idx]
        action_history = item_dict["action_history"]
        current_game_state = item_dict["current_game_state"]
        possible_actions = item_dict["possible_actions"]
        chosen_action_index = item_dict["chosen_action_index"]

        action_history_vectors = self.__get_action_history_vectors(action_history)
        current_game_state_vectors = self.__get_current_game_state_vectors(current_game_state)
        possible_actions_vectors = self.__get_possible_actions_vectors(possible_actions)
        target_action = self.__get_target_action(
            n_possible_actions=len(possible_actions),
            chosen_action_index=chosen_action_index
        )

        return (
            action_history_vectors,
            current_game_state_vectors,
            possible_actions_vectors,
            target_action
        )

    def __action_list_to_tensors(self, action_list: List[Dict[str, np.ndarray]]) -> Dict[str, torch.Tensor]:
        """
        Inputs:

        action_list:
        [{
          general: (action_dim,)
          source_card_vectors: (n_action_source_cards, zone_vector_dim)
          target_card_vectors: (n_action_target_cards, zone_vector_dim)
        }] * n_actions

        Return:

        action_list_vectors
        - general: (n_actions, action_dim)
        - source_card_vectors: (n_actions, max_n_action_source_cards, zone_vector_dim)
        - target_card_vectors: (n_actions, max_n_action_target_cards, zone_vector_dim)
        """
        general_vectors = []
        source_card_vectors = []
        target_card_vectors = []

        for action_dict in action_list:
            general = torch.from_numpy(action_dict["general"])
            source_cards = torch.from_numpy(action_dict["source_card_vectors"])
            if len(source_cards) == 0:
                source_cards = torch.zeros(size=(self.max_n_action_source_cards, self.zone_vector_dim))
            else:
                source_cards = pad_tensor(
                    source_cards,
                    pad=self.max_n_action_source_cards,
                    dim=0
                )

            target_cards = torch.from_numpy(action_dict["target_card_vectors"])
            if len(target_cards) == 0:
                target_cards = torch.zeros(size=(self.max_n_action_target_cards, self.zone_vector_dim))
            else:
                target_cards = pad_tensor(
                    target_cards,
                    pad=self.max_n_action_target_cards,
                    dim=0
                )

            general_vectors.append(general[None])
            source_card_vectors.append(source_cards[None])
            target_card_vectors.append(target_cards[None])

        return {
            "general": torch.cat(general_vectors, dim=0).float(),
            "source_card_vectors": torch.cat(source_card_vectors, dim=0).float(),
            "target_card_vectors": torch.cat(target_card_vectors, dim=0).float(),
        }

    def __get_action_history_vectors(self, action_history: List[Dict[str, np.ndarray]]) -> Dict[str, torch.Tensor]:
        """
        Inputs:

        action_history:
        [{
          general: (action_dim,)
          source_card_vectors: (n_action_source_cards, zone_vector_dim)
          target_card_vectors: (n_action_target_cards, zone_vector_dim)
        }] * history_size

        Return:

        action_history_vectors
        - general: (history_size, action_dim)
        - source_card_vectors: (history_size, max_n_action_source_cards, zone_vector_dim)
        - target_card_vectors: (history_size, max_n_action_target_cards, zone_vector_dim)
        """
        return self.__action_list_to_tensors(action_list=action_history)

    def __get_current_game_state_vectors(self, current_game_state: Dict[str, torch.Tensor]):
        """
        Inputs:

        current_game_state:
        {
            global: (global_dim,)
            players: (n_players, player_dim)
            zones: (n_zone_vectors, zone_vector_dim)
        }

        Return:

        current_game_state_vectors:
        - global: (global_dim,)
        - players: (n_players, player_dim)
        - zones: (max_n_zone_vectors, zone_vector_dim)
        """
        return {
            "global": torch.from_numpy(current_game_state["global"]).float(),
            "players": torch.from_numpy(current_game_state["players"]).float(),
            "zones": pad_tensor(
                torch.from_numpy(current_game_state["zones"]),
                pad=self.max_n_zone_vectors,
                dim=0
            ).float()
        }

    def __get_possible_actions_vectors(self, possible_actions: List[Dict[str, np.ndarray]]) -> Dict[str, torch.Tensor]:
        """
        Inputs:

        possible_actions:
        [{
          general: (action_dim,)
          source_card_vectors: (n_action_source_cards, zone_vector_dim)
          target_card_vectors: (n_action_target_cards, zone_vector_dim)
        }] * n_possible_actions

        Return:

        possible_actions_vectors:
        - general: (n_possible_actions, action_dim)
        - source_card_vectors: (n_possible_actions, max_n_action_source_cards, zone_vector_dim)
        - target_card_vectors: (n_possible_actions, max_n_action_target_cards, zone_vector_dim)
        """
        return self.__action_list_to_tensors(action_list=possible_actions)

    def __get_target_action(self, n_possible_actions: int, chosen_action_index: int) -> torch.Tensor:
        """
        Return:

        target_action: (n_possible_actions,)
        """
        target_action = torch.zeros(n_possible_actions).float()
        target_action[chosen_action_index] = 1
        return target_action

In [154]:
max_n_zone_vectors = 120
max_n_action_source_cards = 10
max_n_action_target_cards = 10

In [155]:
deep_learning_dataset = DeepLearningDataset(
    player_dataset=dataset,
    zone_vector_dim=zone_vector_dim,
    max_n_zone_vectors=max_n_zone_vectors,
    max_n_action_source_cards=max_n_action_source_cards,
    max_n_action_target_cards=max_n_action_target_cards
)

In [156]:
len(deep_learning_dataset)

183

In [157]:
action_history_vectors, current_game_state_vectors, possible_actions_vectors, target_action = deep_learning_dataset[44]

In [158]:
action_history_vectors.keys()

dict_keys(['general', 'source_card_vectors', 'target_card_vectors'])

In [159]:
action_history_vectors["general"].shape

torch.Size([10, 31])

In [160]:
action_history_vectors["source_card_vectors"].shape

torch.Size([10, 10, 34])

In [161]:
action_history_vectors["target_card_vectors"].shape

torch.Size([10, 10, 34])

In [162]:
current_game_state_vectors.keys()

dict_keys(['global', 'players', 'zones'])

In [163]:
current_game_state_vectors["global"].shape

torch.Size([2])

In [164]:
current_game_state_vectors["players"].shape

torch.Size([2, 8])

In [165]:
current_game_state_vectors["zones"].shape

torch.Size([120, 34])

In [166]:
possible_actions_vectors.keys()

dict_keys(['general', 'source_card_vectors', 'target_card_vectors'])

In [167]:
possible_actions_vectors["general"].shape

torch.Size([3, 31])

In [168]:
possible_actions_vectors["source_card_vectors"].shape

torch.Size([3, 10, 34])

In [169]:
possible_actions_vectors["target_card_vectors"].shape

torch.Size([3, 10, 34])

In [170]:
target_action

tensor([0., 0., 1.])

In [171]:
possible_actions

[{'source_player_index': 0,
  'general': array([1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]),
  'source_card_vectors': array([], dtype=float32),
  'target_card_vectors': array([], dtype=float32)},
 {'source_player_index': 0,
  'general': array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 1., 0.,
         1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]),
  'source_card_vectors': array([], dtype=float32),
  'target_card_vectors': array([[1., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 2., 3., 2., 0.,
          0., 0., 0., 0., 0., 0., 1., 0., 1., 0., 1., 0., 0., 0., 1., 0.,
          0., 0.]], dtype=float32)},
 {'source_player_index': 0,
  'general': array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 1., 0.,
         1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]),
  'source_card_vectors': array([], dtype=float32),
  'target_card_vectors': array([[1., 0

In [172]:
def pad_possible_actions_collate_fn(
    samples: List[Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor], Dict[str, torch.Tensor], torch.Tensor]]
) -> Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor], Dict[str, torch.Tensor], torch.Tensor]:
    batch_action_history_vectors = {}
    batch_current_game_state_vectors = {}
    batch_possible_actions_vectors = {}
    batch_target_action = []

    max_n_possible_actions = max([sample[2]["general"].shape[0] for sample in samples])

    for sample in samples:
        for key, tensor in sample[0].items():
            if key not in batch_action_history_vectors:
                batch_action_history_vectors[key] = []
            batch_action_history_vectors[key].append(tensor[None])
        for key, tensor in sample[1].items():
            if key not in batch_current_game_state_vectors:
                batch_current_game_state_vectors[key] = []
            batch_current_game_state_vectors[key].append(tensor[None])
        for key, tensor in sample[2].items():
            if key not in batch_possible_actions_vectors:
                batch_possible_actions_vectors[key] = []
            tensor = pad_tensor(tensor, pad=max_n_possible_actions, dim=0)
            batch_possible_actions_vectors[key].append(tensor[None])
        batch_target_action.append(
            pad_tensor(
                sample[3],
                pad=max_n_possible_actions,
                dim=0
            )[None]
        )

    for key, tensors in batch_action_history_vectors.items():
        batch_action_history_vectors[key] = torch.cat(tensors, dim=0)
    for key, tensors in batch_current_game_state_vectors.items():
        batch_current_game_state_vectors[key] = torch.cat(tensors, dim=0)
    for key, tensors in batch_possible_actions_vectors.items():
        batch_possible_actions_vectors[key] = torch.cat(tensors, dim=0)
    batch_target_action = torch.cat(batch_target_action, dim=0)
    
    return (
        batch_action_history_vectors,
        batch_current_game_state_vectors,
        batch_possible_actions_vectors,
        batch_target_action
    )

In [173]:
dataloader = torch.utils.data.DataLoader(
    deep_learning_dataset,
    batch_size=5,
    collate_fn=pad_possible_actions_collate_fn
)

In [174]:
batch_action_history_vectors, batch_current_game_state_vectors, batch_possible_actions_vectors, batch_target_action = next(iter(dataloader))

In [175]:
batch_action_history_vectors.keys()

dict_keys(['general', 'source_card_vectors', 'target_card_vectors'])

In [176]:
batch_action_history_vectors["general"].shape

torch.Size([5, 10, 31])

In [177]:
batch_action_history_vectors["source_card_vectors"].shape

torch.Size([5, 10, 10, 34])

In [178]:
batch_action_history_vectors["target_card_vectors"].shape

torch.Size([5, 10, 10, 34])

In [179]:
batch_current_game_state_vectors.keys()

dict_keys(['global', 'players', 'zones'])

In [180]:
batch_current_game_state_vectors["global"].shape

torch.Size([5, 2])

In [181]:
batch_current_game_state_vectors["players"].shape

torch.Size([5, 2, 8])

In [182]:
batch_current_game_state_vectors["zones"].shape

torch.Size([5, 120, 34])

In [183]:
batch_possible_actions_vectors["general"].shape

torch.Size([5, 5, 31])

In [184]:
batch_possible_actions_vectors["source_card_vectors"].shape

torch.Size([5, 5, 10, 34])

In [185]:
batch_possible_actions_vectors["target_card_vectors"].shape

torch.Size([5, 5, 10, 34])

In [186]:
batch_target_action

tensor([[0., 1., 0., 0., 0.],
        [1., 0., 0., 0., 0.],
        [1., 0., 0., 0., 0.],
        [1., 0., 0., 0., 0.],
        [0., 0., 0., 0., 1.]])

# Implement deep learning model

In [187]:
class BaseDeepLearningScorer(LightningModule):
    def __init__(self):
        super().__init__()
        self.loss = torch.nn.BCELoss()

    @abstractmethod
    def forward(self, batch_action_history_vectors, batch_current_game_state_vectors, batch_possible_actions_vectors):
        """
        Inputs:

        batch_action_history_vectors:
        {
            general: (batch_size, history_size, action_general_dim)
            source_card_vectors: (batch_size, history_size, max_n_action_source_cards, zone_vector_dim)
            target_card_vectors: (batch_size, history_size, max_n_action_target_cards, zone_vector_dim)
        }

        batch_current_game_state_vectors:
        {
            global: (batch_size, game_state_global_dim)
            players: (batch_size, n_players, player_dim)
            zones: (batch_size, max_n_zone_vectors, zone_vector_dim)
        }

        batch_possible_actions_vectors:
        {
            general: (batch_size, max_n_possible_actions_in_batch, action_general_dim)
            source_card_vectors: (batch_size, max_n_possible_actions_in_batch, max_n_action_source_cards, zone_vector_dim)
            target_card_vectors: (batch_size, max_n_possible_actions_in_batch, max_n_action_target_cards, zone_vector_dim)
        }

        Returns:
        - batch_predicted_target_action: (batch_size, max_n_possible_actions_in_batch)
        """
        raise NotImplementedError

    def __step(self, batch, batch_idx, base_metric_name):
        batch_action_history_vectors, batch_current_game_state_vectors, batch_possible_actions_vectors, batch_target_action = batch
        batch_predicted_target_action = self.forward(
            batch_action_history_vectors, batch_current_game_state_vectors, batch_possible_actions_vectors
        )
        batch_loss = self.loss(batch_predicted_target_action, batch_target_action)  # FIXME: Mulitply with a mask
        self.log(f"{base_metric_name}_loss", batch_loss, on_epoch=True)
        return batch_loss

    def training_step(self, batch, batch_idx):
        return self.__step(batch=batch, batch_idx=batch_idx, base_metric_name="training")

    def validation_step(self, batch, batch_idx):
        return self.__step(batch=batch, batch_idx=batch_idx, base_metric_name="validation")

    def predict_step(self, batch, batch_idx, dataloader_idx=0):
        batch_game_state_vectors, batch_action_vectors = batch
        batch_predicted_scores = self.forward(
            batch_game_state_vectors=batch_game_state_vectors, batch_action_vectors=batch_action_vectors
        )
        return batch_predicted_scores

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters())

    def get_n_parameters(self):
        return sum(p.numel() for p in self.parameters())

In [241]:
class ActionProcessingBlock(torch.nn.Module):
    def __init__(
        self,
        action_general_dim: int,
        max_n_action_source_cards: int,
        max_n_action_target_cards: int,
        zone_vector_dim: int,
        output_dim: int,
        transformer_n_layers: int = 1,
        transformer_n_heads: int = 1,
        transformer_dim_feedforward: int = 128,
        dropout: float = 0.0,
    ):
        super().__init__()
        self.action_general_dim = action_general_dim
        self.max_n_action_source_cards = max_n_action_source_cards
        self.max_n_action_target_cards = max_n_action_target_cards
        self.zone_vector_dim = zone_vector_dim
        assert output_dim > 3
        self.output_dim = output_dim
        self.transformer_n_layers = transformer_n_layers
        self.transformer_n_heads = transformer_n_heads
        self.transformer_dim_feedforward = transformer_dim_feedforward
        self.dropout = dropout

        # Modules
        self.general_mlp = torch.nn.Sequential(
            torch.nn.Linear(in_features=self.action_general_dim, out_features=self.output_dim - 3),
            torch.nn.ReLU()
        )
        self.card_mlp = torch.nn.Sequential(
            torch.nn.Linear(in_features=self.zone_vector_dim, out_features=self.output_dim - 3),
            torch.nn.ReLU()
        )
        self.transformer_encoder = torch.nn.TransformerEncoder(
            encoder_layer=torch.nn.TransformerEncoderLayer(
                d_model=self.output_dim,
                nhead=self.transformer_n_heads,
                dim_feedforward=self.transformer_dim_feedforward,
                dropout=self.dropout,
                activation="relu",
                batch_first=True
            ),
            num_layers=self.transformer_n_layers
        )

    def forward(self, action_vectors: Dict[str, torch.Tensor]):
        """
        Inputs:
        - action_vectors:
        {
            general: (batch_size, action_general_dim)
            source_card_vectors: (batch_size, max_n_action_source_cards, zone_vector_dim)
            target_card_vectors: (batch_size, max_n_action_target_cards, zone_vector_dim)
        }

        Outputs:
        - action_embedding: (batch_size, output_dim)
        """
        batch_size = action_vectors["general"].shape[0]

        action_general_embedding = self.__get_action_general_embedding(action_vectors["general"])
        action_source_card_embeddings = self.__get_action_source_card_embeddings(action_vectors["source_card_vectors"])
        action_target_card_embeddings = self.__get_action_target_card_embeddings(action_vectors["target_card_vectors"])

        action_embedding_for_prediction = torch.zeros(batch_size, 1, self.output_dim)

        action_embeddings_sequence = torch.cat(
            [
                action_embedding_for_prediction,
                action_general_embedding[:, None],
                action_source_card_embeddings,
                action_target_card_embeddings
            ],
            dim=1
        )

        action_embeddings_sequence_after_transformer = self.transformer_encoder(action_embeddings_sequence)

        return action_embeddings_sequence_after_transformer[:, 0, :]

    def __get_action_general_embedding(self, action_general_vector: torch.Tensor) -> torch.Tensor:
        batch_size = action_general_vector.shape[0]
        action_general_embedding = self.general_mlp(action_general_vector)
        action_general_embedding_type = torch.tensor([[1.0, 0.0, 0.0]]).repeat(batch_size, 1)
        return torch.cat([action_general_embedding, action_general_embedding_type], dim=1)

    def __get_action_source_card_embeddings(self, action_source_card_vectors: torch.Tensor) -> torch.Tensor:
        batch_size = action_source_card_vectors.shape[0]
        action_source_card_embeddings = self.card_mlp(action_source_card_vectors)
        action_source_card_embeddings_type = torch.tensor([[[0.0, 1.0, 0.0]]]).repeat(batch_size, self.max_n_action_source_cards, 1)
        return torch.cat([action_source_card_embeddings, action_source_card_embeddings_type], dim=2)

    def __get_action_target_card_embeddings(self, action_target_card_vectors: torch.Tensor) -> torch.Tensor:
        batch_size = action_target_card_vectors.shape[0]
        action_target_card_embeddings = self.card_mlp(action_target_card_vectors)
        action_target_card_embeddings_type = torch.tensor([[[0.0, 0.0, 1.0]]]).repeat(batch_size, self.max_n_action_target_cards, 1)
        return torch.cat([action_target_card_embeddings, action_target_card_embeddings_type], dim=2)


class GameStateProcessingBlock(torch.nn.Module):
    def __init__(
        self,
        game_state_global_dim: int,
        n_players: int,
        player_dim: int,
        max_n_zone_vectors: int,
        zone_vector_dim: int,
        output_dim: int,
        transformer_n_layers: int = 1,
        transformer_n_heads: int = 1,
        transformer_dim_feedforward: int = 128,
        dropout: float = 0.0,
    ):
        super().__init__()
        self.game_state_global_dim = game_state_global_dim
        self.n_players = n_players
        self.player_dim = player_dim
        self.max_n_zone_vectors = max_n_zone_vectors
        self.zone_vector_dim = zone_vector_dim
        assert output_dim > 3
        self.output_dim = output_dim
        self.transformer_n_layers = transformer_n_layers
        self.transformer_n_heads = transformer_n_heads
        self.transformer_dim_feedforward = transformer_dim_feedforward
        self.dropout = dropout

        # Modules
        self.global_mlp = torch.nn.Sequential(
            torch.nn.Linear(in_features=self.game_state_global_dim, out_features=self.output_dim - 3),
            torch.nn.ReLU()
        )
        self.player_mlp = torch.nn.Sequential(
            torch.nn.Linear(in_features=self.player_dim, out_features=self.output_dim - 3),
            torch.nn.ReLU()
        )
        self.zone_mlp = torch.nn.Sequential(
            torch.nn.Linear(in_features=self.zone_vector_dim, out_features=self.output_dim - 3),
            torch.nn.ReLU()
        )
        self.transformer_encoder = torch.nn.TransformerEncoder(
            encoder_layer=torch.nn.TransformerEncoderLayer(
                d_model=self.output_dim,
                nhead=self.transformer_n_heads,
                dim_feedforward=self.transformer_dim_feedforward,
                dropout=self.dropout,
                activation="relu",
                batch_first=True
            ),
            num_layers=self.transformer_n_layers
        )

    def forward(self, game_state_vectors: Dict[str, torch.Tensor]):
        """
        Inputs:
        - game_state_vectors:
        {
            global: (batch_size, game_state_global_dim)
            players: (batch_size, n_players, player_dim)
            zones: (batch_size, max_n_zone_vectors, zone_vector_dim)
        }

        Outputs:
        - game_state_embedding: (batch_size, output_dim)
        """
        batch_size = game_state_vectors["global"].shape[0]

        global_embedding = self.__get_global_embedding(game_state_vectors["global"])
        player_embeddings = self.__get_player_embeddings(game_state_vectors["players"])
        zone_embeddings = self.__get_zone_embeddings(game_state_vectors["zones"])

        embedding_for_prediction = torch.zeros(batch_size, 1, self.output_dim)

        embeddings_sequence = torch.cat(
            [
                embedding_for_prediction,
                global_embedding[:, None],
                player_embeddings,
                zone_embeddings
            ],
            dim=1
        )

        embeddings_sequence_after_transformer = self.transformer_encoder(embeddings_sequence)

        return embeddings_sequence_after_transformer[:, 0, :]

    def __get_global_embedding(self, game_state_global_vector: torch.Tensor) -> torch.Tensor:
        batch_size = game_state_global_vector.shape[0]
        global_embedding = self.global_mlp(game_state_global_vector)
        global_embedding_type = torch.tensor([[1.0, 0.0, 0.0]]).repeat(batch_size, 1)
        return torch.cat([global_embedding, global_embedding_type], dim=1)

    def __get_player_embeddings(self, game_state_player_vectors: torch.Tensor) -> torch.Tensor:
        batch_size = game_state_player_vectors.shape[0]
        player_embeddings = self.player_mlp(game_state_player_vectors)
        player_embeddings_type = torch.tensor([[[0.0, 1.0, 0.0]]]).repeat(batch_size, self.n_players, 1)
        return torch.cat([player_embeddings, player_embeddings_type], dim=2)

    def __get_zone_embeddings(self, game_state_zone_vectors: torch.Tensor) -> torch.Tensor:
        batch_size = game_state_zone_vectors.shape[0]
        zone_embeddings = self.zone_mlp(game_state_zone_vectors)
        zone_embeddings_type = torch.tensor([[[0.0, 0.0, 1.0]]]).repeat(batch_size, self.max_n_zone_vectors, 1)
        return torch.cat([zone_embeddings, zone_embeddings_type], dim=2)




class DeepLearningScorerV1(BaseDeepLearningScorer):
    def __init__(
        self,
        action_general_dim: int,
        max_n_action_source_cards: int,
        max_n_action_target_cards: int,
        zone_vector_dim: int,
        embedding_size: int,
    ):
        super().__init__()
        self.action_general_dim = action_general_dim
        self.max_n_action_source_cards = max_n_action_source_cards
        self.max_n_action_target_cards = max_n_action_target_cards
        self.zone_vector_dim = zone_vector_dim
        self.embedding_size = embedding_size

        # Modules
        self.action_processing_block = ActionProcessingBlock(
            action_general_dim=action_general_dim,
            max_n_action_source_cards=max_n_action_source_cards,
            max_n_action_target_cards=max_n_action_target_cards,
            zone_vector_dim=zone_vector_dim,
            output_dim=self.embedding_size,
            transformer_n_layers=1,
            transformer_n_heads=1,
            transformer_dim_feedforward=128,
        )

    def forward(self, batch_action_history_vectors, batch_current_game_state_vectors, batch_possible_actions_vectors):
        batch_action_history_embeddings = self.__process_action_list(
            batch_action_list_vectors=batch_action_history_vectors
        )

        batch_possible_actions_embeddings = self.__process_action_list(
            batch_possible_actions_vectors=batch_possible_actions_vectors
        )

    def __process_action_list(self, batch_action_list_vectors: torch.Tensor) -> torch.Tensor:
        n_actions = batch_action_list_vectors.shape[1]
        action_embeddings = []
        for _ in range(n_actions):
            one_action_vectors = {key: tensor[:, 0] for key, tensor in batch_action_list_vectors.items()}
            action_embedding = self.action_processing_block(one_action_vectors)
            action_embeddings.append(action_embedding[:, None])
        return torch.cat(action_embeddings)

In [242]:
game_state_processing_block = GameStateProcessingBlock(
    game_state_global_dim=game_state_global_dim,
    n_players=n_players,
    player_dim=player_dim,
    max_n_zone_vectors=max_n_zone_vectors,
    zone_vector_dim=zone_vector_dim,
    output_dim=64,
)

In [243]:
game_state_embedding = game_state_processing_block(batch_current_game_state_vectors)

In [245]:
game_state_embedding.shape

torch.Size([5, 64])