##### This is ongoing notebook
Few things needs to be improved: 
- Create the graph using pure jax numpy and remove the few pytorch related things (torch.nn.Embedding to flax.nn.Embed)
- Making the graph Bidirectional instead Unidirectional (Optional)
- Training batchwise, instead of full pass
- Checking a bit more GraphMapFeatures, GraphNetwork classes from jraph and how they do the message passing
- Look more into the objective formulation. I.e. shall we do link prediction in unsupervised way, or since we have the rating use the rating as edge labels.
- And finally make this notebook as Indaba_Prac Template

In [1]:
import torch
import jax
import jraph
import networkx as nx
import pandas as pd
import numpy as np
from typing import Sequence
from flax import linen as nn
import jax.numpy as jnp
from data_prep import UniGraphDataPreparation
from data_prep import subset_dataset

In [2]:

# Load the dataset
ratings_df = pd.read_csv('../ml-25m/ratings.csv')
movies_df = pd.read_csv('../ml-25m/movies.csv')
ratings_df.head()

Unnamed: 0,userId,movieId,rating,timestamp
0,1,296,5.0,1147880044
1,1,306,3.5,1147868817
2,1,307,5.0,1147868828
3,1,665,5.0,1147878820
4,1,899,3.5,1147868510


In [3]:
ratings_subset_df, movies_subset_df = subset_dataset(ratings_df=ratings_df, movies_df=movies_df)
ratings_subset_df.shape, movies_subset_df.shape

((4027019, 5), (1000, 3))

In [4]:
len(ratings_subset_df.userId.unique()), len(ratings_subset_df.movieId.unique())

(10000, 1000)

In [5]:
del ratings_df
del movies_df

In [6]:
graph = UniGraphDataPreparation(ratings_df = ratings_subset_df, movies_df = movies_subset_df)
train_graph, test_graph = graph.prepare_data()

In [7]:
train_graph.nodes.shape, test_graph.nodes.shape

((11000, 19), (11000, 19))

In [8]:
# node features
train_graph.nodes

array([[-1.4612067 , -0.2764672 , -1.2072662 , ...,  0.9139635 ,
        -1.2479409 , -1.8211188 ],
       [ 1.1204472 , -1.35316   ,  0.84818536, ...,  0.6455621 ,
         0.6479968 , -0.6489184 ],
       [-0.7850939 ,  1.2695131 , -0.65635425, ..., -0.09871044,
        -0.76065737,  1.56671   ],
       ...,
       [ 0.        ,  0.        ,  0.        , ...,  0.        ,
         0.        ,  0.        ],
       [ 1.        ,  1.        ,  0.        , ...,  0.        ,
         0.        ,  0.        ],
       [ 1.        ,  0.        ,  0.        , ...,  0.        ,
         0.        ,  0.        ]], dtype=float32)

In [9]:
# check source nodes, in our case this corresponds to users
train_graph.senders

array([4066, 6204, 4034, ..., 2560, 9460, 4233])

In [10]:
# check destination nodes, in our case this corresponds to movies
train_graph.senders

array([4066, 6204, 4034, ..., 2560, 9460, 4233])

In [45]:
# Adopted from https://github.com/deepmind/jraph/blob/master/jraph/ogb_examples/train_flax.py

class MLP(nn.Module):
  """A flax MLP."""
  features: Sequence[int]

  @nn.compact
  def __call__(self, inputs):
    x = inputs
    for i, lyr in enumerate([nn.Dense(feat) for feat in self.features]):
      x = lyr(x)
      # if i != len(self.features) - 1:
      x = nn.relu(x)
    return x

class EdgePredictor(nn.Module):
  features: Sequence[int]

  @nn.compact
  def __call__(self, edge_features):
    x = edge_features
    for i, feat in enumerate(self.features):
      x = nn.Dense(feat)(x)
      if i != len(self.features) - 1:
        x = nn.relu(x)
    return x


def make_embed_fn(latent_size):
  def embed(inputs):
    return nn.Dense(latent_size)(inputs)
  return embed


def make_mlp(features):
  @jraph.concatenated_args
  def update_fn(inputs):
    return MLP(features)(inputs)
  return update_fn


class GraphNetwork(nn.Module):
  """A flax GraphNetwork."""
  mlp_features: Sequence[int]
  edge_pred_features: Sequence[int]
  latent_size: int

  @nn.compact
  def __call__(self, graph):
    # Add a global parameter.
    graph = graph._replace(globals=jnp.zeros([graph.n_node.shape[0], 1])) 

    embedder = jraph.GraphMapFeatures(
        embed_node_fn=make_embed_fn(self.latent_size),
        embed_edge_fn=make_embed_fn(self.latent_size))

    net = jraph.GraphNetwork(
        update_node_fn=make_mlp(self.mlp_features),
        update_edge_fn=make_mlp(self.mlp_features),
        )
      
    updated_graph = net(embedder(graph))
    edge_predictions = EdgePredictor(self.edge_pred_features)(updated_graph.edges)
    return edge_predictions

