In [8]:
import pandas as pd
from src.deepromoter import DeePromoter
from src.utils import load_dataset, protein2num, get_list_kmer


from torch.utils.data import DataLoader, TensorDataset, random_split
import torch
import math
import argparse
import torch.optim as optim
from torch import nn
from icecream import ic
from pathlib import Path


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


In [3]:
data = load_dataset("train_data.tsv")

In [16]:
data.columns

Index(['insert_chrom', 'insert_name', 'sequence', 'rna_dna_ratio',
       'is_active'],
      dtype='object')

In [22]:
MAPPER = {
    "A": 0,
    "C": 1,
    "G": 2,
    "T": 3,
}
def one_hot_seq(sequences: pd.Series) -> torch.Tensor:
    total = []
    
    for seq in sequences:
        x = torch.zeros(size=(4, len(seq)))
        for i, aa in enumerate(seq):
            x[MAPPER[aa], i] = 1
        total.append(x)
    return torch.stack(total, dim=0)

X = one_hot_seq(data["sequence"])
        

In [23]:
y = data["is_active"]
y = torch.tensor(y, dtype=torch.float)


In [28]:
ker = [27, 14, 7]
net = DeePromoter(ker, 
                  input_shape=(32, 271, 4),
                  )
net.to(device)

DeePromoter(
  (pconv): ParallelCNN(
    (lseq): ModuleList(
      (0): Sequential(
        (0): Conv1d(4, 4, kernel_size=(27,), stride=(1,), padding=same)
        (1): ReLU()
        (2): MaxPool1d(kernel_size=6, stride=6, padding=0, dilation=1, ceil_mode=False)
        (3): Dropout(p=0.5, inplace=False)
      )
      (1): Sequential(
        (0): Conv1d(4, 4, kernel_size=(14,), stride=(1,), padding=same)
        (1): ReLU()
        (2): MaxPool1d(kernel_size=6, stride=6, padding=0, dilation=1, ceil_mode=False)
        (3): Dropout(p=0.5, inplace=False)
      )
      (2): Sequential(
        (0): Conv1d(4, 4, kernel_size=(7,), stride=(1,), padding=same)
        (1): ReLU()
        (2): MaxPool1d(kernel_size=6, stride=6, padding=0, dilation=1, ceil_mode=False)
        (3): Dropout(p=0.5, inplace=False)
      )
    )
  )
  (bilstm): BidirectionalLSTM(
    (rnn): LSTM(12, 12, batch_first=True, bidirectional=True)
    (linear): Linear(in_features=24, out_features=12, bias=True)
  )
  (fla

In [24]:

# Combine X and y into a TensorDataset
dataset = TensorDataset(X, y)

# Split the dataset into training and testing sets (80% train, 20% test)
train_size = int(0.8 * len(dataset))
test_size = len(dataset) - train_size
train_dataset, test_dataset = random_split(dataset, [train_size, test_size])

# Create DataLoaders for training and testing sets
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

In [35]:
epoch_num = 1000
criterion = nn.CrossEntropyLoss()
# optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
optimizer = optim.Adam(net.parameters(), lr=0.001)

exp_folder = Path("exp")
running_loss = 0
best_mcc = 0
best_precision = 0
best_recall = 0
break_after = 10
last_update_best = 0
pbar = range(epoch_num)
ic("Start training")

for epoch in pbar:
    running_loss = 0
    for i, (X, y) in enumerate(train_loader):
        net.train()
        # get the inputs
        inputs, labels = X.to(device), y.to(device)
        inputs = inputs.permute(0, 2, 1)
        # zero the parameter gradients
        optimizer.zero_grad()

        # pass model to
        outputs = net(inputs)
        loss = criterion(outputs, labels.long())
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    # print statistics
    running_loss /= len(train_loader)
    ic(f"Epoch {epoch+1}, loss: {running_loss:.3f}")
    if epoch % 10 == 0:
        net.eval()
        torch.save(net.state_dict(), str(exp_folder.joinpath("epoch_" + str(epoch) + ".pth")))
        correct = 0
        total = 0
        with torch.no_grad():
            for test_inputs, test_labels in test_loader:
                test_inputs, test_labels = test_inputs.to(device), test_labels.to(device)
                test_inputs = test_inputs.permute(0, 2, 1)
                test_outputs = net(test_inputs)
                _, predicted = torch.max(test_outputs, 1)
                total += test_labels.size(0)
                correct += (predicted == test_labels.long()).sum().item()

        accuracy = 100 * correct / total
        ic(f"Accuracy on test set: {accuracy}%")
        #        precision, recall, MCC = mcc(eval_data)


ic| 'Start training'
ic| f"Epoch {epoch+1}, loss: {running_loss:.3f}": 'Epoch 1, loss: 0.633'
ic| f"Accuracy on test set: {accuracy}%": 'Accuracy on test set: 67.16185455852403%'
ic| f"Epoch {epoch+1}, loss: {running_loss:.3f}": 'Epoch 2, loss: 0.624'
ic| f"Epoch {epoch+1}, loss: {running_loss:.3f}": 'Epoch 3, loss: 0.618'
ic| f"Epoch {epoch+1}, loss: {running_loss:.3f}": 'Epoch 4, loss: 0.617'
ic| f"Epoch {epoch+1}, loss: {running_loss:.3f}": 'Epoch 5, loss: 0.612'
ic| f"Epoch {epoch+1}, loss: {running_loss:.3f}": 'Epoch 6, loss: 0.607'
ic| f"Epoch {epoch+1}, loss: {running_loss:.3f}": 'Epoch 7, loss: 0.606'
ic| f"Epoch {epoch+1}, loss: {running_loss:.3f}": 'Epoch 8, loss: 0.602'
ic| f"Epoch {epoch+1}, loss: {running_loss:.3f}": 'Epoch 9, loss: 0.602'
ic| f"Epoch {epoch+1}, loss: {running_loss:.3f}": 'Epoch 10, loss: 0.600'
ic| f"Epoch {epoch+1}, loss: {running_loss:.3f}": 'Epoch 11, loss: 0.599'
ic| f"Accuracy on test set: {accuracy}%": 'Accuracy on test set: 60.189289565113214%'
ic|

KeyboardInterrupt: 