This is a noteboook that illustrates how to use Graph Neural Networks to encode structured data for use in Large Language Models. It is Part 2 of a two part tutorial from KDD'24.

**This notebook requires a TPUv2 runtime**

If you find this tutorial useful or want to know more, please consider our publication:
Let your graph do the talking: Encoding structured data for LLMs
```
@article{perozzi2024let,
  title={Let your graph do the talking: Encoding structured data for llms},
  author={Perozzi, Bryan and Fatemi, Bahare and Zelle, Dustin and Tsitsulin, Anton and Kazemi, Mehran and Al-Rfou, Rami and Halcrow, Jonathan},
  journal={arXiv preprint arXiv:2402.05862},
  year={2024}
}
```

## Tutorial Part II: GNN Encoding of Graph Information
This notebook takes the work we did in the first part of the tutorial and extends it to using a Graph Neural Network to directly encode a representation of a graph into a prompt (vs using a text encoding as we did in the previous part).

## Notebook Outline:

Setup (Install Dependencies, download Gemma weights)
Dataset creation
Graph-to-Text conversion
Evaluation
Exercise: Graph Encoding Challenge
Exercise: DBLP Dataset
Setup

## Prework!

Sign-up for Kaggle and consent to the Gemma TOS (this is a requirement to download the Gemma weights used in this notebook).
https://www.kaggle.com/models/google/gemma/license/consent?returnUrl=%2Fmodels%2Fgoogle%2Fgemma%2FFlax%2F2b-it%2F2

In [None]:
%%capture
# @title Install Dependencies
!pip install git+https://github.com/google-deepmind/gemma.git
!pip install --user kaggle
!pip install sparse_deferred
!git clone https://github.com/google-research/talk-like-a-graph.git
import sys
sys.path.insert(0, "/content/talk-like-a-graph")


## Login to Kaggle
Follow the link in the login dialog to get an API key if you don't already have one. Also make sure to approve the [Gemma TOS](https://www.kaggle.com/models/google/gemma/license/consent?returnUrl=%2Fmodels%2Fgoogle%2Fgemma%2FFlax%2F2b-it%2F2) as well.

In [None]:
import kagglehub

kagglehub.login()

## Download Gemma

In [None]:
import os
VARIANT = '2b-it' # @param ['2b', '2b-it', '7b', '7b-it'] {type:"string"}
weights_dir = kagglehub.model_download(f'google/gemma/Flax/{VARIANT}')
ckpt_path = os.path.join(weights_dir, VARIANT)
vocab_path = os.path.join(weights_dir, 'tokenizer.model')

## Import dependencies

In [None]:
# @title
import os
from collections.abc import Sequence
import dataclasses
from typing import Any, Callable, Mapping
import sys


import chex
from flax import linen as nn
import jax
import jax.numpy as jnp
import networkx as nx
import numpy as np


from gemma import params as params_lib
from gemma import transformer as transformer_lib
from gemma import sampler as sampler_lib
import sentencepiece as spm
import sparse_deferred as sd
from sparse_deferred import jax as sdjnp
from sparse_deferred.structs import graph_struct
from sparse_deferred import np as sdnp


## GraphToken library code

In [None]:
# @title
import collections
from collections.abc import Iterable
import io
import json
from typing import Any, Callable, NamedTuple, Sequence

import numpy as np
import tqdm

# Code for converting NetworkX graphs to graph tensor
def laplacian_pos_embedding(graph: nx.Graph, units: int = 4) -> nx.Graph:
  """Adds the laplacian positional encoding."""
  m = nx.normalized_laplacian_matrix(
      graph, nodelist=sorted(graph.nodes), weight=None
  ).astype(np.float32)
  u, _, _ = np.linalg.svd(m.todense(), compute_uv=True)
  if units > u.shape[1]:
    u = np.pad(u, ((0, 0), (0, units - u.shape[1])))
  nx.set_node_attributes(
      graph, dict(zip(sorted(graph.nodes), u[:, :units])), name='lpe'
  )
  return graph


