## Introduction

We're now going to take a look at the drosha measurements and how to featurize them onto the graphs.

In [None]:
from pyprojroot import here
import pandas as pd

df_bioc = pd.read_csv(here() / "data/df_bioc.csv", index_col=0)
df_bioc.columns

There are a lot of columns in there, however, the ones we are most interested in are:

- `frac_avg`: Gives us the activity
- `dot_bracket`: Gives us the dot-bracket notation

Things that we may be interested in include:

- The `shannon_{pos}` series of columns, which gives us the shannon entropy of that particular position in the folded RNA.

Our goal here is to predict `frac_avg` (or some math transform of it) from the `dot_bracket` structure.
Our hypothesis here is that the `dot_bracket` structure represented as a graph
gives us sufficient information to predict `frac_avg` accurately;
alternatively, we might want to add in the shannon entropy,
as we found previously that it was visually\* correlated with RNA cleavage (`frac_avg`).


> \* by visually correlated, we refer to Fig. 2 of [our previously-published paper](https://www.sciencedirect.com/science/article/abs/pii/S1097276520307358).

In [None]:
from drosha_gnn.graph import to_networkx
import janitor

## Make graphs from dot-bracket

In [None]:
from jax.scipy.special import logit
import jax.numpy as np

def logit_transform(value, tol: float = 1e-5):
    return logit(np.clip(value, tol, 1-tol))

df = (
    df_bioc
    .join_apply(lambda row: to_networkx(row["dot_bracket"]), "graph")
    .transform_column("frac_avg", logit_transform, "frac_avg_logit")
)
df.head()

In [None]:
import jax.numpy as np

def ecdf(data):
    x = np.sort(data)
    y = np.arange(len(data)) / (len(data) + 1)
    return x, y

In [None]:
import matplotlib.pyplot as plt

x, y = ecdf(df_bioc["frac_avg"].values)
plt.scatter(x, y)

We have a pretty even distribution of points here.

One thing that we'll definitely want to do is to regress on the logits,
so we'll have to transform the `frac_avg` column to logit space instead.

In [None]:



df_bioc.transform_column("frac_avg", logit_transform, "frac_avg_logit")["frac_avg_logit"]

We'll also need to annotate each node on the graph with its nucleotide identifier.

In [None]:
graphs = df["graph"].to_dict()

Annotate nucleotide and entropy.

In [None]:
import networkx as nx 

def annotate_nucleotide(G: nx.Graph, sequence: str):
    nucleotides = sorted("AUGC")
    for i, letter in enumerate(sequence):
        G.nodes[i]["nucleotide"] = letter
        G.nodes[i]["nucleotide_idx"] = nucleotides.index(letter) + 1
    return G

def annotate_entropy(G: nx.graph, entropy_vector: np.ndarray):
    for node, entropy in zip(G.nodes(), entropy_vector):
        G.nodes[node]["entropy"] = entropy
    return G

In [None]:
def get_entropy_vector(df, row):
    r = df.loc[row]
    entropy_cols = sorted([c for c in df.columns if "shannon" in c])
    return r[entropy_cols].values

In [None]:
for idx, g in graphs.items():
    seq = df.loc[idx]["seq"]
    g = annotate_nucleotide(g, seq)
    
    entropy_vec = get_entropy_vector(df, idx)
    g = annotate_entropy(g, entropy_vec)
    graphs[idx] = g

In [None]:
graphs[761].nodes(data=True)

## Transformation to graph data structures

We're now going to make the feature matrix and adjacency matrix for each graph.
The key here is that we have to pad it to a particular size
in order for the operations to work correctly.

In [None]:
from jax import jit

def prep_feats(F, size):
    # F is of shape (n_nodes, n_feats)
    return np.pad(
        F,
        [
            (0, size - F.shape[0]),
            (0, 0)
        ],
    )

def prep_adjs(A, size):
    # A is of shape (n_nodes, n_nodes)
    return np.pad(
        A,
        [
            (0, size-A.shape[0]),
            (0, size-A.shape[0]),
        ],
    )

In [None]:
largest_graph_size = max(len(g) for g in graphs.values())

pd.Series(graphs)

In [None]:
def feat_matrix(G):
    feats = []
    for n, d in G.nodes(data=True):
        feat_vect = np.array([d["nucleotide_idx"], d["entropy"]])
        feats.append(feat_vect)
    feats = np.stack(feats)
    return feats

In [None]:
from tqdm.auto import tqdm
feat_matrices = dict()
for idx, graph in tqdm(graphs.items()):
    feat_matrices[idx] = prep_feats(feat_matrix(graph), largest_graph_size)


In [None]:
adj_matrices = dict()
for idx, graph in tqdm(graphs.items()):
    adj_matrices[idx] = prep_adjs(np.array(nx.adjacency_matrix(graph).todense()), largest_graph_size)

In [None]:
pd.Series(adj_matrices, name="adj")

In [None]:
pd.Series(feat_matrices, name="feats")

In [None]:
graph_matrices = dict()
for (idx, adj), (idx2, feat) in zip(adj_matrices.items(), feat_matrices.items()):
    graph_matrices[idx] = np.concatenate([adj, feat], axis=1)

Now, we can start designing a graph attention network to do this!

Firstly, we need a node embedding layer. 
For this, we will borrow inspiration from the language modelling world.
Our "vocabulary" is the letters "AUGC",
so we'll use a learnable embedding for each letter.
Every node feature vector's first slot is dedicated to an integer value
that we can use to index into the embedding vector.
We'll make the embedding vector length 256,
just for funzies.

In [None]:
from jax import random

rng = random.PRNGKey(99)
# vocab_size = 4
# embedding_size = 256
# embedding_matrix = random.normal(rng, shape=(vocab_size, embedding_size))

# indices = np.array([0.0, 1.0, 1.0, 3.0, 2.0]).astype(int)
# np.take(embedding_matrix, indices, axis=0).shape

When it comes to GNN operations,
some involve the feature matrix only,
others involve the adjacency matrix only,
and yet others involve both the adjacency and feature matrices.

To simplify the representation of a graph,
let's consider the case where we have only a 2D matrix.
It is of size (num_nodes, num_nodes + num_features).
What do they semantically mean?

- The (num_nodes, num_nodes) portion (left side of the matrix) is the adjacency matrix.
- The (num_nodes, num_features) portion (right side of the matrix) is the feature matrix.

In each step, we can accept the entire thing as one piece, and then split accordingly.

Let's call this matrix the "graph matrix"

Because one graph is one sample, its shape, then, is defined as `(num_nodes, num_nodes + num_features)`.

In [None]:
from jax import random

In [None]:
def select_feats(graph_mat, num_nodes: int):
    return graph_mat[:, num_nodes:]

def select_adj(graph_mat, num_nodes: int):
    return graph_mat[:, :num_nodes]

def RnaGraphEmbedding(num_nodes: int, embedding_size: int):
    vocab_size = 4
    def init_fun(rng, input_shape):
        """
        :param input_shape: (num_nodes, num_nodes + num_features)
        """
        num_nodes, num_nodes_features = input_shape
        num_features = num_nodes_features - num_nodes

        embedding_matrix = random.normal(rng, shape=(vocab_size, embedding_size))
        # Add a zeros vector to the beginning for padded vector.
        embedding_matrix = np.concatenate([np.zeros((1, embedding_size)), embedding_matrix])
        return (num_nodes, num_nodes + embedding_size,), embedding_matrix
    
    def apply_fun(params, inputs, **kwargs):
        """
        :param inputs: The node feature matrix.
            We assume that the node feature matrix's first column
            is the embedding index.
        """
        embedding_matrix = params
        adj = select_adj(inputs, num_nodes)
        feats = select_feats(inputs, num_nodes)

        indices = np.take(feats, 0, axis=1).astype(int)
        embedding = np.take(embedding_matrix, indices, axis=0)
        
        output = np.concatenate([adj, embedding], axis=1)
        return output
        
    return init_fun, apply_fun

In [None]:
init_fun, apply_fun = RnaGraphEmbedding(num_nodes=170, embedding_size=256)
output_shape, params = init_fun(rng, input_shape=(170, 2))

out = apply_fun(params, (graph_matrices[763]))
# out[0].shape, out[1].shape
out.shape

We also need a layer that simply extracts out the rest of the node features.

In [None]:
from jax import lax
def NodeFeatureExtractor(num_nodes: int):
    def init_fun(rng, input_shape):
        num_nodes, num_feats = input_shape
        return (num_nodes, num_nodes + input_shape[-1] - 1,), ()
    
    def apply_fun(params, inputs, **kwargs):
        adj = select_adj(inputs, num_nodes)
        feats = select_feats(inputs, num_nodes)

        return np.concatenate([adj, feats[:, 1:]], axis=1)
    
    return init_fun, apply_fun
    
    
init_fun, apply_fun = NodeFeatureExtractor(num_nodes=170)
_, params = init_fun(rng, (170, 2))
out = apply_fun(params, (graph_matrices[763]))
out.shape, _

Now we can do the fan-out operation.

In [None]:
from jax.experimental import stax

In [None]:
node_featurization = stax.serial(
    stax.FanOut(2),
    stax.parallel(
        RnaGraphEmbedding(num_nodes=170, embedding_size=256),
        NodeFeatureExtractor(num_nodes=170),
    ),
)

In [None]:
# Test-drive
init_fun, apply_fun = node_featurization
output_shape, params = init_fun(rng, input_shape=(170, 2,))

inputs = apply_fun(params, (graph_matrices[763]))
output_shape

In [None]:
from jax.tree_util import tree_map
from functools import partial


def GraphFanInConcat(num_nodes: int, axis: int = -1):
    def init_fun(rng, input_shape):
        
        num_feats = np.sum(np.array([i[1] - num_nodes for i in input_shape]))
        return (num_nodes, num_nodes + num_feats), ()

    def apply_fun(params, inputs, **kwargs):
        adj = tree_map(partial(select_adj, num_nodes=170), inputs)
        feats = tree_map(partial(select_feats, num_nodes=170), inputs)
        feats = np.concatenate(feats, axis=1)
        return np.concatenate([adj[0], feats], axis=1)

    return init_fun, apply_fun


init_fun, apply_fun = stax.serial(
    stax.FanOut(2),
    stax.parallel(
        RnaGraphEmbedding(num_nodes=170, embedding_size=256),
        NodeFeatureExtractor(num_nodes=170),
    ),
    GraphFanInConcat(num_nodes=170)
)
output_shape, params = init_fun(rng, input_shape=(170, 2))
out = apply_fun(params, graph_matrices[763])

After that, we do the Graph attention layer.
I've written this layer a few times,
but I'd like to do this layer in a fashion
that makes sense for this problem.

The Graph attention layer accepts in a graph matrix.
It then computes a node-by-node similarity matrix based on the node information.


In [None]:
from functools import partial
from jax import vmap

def concat_nodes(node1, node2):
    """Concatenate two nodes together."""
    return np.concatenate([node1, node2])


def concatenate(node: np.ndarray, node_feats: np.ndarray):
    """Concatenate node with each node in node_feats.

    Behaviour is as follows.
    Given a node with features `f_0` and stacked node features
    `[f_0, f_1, f_2, ..., f_N]`,
    return a stacked concatenated feature array:
    `[(f_0, f_0), (f_0, f_1), (f_0, f_2), ..., (f_0, f_N)]`.
    
    :param node: A vector embedding of a single node in the graph.
        Should be of shape (n_input_features,)
    :param node_feats: Stacked vector embedding of all nodes in the graph.
        Should be of shape (n_nodes, n_input_features)
    :returns: A stacked array of concatenated node features.
    """
    return vmap(partial(concat_nodes, node))(node_feats)


def concatenate_node_features(node_feats):
    """Return node-by-node concatenated features.
    
    Given a node feature matrix of shape (n_nodes, n_features),
    this returns a matrix of shape (n_nodes, n_nodes, 2*n_features).
    """
    outputs = vmap(partial(concatenate, node_feats=node_feats))(node_feats)
    return outputs

outputs = concatenate_node_features(out)


In [None]:
from jax import nn

def AttentiveMessagePassingLayer(num_nodes: int, hidden_dims: int):
    """Attentive message passing on a graph.
    
    We use a feed forward neural network to learn
    the weights on which a message passing operator should work.
    
    The input is the graph matrix. Should be of size (num_nodes, num_nodes + num_feats).
    The output is also of the size (num_nodes, num_nodes + num_feats).
    """

    def init_fun(rng, input_shape):
        num_nodes, n_node_feats = input_shape
        num_feats = n_node_feats - num_nodes
        k1, k2, k3, k4 = random.split(rng, 4)
        
        # Params for neural network transformation of node concatenated features.
        w1 = random.normal(k1, shape=(2 * num_feats, hidden_dims)) * 0.001
        b1 = random.normal(k2, shape=(hidden_dims,)) * 0.001
        w2 = random.normal(k3, shape=(hidden_dims,)) * 0.001
        b2 = random.normal(k4, shape=(1,)) * 0.001
        
        params = w1, b1, w2, b2
        output_shape = (num_nodes, num_nodes + num_feats)
        return output_shape, params
    
    def apply_fun(params, inputs, **kwargs):
        ### START ATTENTIVE MATRIX CALCULATION ###
        w1, b1, w2, b2 = params
        adj = select_adj(inputs, num_nodes)
        feats = select_feats(inputs, num_nodes)
        node_by_node_concat = concatenate_node_features(feats)
        
        # Neural network piece here.
        a1 = nn.relu(np.dot(node_by_node_concat, w1) + b1)
        a2 = np.dot(a1, w2) + b2

        attentive_adj = adj * a2
        ### END ATTENTIVE MATRIX CALCULATION ###
        mp = np.dot(attentive_adj, feats)
        return np.concatenate([adj, mp], axis=1)

    return init_fun, apply_fun

Now, we need an attentive graph summation layer.

If we define attention as just "sample driven, fancy ways of calculating linear weighting...",
then we can use a neural network to calculate weights per sample.

The input is the graph matrix of shape `(num_nodes, num_nodes + num_feats)`.
Inside this function, we take the matrix and do a neural network forward pass
to produce a vector that is of length `(num_nodes,)`,
which we can call the "attentive weights".
Finally, we take feature matrix portion of the graph matrix
and dot product it against the attentive weights
to arrive at the summed up vector
to give effectively a graph-level vector.

In [None]:
def AttentiveGraphSummation(num_nodes, hidden_dims: int = 2048):
    def init_fun(rng, input_shape):
        num_nodes, num_node_feats = input_shape
        num_feats = num_node_feats - num_nodes
        
        k1, k2, k3, k4 = random.split(rng, 4)
        # Params for neural network transformation of node concatenated features.
        w1 = random.normal(k1, shape=(num_feats, hidden_dims)) * 0.001
        b1 = random.normal(k2, shape=(hidden_dims,)) * 0.001
        w2 = random.normal(k3, shape=(hidden_dims, num_nodes)) * 0.001
        b2 = random.normal(k4, shape=(num_nodes,)) * 0.001
        params = w1, b1, w2, b2
        output_shape = (num_feats,)
        return output_shape, params
    
    def apply_fun(params, inputs, **kwargs):
        w1, b1, w2, b2 = params
        feats = select_feats(inputs, num_nodes)
        
        # Neural network piece here.
        a1 = nn.relu(np.dot(feats, w1) + b1)
        a2 = np.tanh(np.dot(a1, w2) + b2)
        node_attn_weights = np.tanh(np.sum(a2, axis=0))

        # Weighted summation happens here
        out = np.dot(node_attn_weights, feats)
        return out
    return init_fun, apply_fun

In [None]:
def AttentionEverywhereGNN(num_nodes: int):

    init_fun, apply_fun = stax.serial(
        stax.FanOut(2),
        stax.parallel(
            RnaGraphEmbedding(num_nodes=num_nodes, embedding_size=256),
            NodeFeatureExtractor(num_nodes=num_nodes),
        ),
        GraphFanInConcat(num_nodes=num_nodes),
        AttentiveMessagePassingLayer(num_nodes=num_nodes, hidden_dims=256),
        AttentiveGraphSummation(num_nodes=num_nodes),
        stax.Dense(256),
        stax.Relu,
        stax.Dense(256),
        stax.Relu,
        stax.Dense(1),
    )
    
    return init_fun, apply_fun

init_fun, apply_fun = AttentionEverywhereGNN(170)

output_shape, params = init_fun(rng, input_shape=(170, 2))
out = apply_fun(params, graph_matrices[763])


In [None]:
out

## Train neural network


We can now train the model!

In [None]:
X = np.stack(pd.Series(graph_matrices).values)
y = np.stack(df["frac_avg_logit"].values).reshape(-1, 1)

In [None]:
def train_test_split(rng, X, y, train_fraction=0.7):
    indices = np.arange(len(X))
    indices = random.permutation(rng, indices)
    num_train = int(len(X) * train_fraction)
    train_idxs = indices[:num_train]
    test_idxs = indices[num_train:]
    return train_idxs, test_idxs

train_idxs, test_idxs = train_test_split(rng, X, y)
X_train = X[train_idxs]
X_test = X[test_idxs]
y_train = y[train_idxs]
y_test = y[test_idxs]


In [None]:
from jax import grad, vmap

## Training loop
def mse(y_true: np.array, y_pred: np.array):
    return np.mean(np.power(y_true - y_pred, 2))


def mseloss(params, model, X, y):
    """MSE loss."""
    y_pred = vmap(partial(model, params))(X)
    return mse(y, y_pred)


dmseloss = grad(mseloss)
init_fun, model = AttentionEverywhereGNN(170)
_, params = init_fun(rng, input_shape=(170, 2))

mseloss(params, model, X, y)