In [32]:
# %matplotlib ipympl

import json
import matplotlib.pyplot as plt
import numpy as np
import os
import glob
import torch
from torch import nn
import networkx as nx
import matplotlib.patches as mpatches
from torch.utils.data import Dataset, DataLoader, ConcatDataset, random_split

In [None]:
def add_self_edges(edges):
    n_nodes = torch.max(edges) + 1
    all_nodes = torch.arange(n_nodes)
    start_node = edges[0]
    end_node = edges[1]
    has_self_edge = torch.unique(start_node[start_node == end_node])
    
    has_self_edge_map = torch.zeros(n_nodes, dtype=np.bool)
    has_self_edge_map[has_self_edge] = True
    nodes_without_self_edge = all_nodes[~has_self_edge_map]
    
    self_edges = torch.vstack([nodes_without_self_edge, nodes_without_self_edge])
    return torch.hstack([edges, self_edges])



# adding bias in intermediate layers is an unnecessary complication
# same shift is a message will be reduced by softmax exponent ratio

class GATMessage(nn.Module):
    def __init__(self, in_features, out_features):
        super().__init__()
        self.transform_fn = nn.Linear(in_features, out_features, bias=False)
    
    def forward(self, x):
        x = self.transform_fn(x)
        return x
    

class GATAttentionCoefficients(nn.Module):
    def __init__(self, n_features):
        super().__init__()
        self.transform_fn = nn.Linear(n_features, 1, bias=False)
        self.activation_fn = nn.LeakyReLU()

    def forward(self, x):
        # x has dimensions [b, n_edges, n_features]
        x = self.transform_fn(x)
        x = self.activation_fn(x)
        return x
    

class GATLayer(nn.Module):
    def __init__(self, in_features, out_features):
        # edges = 2 x |E| numpy array
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.message_fn = GATMessage(in_features, out_features)
        self.attention_coef_fn = GATAttentionCoefficients(2 * out_features)
        self.activation_fn = nn.ReLU()
    

    def forward(self, in_states, edges):
        # in_states has dimensions [b, n_nodes, in_features]
        n_nodes = np.max(edges) + 1    
        messages = self.message_fn(in_states)
        # messages has dimensions [b, n_edges, out_features]
        m_from = messages[:, edges[0], :]
        m_to = messages[:, edges[1], :]
        m_stack = torch.cat(m_from, m_to, dim=2)
        # m_stack has dimensions [b, n_edges, 2 * out_features]
        attention_coef = self.attention_coef_fn(m_stack)
        # attention_coef has dimensions [b, n_edges]
        
        merged_msgs = torch.empty(in_states.shape[0], n_nodes, self.out_features)

        for i_node in range(n_nodes):
            in_edge_map_idx = edges[1] == i_node
            neighbors = edges[0][in_edge_map_idx]
            neighb_msgs = messages[:, neighbors, :]
            neighb_att = torch.softmax(attention_coef[:, in_edge_map_idx])
            # neighb_msgs - [b, n_nbh, out_f], neighb_att - [b, n_nbh, 1]
            # broadcast attention coefficients over the last dim = same over all instance features
            merged_msg = torch.sum(neighb_msgs * neighb_att, dim=1)
            # merged_msg = - [b, out_f],
            merged_msgs[:, i_node, :] = merged_msg
        
        # transformations correspond to and learn the residual value of the layer
        # new_states - in_states = "what layer learned"
        new_states_residual = self.activation_fn(merged_msgs)
        new_states = in_states + new_states_residual
        return new_states

In [None]:
class GATGraphSummary(nn.Module):
    def __init__(self, in_features, out_features):
        super().__init__()
        self.attention_coef_fn = GATAttentionCoefficients(in_features)
        self.summary_fn = nn.Linear(in_features, out_features)
        self.in_features = in_features
        self.out_features = out_features

    def forward(self, in_states, graph_ids):
        # graph_ids - [n_nodes, 1]
        # in_states - [n_nodes, in_features], graph_ids - n_nodes
        att_coeffs = self.attention_coef_fn(in_states)
        # att_coeffs - [n_nodes, 1]
        _, ord_ids_map = torch.unique(graph_ids, sorted=True, return_inverse=True)

        att = torch.empty_like(att_coeffs)
        for id in 
        att = torch.softmax(att_coeffs, dim=1)

        n_graphs = len(ord_ids_map)
        graph_sum = torch.zeros(
            in_states.shape[0], n_graphs, self.in_features,
            dtype=torch.float32
        )
        graph_sum.index_add_(1, ord_ids_map, in_states_scaled)


        merged_state = torch.sum(in_states * att, dim=1)
        # merged_state - [b, n_graphs, in_features]
        summary = self.summary_fn(merged_state)
        # merged_state - [b, n_graphs, out_features]
        return summary
    

