In [1]:
import torch
import torch.nn as nn
import numpy as np
import sys
sys.path.append("..")
from src.gnn_models import GNN_7
import hls4ml
from torch_geometric.utils import to_dense_adj




In [2]:
class CustomGraphLayer(nn.Module):
    def __init__(self, input_features, output_features):
        super().__init__()
        self.weight_node = nn.Linear(input_features, output_features, bias=False)
        self.weight_adj = nn.Linear(input_features, output_features, bias=True)

    def forward(self, node_features, Adj, batch):
        new_sample_index = torch.where(batch[:-1] != batch[1:])[0] + 1
        new_sample_index = torch.cat(
            (torch.tensor([0], device="cpu"), new_sample_index)
        )
        new_sample_index = torch.cat(
            (new_sample_index, torch.tensor([len(batch)], device="cpu"))
        )
        num_nodes_list = new_sample_index[1:] - new_sample_index[:-1]
        for i, num_nodes in enumerate(num_nodes_list):
            # reshape long list of flattened Adjacency matrices to one big block
            # diagonal matrix over multiple graphs
            square_Adj = Adj[: num_nodes**2, :].reshape(num_nodes, num_nodes)
            Adj = Adj[num_nodes**2 :, :]
            if i == 0:
                block_Adj = square_Adj
            else:
                block_Adj = torch.block_diag(block_Adj, square_Adj)
        node_term = self.weight_node(node_features)
        adjacency_sum = torch.matmul(block_Adj, node_features)
        adjacency_term = self.weight_adj(adjacency_sum)
        new_node_features = node_term + adjacency_term

        return new_node_features

class PMatMul(nn.Module):
    
    def __init__(self):
        super().__init__()
    
    def forward(self, x1, x2):
        return torch.matmul(x1, x2)

class HMatMul(hls4ml.model.layers.Layer):
    
    def initialize(self):
        inp = self.get_input_variable()
        shape = inp.shape
        dims = inp.dim_names
        self.add_output_variable(shape, dims)
    
class CustomGraphConv(nn.Module):
    
    def __init__(self, input_features, output_features):
        super().__init__()
        
        self.matmul = PMatMul()
        self.weight_node = nn.Linear(input_features, output_features, bias=False)
        self.weight_adj = nn.Linear(input_features, output_features, bias=True)
        
    def forward(self, x, adj):
        
        node_term = self.weight_node(x)
        
        adjaceny_sum = self.matmul(adj, x)
        adjaceny_term = self.weight_adj(adjaceny_sum)
        
        return node_term + adjaceny_term
        
class GraphWTorchNet(torch.nn.Module):
    def __init__(
        self,
        hidden_channels_GCN=[32, 128, 256, 512, 512, 256, 256],
        hidden_channels_MLP=[256, 128, 64],
        num_node_features=5,
        num_classes=1,
        manual_seed=12345,
    ):
        # num_classes is 1 for each head
        super().__init__()
        if manual_seed is not None:
            torch.manual_seed(manual_seed)

        # Activation
        self.activation = nn.ReLU()
        
        # GCN layers
        channels = [num_node_features] + hidden_channels_GCN
        self.graph_layers = nn.ModuleList(
            [
                CustomGraphConv(in_channels, out_channels)
                for (in_channels, out_channels) in zip(channels[:-1], channels[1:])
            ]
        )

        # Dense layers
        channels = hidden_channels_GCN[-1:] + hidden_channels_MLP
        self.dense_layers = nn.ModuleList(
            [
                nn.Linear(in_channels, out_channels)
                for (in_channels, out_channels) in zip(channels[:-1], channels[1:])
            ]
        )

        # Output later
        self.output_layer = nn.Linear(hidden_channels_MLP[-1], num_classes)

    def forward(self, x, adj):
        
        #  node embeddings
        for layer in self.graph_layers:
            x = layer(x, adj)
            x = self.activation(x)

        # global mean pool
        x = x.mean(axis=0)

        for layer in self.dense_layers:
            x = layer(x)
            x = self.activation(x)

        # output
        x = self.output_layer(x)

        return x
    
