In [None]:
import torch
import torch.nn as nn
import dgl
import torch.nn.functional as F
from dgl.nn import GraphConv

class Graph(torch.nn.Module):
    def __init__(self, g_in_feats, g_hidden_size, g_out_feats, dropout=0.5):
        super(Graph, self).__init__()
        torch.manual_seed(1234)
        
        self.layers = torch.nn.ModuleList()
        self.layers.append(GraphConv(g_in_feats, g_hidden_size[0], allow_zero_in_degree=True))
        for i in range(len(g_hidden_size) - 1):
            self.layers.append(GraphConv(g_hidden_size[i], g_hidden_size[i + 1], allow_zero_in_degree=True))
        self.classify = torch.nn.Linear(sum(g_hidden_size), g_out_feats)  
        self.dropout = torch.nn.Dropout(dropout)

    def forward(self, g):  
        hidden_rep = []
        h = g.ndata['x']
        for layer in self.layers:
            h = layer(g, h)
            h = F.relu(h)
            h = self.dropout(h)
            g.ndata['h'] = h  
            hg = dgl.mean_nodes(g, 'h')  
            hidden_rep.append(hg)
        hg = torch.cat(hidden_rep, dim=1)
        hg = self.dropout(hg)
        return self.classify(hg)  

class Finger(nn.Module):
    def __init__(self, d_in_feats, d_hidden_size, d_out_feats, dropout=0.5):
        super(Finger, self).__init__()
        self.linear1 = nn.Linear(d_in_feats, d_hidden_size)
        self.linear2 = nn.Linear(d_hidden_size, d_hidden_size)
        self.out_layer = nn.Linear(d_hidden_size, d_out_feats)

    def forward(self, input):
        x = F.relu(self.linear1(input))
        x = F.relu(self.linear2(x))
        x = self.out_layer(x)
        return x

class MultiFusion(nn.Module):
    def __init__(self, g_in_feats, g_hidden_size, g_out_feats, d_in_feats, d_hidden_size, d_out_feats):
        super(MultiFusion, self).__init__()
        self.graph = Graph(g_in_feats, g_hidden_size, g_out_feats)
        self.finger = Finger(d_in_feats, d_hidden_size, d_out_feats)

    def forward(self, g, x):
        g_output = self.graph(g)
        d_output = self.finger(x)
        output = torch.cat((g_output, d_output), dim=1)
        output = g_output + d_output 
        return output