In [1]:
import torch
from torch.utils.data import TensorDataset,DataLoader,Dataset
from torch import nn

import logomaker
import math
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import random
random.seed(7) #10

import utils as u

In [2]:
def load_data(
        upstream_region_file,
        data_mat_file, 
        sample2cond_file, 
        sample_file, 
        condition_file,
        coded_meta_file
        ):
    '''
    Wrapper function to load data from files into relavent objects
    '''
    # load upstream seq regions
    seqs = u.load_promoter_seqs(upstream_region_file)
    loc2seq = dict([(x,z) for (x,y,z) in seqs])
    
    # load TPM data
    tpm_df = pd.read_csv(data_mat_file,sep='\t').fillna('')

    
    # load mapping from sample to condition
    with open(sample2cond_file,'r') as f:
        sample2condition = dict(x.strip().split() for x in f.readlines())

    
    # load sample to include file
    if sample_file:
        with open(sample_file,'r') as f:
            samples = list(x.strip() for x in f.readlines())
    # if none provided, just use all the samples from the sample2condition dict
    else: 
        samples = list(sample2condition.keys())

        
    # load the conditions to include file
    if condition_file:
        with open(condition_file,'r') as f:
            conditions = list(x.strip() for x in f.readlines())
    # if none provided, just use all the conditions
    else:
        conditions = list(set([sample2condition[x] for x in sample2condition]))

    # load coded metadata file
    meta_df = pd.read_csv(coded_meta_file,sep='\t')
    meta_df['sample'] = meta_df['#sample']+'_tpm'

    return loc2seq, tpm_df, sample2condition, samples, conditions, meta_df

In [3]:
upstream_region_file = 'all_seq_info/all_loci_upstream_regions_w100_min20.fa'
data_mat_file = 'data/extract_TPM_counts.tsv'
sample2cond_file = 'data/sample2condition.txt'
sample_file = None
condition_file = 'data/conditions_to_include.txt'
coded_meta_file = 'data/5G_exp_metadata_coded.tsv'
COND_COLS = ['carbon_source','oxygen_level','nitrate_level','copper_level','lanthanum_level','growth_rate','growth_mode']


loc2seq, tpm_df, sample2condition, samples, conditions, meta_df = load_data(
    upstream_region_file,
    data_mat_file, 
    sample2cond_file, 
    sample_file, 
    condition_file,
    coded_meta_file
)

In [4]:
tpm_df.head()

Unnamed: 0,locus_tag,product,type,gene_symbol,locus,start_coord,end_coord,note,translation,gene_len,...,5GB1_pA9_red_tpm,5GB1_pA9_yellow_tpm,5GB1C-5G-La-BR1_tpm,5GB1C-5G-La-BR2_tpm,5GB1C-5G-N-BR1_tpm,5GB1C-5G-N-BR2_tpm,5GB1C-JG15-La-BR1_tpm,5GB1C-JG15-La-BR2_tpm,5GB1C-JG15-N-BR1_tpm,5GB1C-JG15-N-BR2_tpm
0,EQU24_RS00005,chromosomal replication initiator protein DnaA,CDS,dnaA,NZ_CP035467.1,0,1317,Derived by automated computational analysis us...,MSALWNNCLAKLENEISSSEFSTWIRPLQAIETDGQIKLLAPNRFV...,1318,...,38.557373,38.810668,37.444214,40.246006,40.100118,33.432274,39.880174,38.355431,30.247582,41.248441
1,EQU24_RS00010,DNA polymerase III subunit beta,CDS,,NZ_CP035467.1,1502,2603,Derived by automated computational analysis us...,MKYIINREQLLVPLQQIVSVIEKRQTMPILSNVLMVFRENTLVMTG...,1102,...,52.552767,52.461746,42.676553,49.210083,46.798476,48.142385,45.465136,46.498139,37.152951,52.90241
2,EQU24_RS00015,DNA replication/repair protein RecF,CDS,recF,NZ_CP035467.1,3060,4140,Derived by automated computational analysis us...,MSLQKLDIFNVRNIRQASLQPSPGLNLIYGANASGKSSVLEAIFIL...,1081,...,31.350991,34.914128,21.479309,24.204682,22.171104,22.006566,22.658157,22.753325,19.407103,29.834124
3,EQU24_RS00020,DNA topoisomerase (ATP-hydrolyzing) subunit B,CDS,gyrB,NZ_CP035467.1,4185,6600,Derived by automated computational analysis us...,MSENIKQYDSTNIQVLKGLDAVRKRPGMYIGDTDDGTGLHHMVFEV...,2416,...,74.848501,80.850761,54.959319,64.911376,59.653059,64.648318,69.119079,65.643179,57.590223,68.306759
4,EQU24_RS00025,hypothetical protein,CDS,,NZ_CP035467.1,6825,7062,Derived by automated computational analysis us...,VKTTKYFLTTRMRPDREIIKDEWIQYVVRFPENEHIQFDGRIRRWA...,238,...,50.324948,49.349547,34.539657,36.521074,37.789611,39.358066,38.992158,35.870964,41.462392,40.227192


In [5]:
locus2info = u.make_info_dict(tpm_df)
locus2info['EQU24_RS19315']

{'gene': 'pmoC',
 'product': 'methane monooxygenase/ammonia monooxygenase subunit C',
 'type': 'CDS'}

In [6]:
df_means = u.get_gene_means_by_condition(tpm_df,samples,sample2condition)

In [7]:
df_means

locus_tag,exp_condition,EQU24_RS00005,EQU24_RS00010,EQU24_RS00015,EQU24_RS00020,EQU24_RS00025,EQU24_RS00030,EQU24_RS00035,EQU24_RS00040,EQU24_RS00045,...,EQU24_RS22110,EQU24_RS22115,EQU24_RS22120,EQU24_RS22125,EQU24_RS22130,EQU24_RS22135,EQU24_RS22140,EQU24_RS22145,EQU24_RS22150,EQU24_RS22155
0,LanzaTech,25.626702,55.71192,11.804042,76.880335,43.737438,27.940983,35.296053,25.926725,23.461781,...,7420.641716,16.014544,15.947067,16.286482,9.514666,46.013885,354.315359,157.364073,571.458102,613.084675
1,MeOH,23.323155,18.905775,18.443916,18.257805,16.950643,12.367795,43.805536,9.66095,7.292145,...,1298.247682,15.614619,20.198066,25.994364,20.950234,28.709983,93.606437,161.518124,496.980651,280.334047
2,NO3_lowO2_slow_growth,32.040358,43.64676,21.341623,62.257687,41.674925,31.911455,57.839768,16.875694,14.916147,...,6497.858109,26.263485,28.935133,23.515245,26.422667,35.157264,178.986199,164.073806,433.428735,493.885115
3,NoCu,44.338687,59.61936,28.258717,56.808319,49.829406,38.384652,81.520362,40.491969,36.5665,...,8345.775345,43.055124,34.370565,44.409579,34.591933,65.329879,253.598495,273.274694,731.04219,1087.611126
4,NoLanthanum,33.434023,43.679839,23.162675,57.287047,42.357072,41.931657,102.503601,30.216787,19.452312,...,5085.627409,16.413284,35.578138,44.613117,43.191743,21.91726,109.77333,67.267718,211.565175,328.933746
5,WT_control,34.988452,38.272163,15.49593,54.855025,35.396786,26.817623,59.558466,24.797995,27.925137,...,4939.367129,18.248719,21.027514,31.490596,45.810743,36.195898,172.298263,181.468991,403.551659,648.122601
6,WithLanthanum,35.452185,41.782237,20.634554,57.120166,34.248335,46.191637,110.711781,31.803805,19.428086,...,3942.947792,15.962203,34.308829,49.206725,39.990662,21.210809,98.09061,73.106973,194.379586,319.988959
7,aa3_KO,26.849583,45.489035,11.962099,55.846659,33.097353,28.718155,50.249664,21.490715,24.790508,...,5355.094603,16.236806,20.836285,28.298842,40.627394,30.06563,181.790564,153.928317,376.664997,657.228922
8,crotonic_acid,38.684021,52.507256,33.132559,77.849631,49.837247,37.332021,69.675355,32.001867,35.85626,...,8680.904376,42.175374,47.879804,58.961083,43.069352,81.481463,365.899345,315.553709,614.920997,676.638496
9,highCu,47.851477,79.09949,33.524043,73.320408,48.652214,33.976359,92.989818,51.940784,50.360579,...,8132.537467,48.884308,35.59873,46.0993,30.115207,89.70015,342.971435,386.483127,1021.443762,1692.391154


In [8]:
# filter to only the copper conditions
cu_conds = ['NoCu','lowCu','medCu','highCu']

df_means_cu = df_means[df_means['exp_condition'].isin(cu_conds)]
df_means_cu

locus_tag,exp_condition,EQU24_RS00005,EQU24_RS00010,EQU24_RS00015,EQU24_RS00020,EQU24_RS00025,EQU24_RS00030,EQU24_RS00035,EQU24_RS00040,EQU24_RS00045,...,EQU24_RS22110,EQU24_RS22115,EQU24_RS22120,EQU24_RS22125,EQU24_RS22130,EQU24_RS22135,EQU24_RS22140,EQU24_RS22145,EQU24_RS22150,EQU24_RS22155
3,NoCu,44.338687,59.61936,28.258717,56.808319,49.829406,38.384652,81.520362,40.491969,36.5665,...,8345.775345,43.055124,34.370565,44.409579,34.591933,65.329879,253.598495,273.274694,731.04219,1087.611126
9,highCu,47.851477,79.09949,33.524043,73.320408,48.652214,33.976359,92.989818,51.940784,50.360579,...,8132.537467,48.884308,35.59873,46.0993,30.115207,89.70015,342.971435,386.483127,1021.443762,1692.391154
12,lowCu,42.963556,61.199155,28.818713,61.563321,50.956799,31.309574,75.037593,40.02867,35.321019,...,7157.334557,43.376082,33.564108,36.862718,31.889782,66.733497,293.589291,313.731841,843.597251,1123.659681
15,medCu,44.910897,65.001074,29.409165,65.383162,48.719958,26.307981,80.114744,39.235705,32.118273,...,5934.158113,43.944096,36.386184,34.983692,32.897138,84.616668,340.021564,377.209038,1060.040784,1429.583942


In [9]:
# transpose to get genes as examples again and add upstream region
XYdf = df_means_cu.set_index('exp_condition').T.reset_index()
XYdf['upstream_region'] = XYdf['locus_tag'].apply(lambda x: loc2seq[x])
XYdf = XYdf[['locus_tag','upstream_region']+cu_conds]

XYdf

exp_condition,locus_tag,upstream_region,NoCu,lowCu,medCu,highCu
0,EQU24_RS00005,CGCCGGTTTATGTCAATTATGCCGGCACTGATTTGATTGCTGTATA...,44.338687,42.963556,44.910897,47.851477
1,EQU24_RS00010,AACGCCGGTTTTACAGTTCATAAGCTATTGATAAATAAAATAAAAA...,59.619360,61.199155,65.001074,79.099490
2,EQU24_RS00015,ATCGCAGTCATTATTAAATGTGGAAGCAACAAAAAAACGAGCTTGT...,28.258717,28.818713,29.409165,33.524043
3,EQU24_RS00020,AACTTAATAACTATAAAATGTTCCACGTGGAACATGGTGAAATTAA...,56.808319,61.563321,65.383162,73.320408
4,EQU24_RS00025,CTTTGCCGAACACCCCGCACCTCCACGCGTCAACAACGAAATTTGA...,49.829406,50.956799,48.719958,48.652214
...,...,...,...,...,...,...
4208,EQU24_RS22135,CCCGGCCGGTTTGGTCTTGTACTGGGTGGTCAACAATACGCTGTCG...,65.329879,66.733497,84.616668,89.700150
4209,EQU24_RS22140,GCCGCCCAGGGCACCTATCTTACAGTCCGAAGAGTATTAAAGTGTC...,253.598495,293.589291,340.021564,342.971435
4210,EQU24_RS22145,AATATTGATGTTGTTGTTATGGCCCGAAAAGATGCACTCAATGCAT...,273.274694,313.731841,377.209038,386.483127
4211,EQU24_RS22150,AAGAACTCACGGCTTTCGTGCCAGAATGGCGACCAAAGGCGGCCGT...,731.042190,843.597251,1060.040784,1021.443762


In [10]:
loc2seq["EQU24_RS19315"]

'GCAGCACCAAATTGGACTGGTAGAGCTTAAATAAAAGCGTTAAAGGGATGTTTTAAAACAACCGCCCTTCGGGGTTTTTAAAAATTTTTTAGGAGGTAGA'

In [11]:
def one_hot_encode(seq):
    #print("one hot encoding...")
    
    # Dictionary returning one-hot encoding of nucleotides. 
    nuc_d = {'A':[1.0,0.0,0.0,0.0],
             'C':[0.0,1.0,0.0,0.0],
             'G':[0.0,0.0,1.0,0.0],
             'T':[0.0,0.0,0.0,1.0],
             'N':[0.0,0.0,0.0,0.0]}
    
    # Creat empty matrix.
    #vec=torch.tensor([nuc_d[x] for x in seq])
    vec=np.array([nuc_d[x] for x in seq]).flatten()
        
    return vec

