In [1]:
import sys
import h5py
import torch
import mlflow
import numpy as np
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
from pytorch_lightning import Trainer

sys.path.append("../")

from magic_the_gathering.players.deep_learning_based.models.v1 import DeepLearningScorerV1

# Load H5 data

In [2]:
h5_file_path = "../data/preprocessed_game_logs.h5"

In [3]:
h5_file = h5py.File(h5_file_path, "r")

In [4]:
h5_file.keys()

<KeysViewHDF5 ['action', 'game_id', 'game_state', 'label']>

In [5]:
n_players = h5_file["game_state"]["players"].shape[1]
player_dim = h5_file["game_state"]["players"].shape[-1]
card_dim = h5_file["game_state"]["zones"].shape[-1]
action_general_dim = h5_file["action"]["general"].shape[-1]

In [6]:
print(f"n_players: {n_players}")
print(f"player_dim: {player_dim}")
print(f"card_dim: {card_dim}")
print(f"action_general_dim: {action_general_dim}")

n_players: 2
player_dim: 8
card_dim: 34
action_general_dim: 16


In [7]:
mlflow.log_param("n_players", n_players)
mlflow.log_param("player_dim", player_dim)
mlflow.log_param("card_dim", card_dim)
mlflow.log_param("action_general_dim", action_general_dim)

16

In [29]:
print(f"Number of label -1: {np.sum(h5_file['label'][:] == -1)}")
print(f"Number of label 0: {np.sum(h5_file['label'][:] == 0)}")
print(f"Number of label 1: {np.sum(h5_file['label'][:] == 1)}")

Number of label -1: 0
Number of label 0: 656
Number of label 1: 1106


# Define model

In [8]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [9]:
final_common_dim = 32

In [10]:
mlflow.log_param("final_common_dim", final_common_dim)

32

In [11]:
mlflow.pytorch.autolog()



In [12]:
model = DeepLearningScorerV1(
    n_players=n_players,
    player_dim=player_dim,
    card_dim=card_dim,
    action_general_dim=action_general_dim,
    final_common_dim=final_common_dim
).to(device)

In [13]:
model