def to_graph_struct(graph: nx.Graph, node_ids: list[int]=None) -> graph_struct.GraphStruct:
  if graph.edges(data=True):
    s, t, w = zip(*[
        (s, t, (d['weight'] if d and 'weight' in d else None))
        for s, t, d in graph.edges(data=True)
    ])
  else:
    s, t, w = (), (), ()
  # tfgnn assumes graphs are directed. Adding the rev edges for an undirected
  # graph.
  if not graph.is_directed():
    s, t, w = s + t, t + s, w + w

  graph = laplacian_pos_embedding(graph, units=4)
  return graph_struct.GraphStruct.new(
    nodes={'nodes': {'lpe': np.stack([graph.nodes('lpe')[i] for i in range(graph.number_of_nodes())])}},
    edges={'edges': ((np.array(s, dtype=np.int32), np.array(t, dtype=np.int32)), {})}
  )


Tensor = sd.matrix.Tensor
Features = dict[str, Tensor]
FeatureSets = dict[str, Features]
Edge = tuple[tuple[Tensor, ...], Features]  # (endpoints, edge features)
Edges = dict[str, Edge]
Nodes = FeatureSets
Schema = dict[str, tuple[str, ...]]
_Schema = dict[str, tuple[dict[str, int], ...]]



class FixedSizePadder:
  """Adds padding to `GraphStruct` instances for fixed-sized tensors.

  Fixed-size tensors can be preferred when running on TPU accelerators.

  To use this class, you must first initialize it with statistics of your graphs
  then use it to pad graphs. The statistics can be initialized by invoking
  `calculate_pad_statistics`: this function records the *maximum* observerd size
  of every node and edge set, as well as the standard deviation (std) of sizes.

  Once initialized, the function: `pad_graph()` will add padding to the graph.
  Specifically, the node feature (tensors) will be padded with zeros. Similarly,
  edges will be inserted, among newly-added virtual nodes.

  Each node (or edge) size will become:

  `max observed [per calculate_pad_statistics] + slack*std + 1`

  NOTE: there will always be at least one more node or edge, even if the
  statistics show zero std. This is required for making virtual nodes.

  All sizes node-set (features) and edge-set (features and adjacency list)
  """

  def __init__(self, engine: sd.ComputeEngine, slack: float = 1.0):
    # `('edge'|'node', NodeOrEdgeName) -> target size`
    # where `target size` is maximum observed size for node (or edge) set, plus
    # one, plus slack-times-std of observed sizes.
    self.sizes: dict[tuple[str, str], int] = {}
    self.slack = slack
    self._engine = engine

  def calculate_pad_statistics(
      self, examples: Iterable[graph_struct.GraphStruct], num_steps: int = 100):
    """Measures the max and std of node & edge sizes of elements of `examples`.

    Calling this function is necessary before invoking `pad_graph`.

    Args:
      examples: iterable that yields `GraphStruct` examples.
      num_steps: If positive, considers this many samples of `examples`.
        Otherwise, iterates over all `examples`. Warning: this may run
        infinitely on infinite iterators (e.g., `dataset.repeat()`).
    """
    sizes: dict[tuple[str, str], list[int]] = collections.defaultdict(list)
    for i, graph in enumerate(examples):
      assert isinstance(graph, graph_struct.GraphStruct)
      if i > 0 and i >= num_steps:
        break
      for node_name, features in graph.nodes.items():
        value_list = sizes[('nodes', node_name)]
        if not features:
          value_list.append(0)
        else:
          value_list.append(list(features.values())[0].shape[0])

      for edge_name, edges_tuple in graph.edges.items():
        value_list = sizes[('edges', edge_name)]
        source_nodes = edges_tuple[0][0]
        # if len(value_list) and edge_set.sizes.shape != value_list[-1].shape:
        #   continue
        value_list.append(source_nodes.shape[0])

    self.sizes = {k: int(1 + max(v) + self.slack * np.std(v))
                  for k, v in sizes.items()}

  def pad_graph(self, graph: graph_struct.GraphStruct) -> graph_struct.GraphStruct:
    """Pads node-sets and edge-sets, with zeros, to max-seen during `calc..`.

    This function is useful for running on TPU hardware.

    Args:
      graph: contains any number of nodes and edges.

    Returns:
      graph with deterministic number of nodes and edges. See class docstring.
    """
    if not self.sizes:
      raise ValueError(
          'No statistics have been initialized. '
          'Perhaps you forgot to invoke "calculate_pad_statistics"?')
    # Edge set name -> (1D vectors containing endpoints**), {"feature": Tensor})
    edges: Edges = {}
    # ** tuple should have 2 entries for directed graphs

    nodes: Nodes = {}

    # For every key in `edges`, store names of node sets that `key` edge
    # connects.
    schema = graph.schema

    e = self._engine  # for short.
    for node_name, node_features in graph.nodes.items():
      padded_features = {}
      desired_size = self.sizes[('nodes', node_name)]

      for feature_name, feature in node_features.items():
        feature = feature[:desired_size]  # if `is_oversized`.
        pad = self._engine.maximum(
            desired_size - self._engine.shape(feature)[0], 0)
        zeros = e.zeros(
            tuple([pad] + list(feature.shape[1:])), dtype=feature.dtype)
        padded_feature = e.concat([feature, zeros], axis=0)
        padded_feature = e.reshape(
            padded_feature, [desired_size] + list(padded_feature.shape[1:]))
        padded_features[feature_name] = padded_feature

      nodes[node_name] = padded_features

    for edge_name, (edge_endpoints, features) in graph.edges.items():
      padded_features = {}
      padded_endpoints = []
      desired_size = self.sizes[('edges', edge_name)]
      current_size = e.shape(edge_endpoints[0])[0]

      pad = e.maximum(desired_size - current_size, 0)
      e.assert_greater(pad, -1)

      for feature_name, feature in features.items():
        feature = feature[:desired_size]  # if `is_oversized`.
        zeros = e.zeros(
            tuple([pad] + list(feature.shape[1:])), dtype=feature.dtype
        )
        padded_feature = e.concat([feature, zeros], axis=0)
        padded_feature = e.reshape(
            padded_feature, [desired_size] + list(padded_feature.shape[1:])
        )
        padded_features[feature_name] = padded_feature

      edge_endpoints = [node_ids[:desired_size] for node_ids in edge_endpoints]
      # [[src1_is_valid, src2_is_valid, ...], [tgt1_is_valid, ...]]
      valid = e.cast(
          [
              ids < self.sizes[('nodes', node_name)]
              for ids, node_name in zip(edge_endpoints, schema[edge_name])
          ],
          dtype=bool,
      )
      valid = e.reduce_all(valid, axis=0)

      for node_ids, node_name in zip(edge_endpoints, schema[edge_name]):
        # Universe size (e.g., of source or target).
        max_endpoint = self.sizes[('nodes', node_name)] - 1
        node_ids = node_ids[:desired_size]
        node_ids = e.boolean_mask(node_ids, valid)
        pad = desired_size - e.shape(node_ids)[0]  # Need only to compute once.

        padded_ids = e.concat([
            node_ids,
            e.ones((pad), dtype=node_ids.dtype) * max_endpoint
        ], axis=0)
        padded_ids = e.reshape(padded_ids, [desired_size])
        padded_endpoints.append(padded_ids)

      edges[edge_name] = (tuple(padded_endpoints), padded_features)

    graph = graph_struct.GraphStruct.new(nodes=nodes, edges=edges, schema=schema)
    return graph



