In [None]:
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader

import torch.optim as optim
import pandas as pd
import numpy as np
import matplotlib
import torch
from torch import nn
from torch.nn import functional as F
from sklearn.model_selection import train_test_split
from sklearn.metrics import roc_auc_score, r2_score
from sklearn.utils import shuffle
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import roc_auc_score, roc_curve
import glob

import os
from sklearn.metrics import confusion_matrix
from tqdm import tqdm

In [None]:
DATAPATH = '/home/jupyter/ADAPT_PCR_share/safe/dataset'
!ls $DATAPATH

In [None]:
train_df = pd.read_csv('%s/0716_dataset_train.csv'%DATAPATH,index_col=[0,1])
val_df = pd.read_csv('%s/0716_dataset_valid.csv'%DATAPATH,index_col=[0,1])
test_df = pd.read_csv('%s/0716_dataset_test.csv'%DATAPATH,index_col=[0,1])
print(train_df.shape)
print(val_df.shape)
print(test_df.shape)
pd.set_option('display.max_rows', None)
pd.set_option('display.max_columns', None)
pd.set_option('display.width', None) # Adjust display width for long lines
pd.set_option('display.max_colwidth', None) # Display full content of cells

train_df.head(1)

In [None]:
column_names = train_df.columns
print(column_names)

In [None]:
print(len(train_df['prod_Tm'].unique()))

# Handcrafted Features

In [None]:
inputs = train_df[["f_length", "f_indel","f_mm","r_length", "r_indel","r_mm","f_Tm", "r_Tm", "prod_length", "prod_Tm"]]
scaler = StandardScaler()
scaler.fit(inputs)
X_train_scaled = scaler.transform(inputs)
X_val_scaled = scaler.transform(val_df[["f_length", "f_indel","f_mm","r_length", "r_indel","r_mm","f_Tm", "r_Tm", "prod_length", "prod_Tm"]])
X_test_scaled = scaler.transform(test_df[["f_length", "f_indel","f_mm","r_length", "r_indel","r_mm","f_Tm", "r_Tm", "prod_length", "prod_Tm"]])

In [None]:
X_test_scaled

# Sequence-based classification

In [None]:
class PCRDataset(Dataset):
    def __init__(self, X, y):
        self.X = torch.Tensor(np.array(X)) 
        self.y = torch.Tensor(np.array(y))
        self.len=len(self.X)

    def __getitem__(self, index):
        return self.X[index], self.y[index]

    def __len__(self):
        return self.len

In [None]:
MAXPLEN = max(train_df['f_penc'].apply(len).max(), train_df['r_penc'].apply(len).max())
print(MAXPLEN)

def one_hot_encode(seq, length=28):
    mapping = { 'A':[1, 0, 0, 0, 0],
                'T':[0, 1, 0, 0, 0],
                'C':[0, 0, 1, 0, 0],
                'G':[0, 0, 0, 1, 0],
                'N':[0, 0, 0, 0, 0],
                '-':[0, 0, 0, 0, 1] }
    seq = seq.ljust(length, 'N') # (6, ATCG) -> NNATCG
    return np.array([mapping[char.upper()] for char in seq])

def one_hot_encode_full_gap(df_seqs, maxl=1421):
    primer_encoded = []
    target_encoded = []
    for (tname,pname),row in df_seqs.iterrows():
        fseq, fst, rseq, rst, tseq = row[['f_seq','f_start','r_seq','r_start','target_seq']]
        fenc, ftenc, renc, rtenc = row[['f_penc','f_tenc','r_penc','r_tenc']]
        pseq = 'N'*fst + fenc + 'N'*(rst-(fst+len(fseq))) + renc + 'N'*(len(tseq)-(rst+len(rseq)))
        tseq = tseq[:fst] + ftenc + tseq[fst+len(fseq):rst] + rtenc + tseq[rst+len(rseq):]
        primer_encoded.append(one_hot_encode(pseq, maxl))
        target_encoded.append(one_hot_encode(tseq, maxl))
    final_encoded = np.append(np.array(target_encoded), np.array(primer_encoded), axis=2)
    print(final_encoded.shape)
    return torch.tensor(final_encoded, dtype=torch.float32)