In [12]:
class MultiTaskDataset(Dataset):
    def __init__(self,df):
        self.loci = list(df['locus_tag'])
        self.seqs = list(df['upstream_region'])
        self.seq_len = len(self.seqs[0])
    
        self.nocu_labels = list(df['NoCu'])
        self.lowcu_labels = list(df['lowCu'])
        self.medcu_labels = list(df['medCu'])
        self.highcu_labels = list(df['highCu'])
        
    def __len__(self): return len(self.loci)
    
    def __getitem__(self,idx):
        seq = torch.tensor(one_hot_encode(self.seqs[idx]))
        
        nocu = torch.tensor(float(self.nocu_labels[idx]))
        lowcu = torch.tensor(float(self.lowcu_labels[idx]))
        medcu = torch.tensor(float(self.medcu_labels[idx]))
        highcu = torch.tensor(float(self.highcu_labels[idx]))
        
        return seq, (nocu, lowcu, medcu, highcu)
        # ??? ^ how is this handled? Does the data loader know what to do with
        # a list of labels...?
        

In [13]:
def quick_load_and_split_input(df, split_frac=0.8, verbose=False):
    '''
    Given a df, randomly split between
    train and test. Not a formal train/test split, just a quick n dirty version.
    
    '''

    # train test split
    idxs = list(range(df.shape[0]))
    random.shuffle(idxs)

    split = int(len(idxs)*split_frac)
    train_idxs = idxs[:split]
    test_idxs = idxs[split:]
    
    # split df and convert to tensors
    train_df = df[df.index.isin(train_idxs)]
    test_df = df[df.index.isin(test_idxs)]
        
    return train_df, test_df


def build_dataloaders(df,batch_size=32):
    # train test split
    train_df, test_df = quick_load_and_split_input(df)
    
    # put into dataloader object
    train_ds = MultiTaskDataset(train_df)
    train_dl = DataLoader(train_ds, batch_size=batch_size, shuffle=True)
    
    test_ds = MultiTaskDataset(test_df)
    test_dl = DataLoader(test_ds, batch_size=batch_size * 2)
    
    return train_dl, test_dl,train_df,test_df

In [14]:
# did this actually work???
train_dl, test_dl,train_df, test_df = build_dataloaders(XYdf)

In [15]:
train_dl

<torch.utils.data.dataloader.DataLoader at 0x7f9bd5b04640>

In [16]:
for xb, yb in train_dl:
    print(xb)
    print(yb)
    print('*******')

tensor([[0., 0., 1.,  ..., 1., 0., 0.],
        [0., 1., 0.,  ..., 1., 0., 0.],
        [0., 0., 1.,  ..., 1., 0., 0.],
        ...,
        [0., 0., 1.,  ..., 0., 0., 1.],
        [0., 0., 1.,  ..., 0., 1., 0.],
        [0., 1., 0.,  ..., 0., 1., 0.]], dtype=torch.float64)