## gnn.py
class GIN(nn.Module):
  """Graph Isomorphism Network: https://arxiv.org/pdf/1810.00826.pdf."""

  output_dim: int
  num_hidden_layers: int = 1
  hidden_dim: int = 32
  epsilon: float = 0.1  # See GIN paper (link above)

  def setup(self):
    layer_dims = [self.hidden_dim] * self.num_hidden_layers
    self.layers = [
        nn.Dense(dim, use_bias=False, dtype=jnp.bfloat16) for dim in layer_dims
    ]
    self.out_layer = nn.Dense(
        self.output_dim, use_bias=False, dtype=jnp.bfloat16
    )

  def __call__(self, graph: graph_struct.GraphStruct) -> jax.Array:
    x = graph.nodes['nodes']['lpe']
    adj = graph.adj(sdjnp.engine, 'edges')
    adj = adj.add_eye(1 + self.epsilon)  # self connections with 1+eps weight.

    for i, layer in enumerate(self.layers):
      x = layer(adj @ x)
      if i < self.num_hidden_layers:
        x = nn.relu(x)
    x = jnp.concat(x, axis=-1)
    return self.out_layer(x)


class GCN(nn.Module):
  """Graph convolutional network: https://arxiv.org/pdf/1609.02907.pdf."""

  output_dim: int
  num_hidden_layers: int = 1
  hidden_dim: int = 32

  def setup(self):
    layer_dims = [self.hidden_dim] * self.num_hidden_layers
    self.layers = [nn.Dense(dim, use_bias=False) for dim in layer_dims]
    self.out_layer = nn.Dense(
        self.output_dim, use_bias=False, dtype=jnp.bfloat16
    )

  def __call__(self, graph: graph_struct.GraphStruct) -> jax.Array:
    x = graph.nodes['nodes']['lpe']
    adj = graph.adj(sdjnp.engine, 'edges')
    adj_symnorm = (adj + adj.transpose()).add_eye().normalize_symmetric()

    for i, layer in enumerate(self.layers):
      x = layer(adj_symnorm @ x)
      if i < self.num_hidden_layers:
        x = nn.relu(x)
    x = jnp.concat(x, axis=-1)
    return self.out_layer(x)


