In [1]:
import torch
from torch import nn
import torch.autograd as autograd

from itertools import product
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

import utils as u

In [2]:
# create all possible 8-mers
seqs8 = [''.join(x) for x in product(['A','C','G','T'], repeat=8)]
print('Total 8mers:',len(seqs8))

Total 8mers: 65536


In [3]:
# methods for assigning scores to a particular DNA sequence
score_dict = {
    'A':20,
    'C':17,
    'G':14,
    'T':11
}
def score_seqs(seqs):
    '''Each seq is just the average of the letter scores wrt score_dict'''
    data = []
    for seq in seqs:
        score = np.mean([score_dict[base] for base in seq])
        data.append([seq,score])
        
    df = pd.DataFrame(data, columns=['seq','score'])
    return df
                  
def score_seqs_motif(seqs):
    '''
    Each seq is the average of the letter scores wrt score_dict but if
    it has a TAT it gets a +10 but if it has a GCG it gets a -10
    '''
    data = []
    for seq in seqs:
        score = np.mean([score_dict[base] for base in seq])
        if 'TAT' in seq:
            score += 10
        if 'GCG' in seq:
            score -= 10
        data.append([seq,score])
        
    df = pd.DataFrame(data, columns=['seq','score'])
    return df

In [4]:
mer8 = score_seqs(seqs8)
mer8.head()

Unnamed: 0,seq,score
0,AAAAAAAA,20.0
1,AAAAAAAC,19.625
2,AAAAAAAG,19.25
3,AAAAAAAT,18.875
4,AAAAAACA,19.625


In [5]:
mer8_motif = score_seqs_motif(seqs8)
mer8_motif

Unnamed: 0,seq,score
0,AAAAAAAA,20.000
1,AAAAAAAC,19.625
2,AAAAAAAG,19.250
3,AAAAAAAT,18.875
4,AAAAAACA,19.625
...,...,...
65531,TTTTTTGT,11.375
65532,TTTTTTTA,12.125
65533,TTTTTTTC,11.750
65534,TTTTTTTG,11.375


In [6]:
# load stuff into pytorch dataloaders
mer8motif_train_dl,\
mer8motif_test_dl, \
mer8motif_train_df, \
mer8motif_test_df = u.build_dataloaders_single(mer8_motif,batch_size=11)
# change to batch size 11 so I can figure out the dimension errors

In [7]:
mer8motif_train_dl

<torch.utils.data.dataloader.DataLoader at 0x7fed9ca1a1f0>

In [8]:
print(mer8motif_train_dl.batch_size)

11


In [9]:
def loss_batch(model, loss_func, xb, yb, opt=None,verbose=False):
    '''
    Apply loss function to a batch of inputs. If no optimizer
    is provided, skip the back prop step.
    '''
    if verbose:
        print('\n\nloss batch ****')
        print("xb shape:",xb.shape)
        print("yb shape:",yb.shape)

    xb_out = model(xb.float())
    if verbose:
        print("model out pre loss", xb_out.shape)
    loss = loss_func(xb_out, yb.float())

    if opt is not None:
        loss.backward()
        opt.step()
        opt.zero_grad()

    #print("lb returning:",loss.item(), len(xb))
    return loss.item(), len(xb)

def fit(epochs, model, loss_func, opt, train_dl, test_dl):
    '''
    Fit the model params to the training data, eval on unseen data.
    Loop for a number of epochs and keep train of train and val losses 
    along the way
    '''
    # keep track of losses
    train_losses = []    
    test_losses = []
    
    # loops through epochs
    for epoch in range(epochs):
        #print("TRAIN")
        model.train()
        ts = []
        ns = []
        # collect train loss; provide opt so backpropo happens
        for xb, yb in train_dl:
            t, n = loss_batch(model, loss_func, xb, yb, opt=opt)
            ts.append(t)
            ns.append(n)
        train_loss = np.sum(np.multiply(ts, ns)) / np.sum(ns)
        train_losses.append(train_loss)
        
        #print("EVAL")
        model.eval()
        with torch.no_grad():
            losses, nums = zip(
                # loop through test batches
                # returns loss calc for test set batch size
                # unzips into two lists
                *[loss_batch(model, loss_func, xb, yb) for xb, yb in test_dl]
                # Note: no opt provided, backprop won't happen
            )
        # Gets average MSE loss across all batches (may be of diff sizes, hence the multiply)
        #print("losses", losses)
        #print("nums", nums)
        test_loss = np.sum(np.multiply(losses, nums)) / np.sum(nums)

        print(epoch, test_loss)
        test_losses.append(test_loss)

    return train_losses, test_losses

