In [2]:
import mlx.core as mx
import mlx.nn as nn

import mlx.optimizers as optim
from mlx.utils import tree_map
from functools import partial
import mlx.core as mx
import mlx.nn as nn

from tqdm import tqdm
import numpy as np
import argparse
import random
import csv
import json

In [2]:
class gat_layer(nn.Module):
    def __init__(self, num_nodes: int, dim_proj: int, num_att_heads: int): 
        super().__init__()

        self.source_dim = 0;
        self.target_dim = 1;

        self.dim_proj = dim_proj
        self.num_nodes = num_nodes
        self.num_att_heads = num_att_heads

        self.source_scores_fn = mx.array([1, num_att_heads, dim_proj]) 
        self.target_scores_fn = mx.array([1, num_att_heads, dim_proj]) 

        self.leakyReLU = nn.LeakyReLU(0.02)

    def __call__(self, node_proj, adjacency_matrix):

        source_idx = adjacency_matrix[self.source_dim]
        target_idx = adjacency_matrix[self.target_dim]

        node_proj = node_proj.reshape([-1, self.num_att_heads, self.dim_proj])
    
        source_scores = (node_proj * self.source_scores_fn).sum(dim=-1)
        target_scores = (node_proj * self.target_scores_fn).sum(dim=-1)

        edge_filtered_node_proj = mx.take(node_proj, source_idx, axis=0);
        edge_filtered_source_scores = mx.take(source_scores, source_idx, axis=0);
        edge_filtered_target_scores = mx.take(target_scores, target_idx, axis=0);

        edge_scores = self.leakyReLU(edge_filtered_source_scores + edge_filtered_target_scores)
        edge_scores = (edge_scores - edge_scores.max()).exp()

        softmax_denominator = mx.zeros([self.num_nodes, self.num_att_heads, self.dim_proj])
        softmax_denominator = softmax_denominator.at[target_idx].add(edge_scores)

        attention_scores = edge_scores / (softmax_denominator + 1e-16)

        edge_filtered_node_proj = edge_filtered_node_proj * attention_scores;

        new_node_proj = mx.zeros([self.num_nodes, self.num_att_heads, self.dim_proj]).at(target_idx).add(edge_filtered_node_proj)
        new_node_proj = self.leakyReLU(new_node_proj)
        new_node_proj = new_node_proj.reshape((self.num_nodes, self.num_att_heads * self.dim_proj))

        return new_node_proj


In [97]:
class gat(nn.Module):
    def __init__(self, num_nodes: int, dim_embed: int, dim_proj: int, num_att_heads: int, num_layers: int, skip_connections: bool, num_out_layers: int): 
        super().__init__()
        
        total_att_size = dim_proj * num_att_heads;

        gat_layer = gat_layer(total_att_size)

        self.embed_proj = mx.linear(dim_embed, total_att_size)
        self.gat_layers = mx.Sequential([gat_layer] * num_layers)
        self.out_layers = mx.Sequential([mx.Linear(total_att_size, total_att_size)] * num_out_layers + [mx.Linear(total_att_size, 7)])

        self.leakyReLU = nn.LeakyReLU(.02)

    def __call__(self, node_embeddings, adjacency_matrix):
        assert node_embeddings.shape[1] == self.dim_embed, f'Incorrect node embedding size'

        node_proj = self.embed_proj(node_embeddings);

        for layer in self.gat_layers:
            new_node_proj = layer(node_proj)
            if (self.skip_connections):
                new_node_proj += node_proj;
            node_proj = new_node_proj

        for layer in self.out_layers:
            node_embeddings = mx.layer(node_embeddings)

        return mx.softmax(node_embeddings)