## sampler.py

@dataclasses.dataclass
class SamplerOutput:

  # Decoded samples from the model.
  text: list[str]

  # Per-step logits used during sampling.
  logits: list[list[float]]

  # Tokens corresponding to the generated samples.
  tokens: list[list[int]]

  graph_embeddings: list[jnp.ndarray]


class GraphTokenSampler:
  """Sampler for GraphToken."""

  def __init__(
      self,
      gnn: nn.Module,
      llm: transformer_lib.Transformer,
      vocab: spm.SentencePieceProcessor,
      params: Mapping[str, Any],
      gnn_token_template: str = r'<unused%d>',
  ):
    """Initializes the sampler.

    Args:
      gnn: The GNN model.
      llm: The LLM model.
      vocab: The vocab used by the LLM.
      params: The parameters for the GNN and LLM. This should contain the params
        for the gnn under params['gnn'] and the params for the llm under
        params['transformer']
      gnn_token_template: The token used to represent the GNN embedding.
    """

    self._gnn = gnn
    self._llm = llm
    self._params = params
    self._vocab = vocab
    self._gnn_token_template = gnn_token_template
    self._sampler = sampler_lib.Sampler(
        transformer=self._llm,
        vocab=self._vocab,
        params=self._params['transformer'],
    )

  def __call__(
      self,
      input_strings: Sequence[str],
      input_graphs: Sequence[graph_struct.GraphStruct],
      total_generation_steps: int,
      echo: bool = False,
      return_logits: bool = True,
      forbidden_tokens: Sequence[str] | None = None,
  ) -> SamplerOutput:
    """Samples from the model.

    Args:
      input_strings: The input strings.
      input_graphs: The input graphs.
      total_generation_steps: The number of steps to generate.
      echo: Whether to echo the input.
      return_logits: Whether to return the logits.
      forbidden_tokens: Tokens that are forbidden, in addition to the GNN token.

    Returns:
      The sampled output.
    """
    assert len(input_graphs) == len(input_strings), (
        len(input_graphs),
        len(input_strings),
    )
    augmented_inputs = []
    full_forbidden_tokens = []
    if forbidden_tokens is not None:
      full_forbidden_tokens += forbidden_tokens
    graph_embeddings = []
    augmented_transformer_params = self._params['transformer']

    placeholder_token = PLACEHOLDER_TOKEN
    full_forbidden_tokens.append(placeholder_token)
    placeholder_token_id = self._vocab.EncodeAsIds(placeholder_token)
    assert len(placeholder_token_id) == 1, placeholder_token
    placeholder_token_id = placeholder_token_id[0]

    for prompt, graph in zip(input_strings, input_graphs):
      embed = self._gnn.apply(self._params['gnn'], graph)
      assert (
          self._params['transformer']['embedder']['input_embedding'][
              placeholder_token_id
          ].shape
          == embed.shape
      )
      augmented_transformer_params['embedder']['input_embedding'] = (
          augmented_transformer_params['embedder']['input_embedding']
          .at[placeholder_token_id]
          .set(embed)
      )
      graph_embeddings.append(embed)
      augmented_inputs.append(placeholder_token + prompt)

    self._sampler.params = augmented_transformer_params
    o = self._sampler(
        input_strings=augmented_inputs,
        total_generation_steps=total_generation_steps,
        echo=echo,
        return_logits=return_logits,
        forbidden_tokens=full_forbidden_tokens,
    )
    return SamplerOutput(
        **dataclasses.asdict(o),
        graph_embeddings=graph_embeddings,
    )