def parse_matmul_layer(torch_layer, input_names, input_shapes, data_reader):
    layer = {}
    layer["class_name"] = "HMatMul"
    layer["name"] = torch_layer["config"]["name"]
    layer["n_in"] = input_shapes[0][1]
    
    if input_names is not None:
        layer["inputs"] = input_names
    
    return layer, [shape for shape in input_shapes[0]]

    

In [3]:
def print_dict(d, indent=0):
    for key, value in d.items():
        print('  ' * indent + str(key), end='')
        if isinstance(value, dict):
            print()
            print_dict(value, indent + 1)
        else:
            print(':' + ' ' * (20 - len(key) - 2 * indent) + str(value))

In [4]:
model = GraphWTorchNet()
model

GraphWTorchNet(
  (activation): ReLU()
  (graph_layers): ModuleList(
    (0): CustomGraphConv(
      (matmul): PMatMul()
      (weight_node): Linear(in_features=5, out_features=32, bias=False)
      (weight_adj): Linear(in_features=5, out_features=32, bias=True)
    )
    (1): CustomGraphConv(
      (matmul): PMatMul()
      (weight_node): Linear(in_features=32, out_features=128, bias=False)
      (weight_adj): Linear(in_features=32, out_features=128, bias=True)
    )
    (2): CustomGraphConv(
      (matmul): PMatMul()
      (weight_node): Linear(in_features=128, out_features=256, bias=False)
      (weight_adj): Linear(in_features=128, out_features=256, bias=True)
    )
    (3): CustomGraphConv(
      (matmul): PMatMul()
      (weight_node): Linear(in_features=256, out_features=512, bias=False)
      (weight_adj): Linear(in_features=256, out_features=512, bias=True)
    )
    (4): CustomGraphConv(
      (matmul): PMatMul()
      (weight_node): Linear(in_features=512, out_features=512, 

In [5]:
from pathlib import Path

matmul_config_template = """struct config{index} : nnet::matmul_config {{
    static const unsigned n_in = {n_in};
}};\n"""

matmul_function_template = 'nnet::product<{x_t}, {y_t}>({x}, {y});'
matmul_include_list = ['nnet_utils/nnet_matmul.h']

class HMatMulConfigTemplate(hls4ml.backends.template.LayerConfigTemplate):
    def __init__(self):
        super().__init__(HMatMul)
        self.template = matmul_config_template

    def format(self, node):
        params = self._default_config_params(node)
        return self.template.format(**params)


class HMatMulFunctionTemplate(hls4ml.backends.template.FunctionCallTemplate):
    def __init__(self):
        super().__init__(HMatMul, include_header=matmul_include_list)
        self.template = matmul_function_template

    def format(self, node):
        params = self._default_function_params(node)
        return self.template.format(**params)
    
# Register the converter for custom Keras layer
hls4ml.converters.register_pytorch_layer_handler('PMatMul', parse_matmul_layer)

# Register the hls4ml's IR layer
hls4ml.model.layers.register_layer('HMatMul', HMatMul)


# Register the optimization passes (if any)
backend = hls4ml.backends.get_backend("Vivado")
# backend.register_pass('remove_duplicate_reverse', RemoveDuplicateReverse, flow=f'{backend_id.lower()}:optimize')

# Register template passes for the given backend
backend.register_template(HMatMulConfigTemplate)
backend.register_template(HMatMulFunctionTemplate)

# Register HLS implementation
path = Path("C:/Users/isakb/miniforge3/envs/dml_cpu/Lib/site-packages/hls4ml/backends/templates/vivado/nnet_utils/nnet_mult.h")
backend.register_source(path)

In [6]:
config = hls4ml.utils.config_from_pytorch_model(model)
input_shapes = [(100, 5), (400, 400)]
hls_model = hls4ml.converters.convert_from_pytorch_model(model, input_shape=input_shapes, backend="Vivado")

Interpreting Model ...
Topology:
Layer name: graph_layers_0_weight_node, layer type: Dense, input shape: [[100, 5]]


  torch.has_cuda,
  torch.has_cudnn,
  torch.has_mps,
  torch.has_mkldnn,


Exception: Unsupported function matmul

In [8]:
backend.get_custom_source()

{'nnet_utils\\nnet_mult.h': WindowsPath('C:/Users/isakb/miniforge3/envs/dml_cpu/Lib/site-packages/hls4ml/backends/templates/vivado/nnet_utils/nnet_mult.h')}