In [15]:
import tensorflow as tf
from tensorflow.keras.layers import Dense, BatchNormalization
from tensorflow.keras.activations import relu
from spektral.layers import GraphSageConv
from spektral.data.graph import Graph
from contextlib import suppress
import spektral
import numpy as np

import networkx as nx
import pandas as pd
from os import listdir

In [8]:
class GNN(tf.keras.Model):
    def __init__(self, hidden_channels, out_channels, add_linear=True):
        super(GNN, self).__init__()
        self.conv1 = GraphSageConv(hidden_channels)
        self.bn1 = BatchNormalization()
        self.conv2 = GraphSageConv(hidden_channels)
        self.bn2 = BatchNormalization()
        self.conv3 = GraphSageConv(out_channels)
        self.bn3 = BatchNormalization()
        
        if add_linear:
            self.lin = Dense(out_channels)
        else:
            self.lin = None
    
    def call(self, inputs):
        x, adj = inputs
        
        x1 = self.bn1(relu(self.conv1(inputs)))
        x2 = self.bn2(relu(self.conv2([x1, adj])))
        x3 = self.bn3(relu(self.conv3([x2, adj])))

        x = tf.concat([x1, x2, x3], axis=-1)

        if self.lin is not None:
            x = relu(self.lin(x))

        return x

In [9]:
def dense_diff_pool(x, adj, s):
    # https://pytorch-geometric.readthedocs.io/en/latest/_modules/torch_geometric/nn/dense/diff_pool.html#dense_diff_pool
    # add batch dimension if necessary
    with suppress(TypeError):
        adj = tf.sparse.to_dense(adj)
        s = tf.sparse.to_dense(adj)

    #x = tf.expand_dims(x, axis=0) if len(x.shape) == 2 else x
    #adj = tf.expand_dims(adj, axis=0) if len(adj.shape) == 2 else adj
    #s = tf.expand_dims(s, axis=0) if len(s.shape) == 2 else s

    #batch_size, num_nodes, _ = x.shape  # used when maks is implemented

    # s = tf.nn.softmax(s, axis=-1)
    s = tf.nn.softmax(s, axis=-1)    # check if this works as tf.nn.softmax(x, axis=-1)
    st = tf.transpose(s, (1, 0))

    out = tf.matmul(st, x)
    out_adj = tf.matmul(tf.matmul(st, adj), s)

    link_loss = adj - tf.matmul(s, st)
    link_loss = tf.norm(link_loss, ord=2)
    link_loss = link_loss / tf.size(adj, out_type=tf.dtypes.float32)

    ent_loss = tf.reduce_mean(tf.reduce_sum(-s * tf.math.log(s + 1e-15), axis=-1))

    return out, tf.sparse.from_dense(out_adj), link_loss, ent_loss

In [10]:
class Net(tf.keras.Model):
    def __init__(self, num_classes=6, max_nodes=200):
        super(Net, self).__init__()

        num_nodes = np.ceil(0.5 * max_nodes).astype(int)
        self.gnn1_pool = GNN(64, num_nodes)
        self.gnn1_embed = GNN(64, 64, add_linear=False)

        num_nodes = np.ceil(0.5 * num_nodes).astype(int)
        self.gnn2_pool = GNN(64, num_nodes)
        self.gnn2_embed = GNN(64, 64, add_linear=False)

        self.gnn3_embed = GNN(64, 64, add_linear=False)

        self.lin1 = Dense(64)
        self.lin2 = Dense(num_classes)
    
    def call(self, inputs):
        x, adj = inputs
        s = self.gnn1_pool(inputs)
        x = self.gnn1_embed([x, adj])

        x, adj, l1, e1 = dense_diff_pool(x, adj, s)

        s = self.gnn2_pool([x, adj])
        x = self.gnn2_embed([x, adj])

        x, adj, l2, e2 = dense_diff_pool(x, adj, s)

        x = self.gnn3_embed([x, adj])

        x = tf.reduce_mean(x, axis=1)

        if len(x.shape) == 1:
            x = tf.expand_dims(x, axis=0)

        x = relu(self.lin1(x))
        x = self.lin2(x)

        return tf.nn.log_softmax(x, axis=-1), l1 + l2, e1 + e2

In [11]:
num_nodes = 200
num_features = 10

x = tf.Variable(tf.random.normal((num_nodes, num_features)))
adj = tf.sparse.from_dense(tf.round(tf.random.uniform((num_nodes, num_nodes))))
net = Net()
net([x, adj])[0]

<tf.Tensor: shape=(1, 6), dtype=float32, numpy=
array([[-1.7215043, -1.7951338, -1.7954525, -1.7804478, -1.7885032,
        -1.875537 ]], dtype=float32)>

In [17]:
class WICO(spektral.data.Dataset):
    def __init__(self, path="./dataset/WICO/"):
        super().__init()
        self.path = path
        self.labels = {x: i for i, x in enumerate(listdir(self.path))}

    def read(self):
        for y,graph_type in enumerate(listdir(self.path)):
            subgraphs_list = list(filter(str.isnumeric, listdir(f"{self.path}/{graph_type}/")))
            for graph_id in 

In [16]:
path = "./dataset/WICO"

for i,graph_type in enumerate(listdir(path)):
    

0 5G_Conspiracy_Graphs
1 Non_Conspiracy_Graphs
2 Other_Graphs


In [20]:
list(filter(str.isnumeric, listdir("./dataset/WICO/5G_Conspiracy_Graphs/")))

['1',
 '10',
 '100',
 '101',
 '102',
 '103',
 '104',
 '105',
 '106',
 '107',
 '108',
 '109',
 '11',
 '110',
 '111',
 '112',
 '113',
 '114',
 '115',
 '116',
 '117',
 '118',
 '119',
 '12',
 '120',
 '121',
 '122',
 '123',
 '124',
 '125',
 '126',
 '127',
 '128',
 '129',
 '13',
 '130',
 '131',
 '132',
 '133',
 '134',
 '135',
 '136',
 '137',
 '138',
 '139',
 '14',
 '140',
 '141',
 '142',
 '143',
 '144',
 '145',
 '146',
 '147',
 '148',
 '149',
 '15',
 '150',
 '151',
 '152',
 '153',
 '154',
 '155',
 '156',
 '157',
 '158',
 '159',
 '16',
 '160',
 '161',
 '162',
 '163',
 '164',
 '165',
 '166',
 '167',
 '168',
 '169',
 '17',
 '170',
 '171',
 '172',
 '173',
 '174',
 '175',
 '176',
 '177',
 '178',
 '179',
 '18',
 '180',
 '181',
 '182',
 '183',
 '184',
 '185',
 '186',
 '187',
 '188',
 '189',
 '19',
 '190',
 '191',
 '192',
 '193',
 '194',
 '195',
 '196',
 '197',
 '198',
 '199',
 '2',
 '20',
 '200',
 '201',
 '202',
 '203',
 '204',
 '205',
 '206',
 '207',
 '208',
 '209',
 '21',
 '210',
 '211',
 '212',
