In [1]:
import numpy as np
import pandas as pd
import re
from scipy.sparse import coo_matrix, vstack

# Data pre-processing

In [2]:
with open("data/ss_out.txt", mode="r") as f:
    raw_data = f.read()

In [3]:
# Read file into a DataFrame

PATTERN = re.compile(r"""^([UCAG]{110})
([.()]{110}) \( {0,2}(-?[0-9]{1,2}\.[0-9]{2})\)
([.,|(){}]{110})""", re.IGNORECASE | re.MULTILINE)

data = (match.groups() for match in PATTERN.finditer(raw_data))
data = ((sequence, secondary_structure, float(free_energy), secondary_structure_prob)
        for (sequence, secondary_structure, free_energy, secondary_structure_prob) in data)

df = pd.DataFrame.from_records(data, columns=["sequence", "secondary_structure", "free_energy", "secondary_structure_prob"])

In [4]:
# Drop duplicates and transform sequences by replacing U with T

df.drop_duplicates(inplace=True)
df.reset_index(drop=True, inplace=True)
df["sequence"] = df["sequence"].str.replace("U", "T")
df

Unnamed: 0,sequence,secondary_structure,free_energy,secondary_structure_prob
0,TGTCCCCGGGTCTTCCAACGGACTGGCGTTGCCCCGGTTCACTGGG...,.(.((((((((....(((((......)))))((((((....)))))...,-45.92,".(.((((((((.,,,{{..(|||{((|{..,{{||||,,,.}))))..."
1,AGATTTTTGGTTCAATATGCTCCTTGAGTGGAGTCTTAGTGATTGC...,........(((((.....(((((......)))))...(((((..((...,-31.17,"........(((((,.,..(((((......))))}...{({({..({..."
2,ACCCGGCGCCGCTCGACCCGGAGCGAGGAGTTGACCCGGAGCGAGG...,....((((((.((((..((((..(((....)))..))))..))))....,-41.95,"....((((((.((((..((((.,({(....})).,))))..)}}),..."
3,ATGAGGGCTGGAATTTGCATTGAAACACTGGTCCAGTCGCTGTGTA...,.....(((((((...((........))....)))))))..(((......,-23.18,"...,((((((((.,.({........}}..,,))))))).,|((,,...."
4,CCTTAGTGCCCTTAAAATAATGATTTAAGCATTTTACTGTATGTAT...,....(.((((((((((.(((.(......((.(((((((((.(.......,-30.84,"....{.(((((((((({(((.{......((.((((((((,.{......."
...,...,...,...,...
89995,TTGTAGCTGTCAATTGTATTTAATATACTTTTTTGTCTTTTTAATT...,((((..(((.((((((.((((((...........(((............,-18.50,",{(((((((.,(({{((((.....}}}}......(((............"
89996,AAAACACCACTACATATGTTTCTCATAAGCGCAACTGTAGTGTTAT...,((((((((((((((..(((..((....)).)))..))))))).......,-19.40,"((((((((((((((..(((..,.....}}.)})..)))))))......."
89997,AGGATTTTTTTTTTCACCAATGCTCTTTAATACACACTTGCCTATA...,.(((..((((((((.......((................))........,-20.95,",((,((((((((({...,,,,({..............,,)}....)..."
89998,GGTGCTTCAAAGAGTGATTACCCACTAACTAATGAACCCAGACTGT...,((((..(((.....))).))))..((((..((((((((((((.(((...,-28.53,"((((..(((.....))).))))..((((..(((((((((,((.(((..."


In [5]:
# Read the sequence IDs

sequence_ids = \
  pd.read_table("data/3U_sequences_final.txt",
                header=None,
                names=["id", "sequence"],
                dtype={"id": str},
                # Drop the common prefix and suffix, and upcase the string
                converters={"sequence": lambda string: string[20:-20].upper()},
                index_col="sequence")

assert sequence_ids.index.is_unique
assert sequence_ids["id"].is_unique

sequence_ids

Unnamed: 0_level_0,id
sequence,Unnamed: 1_level_1
GATCAAATGCTAAAGAAAATATTGGTTTTAGTAATAATCTCTATGCTGAATTTAACTTTGGGAGATGCTGAAATTATTGAGGGTTAACATTACCGTTAAGTATTGAATCT,S1_H_T1
ATCTGGTAAATTAGGTTGATTTCTGGTTATGGAAAAAGCGCGAAAATGGGTCAGCAGTGTTCTTATTAAAATGAATTTCATAATAAATCAATTCAAGTAACGTGTACCTG,S1_H_T2
TAACTGAGCCTTATGATTATGACATTTGACTGAAGTATTTGTTTTTATTGTAATTCTGTTTATTTTTACACTTGCAAATAATTAATAAAACCAAGAAAGAGTATTTACAA,S1_H_T3
CTTGTGTACGACGAACTCAGAAGCCGCAAATAGGAGACTGTTTTCAGTTTTCTAGTTTGGACCCTTGCAAACAAGACCCTTTTTTGCGTCTGGTGTCGGAGGTGTTCATC,S1_H_T4
ATTTAAGATGTTTTCTCACGTTTGTATTCGCTTTTAATTAGGATGCAATGAAATTAAACCTTGATCTGATATTTCACTTTTCTTTAAATATAGACATGGACGAGCAGCTC,S1_H_T5
...,...
GTGGGCGGTGTGGACAGCGTGTCTGAGAGCACTGGCAGCATCCTCAGCAAGCTGGTCTGGAATGCCATCGAAGACATGGTGGCCAGCGTGGAGGACCAGGGCCTGTCTGT,S0_M_T1318
GCTCGTTAACAGCTGCTGTAACTAGTCTGGCCTACAATAGTGTGATTCATGTAGGACTTCTTTCATCAATTCAAAACCCCTAGAAAACGTATACAGATTATATAAGTAGG,S0_M_T1319
GCTCGTTAACAGCTGCTGTAACTAGTCTGGCCTACAATAGTGTGATTCATGTAGGCCTTCTTTCATCAATTCAAAACCCCTAGAAAACGTATACAGATTATATAAGTAGG,S0_M_T1320
AAGGGATGGTCCACATCAGAAAACTCACTAAATGTCATGTTAGAATCCCACATGGACTGCATGTGACCACCTACCATCCCTTTAGTACAAATTAAGCTATTAAAAATACA,S0_M_T1321


In [6]:
# Add ID to the main DataFrame

df = df.join(sequence_ids, on="sequence")

assert df["id"].is_unique
assert (df.sort_values(by="id", axis="index").reset_index()[["sequence", "id"]] ==
        sequence_ids.sort_values(by="id", axis="index").reset_index()) \
       .all().all()

df.set_index("id", inplace=True)

df

Unnamed: 0_level_0,sequence,secondary_structure,free_energy,secondary_structure_prob
id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
S0_M_T1,TGTCCCCGGGTCTTCCAACGGACTGGCGTTGCCCCGGTTCACTGGG...,.(.((((((((....(((((......)))))((((((....)))))...,-45.92,".(.((((((((.,,,{{..(|||{((|{..,{{||||,,,.}))))..."
S0_M_T10,AGATTTTTGGTTCAATATGCTCCTTGAGTGGAGTCTTAGTGATTGC...,........(((((.....(((((......)))))...(((((..((...,-31.17,"........(((((,.,..(((((......))))}...{({({..({..."
S0_M_T100,ACCCGGCGCCGCTCGACCCGGAGCGAGGAGTTGACCCGGAGCGAGG...,....((((((.((((..((((..(((....)))..))))..))))....,-41.95,"....((((((.((((..((((.,({(....})).,))))..)}}),..."
S0_M_T1000,ATGAGGGCTGGAATTTGCATTGAAACACTGGTCCAGTCGCTGTGTA...,.....(((((((...((........))....)))))))..(((......,-23.18,"...,((((((((.,.({........}}..,,))))))).,|((,,...."
S0_M_T1001,CCTTAGTGCCCTTAAAATAATGATTTAAGCATTTTACTGTATGTAT...,....(.((((((((((.(((.(......((.(((((((((.(.......,-30.84,"....{.(((((((((({(((.{......((.((((((((,.{......."
...,...,...,...,...
S3_H_T9995,TTGTAGCTGTCAATTGTATTTAATATACTTTTTTGTCTTTTTAATT...,((((..(((.((((((.((((((...........(((............,-18.50,",{(((((((.,(({{((((.....}}}}......(((............"
S3_H_T9996,AAAACACCACTACATATGTTTCTCATAAGCGCAACTGTAGTGTTAT...,((((((((((((((..(((..((....)).)))..))))))).......,-19.40,"((((((((((((((..(((..,.....}}.)})..)))))))......."
S3_H_T9997,AGGATTTTTTTTTTCACCAATGCTCTTTAATACACACTTGCCTATA...,.(((..((((((((.......((................))........,-20.95,",((,((((((((({...,,,,({..............,,)}....)..."
S3_H_T9998,GGTGCTTCAAAGAGTGATTACCCACTAACTAATGAACCCAGACTGT...,((((..(((.....))).))))..((((..((((((((((((.(((...,-28.53,"((((..(((.....))).))))..((((..(((((((((,((.(((..."


In [7]:
# Read degradation rates

def load_deg_rate(filename: str) -> pd.DataFrame:
    return pd.read_table(filename,
                         header=None,
                         names=["id", "log2_deg_rate", "log2_x0", "onset_time"],
                         dtype={"id": str},
                         index_col="id")

deg_rate_a_plus = load_deg_rate("data/3U.models.3U.40A.seq1022_param.txt")
deg_rate_a_minus = load_deg_rate("data/3U.models.3U.00A.seq1022_param.txt")
assert (deg_rate_a_plus.index == deg_rate_a_minus.index).all()

deg_rate_a_plus = deg_rate_a_plus.add_suffix("_a_plus")
deg_rate_a_minus = deg_rate_a_minus.add_suffix("_a_minus")

In [8]:
# Add degradation rates to the main DataFrame

df = df.join([deg_rate_a_plus, deg_rate_a_minus])
df

Unnamed: 0_level_0,sequence,secondary_structure,free_energy,secondary_structure_prob,log2_deg_rate_a_plus,log2_x0_a_plus,onset_time_a_plus,log2_deg_rate_a_minus,log2_x0_a_minus,onset_time_a_minus
id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1
S0_M_T1,TGTCCCCGGGTCTTCCAACGGACTGGCGTTGCCCCGGTTCACTGGG...,.(.((((((((....(((((......)))))((((((....)))))...,-45.92,".(.((((((((.,,,{{..(|||{((|{..,{{||||,,,.}))))...",,,,,,
S0_M_T10,AGATTTTTGGTTCAATATGCTCCTTGAGTGGAGTCTTAGTGATTGC...,........(((((.....(((((......)))))...(((((..((...,-31.17,"........(((((,.,..(((((......))))}...{({({..({...",-2.7469,2.7887,1.0,-2.1721,2.5964,1.0
S0_M_T100,ACCCGGCGCCGCTCGACCCGGAGCGAGGAGTTGACCCGGAGCGAGG...,....((((((.((((..((((..(((....)))..))))..))))....,-41.95,"....((((((.((((..((((.,({(....})).,))))..)}}),...",,,,,,
S0_M_T1000,ATGAGGGCTGGAATTTGCATTGAAACACTGGTCCAGTCGCTGTGTA...,.....(((((((...((........))....)))))))..(((......,-23.18,"...,((((((((.,.({........}}..,,))))))).,|((,,....",,,,,,
S0_M_T1001,CCTTAGTGCCCTTAAAATAATGATTTAAGCATTTTACTGTATGTAT...,....(.((((((((((.(((.(......((.(((((((((.(.......,-30.84,"....{.(((((((((({(((.{......((.((((((((,.{.......",,,,,,
...,...,...,...,...,...,...,...,...,...,...
S3_H_T9995,TTGTAGCTGTCAATTGTATTTAATATACTTTTTTGTCTTTTTAATT...,((((..(((.((((((.((((((...........(((............,-18.50,",{(((((((.,(({{((((.....}}}}......(((............",,,,,,
S3_H_T9996,AAAACACCACTACATATGTTTCTCATAAGCGCAACTGTAGTGTTAT...,((((((((((((((..(((..((....)).)))..))))))).......,-19.40,"((((((((((((((..(((..,.....}}.)})..))))))).......",-2.5808,3.4966,1.0,-2.3105,3.3307,1.0
S3_H_T9997,AGGATTTTTTTTTTCACCAATGCTCTTTAATACACACTTGCCTATA...,.(((..((((((((.......((................))........,-20.95,",((,((((((((({...,,,,({..............,,)}....)...",,,,,,
S3_H_T9998,GGTGCTTCAAAGAGTGATTACCCACTAACTAATGAACCCAGACTGT...,((((..(((.....))).))))..((((..((((((((((((.(((...,-28.53,"((((..(((.....))).))))..((((..(((((((((,((.(((...",,,,,,


In [9]:
df.dropna(inplace=True)
df

Unnamed: 0_level_0,sequence,secondary_structure,free_energy,secondary_structure_prob,log2_deg_rate_a_plus,log2_x0_a_plus,onset_time_a_plus,log2_deg_rate_a_minus,log2_x0_a_minus,onset_time_a_minus
id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1
S0_M_T10,AGATTTTTGGTTCAATATGCTCCTTGAGTGGAGTCTTAGTGATTGC...,........(((((.....(((((......)))))...(((((..((...,-31.17,"........(((((,.,..(((((......))))}...{({({..({...",-2.746900,2.7887,1.0,-2.17210,2.5964,1.0
S0_M_T1006,TAGATAGAGATCATCTTTACAGTTCCTCGGGAAAATGTGCTTGTGA...,...(((((((((((...((((.((((...))))..))))...))))...,-26.55,".,,({(((({((((...((((.((((...))))..))))...))))...",-2.495200,3.5146,1.0,-1.94970,3.1963,1.0
S0_M_T1009,TAGTTATTGTGTGTTGCTAATCATTGACTGTAGTCCCAGTCTGGGA...,.....(((((((((((.(((.((..(((((.(((..((((((((((...,-33.05,"...,,(((((((((((.{((.{(..(((((.(((..((((((((((...",-2.550700,2.7105,1.0,-1.51500,2.8747,1.0
S0_M_T1013,TGATTCTAGTATATAATATTTTTGTCACGCACCTGCTGACTTAGGA...,.......................((((.((....))))))...((....,-20.70,".......,{,,............((((.((....))))))..,((....",-2.327900,3.7761,1.0,-1.89040,3.2967,1.0
S0_M_T1014,TTCTAGACTTTCCAAGTATGTTGTCTTTCCAATGGTGCGACAGAGC...,.............(((..(((((((((......(((((......))...,-22.82,"......,,,,...|||,.,(((((,,{{((.{((((((......))...",-1.623200,1.6160,4.8,-2.09580,2.1356,1.0
...,...,...,...,...,...,...,...,...,...,...
S3_H_T9985,GTCCTTATTTACATGTTTCATTGAGCCCTTTTTGATGTGATTCTTG...,.............((((..(((((.........(((((((.((((....,-18.22,",....,,,,...,((((,.((({{.........(((((((.((((....",-2.027200,2.6826,1.0,-1.59040,2.5908,1.0
S3_H_T9987,TCAATGGTTACAGGTTTCAAACATTCTTCAAAATATCTTCTTTTTG...,(((((((((((((((((.......((..(((((........)))))...,-30.56,(((((((((((((((((.......{{..(((((........)))))...,-2.589200,2.7555,1.0,-2.05310,2.5282,1.0
S3_H_T9989,TGAAAGCACAGAGGGGCTGAGATTCTAAGGGCACTTCATGTTTTTT...,.(((((((..((((.(((...........))).)))).)))))))....,-24.03,".((((({(.,((((,(((.,,...,....))).)))).))))))),...",0.023414,1.2948,4.5,-0.75861,1.4902,3.0
S3_H_T9990,AATTAAAGAGAGAGAGAGACGGAGAACACGGTGGGTTTACTAGCGC...,.........((.(((((((((.((....((.(((.....))).)))...,-27.22,".........{(.(((((((((.,{....((.(((.....))).)),...",-2.463100,1.5812,1.0,-2.71790,1.2301,1.0


In [10]:
# One-hot encode the sequences

sequences = df["sequence"].str.split("", expand=True)
sequences.drop(columns=[sequences.columns[0], sequences.columns[-1]], inplace=True)
sequences = sequences.add_prefix("sequence_")
sequences = pd.get_dummies(sequences, sparse=True)

In [11]:
# Merge into main DataFrame

df.drop(columns=["sequence"], inplace=True)
df = pd.concat([sequences, df], axis="columns")
df

Unnamed: 0_level_0,sequence_1_A,sequence_1_C,sequence_1_G,sequence_1_T,sequence_2_A,sequence_2_C,sequence_2_G,sequence_2_T,sequence_3_A,sequence_3_C,...,sequence_110_T,secondary_structure,free_energy,secondary_structure_prob,log2_deg_rate_a_plus,log2_x0_a_plus,onset_time_a_plus,log2_deg_rate_a_minus,log2_x0_a_minus,onset_time_a_minus
id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
S0_M_T10,1,0,0,0,0,0,1,0,1,0,...,1,........(((((.....(((((......)))))...(((((..((...,-31.17,"........(((((,.,..(((((......))))}...{({({..({...",-2.746900,2.7887,1.0,-2.17210,2.5964,1.0
S0_M_T1006,0,0,0,1,1,0,0,0,0,0,...,1,...(((((((((((...((((.((((...))))..))))...))))...,-26.55,".,,({(((({((((...((((.((((...))))..))))...))))...",-2.495200,3.5146,1.0,-1.94970,3.1963,1.0
S0_M_T1009,0,0,0,1,1,0,0,0,0,0,...,1,.....(((((((((((.(((.((..(((((.(((..((((((((((...,-33.05,"...,,(((((((((((.{((.{(..(((((.(((..((((((((((...",-2.550700,2.7105,1.0,-1.51500,2.8747,1.0
S0_M_T1013,0,0,0,1,0,0,1,0,1,0,...,1,.......................((((.((....))))))...((....,-20.70,".......,{,,............((((.((....))))))..,((....",-2.327900,3.7761,1.0,-1.89040,3.2967,1.0
S0_M_T1014,0,0,0,1,0,0,0,1,0,1,...,0,.............(((..(((((((((......(((((......))...,-22.82,"......,,,,...|||,.,(((((,,{{((.{((((((......))...",-1.623200,1.6160,4.8,-2.09580,2.1356,1.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
S3_H_T9985,0,0,1,0,0,0,0,1,0,1,...,0,.............((((..(((((.........(((((((.((((....,-18.22,",....,,,,...,((((,.((({{.........(((((((.((((....",-2.027200,2.6826,1.0,-1.59040,2.5908,1.0
S3_H_T9987,0,0,0,1,0,1,0,0,1,0,...,0,(((((((((((((((((.......((..(((((........)))))...,-30.56,(((((((((((((((((.......{{..(((((........)))))...,-2.589200,2.7555,1.0,-2.05310,2.5282,1.0
S3_H_T9989,0,0,0,1,0,0,1,0,1,0,...,0,.(((((((..((((.(((...........))).)))).)))))))....,-24.03,".((((({(.,((((,(((.,,...,....))).)))).))))))),...",0.023414,1.2948,4.5,-0.75861,1.4902,3.0
S3_H_T9990,1,0,0,0,1,0,0,0,0,0,...,0,.........((.(((((((((.((....((.(((.....))).)))...,-27.22,".........{(.(((((((((.,{....((.(((.....))).)),...",-2.463100,1.5812,1.0,-2.71790,1.2301,1.0


In [12]:
def match_parens(string: str) -> np.ndarray:
    """
    Returns a matrix of matching parentheses. For each pair of indices i, j
    in the input string, the cell (i, j) in the matrix will have a value of 1
    iff i and j contain a matching pair of parens.
    """

    pairs_matrix = np.zeros((len(string), len(string)), dtype=np.uint8)

    stack = []
    for index, char in enumerate(string):
        if char == '(':
            stack.append(index)
        elif char == ')':
            open_index = stack.pop()
            pairs_matrix[open_index, index] = 1
            pairs_matrix[index, open_index] = 1
    assert not stack

    return pairs_matrix

In [13]:
# One-hot encode the secondary structure of each sequence

all_pairs_matrices = vstack(df["secondary_structure"].map(lambda struct: coo_matrix(match_parens(struct).reshape(-1))))
secondary_structures = pd.DataFrame.sparse.from_spmatrix(all_pairs_matrices,
                                                         index=df.index,
                                                         columns=pd.RangeIndex(1, all_pairs_matrices.shape[1] + 1))
secondary_structures = secondary_structures.add_prefix("secondary_structure_")

In [14]:
# Merge into main DataFrame

df.drop(columns=["secondary_structure"], inplace=True)
df = pd.concat([df, secondary_structures], axis="columns")
df

Unnamed: 0_level_0,sequence_1_A,sequence_1_C,sequence_1_G,sequence_1_T,sequence_2_A,sequence_2_C,sequence_2_G,sequence_2_T,sequence_3_A,sequence_3_C,...,secondary_structure_12091,secondary_structure_12092,secondary_structure_12093,secondary_structure_12094,secondary_structure_12095,secondary_structure_12096,secondary_structure_12097,secondary_structure_12098,secondary_structure_12099,secondary_structure_12100
id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
S0_M_T10,1,0,0,0,0,0,1,0,1,0,...,0,0,0,0,0,0,0,0,0,0
S0_M_T1006,0,0,0,1,1,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
S0_M_T1009,0,0,0,1,1,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
S0_M_T1013,0,0,0,1,0,0,1,0,1,0,...,0,0,0,0,0,0,0,0,0,0
S0_M_T1014,0,0,0,1,0,0,0,1,0,1,...,0,0,0,0,0,0,0,0,0,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
S3_H_T9985,0,0,1,0,0,0,0,1,0,1,...,0,0,0,0,0,0,0,0,0,0
S3_H_T9987,0,0,0,1,0,1,0,0,1,0,...,0,0,0,0,0,0,0,0,0,0
S3_H_T9989,0,0,0,1,0,0,1,0,1,0,...,0,0,0,0,0,0,0,0,0,0
S3_H_T9990,1,0,0,0,1,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0


In [15]:
sequences_tensor = df.filter(regex="^sequence", axis="columns").to_numpy()
sequences_tensor = sequences_tensor.reshape(-1, sequences_tensor.shape[1] // 4, 4)
sequences_tensor = sequences_tensor.transpose(0, 2, 1)
sequences_tensor = sequences_tensor.astype(np.float32)
sequences_tensor

array([[[1., 0., 1., ..., 0., 0., 0.],
        [0., 0., 0., ..., 1., 1., 0.],
        [0., 1., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 1.]],

       [[0., 1., 0., ..., 0., 1., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 1., ..., 0., 0., 0.],
        [1., 0., 0., ..., 1., 0., 1.]],

       [[0., 1., 0., ..., 0., 1., 0.],
        [0., 0., 0., ..., 1., 0., 0.],
        [0., 0., 1., ..., 0., 0., 0.],
        [1., 0., 0., ..., 0., 0., 1.]],

       ...,

       [[0., 0., 1., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 1., 0., ..., 0., 0., 1.],
        [1., 0., 0., ..., 1., 1., 0.]],

       [[1., 1., 0., ..., 1., 0., 1.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 1., 0.],
        [0., 0., 1., ..., 0., 0., 0.]],

       [[1., 1., 1., ..., 0., 0., 1.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 1., 1., 0.]]], dtype=float32)

In [16]:
all_data = list(zip(sequences_tensor, df[["log2_deg_rate_a_plus", "log2_deg_rate_a_minus"]].to_numpy().astype(np.float32)))

# NN

In [17]:
import torch
import torch.nn as nn

In [18]:
class Net(nn.Module):
    def __init__(self):
        super().__init__()

        self.conv1 = nn.Conv1d(4, 96, 12)
        self.norm = nn.BatchNorm1d(96)
        self.relu = nn.ReLU()
        self.linear = nn.Linear(96 * 99, 2)

    def forward(self, x):
        x = self.conv1(x)
        x = self.norm(x)
        x = self.relu(x)

        x = x.view(-1, 96 * 99)
        x = self.linear(x)

        return x

net = Net()
net

Net(
  (conv1): Conv1d(4, 96, kernel_size=(12,), stride=(1,))
  (norm): BatchNorm1d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU()
  (linear): Linear(in_features=9504, out_features=2, bias=True)
)

In [19]:
import torch.optim as optim

criterion = nn.MSELoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)

In [20]:
train_size = int(0.9 * len(all_data))
test_size = len(all_data) - train_size
train_data, test_data = torch.utils.data.random_split(all_data, [train_size, test_size])

In [21]:
train_loader = torch.utils.data.DataLoader(train_data, batch_size=4, shuffle=True, num_workers=2)

In [22]:
for epoch in range(10):
    for sequences, rates in train_loader:
        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = net(sequences)
        loss = criterion(outputs, rates)
        loss.backward()
        optimizer.step()

    print(f"Epoch {epoch + 1}")

Epoch 1
Epoch 2
Epoch 3
Epoch 4
Epoch 5
Epoch 6
Epoch 7
Epoch 8
Epoch 9
Epoch 10


In [23]:
test_loader = torch.utils.data.DataLoader(test_data, batch_size=4, shuffle=False, num_workers=2)

In [24]:
with torch.no_grad():
    errors = np.zeros((test_size, 2))
    for batch, (sequences, rates) in enumerate(test_loader):
        outputs = net(sequences)

        row = batch * test_loader.batch_size
        errors[row:row + test_loader.batch_size, :] = (outputs - rates) ** 2

In [25]:
mse = pd.DataFrame(errors, columns=["a_plus", "a_minus"]).mean(axis="index")
mse


a_plus     0.660364
a_minus    0.510896
dtype: float64