@chex.dataclass(frozen=True)
class TrainingInput:
  """Batch of training data for a GraphToken model."""

  # Input tokens given to the model
  input_tokens: np.ndarray  # size [B, L]

  # A mask that determines which tokens contribute to the target loss
  # calculation.
  target_mask: np.ndarray  # size [B, L]

  input_graphs: list[graph_struct.GraphStruct]  # size [B]

  # Ground truth for the input tokens, if representable as an integer.
  # For boolean classification tasks, this is 0/1.
  parsed_ground_truth: np.ndarray | None  # size [B]


def parse_int(s: str) -> int:
  """Parse a string as an integer."""
  return int(float(s.strip()))


def parse_yes_no(s: str) -> bool:
  """Parse a string as a yes/no answer, looking at the first 10 chars."""
  return 'yes' in s.lower()[:10]

PLACEHOLDER_TOKEN = '<unused0>'

def graphqa_ds(
    vocab: spm.SentencePieceProcessor,
    encoded_examples: list,
    padder: graph_struct.FixedSizePadder | None = None,
    max_tokens: int = 100,
    gt_parser: Callable[[str], Any] | None = None,
) -> tuple[graph_struct.FixedSizePadder, list[TrainingInput]]:
  """Load a GraphQA dataset as a list of TrainingInput.

  Args:
    vocab: The vocab to use for tokenization.
    encoded_examples: List of encoded examples generated by GraphQA
    padder: The padder to use for padding the graph. If None, a new padder will
      be created and returned. This is so a padder can be shared across multiple
      datasets / splits.
    max_tokens: The maximum number of tokens to allow in the input. For
      'task_only' prompting this can be quite small (100 tokens is plenty)
    gt_parser: A function to parse the ground truth from the answer string, used
      to supply the 'parsed_ground_truth' field in the TrainingInput.

  Returns:
    The padder used for padding the graphs, and a list of TrainingInput.
  """

  output = []
  for ex in encoded_examples:
      query = PLACEHOLDER_TOKEN + ex['question'][ex['question'].find('Q:'):]
      answer = ex['answer']
      graph = to_graph_struct(ex['graph'])
      query_tokens = vocab.EncodeAsIds(query)
      answer_tokens = vocab.EncodeAsIds(answer) + [vocab.eos_id()]
      input_tokens = np.array([vocab.bos_id()] + query_tokens + answer_tokens)
      target_mask = np.zeros_like(input_tokens, dtype=jnp.int32)
      # Add one for BOS token
      target_mask[len(query_tokens) + 1 :] = 1
      orig_len = len(query_tokens) + len(answer_tokens) + 1
      input_tokens = np.pad(
          input_tokens,
          [[0, max_tokens - orig_len]],
          constant_values=vocab.pad_id(),
      )

      target_mask = np.pad(target_mask, [[0, max_tokens - orig_len]])


      # The GNN library that we are using requires a global feature. We set
      # a fake value here, but it is unused otherwise.
      #graph = graph.update(nodes={'g': {'foo': np.zeros([1])}})

      output.append(
          TrainingInput(
              input_tokens=np.array([input_tokens]),
              target_mask=np.array([target_mask]),
              input_graphs=[graph],
              parsed_ground_truth=np.array(gt_parser(answer)) if gt_parser else None,
          )
      )
  if padder is None:
    padder = FixedSizePadder(sdnp.engine)
    padder.calculate_pad_statistics(
        [e.input_graphs[0] for e in output], len(output)
    )
  for o in output:
    o.input_graphs[0] = padder.pad_graph(o.input_graphs[0])
  return padder, output


def decode_questions(
    training_input: TrainingInput, vocab: spm.SentencePieceProcessor
) -> list[str]:
  """Decode the question from the input tokens. (ignoring the first 2)."""
  b, l = training_input.input_tokens.shape
  question_tokens = []
  for i in range(b):
    question_tokens.append([])
    # Skip the first two tokens (BOS and control token).
    for j in range(2, l):
      if training_input.target_mask[i, j] == 1:
        break
      question_tokens[i].append(int(training_input.input_tokens[i, j]))
  return [''.join(vocab.DecodeIds(q)) for q in question_tokens]


## training_loop
import functools
from typing import Any, MutableMapping