def one_hot_encode_sequences(df_seqs):
    primer_encoded = []
    target_encoded = []
    for (tname,pname),row in df_seqs.iterrows():
        fenc, ftenc, renc, rtenc = row[['f_penc','f_tenc','r_penc','r_tenc']].apply(one_hot_encode)
        prienc = np.append(fenc,renc,axis=0)
        tarenc = np.append(ftenc,rtenc,axis=0)
        primer_encoded.append(prienc)
        target_encoded.append(tarenc)
    primer_encoded = np.array(primer_encoded)
    target_encoded = np.array(target_encoded)
    final_encoded = np.append(target_encoded, primer_encoded, axis=2)
    print(final_encoded.shape)
    return torch.tensor(final_encoded, dtype=torch.float32)

class PCRDataset(Dataset):
    def __init__(self, encoded_input, custom_features, ct_values):
        """
        encoded_input: consists of a tensor containing (8 x 56 encoding of the sequence).
                       the upper 4 rows are the one-hot encoding of the target seq, 
                       the lower 4 rows are the one-hot encoding of the primer seq.
        ct_values: a tensor containing the ct values for each primer pair and target sequence
        """
        self.encoded_input = encoded_input
        self.custom_features = custom_features
        self.ct_values = ct_values
    def __len__(self):
        return len(self.encoded_input)
    def __getitem__(self, idx):
        return self.encoded_input[idx], self.custom_features[idx], self.ct_values[idx]

In [None]:
X_train_scaled

In [None]:
torch.tensor(X_train_scaled)

In [None]:
### Change between these for either full-sequence or core sequence testing

# train_inps = one_hot_encode_sequences(train_df)
# val_inps = one_hot_encode_sequences(val_df)
# test_inps = one_hot_encode_sequences(test_df)

# train_inps = one_hot_encode_full_gap(train_df)
# val_inps = one_hot_encode_full_gap(val_df)
# test_inps = one_hot_encode_full_gap(test_df)
# # useful for cross-validation
# train_and_val_inps = torch.cat((train_inps,val_inps))

cutoff=0

train_labels = torch.tensor(np.where(train_df['score'] > cutoff, 1, 0))
val_labels = torch.tensor(np.where(val_df['score'] > cutoff, 1, 0))
test_labels = torch.tensor(np.where(test_df['score'] > cutoff, 1, 0))
train_val_labels = torch.cat((train_labels, val_labels))


# Create dataset objects
train_dataset = PCRDataset(train_inps, torch.tensor(X_train_scaled), train_labels)
val_dataset = PCRDataset(val_inps, torch.tensor(X_val_scaled), val_labels)
test_dataset = PCRDataset(test_inps, torch.tensor(X_test_scaled), test_labels)
train_val_dataset = PCRDataset(train_and_val_inps, torch.tensor(np.concatenate([X_train_scaled, X_val_scaled])), train_val_labels)


# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True) # can play around
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)
train_val_dataset = DataLoader(train_val_dataset, batch_size=64, shuffle=False) # dont want it interfering with cross-val

# Model Definitions

In [None]:
import numpy as np
import pandas as pd
import os
import math
import matplotlib.pyplot as plt
import seaborn as sns

from einops import rearrange, repeat
from scipy.stats import spearmanr
from sklearn.metrics import r2_score
from sklearn.preprocessing import StandardScaler
from joblib import dump, load
from tqdm.auto import tqdm
from Bio import Seq

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.utils.data import TensorDataset, Subset
from sklearn.model_selection import KFold
torch.manual_seed(42)
device = torch.device('cuda' if torch.cuda.is_available() else "cpu")
device

In [None]:
class PGC(nn.Module):
    def __init__(self,d_model,expansion_factor = 1.0,dropout = 0.0):
        super().__init__()
        self.d_model = d_model
        self.expansion_factor = expansion_factor
        self.dropout = dropout
        self.conv = nn.Conv1d(int(d_model * expansion_factor), int(d_model * expansion_factor),
                              kernel_size=3, padding=1, groups=int(d_model * expansion_factor))
        self.in_proj = nn.Linear(d_model, int(d_model * expansion_factor * 2))
        self.out_norm = nn.RMSNorm(int(d_model), eps=1e-8)
        self.in_norm = nn.RMSNorm(int(d_model * expansion_factor * 2), eps=1e-8)
        self.out_proj = nn.Linear(int(d_model * expansion_factor), d_model)
        self.dropout = nn.Dropout(dropout)
    def forward(self, u):
        xv = self.in_norm(self.in_proj(u))
        x,v = xv.chunk(2,dim=-1)
        x_conv = self.conv(x.transpose(-1,-2)).transpose(-1,-2)
        gate =  v * x_conv
        x = self.out_norm(self.out_proj(gate))
        return x
    