DeepLearningScorerV1(
  (loss): BCELoss()
  (players_mlp): PlayersMLP(
    (fc1): Linear(in_features=16, out_features=32, bias=True)
    (relu): ReLU()
  )
  (zones_transformer_encoder): ZonesTransformerEncoder(
    (initial_fc): Linear(in_features=34, out_features=32, bias=True)
    (relu): ReLU()
    (transformer_encoder): TransformerEncoder(
      (layers): ModuleList(
        (0-1): 2 x TransformerEncoderLayer(
          (self_attn): MultiheadAttention(
            (out_proj): NonDynamicallyQuantizableLinear(in_features=32, out_features=32, bias=True)
          )
          (linear1): Linear(in_features=32, out_features=32, bias=True)
          (dropout): Dropout(p=0.1, inplace=False)
          (linear2): Linear(in_features=32, out_features=32, bias=True)
          (norm1): LayerNorm((32,), eps=1e-05, elementwise_affine=True)
          (norm2): LayerNorm((32,), eps=1e-05, elementwise_affine=True)
          (dropout1): Dropout(p=0.1, inplace=False)
          (dropout2): Dropout(p=0.1

# Define datasets

In [14]:
class MTGDataset(torch.utils.data.Dataset):
    def __init__(self, h5_file_path, device, return_labels=True):
        self.h5_file_path = h5_file_path
        self.h5_file = h5py.File(self.h5_file_path, "r")
        self.return_labels = return_labels
        self.device = device
        self.indices_with_a_label = np.where(self.h5_file["label"][:] != -1)[0]

    def __len__(self):
        return len(self.indices_with_a_label)

    def __getitem__(self, idx):
        idx = self.indices_with_a_label[idx]
        batch_game_state_vectors = {
            "global": torch.from_numpy(self.h5_file["game_state"]["global"][idx]).to(self.device),
            "players": torch.from_numpy(self.h5_file["game_state"]["players"][idx]).to(self.device),
            "zones": torch.from_numpy(self.h5_file["game_state"]["zones"][idx]).to(self.device),
            "zones_padding_mask": torch.from_numpy(self.h5_file["game_state"]["zones_padding_mask"][idx]).to(self.device),
        }
        batch_action_vectors = {
            "general": torch.from_numpy(self.h5_file["action"]["general"][idx]).to(self.device),
            "source_card_uuids": torch.from_numpy(self.h5_file["action"]["source_card_uuids"][idx]).to(self.device),
            "source_card_uuids_padding_mask": torch.from_numpy(self.h5_file["action"]["source_card_uuids_padding_mask"][idx]).to(self.device),
            "target_card_uuids": torch.from_numpy(self.h5_file["action"]["target_card_uuids"][idx]).to(self.device),
            "target_card_uuids_padding_mask": torch.from_numpy(self.h5_file["action"]["target_card_uuids_padding_mask"][idx]).to(self.device),
        }
        if self.return_labels:
            batch_labels = torch.from_numpy(self.h5_file["label"][[idx]]).float().to(self.device)
            return batch_game_state_vectors, batch_action_vectors, batch_labels
        return batch_game_state_vectors, batch_action_vectors

In [15]:
dataset = MTGDataset(
    h5_file_path=h5_file_path,
    return_labels=True,
    device=device
)

In [16]:
dataset.indices_with_a_label

array([   0,    1,    2, ..., 1759, 1760, 1761])

In [17]:
dataset[100]

({'global': tensor([12.,  1.], device='cuda:0'),
  'players': tensor([[20.,  0.,  0.,  0.,  0.,  0.,  1.,  0.],
          [18.,  0.,  0.,  3.,  0.,  0.,  1.,  1.]], device='cuda:0'),
  'zones': tensor([[0., 1., 0.,  ..., 0., 0., 0.],
          [0., 1., 0.,  ..., 0., 0., 0.],
          [0., 1., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]], device='cuda:0'),
  'zones_padding_mask': tensor([ True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
           True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
           True,  True,  True,  True,  True, False, False, False, False, False,
          False, False, False, False, False, False, False, False, False, False,
          False, False, False, False, False], device='cuda:0')},
 {'general': tensor([0., 1., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         device='cuda:0'),
  

In [18]:
training_dataset, validation_dataset = torch.utils.data.random_split(dataset, [0.8, 0.2])

# Define data loaders

In [19]:
batch_size = 10

In [20]:
training_data_loader = torch.utils.data.DataLoader(
    training_dataset,
    batch_size=batch_size,
    shuffle=True,
    drop_last=True,
    num_workers=0,
)

In [21]:
validation_data_loader = torch.utils.data.DataLoader(
    validation_dataset,
    batch_size=batch_size,
    shuffle=False,
    drop_last=False,
    num_workers=0,
)

# Define Trainer

In [22]:
n_epochs = 100

In [23]:
callbacks = [
    ModelCheckpoint(
        dirpath="models/",
        filename="deep_learning_scorer_v1",
        monitor="validation_loss",
        mode="min",
        save_top_k=1,
        verbose=False,
    ),
    EarlyStopping(
        mode="min",
        verbose=False,
        monitor="validation_loss",
        patience=10,
    ),
]

In [24]:
trainer = Trainer(
    max_epochs=n_epochs,
    devices="auto",
    deterministic=True,
    callbacks=callbacks,
)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


# Start training

In [25]:
trainer.fit(
    model=model,
    train_dataloaders=training_data_loader,
    val_dataloaders=validation_data_loader,
)

You are using a CUDA device ('NVIDIA GeForce RTX 3060') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name                       | Type                     | Params
------------------------------------------------------------------------
0 | loss                       | BCELoss                  | 0     
1 | players_mlp                | PlayersMLP               | 544   
2 | zones_transformer_encoder  | ZonesTransformerEncoder  | 14.0 K
3 | action_general_mlp         | ActionGeneralMLP         | 493   
4 | action_card_mlp            | ActionCardMLP            | 1.0 K 
5 | action_transformer_encoder | ActionTr

Sanity Checking: 0it [00:00, ?it/s]

  rank_zero_warn(
  rank_zero_warn(


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

`Trainer.fit` stopped: `max_epochs=100` reached.
