In [1]:
import numpy as np
import torch
import torch.nn as nn
from sklearn.metrics import mean_squared_error
from torch import from_numpy
from torch.utils.data import DataLoader, TensorDataset

from neural_net import ResidualDegrade
from preprocessing import one_hot_encode_sequences, read_all_data, \
    read_original_predictions
from util import train_network

In [2]:
NETWORK = ResidualDegrade
PARAMS = {
    'stage4_conv_channels': 198,
    'stage3_pool_kernel_size': 8,
    'stage2_conv_kernel_size': 3,
    'stage1_conv_kernel_size': 7,
    'stage1_conv_channels': 97
}
DEG_MODEL = "a_minus"

# Data pre-processing

In [3]:
# Load data

df = read_all_data("data/ss_out.txt",
                   "data/3U_sequences_final.txt",
                   "data/3U.models.3U.40A.seq1022_param.txt",
                   "data/3U.models.3U.00A.seq1022_param.txt")
df

Unnamed: 0_level_0,sequence,secondary_structure,free_energy,secondary_structure_prob,log2_deg_rate_a_plus,log2_x0_a_plus,onset_time_a_plus,log2_deg_rate_a_minus,log2_x0_a_minus,onset_time_a_minus
id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1
S0_M_T1,TGTCCCCGGGTCTTCCAACGGACTGGCGTTGCCCCGGTTCACTGGG...,.(.((((((((....(((((......)))))((((((....)))))...,-45.92,".(.((((((((.,,,{{..(|||{((|{..,{{||||,,,.}))))...",,,,,,
S0_M_T10,AGATTTTTGGTTCAATATGCTCCTTGAGTGGAGTCTTAGTGATTGC...,........(((((.....(((((......)))))...(((((..((...,-31.17,"........(((((,.,..(((((......))))}...{({({..({...",-2.7469,2.7887,1.0,-2.1721,2.5964,1.0
S0_M_T100,ACCCGGCGCCGCTCGACCCGGAGCGAGGAGTTGACCCGGAGCGAGG...,....((((((.((((..((((..(((....)))..))))..))))....,-41.95,"....((((((.((((..((((.,({(....})).,))))..)}}),...",,,,,,
S0_M_T1000,ATGAGGGCTGGAATTTGCATTGAAACACTGGTCCAGTCGCTGTGTA...,.....(((((((...((........))....)))))))..(((......,-23.18,"...,((((((((.,.({........}}..,,))))))).,|((,,....",,,,,,
S0_M_T1001,CCTTAGTGCCCTTAAAATAATGATTTAAGCATTTTACTGTATGTAT...,....(.((((((((((.(((.(......((.(((((((((.(.......,-30.84,"....{.(((((((((({(((.{......((.((((((((,.{.......",,,,,,
...,...,...,...,...,...,...,...,...,...,...
S3_H_T9995,TTGTAGCTGTCAATTGTATTTAATATACTTTTTTGTCTTTTTAATT...,((((..(((.((((((.((((((...........(((............,-18.50,",{(((((((.,(({{((((.....}}}}......(((............",,,,,,
S3_H_T9996,AAAACACCACTACATATGTTTCTCATAAGCGCAACTGTAGTGTTAT...,((((((((((((((..(((..((....)).)))..))))))).......,-19.40,"((((((((((((((..(((..,.....}}.)})..))))))).......",-2.5808,3.4966,1.0,-2.3105,3.3307,1.0
S3_H_T9997,AGGATTTTTTTTTTCACCAATGCTCTTTAATACACACTTGCCTATA...,.(((..((((((((.......((................))........,-20.95,",((,((((((((({...,,,,({..............,,)}....)...",,,,,,
S3_H_T9998,GGTGCTTCAAAGAGTGATTACCCACTAACTAATGAACCCAGACTGT...,((((..(((.....))).))))..((((..((((((((((((.(((...,-28.53,"((((..(((.....))).))))..((((..(((((((((,((.(((...",,,,,,


In [4]:
train_df = df.dropna()
train_df

Unnamed: 0_level_0,sequence,secondary_structure,free_energy,secondary_structure_prob,log2_deg_rate_a_plus,log2_x0_a_plus,onset_time_a_plus,log2_deg_rate_a_minus,log2_x0_a_minus,onset_time_a_minus
id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1
S0_M_T10,AGATTTTTGGTTCAATATGCTCCTTGAGTGGAGTCTTAGTGATTGC...,........(((((.....(((((......)))))...(((((..((...,-31.17,"........(((((,.,..(((((......))))}...{({({..({...",-2.746900,2.7887,1.0,-2.17210,2.5964,1.0
S0_M_T1006,TAGATAGAGATCATCTTTACAGTTCCTCGGGAAAATGTGCTTGTGA...,...(((((((((((...((((.((((...))))..))))...))))...,-26.55,".,,({(((({((((...((((.((((...))))..))))...))))...",-2.495200,3.5146,1.0,-1.94970,3.1963,1.0
S0_M_T1009,TAGTTATTGTGTGTTGCTAATCATTGACTGTAGTCCCAGTCTGGGA...,.....(((((((((((.(((.((..(((((.(((..((((((((((...,-33.05,"...,,(((((((((((.{((.{(..(((((.(((..((((((((((...",-2.550700,2.7105,1.0,-1.51500,2.8747,1.0
S0_M_T1013,TGATTCTAGTATATAATATTTTTGTCACGCACCTGCTGACTTAGGA...,.......................((((.((....))))))...((....,-20.70,".......,{,,............((((.((....))))))..,((....",-2.327900,3.7761,1.0,-1.89040,3.2967,1.0
S0_M_T1014,TTCTAGACTTTCCAAGTATGTTGTCTTTCCAATGGTGCGACAGAGC...,.............(((..(((((((((......(((((......))...,-22.82,"......,,,,...|||,.,(((((,,{{((.{((((((......))...",-1.623200,1.6160,4.8,-2.09580,2.1356,1.0
...,...,...,...,...,...,...,...,...,...,...
S3_H_T9985,GTCCTTATTTACATGTTTCATTGAGCCCTTTTTGATGTGATTCTTG...,.............((((..(((((.........(((((((.((((....,-18.22,",....,,,,...,((((,.((({{.........(((((((.((((....",-2.027200,2.6826,1.0,-1.59040,2.5908,1.0
S3_H_T9987,TCAATGGTTACAGGTTTCAAACATTCTTCAAAATATCTTCTTTTTG...,(((((((((((((((((.......((..(((((........)))))...,-30.56,(((((((((((((((((.......{{..(((((........)))))...,-2.589200,2.7555,1.0,-2.05310,2.5282,1.0
S3_H_T9989,TGAAAGCACAGAGGGGCTGAGATTCTAAGGGCACTTCATGTTTTTT...,.(((((((..((((.(((...........))).)))).)))))))....,-24.03,".((((({(.,((((,(((.,,...,....))).)))).))))))),...",0.023414,1.2948,4.5,-0.75861,1.4902,3.0
S3_H_T9990,AATTAAAGAGAGAGAGAGACGGAGAACACGGTGGGTTTACTAGCGC...,.........((.(((((((((.((....((.(((.....))).)))...,-27.22,".........{(.(((((((((.,{....((.(((.....))).)),...",-2.463100,1.5812,1.0,-2.71790,1.2301,1.0


In [5]:
train_sequences = one_hot_encode_sequences(train_df["sequence"])
train_sequences

array([[[1., 0., 1., ..., 0., 0., 0.],
        [0., 0., 0., ..., 1., 1., 0.],
        [0., 1., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 1.]],

       [[0., 1., 0., ..., 0., 1., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 1., ..., 0., 0., 0.],
        [1., 0., 0., ..., 1., 0., 1.]],

       [[0., 1., 0., ..., 0., 1., 0.],
        [0., 0., 0., ..., 1., 0., 0.],
        [0., 0., 1., ..., 0., 0., 0.],
        [1., 0., 0., ..., 0., 0., 1.]],

       ...,

       [[0., 0., 1., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 1., 0., ..., 0., 0., 1.],
        [1., 0., 0., ..., 1., 1., 0.]],

       [[1., 1., 0., ..., 1., 0., 1.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 1., 0.],
        [0., 0., 1., ..., 0., 0., 0.]],

       [[1., 1., 1., ..., 0., 0., 1.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 1., 1., 0.]]], dtype=float32)

In [6]:
X_train = from_numpy(train_sequences)
y_train = from_numpy(train_df[f"log2_deg_rate_{DEG_MODEL}"].to_numpy(np.float32).reshape(-1, 1))

In [7]:
compare_df, a_minus_clip, a_plus_clip = \
    read_original_predictions("data/models_full_dg.txt")

In [8]:
# We'll be testing against the sequences for which we do not have
# a calculated degradation rate
test_index = df.index.difference(train_df.index)

test_df = compare_df.loc[test_index]
test_df["sequence"] = df["sequence"].loc[test_index]
test_df

Unnamed: 0_level_0,a_plus,a_minus,sequence
id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
S0_M_T1,-1.12880,-1.1807,TGTCCCCGGGTCTTCCAACGGACTGGCGTTGCCCCGGTTCACTGGG...
S0_M_T100,-1.58920,-1.5481,ACCCGGCGCCGCTCGACCCGGAGCGAGGAGTTGACCCGGAGCGAGG...
S0_M_T1000,-0.26523,-1.4689,ATGAGGGCTGGAATTTGCATTGAAACACTGGTCCAGTCGCTGTGTA...
S0_M_T1001,-0.76394,-3.3179,CCTTAGTGCCCTTAAAATAATGATTTAAGCATTTTACTGTATGTAT...
S0_M_T1002,-2.11520,-2.8701,GTAGGCCATGATAATAGGTCATATGTTGTGTTTGGTTCTGTGTTCA...
...,...,...,...
S3_H_T9994,-1.43950,-2.5863,TTTGGCTATAGAATCAGGCGGCCGTTTTATGTGGGATTTGACGACC...
S3_H_T9995,-0.96243,-3.6068,TTGTAGCTGTCAATTGTATTTAATATACTTTTTTGTCTTTTTAATT...
S3_H_T9997,-1.32350,-2.6724,AGGATTTTTTTTTTCACCAATGCTCTTTAATACACACTTGCCTATA...
S3_H_T9998,-1.99170,-2.3379,GGTGCTTCAAAGAGTGATTACCCACTAACTAATGAACCCAGACTGT...


In [9]:
test_sequences = one_hot_encode_sequences(test_df["sequence"])
test_sequences

array([[[0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 1., 0., ..., 0., 1., 0.],
        [1., 0., 1., ..., 1., 0., 1.]],

       [[1., 0., 0., ..., 0., 0., 1.],
        [0., 1., 1., ..., 1., 1., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.]],

       [[1., 0., 0., ..., 1., 0., 0.],
        [0., 0., 0., ..., 0., 1., 0.],
        [0., 0., 1., ..., 0., 0., 1.],
        [0., 1., 0., ..., 0., 0., 0.]],

       ...,

       [[1., 0., 0., ..., 1., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 1., 1., ..., 0., 0., 1.],
        [0., 0., 0., ..., 0., 1., 0.]],

       [[0., 0., 0., ..., 0., 0., 1.],
        [0., 0., 0., ..., 0., 0., 0.],
        [1., 1., 0., ..., 1., 1., 0.],
        [0., 0., 1., ..., 0., 0., 0.]],

       [[0., 0., 0., ..., 0., 0., 0.],
        [1., 0., 0., ..., 0., 0., 0.],
        [0., 0., 1., ..., 0., 0., 0.],
        [0., 1., 0., ..., 1., 1., 1.]]], dtype=float32)

In [10]:
X_test = from_numpy(test_sequences)
y_test = from_numpy(test_df[DEG_MODEL].to_numpy(np.float32).reshape(-1, 1))

# NN

In [11]:
network = NETWORK(**PARAMS)
network

ResidualDegrade(
  (stage1): Sequential(
    (0): Conv1d(4, 97, kernel_size=(7,), stride=(1,))
    (1): BatchNorm1d(97, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
  )
  (stage2): ResidualLayer(
    (conv1): Conv1d(97, 97, kernel_size=(3,), stride=(1,), padding=(1,))
    (norm1): BatchNorm1d(97, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (conv2): Conv1d(97, 97, kernel_size=(3,), stride=(1,), padding=(1,))
    (norm2): BatchNorm1d(97, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (stage3): AvgPool1d(kernel_size=(8,), stride=(8,), padding=(0,))
  (stage4): Sequential(
    (0): Conv1d(97, 198, kernel_size=(13,), stride=(1,))
    (1): BatchNorm1d(198, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
  )
  (stage5): Linear(in_features=198, out_features=1, bias=True)
)

In [12]:
train_loader = DataLoader(TensorDataset(X_train, y_train),
                          batch_size=4, shuffle=True)
optimizer = torch.optim.Adam(network.parameters(), lr=0.001)
train_network(network, train_loader, 5, nn.MSELoss(), optimizer)

Epoch 1
Epoch 2
Epoch 3
Epoch 4
Epoch 5


In [13]:
with torch.no_grad():
    y_pred = network(X_test)

In [14]:
mean_squared_error(y_true=y_test, y_pred=y_pred)

0.2820798