# Using GEBM
This an example of how GEBM could be used outside the framework of our project as a standalone module.

In [1]:
import torch
from graph_uq.gebm import GraphEBMWrapper

In [2]:
gebm = GraphEBMWrapper()

### Fit the GEBM to the logits and embeddings of your GNN
We fit GEBM to the logits and embeddings of your model. The embeddings are used to fit a latent space Gaussian regularizer. The logits are needed to compute a normalizer that scales the logit-based energy term and the regularizer to the same scale.

Note that the logits and embeddings should be computed in the presence of network effects, i.e. your normal GNN outputs. Additionally, you should provide a mask for the nodes you have labels for, i.e. training nodes. The regularizer will fit class-conditional Gaussians based on these labels.

Your GNN model that predicts some logits of shape [num_nodes, num_classes] and embeddings of shape [num_nodes, emb_dim]. Here, we have dummy code for generating the tensors

In [3]:
logits = torch.randn(
    200, 5
)  # logits outputted by your GNN model in the presence of edges
embeddings = torch.randn(
    200, 32
)  # embeddings outputted by your GNN model in the presence of edges
y = torch.randint(0, 5, (200,))  # true class labels
edge_index = torch.randint(0, 200, (2, 1000))  # edge index tensor
edge_index = torch.cat([edge_index, edge_index[[1, 0]]], dim=1)  # make graph undirected
mask_train = (
    torch.rand(200) < 0.5
)  # train mask, the model will fit its regularizers to nodes of this mask

In [4]:
gebm.fit(logits, embeddings, edge_index, y, mask_train)

### Evaluate the GEBM on your model outputs

Now you can evaluate GEBM on outputs of your GNN, e.g. on data with a distribution shift. This time, logits and embeddings should be computed **without using network effects**, e.g. by setting the adjacency matrix to the identity or passing an empty edge index tensor to your model. The diffusion of GEBM itself, however, uses the graph.

In [5]:
logits_eval_no_network = torch.randn(
    200, 5
)  # logits outputted by your GNN model in the absence of edges
embeddings_eval_no_network = torch.randn(
    200, 32
)  # embeddings outputted by your GNN model in the absence of edges

In [6]:
uncertainty = gebm.get_uncertainty(
    logits_unpropagated=logits_eval_no_network,
    embeddings_unpropagated=embeddings_eval_no_network,
    edge_index=edge_index,
)
uncertainty.size()

torch.Size([200])

## Using GEBM within our framework

You can also use GEBM within our framework, i.e. by creating a model that inherits from `BaseModel` and using a dataset that inherits from `Data`, or simply using models and datasets from our codebase.

In [7]:
from graph_uq.config.data import default_data_config
from graph_uq.config.model import default_model_config
from graph_uq.config.trainer import default_trainer_config

from graph_uq.model.build import get_model
from graph_uq.data.build import apply_distribution_shift_and_split, get_base_data
from graph_uq.training import train_model
from graph_uq.logging.logger import Logger
from graph_uq.evaluation.uncertainty import binary_classification

In [8]:
# CoraML dataset, leave out classes setting

data_config = default_data_config.copy()
data_config["name"] = "cora_ml"
data_config["categorical_features"] = True
data_config["distribution_shift"]["type_"] = "leave_out_classes"
data_config["distribution_shift"]["leave_out_classes_type"] = "last"
data_config["distribution_shift"]["num_left_out_classes"] = 3

dataset = apply_distribution_shift_and_split(get_base_data(data_config), data_config)

In [9]:
model_config = default_model_config.copy()
model_config["name"] = "gcn"
model_config["type_"] = "gcn"
model_config["hidden_dims"] = [64]

model = get_model(default_model_config, dataset.data_train)

In [10]:
trainer_config = default_trainer_config.copy()

Lets train our model. This takes usually 500-1500 epochs on CoraML.

In [11]:
metrics = train_model(trainer_config, dataset.data_train, model, Logger())

  5%|▍         | 455/10000 [00:20<07:18, 21.77it/s]


In [12]:
model = model.eval()
model.reset_cache()

Now we can evaluate GEBM using the wrapper. It provides utility to interact very nicely with dataset and model classes of our framework.

In [19]:
gebm = GraphEBMWrapper()

In [20]:
model = model.cpu().eval()
model.reset_cache()
gebm.fit_from_model(
    dataset.data_train.cpu(),
    model.cpu(),
)

In [21]:
model.reset_cache()
uncertainty = gebm.get_uncertainty_from_model(
    dataset.data_shifted["loc"],
    model,
)

In [22]:
mask = dataset.data_shifted["loc"].get_mask("test")
binary_classification(
    uncertainty[mask], dataset.data_shifted["loc"].get_distribution_mask("ood")[mask]
)

{auc_roc: 0.8890266654532217, auc_pr: 0.8337349197105466}