# Fake News Detection using PyTorch Geometric

## Imports and Setup

In [None]:
import os
import random

In [None]:
# PyTorch imports
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data
import torch.optim as optim

In [None]:
# Import PyTorch Metrics
try:
    import torchmetrics
except ModuleNotFoundError:
    !pip install -q torchmetrics
    import torchmetrics

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m806.1/806.1 kB[0m [31m5.1 MB/s[0m eta [36m0:00:00[0m
[?25h

In [None]:
# Import PyTorch Lightning
try:
    import pytorch_lightning as pl
except ModuleNotFoundError: # Google Colab does not have PyTorch Lightning installed by default. Hence, we do it here if necessary
    !pip install -q pytorch-lightning
    import pytorch_lightning as pl
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint, EarlyStopping

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m777.7/777.7 kB[0m [31m6.8 MB/s[0m eta [36m0:00:00[0m
[?25h

In [None]:
# Torch geometric
try:
    import torch_geometric
except ModuleNotFoundError:
    # Installing torch geometric packages with specific CUDA+PyTorch version.
    # See https://pytorch-geometric.readthedocs.io/en/latest/notes/installation.html for details
    TORCH = torch.__version__.split('+')[0]
    CUDA = 'cu' + torch.version.cuda.replace('.','')

    !pip install -q pyg_lib torch_scatter torch_sparse torch_cluster torch_spline_conv -f https://data.pyg.org/whl/torch-{TORCH}+{CUDA}.html
    !pip install -q git+https://github.com/rusty1s/pytorch_geometric.git
    import torch_geometric
import torch_geometric.nn as geom_nn
from torch_geometric.loader import DataLoader
from torch_geometric.utils import get_laplacian

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.4/2.4 MB[0m [31m12.9 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m10.8/10.8 MB[0m [31m51.4 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m5.0/5.0 MB[0m [31m36.7 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.3/3.3 MB[0m [31m105.1 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m932.1/932.1 kB[0m [31m68.8 MB/s[0m eta [36m0:00:00[0m
[?25h  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
  Building wheel for torch_geometric (pyproject.toml) ... [?25l[?25hdone


In [None]:
# Ensure that all operations are deterministic on GPU (if used) for reproducibility
def set_seed(seed: int = 42) -> None:
    pl.seed_everything(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    print(f"Random seed set as {seed}")

set_seed()

INFO:lightning_fabric.utilities.seed:Seed set to 42


Random seed set as 42


In [None]:
CHECKPOINT_PATH = "./saved_models/"
DATA_PATH = "./data"
device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")

## Define dataset classes

In [None]:
from torch_geometric.datasets import UPFD

class UPFDDataModule(pl.LightningDataModule):
    def __init__(
            self,
            batch_size: int = 32,
            dataset: str = "politifact",
            feature: str = "bert"
        ):
        super().__init__()

        # this line allows to access init params with 'self.hparams' attribute
        self.save_hyperparameters(logger=False)

    @property
    def num_node_features(self):
        return self.data_train.num_node_features

    @property
    def num_classes(self):
        return self.data_train.num_classes

    def setup(self, stage: str):
        if stage=="fit":
            self.data_train = UPFD(root=DATA_PATH, name=self.hparams.dataset, feature=self.hparams.feature, split="train")
            self.data_val = UPFD(root=DATA_PATH, name=self.hparams.dataset, feature=self.hparams.feature, split="val")

        if stage=="test":
            self.data_test = UPFD(root=DATA_PATH, name=self.hparams.dataset, feature=self.hparams.feature, split="test")

    def train_dataloader(self):
        return DataLoader(dataset=self.data_train, batch_size=self.hparams.batch_size, shuffle=True)

    def val_dataloader(self):
        return DataLoader(dataset=self.data_val, batch_size=self.hparams.batch_size)

    def test_dataloader(self):
        return DataLoader(dataset=self.data_test, batch_size=self.hparams.batch_size)

## Define Classifier models

### UPFD Model

In [None]:
class UPFDModel(nn.Module):

    def __init__(self, c_in, c_hidden, c_out, concat, dropout_ratio=0.5, **kwargs):
        """
        Inputs:
            c_in - Dimension of input features
            c_hidden - Dimension of hidden features
            c_out - Dimension of the output features. Usually number of classes in classification
            dp_rate_linear - Dropout rate before the linear layer (usually much higher than inside the GNN)
            kwargs - Additional arguments for the graph layers
        """
        super().__init__()

        self.concat = concat
        self.model = 'sage'

        if self.model == 'gcn':
            self.conv1 = geom_nn.GCNConv(c_in, c_hidden)
        elif self.model == 'sage':
            self.conv1 = geom_nn.SAGEConv(c_in, c_hidden)
        elif self.model == 'gat':
            self.conv1 = geom_nn.GATConv(c_in, c_hidden)

        if self.concat:
            self.lin0 = torch.nn.Linear(c_in, c_hidden)
            self.lin1 = torch.nn.Linear(c_hidden * 2, c_hidden)

        self.lin2 = torch.nn.Linear(c_hidden, c_out)


    def forward(self, x, edge_index, batch_idx, batch_size):
        """
        Inputs:
            x - Input features per node
            edge_index - List of vertex index pairs representing the edges in the graph (PyTorch geometric notation)
            batch_idx - Index of batch element for each node
            batch_size - Size of a batch that is used in MixUp
            lam - Lambda paratemeter for MixUp
        """
        edge_attr = None

        raw_x = x.clone()

        x = F.relu(self.conv1(x, edge_index, edge_attr))
        x = geom_nn.global_max_pool (x, batch_idx)

        if self.concat:
            news = torch.stack([raw_x[(batch_idx == idx).nonzero().squeeze()[0]] for idx in range(batch_size)])
            news = F.relu(self.lin0(news))
            x = torch.cat([x, news], dim=1)
            x = F.relu(self.lin1(x))

        x = self.lin2(x)

        return x, None

### MLP Used in GIN layers

In [None]:
class MLP(nn.Module):
    def __init__(self, c_in, c_hidden, c_out, dropout=0.2):
        super(MLP, self).__init__()

        self.fc1 = nn.Linear(in_features=c_in, out_features=c_hidden)
        self.fc2 = nn.Linear(in_features=c_hidden, out_features=c_hidden)
        self.fc3 = nn.Linear(in_features=c_hidden, out_features=c_out)


        self.bn1 = nn.BatchNorm1d(c_hidden)
        self.bn2 = nn.BatchNorm1d(c_hidden)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        x = self.fc1(x)
        x = self.bn1(x)
        x = torch.relu(x)
        x = self.dropout(x)

        x = self.fc2(x)
        x = self.bn2(x)
        x = torch.relu(x)
        x = self.dropout(x)

        x = self.fc3(x)

        return x

### HPG-SL Pooling

In [None]:
import torch
import torch.nn as nn
from torch.autograd import Function
from torch_scatter import scatter_add, scatter_max


def scatter_sort(x, batch, fill_value=-1e16):
    num_nodes = scatter_add(batch.new_ones(x.size(0)), batch, dim=0)
    batch_size, max_num_nodes = num_nodes.size(0), num_nodes.max().item()

    cum_num_nodes = torch.cat([num_nodes.new_zeros(1), num_nodes.cumsum(dim=0)[:-1]], dim=0)

    index = torch.arange(batch.size(0), dtype=torch.long, device=x.device)
    index = (index - cum_num_nodes[batch]) + (batch * max_num_nodes)

    dense_x = x.new_full((batch_size * max_num_nodes,), fill_value)
    dense_x[index] = x
    dense_x = dense_x.view(batch_size, max_num_nodes)

    sorted_x, _ = dense_x.sort(dim=-1, descending=True)
    cumsum_sorted_x = sorted_x.cumsum(dim=-1)
    cumsum_sorted_x = cumsum_sorted_x.view(-1)

    sorted_x = sorted_x.view(-1)
    filled_index = sorted_x != fill_value

    sorted_x = sorted_x[filled_index]
    cumsum_sorted_x = cumsum_sorted_x[filled_index]

    return sorted_x, cumsum_sorted_x


def _make_ix_like(batch):
    num_nodes = scatter_add(batch.new_ones(batch.size(0)), batch, dim=0)
    idx = [torch.arange(1, i + 1, dtype=torch.long, device=batch.device) for i in num_nodes]
    idx = torch.cat(idx, dim=0)

    return idx


def _threshold_and_support(x, batch):
    """Sparsemax building block: compute the threshold
    Args:
        x: input tensor to apply the sparsemax
        batch: group indicators
    Returns:
        the threshold value
    """
    num_nodes = scatter_add(batch.new_ones(x.size(0)), batch, dim=0)
    cum_num_nodes = torch.cat([num_nodes.new_zeros(1), num_nodes.cumsum(dim=0)[:-1]], dim=0)

    sorted_input, input_cumsum = scatter_sort(x, batch)
    input_cumsum = input_cumsum - 1.0
    rhos = _make_ix_like(batch).to(x.dtype)
    support = rhos * sorted_input > input_cumsum

    support_size = scatter_add(support.to(batch.dtype), batch)
    # mask invalid index, for example, if batch is not start from 0 or not continuous, it may result in negative index
    idx = support_size + cum_num_nodes - 1
    mask = idx < 0
    idx[mask] = 0
    tau = input_cumsum.gather(0, idx)
    tau /= support_size.to(x.dtype)

    return tau, support_size


class SparsemaxFunction(Function):

    @staticmethod
    def forward(ctx, x, batch):
        """sparsemax: normalizing sparse transform
        Parameters:
            ctx: context object
            x (Tensor): shape (N, )
            batch: group indicator
        Returns:
            output (Tensor): same shape as input
        """
        max_val, _ = scatter_max(x, batch)
        x -= max_val[batch]
        tau, supp_size = _threshold_and_support(x, batch)
        output = torch.clamp(x - tau[batch], min=0)
        ctx.save_for_backward(supp_size, output, batch)

        return output

    @staticmethod
    def backward(ctx, grad_output):
        supp_size, output, batch = ctx.saved_tensors
        grad_input = grad_output.clone()
        grad_input[output == 0] = 0

        v_hat = scatter_add(grad_input, batch) / supp_size.to(output.dtype)
        grad_input = torch.where(output != 0, grad_input - v_hat[batch], grad_input)

        return grad_input, None


sparsemax = SparsemaxFunction.apply


class Sparsemax(nn.Module):

    def __init__(self):
        super(Sparsemax, self).__init__()

    def forward(self, x, batch):
        return sparsemax(x, batch)

In [None]:
from torch_scatter import scatter_add
from torch_sparse import spspmm, coalesce
from torch_geometric.utils import softmax, dense_to_sparse, add_remaining_self_loops
from torch.nn import Parameter
from torch_geometric.data import Data
from torch_geometric.nn.pool.connect.filter_edges import filter_adj
from torch_geometric.nn.pool.select.topk import topk

class TwoHopNeighborhood(object):
    def __call__(self, data):
        edge_index, edge_attr = data.edge_index, data.edge_attr
        n = data.num_nodes

        fill = 1e16
        value = edge_index.new_full((edge_index.size(1),), fill, dtype=torch.float)

        index, value = spspmm(edge_index, value, edge_index, value, n, n, n, True)

        edge_index = torch.cat([edge_index, index], dim=1)
        if edge_attr is None:
            data.edge_index, _ = coalesce(edge_index, None, n, n)
        else:
            value = value.view(-1, *[1 for _ in range(edge_attr.dim() - 1)])
            value = value.expand(-1, *list(edge_attr.size())[1:])
            edge_attr = torch.cat([edge_attr, value], dim=0)
            data.edge_index, edge_attr = coalesce(edge_index, edge_attr, n, n, op='min')
            edge_attr[edge_attr >= fill] = 0
            data.edge_attr = edge_attr

        return data

    def __repr__(self):
        return '{}()'.format(self.__class__.__name__)


class NodeInformationScore(geom_nn.conv.MessagePassing):
    def __init__(self, improved=False, cached=False, **kwargs):
        super(NodeInformationScore, self).__init__(aggr='add', **kwargs)

        self.improved = improved
        self.cached = cached
        self.cached_result = None
        self.cached_num_edges = None

    @staticmethod
    def norm(edge_index, num_nodes, edge_weight, dtype=None):
        if edge_weight is None:
            edge_weight = torch.ones((edge_index.size(1),), dtype=dtype, device=edge_index.device)

        row, col = edge_index
        deg = scatter_add(edge_weight, row, dim=0, dim_size=num_nodes)
        deg_inv_sqrt = deg.pow(-0.5)
        deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0

        edge_index, edge_weight = add_remaining_self_loops(edge_index, edge_weight, 0, num_nodes)

        row, col = edge_index
        expand_deg = torch.zeros((edge_weight.size(0),), dtype=dtype, device=edge_index.device)
        expand_deg[-num_nodes:] = torch.ones((num_nodes,), dtype=dtype, device=edge_index.device)

        return edge_index, expand_deg - deg_inv_sqrt[row] * edge_weight * deg_inv_sqrt[col]

    def forward(self, x, edge_index, edge_weight):
        if self.cached and self.cached_result is not None:
            if edge_index.size(1) != self.cached_num_edges:
                raise RuntimeError(
                    'Cached {} number of edges, but found {}'.format(self.cached_num_edges, edge_index.size(1)))

        if not self.cached or self.cached_result is None:
            self.cached_num_edges = edge_index.size(1)
            edge_index, norm = self.norm(edge_index, x.size(0), edge_weight, x.dtype)
            self.cached_result = edge_index, norm

        edge_index, norm = self.cached_result

        return self.propagate(edge_index, x=x, norm=norm)

    def message(self, x_j, norm):
        return norm.view(-1, 1) * x_j

    def update(self, aggr_out):
        return aggr_out


class HGPSLPool(torch.nn.Module):
    def __init__(self, in_channels, ratio=0.8, sample=False, sparse=False, sl=True, lamb=1.0, negative_slop=0.2):
        super(HGPSLPool, self).__init__()
        self.in_channels = in_channels
        self.ratio = ratio
        self.sample = sample
        self.sparse = sparse
        self.sl = sl
        self.negative_slop = negative_slop
        self.lamb = lamb

        self.att = Parameter(torch.Tensor(1, self.in_channels * 2))
        nn.init.xavier_uniform_(self.att.data)
        self.sparse_attention = Sparsemax()
        self.neighbor_augment = TwoHopNeighborhood()
        self.calc_information_score = NodeInformationScore()

    def forward(self, x, edge_index, edge_attr, batch=None):
        if batch is None:
            batch = edge_index.new_zeros(x.size(0))

        x_information_score = self.calc_information_score(x, edge_index, edge_attr)
        score = torch.sum(torch.abs(x_information_score), dim=1)

        # Graph Pooling
        original_x = x
        perm = topk(score, self.ratio, batch)
        x = x[perm]
        batch = batch[perm]
        induced_edge_index, induced_edge_attr = filter_adj(edge_index, edge_attr, perm, num_nodes=score.size(0))

        # Discard structure learning layer, directly return
        if self.sl is False:
            return x, induced_edge_index, induced_edge_attr, batch

        # Structure Learning
        if self.sample:
            # A fast mode for large graphs.
            # In large graphs, learning the possible edge weights between each pair of nodes is time consuming.
            # To accelerate this process, we sample it's K-Hop neighbors for each node and then learn the
            # edge weights between them.
            k_hop = 3
            if edge_attr is None:
                edge_attr = torch.ones((edge_index.size(1),), dtype=torch.float, device=edge_index.device)

            hop_data = Data(x=original_x, edge_index=edge_index, edge_attr=edge_attr)
            for _ in range(k_hop - 1):
                hop_data = self.neighbor_augment(hop_data)
            hop_edge_index = hop_data.edge_index
            hop_edge_attr = hop_data.edge_attr
            new_edge_index, new_edge_attr = filter_adj(hop_edge_index, hop_edge_attr, perm, num_nodes=score.size(0))

            new_edge_index, new_edge_attr = add_remaining_self_loops(new_edge_index, new_edge_attr, 0, x.size(0))
            row, col = new_edge_index
            weights = (torch.cat([x[row], x[col]], dim=1) * self.att).sum(dim=-1)
            weights = F.leaky_relu(weights, self.negative_slop) + new_edge_attr * self.lamb
            adj = torch.zeros((x.size(0), x.size(0)), dtype=torch.float, device=x.device)
            adj[row, col] = weights
            new_edge_index, weights = dense_to_sparse(adj)
            row, col = new_edge_index
            if self.sparse:
                new_edge_attr = self.sparse_attention(weights, row)
            else:
                new_edge_attr = softmax(weights, row, x.size(0))
            # filter out zero weight edges
            adj[row, col] = new_edge_attr
            new_edge_index, new_edge_attr = dense_to_sparse(adj)
            # release gpu memory
            del adj
            torch.cuda.empty_cache()
        else:
            # Learning the possible edge weights between each pair of nodes in the pooled subgraph, relative slower.
            if edge_attr is None:
                induced_edge_attr = torch.ones((induced_edge_index.size(1),), dtype=x.dtype,
                                               device=induced_edge_index.device)
            num_nodes = scatter_add(batch.new_ones(x.size(0)), batch, dim=0)
            shift_cum_num_nodes = torch.cat([num_nodes.new_zeros(1), num_nodes.cumsum(dim=0)[:-1]], dim=0)
            cum_num_nodes = num_nodes.cumsum(dim=0)
            adj = torch.zeros((x.size(0), x.size(0)), dtype=torch.float, device=x.device)
            # Construct batch fully connected graph in block diagonal matirx format
            for idx_i, idx_j in zip(shift_cum_num_nodes, cum_num_nodes):
                adj[idx_i:idx_j, idx_i:idx_j] = 1.0
            new_edge_index, _ = dense_to_sparse(adj)
            row, col = new_edge_index

            weights = (torch.cat([x[row], x[col]], dim=1) * self.att).sum(dim=-1)
            weights = F.leaky_relu(weights, self.negative_slop)
            adj[row, col] = weights
            induced_row, induced_col = induced_edge_index

            adj[induced_row, induced_col] += induced_edge_attr * self.lamb
            weights = adj[row, col]
            if self.sparse:
                new_edge_attr = self.sparse_attention(weights, row)
            else:
                new_edge_attr = softmax(weights, row, x.size(0))
            # filter out zero weight edges
            adj[row, col] = new_edge_attr
            new_edge_index, new_edge_attr = dense_to_sparse(adj)
            # release gpu memory
            del adj
            torch.cuda.empty_cache()

        return x, new_edge_index, new_edge_attr, batch

### GCN Lalyer for HGP-SL

In [None]:
class GCN(geom_nn.conv.MessagePassing):
    def __init__(self, in_channels, out_channels, cached=False, bias=True, **kwargs):
        super(GCN, self).__init__(aggr='add', **kwargs)

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.cached = cached
        self.cached_result = None
        self.cached_num_edges = None

        self.weight = Parameter(torch.Tensor(in_channels, out_channels))
        nn.init.xavier_uniform_(self.weight.data)

        if bias:
            self.bias = Parameter(torch.Tensor(out_channels))
            nn.init.zeros_(self.bias.data)
        else:
            self.register_parameter('bias', None)

        self.reset_parameters()

    def reset_parameters(self):
        self.cached_result = None
        self.cached_num_edges = None

    @staticmethod
    def norm(edge_index, num_nodes, edge_weight, dtype=None):
        if edge_weight is None:
            edge_weight = torch.ones((edge_index.size(1),), dtype=dtype, device=edge_index.device)

        row, col = edge_index
        deg = scatter_add(edge_weight, row, dim=0, dim_size=num_nodes)
        deg_inv_sqrt = deg.pow(-0.5)
        deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0

        return edge_index, deg_inv_sqrt[row] * edge_weight * deg_inv_sqrt[col]

    def forward(self, x, edge_index, edge_weight=None):
        x = torch.matmul(x, self.weight)

        if self.cached and self.cached_result is not None:
            if edge_index.size(1) != self.cached_num_edges:
                raise RuntimeError(
                    'Cached {} number of edges, but found {}'.format(self.cached_num_edges, edge_index.size(1)))

        if not self.cached or self.cached_result is None:
            self.cached_num_edges = edge_index.size(1)
            edge_index, norm = self.norm(edge_index, x.size(0), edge_weight, x.dtype)
            self.cached_result = edge_index, norm

        edge_index, norm = self.cached_result

        return self.propagate(edge_index, x=x, norm=norm)

    def message(self, x_j, norm):
        return norm.view(-1, 1) * x_j

    def update(self, aggr_out):
        if self.bias is not None:
            aggr_out = aggr_out + self.bias
        return aggr_out

    def __repr__(self):
        return '{}({}, {})'.format(self.__class__.__name__, self.in_channels, self.out_channels)

### HGP-SL Model

In [None]:
class HGPSLModel(nn.Module):

    def __init__(self, c_in, c_hidden, c_out, conv_layer, concat=False, pooling_ratio=0.8, sample=True, sparse=True, sl=True, lamb=1, dropout_ratio=0.0, **kwargs):
        """
        Inputs:
            c_in - Dimension of input features
            c_hidden - Dimension of hidden features
            c_out - Dimension of the output features. Usually number of classes in classification
            dp_rate_linear - Dropout rate before the linear layer (usually much higher than inside the GNN)
            kwargs - Additional arguments for the graph layers
        """
        super().__init__()

        self.dropout_ratio = dropout_ratio
        self.concat = concat

        if conv_layer == 'gcn':
            self.conv1 = geom_nn.GCNConv(c_in, c_hidden)
        elif conv_layer == 'sage':
            self.conv1 = geom_nn.SAGEConv(c_in, c_hidden)
        elif conv_layer == 'gat':
            self.conv1 = geom_nn.GATConv(c_in, c_hidden)
        elif conv_layer == 'gin':
            self.conv1 = geom_nn.GINConv(MLP(c_in, c_hidden, c_hidden, dropout_ratio))

        self.conv2 = GCN(c_hidden, c_hidden)
        self.conv3 = GCN(c_hidden, c_hidden)

        self.pool1 = HGPSLPool(c_hidden, pooling_ratio, sample, sparse, sl, lamb)
        self.pool2 = HGPSLPool(c_hidden, pooling_ratio, sample, sparse, sl, lamb)

        # Concatenation Layers
        if self.concat:
            self.lin0 = torch.nn.Linear(c_in, c_hidden)
            self.lin1 = torch.nn.Linear(c_hidden * 3, 2 * c_hidden)

        self.lin2 = torch.nn.Linear(2 * c_hidden, c_hidden)
        self.lin3 = torch.nn.Linear(c_hidden, c_hidden // 2)
        self.lin4 = torch.nn.Linear(c_hidden // 2, c_out)


    def forward(self, x, edge_index, batch_idx, batch_size):
        """
        Inputs:
            x - Input features per node
            edge_index - List of vertex index pairs representing the edges in the graph (PyTorch geometric notation)
            batch_idx - Index of batch element for each node
            batch_size - Size of a batch that is used in MixUp
            lam - Lambda paratemeter for MixUp
        """
        raw_x = x.clone()
        batch_idx_raw = batch_idx
        edge_attr = None

        x = F.relu(self.conv1(x, edge_index, edge_attr))
        x, edge_index, edge_attr, batch_idx = self.pool1(x, edge_index, edge_attr, batch_idx)
        x1 = torch.cat([geom_nn.global_max_pool(x, batch_idx), geom_nn.global_mean_pool(x, batch_idx)], dim=1)

        x = F.relu(self.conv2(x, edge_index, edge_attr))
        x, edge_index, edge_attr, batch_idx = self.pool2(x, edge_index, edge_attr, batch_idx)
        x2 = torch.cat([geom_nn.global_max_pool(x, batch_idx), geom_nn.global_mean_pool(x, batch_idx)], dim=1)

        x = F.relu(self.conv3(x, edge_index, edge_attr))
        x3 = torch.cat([geom_nn.global_max_pool(x, batch_idx), geom_nn.global_mean_pool(x, batch_idx)], dim=1)

        x = F.relu(x1) + F.relu(x2) + F.relu(x3)

        if self.concat:
            news = torch.stack([raw_x[(batch_idx_raw == idx).nonzero().squeeze()[0]] for idx in range(batch_size)])
            news = F.relu(self.lin0(news))
            x = torch.cat([x, news], dim=1)
            x = F.relu(self.lin1(x))

        x = F.relu(self.lin2(x))
        x = F.dropout(x, p=self.dropout_ratio, training=self.training)
        x = F.relu(self.lin3(x))
        x = F.dropout(x, p=self.dropout_ratio, training=self.training)
        x = self.lin4(x)

        return x, None

### PL Training Module

In [None]:
class GraphLevelGNN(pl.LightningModule):

    def __init__(self, model_name, learning_rate, **model_kwargs):
        super().__init__()
        # Saving hyperparameters
        self.save_hyperparameters()
        self.lr = learning_rate
        self.model_name = model_name

        if "mewis" in self.model_name.lower():
            self.model = MEWISModel(**model_kwargs)
        elif "hgpsl" in self.model_name.lower():
            self.model = HGPSLModel(**model_kwargs)
        else:
            self.model = UPFDModel(**model_kwargs)

        self.loss_module = nn.BCEWithLogitsLoss()
        # Accuracy
        self.train_acc = torchmetrics.classification.BinaryAccuracy()
        self.valid_acc = torchmetrics.classification.BinaryAccuracy()
        self.test_acc = torchmetrics.classification.BinaryAccuracy()
        self.test_f1 = torchmetrics.classification.BinaryF1Score()
        self.test_prec = torchmetrics.classification.BinaryPrecision()
        self.test_rec = torchmetrics.classification.BinaryRecall()


    def forward(self, data, mode="train"):
        x, edge_index, batch_idx, batch_size = data.x, data.edge_index, data.batch, data.batch_size

        x, loss_pool = self.model(x, edge_index, batch_idx, batch_size)
        x = x.squeeze(dim=-1)

        preds = (x > 0).float()
        data.y = data.y.float()


        if "mewis" in self.model_name.lower():
            loss = self.loss_module(x, data.y) + 0.01 * loss_pool
        else:
            loss = self.loss_module(x, data.y)

        return loss, preds

    def configure_optimizers(self):
        optimizer = optim.Adam(self.parameters(), lr=self.lr, weight_decay=0.001)
        return optimizer

    def training_step(self, batch, batch_idx):
        loss, preds = self.forward(batch, mode="train")
        self.log('train_loss', loss, on_step=False, on_epoch=True, batch_size=batch.batch_size)
        self.train_acc(preds, batch.y)
        self.log('train_acc', self.train_acc, on_step=False, on_epoch=True, batch_size=batch.batch_size)
        return loss

    def validation_step(self, batch, batch_idx):
        _, preds = self.forward(batch,mode="val")
        self.valid_acc(preds, batch.y)
        self.log('valid_acc', self.valid_acc, on_step=False, on_epoch=True, batch_size=batch.batch_size)


    def test_step(self, batch, batch_idx):
        _, preds = self.forward(batch, mode="test")
        self.test_acc(preds, batch.y)
        self.log('test_acc', self.test_acc, on_step=False, on_epoch=True, batch_size=batch.batch_size)
        self.test_f1(preds, batch.y)
        self.log('test_f1', self.test_f1, on_step=False, on_epoch=True, batch_size=batch.batch_size)
        self.test_prec(preds, batch.y)
        self.log('test_prec', self.test_prec, on_step=False, on_epoch=True, batch_size=batch.batch_size)
        self.test_rec(preds, batch.y)
        self.log('test_rec', self.test_rec, on_step=False, on_epoch=True, batch_size=batch.batch_size)

### Model training

In [None]:
def train_graph_classifier(model_name, dataset, batch_size, feature, max_epochs, patience, **model_kwargs):
    # Create a PyTorch Lightning trainer with the generation callback
    root_dir = os.path.join(CHECKPOINT_PATH, "GraphLevel" + model_name)
    os.makedirs(root_dir, exist_ok=True)

    trainer = pl.Trainer(
        default_root_dir=root_dir,
        callbacks=[
            ModelCheckpoint(save_weights_only=True, mode="max", monitor="valid_acc"),
            EarlyStopping(monitor="valid_acc", min_delta=0.00, patience=patience, verbose=True, mode="max")],
        accelerator="gpu" if str(device).startswith("cuda") else "cpu",
        devices=1,
        max_epochs=max_epochs,
        enable_progress_bar=False
    )
    trainer.logger._default_hp_metric = None # Optional logging argument that we don't need

    dm = UPFDDataModule(
        batch_size = batch_size,
        dataset = dataset,
        feature = feature
    )
    dm.setup(stage="fit")

    model = GraphLevelGNN(
        model_name=model_name,
        c_in=dm.num_node_features,
        c_out=1 if dm.num_classes==2 else dm.num_classes,
        **model_kwargs
    )

    trainer.fit(
        model,
        datamodule=dm
    )
    # Load and test best model
    model = GraphLevelGNN.load_from_checkpoint(trainer.checkpoint_callback.best_model_path)
    dm.setup(stage="test")
    test_result = trainer.test(
        model,
        datamodule=dm,
        verbose=True
    )
    return model

In [None]:
model = train_graph_classifier(
        model_name="hgpsl_gcn_pol",
        dataset="politifact",
        feature="bert",
        batch_size=128,
        max_epochs=150,
        patience=30,
        learning_rate=1e-2,
        conv_layer="gcn",
        c_hidden=128,
        dropout_ratio=0.3,
        pooling_ratio=0.3,
        concat = False
    )

In [None]:
%reload_ext tensorboard
%tensorboard --logdir=./saved_models/