In [None]:
%load_ext autoreload
%autoreload 2
%config InlineBackend.figure_format='retina'

## Cross-validation of model performance

In this notebook, we wish to cross-validate our GNN model's performance. In particular, I do want to see whether the trends we observed are reproducible across different PRNGKeys.

## Load data

We start with a pickled version of our graph data.

In [None]:
# 1. Start a W&B run
# wandb.init(project='drosha-gnn', entity='ericmjl')

# 2. Save model inputs and hyperparameters
# config = wandb.config
# config.learning_rate = 0.01

# Model training here

# 3. Log metrics over time to visualize performance
# wandb.log({"loss": loss})
    

## Imports

In [None]:
from sklearn.pipeline import Pipeline

import pickle as pkl
from pyprojroot import here
from drosha_gnn.graph import to_networkx
from drosha_gnn import annotate
import pandas as pd
import janitor
import jax.numpy as np
import networkx as nx

## Read Data

### Raw Dataframe

In [None]:
df = (
    pd.read_csv("https://drosha-data.fly.dev/drosha/combined.csv?_stream=on&_sort=rowid&replicate__exact=1&_size=max")
    .join_apply(lambda row: to_networkx(row["dot_bracket"]), "graph")
)
df.head()

### Nucleotide Entropy Data

In [None]:
entropy = pd.read_csv("https://drosha-data.fly.dev/drosha/entropy.csv?_labels=on&_stream=on&_sort=rowid&rowid__lte=847&_size=max")

In [None]:
entropy

In [None]:
from tqdm.auto import tqdm
from drosha_gnn.data import make_graph, make_graph_matrices
from drosha_gnn.data import prep_feats, prep_adjs, feat_matrix
graphs = dict()
graph_matrices = dict()
for sample_idx in tqdm(df.index):
    graphs[sample_idx] = make_graph(sample_idx, df, entropy)
    graph_matrices[sample_idx] = make_graph_matrices(sample_idx, df, entropy)

## Train test splits

We need different train-test splits in order to test the uncertainty in model performance.

In [None]:
from jax.random import PRNGKey, split

key = PRNGKey(99)


In [None]:
from drosha_gnn.training import train_test_split
from drosha_gnn.data import split_graph_data

In [None]:
# Just test-driving
X_train, X_test, y_train, y_test = split_graph_data(key, graph_matrices, df)
y_train.shape

In [None]:
from drosha_gnn.models import AttentionEverywhereGNN

In [None]:

from drosha_gnn.models import make_model_and_params

model, params = make_model_and_params(key, AttentionEverywhereGNN, input_shape=(170, 2), num_nodes=170)

In [None]:
model(params, X_train[12])

In [None]:
## Test-drive model
from jax import vmap
from functools import partial

vmap(partial(model, params))(X_train)

In [None]:
from drosha_gnn.training import fit, dmseloss, mseloss

In [None]:
losses_train, states, opt_get_params = fit(model, params, X_train, y_train)

In [None]:
# from jax.experimental.optimizers import adam
# from jax import jit, vmap
# from typing import Callable
# from jax.tree_util import Partial

from drosha_gnn.training import states_losses, best_params

best_param, best_idx = best_params(states, model, X_test, y_test, opt_get_params, mseloss)

In [None]:
y_preds  = vmap(partial(model, best_param))(X_test)
plt.scatter(y_preds.squeeze(), y_test.squeeze())
plt.gca().set_aspect("equal")

In [None]:
from drosha_gnn.training import mse

mse(y_test, y_preds)

In [None]:
best_idx

In [None]:
# import matplotlib.pyplot as plt
# plt.plot(losses_train)
# plt.plot(test_losses)