# Introduction
Here we explore training of neural networks that predict the GERP score from the sequence. We try multiple network architectures and report the results of each.


Inversion symmetry of score implies local window should extend both sides equally - since if one sequence affects only in one direction, its revetrse complement will affect in the opposite direction

## Dependencies

In [40]:
import logging
from data.load import read_sequence, read_annotation_generator, examine_annotation, read_gerp_scorer
from data.paths import chr17_paths # paths to source data files
from data.process import get_train_test_x_y

from score_modeling import LocalWindowModel, LocalTransformerEncoderModel, ModelTrainer

import jax
import jax.numpy as jnp
import pandas as pd
import optax

In [4]:
# random seed
SEED = 200

# display options
pd.set_option('display.max_rows', 100)
pd.set_option('display.precision', 3)

In [19]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


## Load data

In [6]:
# start the analysis with human chromosome 17
paths = chr17_paths

# get the raw sequence dictionary, from a FASTA file
seq_dict = read_sequence(paths.sequence)

# # (optional) examine annotations, from the annotation GFF file
# examine_annotation(paths.annotation)

# get the annotated sequence generator function, from the annotation GFF file
seq_records_gen = read_annotation_generator(paths.annotation, seq_dict=seq_dict)

# get GERP retrieval function, from the BigWig file
gerp_scorer = read_gerp_scorer(paths.gerp)

In [53]:
# One hot encode the sequence in a sliding window
x_train, y_train, x_test, y_test = get_train_test_x_y(seq_records_gen, gerp_scorer, ['CDS'])

INFO:root:Sequence 17 ...
INFO:root:Traversing feature tree ...
INFO:root:Traversed feature tree in 202,358 iterations. Extracted 54,544 features.
INFO:root:Processing feature 0. x_train shape: (0, 76)
INFO:root:Processing feature 100. x_train shape: (9992, 76)
INFO:root:Processing feature 200. x_train shape: (15608, 76)
INFO:root:Processing feature 300. x_train shape: (23115, 76)
INFO:root:Processing feature 400. x_train shape: (30534, 76)
INFO:root:Processing feature 500. x_train shape: (37718, 76)
INFO:root:Processing feature 600. x_train shape: (45469, 76)
INFO:root:Processing feature 700. x_train shape: (52701, 76)


KeyboardInterrupt: 

# Train Models
Here we train different neural network architectures, and rate them with the mean square error loss as well as R^2, which measures the proportion of the variance in the scores we are trying to predict that is explained by the model.

## Local Window Fully Connected Model

In [11]:
model1 = LocalWindowModel()

In [21]:
model_trainer = ModelTrainer(model1, epochs=10000, optimizer=optax.sgd)

INFO:root:Epoch 0. Training loss: 7.9045. R^2: -0.1179.
INFO:root:Epoch 10. Training loss: 7.0631. R^2: 0.0011.
INFO:root:Epoch 20. Training loss: 7.0472. R^2: 0.0033.
INFO:root:Epoch 30. Training loss: 7.0077. R^2: 0.0089.
INFO:root:Epoch 40. Training loss: 6.8796. R^2: 0.0270.
INFO:root:Epoch 50. Training loss: 6.5478. R^2: 0.0740.
INFO:root:Epoch 60. Training loss: 6.4314. R^2: 0.0904.
INFO:root:Epoch 70. Training loss: 6.9480. R^2: 0.0174.
INFO:root:Epoch 80. Training loss: 6.4350. R^2: 0.0899.
INFO:root:Epoch 90. Training loss: 6.3442. R^2: 0.1028.
INFO:root:Epoch 100. Training loss: 6.7097. R^2: 0.0511.
INFO:root:Epoch 110. Training loss: 6.4268. R^2: 0.0911.
INFO:root:Epoch 120. Training loss: 6.3589. R^2: 0.1007.
INFO:root:Epoch 130. Training loss: 6.3735. R^2: 0.0986.
INFO:root:Epoch 140. Training loss: 6.3497. R^2: 0.1020.
INFO:root:Epoch 150. Training loss: 6.3288. R^2: 0.1049.
INFO:root:Epoch 160. Training loss: 6.3154. R^2: 0.1068.
INFO:root:Epoch 170. Training loss: 6.303

LocalWindowModel(
  layers=[
    Linear(
      weight=f32[38,76],
      bias=f32[38],
      in_features=76,
      out_features=38,
      use_bias=True
    ),
    Linear(
      weight=f32[38,38],
      bias=f32[38],
      in_features=38,
      out_features=38,
      use_bias=True
    ),
    Linear(
      weight=f32[38,38],
      bias=f32[38],
      in_features=38,
      out_features=38,
      use_bias=True
    ),
    Linear(
      weight=f32[1,38],
      bias=f32[1],
      in_features=38,
      out_features=1,
      use_bias=True
    )
  ],
  extra_bias=f32[1]
)

In [34]:
model_trainer.train(x_train, y_train)

INFO:root:Epoch 0. Training loss: 3.4208. R^2: 0.5162.
INFO:root:Epoch 100. Training loss: 3.2799. R^2: 0.5361.
INFO:root:Epoch 200. Training loss: 3.2976. R^2: 0.5336.
INFO:root:Epoch 300. Training loss: 3.4027. R^2: 0.5188.
INFO:root:Epoch 400. Training loss: 3.2364. R^2: 0.5423.
INFO:root:Epoch 500. Training loss: 3.4585. R^2: 0.5109.
INFO:root:Epoch 600. Training loss: 3.4584. R^2: 0.5109.
INFO:root:Epoch 700. Training loss: 3.4705. R^2: 0.5092.
INFO:root:Epoch 800. Training loss: 3.1952. R^2: 0.5481.
INFO:root:Epoch 900. Training loss: 3.3950. R^2: 0.5199.
INFO:root:Epoch 1,000. Training loss: 3.2941. R^2: 0.5341.
INFO:root:Epoch 1,100. Training loss: 3.3498. R^2: 0.5262.
INFO:root:Epoch 1,200. Training loss: 3.2208. R^2: 0.5445.
INFO:root:Epoch 1,300. Training loss: 3.2711. R^2: 0.5374.
INFO:root:Epoch 1,400. Training loss: 3.6242. R^2: 0.4874.
INFO:root:Epoch 1,500. Training loss: 3.5987. R^2: 0.4911.
INFO:root:Epoch 1,600. Training loss: 3.2915. R^2: 0.5345.
INFO:root:Epoch 1,7

KeyboardInterrupt: 

In [36]:
loss_test = model_trainer.model.loss(x_test, y_test)
var_y = jnp.var(y_test)
logging.info(f'Test loss {loss_test:.3f}. R^2: {1-loss_test/var_y:.3f}.')

INFO:root:Test loss 3.742. R^2: 0.483.


## Local Transformer Model

In [49]:
model2 = LocalTransformerEncoderModel()

In [50]:
model2(x_train[0])

Array([2.5000381], dtype=float32)

In [51]:
model2.loss(x_train, y_train)

Array(11.390394, dtype=float32)

In [46]:
model_trainer2 = ModelTrainer(model2, epochs=100, optimizer=optax.sgd)

In [47]:
model_trainer2.train(x_train, y_train)

INFO:root:Epoch 0. Training loss: 11.3904. R^2: -0.6109.


KeyboardInterrupt: 