In [1]:
import sys

sys.path.append("..")

import torch
import pickle
from magic_the_gathering.players.deep_learning_based.models.v1 import DeepLearningScorerV1

In [2]:
pickle_path = "../logs/game_382bda50-8299-40fe-b5e8-bfb4e49014b1.pickle"

In [3]:
with open(pickle_path, "rb") as f:
    pickle_dict = pickle.load(f)

In [4]:
len(pickle_dict["dataset"])

1230

In [5]:
indices = [250, 251, 252]

In [6]:
game_states = [pickle_dict["dataset"][index]["game_state"] for index in indices]
actions = [pickle_dict["dataset"][index]["action"] for index in indices]

In [7]:
game_state_vectors_numpy = [game_state.to_vectors() for game_state in game_states]

In [8]:
n_players = game_state_vectors_numpy[0]["players"].shape[0]
player_dim = game_state_vectors_numpy[0]["players"].shape[1]
card_dim = game_state_vectors_numpy[0]["zones"].shape[1]

In [32]:
print(n_players, player_dim, card_dim)

2 8 34


In [9]:
action_vectors_numpy = [action.to_vectors(game_state=game_state) for action, game_state in zip(actions, game_states)]

In [10]:
action_general_dim = action_vectors_numpy[0]["general"].shape[0]

In [31]:
print(action_general_dim)

16


In [11]:
batch_game_state_vectors_torch = {
    "global": torch.tensor([game_state_vector["global"] for game_state_vector in game_state_vectors_numpy]).float(),
    "players": torch.tensor([game_state_vector["players"] for game_state_vector in game_state_vectors_numpy]).float(),
    "zones": torch.tensor([game_state_vector["zones"] for game_state_vector in game_state_vectors_numpy]).float(),
}

  "global": torch.tensor([game_state_vector["global"] for game_state_vector in game_state_vectors_numpy]).float(),


In [12]:
batch_game_state_vectors_torch["global"].shape, batch_game_state_vectors_torch["players"].shape, batch_game_state_vectors_torch["zones"].shape

(torch.Size([3, 2]), torch.Size([3, 2, 8]), torch.Size([3, 120, 34]))

In [13]:
batch_action_vectors_torch = {
    "general": torch.tensor([action_vector["general"] for action_vector in action_vectors_numpy]).float(),
    "source_card_uuids": torch.tensor([action_vector["source_card_uuids"] for action_vector in action_vectors_numpy]).bool(),
    "target_card_uuids": torch.tensor([action_vector["target_card_uuids"] for action_vector in action_vectors_numpy]).bool(),
}

In [14]:
batch_action_vectors_torch["general"].shape, batch_action_vectors_torch["source_card_uuids"].shape, batch_action_vectors_torch["target_card_uuids"].shape

(torch.Size([3, 16]), torch.Size([3, 120]), torch.Size([3, 120]))

In [15]:
max_n_cards = 128
max_action_sequence_length = 16
final_common_dim = 32

In [16]:
scorer = DeepLearningScorerV1(
    n_players=n_players,
    player_dim=player_dim,
    card_dim=card_dim,
    action_general_dim=action_general_dim,
    max_n_cards=max_n_cards,
    max_action_sequence_length=max_action_sequence_length,
    final_common_dim=final_common_dim,
)

In [19]:
scorer.eval()

DeepLearningScorerV1(
  (loss): BCELoss()
  (players_mlp): PlayersMLP(
    (fc1): Linear(in_features=16, out_features=32, bias=True)
    (relu): ReLU()
  )
  (zones_transformer_encoder): ZonesTransformerEncoder(
    (initial_fc): Linear(in_features=34, out_features=32, bias=True)
    (relu): ReLU()
    (transformer_encoder): TransformerEncoder(
      (layers): ModuleList(
        (0-1): 2 x TransformerEncoderLayer(
          (self_attn): MultiheadAttention(
            (out_proj): NonDynamicallyQuantizableLinear(in_features=32, out_features=32, bias=True)
          )
          (linear1): Linear(in_features=32, out_features=32, bias=True)
          (dropout): Dropout(p=0.1, inplace=False)
          (linear2): Linear(in_features=32, out_features=32, bias=True)
          (norm1): LayerNorm((32,), eps=1e-05, elementwise_affine=True)
          (norm2): LayerNorm((32,), eps=1e-05, elementwise_affine=True)
          (dropout1): Dropout(p=0.1, inplace=False)
          (dropout2): Dropout(p=0.1

In [20]:
scorer.get_n_parameters()

29128

In [28]:
score = scorer.forward(
    batch_game_state_vectors=batch_game_state_vectors_torch,
    batch_action_vectors=batch_action_vectors_torch
).cpu().detach().numpy()[0][0]