# Fingerprinting

In [1]:
%load_ext autoreload
%autoreload 2

import torch
from copy import deepcopy

from torch_geometric.data import Data, Batch

from graphium.nn.architectures import FullGraphMultiTaskNetwork

_ = torch.manual_seed(42)

## Define batch

In [2]:
in_dim = 5          # Input node-feature dimensions
in_dim_edges = 13   # Input edge-feature dimensions
out_dim = 11        # Desired output node-feature dimensions


# Let's create 2 simple pyg graphs. 
# start by specifying the edges with edge index
edge_idx1 = torch.tensor([[0, 1, 2],
                          [1, 2, 3]])
edge_idx2 = torch.tensor([[2, 0, 0, 1],
                          [0, 1, 2, 0]])

# specify the node features, convention with variable x
x1 = torch.randn(edge_idx1.max() + 1, in_dim, dtype=torch.float32)
x2 = torch.randn(edge_idx2.max() + 1, in_dim, dtype=torch.float32)

# specify the edge features in e
e1 = torch.randn(edge_idx1.shape[-1], in_dim_edges, dtype=torch.float32)
e2 = torch.randn(edge_idx2.shape[-1], in_dim_edges, dtype=torch.float32)

# make the pyg graph objects with our constructed features
g1 = Data(feat=x1, edge_index=edge_idx1, edge_feat=e1)
g2 = Data(feat=x2, edge_index=edge_idx2, edge_feat=e2)

# put the two graphs into a Batch graph
bg = Batch.from_data_list([g1, g2])

# The batched graph will show as a single graph with 7 nodes
print(bg)


DataBatch(edge_index=[2, 7], feat=[7, 5], edge_feat=[7, 13], batch=[7], ptr=[3])


## Define model

In [3]:
temp_dim_1 = 23
temp_dim_2 = 17

pre_nn_kwargs = {
    "in_dim": in_dim,
    "out_dim": temp_dim_1,
    "hidden_dims": 4,
    "depth": 2,
    "activation": 'relu',
    "last_activation": "none",
    "dropout": 0.2
}

gnn_kwargs = {
    "in_dim": temp_dim_1,
    "out_dim": temp_dim_2,
    "hidden_dims": 5,
    "depth": 4,
    "activation": 'gelu',
    "last_activation": None,
    "dropout": 0.1,
    "normalization": 'layer_norm',
    "last_normalization": 'layer_norm',
    "residual_type": 'simple',
    "virtual_node": None,
    "layer_type": 'pyg:gcn',
    "layer_kwargs": None
}

task_heads_kwargs = {
    "graph-task-1": {
        "task_level": 'graph',
        "out_dim": 3,
        "hidden_dims": 32,
        "depth": 4,
        "activation": 'relu',
        "last_activation": None,
        "dropout": 0.1,
        "normalization": None,
        "last_normalization": None,
        "residual_type": "none"
    },
    "graph-task-2": {
        "task_level": 'graph',
        "out_dim": 4,
        "hidden_dims": 32,
        "depth": 2,
        "activation": 'relu',
        "last_activation": None,
        "dropout": 0.1,
        "normalization": None,
        "last_normalization": None,
        "residual_type": "none"
    },
    "node-task-1": {
        "task_level": 'node',
        "out_dim": 2,
        "hidden_dims": 32,
        "depth": 3,
        "activation": 'relu',
        "last_activation": None,
        "dropout": 0.1,
        "normalization": None,
        "last_normalization": None,
        "residual_type": "none"
    }
}

graph_output_nn_kwargs = {
    "graph": {
        "pooling": ['sum'],
        "out_dim": temp_dim_2,
        "hidden_dims": temp_dim_2,
        "depth": 1,
        "activation": 'relu',
        "last_activation": None,
        "dropout": 0.1,
        "normalization": None,
        "last_normalization": None,
        "residual_type": "none"
    },
    "node": {
        "pooling": None,
        "out_dim": temp_dim_2,
        "hidden_dims": temp_dim_2,
        "depth": 1,
        "activation": 'relu',
        "last_activation": None,
        "dropout": 0.1,
        "normalization": None,
        "last_normalization": None,
        "residual_type": "none"
    }
}
    

model = FullGraphMultiTaskNetwork(
    gnn_kwargs=gnn_kwargs,
    pre_nn_kwargs=pre_nn_kwargs, 
    task_heads_kwargs=task_heads_kwargs,
    graph_output_nn_kwargs = graph_output_nn_kwargs
)

In [4]:
# Create the module map to navigate the modules
model.create_module_map(level="module")
module_map = model._module_map

## Fingerprints on graph level

Let's take as fingerprints the outputs of a TaskHead

In [5]:
# How many layers per module?
print("TaskHead depth:", module_map['task_heads/graph-task-1'].depth)

# Keep track of readouts for desired modules
module_map['task_heads/graph-task-1']._keep_readouts()

# Run one forward pass
batch = deepcopy(bg)
_ = model(batch)

# Extract the readouts
task_head_readouts = module_map['task_heads/graph-task-1'].readouts

print("TaskHead readout shapes:")
for readout in task_head_readouts.values():
    print(readout.shape)

TaskHead depth: 4
TaskHead readout shapes:
torch.Size([2, 32])
torch.Size([2, 32])
torch.Size([2, 32])
torch.Size([2, 3])


## Fingerprints on node level

Let's take as fingerprints the outputs of 2nd and 4th layer of the GNN and first two layers of a TaskHead

In [6]:
# How many layers per module?
print("GNN depth:", module_map['gnn'].depth)
print("TaskHead depth:", module_map['task_heads/node-task-1'].depth)

# Keep track of readouts for desired modules
module_map['gnn']._keep_readouts()
module_map['task_heads/node-task-1']._keep_readouts()

# Run one forward pass
batch = deepcopy(bg)
_ = model(batch)

# Extract the readouts
gnn_readouts = module_map['gnn'].readouts
task_head_readouts = module_map['task_heads/node-task-1'].readouts

print("GNN readout shapes:")
for readout in gnn_readouts.values():
    print(readout.shape)

print("TaskHead readout shapes:")
for readout in task_head_readouts.values():
    print(readout.shape)

GNN depth: 4
TaskHead depth: 3
GNN readout shapes:
torch.Size([7, 5])
torch.Size([7, 5])
torch.Size([7, 5])
torch.Size([7, 17])
TaskHead readout shapes:
torch.Size([7, 32])
torch.Size([7, 32])
torch.Size([7, 2])


In [7]:
# Create configuration of fingerprint
config = {
    'gnn': [1, 3],
    'task_heads/node-task-1': [0, 1],
}

In [8]:
# Function for creating fingerprints 
def create_fingerprint(model, config, batch):
    model.create_module_map(level=None)
    module_map = model._module_map
    for module_names in config.keys():
        module_map[module_names]._keep_readouts()
    
    # Run one forward pass
    _ = model(batch)

    readout_list = []

    for module_name, layers in config.items():
        readout_list.extend([module_map[module_name].readouts[layer] for layer in layers])

    return torch.cat(readout_list, dim=-1), readout_list

In [9]:
# Create the fingerprint
fingerprint, r =  create_fingerprint(model, config, deepcopy(bg))
print(fingerprint.shape)

torch.Size([7, 86])
