In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import sys
sys.path.append('..')

In [3]:
import sys

IN_COLAB = 'google.colab' in sys.modules
REPO_DIR = '..' if IN_COLAB  else '..'

# Code

In [4]:
import os
import itertools
import collections
import tqdm.auto as tqdm
import time

from IPython.display import display

import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns

import tensorflow as tf
import sonnet as snt
import graph_nets
from graph_nets.graphs import GraphsTuple
import graph_attribution as gatt

# Ignore tf/graph_nets UserWarning:
# Converting sparse IndexedSlices to a dense Tensor of unknown shape
import warnings
warnings.simplefilter("ignore", UserWarning)

for mod in [tf, snt, gatt]:
    print(f'{mod.__name__:20s} = {mod.__version__}')

rdkit detected? True
tensorflow           = 2.4.0
sonnet               = 2.0.0
graph_attribution    = 1.0.0b


## Graph Attribution specific imports

In [5]:
from graph_attribution import tasks
from graph_attribution import graphnet_models as gnn_models
from graph_attribution import graphnet_techniques as techniques
from graph_attribution import datasets
from graph_attribution import experiments
from graph_attribution import templates
from graph_attribution import graphs as graph_utils

datasets.DATA_DIR = os.path.join(REPO_DIR, 'data')
print(f'Reading data from: {datasets.DATA_DIR}')

Reading data from: ../data


# Load Experiment data, a task and attribution techniques

In [6]:
print(f'Available tasks: {[t.name for t in tasks.Task]}')
print(f'Available model types: {[m.name for m in gnn_models.BlockType]}')
print(f'Available ATT techniques: {list(techniques.get_techniques_dict(None,None).keys())}')

Available tasks: ['benzene', 'logic7', 'logic8', 'logic10', 'crippen', 'bashapes', 'treegrid', 'bacommunity']
Available model types: ['gcn', 'gat', 'mpnn', 'graphnet']
Available ATT techniques: ['Random', 'GradInput', 'SmoothGrad(GradInput)', 'GradCAM-last', 'GradCAM-all', 'CAM']


In [7]:
task_type = 'crippen'
block_type = 'gcn'

task_dir = datasets.get_task_dir(task_type)
exp, task, methods = experiments.get_experiment_setup(task_type, block_type)
task_act, task_loss = task.get_nn_activation_fn(), task.get_nn_loss_fn()
graph_utils.print_graphs_tuple(exp.x_train)
print(f'Experiment data fields:{list(exp.__dict__.keys())}')

Shapes of GraphsTuple's fields:
GraphsTuple(nodes=TensorShape([23201, 14]), edges=TensorShape([47078, 5]), receivers=TensorShape([47078]), senders=TensorShape([47078]), globals=TensorShape([901, 1]), n_node=TensorShape([901]), n_edge=TensorShape([901]))
Experiment data fields:['x_train', 'x_test', 'y_train', 'y_test', 'att_test', 'x_aug', 'y_aug']


# Create a GNN model

## Define hparams of the experiment

In [None]:
hp = gatt.hparams.get_hparams({'block_type':block_type, 'task_type':task_type})
hp

## Instantiate model

In [None]:
model = experiments.GNN(node_size = hp.node_size,
               edge_size = hp.edge_size,
               global_size = hp.global_size,
               y_output_size = task.n_outputs,
               block_type = gnn_models.BlockType(hp.block_type),
               activation = task_act,
               target_type = task.target_type,
               n_layers = hp.n_layers)
model(exp.x_train)
gnn_models.print_model(model)

# Train model on task



## Training loop

In [8]:
from typing import Callable, Tuple

def get_batch_indices(n: int, batch_size: int) -> np.ndarray:
    """Gets shuffled constant size batch indices to train a model."""
    n_batches = n // batch_size
    indices = tf.random.shuffle(tf.range(n))
    indices = indices[:n_batches * batch_size]
    indices = tf.reshape(indices, (n_batches, batch_size))
    return indices