[tensor([ 57.4526,  15.8513,  25.2045,  57.2232,  95.0641,  20.9470, 157.2000,
        565.0566,  89.2762,  58.7207, 886.0788,   0.0000, 206.4296,  53.3444,
         31.4181,  15.9749,   6.6779,   3.8572, 275.9895,  30.2502,  20.6546,
          8.9382,  48.0568,  25.6551,  11.5600,  15.0710,  94.1441, 133.8691,
         64.3297, 116.5948,  14.9377,  80.9924]), tensor([ 57.0908,  12.4401,  20.5634,  57.3638,  90.0727,  19.1010, 126.9318,
        586.9841,  67.6189,  52.8856, 994.4968,   0.0000, 150.3905,  54.9581,
         30.2176,  15.8771,   6.2018,   3.7174, 238.7724,  30.5601,  19.4110,
          7.3877,  47.3123,  24.2728,   9.3808,  15.0326,  83.3939, 110.1302,
         50.4473, 144.7868,  13.8450,  63.1770]), 

tensor([[0., 1., 0.,  ..., 0., 0., 1.],
        [1., 0., 0.,  ..., 1., 0., 0.],
        [1., 0., 0.,  ..., 0., 0., 1.],
        ...,
        [0., 1., 0.,  ..., 1., 0., 0.],
        [0., 0., 1.,  ..., 0., 0., 1.],
        [1., 0., 0.,  ..., 0., 0., 0.]], dtype=torch.float64)
[tensor([  64.4531,   32.6824,   89.5751,   25.2323,   56.0682,   59.9594,
          30.1208,   51.8238,   26.9085,   34.1492,   14.6181,   54.1326,
          62.8292,   38.5142,   27.9057,  140.5682, 1281.2480,  106.3304,
          30.4411, 1646.5765,  178.8220,   49.9468,  146.3354,  133.8345,
          77.2065,   12.4251,  118.8393,  603.4210,   53.3132,   66.7904,
          51.9758,   74.0779]), tensor([  62.1313,   36.9827,   82.8692,   23.0107,   53.2902,   53.3887,
          36.4628,   44.8918,   26.6823,   31.7016,   14.4856,   48.3111,
          67.3286,   37.6580,   23.8296,  129.4805, 1294.1810,  115.3336,
          26.4534, 1039.6091,  160.3840,   45.3788,  147.7827,  157.6620,
          71.5275,   10.93

[tensor([  3.2572,  17.6607, 151.4836,  95.7364,  38.8680, 157.8991,  14.8089,
          0.0000,  94.1744,   3.1564,  89.7540,  21.3440,   8.4027,   0.0000,
         23.7760,   8.6159,  41.8675, 319.8055,  36.1004,  40.9664,  26.2311,
        430.4978, 114.9411,  21.0129,   0.0000,  29.5927,  60.0767,  62.5589,
         26.2852, 138.9278,  63.7577, 135.1808]), tensor([  2.9559,  17.3504, 151.5593,  91.5878,  34.5746, 155.2073,  15.7662,
          0.0000,  88.2011,   2.9540, 105.5980,  17.7547,   9.5096,   0.0000,
         23.3294,   9.2897,  42.2034, 287.7898,  31.8812,  52.1509,  29.0843,
        509.9711, 127.3358,  21.7592,   0.0000,  28.6576,  58.0124,  49.0111,
         21.0329,  97.3232,  59.6665, 141.3152]), tensor([  2.5713,  15.7243, 138.0378,  88.7348,  28.4701, 138.1975,  14.2322,
          0.0000,  91.6558,   3.0243, 112.6045,  16.7883,   8.8817,   0.0000,
         19.5151,   9.6584,  37.0015, 282.2116,  27.2799,  56.1730,  38.2692,
        528.9481, 148.3075,  18.0215,   0

tensor([[0., 0., 1.,  ..., 1., 0., 0.],
        [0., 0., 1.,  ..., 0., 1., 0.],
        [0., 1., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 1., 0.,  ..., 0., 0., 1.],
        [0., 1., 0.,  ..., 1., 0., 0.],
        [0., 0., 1.,  ..., 1., 0., 0.]], dtype=torch.float64)
[tensor([   4.1831,   92.1448,   15.5674,  128.1308,   37.6818,  132.2145,
           9.1057, 3034.4563,  481.6253,  240.7278,  112.2584,  231.4232,
          17.0662,   12.9056,   26.8021,  115.6876,   58.6155,  211.9420,
        2892.0647,   42.2188,  100.8872,   24.8384,   61.4486, 1212.4197,
          90.0458,   21.4523,   52.1914,  432.1953,   40.1351,    6.6155,
          27.5036,   64.3018]), tensor([   4.2611,   87.1331,   12.9371,  120.8802,   35.2343,  120.4923,
           9.9639, 3260.9917,  468.9299,  231.6390,   99.5569,  165.4391,
          12.7714,   11.3402,   25.3764,  118.9057,   53.5123,  246.4681,
        3331.9507,   37.6329,   85.6990,   29.1140,   54.4734, 1424.6825,
          81.3513,   17.58

tensor([[0., 1., 0.,  ..., 0., 0., 1.],
        [0., 0., 1.,  ..., 1., 0., 0.],
        [1., 0., 0.,  ..., 1., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 1., 0., 0.],
        [0., 1., 0.,  ..., 1., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 1.]], dtype=torch.float64)
[tensor([3.4238e+02, 5.5804e+01, 3.7398e+01, 1.6983e+02, 3.5241e+01, 1.5995e+02,
        3.3091e+02, 3.3115e+01, 2.0946e+01, 2.9697e+01, 8.1910e+01, 1.1621e+02,
        5.1274e+01, 1.7696e+01, 1.7486e+01, 1.5473e+01, 3.5044e+01, 8.0592e+01,
        7.4800e+01, 1.9134e+01, 2.4601e+02, 5.0347e+01, 1.1167e+02, 2.9313e+01,
        2.8582e+01, 5.3721e+03, 1.0162e+02, 3.7699e+01, 4.7836e+01, 1.3139e+02,
        3.7488e+00, 6.8149e+01]), tensor([3.5102e+02, 5.5576e+01, 3.2368e+01, 1.6157e+02, 3.3062e+01, 1.4590e+02,
        3.3726e+02, 2.8530e+01, 1.8186e+01, 2.9630e+01, 8.1447e+01, 1.0659e+02,
        5.2919e+01, 1.7357e+01, 1.5400e+01, 1.2049e+01, 3.1464e+01, 8.6491e+01,
        7.3177e+01, 1.6173e+01, 2.4860e+02, 4.4929

[tensor([1.6896e+02, 5.9703e+01, 5.7703e+01, 9.6886e+01, 3.6896e+01, 1.6355e+01,
        7.1880e+00, 3.6380e+01, 0.0000e+00, 1.6041e+01, 8.0718e+01, 2.1405e+00,
        1.2104e+02, 0.0000e+00, 9.1076e+01, 2.0349e+01, 6.3565e+01, 1.6888e+02,
        3.7753e+01, 1.3329e+01, 9.4498e+00, 3.2389e+01, 3.2138e+02, 4.1536e+01,
        2.2068e+03, 1.3699e+01, 2.4582e+02, 1.6240e+01, 8.3169e+00, 1.1748e+02,
        4.0098e+01, 1.3213e+01]), tensor([1.6046e+02, 5.5210e+01, 3.8388e+01, 8.8079e+01, 3.4504e+01, 1.4464e+01,
        6.9382e+00, 3.6350e+01, 0.0000e+00, 1.6112e+01, 7.7167e+01, 2.1220e+00,
        1.1292e+02, 0.0000e+00, 1.0869e+02, 1.6787e+01, 5.2128e+01, 1.6531e+02,
        3.4444e+01, 1.2753e+01, 8.5148e+00, 3.0237e+01, 2.9233e+02, 3.4733e+01,
        2.6875e+03, 1.1598e+01, 2.5372e+02, 1.7503e+01, 9.9307e+00, 1.1206e+02,
        4.0245e+01, 9.6233e+00]), tensor([1.3436e+02, 5.4461e+01, 2.8850e+01, 7.9548e+01, 3.2593e+01, 1.2042e+01,
        5.9246e+00, 3.7218e+01, 0.0000e+00, 1.4357e

tensor([[1., 0., 0.,  ..., 0., 1., 0.],
        [0., 0., 0.,  ..., 0., 0., 1.],
        [0., 0., 0.,  ..., 1., 0., 0.],
        ...,
        [0., 0., 1.,  ..., 0., 0., 1.],
        [0., 0., 1.,  ..., 0., 0., 1.],
        [1., 0., 0.,  ..., 0., 0., 0.]], dtype=torch.float64)
[tensor([ 8.2193, 86.4518, 76.7569, 47.9235, 14.0302, 63.3489, 38.3847, 25.2129,
        50.3810,  5.2269]), tensor([ 7.5014, 84.0678, 73.5852, 45.0933, 16.3062, 62.0295, 31.3096, 24.9907,
        41.6491,  5.0146]), tensor([ 6.9560, 89.8796, 77.6459, 52.3342, 23.4459, 62.2770, 26.3080, 13.6674,
        37.6137,  6.1552]), tensor([  7.9818, 104.7552,  74.9064,  59.1036,  19.6287,  68.4303,  33.9764,
          5.5580,  36.7909,   6.6637])]
*******


In [None]:
# ^^ yb seems ot be a list of tensors, each the length of a batch?

In [21]:
class DNA_Linear_Multi(nn.Module):
    def __init__(self,seq_len,h1_size,h2_size):
        super().__init__()
        
        # some arbitrary arch of a few linear layers
        self.lin_share = nn.Sequential(
            nn.Linear(4*seq_len, h1_size),
            nn.ReLU(inplace=True),
            nn.Linear(h1_size, h2_size),
            nn.ReLU(inplace=True)
        )
        
        # define the multi task objectectives?
        self.nocu_obj   = nn.Linear(h2_size,1)
        self.lowcu_obj  = nn.Linear(h2_size,1)
        self.medcu_obj  = nn.Linear(h2_size,1)
        self.highcu_obj = nn.Linear(h2_size,1)
        
        

    def forward(self, xb):
        # Linear wraps up the weights/bias dot product operations
        out = self.lin_share(xb)
        out_nocu = self.nocu_obj(out)
        out_lowcu = self.lowcu_obj(out)
        out_medcu = self.medcu_obj(out)
        out_highcu = self.highcu_obj(out)
        
        return [out_nocu, out_lowcu, out_medcu, out_highcu]
        # ^^ where does this go and what does it do?
    


In [68]:
h1 = 400
h2 = 200
multi_lin_model = DNA_Linear_Multi(100, h1,h2)

In [69]:
loss_func = torch.nn.MSELoss() 
optimizer = torch.optim.SGD(multi_lin_model.parameters(), lr=0.0001) 

In [70]:
def loss_batch(model, loss_func, xb, yb, opt=None):
    '''
    Apply loss function to a batch of inputs. If no optimizer
    is provided, skip the back prop step.
    '''
    #print("in Loss Batch...")

    #loss = loss_func(model(xb.float()), yb.float())
    # yb is a list of tensors... what should it be?
    # how to adapt loss for multi-task?
    #print("yb[0]:", yb[0])
    
    # try looping through ybs?
    res = model(xb.float()) # get predictions? returns list
    
    loss_nocu = loss_func(res[0], yb[0].float())
    loss_lowcu = loss_func(res[1], yb[1].float())
    loss_medcu = loss_func(res[2], yb[2].float())
    loss_highcu = loss_func(res[3], yb[3].float())
    
    loss = sum([loss_nocu, loss_lowcu, loss_medcu, loss_highcu])
    
    print("what is 'loss'?:")
    print(loss_nocu)
    print(loss_lowcu)
    print(loss_medcu)
    print(loss_highcu)
    print("*****")

    if opt is not None:
        loss.backward()
        opt.step()
        opt.zero_grad()

    #print("lb returning:",loss.item(), len(xb))
    return loss.item(), len(xb)


def fit(epochs, model, loss_func, opt, train_dl, test_dl):
    '''
    Fit the model params to the training data, eval on unseen data.
    Loop for a number of epochs and keep train of train and val losses 
    along the way
    '''
    # keep track of losses
    train_losses = []    
    val_losses = []
    
    # loops through epochs
    for epoch in range(epochs):
        print('Epoch:',epoch)
        #print("TRAIN")
        model.train()
        ts = []
        ns = []
        # collect train loss; provide opt so backpropo happens
        for xb, yb in train_dl:
            t, n = loss_batch(model, loss_func, xb, yb, opt)
            ts.append(t)
            ns.append(n)
        train_loss = np.sum(np.multiply(ts, ns)) / np.sum(ns)
        train_losses.append(train_loss)
        
        #print("EVAL")
        model.eval()
        with torch.no_grad():
            losses, nums = zip(
                # loop through test batches
                # returns loss calc for test set batch size
                # unzips into two lists
                *[loss_batch(model, loss_func, xb, yb) for xb, yb in test_dl]
                # Note: no opt provided, backprop won't happen
            )
        # Gets average MSE loss across all batches (may be of diff sizes, hence the multiply)
        #print("losses", losses)
        #print("nums", nums)
        val_loss = np.sum(np.multiply(losses, nums)) / np.sum(nums)

        print(epoch, val_loss)
        val_losses.append(val_loss)

    return train_losses, val_losses

In [71]:
epochs = 20

train_losses, val_losses = fit(epochs, multi_lin_model, loss_func, optimizer, train_dl, test_dl)

Epoch: 0
what is 'loss'?:
tensor(64548.2930, grad_fn=<MseLossBackward>)
tensor(98027.2031, grad_fn=<MseLossBackward>)
tensor(110896.1562, grad_fn=<MseLossBackward>)
tensor(99123.4688, grad_fn=<MseLossBackward>)
*****
what is 'loss'?:
tensor(243965.7812, grad_fn=<MseLossBackward>)
tensor(298228.7188, grad_fn=<MseLossBackward>)
tensor(315704.4375, grad_fn=<MseLossBackward>)
tensor(322020.7500, grad_fn=<MseLossBackward>)
*****
what is 'loss'?:
tensor(66252.6562, grad_fn=<MseLossBackward>)
tensor(75454.8906, grad_fn=<MseLossBackward>)
tensor(97184.6875, grad_fn=<MseLossBackward>)
tensor(108635.1406, grad_fn=<MseLossBackward>)
*****
what is 'loss'?:
tensor(97658.7266, grad_fn=<MseLossBackward>)
tensor(138933.5625, grad_fn=<MseLossBackward>)
tensor(146969.5156, grad_fn=<MseLossBackward>)
tensor(180067.1094, grad_fn=<MseLossBackward>)
*****
what is 'loss'?:
tensor(17498.4922, grad_fn=<MseLossBackward>)
tensor(16036.3896, grad_fn=<MseLossBackward>)
tensor(13601.4814, grad_fn=<MseLossBackward>)

tensor(14105872., grad_fn=<MseLossBackward>)
tensor(23151940., grad_fn=<MseLossBackward>)
*****
what is 'loss'?:
tensor(63618.5938, grad_fn=<MseLossBackward>)
tensor(98763.7812, grad_fn=<MseLossBackward>)
tensor(115884.1562, grad_fn=<MseLossBackward>)
tensor(109312.6562, grad_fn=<MseLossBackward>)
*****
what is 'loss'?:
tensor(976066.6875, grad_fn=<MseLossBackward>)
tensor(879121.4375, grad_fn=<MseLossBackward>)
tensor(797854.5625, grad_fn=<MseLossBackward>)
tensor(1308351.1250, grad_fn=<MseLossBackward>)
*****
what is 'loss'?:
tensor(3808.9521, grad_fn=<MseLossBackward>)
tensor(3346.6875, grad_fn=<MseLossBackward>)
tensor(3208.8340, grad_fn=<MseLossBackward>)
tensor(4243.4287, grad_fn=<MseLossBackward>)
*****
what is 'loss'?:
tensor(9504.0283, grad_fn=<MseLossBackward>)
tensor(9730.0674, grad_fn=<MseLossBackward>)
tensor(9911.0781, grad_fn=<MseLossBackward>)
tensor(11577.5840, grad_fn=<MseLossBackward>)
*****
what is 'loss'?:
tensor(245625.2656, grad_fn=<MseLossBackward>)
tensor(26620

what is 'loss'?:
tensor(111599.5625, grad_fn=<MseLossBackward>)
tensor(126275.4375, grad_fn=<MseLossBackward>)
tensor(147456.5469, grad_fn=<MseLossBackward>)
tensor(199218.7656, grad_fn=<MseLossBackward>)
*****
what is 'loss'?:
tensor(45836.8359, grad_fn=<MseLossBackward>)
tensor(61937.8750, grad_fn=<MseLossBackward>)
tensor(72932.2109, grad_fn=<MseLossBackward>)
tensor(89716.7891, grad_fn=<MseLossBackward>)
*****
what is 'loss'?:
tensor(81890.0625, grad_fn=<MseLossBackward>)
tensor(119229.0312, grad_fn=<MseLossBackward>)
tensor(172414.5000, grad_fn=<MseLossBackward>)
tensor(127658.5547, grad_fn=<MseLossBackward>)
*****
what is 'loss'?:
tensor(246284.6719, grad_fn=<MseLossBackward>)
tensor(356817.8750, grad_fn=<MseLossBackward>)
tensor(728525.3125, grad_fn=<MseLossBackward>)
tensor(484301.6250, grad_fn=<MseLossBackward>)
*****
what is 'loss'?:
tensor(112707.7734, grad_fn=<MseLossBackward>)
tensor(125595.5000, grad_fn=<MseLossBackward>)
tensor(228053., grad_fn=<MseLossBackward>)
tensor(

what is 'loss'?:
tensor(4980660.5000, grad_fn=<MseLossBackward>)
tensor(5151231.5000, grad_fn=<MseLossBackward>)
tensor(3788013., grad_fn=<MseLossBackward>)
tensor(3574328., grad_fn=<MseLossBackward>)
*****
what is 'loss'?:
tensor(367055.5312, grad_fn=<MseLossBackward>)
tensor(454764.4062, grad_fn=<MseLossBackward>)
tensor(452884.2188, grad_fn=<MseLossBackward>)
tensor(393365.1875, grad_fn=<MseLossBackward>)
*****
what is 'loss'?:
tensor(80922.2578, grad_fn=<MseLossBackward>)
tensor(120676.6484, grad_fn=<MseLossBackward>)
tensor(125310.7031, grad_fn=<MseLossBackward>)
tensor(153692.8750, grad_fn=<MseLossBackward>)
*****
what is 'loss'?:
tensor(19809.1602, grad_fn=<MseLossBackward>)
tensor(15863.1152, grad_fn=<MseLossBackward>)
tensor(11860.7031, grad_fn=<MseLossBackward>)
tensor(13539.5195, grad_fn=<MseLossBackward>)
*****
what is 'loss'?:
tensor(128060.6484, grad_fn=<MseLossBackward>)
tensor(58751.1016, grad_fn=<MseLossBackward>)
tensor(25271.2773, grad_fn=<MseLossBackward>)
tensor(34

what is 'loss'?:
tensor(59531.9297, grad_fn=<MseLossBackward>)
tensor(107629.7266, grad_fn=<MseLossBackward>)
tensor(173262.4688, grad_fn=<MseLossBackward>)
tensor(123057.5312, grad_fn=<MseLossBackward>)
*****
what is 'loss'?:
tensor(17072.4551, grad_fn=<MseLossBackward>)
tensor(14571.6191, grad_fn=<MseLossBackward>)
tensor(13051.3311, grad_fn=<MseLossBackward>)
tensor(20436.0781, grad_fn=<MseLossBackward>)
*****
what is 'loss'?:
tensor(172003.1250, grad_fn=<MseLossBackward>)
tensor(247970.5000, grad_fn=<MseLossBackward>)
tensor(472694.3125, grad_fn=<MseLossBackward>)
tensor(377295.7812, grad_fn=<MseLossBackward>)
*****
what is 'loss'?:
tensor(148544.9531, grad_fn=<MseLossBackward>)
tensor(192437.6250, grad_fn=<MseLossBackward>)
tensor(202502.2969, grad_fn=<MseLossBackward>)
tensor(197904.3594, grad_fn=<MseLossBackward>)
*****
what is 'loss'?:
tensor(57156.9219, grad_fn=<MseLossBackward>)
tensor(56867.1836, grad_fn=<MseLossBackward>)
tensor(67053.0703, grad_fn=<MseLossBackward>)
tensor

what is 'loss'?:
tensor(176338.4688)
tensor(199602.8750)
tensor(242783.3125)
tensor(285429.8750)
*****
what is 'loss'?:
tensor(43200.6797)
tensor(38176.4688)
tensor(50864.8047)
tensor(64938.5859)
*****
what is 'loss'?:
tensor(146190.2656)
tensor(143460.9688)
tensor(111918.4609)
tensor(140721.7656)
*****
what is 'loss'?:
tensor(68760.3750)
tensor(57195.8750)
tensor(68883.7266)
tensor(62057.9531)
*****
what is 'loss'?:
tensor(888372.5625)
tensor(972664.3125)
tensor(913195.8750)
tensor(957132.3750)
*****
what is 'loss'?:
tensor(55019676.)
tensor(46920084.)
tensor(78789240.)
tensor(71038888.)
*****
what is 'loss'?:
tensor(93690.1094)
tensor(78516.7500)
tensor(80943.6797)
tensor(115979.3906)
*****
what is 'loss'?:
tensor(6323565.)
tensor(4649998.5000)
tensor(3195619.)
tensor(6004368.)
*****
1 20927704.46619217
Epoch: 2
what is 'loss'?:
tensor(941274.1250, grad_fn=<MseLossBackward>)
tensor(1013970.5625, grad_fn=<MseLossBackward>)
tensor(1510896.3750, grad_fn=<MseLossBackward>)
tensor(1332480

what is 'loss'?:
tensor(85004.5469, grad_fn=<MseLossBackward>)
tensor(118488.2188, grad_fn=<MseLossBackward>)
tensor(137454.4219, grad_fn=<MseLossBackward>)
tensor(140731.1250, grad_fn=<MseLossBackward>)
*****
what is 'loss'?:
tensor(144926.8438, grad_fn=<MseLossBackward>)
tensor(159972.5469, grad_fn=<MseLossBackward>)
tensor(132625.5156, grad_fn=<MseLossBackward>)
tensor(110351.5703, grad_fn=<MseLossBackward>)
*****
what is 'loss'?:
tensor(54678.5547, grad_fn=<MseLossBackward>)
tensor(73572.6250, grad_fn=<MseLossBackward>)
tensor(85556.7188, grad_fn=<MseLossBackward>)
tensor(88653.7969, grad_fn=<MseLossBackward>)
*****
what is 'loss'?:
tensor(3926559.2500, grad_fn=<MseLossBackward>)
tensor(2985448.5000, grad_fn=<MseLossBackward>)
tensor(2488142.7500, grad_fn=<MseLossBackward>)
tensor(33998.1133, grad_fn=<MseLossBackward>)
*****
what is 'loss'?:
tensor(12304909., grad_fn=<MseLossBackward>)
tensor(5572234., grad_fn=<MseLossBackward>)
tensor(261985.1875, grad_fn=<MseLossBackward>)
tensor

what is 'loss'?:
tensor(5658.2529, grad_fn=<MseLossBackward>)
tensor(4686.4048, grad_fn=<MseLossBackward>)
tensor(4602.3164, grad_fn=<MseLossBackward>)
tensor(6774.1548, grad_fn=<MseLossBackward>)
*****
what is 'loss'?:
tensor(79390.1641)
tensor(88724.6484)
tensor(77468.3359)
tensor(68085.3906)
*****
what is 'loss'?:
tensor(91398.4062)
tensor(98297.4297)
tensor(131429.4688)
tensor(147526.1875)
*****
what is 'loss'?:
tensor(1118640.5000)
tensor(978899.8125)
tensor(894922.3125)
tensor(1500485.7500)
*****
what is 'loss'?:
tensor(5395428.5000)
tensor(2817683.2500)
tensor(301971.5312)
tensor(118714.4766)
*****
what is 'loss'?:
tensor(52394.5039)
tensor(78837.3125)
tensor(107535.7109)
tensor(86627.8828)
*****
what is 'loss'?:
tensor(26057.1426)
tensor(32406.9922)
tensor(41905.1875)
tensor(42204.8359)
*****
what is 'loss'?:
tensor(175158.7344)
tensor(198404.0156)
tensor(241574.7656)
tensor(284154.4062)
*****
what is 'loss'?:
tensor(42334.0781)
tensor(37329.5430)
tensor(49940.8750)
tensor(6390

what is 'loss'?:
tensor(50136.1992, grad_fn=<MseLossBackward>)
tensor(52047.7969, grad_fn=<MseLossBackward>)
tensor(62856.6055, grad_fn=<MseLossBackward>)
tensor(57533.0234, grad_fn=<MseLossBackward>)
*****
what is 'loss'?:
tensor(53267.3906, grad_fn=<MseLossBackward>)
tensor(70049.2656, grad_fn=<MseLossBackward>)
tensor(83309.3281, grad_fn=<MseLossBackward>)
tensor(82415.6328, grad_fn=<MseLossBackward>)
*****
what is 'loss'?:
tensor(18679010., grad_fn=<MseLossBackward>)
tensor(31538706., grad_fn=<MseLossBackward>)
tensor(42384596., grad_fn=<MseLossBackward>)
tensor(64610428., grad_fn=<MseLossBackward>)
*****
what is 'loss'?:
tensor(13767.3125, grad_fn=<MseLossBackward>)
tensor(9907.7227, grad_fn=<MseLossBackward>)
tensor(8352.3682, grad_fn=<MseLossBackward>)
tensor(13974.8975, grad_fn=<MseLossBackward>)
*****
what is 'loss'?:
tensor(193349.7656, grad_fn=<MseLossBackward>)
tensor(177918.6719, grad_fn=<MseLossBackward>)
tensor(203443.6250, grad_fn=<MseLossBackward>)
tensor(231853.8594, 

tensor(270109.9062, grad_fn=<MseLossBackward>)
tensor(354544.7812, grad_fn=<MseLossBackward>)
tensor(230202.8438, grad_fn=<MseLossBackward>)
tensor(205096.0938, grad_fn=<MseLossBackward>)
*****
what is 'loss'?:
tensor(202199., grad_fn=<MseLossBackward>)
tensor(229625.1875, grad_fn=<MseLossBackward>)
tensor(190667.6719, grad_fn=<MseLossBackward>)
tensor(228482.5938, grad_fn=<MseLossBackward>)
*****
what is 'loss'?:
tensor(90503.2109, grad_fn=<MseLossBackward>)
tensor(144728.8438, grad_fn=<MseLossBackward>)
tensor(178077.6719, grad_fn=<MseLossBackward>)
tensor(123224.7656, grad_fn=<MseLossBackward>)
*****
what is 'loss'?:
tensor(8119.6465, grad_fn=<MseLossBackward>)
tensor(7431.6938, grad_fn=<MseLossBackward>)
tensor(6454.1221, grad_fn=<MseLossBackward>)
tensor(7483.4609, grad_fn=<MseLossBackward>)
*****
what is 'loss'?:
tensor(29725.8945, grad_fn=<MseLossBackward>)
tensor(40690.5312, grad_fn=<MseLossBackward>)
tensor(52573.4648, grad_fn=<MseLossBackward>)
tensor(74289.5547, grad_fn=<Mse

what is 'loss'?:
tensor(168002.3125, grad_fn=<MseLossBackward>)
tensor(217662.2188, grad_fn=<MseLossBackward>)
tensor(364175.0625, grad_fn=<MseLossBackward>)
tensor(427619.4688, grad_fn=<MseLossBackward>)
*****
what is 'loss'?:
tensor(33582.9648, grad_fn=<MseLossBackward>)
tensor(28150.4160, grad_fn=<MseLossBackward>)
tensor(27737.2051, grad_fn=<MseLossBackward>)
tensor(54428.7812, grad_fn=<MseLossBackward>)
*****
what is 'loss'?:
tensor(121785.0781, grad_fn=<MseLossBackward>)
tensor(155381.4375, grad_fn=<MseLossBackward>)
tensor(231639.6719, grad_fn=<MseLossBackward>)
tensor(232396.9375, grad_fn=<MseLossBackward>)
*****
what is 'loss'?:
tensor(80950.9922, grad_fn=<MseLossBackward>)
tensor(86563.3906, grad_fn=<MseLossBackward>)
tensor(93308.9375, grad_fn=<MseLossBackward>)
tensor(112280.6016, grad_fn=<MseLossBackward>)
*****
what is 'loss'?:
tensor(9822.0215, grad_fn=<MseLossBackward>)
tensor(11760.8730, grad_fn=<MseLossBackward>)
tensor(17906.9492, grad_fn=<MseLossBackward>)
tensor(15

what is 'loss'?:
tensor(1145156.1250, grad_fn=<MseLossBackward>)
tensor(771333., grad_fn=<MseLossBackward>)
tensor(907942.2500, grad_fn=<MseLossBackward>)
tensor(1098600.6250, grad_fn=<MseLossBackward>)
*****
what is 'loss'?:
tensor(9187917., grad_fn=<MseLossBackward>)
tensor(8148332., grad_fn=<MseLossBackward>)
tensor(7793025., grad_fn=<MseLossBackward>)
tensor(4984514., grad_fn=<MseLossBackward>)
*****
what is 'loss'?:
tensor(52140.2266, grad_fn=<MseLossBackward>)
tensor(58237.1055, grad_fn=<MseLossBackward>)
tensor(65835.9766, grad_fn=<MseLossBackward>)
tensor(78998.7969, grad_fn=<MseLossBackward>)
*****
what is 'loss'?:
tensor(66821.9375, grad_fn=<MseLossBackward>)
tensor(84796.6641, grad_fn=<MseLossBackward>)
tensor(125748.3594, grad_fn=<MseLossBackward>)
tensor(118774.1250, grad_fn=<MseLossBackward>)
*****
what is 'loss'?:
tensor(233309.9531, grad_fn=<MseLossBackward>)
tensor(273621.7500, grad_fn=<MseLossBackward>)
tensor(223099.5781, grad_fn=<MseLossBackward>)
tensor(77892.8281,

tensor(1492169.5000)
*****
what is 'loss'?:
tensor(5386787.)
tensor(2810780.5000)
tensor(298845.1562)
tensor(116745.3203)
*****
what is 'loss'?:
tensor(50557.9102)
tensor(76873.6250)
tensor(105535.0703)
tensor(84570.5391)
*****
what is 'loss'?:
tensor(24440.9863)
tensor(30747.1367)
tensor(40179.4023)
tensor(40314.0195)
*****
what is 'loss'?:
tensor(173002.1406)
tensor(196214.0625)
tensor(239366.0312)
tensor(281820.5312)
*****
what is 'loss'?:
tensor(40779.8906)
tensor(35816.4609)
tensor(48279.1562)
tensor(62040.2109)
*****
what is 'loss'?:
tensor(143562.3125)
tensor(140786.3438)
tensor(109459.1250)
tensor(137933.4688)
*****
what is 'loss'?:
tensor(66102.8906)
tensor(54674.3594)
tensor(66397.5859)
tensor(59567.9297)
*****
what is 'loss'?:
tensor(881930.)
tensor(966115.9375)
tensor(906602.2500)
tensor(950103.5625)
*****
what is 'loss'?:
tensor(54982616.)
tensor(46883060.)
tensor(78744912.)
tensor(70999584.)
*****
what is 'loss'?:
tensor(91825.6484)
tensor(76793.2266)
tensor(79254.8125)
t

what is 'loss'?:
tensor(136998.5000, grad_fn=<MseLossBackward>)
tensor(182137.1250, grad_fn=<MseLossBackward>)
tensor(251709.6094, grad_fn=<MseLossBackward>)
tensor(235744.0781, grad_fn=<MseLossBackward>)
*****
what is 'loss'?:
tensor(19859.5723, grad_fn=<MseLossBackward>)
tensor(28044.4219, grad_fn=<MseLossBackward>)
tensor(35985.4453, grad_fn=<MseLossBackward>)
tensor(31997.8770, grad_fn=<MseLossBackward>)
*****
what is 'loss'?:
tensor(40954.2930, grad_fn=<MseLossBackward>)
tensor(67901.5703, grad_fn=<MseLossBackward>)
tensor(87205.2656, grad_fn=<MseLossBackward>)
tensor(79953.5234, grad_fn=<MseLossBackward>)
*****
what is 'loss'?:
tensor(44507.4219, grad_fn=<MseLossBackward>)
tensor(66212.3594, grad_fn=<MseLossBackward>)
tensor(102791.9453, grad_fn=<MseLossBackward>)
tensor(112222.1875, grad_fn=<MseLossBackward>)
*****
what is 'loss'?:
tensor(152436., grad_fn=<MseLossBackward>)
tensor(204662.3281, grad_fn=<MseLossBackward>)
tensor(368919.7812, grad_fn=<MseLossBackward>)
tensor(31083

what is 'loss'?:
tensor(6160.8286, grad_fn=<MseLossBackward>)
tensor(5751.3247, grad_fn=<MseLossBackward>)
tensor(5567.1592, grad_fn=<MseLossBackward>)
tensor(7173.9072, grad_fn=<MseLossBackward>)
*****
what is 'loss'?:
tensor(36046.4766, grad_fn=<MseLossBackward>)
tensor(30487.6621, grad_fn=<MseLossBackward>)
tensor(29125.5371, grad_fn=<MseLossBackward>)
tensor(53042.3359, grad_fn=<MseLossBackward>)
*****
what is 'loss'?:
tensor(12231362., grad_fn=<MseLossBackward>)
tensor(5493265., grad_fn=<MseLossBackward>)
tensor(179597.9844, grad_fn=<MseLossBackward>)
tensor(7616.2661, grad_fn=<MseLossBackward>)
*****
what is 'loss'?:
tensor(214751.9688, grad_fn=<MseLossBackward>)
tensor(280403.7188, grad_fn=<MseLossBackward>)
tensor(409883.9688, grad_fn=<MseLossBackward>)
tensor(419814.4062, grad_fn=<MseLossBackward>)
*****
what is 'loss'?:
tensor(97050.3672, grad_fn=<MseLossBackward>)
tensor(105780.4844, grad_fn=<MseLossBackward>)
tensor(75962.5234, grad_fn=<MseLossBackward>)
tensor(80744.8906, 

what is 'loss'?:
tensor(205166.3125, grad_fn=<MseLossBackward>)
tensor(240431.9375, grad_fn=<MseLossBackward>)
tensor(227395.3125, grad_fn=<MseLossBackward>)
tensor(261705.7656, grad_fn=<MseLossBackward>)
*****
what is 'loss'?:
tensor(6824.2471, grad_fn=<MseLossBackward>)
tensor(6469.6372, grad_fn=<MseLossBackward>)
tensor(8034.7217, grad_fn=<MseLossBackward>)
tensor(11459.4678, grad_fn=<MseLossBackward>)
*****
what is 'loss'?:
tensor(172545.4688, grad_fn=<MseLossBackward>)
tensor(214619.0469, grad_fn=<MseLossBackward>)
tensor(311579.1562, grad_fn=<MseLossBackward>)
tensor(253127.9375, grad_fn=<MseLossBackward>)
*****
what is 'loss'?:
tensor(94729.0156, grad_fn=<MseLossBackward>)
tensor(103645.6562, grad_fn=<MseLossBackward>)
tensor(213669.5938, grad_fn=<MseLossBackward>)
tensor(279860.2500, grad_fn=<MseLossBackward>)
*****
what is 'loss'?:
tensor(46004.5312, grad_fn=<MseLossBackward>)
tensor(76959.2344, grad_fn=<MseLossBackward>)
tensor(104846.2969, grad_fn=<MseLossBackward>)
tensor(6

what is 'loss'?:
tensor(515042.5000, grad_fn=<MseLossBackward>)
tensor(479787.1250, grad_fn=<MseLossBackward>)
tensor(757075.1250, grad_fn=<MseLossBackward>)
tensor(659577.8125, grad_fn=<MseLossBackward>)
*****
what is 'loss'?:
tensor(84088.8516, grad_fn=<MseLossBackward>)
tensor(99024.6484, grad_fn=<MseLossBackward>)
tensor(129803.2969, grad_fn=<MseLossBackward>)
tensor(165818.6562, grad_fn=<MseLossBackward>)
*****
what is 'loss'?:
tensor(66677.7812, grad_fn=<MseLossBackward>)
tensor(78613.4922, grad_fn=<MseLossBackward>)
tensor(71458.0312, grad_fn=<MseLossBackward>)
tensor(81235.7656, grad_fn=<MseLossBackward>)
*****
what is 'loss'?:
tensor(933515.5000, grad_fn=<MseLossBackward>)
tensor(1207811.7500, grad_fn=<MseLossBackward>)
tensor(819348.5625, grad_fn=<MseLossBackward>)
tensor(926845.8750, grad_fn=<MseLossBackward>)
*****
what is 'loss'?:
tensor(1086532.3750, grad_fn=<MseLossBackward>)
tensor(692644.2500, grad_fn=<MseLossBackward>)
tensor(776458.6250, grad_fn=<MseLossBackward>)
te

what is 'loss'?:
tensor(50460.4531, grad_fn=<MseLossBackward>)
tensor(53606.6992, grad_fn=<MseLossBackward>)
tensor(101307., grad_fn=<MseLossBackward>)
tensor(92484.9531, grad_fn=<MseLossBackward>)
*****
what is 'loss'?:
tensor(39369.1680, grad_fn=<MseLossBackward>)
tensor(36867.5898, grad_fn=<MseLossBackward>)
tensor(47570.1875, grad_fn=<MseLossBackward>)
tensor(59663.7656, grad_fn=<MseLossBackward>)
*****
what is 'loss'?:
tensor(6866.0225, grad_fn=<MseLossBackward>)
tensor(6704.4761, grad_fn=<MseLossBackward>)
tensor(5979.1084, grad_fn=<MseLossBackward>)
tensor(6574.3950, grad_fn=<MseLossBackward>)
*****
what is 'loss'?:
tensor(45361.0273, grad_fn=<MseLossBackward>)
tensor(45312.9297, grad_fn=<MseLossBackward>)
tensor(55394.9766, grad_fn=<MseLossBackward>)
tensor(58416.6406, grad_fn=<MseLossBackward>)
*****
what is 'loss'?:
tensor(74806.2812, grad_fn=<MseLossBackward>)
tensor(89735.3828, grad_fn=<MseLossBackward>)
tensor(88763.0469, grad_fn=<MseLossBackward>)
tensor(139266.0938, grad

what is 'loss'?:
tensor(75737.6250, grad_fn=<MseLossBackward>)
tensor(80935.3516, grad_fn=<MseLossBackward>)
tensor(138464.9375, grad_fn=<MseLossBackward>)
tensor(148142.4062, grad_fn=<MseLossBackward>)
*****
what is 'loss'?:
tensor(18160978., grad_fn=<MseLossBackward>)
tensor(31140990., grad_fn=<MseLossBackward>)
tensor(41758720., grad_fn=<MseLossBackward>)
tensor(64075348., grad_fn=<MseLossBackward>)
*****
what is 'loss'?:
tensor(12256424., grad_fn=<MseLossBackward>)
tensor(5521355.5000, grad_fn=<MseLossBackward>)
tensor(214197.3750, grad_fn=<MseLossBackward>)
tensor(35360.5938, grad_fn=<MseLossBackward>)
*****
what is 'loss'?:
tensor(48923.0625, grad_fn=<MseLossBackward>)
tensor(55303.6602, grad_fn=<MseLossBackward>)
tensor(58715.0312, grad_fn=<MseLossBackward>)
tensor(53875.7461, grad_fn=<MseLossBackward>)
*****
what is 'loss'?:
tensor(8757.7920, grad_fn=<MseLossBackward>)
tensor(8397.1523, grad_fn=<MseLossBackward>)
tensor(9705.1875, grad_fn=<MseLossBackward>)
tensor(12847.9531, g

what is 'loss'?:
tensor(170167.8281)
tensor(193337.)
tensor(236456.6094)
tensor(278733.9375)
*****
what is 'loss'?:
tensor(38809.8555)
tensor(33910.7266)
tensor(46155.3750)
tensor(59625.5508)
*****
what is 'loss'?:
tensor(141396.6719)
tensor(138583.5781)
tensor(107454.6719)
tensor(135622.8750)
*****
what is 'loss'?:
tensor(63909.3984)
tensor(52616.1328)
tensor(64367.7891)
tensor(57539.1992)
*****
what is 'loss'?:
tensor(876163.6875)
tensor(960255.5625)
tensor(900692.0625)
tensor(943785.2500)
*****
what is 'loss'?:
tensor(54947928.)
tensor(46848416.)
tensor(78703360.)
tensor(70962736.)
*****
what is 'loss'?:
tensor(90380.7812)
tensor(75488.4688)
tensor(77978.1875)
tensor(112220.6172)
*****
what is 'loss'?:
tensor(6283313.5000)
tensor(4614977.5000)
tensor(3166487.2500)
tensor(5964364.5000)
*****
7 20873668.43297746
Epoch: 8
what is 'loss'?:
tensor(4805375.5000, grad_fn=<MseLossBackward>)
tensor(5071206.5000, grad_fn=<MseLossBackward>)
tensor(3768856.5000, grad_fn=<MseLossBackward>)
tenso

what is 'loss'?:
tensor(88978.9453, grad_fn=<MseLossBackward>)
tensor(115836.6094, grad_fn=<MseLossBackward>)
tensor(144283.2031, grad_fn=<MseLossBackward>)
tensor(131902.8594, grad_fn=<MseLossBackward>)
*****
what is 'loss'?:
tensor(11844.9170, grad_fn=<MseLossBackward>)
tensor(11822.2744, grad_fn=<MseLossBackward>)
tensor(10023.6201, grad_fn=<MseLossBackward>)
tensor(13550.3672, grad_fn=<MseLossBackward>)
*****
what is 'loss'?:
tensor(321960.1250, grad_fn=<MseLossBackward>)
tensor(290705.0312, grad_fn=<MseLossBackward>)
tensor(278059.2500, grad_fn=<MseLossBackward>)
tensor(333078.3438, grad_fn=<MseLossBackward>)
*****
what is 'loss'?:
tensor(172475.5312, grad_fn=<MseLossBackward>)
tensor(214020.0781, grad_fn=<MseLossBackward>)
tensor(313891.9688, grad_fn=<MseLossBackward>)
tensor(256241.8281, grad_fn=<MseLossBackward>)
*****
what is 'loss'?:
tensor(39307.2148, grad_fn=<MseLossBackward>)
tensor(55376.1836, grad_fn=<MseLossBackward>)
tensor(57276.1875, grad_fn=<MseLossBackward>)
tensor

what is 'loss'?:
tensor(99757.5312, grad_fn=<MseLossBackward>)
tensor(138617.4062, grad_fn=<MseLossBackward>)
tensor(192750.0781, grad_fn=<MseLossBackward>)
tensor(206009.0781, grad_fn=<MseLossBackward>)
*****
what is 'loss'?:
tensor(8206.3340, grad_fn=<MseLossBackward>)
tensor(9268.3066, grad_fn=<MseLossBackward>)
tensor(11479.7793, grad_fn=<MseLossBackward>)
tensor(9910.7910, grad_fn=<MseLossBackward>)
*****
what is 'loss'?:
tensor(286660.1875, grad_fn=<MseLossBackward>)
tensor(345959.7812, grad_fn=<MseLossBackward>)
tensor(341676.2500, grad_fn=<MseLossBackward>)
tensor(140568.2969, grad_fn=<MseLossBackward>)
*****
what is 'loss'?:
tensor(77192.2578, grad_fn=<MseLossBackward>)
tensor(86844.5156, grad_fn=<MseLossBackward>)
tensor(194953.9375, grad_fn=<MseLossBackward>)
tensor(252989.4844, grad_fn=<MseLossBackward>)
*****
what is 'loss'?:
tensor(4078.9888, grad_fn=<MseLossBackward>)
tensor(4985.3794, grad_fn=<MseLossBackward>)
tensor(6447.8120, grad_fn=<MseLossBackward>)
tensor(8865.37

what is 'loss'?:
tensor(155227.1094, grad_fn=<MseLossBackward>)
tensor(195570.5469, grad_fn=<MseLossBackward>)
tensor(218178.7344, grad_fn=<MseLossBackward>)
tensor(274180.3438, grad_fn=<MseLossBackward>)
*****
what is 'loss'?:
tensor(271336.9375, grad_fn=<MseLossBackward>)
tensor(296666.9688, grad_fn=<MseLossBackward>)
tensor(204029.6719, grad_fn=<MseLossBackward>)
tensor(19640.5293, grad_fn=<MseLossBackward>)
*****
what is 'loss'?:
tensor(51954.8477, grad_fn=<MseLossBackward>)
tensor(55976.5859, grad_fn=<MseLossBackward>)
tensor(101278.5938, grad_fn=<MseLossBackward>)
tensor(88469.0469, grad_fn=<MseLossBackward>)
*****
what is 'loss'?:
tensor(20142.5898, grad_fn=<MseLossBackward>)
tensor(17057.4961, grad_fn=<MseLossBackward>)
tensor(18123.6172, grad_fn=<MseLossBackward>)
tensor(35276.0938, grad_fn=<MseLossBackward>)
*****
what is 'loss'?:
tensor(13962.9990, grad_fn=<MseLossBackward>)
tensor(13921.5713, grad_fn=<MseLossBackward>)
tensor(24928.4492, grad_fn=<MseLossBackward>)
tensor(23

what is 'loss'?:
tensor(955418.8125, grad_fn=<MseLossBackward>)
tensor(880177.5000, grad_fn=<MseLossBackward>)
tensor(794187.2500, grad_fn=<MseLossBackward>)
tensor(1228184.7500, grad_fn=<MseLossBackward>)
*****
what is 'loss'?:
tensor(79513.1406, grad_fn=<MseLossBackward>)
tensor(106433.7109, grad_fn=<MseLossBackward>)
tensor(134586.1094, grad_fn=<MseLossBackward>)
tensor(120386.2734, grad_fn=<MseLossBackward>)
*****
what is 'loss'?:
tensor(91771.7109, grad_fn=<MseLossBackward>)
tensor(100678.2734, grad_fn=<MseLossBackward>)
tensor(167340.1250, grad_fn=<MseLossBackward>)
tensor(190962.5781, grad_fn=<MseLossBackward>)
*****
what is 'loss'?:
tensor(127021.5703, grad_fn=<MseLossBackward>)
tensor(185023.4375, grad_fn=<MseLossBackward>)
tensor(218516.2656, grad_fn=<MseLossBackward>)
tensor(244292.4219, grad_fn=<MseLossBackward>)
*****
what is 'loss'?:
tensor(4411.8291, grad_fn=<MseLossBackward>)
tensor(3192.1138, grad_fn=<MseLossBackward>)
tensor(2906.0229, grad_fn=<MseLossBackward>)
tenso

what is 'loss'?:
tensor(845595.1250, grad_fn=<MseLossBackward>)
tensor(939930., grad_fn=<MseLossBackward>)
tensor(1427466.8750, grad_fn=<MseLossBackward>)
tensor(1233526.6250, grad_fn=<MseLossBackward>)
*****
what is 'loss'?:
tensor(81195.4297, grad_fn=<MseLossBackward>)
tensor(108659.4609, grad_fn=<MseLossBackward>)
tensor(134462.7500, grad_fn=<MseLossBackward>)
tensor(136775.5625, grad_fn=<MseLossBackward>)
*****
what is 'loss'?:
tensor(164759.2500, grad_fn=<MseLossBackward>)
tensor(69745.0312, grad_fn=<MseLossBackward>)
tensor(14163.7754, grad_fn=<MseLossBackward>)
tensor(15376.1758, grad_fn=<MseLossBackward>)
*****
what is 'loss'?:
tensor(8860.0674, grad_fn=<MseLossBackward>)
tensor(10054.6836, grad_fn=<MseLossBackward>)
tensor(10542.2012, grad_fn=<MseLossBackward>)
tensor(13359.4463, grad_fn=<MseLossBackward>)
*****
what is 'loss'?:
tensor(4270.0977, grad_fn=<MseLossBackward>)
tensor(4552.0083, grad_fn=<MseLossBackward>)
tensor(7697.8311, grad_fn=<MseLossBackward>)
tensor(6327.864

what is 'loss'?:
tensor(77205.6875, grad_fn=<MseLossBackward>)
tensor(86790.4922, grad_fn=<MseLossBackward>)
tensor(197415.0469, grad_fn=<MseLossBackward>)
tensor(255643.4688, grad_fn=<MseLossBackward>)
*****
what is 'loss'?:
tensor(48694.4883, grad_fn=<MseLossBackward>)
tensor(67970.1484, grad_fn=<MseLossBackward>)
tensor(76636.7500, grad_fn=<MseLossBackward>)
tensor(83680.9375, grad_fn=<MseLossBackward>)
*****
what is 'loss'?:
tensor(26965.6035, grad_fn=<MseLossBackward>)
tensor(25541.2012, grad_fn=<MseLossBackward>)
tensor(34412.4492, grad_fn=<MseLossBackward>)
tensor(34749.5352, grad_fn=<MseLossBackward>)
*****
what is 'loss'?:
tensor(72797.7734, grad_fn=<MseLossBackward>)
tensor(106308.7344, grad_fn=<MseLossBackward>)
tensor(109213.7500, grad_fn=<MseLossBackward>)
tensor(107934.2188, grad_fn=<MseLossBackward>)
*****
what is 'loss'?:
tensor(76267.8125, grad_fn=<MseLossBackward>)
tensor(83463.1562, grad_fn=<MseLossBackward>)
tensor(183864.0625, grad_fn=<MseLossBackward>)
tensor(1592

what is 'loss'?:
tensor(1797.0767, grad_fn=<MseLossBackward>)
tensor(1476.0051, grad_fn=<MseLossBackward>)
tensor(1603.0604, grad_fn=<MseLossBackward>)
tensor(2110.2661, grad_fn=<MseLossBackward>)
*****
what is 'loss'?:
tensor(43311.5117, grad_fn=<MseLossBackward>)
tensor(62414.6250, grad_fn=<MseLossBackward>)
tensor(73966.8203, grad_fn=<MseLossBackward>)
tensor(79542.0938, grad_fn=<MseLossBackward>)
*****
what is 'loss'?:
tensor(238967.2656, grad_fn=<MseLossBackward>)
tensor(340819.1250, grad_fn=<MseLossBackward>)
tensor(743465.8125, grad_fn=<MseLossBackward>)
tensor(589814.5625, grad_fn=<MseLossBackward>)
*****
what is 'loss'?:
tensor(30145.8750, grad_fn=<MseLossBackward>)
tensor(27458.0020, grad_fn=<MseLossBackward>)
tensor(36208.5078, grad_fn=<MseLossBackward>)
tensor(36773.0312, grad_fn=<MseLossBackward>)
*****
what is 'loss'?:
tensor(304766.0625, grad_fn=<MseLossBackward>)
tensor(351054.9375, grad_fn=<MseLossBackward>)
tensor(299946.4375, grad_fn=<MseLossBackward>)
tensor(259730.

tensor(78114.6328, grad_fn=<MseLossBackward>)
tensor(31198.0996, grad_fn=<MseLossBackward>)
tensor(44916.8164, grad_fn=<MseLossBackward>)
*****
what is 'loss'?:
tensor(16847.6992, grad_fn=<MseLossBackward>)
tensor(14324.9717, grad_fn=<MseLossBackward>)
tensor(14668.6934, grad_fn=<MseLossBackward>)
tensor(21811.4238, grad_fn=<MseLossBackward>)
*****
what is 'loss'?:
tensor(1806936., grad_fn=<MseLossBackward>)
tensor(1871221.6250, grad_fn=<MseLossBackward>)
tensor(2851622.2500, grad_fn=<MseLossBackward>)
tensor(3063220.2500, grad_fn=<MseLossBackward>)
*****
what is 'loss'?:
tensor(2451.0991, grad_fn=<MseLossBackward>)
tensor(2156.4258, grad_fn=<MseLossBackward>)
tensor(1813.8121, grad_fn=<MseLossBackward>)
tensor(3033.5308, grad_fn=<MseLossBackward>)
*****
what is 'loss'?:
tensor(3462.7903, grad_fn=<MseLossBackward>)
tensor(3364.3013, grad_fn=<MseLossBackward>)
tensor(2815.6736, grad_fn=<MseLossBackward>)
tensor(3691.9775, grad_fn=<MseLossBackward>)
*****
what is 'loss'?:
tensor(6463.467

what is 'loss'?:
tensor(85287.9375, grad_fn=<MseLossBackward>)
tensor(78695.2188, grad_fn=<MseLossBackward>)
tensor(62704.4609, grad_fn=<MseLossBackward>)
tensor(85996.2969, grad_fn=<MseLossBackward>)
*****
what is 'loss'?:
tensor(8568.8555, grad_fn=<MseLossBackward>)
tensor(7679.5537, grad_fn=<MseLossBackward>)
tensor(12114.3223, grad_fn=<MseLossBackward>)
tensor(32867.2695, grad_fn=<MseLossBackward>)
*****
what is 'loss'?:
tensor(27882.7207, grad_fn=<MseLossBackward>)
tensor(25132.2246, grad_fn=<MseLossBackward>)
tensor(30335.9238, grad_fn=<MseLossBackward>)
tensor(32139.3379, grad_fn=<MseLossBackward>)
*****
what is 'loss'?:
tensor(57520.5742, grad_fn=<MseLossBackward>)
tensor(72364.4219, grad_fn=<MseLossBackward>)
tensor(83129.9062, grad_fn=<MseLossBackward>)
tensor(103967.6328, grad_fn=<MseLossBackward>)
*****
what is 'loss'?:
tensor(6627.4756, grad_fn=<MseLossBackward>)
tensor(6019.7920, grad_fn=<MseLossBackward>)
tensor(4231.3535, grad_fn=<MseLossBackward>)
tensor(5935.5264, gra

what is 'loss'?:
tensor(4573461., grad_fn=<MseLossBackward>)
tensor(2462536., grad_fn=<MseLossBackward>)
tensor(435280.1875, grad_fn=<MseLossBackward>)
tensor(413310.9688, grad_fn=<MseLossBackward>)
*****
what is 'loss'?:
tensor(4956.8970, grad_fn=<MseLossBackward>)
tensor(4428.0327, grad_fn=<MseLossBackward>)
tensor(3373.2290, grad_fn=<MseLossBackward>)
tensor(5384.9312, grad_fn=<MseLossBackward>)
*****
what is 'loss'?:
tensor(218453.4688, grad_fn=<MseLossBackward>)
tensor(188454.6250, grad_fn=<MseLossBackward>)
tensor(268979.8125, grad_fn=<MseLossBackward>)
tensor(323845.1875, grad_fn=<MseLossBackward>)
*****
what is 'loss'?:
tensor(232370.8281, grad_fn=<MseLossBackward>)
tensor(299980., grad_fn=<MseLossBackward>)
tensor(426608.4062, grad_fn=<MseLossBackward>)
tensor(379445.0312, grad_fn=<MseLossBackward>)
*****
what is 'loss'?:
tensor(46025.4023, grad_fn=<MseLossBackward>)
tensor(64821.9297, grad_fn=<MseLossBackward>)
tensor(115370.6016, grad_fn=<MseLossBackward>)
tensor(101001.0938

what is 'loss'?:
tensor(31590.0098, grad_fn=<MseLossBackward>)
tensor(46672.2617, grad_fn=<MseLossBackward>)
tensor(61036.3789, grad_fn=<MseLossBackward>)
tensor(78773.1172, grad_fn=<MseLossBackward>)
*****
what is 'loss'?:
tensor(43470.3242, grad_fn=<MseLossBackward>)
tensor(36805.8242, grad_fn=<MseLossBackward>)
tensor(51127.5977, grad_fn=<MseLossBackward>)
tensor(77330.8125, grad_fn=<MseLossBackward>)
*****
what is 'loss'?:
tensor(220933.9219, grad_fn=<MseLossBackward>)
tensor(147696.9375, grad_fn=<MseLossBackward>)
tensor(131543.2500, grad_fn=<MseLossBackward>)
tensor(135633.3906, grad_fn=<MseLossBackward>)
*****
what is 'loss'?:
tensor(137612.2344, grad_fn=<MseLossBackward>)
tensor(173279.9219, grad_fn=<MseLossBackward>)
tensor(179017.7188, grad_fn=<MseLossBackward>)
tensor(153710.6250, grad_fn=<MseLossBackward>)
*****
what is 'loss'?:
tensor(272205.4688, grad_fn=<MseLossBackward>)
tensor(212720.0156, grad_fn=<MseLossBackward>)
tensor(168987.1719, grad_fn=<MseLossBackward>)
tensor

what is 'loss'?:
tensor(4482936., grad_fn=<MseLossBackward>)
tensor(2346118.2500, grad_fn=<MseLossBackward>)
tensor(323460.4062, grad_fn=<MseLossBackward>)
tensor(182784.4375, grad_fn=<MseLossBackward>)
*****
what is 'loss'?:
tensor(1850.3634, grad_fn=<MseLossBackward>)
tensor(2153.5076, grad_fn=<MseLossBackward>)
tensor(2909.3394, grad_fn=<MseLossBackward>)
tensor(3231.2646, grad_fn=<MseLossBackward>)
*****
what is 'loss'?:
tensor(64691.2695, grad_fn=<MseLossBackward>)
tensor(71640.7891, grad_fn=<MseLossBackward>)
tensor(91428.6250, grad_fn=<MseLossBackward>)
tensor(117963.0156, grad_fn=<MseLossBackward>)
*****
what is 'loss'?:
tensor(178988.7031, grad_fn=<MseLossBackward>)
tensor(220634.0938, grad_fn=<MseLossBackward>)
tensor(314598.9375, grad_fn=<MseLossBackward>)
tensor(260810.7969, grad_fn=<MseLossBackward>)
*****
what is 'loss'?:
tensor(39402.2188, grad_fn=<MseLossBackward>)
tensor(45255.7812, grad_fn=<MseLossBackward>)
tensor(90128.4922, grad_fn=<MseLossBackward>)
tensor(101628.

what is 'loss'?:
tensor(73856.8672, grad_fn=<MseLossBackward>)
tensor(97965.7500, grad_fn=<MseLossBackward>)
tensor(112510.8125, grad_fn=<MseLossBackward>)
tensor(138160.8906, grad_fn=<MseLossBackward>)
*****
what is 'loss'?:
tensor(60348.4805, grad_fn=<MseLossBackward>)
tensor(91985.9062, grad_fn=<MseLossBackward>)
tensor(158458.8438, grad_fn=<MseLossBackward>)
tensor(169382., grad_fn=<MseLossBackward>)
*****
what is 'loss'?:
tensor(4948.4111, grad_fn=<MseLossBackward>)
tensor(4772.3838, grad_fn=<MseLossBackward>)
tensor(5657.9580, grad_fn=<MseLossBackward>)
tensor(7816.4692, grad_fn=<MseLossBackward>)
*****
what is 'loss'?:
tensor(1651.6932, grad_fn=<MseLossBackward>)
tensor(1863.3571, grad_fn=<MseLossBackward>)
tensor(2463.1780, grad_fn=<MseLossBackward>)
tensor(2811.7761, grad_fn=<MseLossBackward>)
*****
what is 'loss'?:
tensor(30672.4805, grad_fn=<MseLossBackward>)
tensor(31789.7891, grad_fn=<MseLossBackward>)
tensor(18069.7988, grad_fn=<MseLossBackward>)
tensor(26212.6523, grad_f

what is 'loss'?:
tensor(208443.5781, grad_fn=<MseLossBackward>)
tensor(297543.0625, grad_fn=<MseLossBackward>)
tensor(552024.5000, grad_fn=<MseLossBackward>)
tensor(470479.9688, grad_fn=<MseLossBackward>)
*****
what is 'loss'?:
tensor(25336.2930, grad_fn=<MseLossBackward>)
tensor(19570.7695, grad_fn=<MseLossBackward>)
tensor(21044.4082, grad_fn=<MseLossBackward>)
tensor(35723.1875, grad_fn=<MseLossBackward>)
*****
what is 'loss'?:
tensor(11563.6738, grad_fn=<MseLossBackward>)
tensor(12505.4824, grad_fn=<MseLossBackward>)
tensor(14650.7168, grad_fn=<MseLossBackward>)
tensor(14728.5518, grad_fn=<MseLossBackward>)
*****
what is 'loss'?:
tensor(199820.4375, grad_fn=<MseLossBackward>)
tensor(178603.4844, grad_fn=<MseLossBackward>)
tensor(289159.6562, grad_fn=<MseLossBackward>)
tensor(290386.7500, grad_fn=<MseLossBackward>)
*****
what is 'loss'?:
tensor(21720.8242, grad_fn=<MseLossBackward>)
tensor(17690.6250, grad_fn=<MseLossBackward>)
tensor(14888.6084, grad_fn=<MseLossBackward>)
tensor(12

what is 'loss'?:
tensor(118859.5781, grad_fn=<MseLossBackward>)
tensor(168136.1562, grad_fn=<MseLossBackward>)
tensor(252057.7031, grad_fn=<MseLossBackward>)
tensor(205248.3750, grad_fn=<MseLossBackward>)
*****
what is 'loss'?:
tensor(893758.6875, grad_fn=<MseLossBackward>)
tensor(988558.0625, grad_fn=<MseLossBackward>)
tensor(1476611.7500, grad_fn=<MseLossBackward>)
tensor(1307148.3750, grad_fn=<MseLossBackward>)
*****
what is 'loss'?:
tensor(377849.7812, grad_fn=<MseLossBackward>)
tensor(315221.7500, grad_fn=<MseLossBackward>)
tensor(187614.3906, grad_fn=<MseLossBackward>)
tensor(228824.2500, grad_fn=<MseLossBackward>)
*****
what is 'loss'?:
tensor(447843.1250, grad_fn=<MseLossBackward>)
tensor(393621.8125, grad_fn=<MseLossBackward>)
tensor(441710.5000, grad_fn=<MseLossBackward>)
tensor(582466.5000, grad_fn=<MseLossBackward>)
*****
what is 'loss'?:
tensor(8.9454e+08, grad_fn=<MseLossBackward>)
tensor(9.7702e+08, grad_fn=<MseLossBackward>)
tensor(9.1035e+08, grad_fn=<MseLossBackward>)

what is 'loss'?:
tensor(90570.4062, grad_fn=<MseLossBackward>)
tensor(112833.5078, grad_fn=<MseLossBackward>)
tensor(172410.5000, grad_fn=<MseLossBackward>)
tensor(203163.7188, grad_fn=<MseLossBackward>)
*****
what is 'loss'?:
tensor(25237.8574, grad_fn=<MseLossBackward>)
tensor(37897.5820, grad_fn=<MseLossBackward>)
tensor(44881.2070, grad_fn=<MseLossBackward>)
tensor(40291.3516, grad_fn=<MseLossBackward>)
*****
what is 'loss'?:
tensor(69039.1719, grad_fn=<MseLossBackward>)
tensor(78973.4219, grad_fn=<MseLossBackward>)
tensor(178481.1875, grad_fn=<MseLossBackward>)
tensor(152312.4375, grad_fn=<MseLossBackward>)
*****
what is 'loss'?:
tensor(46942.6484, grad_fn=<MseLossBackward>)
tensor(66473.2578, grad_fn=<MseLossBackward>)
tensor(132551.6406, grad_fn=<MseLossBackward>)
tensor(108254.6953, grad_fn=<MseLossBackward>)
*****
what is 'loss'?:
tensor(23207.9043, grad_fn=<MseLossBackward>)
tensor(27035.6992, grad_fn=<MseLossBackward>)
tensor(37117.3203, grad_fn=<MseLossBackward>)
tensor(381

what is 'loss'?:
tensor(2842.8843, grad_fn=<MseLossBackward>)
tensor(3112.6616, grad_fn=<MseLossBackward>)
tensor(2388.2209, grad_fn=<MseLossBackward>)
tensor(2696.6997, grad_fn=<MseLossBackward>)
*****
what is 'loss'?:
tensor(78748.8906, grad_fn=<MseLossBackward>)
tensor(101980.0938, grad_fn=<MseLossBackward>)
tensor(125310.5703, grad_fn=<MseLossBackward>)
tensor(164693.8281, grad_fn=<MseLossBackward>)
*****
what is 'loss'?:
tensor(24701.5371, grad_fn=<MseLossBackward>)
tensor(22053.9707, grad_fn=<MseLossBackward>)
tensor(44505.4961, grad_fn=<MseLossBackward>)
tensor(47827.8945, grad_fn=<MseLossBackward>)
*****
what is 'loss'?:
tensor(4793738., grad_fn=<MseLossBackward>)
tensor(5057788.5000, grad_fn=<MseLossBackward>)
tensor(3780090.5000, grad_fn=<MseLossBackward>)
tensor(3596834.7500, grad_fn=<MseLossBackward>)
*****
what is 'loss'?:
tensor(319389.6875, grad_fn=<MseLossBackward>)
tensor(401469.4062, grad_fn=<MseLossBackward>)
tensor(499778.2188, grad_fn=<MseLossBackward>)
tensor(4250

what is 'loss'?:
tensor(5346919.)
tensor(2779382.2500)
tensor(285776.3438)
tensor(109344.3984)
*****
what is 'loss'?:
tensor(43750.8906)
tensor(69492.9922)
tensor(97943.5156)
tensor(76740.6484)
*****
what is 'loss'?:
tensor(18705.0508)
tensor(24843.8750)
tensor(33925.2188)
tensor(33294.2227)
*****
what is 'loss'?:
tensor(164640.2656)
tensor(187733.4375)
tensor(230762.)
tensor(272645.6250)
*****
what is 'loss'?:
tensor(35345.0391)
tensor(30626.8203)
tensor(42336.6914)
tensor(55139.5898)
*****
what is 'loss'?:
tensor(137465.0781)
tensor(134590.2969)
tensor(103920.8828)
tensor(131385.0625)
*****
what is 'loss'?:
tensor(59911.3008)
tensor(48968.0820)
tensor(60773.5391)
tensor(53973.6016)
*****
what is 'loss'?:
tensor(863638.3750)
tensor(947527.8750)
tensor(887832.1250)
tensor(929989.3750)
*****
what is 'loss'?:
tensor(54866408.)
tensor(46766952.)
tensor(78605416.)
tensor(70876168.)
*****
what is 'loss'?:
tensor(88169.3594)
tensor(73639.6719)
tensor(76182.3906)
tensor(109535.6484)
*****
wha

what is 'loss'?:
tensor(25635.3691, grad_fn=<MseLossBackward>)
tensor(37816.9688, grad_fn=<MseLossBackward>)
tensor(45822.7891, grad_fn=<MseLossBackward>)
tensor(41890.7539, grad_fn=<MseLossBackward>)
*****
what is 'loss'?:
tensor(205109.9531, grad_fn=<MseLossBackward>)
tensor(287483.5625, grad_fn=<MseLossBackward>)
tensor(429342.5625, grad_fn=<MseLossBackward>)
tensor(312598., grad_fn=<MseLossBackward>)
*****
what is 'loss'?:
tensor(9794.1895, grad_fn=<MseLossBackward>)
tensor(11558.9170, grad_fn=<MseLossBackward>)
tensor(15878.9355, grad_fn=<MseLossBackward>)
tensor(17755.7695, grad_fn=<MseLossBackward>)
*****
what is 'loss'?:
tensor(936036.8125, grad_fn=<MseLossBackward>)
tensor(2364653.5000, grad_fn=<MseLossBackward>)
tensor(3668269.5000, grad_fn=<MseLossBackward>)
tensor(8019707., grad_fn=<MseLossBackward>)
*****
what is 'loss'?:
tensor(70536.2031, grad_fn=<MseLossBackward>)
tensor(85807.5391, grad_fn=<MseLossBackward>)
tensor(77481.0938, grad_fn=<MseLossBackward>)
tensor(102647.2

what is 'loss'?:
tensor(1476427.3750, grad_fn=<MseLossBackward>)
tensor(5058144., grad_fn=<MseLossBackward>)
tensor(14089999., grad_fn=<MseLossBackward>)
tensor(23224070., grad_fn=<MseLossBackward>)
*****
what is 'loss'?:
tensor(3476.0464, grad_fn=<MseLossBackward>)
tensor(3171.8235, grad_fn=<MseLossBackward>)
tensor(2833.1719, grad_fn=<MseLossBackward>)
tensor(3460.6289, grad_fn=<MseLossBackward>)
*****
what is 'loss'?:
tensor(6538.1514, grad_fn=<MseLossBackward>)
tensor(6031.6626, grad_fn=<MseLossBackward>)
tensor(7408.0054, grad_fn=<MseLossBackward>)
tensor(9250.6807, grad_fn=<MseLossBackward>)
*****
what is 'loss'?:
tensor(903323.8750, grad_fn=<MseLossBackward>)
tensor(1180275.1250, grad_fn=<MseLossBackward>)
tensor(795331.1250, grad_fn=<MseLossBackward>)
tensor(902430.5000, grad_fn=<MseLossBackward>)
*****
what is 'loss'?:
tensor(4522.7402, grad_fn=<MseLossBackward>)
tensor(4030.0391, grad_fn=<MseLossBackward>)
tensor(5229.1958, grad_fn=<MseLossBackward>)
tensor(7750.5254, grad_fn

what is 'loss'?:
tensor(58418.6562, grad_fn=<MseLossBackward>)
tensor(72798.9375, grad_fn=<MseLossBackward>)
tensor(110545.5000, grad_fn=<MseLossBackward>)
tensor(134771.6562, grad_fn=<MseLossBackward>)
*****
what is 'loss'?:
tensor(73994.6406, grad_fn=<MseLossBackward>)
tensor(63107.0273, grad_fn=<MseLossBackward>)
tensor(44008.8320, grad_fn=<MseLossBackward>)
tensor(56647.8516, grad_fn=<MseLossBackward>)
*****
what is 'loss'?:
tensor(156119.2656, grad_fn=<MseLossBackward>)
tensor(226604.9062, grad_fn=<MseLossBackward>)
tensor(448374.6250, grad_fn=<MseLossBackward>)
tensor(356098.0938, grad_fn=<MseLossBackward>)
*****
what is 'loss'?:
tensor(43926.6875, grad_fn=<MseLossBackward>)
tensor(51231.9180, grad_fn=<MseLossBackward>)
tensor(70705.5781, grad_fn=<MseLossBackward>)
tensor(61144.3906, grad_fn=<MseLossBackward>)
*****
what is 'loss'?:
tensor(84077.9688, grad_fn=<MseLossBackward>)
tensor(96834.7422, grad_fn=<MseLossBackward>)
tensor(197971., grad_fn=<MseLossBackward>)
tensor(174580.

what is 'loss'?:
tensor(48610.6328, grad_fn=<MseLossBackward>)
tensor(45650.0391, grad_fn=<MseLossBackward>)
tensor(50081.7383, grad_fn=<MseLossBackward>)
tensor(59196.0664, grad_fn=<MseLossBackward>)
*****
what is 'loss'?:
tensor(61847.5000, grad_fn=<MseLossBackward>)
tensor(98302.7344, grad_fn=<MseLossBackward>)
tensor(142230.4219, grad_fn=<MseLossBackward>)
tensor(104206.7109, grad_fn=<MseLossBackward>)
*****
what is 'loss'?:
tensor(654995.5625, grad_fn=<MseLossBackward>)
tensor(514418.0625, grad_fn=<MseLossBackward>)
tensor(588464., grad_fn=<MseLossBackward>)
tensor(781797.6875, grad_fn=<MseLossBackward>)
*****
what is 'loss'?:
tensor(35281.5586, grad_fn=<MseLossBackward>)
tensor(56680.2070, grad_fn=<MseLossBackward>)
tensor(70441.5703, grad_fn=<MseLossBackward>)
tensor(67225.3359, grad_fn=<MseLossBackward>)
*****
what is 'loss'?:
tensor(6010.3389, grad_fn=<MseLossBackward>)
tensor(6058.6411, grad_fn=<MseLossBackward>)
tensor(6750.1919, grad_fn=<MseLossBackward>)
tensor(7950.4014, 

what is 'loss'?:
tensor(4509.7622, grad_fn=<MseLossBackward>)
tensor(4208.9370, grad_fn=<MseLossBackward>)
tensor(3925.2744, grad_fn=<MseLossBackward>)
tensor(5221.0439, grad_fn=<MseLossBackward>)
*****
what is 'loss'?:
tensor(8454.8564, grad_fn=<MseLossBackward>)
tensor(7434.7310, grad_fn=<MseLossBackward>)
tensor(6816.2744, grad_fn=<MseLossBackward>)
tensor(10215.0371, grad_fn=<MseLossBackward>)
*****
what is 'loss'?:
tensor(4754853., grad_fn=<MseLossBackward>)
tensor(5015517.5000, grad_fn=<MseLossBackward>)
tensor(3715116.5000, grad_fn=<MseLossBackward>)
tensor(3507356.5000, grad_fn=<MseLossBackward>)
*****
what is 'loss'?:
tensor(10339.2393, grad_fn=<MseLossBackward>)
tensor(9361.3340, grad_fn=<MseLossBackward>)
tensor(11463.3643, grad_fn=<MseLossBackward>)
tensor(12952.6504, grad_fn=<MseLossBackward>)
*****
what is 'loss'?:
tensor(61392.4766, grad_fn=<MseLossBackward>)
tensor(64273.0508, grad_fn=<MseLossBackward>)
tensor(110325.4453, grad_fn=<MseLossBackward>)
tensor(113791.0391, 

what is 'loss'?:
tensor(21070.3750, grad_fn=<MseLossBackward>)
tensor(22259.9824, grad_fn=<MseLossBackward>)
tensor(33772.1211, grad_fn=<MseLossBackward>)
tensor(35675.9922, grad_fn=<MseLossBackward>)
*****
what is 'loss'?:
tensor(4480750.5000, grad_fn=<MseLossBackward>)
tensor(2329610.5000, grad_fn=<MseLossBackward>)
tensor(167855.7500, grad_fn=<MseLossBackward>)
tensor(109069.1953, grad_fn=<MseLossBackward>)
*****
what is 'loss'?:
tensor(92740.2969, grad_fn=<MseLossBackward>)
tensor(111469., grad_fn=<MseLossBackward>)
tensor(224493.4531, grad_fn=<MseLossBackward>)
tensor(191522.3125, grad_fn=<MseLossBackward>)
*****
what is 'loss'?:
tensor(42613.3945, grad_fn=<MseLossBackward>)
tensor(36541.0352, grad_fn=<MseLossBackward>)
tensor(37121.4102, grad_fn=<MseLossBackward>)
tensor(59873.6094, grad_fn=<MseLossBackward>)
*****
what is 'loss'?:
tensor(139865.4375, grad_fn=<MseLossBackward>)
tensor(181582.3438, grad_fn=<MseLossBackward>)
tensor(198714.7188, grad_fn=<MseLossBackward>)
tensor(18

what is 'loss'?:
tensor(455476.8438, grad_fn=<MseLossBackward>)
tensor(398629.9688, grad_fn=<MseLossBackward>)
tensor(457166.7812, grad_fn=<MseLossBackward>)
tensor(605607.9375, grad_fn=<MseLossBackward>)
*****
what is 'loss'?:
tensor(16489.7305, grad_fn=<MseLossBackward>)
tensor(17656.7344, grad_fn=<MseLossBackward>)
tensor(22604.7695, grad_fn=<MseLossBackward>)
tensor(20599.6367, grad_fn=<MseLossBackward>)
*****
what is 'loss'?:
tensor(95638.9297, grad_fn=<MseLossBackward>)
tensor(33347.1836, grad_fn=<MseLossBackward>)
tensor(3401.1992, grad_fn=<MseLossBackward>)
tensor(3941.8459, grad_fn=<MseLossBackward>)
*****
what is 'loss'?:
tensor(171812.8594, grad_fn=<MseLossBackward>)
tensor(249418.4844, grad_fn=<MseLossBackward>)
tensor(482088.7188, grad_fn=<MseLossBackward>)
tensor(390597., grad_fn=<MseLossBackward>)
*****
what is 'loss'?:
tensor(72117.6406, grad_fn=<MseLossBackward>)
tensor(68314.0078, grad_fn=<MseLossBackward>)
tensor(135855.3906, grad_fn=<MseLossBackward>)
tensor(193344.

what is 'loss'?:
tensor(11123.4775, grad_fn=<MseLossBackward>)
tensor(9658.5977, grad_fn=<MseLossBackward>)
tensor(8822.7656, grad_fn=<MseLossBackward>)
tensor(14341.3105, grad_fn=<MseLossBackward>)
*****
what is 'loss'?:
tensor(101448.3672, grad_fn=<MseLossBackward>)
tensor(157932.8594, grad_fn=<MseLossBackward>)
tensor(176782.2656, grad_fn=<MseLossBackward>)
tensor(198656.5469, grad_fn=<MseLossBackward>)
*****
what is 'loss'?:
tensor(60981.7383, grad_fn=<MseLossBackward>)
tensor(67314.2656, grad_fn=<MseLossBackward>)
tensor(78624.5391, grad_fn=<MseLossBackward>)
tensor(90101.9453, grad_fn=<MseLossBackward>)
*****
what is 'loss'?:
tensor(1908619.2500, grad_fn=<MseLossBackward>)
tensor(2889851.2500, grad_fn=<MseLossBackward>)
tensor(4200878., grad_fn=<MseLossBackward>)
tensor(8809002., grad_fn=<MseLossBackward>)
*****
what is 'loss'?:
tensor(346537.2188, grad_fn=<MseLossBackward>)
tensor(317920.5312, grad_fn=<MseLossBackward>)
tensor(291445.5312, grad_fn=<MseLossBackward>)
tensor(21570

what is 'loss'?:
tensor(165266.1406, grad_fn=<MseLossBackward>)
tensor(183334.1250, grad_fn=<MseLossBackward>)
tensor(200807.7500, grad_fn=<MseLossBackward>)
tensor(125216.1172, grad_fn=<MseLossBackward>)
*****
what is 'loss'?:
tensor(8720.2227, grad_fn=<MseLossBackward>)
tensor(8970.5801, grad_fn=<MseLossBackward>)
tensor(12771.5293, grad_fn=<MseLossBackward>)
tensor(10397.3115, grad_fn=<MseLossBackward>)
*****
what is 'loss'?:
tensor(29251.4922, grad_fn=<MseLossBackward>)
tensor(20582.0957, grad_fn=<MseLossBackward>)
tensor(19309.0391, grad_fn=<MseLossBackward>)
tensor(23161.0684, grad_fn=<MseLossBackward>)
*****
what is 'loss'?:
tensor(18548.6719, grad_fn=<MseLossBackward>)
tensor(19372.2969, grad_fn=<MseLossBackward>)
tensor(17966.9062, grad_fn=<MseLossBackward>)
tensor(20941.2793, grad_fn=<MseLossBackward>)
*****
what is 'loss'?:
tensor(295712.5000, grad_fn=<MseLossBackward>)
tensor(137629.3438, grad_fn=<MseLossBackward>)
tensor(70923.8125, grad_fn=<MseLossBackward>)
tensor(198485

### Ok this seems to be running. The loss is exploding upward. Not yet sure if that's because it's a terrible model or if I've done something incorrectly. TBD

In [30]:
for xb, yb in train_dl:
    pass

In [53]:
type(xb[0])

torch.Tensor

In [57]:
type(multi_lin_model(xb.float()))

list

In [11]:
def format_XY_dfs(tpm_df, meta_df):
    # melt tpm df so every sample is in a row
    tpm_melt = tpm_df[['locus_tag']+samples].melt(id_vars=['locus_tag'],var_name='sample',value_name='tpm')
    tpm_melt['condition'] = tpm_melt['sample'].apply(lambda x: sample2condition[x])

    # also add in upstream seq
    tpm_melt['upstream_region'] = tpm_melt['locus_tag'].apply(lambda x: loc2seq[x])
    
    # get coded metadata conditions from meta_df
    samp2cond_df = meta_df[['sample']+COND_COLS]

    # merge back onto the tpm df
    df = tpm_melt.merge(samp2cond_df,on='sample',how='left')
    
    # reformat full df
    full_df = df[['locus_tag','upstream_region', 'sample','condition']+COND_COLS+['tpm']]
    
    # separate out just X (features) and Y (labels)
    X = full_df[['locus_tag','upstream_region']+COND_COLS]
    Y = full_df['tpm']
    
    return full_df,X,Y

In [12]:
full_df, X, Y = format_XY_dfs(tpm_df,meta_df)
full_df.head()

Unnamed: 0,locus_tag,upstream_region,sample,condition,carbon_source,oxygen_level,nitrate_level,copper_level,lanthanum_level,growth_rate,growth_mode,tpm
0,EQU24_RS00005,CGCCGGTTTATGTCAATTATGCCGGCACTGATTTGATTGCTGTATA...,5GB1_ferm_Ack_QC_tpm,lowO2_slow_growth,2,0,0,3,0,0,0,2.933003
1,EQU24_RS00010,AACGCCGGTTTTACAGTTCATAAGCTATTGATAAATAAAATAAAAA...,5GB1_ferm_Ack_QC_tpm,lowO2_slow_growth,2,0,0,3,0,0,0,1.607784
2,EQU24_RS00015,ATCGCAGTCATTATTAAATGTGGAAGCAACAAAAAAACGAGCTTGT...,5GB1_ferm_Ack_QC_tpm,lowO2_slow_growth,2,0,0,3,0,0,0,1.415515
3,EQU24_RS00020,AACTTAATAACTATAAAATGTTCCACGTGGAACATGGTGAAATTAA...,5GB1_ferm_Ack_QC_tpm,lowO2_slow_growth,2,0,0,3,0,0,0,3.200081
4,EQU24_RS00025,CTTTGCCGAACACCCCGCACCTCCACGCGTCAACAACGAAATTTGA...,5GB1_ferm_Ack_QC_tpm,lowO2_slow_growth,2,0,0,3,0,0,0,1.522728


In [13]:
tpm_df.head()

Unnamed: 0,locus_tag,product,type,gene_symbol,locus,start_coord,end_coord,note,translation,gene_len,...,5GB1_pA9_red_tpm,5GB1_pA9_yellow_tpm,5GB1C-5G-La-BR1_tpm,5GB1C-5G-La-BR2_tpm,5GB1C-5G-N-BR1_tpm,5GB1C-5G-N-BR2_tpm,5GB1C-JG15-La-BR1_tpm,5GB1C-JG15-La-BR2_tpm,5GB1C-JG15-N-BR1_tpm,5GB1C-JG15-N-BR2_tpm
0,EQU24_RS00005,chromosomal replication initiator protein DnaA,CDS,dnaA,NZ_CP035467.1,0,1317,Derived by automated computational analysis us...,MSALWNNCLAKLENEISSSEFSTWIRPLQAIETDGQIKLLAPNRFV...,1318,...,38.557373,38.810668,37.444214,40.246006,40.100118,33.432274,39.880174,38.355431,30.247582,41.248441
1,EQU24_RS00010,DNA polymerase III subunit beta,CDS,,NZ_CP035467.1,1502,2603,Derived by automated computational analysis us...,MKYIINREQLLVPLQQIVSVIEKRQTMPILSNVLMVFRENTLVMTG...,1102,...,52.552767,52.461746,42.676553,49.210083,46.798476,48.142385,45.465136,46.498139,37.152951,52.90241
2,EQU24_RS00015,DNA replication/repair protein RecF,CDS,recF,NZ_CP035467.1,3060,4140,Derived by automated computational analysis us...,MSLQKLDIFNVRNIRQASLQPSPGLNLIYGANASGKSSVLEAIFIL...,1081,...,31.350991,34.914128,21.479309,24.204682,22.171104,22.006566,22.658157,22.753325,19.407103,29.834124
3,EQU24_RS00020,DNA topoisomerase (ATP-hydrolyzing) subunit B,CDS,gyrB,NZ_CP035467.1,4185,6600,Derived by automated computational analysis us...,MSENIKQYDSTNIQVLKGLDAVRKRPGMYIGDTDDGTGLHHMVFEV...,2416,...,74.848501,80.850761,54.959319,64.911376,59.653059,64.648318,69.119079,65.643179,57.590223,68.306759
4,EQU24_RS00025,hypothetical protein,CDS,,NZ_CP035467.1,6825,7062,Derived by automated computational analysis us...,VKTTKYFLTTRMRPDREIIKDEWIQYVVRFPENEHIQFDGRIRRWA...,238,...,50.324948,49.349547,34.539657,36.521074,37.789611,39.358066,38.992158,35.870964,41.462392,40.227192


In [15]:
full_df.head()

Unnamed: 0,locus_tag,upstream_region,sample,condition,carbon_source,oxygen_level,nitrate_level,copper_level,lanthanum_level,growth_rate,growth_mode,tpm
0,EQU24_RS00005,CGCCGGTTTATGTCAATTATGCCGGCACTGATTTGATTGCTGTATA...,5GB1_ferm_Ack_QC_tpm,lowO2_slow_growth,2,0,0,3,0,0,0,2.933003
1,EQU24_RS00010,AACGCCGGTTTTACAGTTCATAAGCTATTGATAAATAAAATAAAAA...,5GB1_ferm_Ack_QC_tpm,lowO2_slow_growth,2,0,0,3,0,0,0,1.607784
2,EQU24_RS00015,ATCGCAGTCATTATTAAATGTGGAAGCAACAAAAAAACGAGCTTGT...,5GB1_ferm_Ack_QC_tpm,lowO2_slow_growth,2,0,0,3,0,0,0,1.415515
3,EQU24_RS00020,AACTTAATAACTATAAAATGTTCCACGTGGAACATGGTGAAATTAA...,5GB1_ferm_Ack_QC_tpm,lowO2_slow_growth,2,0,0,3,0,0,0,3.200081
4,EQU24_RS00025,CTTTGCCGAACACCCCGCACCTCCACGCGTCAACAACGAAATTTGA...,5GB1_ferm_Ack_QC_tpm,lowO2_slow_growth,2,0,0,3,0,0,0,1.522728


In [18]:
len(X.upstream_region.values[0])

100

In [20]:
X

Unnamed: 0,locus_tag,upstream_region,carbon_source,oxygen_level,nitrate_level,copper_level,lanthanum_level,growth_rate,growth_mode
0,EQU24_RS00005,CGCCGGTTTATGTCAATTATGCCGGCACTGATTTGATTGCTGTATA...,2,0,0,3,0,0,0
1,EQU24_RS00010,AACGCCGGTTTTACAGTTCATAAGCTATTGATAAATAAAATAAAAA...,2,0,0,3,0,0,0
2,EQU24_RS00015,ATCGCAGTCATTATTAAATGTGGAAGCAACAAAAAAACGAGCTTGT...,2,0,0,3,0,0,0
3,EQU24_RS00020,AACTTAATAACTATAAAATGTTCCACGTGGAACATGGTGAAATTAA...,2,0,0,3,0,0,0
4,EQU24_RS00025,CTTTGCCGAACACCCCGCACCTCCACGCGTCAACAACGAAATTTGA...,2,0,0,3,0,0,0
...,...,...,...,...,...,...,...,...,...
412869,EQU24_RS22135,CCCGGCCGGTTTGGTCTTGTACTGGGTGGTCAACAATACGCTGTCG...,2,1,0,3,0,2,1
412870,EQU24_RS22140,GCCGCCCAGGGCACCTATCTTACAGTCCGAAGAGTATTAAAGTGTC...,2,1,0,3,0,2,1
412871,EQU24_RS22145,AATATTGATGTTGTTGTTATGGCCCGAAAAGATGCACTCAATGCAT...,2,1,0,3,0,2,1
412872,EQU24_RS22150,AAGAACTCACGGCTTTCGTGCCAGAATGGCGACCAAAGGCGGCCGT...,2,1,0,3,0,2,1


## Linear model arch

In [None]:
class DNA_Linear(nn.Module):
    def __init__(self, seq_len,h1_size,num_heads):
        super().__init__()
        self.lin = nn.Linear(4*seq_len, h1_size)

    def forward(self, xb):
        # Linear wraps up the weights/bias dot product operations
        return self.lin(xb)