class DropoutNd(nn.Module):
    def __init__(self, p: float = 0.5, tie=True, transposed=True):
        """
        tie: tie dropout mask across sequence lengths (Dropout1d/2d/3d)
        """
        super().__init__()
        if p < 0 or p >= 1:
            raise ValueError("dropout probability has to be in [0, 1), " "but got {}".format(p))
        self.p = p
        self.tie = tie
        self.transposed = transposed
        self.binomial = torch.distributions.binomial.Binomial(probs=1-self.p)

    def forward(self, X):
        """X: (batch, dim, lengths...)."""
        if self.training:
            if not self.transposed: X = rearrange(X, 'b ... d -> b d ...')
            # binomial = torch.distributions.binomial.Binomial(probs=1-self.p) 
            # This is incredibly slow because of CPU -> GPU copying
            mask_shape = X.shape[:2] + (1,)*(X.ndim-2) if self.tie else X.shape
            # mask = self.binomial.sample(mask_shape)
            mask = torch.rand(*mask_shape, device=X.device) < 1.-self.p
            X = X * mask * (1.0/(1-self.p))
            if not self.transposed: X = rearrange(X, 'b d ... -> b ... d')
            return X
        return X

class S4DKernel(nn.Module):
    """Generate convolution kernel from diagonal SSM parameters."""

    def __init__(self, d_model, N=64, dt_min=0.001, dt_max=0.1, lr=None):
        super().__init__()
        # Generate dt
        H = d_model
        log_dt = torch.rand(H) * (
            math.log(dt_max) - math.log(dt_min)
        ) + math.log(dt_min)

        C = torch.randn(H, N // 2, dtype=torch.cfloat)
        self.C = nn.Parameter(torch.view_as_real(C))
        self.register("log_dt", log_dt, lr)

        log_A_real = torch.log(0.5 * torch.ones(H, N//2))
        A_imag = math.pi * repeat(torch.arange(N//2), 'n -> h n', h=H)
        self.register("log_A_real", log_A_real, lr)
        self.register("A_imag", A_imag, lr)

    def forward(self, L):
        """
        returns: (..., c, L) where c is number of channels (default 1)
        """

        # Materialize parameters
        dt = torch.exp(self.log_dt) # (H)
        C = torch.view_as_complex(self.C) # (H N)
        A = -torch.exp(self.log_A_real) + 1j * self.A_imag # (H N)

        # Vandermonde multiplication
        dtA = A * dt.unsqueeze(-1)  # (H N)
        K = dtA.unsqueeze(-1) * torch.arange(L, device=A.device) # (H N L)
        C = C * (torch.exp(dtA)-1.) / A
        K = 2 * torch.einsum('hn, hnl -> hl', C, torch.exp(K)).real

        return K

    def register(self, name, tensor, lr=None):
        """Register a tensor with a configurable learning rate and 0 weight decay"""

        if lr == 0.0:
            self.register_buffer(name, tensor)
        else:
            self.register_parameter(name, nn.Parameter(tensor))

            optim = {"weight_decay": 0.0}
            if lr is not None: optim["lr"] = lr
            setattr(getattr(self, name), "_optim", optim)


class S4D(nn.Module):
    def __init__(self, d_model, d_state=64, dropout=0.0, transposed=True, **kernel_args):
        super().__init__()

        self.h = d_model
        self.n = d_state
        self.d_output = self.h
        self.transposed = transposed

        self.D = nn.Parameter(torch.randn(self.h))
        # SSM Kernel
        self.kernel = S4DKernel(self.h, N=self.n, **kernel_args)
        # Pointwise
        self.activation = nn.GELU()
        dropout_fn = DropoutNd
        self.dropout = dropout_fn(dropout) if dropout > 0.0 else nn.Identity()

        # position-wise output transform to mix features
        self.output_linear = nn.Sequential(
            nn.Conv1d(self.h, 2*self.h, kernel_size=1),
            nn.GLU(dim=-2),
        )

    def forward(self, u, **kwargs): # absorbs return_output and transformer src mask
        """ Input and output shape (B, H, L) """
        if not self.transposed: u = u.transpose(-1, -2)
        L = u.size(-1)
        # Compute SSM Kernel
        k = self.kernel(L=L) # (H L)

        # Convolution
        k_f = torch.fft.rfft(k, n=2*L)  # (H L)
        u_f = torch.fft.rfft(u, n=2*L) # (B H L)
        y = torch.fft.irfft(u_f*k_f, n=2*L)[..., :L] # (B H L)

        # Compute D term in state space equation - essentially a skip connection
        y = y + u * self.D.unsqueeze(-1)

        y = self.dropout(self.activation(y))
        y = self.output_linear(y)
        if not self.transposed: y = y.transpose(-1, -2)
        return y
    
class Janus(nn.Module):
    def __init__(self, d_input, d_output, d_model, d_state=64, dropout=0.2, transposed=False, **kernel_args):
        super().__init__()
        self.encoder = nn.Linear(d_input, d_model)
        self.pgc1 = PGC(d_model, expansion_factor=0.25, dropout=dropout)
        self.pgc2 = PGC(d_model, expansion_factor=2, dropout=dropout)
        self.s4d = S4D(d_model, d_state=d_state, dropout=dropout, transposed=transposed, **kernel_args)
        self.norm = nn.RMSNorm(d_model)
        self.decoder = nn.Linear(d_model, d_output)
        self.dropout = nn.Dropout(dropout)

    def forward(self, u):
        x = self.encoder(u)
        x = self.pgc1(x)
        x = self.pgc2(x)
        z = x
        z = self.norm(z)
        x = self.dropout(self.s4d(z)) + x
        x = x.mean(dim=1)
        #x = self.dropout(x)
        x = self.decoder(x)
        return x
    
class MLP(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(MLP, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, output_dim)
        )
    def forward(self, x):
        return self.model(x)
    
class CombinedModel(nn.Module):
    def __init__(self, mlp_dims, ssm_dims, combined_hidden, final_output):
        super(CombinedModel, self).__init__()

        # Individual models
        self.mlp = MLP(*mlp_dims)
        self.ssm = Janus(*ssm_dims)

        # Combining MLP
        combined_input_dim = mlp_dims[-1] + ssm_dims[1]
        self.combiner = nn.Sequential(
            nn.Linear(combined_input_dim, combined_hidden),
            nn.ReLU(),
            nn.Linear(combined_hidden, final_output),
            nn.Sigmoid()
        )

    def forward(self, mlp_input, ssm_input):
        mlp_out = self.mlp(mlp_input)  # Output from MLP
        ssm_out = self.ssm(ssm_input)  # Output from SSM

        # Concatenate outputs
        combined = torch.cat((mlp_out, ssm_out), dim=1)

        # Final prediction
        final_output = self.combiner(combined)
        return final_output

In [None]:
from sklearn.metrics import confusion_matrix


def train_loop(model, lr):
    # Initialize lists to store performance metrics for different data volumes
    metrics = []
    num_epochs = 50  # Adjust
    criterion = nn.BCELoss()
    
    # Train Janus model
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=0.01) #lr = 0.001
    best_cross_loss = float('inf')
    best_r2 = float('-inf') 
    best_model_state = None
    
    # setup the k-fold cross validation
    k_folds = 5
    kfold = KFold(n_splits=k_folds, shuffle=True, random_state=42)

    trues, preds, tprs, fprs = {}, {}, {}, {}

    for epoch in range(num_epochs):
        model.train()
        train_loss = 0
        train_true = []
        train_pred = []


        for inputs, mlp_inputs, labels in tqdm(train_loader, desc=f'Epoch {epoch+1}/{num_epochs} - Janus Training'):
            inputs, mlp_inputs, labels = inputs.to(device).float(), mlp_inputs.to(device).float(), labels.to(device).float()
            outputs = model(mlp_inputs, inputs)
            loss = criterion(outputs.squeeze(), labels)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            train_loss += loss.item()
            train_true.append(labels.detach().cpu().numpy())
            outputs = (outputs >= 0.5).int()
            train_pred.append(outputs.squeeze().detach().cpu().numpy())

        # Evaluate on cross-validation set
        model.eval()
        cross_loss = 0
        cross_true = []
        cross_pred = []

        with torch.no_grad():

            for inputs, mlp_inputs, labels in tqdm(val_loader, desc=f'Epoch {epoch+1}/{num_epochs} - Cross Validation'):
                inputs, mlp_inputs, labels = inputs.to(device).float(), mlp_inputs.to(device).float(), labels.to(device).float()
                outputs = model(mlp_inputs, inputs)
                loss = criterion(outputs.squeeze(), labels)
                cross_loss += loss.item()
                cross_true.append(labels.detach().cpu().numpy())
                outputs = (outputs >= 0.5).int()
                cross_pred.append(outputs.squeeze().detach().cpu().numpy())

            cross_true = np.concatenate(cross_true)
            cross_pred = np.concatenate(cross_pred)

        # Flatten the collected predictions and true labels
        all_true = np.concatenate([cross_true])
        all_pred = np.concatenate([cross_pred])

        # Compute confusion matrix: [[TN, FP], [FN, TP]]
        tn, fp, fn, tp = confusion_matrix(all_true, all_pred).ravel()

        # Compute TPR and FPR
        tpr = tp / (tp + fn) if (tp + fn) > 0 else 0.0
        fpr = fp / (fp + tn) if (fp + tn) > 0 else 0.0

        # Print statistics
        print(f"\nEpoch {epoch+1} Cross-validation Statistics:")
        print(f"Cross-validation Loss: {cross_loss/len(val_loader):.4f}")
        print(f"TPR (Recall): {tpr:.4f}")
        print(f"FPR: {fpr:.4f}")

        # Save metrics
        trues[epoch] = cross_true
        preds[epoch] = cross_pred
        tprs[epoch] = tpr
        fprs[epoch] = fpr

        # Save the best model
        if tpr > best_tpr:  
            best_tpr = tpr
            best_model_state = model.state_dict()

        if cross_loss < best_cross_loss:
            best_cross_loss = cross_loss
    
    plt.title(f"lr = {lr}")
    plt.plot(list(range(num_epochs)), list(tprs.values()))
    plt.plot(list(range(num_epochs)), list(fprs.values()))
    plt.legend(["TPR", "FPR"])
    plt.show()
    
      

In [None]:
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score


def evaluate_auroc(model, data_loader, device, results_df=None, model_name="Model"):
    model.eval()
    all_preds = []
    all_labels = []

    with torch.no_grad():
        for inputs, mlp_inputs, labels in data_loader:
            inputs = inputs.to(device).float()
            mlp_inputs = mlp_inputs.to(device).float()
            labels = labels.to(device)

            outputs = model(mlp_inputs, inputs)

            preds = outputs.squeeze().cpu().numpy() # probability values for each in [0,1]

            all_preds.extend(preds)
            all_labels.extend(labels.cpu().numpy())
            
    output = (np.array(all_preds) >= 0.5).astype(int) # thresholding to convert to 0s and 1s
    # Compute metrics
    accuracy = accuracy_score(all_labels, output)
    precision = precision_score(all_labels, output, zero_division=0)
    recall = recall_score(all_labels, output, zero_division=0)
    f1 = f1_score(all_labels, output, zero_division=0)
    try:
        auroc = roc_auc_score(all_labels, all_preds) # the only one we actually want preds from
    except ValueError:
        auroc = float('nan')  # in case only one class in y_true

    # Prepare results
    result_row = pd.DataFrame([{
        'Model': model_name,
        'Accuracy': accuracy,
        'Precision': precision,
        'Recall': recall,
        'F1': f1,
        'AUROC': auroc
    }])

    # Append to existing DataFrame
    if results_df is None:
        results_df = result_row
    else:
        results_df = pd.concat([results_df, result_row], ignore_index=True)

    return all_labels, all_preds, results_df

def plot_auroc(labels, preds):
    auc = roc_auc_score(labels, preds)
    fpr, tpr, _ = roc_curve(labels, preds)
    print(auc)

    plt.figure()
    plt.plot(fpr, tpr, label=f'AUROC = {auc:.4f}')
    plt.plot([0, 1], [0, 1], 'k--')  # Random guess line
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title('AUROC Curve')
    plt.legend(loc='lower right')
    plt.grid(True)
    plt.show()

In [None]:
janus_inp_dim = 10
janus_out_dim = 4
janus_mod_dims = [64, 128]

mlp_inp_dim = X_test_scaled.shape[1]
mlp_hid_dims = [32, 64]
mlp_out_dim = 4

com_hid_dim = 32
final_out_dim = 1


results_df = pd.DataFrame()

lrs = [0.001]

for janus_mod_dim in janus_mod_dims:
     for mlp_hid_dim in mlp_hid_dims:
        for lr in lrs:
            model = CombinedModel(mlp_dims=(mlp_inp_dim, mlp_hid_dim, mlp_out_dim), 
                  ssm_dims=(janus_inp_dim, janus_out_dim, janus_mod_dim), 
                  combined_hidden=com_hid_dim, final_output=final_out_dim).to(device)
            train_loop(model, lr=lr)
            labels, preds, results_df = evaluate_auroc(model, test_loader, device, results_df, model_name=f"Combined, janus:{janus_mod_dim}, mlp:{mlp_hid_dim}, lr:{lr}")
            plot_auroc(labels, preds)

            
display(results_df)