def make_tf_opt_epoch_fn(
        inputs: GraphsTuple, target: np.ndarray, batch_size: int, model: snt.Module,
        optimizer: snt.Optimizer, loss_fn: templates.LossFunction,
        l2_reg: float = 0.0, orth_lambda: float = 0.0) -> Callable[[tf.Tensor, tf.Tensor], tf.Tensor]:
    """Make a tf.function of (inputs, target) for optimization.

    This function is useful for basic inference training of GNN models. Uses all
    variables to create a a function that has a tf.function optimized input
    signature. Function uses pure tf.functions to build batches and aggregate
    losses. The result is a heavily optimized function that is at least 2x
    faster than a basic tf.function with experimental_relax_shapes=True.

    Args:
      inputs: graphs used for training.
      target: values to predict for training.
      batch_size: batch size.
      model: a GNN model.
      optimizer: optimizer, probably Adam or SGD.
      loss_fn: a loss function to optimize.
      l2_reg: l2 regularization weight.

    Returns:
      optimize_one_epoch(intpus, target), a tf.function optimized
      callable.

    """
    # Explicit input signature is faster than experimental relax shapes.
    input_signature = [
        graph_nets.utils_tf.specs_from_graphs_tuple(inputs),
        tf.TensorSpec.from_tensor(tf.convert_to_tensor(target))
    ]
    n = graph_utils.get_num_graphs(inputs)
    n_batches = tf.cast(n // batch_size, tf.float32)

    if l2_reg > 0.0:
        regularizer = snt.regularizers.L2(l2_reg)
        linear_variables = gnn_models.get_linear_variables(model)
    if batch_size == 1 or n == 1:
        def optimize_one_epoch(inputs, target):
            """One epoch single-batch optimization."""
            with tf.GradientTape() as tape:
                loss = loss_fn(target, model(inputs))
                if l2_reg > 0.0:
                    loss += regularizer(linear_variables)

            grads = tape.gradient(loss, model.trainable_variables)
            optimizer.apply(grads, model.trainable_variables)
            return loss
    else:
        def optimize_one_epoch(inputs, target):
            """One epoch optimization."""
            loss = tf.constant(0.0, tf.float32)
            for batch in get_batch_indices(n, batch_size):
                x_batch = graph_utils.get_graphs_tf(inputs, batch)
                y_batch = tf.gather(target, batch)
                with tf.GradientTape() as tape:
                    output, representations = model(x_batch)
                    
                    # custom loss code here
                    t = tf.cumsum(representations.n_node, exclusive=True)
                    orth_loss = tf.constant(0.0, tf.float32)
                    for i in range(1, len(t)):
                        w = representations.nodes[t[i-1]:t[i]]
                        ws = w @ tf.transpose(w)
                        ws -= tf.linalg.eye(representations.n_node[i - 1])
                        orth_loss += tf.norm(ws, ord=2, axis=[-2, -1])
                            
                    batch_loss = loss_fn(y_batch, output)
                    if l2_reg > 0.0:
                        batch_loss += regularizer(linear_variables)
                    if orth_lambda > 0.0:
                        batch_loss += orth_lambda / 2 * orth_loss / batch_size

                grads = tape.gradient(batch_loss, model.trainable_variables)
                optimizer.apply(grads, model.trainable_variables)
                loss += batch_loss
            return loss / n_batches

    return tf.function(optimize_one_epoch, input_signature=input_signature)

In [None]:
optimizer = snt.optimizers.Adam(hp.learning_rate)



opt_one_epoch = make_tf_opt_epoch_fn(exp.x_train, exp.y_train, hp.batch_size, model,
                                      optimizer, task_loss, l2_reg=0.0, orth_lambda=0.000)

pbar = tqdm.tqdm(range(hp.epochs))
losses = collections.defaultdict(list)
start_time = time.time()
for _ in pbar:
    train_loss = opt_one_epoch(exp.x_train, exp.y_train).numpy()
    losses['train'].append(train_loss)
    r = model(exp.x_test);
    losses['test'].append(task_loss(exp.y_test, model(exp.x_test)[0]).numpy())
    pbar.set_postfix({key: values[-1] for key, values in losses.items()})

losses = {key: np.array(values) for key, values in losses.items()}

## Plot losses

In [None]:
for key, values in losses.items():
  plt.plot(values, label=key)
plt.ylabel('loss')
plt.xlabel('epochs')
plt.legend()
plt.show()

In [None]:
results = []
for method in tqdm.tqdm(methods.values(), total=len(methods)):
  results.append(experiments.generate_result(model, method, task, exp.x_test, exp.y_test, exp.att_test))
pd.DataFrame(results)

# Evaluate predictions and attributions

In [9]:
def train_and_evaluate(task_type, block_type, orth=0.00):
    task_dir = datasets.get_task_dir(task_type)
    exp, task, methods = experiments.get_experiment_setup(task_type, block_type)
    task_act, task_loss = task.get_nn_activation_fn(), task.get_nn_loss_fn()
    graph_utils.print_graphs_tuple(exp.x_train)
    
    hp = gatt.hparams.get_hparams({'block_type':block_type, 'task_type':task_type})
    
    model = experiments.GNN(node_size = hp.node_size,
               edge_size = hp.edge_size,
               global_size = hp.global_size,
               y_output_size = task.n_outputs,
               block_type = gnn_models.BlockType(hp.block_type),
               activation = task_act,
               target_type = task.target_type,
               n_layers = hp.n_layers)
    model(exp.x_train)

    optimizer = snt.optimizers.Adam(hp.learning_rate)



    opt_one_epoch = make_tf_opt_epoch_fn(exp.x_train, exp.y_train, hp.batch_size, model,
                                          optimizer, task_loss, l2_reg=0.0, orth_lambda=0.000)

    pbar = tqdm.tqdm(range(hp.epochs))
    losses = collections.defaultdict(list)
    start_time = time.time()
    for _ in pbar:
        train_loss = opt_one_epoch(exp.x_train, exp.y_train).numpy()
        losses['train'].append(train_loss)
        r = model(exp.x_test);
        losses['test'].append(task_loss(exp.y_test, model(exp.x_test)[0]).numpy())
        pbar.set_postfix({key: values[-1] for key, values in losses.items()})

    losses = {key: np.array(values) for key, values in losses.items()}
    
    results = []
    for method in tqdm.tqdm(methods.values(), total=len(methods)):
        results.append(experiments.generate_result(model, method, task, exp.x_test, exp.y_test, exp.att_test))
        
    return pd.DataFrame(results), losses

In [10]:
from itertools import product

In [None]:
n_trials = 5
block = ['gcn', 'gat', 'mpnn', 'graphnet']
task = ['benzene', 'logic10', 'crippen']
orth = [0.000, 0.001]
run = range(n_trials)

results = {}
losses = {}

for block, task, orth, run in product(block, task, orth, run):
    try:
        result, loss = train_and_evaluate(task, block, orth)
        results[(block, task, orth, run)] = result
        loss[(block, task, orth, run)] = loss
    except Exception as e:
        continue

Shapes of GraphsTuple's fields:
GraphsTuple(nodes=TensorShape([205826, 14]), edges=TensorShape([436434, 5]), receivers=TensorShape([436434]), senders=TensorShape([436434]), globals=TensorShape([10000, 1]), n_node=TensorShape([10000]), n_edge=TensorShape([10000]))


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=300.0), HTML(value='')))

In [None]:
ls

# Visualize attributions