import chex
from flax import linen as nn
from gemma import transformer as transformer_lib
import jax
import jax.numpy as jnp
import optax
import tqdm

Params = MutableMapping[str, Any]


def get_attention_mask_and_positions(
    example: jax.Array,
    pad_id: int,
) -> tuple[jax.Array, jax.Array]:
  """Builds the position and attention mask vectors from the given tokens."""
  pad_mask = example != pad_id
  current_token_position = transformer_lib.build_positions_from_mask(pad_mask)
  attention_mask = transformer_lib.make_causal_attn_mask(pad_mask)
  return current_token_position, attention_mask


def forward_and_loss_fn(
    params: Params,
    *,
    gnn: nn.Module,
    llm: transformer_lib.Transformer,
    input_tokens: jax.Array,  # Shape [B, L]
    input_graphs: list[graph_struct.GraphStruct],  # Shape [B]
    input_mask: jax.Array,  # Shape [B, L]
    positions: jax.Array,  # Shape [B, L]
    attention_mask: jax.Array,  # [B, L, L]
    placeholder_token_id: int,
) -> jax.Array:
  """Forward pass and loss function.

  Args:
    params: Params for the gnn and transformer. The gnn params are stored in
      params['gnn'] and the llm params are stored in params['transformer'].
    gnn: gnn model to call.
    llm: gemma transformer model to call.
    input_tokens: input tokens sequence, shape [B, L].
    input_graphs: input graphs.
    input_mask: tokens to ignore when computing the loss, shape [B, L].
    positions: relative position of each token, shape [B, L].
    attention_mask: input attention mask, shape [B, L].
    placeholder_token_id: Index in the LLM vocabulary that we are using for passing
      graph embeddings.

  Returns:
    Softmax cross-entropy loss for the next-token prediction task.
  """
  # Right now we only support batch_size = 1
  chex.assert_axis_dimension(input_tokens, 0, 1)
  chex.assert_equal_shape([input_tokens, input_mask, positions])
  chex.assert_axis_dimension(attention_mask, 0, 1)
  chex.assert_equal(len(input_graphs), 1)

  # Get the GNN embedding and update the transformer input embedding for a
  # control token.
  graph_embed = gnn.apply(params['gnn'], input_graphs[0])
  params['transformer']['embedder']['input_embedding'] = (
      params['transformer']['embedder']['input_embedding']
      .at[placeholder_token_id]
      .set(graph_embed)
  )
  # Forward pass on the input data.
  # No attention cache is needed here.
  logits, _ = llm.apply(
      {'params': params['transformer']},
      input_tokens,
      positions,
      None,  # Attention cache is None.
      attention_mask,
  )

  # Exclude the last step as it does not appear in the targets.
  logits = logits[0, :-1]

  # Similarly, the first token cannot be predicted.
  target_tokens = input_tokens[0, 1:]
  target_mask = input_mask[0, 1:]

  # Convert the target labels into one-hot encoded vectors.
  one_hot = jax.nn.one_hot(target_tokens, logits.shape[-1])

  # Don't update on unwanted tokens.
  one_hot = one_hot * target_mask.astype(one_hot.dtype)[..., jnp.newaxis]

  # Normalisation factor.
  norm_factor = 1 / (jnp.sum(target_mask) + 1e-8)

  # Return the nll loss.
  return -jnp.sum(jax.nn.log_softmax(logits) * one_hot) * norm_factor


@functools.partial(
    jax.jit,
    static_argnames=['gnn', 'llm', 'optimizer', 'pad_id', 'placeholder_token_id'],
)
def train_step(
    llm: transformer_lib.Transformer,
    gnn: nn.Module,
    params: MutableMapping[str, Any],
    optimizer: optax.GradientTransformation,
    opt_state: optax.OptState,
    pad_id: int,
    example: TrainingInput,
    placeholder_token_id: int,
) -> tuple[jax.Array, Params, optax.OptState]:
  """Train step.

  Args:
    llm: gemma transformer model.
    gnn: gnn model.
    params: model's input parameters.
    optimizer: optax optimizer to use.
    opt_state: input optimizer's state.
    pad_id: id of the pad token.
    example: input batch.
    placeholder_token_id: Index in the LLM vocabulary that we are using for passing
      graph embeddings.

  Returns:
    Training loss, updated parameters, updated optimizer state.
  """

  # Build the position and attention mask vectors.
  positions, attention_mask = get_attention_mask_and_positions(
      jnp.array(example.input_tokens), pad_id
  )

  # Forward and backward passes
  train_loss, grads = jax.value_and_grad(forward_and_loss_fn)(
      params,
      gnn=gnn,
      llm=llm,
      input_tokens=example.input_tokens,
      input_mask=example.target_mask,
      input_graphs=example.input_graphs,
      positions=positions,
      attention_mask=attention_mask,
      placeholder_token_id=placeholder_token_id,
  )

  updates, opt_state = optimizer.update(
      grads['gnn'], opt_state, params=params['gnn']
  )
  params['gnn'] = optax.apply_updates(params['gnn'], updates)

  return train_loss, params, opt_state