class GATEdgeValue(nn.Module):
    def __init__(self, in_features, out_features):
        super().__init__()
        self.edge_value_fn = nn.Linear(2 * in_features, out_features)

    def forward(self, in_states, edges_ids, edges_dict):
        # in_states - [b, n_nodes, in_features]
        state_from = in_states[:, edges[0], :]
        state_to = in_states[:, edges[1], :]
        edge_state = torch.cat(state_from, state_to, dim=2) 
        # edge_state - [b, n_edges, 2*in_features]
        value = self.edge_value_fn(edge_state)
        return value
    

class DiceWarsActionValueModel(nn.Module):
    def __init__(self):
        super().__init__()

        # there are 8 teams - each node has 8 corresponding values
        # for each node, there are 1-8 dice, all belonging to a single team
        # value for this team = # of dice, for other teams value = 0
        self.norm_fn = lambda x: x / 8  # downscale dice on a node
        self.gat_layers = [
            GATLayer(8, 8),
            GATLayer(8, 8),
            GATLayer(8, 8),
            GATLayer(8, 8)
        ]
        self.attack_value_fn = GATEdgeValue(8, 1)
        self.end_turn_value_fn = GATGraphSummary(8, 8)

    def forward(self, in_states, edges_ids, edges_dict):
        x = self.norm_fn(in_states)

        for gat in self.gat_layers:
            x = gat(x, edges)

        edge_attack_val = self.attack_value_fn(x, edges_ids, edges_dict)
        end_turn_val = self.end_turn_value_fn(x)
        return end_turn_val, edge_attack_val



In [24]:
class GameValues:
    win_value = 2000
    lose_value = -2000
    move_value = -1


def adj_matrix_to_edges(adj_mat, add_self_edges):
    n_nodes = len(adj_mat)
    edges = []
    for i in range(n_nodes):
        for j in range(n_nodes):
            is_self_edge = (i == j)
            if adj_mat[i][j] > 0 or (is_self_edge and add_self_edges):
                edges.append([i, j])

    return torch.tensor(edges, dtype=torch.int32).T


def extract_edges(history_data):
    return adj_matrix_to_edges(history_data["adjacency"], add_self_edges=True)


def extract_nodes_states(history_data):
    states_data = history_data["states"][:-1]  # drop the last "terminal" state
    dice_values = np.array([s["dice"] for s in states_data])
    player_values = np.array([s["teams"] for s in states_data])

    adj_matrix = history_data["adjacency"]

    n_nodes = len(adj_matrix)
    n_states = len(states_data)
    n_players = 8

    nodes_states = torch.zeros(n_states, n_nodes, n_players, dtype=torch.float32)

    state_idx, node_idx = np.ogrid[:n_states, :n_nodes]
    idx = (state_idx, node_idx, player_values)
    nodes_states[idx] = torch.tensor(dice_values, dtype=torch.float32)

    return nodes_states


def extract_action_values(history_data):
    last_state = history_data["states"][-1]
    action_count = {i: 0 for i in range(8)}
    attack_edges = []
    end_turn_players = []
    action_values = []

    winning_team = set(last_state["teams"])
    if len(winning_team) > 1:
        raise ValueError("There are multiple winning teams in the last state.")
    winning_team = winning_team.pop()

    for action in history_data["actions"]:
        player_id = action["player"]
        action_count[player_id] += 1


    for action in history_data["actions"]:
        player_id = action["player"] 
        if player_id == winning_team:
            player_values = GameValues.win_value
        else:
            player_values = GameValues.lose_value
        player_values += GameValues.move_value * action_count[player_id]

        if action["move_made"]:
            attack_edges.append([action["from"], action["to"]])
            action_values.append(player_values)
        else:
            attack_edges.append([-1, -1])
        
        if action["turn_end"]:
            end_turn_players.append(player_id)
            action_values.append(player_values)
        else:
            end_turn_players.append(-1)

    attack_edges = torch.tensor(attack_edges, dtype=torch.int32)
    end_turn_players = torch.tensor(end_turn_players, dtype=torch.int32)
    action_values = torch.tensor(action_values, dtype=torch.float32)

    return attack_edges, end_turn_players, action_values