def run_model(train_dl,test_dl, model, lr=0.01, epochs=20):
    '''
    Given data and a model type, run dataloaders with MSE loss and SGD opt
    '''
    # define loss func and optimizer
    loss_func = torch.nn.MSELoss() 
    optimizer = torch.optim.SGD(model.parameters(), lr=lr) 
    
    # run the training loop
    train_losses, test_losses = fit(epochs, model, loss_func, optimizer, train_dl, test_dl)
    
    #return model, train_losses, test_losses
    return train_losses, test_losses

## Attempt to build LSTM

In [10]:
class DNA_LSTM(nn.Module):
    def __init__(self,seq_len,hidden_dim=10):
        super().__init__()
        self.seq_len = seq_len

        self.hidden_dim = hidden_dim
        self.hidden = None # when initialized, should be tuple of (hidden state, cell state)
        
        self.rnn = nn.LSTM(4, hidden_dim,batch_first=True)
        self.fc = nn.Linear(hidden_dim, 1)
            

    
    def init_hidden(self,batch_size):
        # initialize hidden and cell states with 0s
        self.hidden =  (torch.zeros(1, batch_size, self.hidden_dim), 
                        torch.zeros(1, batch_size, self.hidden_dim))
        return self.hidden
        #hidden_state = torch.randn(n_layers, batch_size, hidden_dim)
    

    def forward(self, xb,verbose=False):
        if verbose:
            print("original xb.shape:", xb.shape)
            print(xb) # 11 x 32
        
        # make the one-hot nucleotide vectors group together
        xb = xb.view(-1,self.seq_len,4) 
        if verbose:
            print("re-viewed xb.shape:", xb.shape) # >> 11 x 8 x 4
            print(xb)

        # ** Init hidden/cell states?? **
        batch_size = xb.shape[0]
        if verbose:
            print("batch_size:",batch_size)
        (h,c) = self.init_hidden(batch_size)
         
        # *******
        
        lstm_out, self.hidden = self.rnn(xb, (h,c)) # should this get H and C?
        if verbose:
            #print("lstm_out",lstm_out)
            print("lstm_out shape:",lstm_out.shape) # >> 11, 8, 10
            print("lstm_out[-1] shape:",lstm_out[-1].shape) # >> 8 x 10
            print("lstm_out[-1][-1] shape:",lstm_out[-1][-1].shape) # 10

            print("hidden len:",len(self.hidden)) # 2
            print("hidden[0] shape:", self.hidden[0].shape) # >> 1 x 11 x 10
            print("hidden[0][-1] shape:", self.hidden[0][-1].shape) # >> 11 X 10
            print("hidden[0][-1][-1] shape:", self.hidden[0][-1][-1].shape) # >> 10

            print("*****")
            # These vectors should be the same, right?
            A = lstm_out[-1][-1]
            B = self.hidden[0][-1][-1]
            print("lstm_out[-1][-1]:",A)
            print("self.hidden[0][-1][-1]",B)
            print("==?", A==B)
            print("*****")
        
        # attempt to get the last layer from each last position of 
        # all seqs in the batch? IS this the right thing to get?
        last_layer = lstm_out[:,-1,:] # This is 11X10... and it makes FC out 11X1, which is what I want?
        #last_layer = lstm_out[-1][-1].unsqueeze(0) # this was [10X1]? led to FC outoput being [1]?
        if verbose:
            print("last layer:", last_layer.shape)

        out = self.fc(last_layer) 
        if verbose:
            print("LSTM->FC out shape:",out.shape)   
                                                
        return out

In [11]:
seq_len = len(mer8motif_train_df['seq'].values[0])