@functools.partial(
    jax.jit, static_argnames=['gnn', 'llm', 'pad_id', 'placeholder_token_id']
)
def validation_step(
    gnn: nn.Module,
    llm: transformer_lib.Transformer,
    params: MutableMapping[str, Any],
    pad_id: int,
    example: TrainingInput,
    placeholder_token_id: int,
) -> jax.Array:
  """Validation step.

  Args:
    gnn: gnn model.
    llm: gemma transformer model.
    params: model's input parameters. The gnn params are stored in params['gnn']
      and the llm params are stored in params['transformer'].
    pad_id: id of the pad token.
    example: input batch
    placeholder_token_id: Index in the LLM vocabulary that we are using for passing
      graph embeddings.

  Returns:
    Validation loss.
  """
  jax_input = jax.tree.map(jnp.array, example)
  positions, attention_mask = get_attention_mask_and_positions(
      jax_input.input_tokens, pad_id
  )
  val_loss = forward_and_loss_fn(
      params,
      gnn=gnn,
      llm=llm,
      input_tokens=jax_input.input_tokens,
      input_mask=jax_input.target_mask,
      input_graphs=jax_input.input_graphs,
      positions=positions,
      attention_mask=attention_mask,
      placeholder_token_id=placeholder_token_id,
  )
  return val_loss


@chex.dataclass(frozen=True)
class TrainingConfig:
  learning_rate: float
  num_epochs: int
  eval_every_n: int
  batch_size: int
  max_steps: int | None = None


def train_loop(
    llm: transformer_lib.Transformer,
    gnn: nn.Module,
    train_ds: list[TrainingInput],
    validation_ds: list[TrainingInput],
    params: Params,
    training_cfg: TrainingConfig,
    vocab: spm.SentencePieceProcessor,
) -> Params:
  """Main training loop for GraphToken.

  Args:
    llm: Gemma transformer model.
    gnn: gnn model.
    train_ds: training dataset.
    validation_ds: validation dataset.
    params: Combined params for both the LLM and GNN. The GNN params are stored
      in params['gnn'] and the LLM params are stored in params['transformer'].
    training_cfg: training configuration.
    vocab: sentence piece vocabulary.

  Returns:
    Updated model's input parameters.
  """
  optimizer = optax.lion(training_cfg.learning_rate)
  opt_state = optimizer.init(params['gnn'])

  avg_loss = 0

  placeholder_token_id = vocab.EncodeAsIds(PLACEHOLDER_TOKEN)
  assert (
      len(placeholder_token_id) == 1
  ), f'Placeholder token multiple ids: {placeholder_token_id}'
  placeholder_token_id = placeholder_token_id[0]
  # A first round of validation loss
  n_steps_eval = 0
  eval_loss = 0

  with tqdm.tqdm(range(training_cfg.num_epochs * len(train_ds))) as pbar:
    averaged_steps = 0
    for n_steps in pbar:
      train_example = train_ds[n_steps % len(train_ds)]
      train_loss, params, opt_state = train_step(
          gnn=gnn,
          llm=llm,
          params=params,
          optimizer=optimizer,
          opt_state=opt_state,
          pad_id=vocab.pad_id(),
          example=train_example,
          placeholder_token_id=placeholder_token_id,
      )
      averaged_steps += 1
      avg_loss += train_loss
      if n_steps and n_steps % training_cfg.eval_every_n == 0:
        val_iterator = validation_ds
        avg_loss /= averaged_steps
        averaged_steps = 0
        pbar.write(
            f'STEP {n_steps} training loss: {avg_loss}'
        )
        avg_loss = 0
      if (
          training_cfg.max_steps is not None
          and n_steps > training_cfg.max_steps
      ):
        break
    if averaged_steps != 0:
      avg_loss /= averaged_steps
      pbar.write(
            f'STEP {n_steps} training loss: {avg_loss}'
        )
  return params