def load_dicewars_data(json_path):
    uint64_limit = torch.iinfo(torch.uint64).max + 1
    graph_id = torch.tensor(hash(json_path) % uint64_limit, dtype=torch.uint64)

    history_data = json.load(open(json_path, "r"))
    nodes_states = extract_nodes_states(history_data)
    edges = extract_edges(history_data)
    attack_edges, end_turn_players, action_values = extract_action_values(history_data)

    return dict(
        graph_id=graph_id,
        nodes_states=nodes_states, 
        attack_edges=attack_edges, 
        end_turn_players=end_turn_players, 
        action_values=action_values, 
        edges=edges
    )


In [25]:
class DWGameDataset(Dataset):
    def __init__(self, json_path):
        data = load_dicewars_data(json_path)
        self.graph_id = data["graph_id"]
        self.nodes_states = data["nodes_states"]
        self.attack_edges = data["attack_edges"]
        self.end_turn_players = data["end_turn_players"]
        self.action_values = data["action_values"]
        self.edges = data["edges"]
        self.n_states = self.nodes_states.shape[0]
    
    def __len__(self):
        return self.n_states
    
    def __getitem__(self, index):
        return {
            # scalar
            "graph_id": self.graph_id, 
            # [n_nodes, n_features]
            "nodes_state": self.nodes_states[index],
            # [2]
            "attack_edge": self.attack_edges[index],
            # scalar
            "end_turn_player": self.end_turn_players[index],
            # scalar
            "action_value": self.action_values[index],
            # [2, n_edges]
            "edges": self.edges
        }

In [None]:
# graph data collate function that combines multiple graph staters into a single graph state
# node indices are shifted to become the merged super-graph indices
def collate_dw_data(sample_list):
    nodes_states = []
    attack_edges = []
    end_turn_players = []
    action_values = []
    edges = []
    graph_ids = []

    node_id_offset = 0
    for i, sample in enumerate(sample_list):
        nodes_states.append(sample["nodes_state"])
        attack_edges.append(sample["attack_edge"] + node_id_offset)
        end_turn_players.append(sample["end_turn_player"])
        action_values.append(sample["action_value"])
        edges.append(sample["action_value"] + node_id_offset)

        n_nodes = len(sample["nodes_state"])
        graph_ids.append(torch.full((n_nodes,), n_nodes, dtype=torch.int32))
        node_id_offset += n_nodes

    nodes_states = torch.cat(nodes_states)
    attack_edges = torch.stack(attack_edges)
    end_turn_players = torch.tensor(end_turn_players, dtype=torch.int32)
    action_values = torch.tensor(action_values, dtype=torch.float32)
    edges = torch.cat(edges, dim=1)
    graph_ids = torch.cat(graph_ids)

    return dict(
        nodes_states=nodes_states,
        attack_edges=attack_edges,
        end_turn_players=end_turn_players,
        action_values=action_values,
        graph_ids=graph_ids,
        edges=edges
    )


In [43]:
tuple([1])

(1,)

In [44]:
history_folder = "server/history/"
history_files = sorted(glob.glob(os.path.join(history_folder, "history_*.json")))

In [45]:
concatenated_dataset = ConcatDataset(
    [DWGameDataset(f) for f in history_files]
)
train_ds, val_ds, test_ds = random_split(
    concatenated_dataset,
    [0.8, 0.1, 0.1],
)

In [46]:
loader = DataLoader(train_ds, 10, shuffle=True, collate_fn=collate_dw_data)

In [47]:
batch = next(iter(loader))

RuntimeError: zero-dimensional tensor (at position 0) cannot be concatenated