In [None]:
import torch
import pickle
import numpy as np
import pandas as pd
import torch.nn as nn
import torch.optim as optim
from multiprocessing import Pool
import torch.backends.cudnn as cudnn
from sklearn.metrics import roc_curve, auc
torch.cuda.empty_cache()
cudnn.benchmark = False  
cudnn.deterministic = True 
devices=torch.cuda.current_device()

In [None]:
class InterpretableTransformerLayer(nn.Module):
    def __init__(self, d_model, nhead, dim_ff=128, dropout=0.0,batch_first=True):
        super().__init__()
        self.self_attn = nn.MultiheadAttention(
            embed_dim=d_model,
            num_heads=nhead,
            batch_first=batch_first
        )
        self.ffn = nn.Sequential(
            nn.Linear(d_model, dim_ff),
            nn.ReLU(),
            nn.Linear(dim_ff, d_model)
        )
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        attn_out, attn_weights = self.self_attn(
            x, x, x,
            need_weights=True,
            average_attn_weights=False
        )
        x = self.norm1(x + self.dropout(attn_out))
        ffn_out = self.ffn(x)
        x = self.norm2(x + self.dropout(ffn_out))
        return x, attn_weights
    
class TransformerModel(nn.Module):
    def __init__(self, input_size, num_classes,nhead=3,dropout_prob=0.5,num_path=230,transformer_layers =8,mlp_hidden = 24,batch_first = True):
        super(TransformerModel, self).__init__()
        self.transformer = nn.ModuleList([
            InterpretableTransformerLayer(
                d_model=input_size,
                nhead=nhead,
                dropout=dropout_prob,
                dim_ff = mlp_hidden,
                batch_first=batch_first
            ) for _ in range(transformer_layers)
        ])

        self.fc = nn.Sequential(
            nn.Linear(num_path, num_classes)
        )
        self.dropout1 = nn.Dropout(dropout_prob)
    def forward(self, x):
        attention_weights = []
        for layer in self.transformer:
            x, atten = layer(x)
            attention_weights.append(atten)
        x = x.max(dim=2).values
        x = self.fc(x)
        return x,attention_weights

