# Setup

In [None]:
import os
import torch
torch_version = str(torch.__version__)
scatter_src = f"https://pytorch-geometric.com/whl/torch-{torch_version}.html"
sparse_src = f"https://pytorch-geometric.com/whl/torch-{torch_version}.html"
!pip install torch-scatter -f $scatter_src
!pip install torch-sparse -f $sparse_src
!pip install torch-geometric
!pip install texttable

import pandas as pd
import numpy as np
import torch
import os
import json
from torch_geometric.data import Data

from google.colab import drive
drive.mount('/content/drive')

# Custom Dataset Class

This is the custom dataset class that we use to load our own scene graphs.

In [None]:
class CustomDataset(torch.utils.data.Dataset):
    """
    Custom Pytorch Geometric dataset class for loading scene graphs
    """
    def __init__(self, df, output_folder_path):
        """
        Initialize the dataset
        :param df: Dataframe containing graph dataset information (filepaths and coordinates).
        :param output_folder_path: Path to the folder containing scene graph JSON files.
        """
        super(CustomDataset, self).__init__()
        self.output_folder_path = output_folder_path
        self.df = df
        # Calculate the normalized distance matrix
        self.ndist_matrix = self.calculate_ndist_matrix()

    def __len__(self):
        """
        Get the length of the dataset.
        """
        return len(self.df)

    def __getitem__(self, idx):
        """
        Get an item from the dataset.
        :param idx: Index of the item.
        :return data: PyTorch Geometric Data object containing graph information.
        """
        if torch.is_tensor(idx): # Handle tensor indices
            idx = idx.tolist()

        if isinstance(idx, slice):  # Handle slice indexing
            idx = list(range(*idx.indices(len(self))))

        if isinstance(idx, list):  # Handle list of indices
            return [self[i] for i in idx]

        # Load graph data from JSON file
        graph_filename = os.path.join(self.output_folder_path,
                                      self.df.iloc[idx]['graph_filename'])
        with open(graph_filename, 'r') as f:
            graph_data = json.load(f)

        # Extract node features and create node features tensor
        node_features = []
        for node in graph_data['nodes']:
            node_features.append(node['embedding'])
        x = torch.tensor(node_features, dtype=torch.float32)

        # Extract edge information and create edge index and edge features tensor
        edge_index = []
        edge_attr = []
        # Create id to index mapping
        node_id_map = {node['id']: i for i, node in enumerate(graph_data['nodes'])}

        for edge in graph_data['edges']:
            source_index = node_id_map.get(edge['source'])
            target_index = node_id_map.get(edge['target'])
            if source_index is not None and target_index is not None:
                edge_index.append([source_index, target_index])
                edge_attr.append(edge['embedding'])

        edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous()
        edge_attr = torch.tensor(edge_attr, dtype=torch.float32)

        # Create PyTorch Geometric Data object
        # we add the idx as an attribute to help with indexing for the ndist_matrix
        data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr, i=idx)

        # Add other columns as attributes
        for col in self.df.columns:
            setattr(data, col, self.df.iloc[idx][col])

        return data

    def calculate_ndist_matrix(self):
        """
        Calculate the normalized distance matrix based on L2 distance between XYZ coordinates.
        """
        xyz = self.df[['X', 'Y', 'Z']].values
        # Calculate L2 distances
        distance_matrix = np.linalg.norm(xyz[:, None, :]
                                         - xyz[None, :, :], axis=-1)
        max_dist = np.max(distance_matrix)
        normalized_matrix = distance_matrix / max_dist
        return torch.tensor(normalized_matrix, dtype=torch.float32)

# SimGNN

## Extended SimGNN Modules
These modules are directly borrowed from https://github.com/gospodima/Extended-SimGNN/tree/master, and cover the Attention layer, the Neural Tensor Network and the Diffpool modules.


In [None]:
import torch
import torch.nn.functional as F

from math import ceil
from torch.nn import Linear, ReLU
from torch_geometric.nn import (
    DenseSAGEConv,
    DenseGCNConv,
    DenseGINConv,
    dense_diff_pool,
    JumpingKnowledge,
)
from torch_scatter import scatter_mean, scatter_add


class AttentionModule(torch.nn.Module):
    """
    SimGNN Attention Module to make a pass on graph.
    """

    def __init__(self, args):
        """
        :param args: Arguments object.
        """
        super(AttentionModule, self).__init__()
        self.args = args
        self.setup_weights()
        self.init_parameters()

    def setup_weights(self):
        """
        Defining weights.
        """
        self.weight_matrix = torch.nn.Parameter(
            torch.Tensor(self.args.filters_3, self.args.filters_3)
        )

    def init_parameters(self):
        """
        Initializing weights.
        """
        torch.nn.init.xavier_uniform_(self.weight_matrix)

    def forward(self, x, batch, size=None):
        """
        Making a forward propagation pass to create a graph level representation.
        :param x: Result of the GNN.
        :param size: Dimension size for scatter_
        :param batch: Batch vector, which assigns each node to a specific example
        :return representation: A graph level representation matrix.
        """
        size = batch[-1].item() + 1 if size is None else size
        mean = scatter_mean(x, batch, dim=0, dim_size=size)
        transformed_global = torch.tanh(torch.mm(mean, self.weight_matrix))

        coefs = torch.sigmoid((x * transformed_global[batch]).sum(dim=1))
        weighted = coefs.unsqueeze(-1) * x

        return scatter_add(weighted, batch, dim=0, dim_size=size)

    def get_coefs(self, x):
        mean = x.mean(dim=0)
        transformed_global = torch.tanh(torch.matmul(mean, self.weight_matrix))

        return torch.sigmoid(torch.matmul(x, transformed_global))


