In [23]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from Bio import SeqIO
import numpy as np
from sklearn.metrics import roc_auc_score, average_precision_score


#define one_hot encoding function
def one_hot(seq):
    """Convert RNA string to (4x201) one-hot encoding, float32 array."""
    mapping = {'A': 0, 'C': 1, 'G': 2, 'U': 3, 'T': 3, 'N': 4}
    arr = np.zeros((4, len(seq)), dtype=np.float32)
    for i, nt in enumerate(seq.upper()):
        idx = mapping.get(nt, 4)
        if idx < 4:
            arr[idx,i] = 1.0
    return arr

#define dataset object (expecting pair of pos and negative for train, val, and test respectively)

class PasDataset(Dataset):
    def __init__(self, fasta_paths, labels):
        """
        fasta_paths: list of FASTA file paths
        labels: corresponding labels [1,0]
        """
        self.seqs = []
        self.ys = []
        
        #loop through fasta paths and associated labels -> parse each fasta for 201nt segments -> onehot encode and append to seqs.
        for path, y in zip(fasta_paths, labels):
            for rec in SeqIO.parse(path, "fasta"):
                seq = str(rec.seq)
                if len(seq) != 201:
                    print('skipping sequence that is not 201bp')
                    continue
                self.seqs.append(one_hot(str(rec.seq)))
                self.ys.append(y)
                                 
        self.seqs = torch.tensor(self.seqs)
        self.ys = torch.tensor(self.ys, dtype=torch.float32)

    def __len__(self):
        return len(self.ys)

    def __getitem__(self, i):
        return self.seqs[i], self.ys[i]
            

In [19]:
from Bio import SeqIO
import random

pos = list(SeqIO.parse('../data/pos_201_hg19.fa', 'fasta'))
neg = list(SeqIO.parse('../data/neg_201_hg19.fa', 'fasta'))
random.seed(42)
random.shuffle(pos)
random.shuffle(neg)

def splits(lst, train_frac = 0.7, val_frac = 0.15):
    #splits the list of sequences based on index corresponding to set size. 
    n = len(lst)
    i_train = int(train_frac * n)
    i_val = i_train + int(val_frac * n)
    return lst[:i_train], lst[i_train:i_val], lst[i_val:]

pos_train, pos_val, pos_test = splits(pos)
neg_train, neg_val, neg_test = splits(neg)

#create separate fasta files for each set
from pathlib import Path

def write_split(records, path):
    Path(path).parent.mkdir(exist_ok = True)
    SeqIO.write(records, path, 'fasta')

write_split(pos_train, '../data/processed/pos_201_train.fa')
write_split(pos_val,   '../data/processed/pos_201_val.fa')
write_split(pos_test,  '../data/processed/pos_201_test.fa')

write_split(neg_train, '../data/processed/neg_201_train.fa')
write_split(neg_val,   '../data/processed/neg_201_val.fa')
write_split(neg_test,  '../data/processed/neg_201_test.fa')

In [25]:
class POLYNET(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv1d(4, 32, kernel_size = 8, padding = 'same')
        self.pool1 = nn.MaxPool1d(2)
        self.conv2 = nn.Conv1d(32, 64, kernel_size = 6, padding='same')
        self.gmp = nn.AdaptiveMaxPool1d(1)
        self.fc = nn.Linear(64,1)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = self.pool1(x)
        x = F.relu(self.conv2(x))
        x = self.gmp(x).squeeze(-1)
        return torch.sigmoid(self.fc(x)).squeeze(-1)


In [28]:
def evaluate(model, loader, device):
    model.eval()
    ys, ps = [], []
    with torch.no_grad():
        for xb, yb in loader:
            xb, yb = xb.to(device), yb.to(device).float()
            preds = model(xb)
            ys.extend(yb.cpu().numpy())
            ps.extend(preds.cpu().numpy())
    return roc_auc_score(ys, ps), average_precision_score(ys, ps)

#training loop
def train_and_evaluate(
    train_files, val_files, test_files,
    batch_size = 64, lr = 1e-3, epochs = 10
):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    #Datasets and loaders
    train_ds = PasDataset(train_files, [1,0])
    val_ds = PasDataset(val_files, [1,0])
    test_ds = PasDataset(test_files, [1,0])

    train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle = True, num_workers = 4)
    val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle = False, num_workers = 4)
    test_loader = DataLoader(test_ds, batch_size=batch_size, shuffle = False, num_workers = 4)

    model = POLYNET().to(device)
    opt = torch.optim.Adam(model.parameters(), lr = lr)
    lossf = nn.BCELoss()

    #training loop
    for epoch in range(1, epochs+1):
        model.train()
        running_loss = 0.0
        for xb, yb in train_loader:
            xb, yb = xb.to(device), yb.to(device).float()
            opt.zero_grad()
            preds = model(xb)
            loss = lossf(preds, yb)
            loss.backward()
            opt.step()
            running_loss += loss.item() * xb.size(0)
        train_loss = running_loss / len(train_loader.dataset)

        #Validation
        val_auc, val_auprc = evaluate(model, val_loader, device)
        print(f"Epoch {epoch:2d} | train_loss {train_loss:.4f} | "
        f"val_AUROC {val_auc:.4f} | val_AUPRC {val_auprc:.4f}")

    # Final test
    test_auc, test_auprc = evaluate(model, test_loader, device)
    print(f"Test set - AUROC: {test_auc:.4f}, AUPRC: {test_auprc:.4f}")

    #Save model
    torch.save(model.state_dict(), "../models/POLYNET.pt")


In [29]:
#training and evaluation

train_files = ["../data/processed/pos_201_train.fa", "../data/processed/neg_201_train.fa"]
val_files   = ["../data/processed/pos_201_val.fa",   "../data/processed/neg_201_val.fa"]
test_files  = ["../data/processed/pos_201_test.fa",  "../data/processed/neg_201_test.fa"]

train_and_evaluate(
    train_files, val_files, test_files,
    batch_size = 64, lr = 1e-3, epochs=10
)


skipping sequence that is not 201bp
Epoch  1 | train_loss 0.5052 | val_AUROC 0.8557 | val_AUPRC 0.8469
Epoch  2 | train_loss 0.4704 | val_AUROC 0.8615 | val_AUPRC 0.8531
Epoch  3 | train_loss 0.4610 | val_AUROC 0.8636 | val_AUPRC 0.8551
Epoch  4 | train_loss 0.4551 | val_AUROC 0.8657 | val_AUPRC 0.8574
Epoch  5 | train_loss 0.4508 | val_AUROC 0.8664 | val_AUPRC 0.8580
Epoch  6 | train_loss 0.4477 | val_AUROC 0.8673 | val_AUPRC 0.8590
Epoch  7 | train_loss 0.4453 | val_AUROC 0.8662 | val_AUPRC 0.8585
Epoch  8 | train_loss 0.4429 | val_AUROC 0.8675 | val_AUPRC 0.8596
Epoch  9 | train_loss 0.4417 | val_AUROC 0.8678 | val_AUPRC 0.8597
Epoch 10 | train_loss 0.4399 | val_AUROC 0.8671 | val_AUPRC 0.8592
Test set - AUROC: 0.8662, AUPRC: 0.8586