In [46]:
net = GraphNetwork(mlp_features=(128, 128), latent_size=128, edge_pred_features=(64, 1))
# Initialize the network.
params = net.init(jax.random.PRNGKey(42), test_graph)
output = net.apply(params, test_graph)
output.shape, test_graph.edges.shape

((805404, 1), (805404, 1))

In [48]:
def l1_loss(logits: np.ndarray, y: np.ndarray, reduction: str = "mean") -> np.ndarray:
    """Implementation of l1_loss.

    Args:
        logits: model output logits.
        y: class labels.
        reduction: if reduction is mean, the average is returned, else if it is sum, the sum is returned.

    Returns:
       l1 loss.
    """
    if reduction == "mean":
        loss = jnp.mean(jnp.abs(logits - y))
    if reduction == "sum":
        loss = jnp.sum(jnp.abs(logits - y))

    return loss

l1_loss(output, test_graph.edges)

Array(3.1936445, dtype=float32)

In [55]:
from flax.training import train_state
from typing import Tuple
import optax

def create_train_state(
    model, graph, tx, rngs
):
    """Train state. This function initializes the model."""

    @jax.jit
    def initialize(params_rng):
        variables = model.init(
            params_rng,
            graph,
            # train=False,
        )
        return variables

    variables = initialize(rngs)
    state = train_state.TrainState.create(apply_fn=model.apply, params=variables["params"], tx=tx)

    param_count = sum(x.size for x in jax.tree_util.tree_leaves(state.params))
    print("--- number of model parameters: ", param_count)
    return state

optimizer = optax.adam(learning_rate=1e-4)
rng = jax.random.PRNGKey(0)
rng, init_rng = jax.random.split(rng)
rngs = {"params": rng, "dropout": init_rng}

model = GraphNetwork(mlp_features=(128, 128), latent_size=128, edge_pred_features=(64, 1))
state = create_train_state(
    model=model,
    graph=test_graph,
    tx=optimizer,
    rngs=rngs,
)

--- number of model parameters:  142977


In [56]:
@jax.jit
def train_step(
    state: train_state.TrainState,
    graph: jnp.array,
    labels: jnp.ndarray,
    rngs: dict,
) -> Tuple[train_state.TrainState, tuple]:
    """Performs one update step over the graph.

    Args:
        state: training state.
        graph: graph node features.
        labels: graph classification labels.
        rngs: rngs for droupout

    Returns:
        Current training state, the loss, and logits.
    """
    step = state.step
    rngs = {name: jax.random.fold_in(rng, step) for name, rng in rngs.items()}

    def loss_fn(params, graph, labels):
        # Compute logits and resulting loss.
        variables = {"params": params}
        logits = state.apply_fn(
            variables,
            graph=graph,
            rngs=rngs,
        )
        loss = l1_loss(logits, labels)
        return loss, logits

    grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
    (loss, logits), grads = grad_fn(state.params, graph, labels)
    new_state = state.apply_gradients(grads=grads)
    return new_state, (loss, logits)


In [57]:
@jax.jit
def evaluate_step(
    state: train_state.TrainState,
    graph: jnp.array,
    labels: jnp.ndarray,
    dropout_rng: dict = None,
) -> tuple:
    """Performs evaluation step over a set of inputs."""
    # Get predicted logits, and corresponding probabilities.
    variables = {"params": state.params}
    logits = state.apply_fn(
        variables,
        graph=graph,
        rngs=dropout_rng,
    )
    loss = l1_loss(logits, labels)
    return (loss, logits)

In [58]:

def train_eval(state, train_graph, test_graph, rng):
    # Train for 10 epochs
    for epoch in range(4):
        # Train for one epoch.
        rng, epoch_rng = jax.random.split(rng)
        epoch_rng = {"dropout": epoch_rng}

        state, (train_loss, train_logits) = train_step(state=state, graph=train_graph, labels=train_graph.edges, rngs = epoch_rng)

        test_loss, test_logits = evaluate_step(state=state, graph=test_graph, labels=test_graph.edges)
        print(f"Epoch: {epoch}, train_loss: {train_loss}, val_loss: {test_loss}")

    return state

final_state = train_eval(state, train_graph, test_graph, rng)

Epoch: 0, train_loss: 3.847727060317993, val_loss: 3.684267997741699
Epoch: 1, train_loss: 3.6857426166534424, val_loss: 3.5305914878845215
Epoch: 2, train_loss: 3.5319674015045166, val_loss: 3.385692834854126
Epoch: 3, train_loss: 3.3869986534118652, val_loss: 3.2482481002807617