class DenseAttentionModule(torch.nn.Module):
    """
    SimGNN Dense Attention Module to make a pass on graph.
    """

    def __init__(self, args):
        """
        :param args: Arguments object.
        """
        super(DenseAttentionModule, self).__init__()
        self.args = args
        self.setup_weights()
        self.init_parameters()

    def setup_weights(self):
        """
        Defining weights.
        """
        self.weight_matrix = torch.nn.Parameter(
            torch.Tensor(self.args.filters_3, self.args.filters_3)
        )

    def init_parameters(self):
        """
        Initializing weights.
        """
        torch.nn.init.xavier_uniform_(self.weight_matrix)

    def forward(self, x, mask=None):
        """
        Making a forward propagation pass to create a graph level representation.
        :param x: Result of the GNN.
        :param mask: Mask matrix indicating the valid nodes for each graph.
        :return representation: A graph level representation matrix.
        """
        B, N, _ = x.size()

        if mask is not None:
            num_nodes = mask.view(B, N).sum(dim=1).unsqueeze(-1)
            mean = x.sum(dim=1) / num_nodes.to(x.dtype)
        else:
            mean = x.mean(dim=1)

        transformed_global = torch.tanh(torch.mm(mean, self.weight_matrix))

        koefs = torch.sigmoid(torch.matmul(x, transformed_global.unsqueeze(-1)))
        weighted = koefs * x

        if mask is not None:
            weighted = weighted * mask.view(B, N, 1).to(x.dtype)

        return weighted.sum(dim=1)


class TensorNetworkModule(torch.nn.Module):
    """
    SimGNN Tensor Network module to calculate similarity vector.
    """

    def __init__(self, args):
        """
        :param args: Arguments object.
        """
        super(TensorNetworkModule, self).__init__()
        self.args = args
        self.setup_weights()
        self.init_parameters()

    def setup_weights(self):
        """
        Defining weights.
        """
        self.weight_matrix = torch.nn.Parameter(
            torch.Tensor(
                self.args.filters_3, self.args.filters_3, self.args.tensor_neurons
            )
        )
        self.weight_matrix_block = torch.nn.Parameter(
            torch.Tensor(self.args.tensor_neurons, 2 * self.args.filters_3)
        )
        self.bias = torch.nn.Parameter(torch.Tensor(self.args.tensor_neurons, 1))

    def init_parameters(self):
        """
        Initializing weights.
        """
        torch.nn.init.xavier_uniform_(self.weight_matrix)
        torch.nn.init.xavier_uniform_(self.weight_matrix_block)
        torch.nn.init.xavier_uniform_(self.bias)

    def forward(self, embedding_1, embedding_2):
        """
        Making a forward propagation pass to create a similarity vector.
        :param embedding_1: Result of the 1st embedding after attention.
        :param embedding_2: Result of the 2nd embedding after attention.
        :return scores: A similarity score vector.
        """
        batch_size = len(embedding_1)
        scoring = torch.matmul(
            embedding_1, self.weight_matrix.view(self.args.filters_3, -1)
        )
        scoring = scoring.view(batch_size, self.args.filters_3, -1).permute([0, 2, 1])
        scoring = torch.matmul(
            scoring, embedding_2.view(batch_size, self.args.filters_3, 1)
        ).view(batch_size, -1)
        combined_representation = torch.cat((embedding_1, embedding_2), 1)
        block_scoring = torch.t(
            torch.mm(self.weight_matrix_block, torch.t(combined_representation))
        )
        scores = F.relu(scoring + block_scoring + self.bias.view(-1))
        return scores


class Block(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, mode="cat"):
        super(Block, self).__init__()

        # self.conv1 = DenseSAGEConv(in_channels, hidden_channels)
        # self.conv2 = DenseSAGEConv(hidden_channels, out_channels)

        # self.conv1 = DenseGCNConv(in_channels, hidden_channels)
        # self.conv2 = DenseGCNConv(hidden_channels, out_channels)

        nn1 = torch.nn.Sequential(
            Linear(in_channels, hidden_channels),
            ReLU(),
            Linear(hidden_channels, hidden_channels),
        )

        nn2 = torch.nn.Sequential(
            Linear(hidden_channels, out_channels),
            ReLU(),
            Linear(out_channels, out_channels),
        )

        self.conv1 = DenseGINConv(nn1, train_eps=True)
        self.conv2 = DenseGINConv(nn2, train_eps=True)

        self.jump = JumpingKnowledge(mode)
        if mode == "cat":
            self.lin = Linear(hidden_channels + out_channels, out_channels)
        else:
            self.lin = Linear(out_channels, out_channels)

    def reset_parameters(self):
        self.conv1.reset_parameters()
        self.conv2.reset_parameters()
        self.lin.reset_parameters()

    def forward(self, x, adj, mask=None, add_loop=True):
        x1 = F.relu(self.conv1(x, adj, mask, add_loop))
        x2 = F.relu(self.conv2(x1, adj, mask, add_loop))
        return self.lin(self.jump([x1, x2]))


