In [7]:
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from model import SelfAttentionFeedForward

In [8]:
train_raw = pd.read_csv('../data/train_MPRA.txt', delimiter='\t', header=None)
test_raw = pd.read_csv('../data/test_MPRA.txt', delimiter='\t', header=None)
train_sol = pd.read_csv('../data/trainsolutions.txt', delimiter='\t', header=None)
train_raw.head()
strand_length = 295


In [9]:
# Get our x and y data
train_scores = np.array(train_raw.iloc[:, 2:297]) #Dimensions are 8000 (samples) by 295 (SHARPR scores per nucleotide)
raw_dna_strands_train = [list(train_raw[1][i]) for i in range(len(train_raw))] #List of lists holding DNA strands separated by character. Size 8000 lists each of length 290
embedded_dna_strands_train = [np.column_stack((np.array(pd.get_dummies(pd.concat([pd.Series(raw_dna_strands_train[i]), pd.Series(["A", "C", "T", "G"])]), dtype='int'))[:-4], np.arange(295))) for i in range(len(train_raw))] #One hot encoded dna strands, list of 8000 matrices, each (295,5)
embedded_dna_strands_train = [embedded_dna_strands_train[i] for i in range(len(embedded_dna_strands_train)) if not ("N" in raw_dna_strands_train[i])]
train_scores  = [train_scores[i] for i in range(len(raw_dna_strands_train)) if not ("N" in raw_dna_strands_train[i])]
#Repeat for test data
raw_dna_strands_test = [list(test_raw[1][i]) for i in range(len(test_raw))] #List of lists holding DNA strands separated by character. Size 8000 lists each of length 290
embedded_dna_strands_test = [np.column_stack((np.array(pd.get_dummies(pd.concat([pd.Series(raw_dna_strands_test[i]), pd.Series(["A", "C", "T", "G"])]), dtype='int'))[:-4], np.arange(295))) for i in range(len(test_raw))]
embedded_dna_strands_test = [embedded_dna_strands_test[i] for i in range(len(embedded_dna_strands_test)) if not ("N" in raw_dna_strands_test[i])]

In [10]:
#Add column with unique identifier for each nucleotide (sequence + location)
train_sol[3] = [str(train_sol.iloc[i, 1][5:]).zfill(4) + str(train_sol.iloc[i,2]).zfill(3) for i in range(len(train_sol))]

#Split by activators and repressors
train_sol_act = train_sol[train_sol[0] == 'A'][3]
train_sol_rep = train_sol[train_sol[0] == 'R'][3]

### ML Model

In [11]:
class DNADataset(Dataset):
    def __init__(self, embedded_dna_strands, train_scores):
        self.x = torch.tensor(embedded_dna_strands, dtype=torch.float32) # Convert x and y to tensors
        self.y = torch.tensor(train_scores, dtype=torch.float32)

    def __len__(self):
        return len(self.x)

    def __getitem__(self, idx):
        return self.x[idx], self.y[idx]

