## **1. Building recommender systems using GNNs**

This tutorial assumes that you have finished the graph neural networks and recommender systems tutorials. In this section of the tutorial we will demonstrate how we can use graph neural networks in recommender systems.

**Prerequisites:**

- Familiarity with Jax, especially flax & jraph
- A grasp on the basics of neural networks
- To have finished the graph neural networks (GNNs) tutorial
- To have finished the recommender systems tutorial

**Aims/Learning Objectives:**
- Frame link prediction tasks within the context of movie recommendations
- Implementation of GCN for movie recommendations


### **1.1 Introduction**

Graphs are a powerful and general representation of data with a wide range of applications. Most people are familiar with their use in contexts like social networks and biological systems. Another use case for graphs is in the realm of recommender systems which is it the focus of this tutorial.

### **1.2. Graph Prediction Tasks**

As demonstrated in the GNN tutorial there are three graph tasks in the context of graph neural networks:

1. **Node Classification**: E.g. what is the topic of a paper given a citation network of papers?
2. **Link Prediction / Edge Classification**: E.g. are two people in a social network friends?
3. **Graph Classification**: E.g. is this protein molecule (represented as a graph) likely going to be effective?

<image src="https://storage.googleapis.com/dm-educational/assets/graph-nets/graph_tasks.png" width="700px">

*Image source: Petar Veličković.*


### **1.3. Recommender systems as a link prediction problem**

 A recommender system can be visualized as a graph, where entities (such as users and items) are nodes, and the interactions between them (such as ratings or purchase history) are edges. In the context of a movie recommendation system:

- Nodes might represent:
    - Users: Individuals consuming the content.
    - Movies: Content items to be recommended.
- Edges represent:
    - Ratings: A directed edge from a user to a movie, annotated with a weight that indicates the rating (e.g., on a scale of 1 to 5).



#### **1.3.1. Link (or Edge) Prediction**

Link prediction tries to predict whether a link (or edge) should exist between two nodes, even if it's currently absent. For our movie recommendation system, this translates to predicting whether a user would like (or dislike) a movie they haven't yet rated.

The process works as follows:
1. Train on existing edges: Use known ratings (edges) from users to movies to train a model.
2. Predict missing edges: For a given user, predict ratings for movies they haven't seen or rated. This is akin to predicting missing or potential edges in our graph.
3. Recommend based on predictions: Movies with the highest predicted ratings are recommended to the user.

In this section of the tutorial, we will leverage graph neural networks (GNNs) for this link prediction task. The GNNs operate on the graph, aggregating information from neighboring nodes to produce accurate predictions for unseen edges. The below figure showcases a user-item sub-knowledge graph. Each icon denotes an entity or concept, and the connecting lines or edges symbolize the relationships between them.

In [None]:
# @title Installations
%%capture
!pip install annoy clu
!pip install git+https://github.com/deepmind/jraph.git
!pip install torch-geometric
!pip install networkx

In [None]:
## Install and import anything required. Capture hides the output from the cell.
# @title Import required packages. (Run Cell)

import gc
import jax
import jraph
import optax
import torch
import numpy as np
import pandas as pd
import networkx as nx
import tensorflow as tf
from flax import struct
from clu import metrics
from flax import linen as nn
from jax import numpy as jnp
import jax.tree_util as tree
import matplotlib.pyplot as plt
import tensorflow_datasets as tfds
from jraph import GraphConvolution
from flax.training import train_state
from torch_geometric.data import Data
from torch_geometric.transforms import RandomLinkSplit
from typing import Iterable, Mapping, Sequence, Tuple, Callable

SEED = 42
np.random.seed(SEED)
torch.manual_seed(SEED)

![Description](graph_rc.png)