class DiffPool(torch.nn.Module):
    def __init__(self, args, num_nodes=10, num_layers=4, hidden=16, ratio=0.25):
        super(DiffPool, self).__init__()

        self.args = args
        num_features = self.args.filters_3

        self.att = DenseAttentionModule(self.args)

        num_nodes = ceil(ratio * num_nodes)
        self.embed_block1 = Block(num_features, hidden, hidden)
        self.pool_block1 = Block(num_features, hidden, num_nodes)

        self.embed_blocks = torch.nn.ModuleList()
        self.pool_blocks = torch.nn.ModuleList()
        for i in range((num_layers // 2) - 1):
            num_nodes = ceil(ratio * num_nodes)
            self.embed_blocks.append(Block(hidden, hidden, hidden))
            self.pool_blocks.append(Block(hidden, hidden, num_nodes))
        self.jump = JumpingKnowledge(mode="cat")
        self.lin1 = Linear((len(self.embed_blocks) + 1) * hidden, hidden)
        self.lin2 = Linear(hidden, num_features)

    def reset_parameters(self):
        self.embed_block1.reset_parameters()
        self.pool_block1.reset_parameters()
        for block1, block2 in zip(self.embed_blocks, self.pool_blocks):
            block1.reset_parameters()
            block2.reset_parameters()
        self.jump.reset_parameters()
        self.lin1.reset_parameters()
        self.lin2.reset_parameters()

    def forward(self, x, adj, mask):
        s = self.pool_block1(x, adj, mask, add_loop=True)
        x = F.relu(self.embed_block1(x, adj, mask, add_loop=True))

        xs = [self.att(x, mask)]
        x, adj, _, _ = dense_diff_pool(x, adj, s, mask)

        for i, (embed, pool) in enumerate(zip(self.embed_blocks, self.pool_blocks)):
            s = pool(x, adj)
            x = F.relu(embed(x, adj))
            xs.append(self.att(x))
            if i < (len(self.embed_blocks) - 1):
                x, adj, _, _ = dense_diff_pool(x, adj, s)

        x = self.jump(xs)
        x = F.relu(self.lin1(x))
        x = self.lin2(x)
        return x

    def __repr__(self):
        return self.__class__.__name__

## Utils
These cover functions for evaluation metrics and helper function for batch creation.

In [None]:
import math
import numpy as np
import networkx as nx
import torch
import random
from texttable import Texttable
from torch_geometric.utils import erdos_renyi_graph, to_undirected, to_networkx
from torch_geometric.data import Data
import matplotlib.pyplot as plt

def calculate_ranking_correlation(rank_corr_function, prediction, target):
    """
    Calculating specific ranking correlation for predicted values.
    :param rank_corr_function: Ranking correlation function.
    :param prediction: Vector of predicted values.
    :param target: Vector of ground-truth values.
    :return ranking: Ranking correlation value.
    """
    temp = prediction.argsort()
    r_prediction = np.empty_like(temp)
    r_prediction[temp] = np.arange(len(prediction))

    temp = target.argsort()
    r_target = np.empty_like(temp)
    r_target[temp] = np.arange(len(target))

    return rank_corr_function(r_prediction, r_target).correlation

def calculate_prec_at_k(k, prediction, target):
    """
    Calculating precision at k using distance metrics.
    :param k: Number of top items to consider.
    :param prediction: Vector of predicted distance values.
    :param target: Vector of ground-truth distance values.
    :return: Precision at k.
    """
    # Adjust k in case of ties at the k-th distance value
    target_increase = np.sort(target)  # Sort in ascending order (lower distance = more relevant)
    target_value_sel = (target_increase <= target_increase[k - 1]).sum()
    target_k = max(k, target_value_sel)

    # Select top-k indices based on smallest distances
    best_k_pred = prediction.argsort()[:k]
    best_k_target = target.argsort()[:target_k]

    return len(set(best_k_pred).intersection(set(best_k_target))) / k

def random_sample_from_closest_vectorized(given_numbers, number_list, neighbor_sample_size=5, close_fraction=0.5):
    # Convert to NumPy arrays
    batch_size = len(given_numbers)
    given_numbers = np.array(given_numbers)[:int(close_fraction*batch_size)]
    number_list = np.array(number_list)

    # Step 1: Compute absolute differences for all pairs
    differences = np.abs(given_numbers[:, None] - number_list[None, :])

    # Step 2: Get the indices of the k smallest differences for each given number
    closest_indices = np.argsort(differences, axis=1)[:, :neighbor_sample_size]

    # Step 3: Select the closest numbers
    closest_numbers = number_list[closest_indices]

    # Step 4: Randomly sample one number for each row
    random_choices = [np.random.choice(row) for row in closest_numbers] + [np.random.choice(number_list) for _ in range(batch_size-len(closest_numbers))]

    return random_choices

## Custom EdgeGCNConv Layer
An EdgeGCN is a modified (but simpler) version of a GCN that includes the edge embeddings as part of its message passing. For this, we use a simple addition between the node and edge embeddings (after appropriate linear transformations to a matching dimensional space)  in the message passing function. Mathematically, this can be shown as $$m_u^{(l)} = \mathbf{W_n} h_u + \mathbf{W_e} e_{uv}$$ where $W_n$ and $W_e$ represent learnable weight matrices. To further increase expressiveness, we also change the aggregation function to addition instead of the default mean. This could help differentiate between graphs with higher number of nodes, potentially corresponding to richer scenes, and graphs with lesser number of nodes.  These changes do not affect the update function itself, which thus stays the same.



In [None]:
import torch
from torch_geometric.nn import MessagePassing

class EdgeGCNConv(MessagePassing):
    """
    A modified version of a GCN layer that also uses edge features.
    """
    def __init__(self, in_channels, out_channels, edge_dim):
        super(EdgeGCNConv, self).__init__(aggr='add')
        self.node_mlp = torch.nn.Linear(in_channels, out_channels)
        self.edge_mlp = torch.nn.Linear(edge_dim, out_channels)

    def forward(self, x, edge_index, edge_attr):
        return self.propagate(edge_index, x=x, edge_attr=edge_attr)

    def message(self, x_j, edge_attr):
	   # Updated message passing function
        return self.node_mlp(x_j) + self.edge_mlp(edge_attr)

    def update(self, aggr_out):
        return aggr_out

## Main classes
Primary SimGNN model class that represents the full model architecture, and the SimGNN trainer class that trains and scores the model on the input dataset. Some functions have been customized for our use. Original code source - https://github.com/gospodima/Extended-SimGNN/tree/master.

In [None]:
from types import new_class
import torch
import random
import numpy as np
import torch.nn.functional as F
from tqdm import tqdm, trange
from scipy.stats import spearmanr, kendalltau

from torch_geometric.nn import GCNConv, GINConv
from torch_geometric.data import DataLoader, Batch
from torch_geometric.utils import to_dense_batch, to_dense_adj, degree
from torch_geometric.datasets import GEDDataset
from torch_geometric.transforms import OneHotDegree

import matplotlib.pyplot as plt


class SimGNN(torch.nn.Module):
    """
    SimGNN: A Neural Network Approach to Fast Graph Similarity Computation
    https://arxiv.org/abs/1808.05689
    """

    def __init__(self, args, number_of_labels):
        """
        :param args: Arguments object.
        :param number_of_labels: Number of node labels.
        """
        super(SimGNN, self).__init__()
        self.args = args
        self.number_labels = number_of_labels
        self.setup_layers()

    def calculate_bottleneck_features(self):
        """
        Deciding the shape of the bottleneck layer.
        """
        if self.args.histogram:
            self.feature_count = self.args.tensor_neurons + self.args.bins
        else:
            self.feature_count = self.args.tensor_neurons

    def setup_layers(self):
        """
        Creating the layers.
        """
        self.calculate_bottleneck_features()
        if self.args.gnn_operator == "gcn":
            self.convolution_1 = GCNConv(self.number_labels, self.args.filters_1)
            self.convolution_2 = GCNConv(self.args.filters_1, self.args.filters_2)
            self.convolution_3 = GCNConv(self.args.filters_2, self.args.filters_3)
        elif self.args.gnn_operator == "egcn":
            self.convolution_1 = EdgeGCNConv(self.number_labels, self.args.filters_1, self.args.edge_dim)
            self.convolution_2 = EdgeGCNConv(self.args.filters_1, self.args.filters_2, self.args.edge_dim)
            self.convolution_3 = EdgeGCNConv(self.args.filters_2, self.args.filters_3, self.args.edge_dim)
        elif self.args.gnn_operator == "gin":
            nn1 = torch.nn.Sequential(
                torch.nn.Linear(self.number_labels, self.args.filters_1),
                torch.nn.ReLU(),
                torch.nn.Linear(self.args.filters_1, self.args.filters_1),
                torch.nn.BatchNorm1d(self.args.filters_1),
            )

            nn2 = torch.nn.Sequential(
                torch.nn.Linear(self.args.filters_1, self.args.filters_2),
                torch.nn.ReLU(),
                torch.nn.Linear(self.args.filters_2, self.args.filters_2),
                torch.nn.BatchNorm1d(self.args.filters_2),
            )

            nn3 = torch.nn.Sequential(
                torch.nn.Linear(self.args.filters_2, self.args.filters_3),
                torch.nn.ReLU(),
                torch.nn.Linear(self.args.filters_3, self.args.filters_3),
                torch.nn.BatchNorm1d(self.args.filters_3),
            )

            self.convolution_1 = GINConv(nn1, train_eps=True)
            self.convolution_2 = GINConv(nn2, train_eps=True)
            self.convolution_3 = GINConv(nn3, train_eps=True)
        else:
            raise NotImplementedError("Unknown GNN-Operator.")

        if self.args.diffpool:
            self.attention = DiffPool(self.args)
        else:
            self.attention = AttentionModule(self.args)

        self.tensor_network = TensorNetworkModule(self.args)
        self.fully_connected_first = torch.nn.Linear(
            self.feature_count, self.args.bottle_neck_neurons
        )
        self.scoring_layer = torch.nn.Linear(self.args.bottle_neck_neurons, 1)

    def calculate_histogram(
        self, abstract_features_1, abstract_features_2, batch_1, batch_2
    ):
        """
        Calculate histogram from similarity matrix.
        :param abstract_features_1: Feature matrix for target graphs.
        :param abstract_features_2: Feature matrix for source graphs.
        :param batch_1: Batch vector for source graphs, which assigns each node to a specific example
        :param batch_1: Batch vector for target graphs, which assigns each node to a specific example
        :return hist: Histsogram of similarity scores.
        """
        abstract_features_1, mask_1 = to_dense_batch(abstract_features_1, batch_1)
        abstract_features_2, mask_2 = to_dense_batch(abstract_features_2, batch_2)

        B1, N1, _ = abstract_features_1.size()
        B2, N2, _ = abstract_features_2.size()

        mask_1 = mask_1.view(B1, N1)
        mask_2 = mask_2.view(B2, N2)
        num_nodes = torch.max(mask_1.sum(dim=1), mask_2.sum(dim=1))

        scores = torch.matmul(
            abstract_features_1, abstract_features_2.permute([0, 2, 1])
        ).detach()

        hist_list = []
        for i, mat in enumerate(scores):
            mat = torch.sigmoid(mat[: num_nodes[i], : num_nodes[i]]).view(-1)
            hist = torch.histc(mat, bins=self.args.bins)
            hist = hist / torch.sum(hist)
            hist = hist.view(1, -1)
            hist_list.append(hist)

        return torch.stack(hist_list).view(-1, self.args.bins)

    def convolutional_pass(self, edge_index, features, edge_attr):
        """
        Making convolutional pass.
        :param edge_index: Edge indices.
        :param features: Feature matrix.
        :return features: Abstract feature matrix.
        """
        if self.args.gnn_operator in ["gcn", "gin"]:
            features = self.convolution_1(features, edge_index)
        elif self.args.gnn_operator == "egcn":
            features = self.convolution_1(features, edge_index, edge_attr)
        else:
            raise NotImplementedError("Unknown GNN-Operator.")
        features = F.relu(features)
        features = F.dropout(features, p=self.args.dropout, training=self.training)
        ## we only use a single convolutional layer for our use case
        # features = self.convolution_2(features, edge_index, edge_attr)
        # features = F.relu(features)
        # features = F.dropout(features, p=self.args.dropout, training=self.training)
        # features = self.convolution_3(features, edge_index, edge_attr)
        return features

    def diffpool(self, abstract_features, edge_index, batch):
        """
        Making differentiable pooling.
        :param abstract_features: Node feature matrix.
        :param edge_index: Edge indices
        :param batch: Batch vector, which assigns each node to a specific example
        :return pooled_features: Graph feature matrix.
        """
        x, mask = to_dense_batch(abstract_features, batch)
        adj = to_dense_adj(edge_index, batch)
        return self.attention(x, adj, mask)

    def forward(self, data):
        """
        Forward pass with graphs.
        :param data: Data dictionary.
        :return score: Similarity score.
        """
        edge_index_1 = data["g1"].edge_index
        edge_index_2 = data["g2"].edge_index
        features_1 = data["g1"].x
        features_2 = data["g2"].x
        if hasattr(data["g1"], "edge_attr"):
            edge_attr_1 = data["g1"].edge_attr
            edge_attr_2 = data["g2"].edge_attr
        batch_1 = (
            data["g1"].batch
            if hasattr(data["g1"], "batch")
            else torch.tensor((), dtype=torch.long).new_zeros(data["g1"].num_nodes)
        )
        batch_2 = (
            data["g2"].batch
            if hasattr(data["g2"], "batch")
            else torch.tensor((), dtype=torch.long).new_zeros(data["g2"].num_nodes)
        )

        abstract_features_1 = self.convolutional_pass(edge_index_1, features_1, edge_attr_1)
        abstract_features_2 = self.convolutional_pass(edge_index_2, features_2, edge_attr_2)

        if self.args.histogram:
            hist = self.calculate_histogram(
                abstract_features_1, abstract_features_2, batch_1, batch_2
            )

        if self.args.diffpool:
            pooled_features_1 = self.diffpool(
                abstract_features_1, edge_index_1, batch_1
            )
            pooled_features_2 = self.diffpool(
                abstract_features_2, edge_index_2, batch_2
            )
        else:
            pooled_features_1 = self.attention(abstract_features_1, batch_1)
            pooled_features_2 = self.attention(abstract_features_2, batch_2)

        scores = self.tensor_network(pooled_features_1, pooled_features_2)

        if self.args.histogram:
            scores = torch.cat((scores, hist), dim=1)

        scores = F.relu(self.fully_connected_first(scores))
        score = torch.sigmoid(self.scoring_layer(scores)).view(-1)
        return score


class SimGNNTrainer(object):
    """
    SimGNN model trainer.
    """

    def __init__(self, args, data):
        """
        :param args: Arguments object.
        :param data: Dataset dictionary.
        """
        self.args = args
        self.process_dataset(data)
        self.setup_model()

    def setup_model(self):
        """
        Creating a SimGNN.
        """
        self.model = SimGNN(self.args, self.number_of_labels)

    def save(self, model_path=None):
        """
        Saving model.
        :param model_path: Path to save model.
        """
        if model_path:
            self.args.save = model_path
        torch.save(self.model.state_dict(), self.args.save)
        print(f"Model is saved under {self.args.save}.")

    def load(self, model_path=None):
        """
        Loading model.
        :param model_path: Path to load model.
        """
        if model_path:
            self.args.load = model_path
        self.model.load_state_dict(torch.load(self.args.load))
        print(f"Model is loaded from {self.args.load}.")

    def process_dataset(self, data):
        """
        Extract attrributes from dataset dictionary.
        :param data: Dataset dictionary
        """
        self.training_graphs = data['training_graphs']
        self.testing_graphs = data['testing_graphs']
        self.ndist_matrix = data['ndist_matrix']
        self.real_data_size = self.ndist_matrix.size(0)
        self.number_of_labels = self.training_graphs[0].x.size(1)

    def create_batches(self):
        """
        Creating batches from the training graph list.
        :return batches: Zipped loaders as list.
        """
        batch_size = self.args.batch_size
        # create a dataloader using training graphs with given batch size and shuffle=True
        source_loader = DataLoader(self.training_graphs, batch_size=self.args.batch_size, shuffle=True)
        new_batch_pair_list = []
        for source_batch in source_loader:
          # for each source batch, create a corresponding target batch based on "biased" random sampling logic
          target_batch = Batch.from_data_list(dataset[random_sample_from_closest_vectorized(source_batch["i"], train_indices)])
          new_batch_pair_list.append((source_batch, target_batch))

        return new_batch_pair_list

    def transform(self, data):
        """
        Getting distance for graph pair and grouping with data into dictionary.
        :param data: Graph pair.
        :return new_data: Dictionary with data.
        """
        new_data = dict()

        new_data["g1"] = data[0]
        new_data["g2"] = data[1]

        # fetch normalized distance values from ndist_matrix
        normalized_dist = self.ndist_matrix[
            data[0]["i"].reshape(-1).tolist(), data[1]["i"]
        ].tolist()

        new_data["target"] = (
            torch.from_numpy(np.array(normalized_dist)).view(-1).float()
        )
        return new_data

    def process_batch(self, data):
        """
        Forward pass with a data.
        :param data: Data that is essentially pair of batches, for source and target graphs.
        :return loss: Loss on the data.
        """
        self.optimizer.zero_grad()
        data = self.transform(data)
        target = data["target"]
        prediction = self.model(data)
        loss = F.mse_loss(prediction, target, reduction="sum")
        loss.backward()
        self.optimizer.step()
        return loss.item()

    def fit(self):
        """
        Training a model.
        """
        print("\nModel training.\n")
        self.optimizer = torch.optim.Adam(
            self.model.parameters(),
            lr=self.args.learning_rate,
            weight_decay=self.args.weight_decay,
        )
        self.model.train()

        epochs = trange(self.args.epochs, leave=True, desc="Epoch")
        self.loss_list = []
        loss_list_test = []
        for epoch in epochs:

            if self.args.plot:
                if epoch % 10 == 0:
                    self.model.train(False)
                    cnt_test = 20
                    cnt_train = 100
                    t = tqdm(
                        total=cnt_test * cnt_train,
                        position=2,
                        leave=False,
                        desc="Validation",
                    )
                    scores = torch.empty((cnt_test, cnt_train))

                    for i, g in enumerate(self.testing_graphs[:cnt_test].shuffle()):
                        source_batch = Batch.from_data_list([g] * cnt_train)
                        target_batch = Batch.from_data_list(
                            self.training_graphs[:cnt_train].shuffle()
                        )
                        data = self.transform((source_batch, target_batch))
                        target = data["target"]
                        prediction = self.model(data)

                        scores[i] = F.mse_loss(
                            prediction, target, reduction="none"
                        ).detach()
                        t.update(cnt_train)

                    t.close()
                    loss_list_test.append(scores.mean().item())
                    self.model.train(True)

            batches = self.create_batches()
            main_index = 0
            loss_sum = 0
            for index, batch_pair in tqdm(
                enumerate(batches), total=len(batches), desc="Batches", leave=False
            ):
                loss_score = self.process_batch(batch_pair)
                main_index = main_index + batch_pair[0].num_graphs
                loss_sum = loss_sum + loss_score
            loss = loss_sum / main_index
            epochs.set_description("Epoch (Loss=%g)" % round(loss, 5))
            self.loss_list.append(loss)

        if self.args.plot:
            plt.plot(self.loss_list, label="Train")
            plt.plot(
                [*range(0, self.args.epochs, 10)], loss_list_test, label="Validation"
            )
            plt.ylim([0, 0.01])
            plt.legend()
            filename = self.args.dataset
            filename += "_" + self.args.gnn_operator
            if self.args.diffpool:
                filename += "_diffpool"
            if self.args.histogram:
                filename += "_hist"
            filename = filename + str(self.args.epochs) + ".pdf"
            plt.savefig(filename)

    def measure_time(self):
        import time

        self.model.eval()
        count = len(self.testing_graphs) * len(self.training_graphs)

        t = np.empty(count)
        i = 0
        tq = tqdm(total=count, desc="Graph pairs")
        for g1 in self.testing_graphs:
            for g2 in self.training_graphs:
                source_batch = Batch.from_data_list([g1])
                target_batch = Batch.from_data_list([g2])
                data = self.transform((source_batch, target_batch))

                start = time.process_time()
                self.model(data)
                t[i] = time.process_time() - start
                i += 1
                tq.update()
        tq.close()

        print(
            "Average time (ms): {}; Standard deviation: {}".format(
                round(t.mean() * 1000, 5), round(t.std() * 1000, 5)
            )
        )

    def score(self):
        """
        Scoring.
        """
        print("\n\nModel evaluation.\n")
        self.model.eval()

        self.scores = np.empty((len(self.testing_graphs), len(self.training_graphs)))
        ground_truth = np.empty((len(self.testing_graphs), len(self.training_graphs)))
        prediction_mat = np.empty((len(self.testing_graphs), len(self.training_graphs)))

        self.rho_list = []
        self.tau_list = []
        self.prec_at_10_list = []
        self.prec_at_20_list = []

        t = tqdm(total=len(self.testing_graphs) * len(self.training_graphs))

        for i, g in enumerate(self.testing_graphs):
          try:
            source_batch = Batch.from_data_list([g] * len(self.training_graphs))
            target_batch = Batch.from_data_list(self.training_graphs)

            data = self.transform((source_batch, target_batch))
            target = data["target"]
            ground_truth[i] = target
            prediction = self.model(data)
            prediction_mat[i] = prediction.detach().numpy()

            self.scores[i] = (
                F.mse_loss(prediction, target, reduction="none").detach().numpy()
            )

            self.rho_list.append(
                calculate_ranking_correlation(
                    spearmanr, prediction_mat[i], ground_truth[i]
                )
            )
            self.tau_list.append(
                calculate_ranking_correlation(
                    kendalltau, prediction_mat[i], ground_truth[i]
                )
            )
            self.prec_at_10_list.append(
                calculate_prec_at_k(10, prediction_mat[i], ground_truth[i])
            )
            self.prec_at_20_list.append(
                calculate_prec_at_k(20, prediction_mat[i], ground_truth[i])
            )

            t.update(len(self.training_graphs))
          except KeyboardInterrupt as e:
            raise e
          except:
            print("Ignoring error")
            continue

        self.rho = np.mean(self.rho_list).item()
        self.tau = np.mean(self.tau_list).item()
        self.prec_at_10 = np.mean(self.prec_at_10_list).item()
        self.prec_at_20 = np.mean(self.prec_at_20_list).item()
        self.model_error = np.mean(self.scores).item()
        self.print_evaluation()

    def print_evaluation(self):
        """
        Printing the error rates.
        """
        print("\nmse(10^-3): " + str(round(self.model_error * 1000, 5)) + ".")
        print("Spearman's rho: " + str(round(self.rho, 5)) + ".")
        print("Kendall's tau: " + str(round(self.tau, 5)) + ".")
        print("p@10: " + str(round(self.prec_at_10, 5)) + ".")
        print("p@20: " + str(round(self.prec_at_20, 5)) + ".")

# Parameters

In [None]:
import argparse

def get_default_parameters():
    """
    Returns an argparse.Namespace with the default parameters.
    :return args: Default parameters.
    """
    default_parameters = argparse.Namespace(
        gnn_operator="egcn",
        epochs=100,
        filters_1=64,
        filters_2=64,
        filters_3=64,
        tensor_neurons=64,
        bottle_neck_neurons=64,
        batch_size=32,
        bins=64,
        dropout=0.0,
        learning_rate=0.001,
        weight_decay=5e-4,
        histogram=True,
        diffpool=False,
        plot=False,
        synth=False,
        save=None,
        load=None,
        measure_time=False,
        notify=False,
        edge_dim=768,
    )
    return default_parameters

args = get_default_parameters()

# Create Dataset

In [None]:
output_folder_path = f"/content/drive/Shareddrives/CS224W/outputs/991_double_fixed/"
df = pd.read_csv(os.path.join(output_folder_path, "df_final.csv")).drop(columns=["Unnamed: 0"])
df

In [None]:
dataset = CustomDataset(df, output_folder_path)
len(dataset)

In [None]:
%%time
# train test split
random.seed(42)
train_indices = random.sample(range(len(dataset)), int(0.95 * len(dataset)))
test_indices = [i for i in range(len(dataset)) if i not in train_indices]
training_graphs = dataset[train_indices]
testing_graphs = dataset[test_indices]

In [None]:
data = {
    'training_graphs': training_graphs,
    'testing_graphs': testing_graphs,
    'ndist_matrix': dataset.ndist_matrix,
    'output_folder_path': output_folder_path
}

# Train and Save Model

In [None]:
trainer = SimGNNTrainer(args, data)

In [None]:
trainer.fit()
plt.plot(trainer.loss_list)
trainer.save(output_folder_path+'saved_model_egcn_95_split_1_layer_biased.pkl')

# Load and Evaluate Model

In [None]:
trainer.load(output_folder_path+'saved_model_egcn_95_split_1_layer_biased.pkl')

In [None]:
trainer.score()

## Test on query graph

In [None]:
import matplotlib.pyplot as plt
import networkx as nx
from matplotlib.pyplot import imread

def display_graph(graph_data, image_path, coefs=None):
    # Convert JSON to NetworkX graph
    G = nx.node_link_graph(graph_data, edges="edges", directed=True)

    # Set up layout for the graph
    pos = nx.shell_layout(G)

    # Create a figure with two subplots
    fig, axes = plt.subplots(1, 2, figsize=(16, 8))

    # Left subplot: Display the image
    image = imread(image_path)
    axes[0].imshow(image)
    axes[0].axis("off")  # Hide axes for the image
    axes[0].set_title(image_path.split("/")[-1])

    # Right subplot: Display the graph
    if coefs is not None:
        vmin = coefs.min().item() - 0.005
        vmax = coefs.max().item() + 0.005
        colors = coefs.tolist()
        nodes = nx.draw_networkx_nodes(
            G,
            pos,
            node_size=300,
            node_color=colors,
            cmap=plt.cm.Reds,
            vmin=vmin,
            vmax=vmax,
            edgecolors="black",
            ax=axes[1],
        )
        cbar = plt.colorbar(nodes, ax=axes[1])
        cbar.set_label('Coefficient Values')
        cbar.ax.tick_params(labelsize=8)  # Optional: Set the tick label size

    else:
        nx.draw_networkx_nodes(G, pos, node_size=300, node_color="skyblue", edgecolors="black", ax=axes[1])
    nx.draw_networkx_edges(G, pos, arrowstyle="->", arrowsize=30, ax=axes[1])
    nx.draw_networkx_labels(G, pos, verticalalignment="center", horizontalalignment="center", font_color="blue",
                            labels={node['id']: f"\n\n\n\n\n\n{node['feature']}" for node in graph_data["nodes"]}, font_size=8, ax=axes[1])

    edge_labels = {(link["source"], link["target"]): link["type"] for link in graph_data["edges"]}
    nx.draw_networkx_edge_labels(G, pos, edge_labels=edge_labels, font_size=8, label_pos=0.3, ax=axes[1])
    axes[1].set_title("Scene Graph Visualization")
    axes[1].axis("off")  # Hide axes for the graph

    # Adjust spacing and display the plot
    plt.tight_layout()
    plt.show()

In [None]:
def retrieve_nearest_neighbors(g):

    image_dataset_folder = "/content/drive/Shareddrives/CS224W/Dataset/all"

    # create pairs with all graphs in the dataset
    source_batch = Batch.from_data_list([g] * len(dataset))
    target_batch = Batch.from_data_list(dataset)

    # get model predictions
    data = trainer.transform((source_batch, target_batch))
    target = data["target"]
    ground_truth = target
    prediction = trainer.model(data)
    prediction_mat = prediction.detach().numpy()

    # get attention coefficients
    features = trainer.model.convolutional_pass(g.edge_index, g.x, g.edge_attr)
    coefs = trainer.model.attention.get_coefs(features)

    # load graph from json file
    query_image_path = os.path.join(image_dataset_folder, g.image_filename)
    with open(os.path.join(output_folder_path, g.graph_filename), 'r') as f:
        query_graph_data = json.load(f)

    display_graph(query_graph_data, query_image_path, coefs)

    # Retrieve and display top-k images and graphs
    print(f"\nTop {5} Retrieved Images and Graphs:")
    retrieved_indices = np.argsort(prediction_mat)[:5]
    for i, idx in enumerate(retrieved_indices):
        g_ = dataset[idx]
        features = trainer.model.convolutional_pass(g_.edge_index, g_.x, g_.edge_attr)
        coefs = trainer.model.attention.get_coefs(features)
        retrieved_image_path = os.path.join(image_dataset_folder, g_.image_filename)
        with open(os.path.join(output_folder_path, g_.graph_filename), 'r') as f:
            retrieved_graph_data = json.load(f)

        print(f"\nRetrieved Image #{i + 1}:")
        print(f"Ground Truth: {target[idx]}")
        print(f"Prediction: {prediction_mat[idx]}")
        display_graph(retrieved_graph_data, retrieved_image_path, coefs)

In [None]:
index = 500
g = dataset[index]
# g = testing_graphs[0]

retrieve_nearest_neighbors(g)