mer8motif_model_lstm = DNA_LSTM(seq_len)
mer8motif_model_lstm

DNA_LSTM(
  (rnn): LSTM(4, 10, batch_first=True)
  (fc): Linear(in_features=10, out_features=1, bias=True)
)

In [12]:
train_losses,test_losses= run_model(
    mer8motif_train_dl,
    mer8motif_test_dl,
    mer8motif_model_lstm, 
    lr=0.01
)


  Variable._execution_engine.run_backward(


0 5.120449012796217
1 0.5644370661318066
2 0.2616728325850459
3 0.15103856184447925
4 0.10665999610672727
5 0.07707490203751499
6 0.22454623091280423
7 0.06064956730692331
8 0.05526983230087712
9 0.11212700311615176
10 0.09519580356221956
11 0.04957094306356723
12 0.03402372083281669
13 0.032571924564013414
14 0.06015142847803474
15 0.03031152778510874
16 0.0589027988395023
17 0.08019384578666697
18 0.03782405983770325
19 0.043076120334176356


In [17]:
def quick_seq_pred(model, seqs, oracle):
    '''Given some sequences, get the model's predictions '''
    for dna in seqs:
        s = torch.tensor(u.one_hot_encode(dna))
        pred = model(s.float())
        actual = oracle[dna]
        diff = actual - pred.item()
        print(f"{dna}: {pred.item()} actual:{actual} ({diff})")
        
def quick_test8(model):
    seqs1 = ['AAAAAAAA', 'CCCCCCCC','GGGGGGGG','TTTTTTTT']
    seqs2 = ['AACCAACA','CCGGCGCG','GGGTAAGG', 'TTTCGTTT','TGTAATAC']
    seqsTAT = ['TATAAAAA','CCTATCCC','GTATGGGG','TTTATTTT']
    seqsGCG = ['AAGCGAAA','CGCGCCCC','GGGCGGGG','TTGCGTTT']
    TATGCG =  ['ATATGCGA','TGCGTATT']
    
    scoring = dict(mer8_motif[['seq','score']].values)

    for seqs in [seqs1, seqs2, seqsTAT, seqsGCG, TATGCG]:
        quick_seq_pred(model, seqs, scoring)
        print()

In [18]:
quick_test8(mer8motif_model_lstm)

AAAAAAAA: 19.20769500732422 actual:20.0 (0.7923049926757812)
CCCCCCCC: 17.158016204833984 actual:17.0 (-0.15801620483398438)
GGGGGGGG: 13.617986679077148 actual:14.0 (0.38201332092285156)
TTTTTTTT: 11.39276123046875 actual:11.0 (-0.39276123046875)

AACCAACA: 18.69882583618164 actual:18.875 (0.17617416381835938)
CCGGCGCG: 5.050249099731445 actual:5.5 (0.4497509002685547)
GGGTAAGG: 15.031036376953125 actual:15.125 (0.093963623046875)
TTTCGTTT: 12.184320449829102 actual:12.125 (-0.05932044982910156)
TGTAATAC: 15.418512344360352 actual:15.5 (0.08148765563964844)

TATAAAAA: 27.15608787536621 actual:27.75 (0.5939121246337891)
CCTATCCC: 25.800508499145508 actual:25.875 (0.07449150085449219)
GTATGGGG: 23.516021728515625 actual:24.0 (0.483978271484375)
TTTATTTT: 21.531230926513672 actual:22.125 (0.5937690734863281)

AAGCGAAA: 7.887281894683838 actual:8.125 (0.2377181053161621)
CGCGCCCC: 6.319345474243164 actual:6.25 (-0.06934547424316406)
GGGCGGGG: 4.545219898223877 actual:4.375 (-0.17021989822

In [18]:
mer8_motif.head()

Unnamed: 0,seq,score,oh
0,AAAAAAAA,20.0,"[1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, ..."
1,AAAAAAAC,19.625,"[1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, ..."
2,AAAAAAAG,19.25,"[1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, ..."
3,AAAAAAAT,18.875,"[1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, ..."
4,AAAAAACA,19.625,"[1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, ..."