In [13]:
class SelfAttentionFeedForward(nn.Module):
    #Initialize hyperparameters and NN matrices
    def __init__(self, attention_size, embed_size, hidden_size, hidden_layers, lr, train_len):
        super().__init__()
        self.attention_size = attention_size
        self.embed_size = embed_size
        self.hidden_size = hidden_size
        self.hidden_layers = hidden_layers
        self.lr = lr
        self.train_len = train_len
        #self.dropout_rate = dropout_rate 

        self.initAttention()
        self.initFFN()

        self.optimizer = torch.optim.Adam(
            self.parameters(),
            lr=self.lr,
            amsgrad=True,
        )

    #Initialize our weight matrices as torch objects, allows them to be automatically optimized
    def initAttention(self):
        self.W_Q = nn.Linear(self.embed_size, self.attention_size, bias=False)
        self.W_K = nn.Linear(self.embed_size, self.attention_size, bias=False)
        self.W_V = nn.Linear(self.embed_size, self.attention_size, bias=False)
        self.b = nn.Parameter(torch.rand(295)) # This is an addition term, analogous to y-intercept

        
    
    #Initialize Feed Forward layers, based on however many hidden layers we want
    def initFFN(self):

        layers = []

        layers.append(nn.Linear(self.attention_size, self.hidden_size))
        layers.append(nn.ReLU())
        #layers.append(nn.Dropout(self.dropout_rate))
        
        for _ in range(self.hidden_layers - 1):
            layers.append(nn.Linear(self.hidden_size, self.hidden_size))
            layers.append(nn.ReLU())
            #layers.append(nn.Dropout(self.dropout_rate)) #Add this later on

        layers.append(nn.Linear(self.hidden_size, 1))

        self.layers = nn.ModuleList(layers)
        self.criterion = nn.MSELoss() # Swithc to mean squared error instead of simple norm (this is better apparently?)

    def loss(self, predicted, y):
        return torch.norm(predicted - y)

    def forward(self, x):
        # x of size                                               (batch_size, sequence_length, embedding_size)
        batch_size, seq_len, emb_size = x.shape
        if emb_size != self.embed_size:
            raise ValueError

        queries = self.W_Q(x) #                                   (batch_size, sequence_length, attention_size)
        keys = self.W_K(x) #                                      (batch_size, sequence_length, attention_size)
        values = self.W_V(x) #                                    (batch_size, sequence_length, attention_size)

        #Compute attention and then normalize
        attention = torch.bmm(queries.transpose(1,2), keys) #                  (batch_size, attention_size, attention_size)
        weights = torch.nn.functional.softmax(attention, dim=2) # Apply this per sample

        # Use as weights for values
        context = torch.bmm(weights, values.transpose(1,2)).transpose(1,2) #    (batch_size, attention_size, sequence_length)
        # Run through all FFN layers
        for layer in self.layers:
            context = layer(context)
        # Return prediction with added b term
        return context

    def train_step(self, x, y):
        self.optimizer.zero_grad()
        pred = self(x)
        loss = self.criterion(pred.squeeze(-1), y)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(self.parameters(), max_norm=1.0) # This is for stability
        self.optimizer.step()
        return loss.item() # Diagnostic info

    def train(self, dataloader):
        losses = []
        for epoch in range(self.train_len):
            epoch_loss = 0
            for x_batch, y_batch in dataloader:
                loss = self.train_step(x_batch, y_batch)
                epoch_loss += loss
            avg_loss = epoch_loss / len(dataloader)
            losses.append(avg_loss)
            if (epoch + 1) % 5 == 0:
                print(f"Epoch {epoch+1}/{self.train_len}, Loss: {avg_loss:.4f}")
        return losses
            

In [None]:

dataset = DNADataset(embedded_dna_strands_train, train_scores)

# Create a DataLoader for batching, shuffling, and parallel data loading
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

model = SelfAttentionFeedForward(50, 5, 20, 3, 1e-5, 100) # (attention_size, embed_size, hidden_size, hidden_layers, lr, train_len)
model.train(dataloader)


Epoch 5/100, Loss: 0.3679


In [42]:
model.predict(torch.Tensor(embedded_dna_strands_train[0])).detach().numpy() - train_scores[0]

array([[ 0.22174352,  0.71485429,  0.62327915, ...,  1.0669322 ,
         0.41677427,  1.01910806],
       [ 0.27477124,  0.76788204,  0.6763069 , ...,  1.11995983,
         0.46980199,  1.07213569],
       [ 0.34722375,  0.84033453,  0.74875938, ...,  1.19241238,
         0.54225451,  1.14458823],
       ...,
       [19.20031839, 19.69342905, 19.60185342, ..., 20.04550743,
        19.3953495 , 19.99768257],
       [19.26924997, 19.76236063, 19.670785  , ..., 20.11443901,
        19.46428108, 20.06661415],
       [19.33420473, 19.82731538, 19.73573976, ..., 20.17939377,
        19.52923584, 20.13156891]])

In [20]:
train_scores[0]

array([ 0.019,  0.019,  0.019,  0.018,  0.017,  0.017,  0.016,  0.015,
        0.037,  0.059,  0.082,  0.104,  0.126,  0.125,  0.125,  0.124,
        0.124,  0.123,  0.116,  0.109,  0.102,  0.095,  0.088,  0.07 ,
        0.053,  0.035,  0.018,  0.   ,  0.   ,  0.   ,  0.   ,  0.   ,
        0.   ,  0.005,  0.01 ,  0.014,  0.019,  0.024,  0.019,  0.014,
        0.01 ,  0.005,  0.   , -0.052, -0.104, -0.156, -0.208, -0.26 ,
       -0.27 , -0.28 , -0.29 , -0.3  , -0.31 , -0.298, -0.286, -0.275,
       -0.263, -0.251, -0.253, -0.255, -0.258, -0.26 , -0.262, -0.248,
       -0.234, -0.219, -0.205, -0.191, -0.202, -0.213, -0.223, -0.234,
       -0.245, -0.235, -0.225, -0.214, -0.204, -0.194, -0.191, -0.188,
       -0.185, -0.182, -0.179, -0.181, -0.183, -0.184, -0.186, -0.188,
       -0.155, -0.122, -0.089, -0.056, -0.023, -0.007,  0.009,  0.025,
        0.041,  0.057,  0.086,  0.116,  0.145,  0.175,  0.204,  0.21 ,
        0.216,  0.222,  0.228,  0.234,  0.24 ,  0.246,  0.252,  0.258,
      