def merge_params(llm_params, gnn_params):
  out = {}
  out.update(llm_params)
  out['gnn'] = gnn_params
  return out

## Train a GraphToken Model

Load Gemma Weights

In [None]:
params = params_lib.load_and_format_params(ckpt_path)

# Reshard params over TPU device mesh
from jax.sharding import PartitionSpec as P
mesh = jax.sharding.Mesh(np.array(jax.devices()).reshape(4, 2), ('x', 'y'))
sharding = jax.sharding.NamedSharding(mesh, P('x'))
def try_to_shard(x):
  try:
    return jax.device_put(x, sharding)
  except:
    return x
params = jax.tree_map(try_to_shard, params)


config_2b = transformer_lib.TransformerConfig.from_params(
    params,
    cache_size=128  # Number of time steps in the transformer's cache
)
model_2b = transformer_lib.Transformer(config=config_2b)

# Load vocabulary
vocab = spm.SentencePieceProcessor()
assert vocab.Load(vocab_path)

Generate some training data, for the CycleCheck task

In [None]:
from talk_like_a_graph import graph_generators
from talk_like_a_graph import graph_tasks
random_seed = 9876

train_graphs = graph_generators.generate_graphs(number_of_graphs=500,
                         algorithm='er', # Erdos-Reyni random graphs
                         directed=False,
                         random_seed=random_seed)
test_graphs = graph_generators.generate_graphs(number_of_graphs=10,
                         algorithm='er', # Erdos-Reyni random graphs
                         directed=False,
                         random_seed=random_seed + 12385)
task = graph_tasks.CycleCheck()
train_examples = list(task.prepare_examples_dict(
    train_graphs,
    generator_algorithms = ['er']*len(train_graphs),
    encoding_method='adjacency').values())
test_examples = list(task.prepare_examples_dict(
    test_graphs,
    generator_algorithms = ['er']*len(test_graphs),
    encoding_method='adjacency').values())
padder, train_ds = graphqa_ds(vocab, train_examples, max_tokens=25)
_, test_ds = graphqa_ds(vocab, test_examples, max_tokens=25, padder=padder)

Train GraphToken

In [None]:
gin = GIN(config_2b.embed_dim, num_hidden_layers=3, hidden_dim=4)
key = jax.random.PRNGKey(0)
gnn_params = gin.init(key, train_ds[0].input_graphs[0])


train_config = TrainingConfig(
    learning_rate=0.0001, num_epochs=3, eval_every_n=250, batch_size=1
)
params_learned = train_loop(
    llm=model_2b,
    gnn=gin,
    train_ds=train_ds,
    validation_ds=test_ds,
    params=merge_params(params, gnn_params),
    training_cfg=train_config,
    vocab=vocab,
)

Sample outputs

In [None]:
from IPython.display import Markdown, display

graph_token_sampler = GraphTokenSampler(
    params=params_learned, llm=model_2b, gnn=gin, vocab=vocab
)


def get_graph_qa_question(ex):
  with_graph = ex['question']
  q_index = with_graph.find('Q:')
  return with_graph[q_index:]

for i in range(len(test_examples)):
  tokenized_input = test_ds[i]
  ex = test_examples[i]
  prompt = get_graph_qa_question(ex)
  if i == 0:
    display(
        Markdown(
            '**Prompt:** '
            + prompt
            + '\n\n'
        )
    )
  llm_output = graph_token_sampler(
      [prompt],
      tokenized_input.input_graphs,
      total_generation_steps=15,
      return_logits=False,
  ).text[0]
  display(Markdown(f'**LLM Output:** "{llm_output}"'))
  display(
      Markdown(f"**Ground Truth:** {ex['answer']}")
  )
  display(Markdown('-' * 80))
  print()



## Exercise: Train a model for a different task.

Take the above code and modify it for the NodeCount task.
How does your model perform?

If your kernel runs out of memory run the following code to clear the TPU memory, then re-run the code block labeled 'Load Gemma weights' and retry.
```
for a in jax.live_arrays():
  a.delete()
```