def train_end_to_end_predict_model_nonEP(pathways_activity, patient_labels, path_act_val, patients_val, num_heads=5, seed=666, global_epo=10000,
                                           save_best_model_path=None, batch_size=8,dropout=0.5,num_path = 230,device='cuda:0'):

    X_val_tensor_60 = path_act_val.cuda()
    y_val_tensor_60 = torch.tensor(patients_val.values, dtype=torch.long, device=device)

    cv_result={}
    pathway_activity_train = pathways_activity
    train_labels = patient_labels
    max_auc_val3 = 0.0
    max_auc_train = 0
    input_size = 6
    model_fc = TransformerModel(input_size, num_classes=2,nhead=num_heads,dropout_prob=dropout,num_path = num_path,transformer_layers=2,mlp_hidden=24).cuda()
    optimizer_fc = optim.Adam(model_fc.parameters(), lr=0.001, betas=(0.9, 0.999), eps=1e-7)
    weight = torch.tensor([
        1.0,                    # class 0
        88 / 17                 # class 1
        ]).to(device)
    criterion = nn.CrossEntropyLoss(weight=weight)
    
    pathway_activity_train_tensor = torch.tensor(pathway_activity_train.values, dtype=torch.float64)
    train_labels_tensor = torch.tensor(train_labels.values, dtype=torch.float64)
    train_dataset = torch.utils.data.TensorDataset(pathway_activity_train_tensor.clone().detach().to(dtype=torch.float64),train_labels_tensor)
    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True,num_workers=0)
    val_aucs=[]
    val_aucs3=[]
    for epo in range(global_epo):
        model_fc.train()
        for batch_idx, (path_acti, batch_labels) in enumerate(train_loader):
            batch_labels = torch.tensor([int(label.item()) for label in batch_labels], dtype=torch.long, device=device)
            X_train_tensor =path_acti.clone().detach().to(dtype=torch.float32, device=device)
            X_train_tensor = X_train_tensor.reshape(len(X_train_tensor), 230, 6)
            y_train_tensor = batch_labels.clone().detach().long().to(device)
            optimizer_fc.zero_grad()
            outputs= model_fc(X_train_tensor)
            if isinstance(outputs, tuple):
                outputs = outputs[0]
            outputs=outputs.squeeze(dim=1)

            log_probs=outputs
            loss_fc = criterion(outputs, y_train_tensor)
            l1_reg = torch.tensor(0.,device=device)
            for param in model_fc.parameters():
                l1_reg += torch.norm(param, p=2)
            total_loss = loss_fc + 0.1 * l1_reg
            total_loss.backward()
            optimizer_fc.step()
            
        
        # Evaluate on validation data
        model_fc.eval()
        X_train_tensor1 = torch.tensor(pathways_activity.values, dtype=torch.float32, device="cuda")
        X_train_tensor1 = X_train_tensor1.reshape(len(X_train_tensor1), 230, 6)
        y_train_tensor1 = torch.tensor(patient_labels.values, dtype=torch.long, device="cuda")

        
        with torch.no_grad():

            train_outputs,_ = model_fc(X_train_tensor1)
            train_probs = torch.softmax(train_outputs, dim=1)[:, 1] 
            y_true = y_train_tensor1.cpu().numpy()
            y_score = train_probs.detach().cpu().numpy()

            fpr, tpr, _ = roc_curve(y_true, y_score)
            train_auc = auc(fpr, tpr)
            
            val_outputs_60, _ = model_fc(X_val_tensor_60)
            log_probs = torch.softmax(val_outputs_60, dim=1)[:, 1]
            y_true_val = y_val_tensor_60.cpu().numpy()
            y_score_val = log_probs.detach().cpu().numpy()
            fpr, tpr, _ = roc_curve(y_true_val, y_score_val)
            val_auc3 = auc(fpr, tpr)
            val_aucs3.append(val_auc3)
            val_aucs.append(train_auc)

        if val_auc3 > max_auc_val3:
            max_auc_train = train_auc
            max_auc_val3 = val_auc3
            best_model_fc = model_fc
            best_epo=epo

            if train_auc > 0.9 and val_auc3 > 0.75:
                fold_model_path = f"{save_best_model_path}best_model_epo_{epo + 1}at_{seed}.pt"
                torch.save(best_model_fc, fold_model_path)
                print(f"Saved best model at epoch {epo} to {fold_model_path}")

    print(f"Best epo: {best_epo}")
    print(f"AUC_train: {max_auc_train}")
    print(f"AUC_val: {max_auc_val3}")        
    val_aucs = pd.DataFrame(val_aucs) if isinstance(val_aucs, list) else val_aucs
    val_aucs3 = pd.DataFrame(val_aucs3) if isinstance(val_aucs3, list) else val_aucs3
    data1=pd.concat([val_aucs,val_aucs3], axis=1)
    data1.columns=['val_aucs', 'val_aucs'] 
    return cv_result


In [None]:
with open("../Data/Liu/pathway_activity_144.pkl", "rb") as f:
    pathway_train = pickle.load(f)
with open("../Data/Liu/patient_response_144.pkl", "rb") as f:
    patient_response_train = pickle.load(f)
pathway_train = pathway_train.view(pathway_train.size(0), -1)
pathway_train = pd.DataFrame(pathway_train.detach().cpu().numpy())
pathway_train.index = patient_response_train.index

with open("../Data/Snyder/pathway_activity_60.pkl", "rb") as f:
    pathway_val = pickle.load(f)
with open("../Data/Snyder/patient_response_60.pkl", "rb") as f:
    patient_response_val = pickle.load(f)

seeds = np.random.randint(1000, 10000, size=10)
for seed in seeds:
    print(seed)
    torch.manual_seed(seed)                                        
    model=train_end_to_end_predict_model_nonEP(pathways_activity = pathway_train, patient_labels = patient_response_train, path_act_val = pathway_val,
                                                         patients_val = patient_response_val, batch_size = 8, num_heads = 1, global_epo = 300,
                                                         num_path = 230, save_best_model_path ='../Model', dropout = 0.3, seed = seed)