*Image source: [HI2Rec: Exploring Knowledge in Heterogeneous Information for Movie Recommendation](https://www.semanticscholar.org/paper/HI2Rec%3A-Exploring-Knowledge-in-Heterogeneous-for-He-Wang/038eb4e6839352c8fa8f9c4f5ae5ff958e14c5a3)*

### **1.4 Data download**

We have already preprocessed the data as a graph. If you are interested to have a better understanding of how we converted the initial movies rating data into a graph, feel free to read this tutorial section 3. 

In [None]:
#@title Download data


In [None]:
#@title Read data


In [None]:
#@title Explore data


We will use graph convolution neural network (GCN) as our model. For more in-depth understanding of GCN, please go throught the above provided resources in your free time.

In [None]:
#@title Model
class MLP(nn.Module):
  """A flax MLP."""
  features: Sequence[int]
  kernel_init: Callable = jax.nn.initializers.he_uniform()
  bias_init: Callable = jax.nn.initializers.zeros

  @nn.compact
  def __call__(self, inputs):
    x = inputs
    for i, feat in enumerate(self.features):
        lyr = nn.Dense(feat, kernel_init=self.kernel_init, bias_init=self.bias_init, name=f"mlp_dense_{i}")
        x = lyr(x)
        x = nn.relu(x)
        if i != len(self.features) - 1:
            x = nn.relu(x)
    return x

def make_embed_fn(latent_size):
    def embed(inputs):
        return nn.Dense(latent_size)(inputs)
    return embed

class GraphConvLayer(nn.Module):

  output_decoder_dim: int
  latent_size: int
  update_node_fn: Callable
  aggregate_nodes_fn: Callable = jax.ops.segment_sum
  add_self_edges: bool = False
  symmetric_normalization: bool = True
  layer_norm: bool = False

  @nn.compact
  def __call__(self, graph):
    gcn = GraphConvolution(
        update_node_fn=self.update_node_fn,
        aggregate_nodes_fn=self.aggregate_nodes_fn,
        add_self_edges=self.add_self_edges,
        symmetric_normalization=self.symmetric_normalization
    )
    graph = gcn(graph)
    if self.layer_norm:
      # Apply layer normalization to the node embeddings
      normalized_nodes = nn.LayerNorm()(graph.nodes)
      # Update the graph with the normalized node embeddings
      graph = graph._replace(nodes=normalized_nodes)

    edge_predictions = jnp.sum(graph.nodes[graph.senders] * graph.nodes[graph.receivers], axis=-1)
    edge_predictions = jnp.expand_dims(edge_predictions, axis=1)
    edge_predictions = nn.Dense(self.output_decoder_dim, name="mlp_dense_output")(edge_predictions)

    # Apply sigmoid activation and scale it
    edge_predictions = 4 * jax.nn.sigmoid(edge_predictions) + 1
    return edge_predictions

In [None]:
# Define one layer gcn
gcn_layer = GraphConvLayer(
    output_decoder_dim = 1,
    latent_size = 64,
    update_node_fn=lambda n: MLP(features=[64, 64])(n),
    aggregate_nodes_fn=jax.ops.segment_sum,
    add_self_edges=False,
    symmetric_normalization=True
)

# Initialize to see the output shapes
params = gcn_layer.init(jax.random.PRNGKey(42), graph)
output = gcn_layer.apply(params, graph)
output.shape, graph.edges.shape

In [None]:
# Define loss function
def l1_loss(logits: np.ndarray, y: np.ndarray, mask: np.array, reduction: str = "mean") -> np.ndarray:
    """Implementation of l1_loss.

    Args:
        logits: model output logits.
        y: class labels.
        reduction: if reduction is mean, the average is returned, else if it is sum, the sum is returned.

    Returns:
       l1 loss.
    """
    logits_masked = logits * mask.astype(int)
    preds_masked = y * mask.astype(int)
    if reduction == "mean":
        loss = jnp.mean(jnp.abs(logits_masked - preds_masked))
    if reduction == "sum":
        loss = jnp.sum(jnp.abs(logits_masked - preds_masked))

    return loss

# check random loss without training
l1_loss(output, graph.edges, mask=val_mask)

In [None]:
#@title Define flax train state
def create_train_state(
    model, graph, tx, rngs
):
    """Train state. This function initializes the model."""

    @jax.jit
    def initialize(params_rng):
        variables = model.init(
            params_rng,
            graph,
        )
        return variables

    variables = initialize(rngs)
    state = train_state.TrainState.create(apply_fn=model.apply, params=variables["params"], tx=tx)

    param_count = sum(x.size for x in jax.tree_util.tree_leaves(state.params))
    print("---> number of model parameters: ", param_count)
    return state

In [None]:
# Initialize model, train state, and other hparams
optimizer = optax.adam(learning_rate=0.001)
rng = jax.random.PRNGKey(0)
rng, init_rng = jax.random.split(rng)
rngs = {"params": rng, "dropout": init_rng}

model = GraphConvLayer(
    output_decoder_dim = 1,
    latent_size = 128,
    update_node_fn=lambda n: MLP(features=[128, 128, 64])(n),
    aggregate_nodes_fn=jax.ops.segment_sum,
    add_self_edges=True,
    layer_norm=True,
    symmetric_normalization=True
)
state = create_train_state(
    model=model,
    graph=graph,
    tx=optimizer,
    rngs=rngs,
)

As you can see our model has 27458 parameters. You can play with this by changing hyper-parameters such as latent_size, MLP features, etc. Next, we will define train and evaluation steps and train the model.

In [None]:
# @title Train step
@jax.jit
def train_step(
    state: train_state.TrainState,
    graph: jnp.array,
    labels: jnp.ndarray,
    mask: jnp.ndarray,
    rngs: dict,
) -> Tuple[train_state.TrainState, tuple]:
    """Performs one update step over the graph.

    Args:
        state: training state.
        graph: graph node features.
        labels: graph edge labels.
        mask: mask for labels to consider for optimization
        rngs: rngs for droupout

    Returns:
        Current training state, the loss, and logits.
    """
    step = state.step
    rngs = {name: jax.random.fold_in(rng, step) for name, rng in rngs.items()}

    def loss_fn(params, graph, labels):
        # Compute logits and resulting loss.
        variables = {"params": params}
        logits = state.apply_fn(
            variables,
            graph=graph,
            rngs=rngs,
        )
        loss = l1_loss(logits=logits, y=labels, mask=mask)
        return loss, logits

    grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
    (loss, logits), grads = grad_fn(state.params, graph, labels)
    new_state = state.apply_gradients(grads=grads)
    return new_state, (loss, logits)


In [None]:
# @title Evaluation step
@jax.jit
def evaluate_step(
    state: train_state.TrainState,
    graph: jnp.array,
    labels: jnp.ndarray,
    mask: jnp.ndarray,
    dropout_rng: dict = None,
) -> tuple:
    """Performs evaluation step over a set of inputs."""
    variables = {"params": state.params}
    logits = state.apply_fn(
        variables,
        graph=graph,
        rngs=dropout_rng,
    )
    loss = l1_loss(logits=logits, y=labels, mask=mask)
    return (loss, logits)

In [None]:
# @title Train loop function
def train_eval(state, graph, train_mask, val_mask, rng, epochs = 10):
    final_train_loss = []
    final_val_loss = []
    for epoch in range(epochs):
        rng, epoch_rng = jax.random.split(rng)
        epoch_rng = {"dropout": epoch_rng}

        state, (train_loss, train_logits) = train_step(state=state, graph=graph, labels=graph.edges, mask=train_mask, rngs = epoch_rng)
        val_loss, val_logits = evaluate_step(state=state, graph=graph, mask=val_mask, labels=graph.edges)
        print(f"Epoch: {epoch}, train_loss: {train_loss}, val_loss: {val_loss}")
        final_train_loss.append(train_loss.item())
        final_val_loss.append(val_loss.item())

    return (state, final_train_loss, final_val_loss)

In [None]:
# train the model
epochs = 100
final_state, final_train_loss, final_val_loss = train_eval(state=state, graph=graph, train_mask=train_mask, val_mask=val_mask, rng=rng, epochs = epochs)

In [None]:
# Plotting losses
plt.figure(figsize=(10, 6))
plt.plot(range(epochs), final_train_loss, label="Train Loss", marker='o')
plt.plot(range(epochs), final_val_loss, label="Validation Loss", marker='*')
plt.title("Training and Validation Loss")
plt.xlabel("Epochs")
plt.ylabel("Loss")
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.show()

#### **2. Inference**

In this section, we will show how we can recommend top N movies for a specific user based on the trained GNN. Since we already have a test set edges. We will use this to demonstrate the inference. We will do the following:

1. Extract test edges: With the test_mask, obtain the edges that belong to the test set.
2. Predict ratings: For each of these test edges, use the trained GNN to predict a rating.
3. Map to original data: Using the movie_mapping and user_mapping from the GraphDataPreparation class, convert the predicted edges back to the original movie and user IDs.
4. Get movie names: For the mapped movie IDs, fetch the actual movie names.
5. To get the top N recommended movies for a particular user, we will sort the movies based on the predicted ratings and then pick the top N movies for the user.


Below we will implement two functions to do the above steps for us.

In [None]:
def infer_movie_ratings(graph, state, test_mask, data_prep: GraphDataPreparation):
    """
    Infer movie ratings for the test set and map them to the original movies and users.

    Args:
        graph: The graph containing nodes and edges.
        state: The trained state of the model.
        test_mask: The mask indicating which edges belong to the test set.
        data_prep: The GraphDataPreparation instance used to process the data.

    Returns:
        A DataFrame containing user IDs, movie names, and predicted ratings.
    """

    # Using the test_mask, get the edges corresponding to the test set
    test_edge_senders = graph.senders[test_mask.squeeze()]
    test_edge_receivers = graph.receivers[test_mask.squeeze()] - len(data_prep.user_mapping)

    # Predict the ratings using the trained model
    _, logits = evaluate_step(state, graph, labels=graph.edges, mask=test_mask)
    predicted_ratings = logits[test_mask]

    # Map back to original user and movie IDs
    reverse_user_mapping = {v: k for k, v in data_prep.user_mapping.items()}
    reverse_movie_mapping = {v: k for k, v in data_prep.movie_mapping.items()}

    original_user_ids = np.array([reverse_user_mapping[v] for v in test_edge_senders]).astype(int)
    original_movie_ids = np.array([reverse_movie_mapping[v] for v in test_edge_receivers]).astype(int)

    # Revert the remapping on the movie_id column of the movies_df
    movies_df_original = data_prep.movies_df.copy()
    movies_df_original['movie_id'] = movies_df_original['movie_id'].map(reverse_movie_mapping).astype(int)

    # Get movie names using the original movie_ids
    movie_title = movies_df_original.set_index('movie_id').loc[original_movie_ids, 'movie_title'].values

    # Create a DataFrame with the results
    result_df = pd.DataFrame({
        'user_id': original_user_ids,
        'movie_name': movie_title,
        'predicted_rating': np.round(predicted_ratings, 3)
    })

    return result_df

predicted_ratings_df = infer_movie_ratings(graph, final_state, test_mask, graph_prep)
predicted_ratings_df.head()

In [None]:
def get_top_n_recommendations(df, user_id, N=10):
    """
    Get the top N recommended movies for a user.

    Args:
        df (pd.DataFrame): DataFrame containing user IDs, movie title, and predicted ratings.
        user_id (int): The user ID for whom the recommendations are to be fetched.
        N (int): The number of top movies to fetch.

    Returns:
        A list of top N movie title for the user.
    """

    # Filter out the movies for the given user and sort them based on predicted ratings in descending order
    top_movies = df[df['user_id'] == user_id].sort_values(by='predicted_rating', ascending=False).head(N)

    return top_movies['movie_name'].tolist()

# choose user ID to get it's N ratings
user_id = 600
N = 10
top_10_movies = get_top_n_recommendations(predicted_ratings_df, user_id, N)
print(f'Top {len(top_10_movies)} movies recommended for user ID: {user_id} are: ')

for number, movie_title in enumerate(top_10_movies):
    print(f'{number}: {movie_title}')