In [4]:
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
import copy
import math
from scipy.interpolate import interp1d

dtype = torch.FloatTensor

# Data loading functions
def sort_data(path):
    ''' Sort the genomic and clinical data w.r.t. survival time (OS_MONTHS) in descending order '''
    data = pd.read_csv(path)
    data.sort_values("OS_MONTHS", ascending=False, inplace=True)
    
    x = data.drop(["SAMPLE_ID", "OS_MONTHS", "OS_EVENT", "AGE"], axis=1).values
    ytime = data.loc[:, ["OS_MONTHS"]].values
    yevent = data.loc[:, ["OS_EVENT"]].values
    age = data.loc[:, ["AGE"]].values
    
    return x, ytime, yevent, age

def load_data(path, dtype):
    ''' Load the sorted data, and then convert it to a Pytorch tensor '''
    x, ytime, yevent, age = sort_data(path)
    
    X = torch.from_numpy(x).type(dtype)
    YTIME = torch.from_numpy(ytime).type(dtype)
    YEVENT = torch.from_numpy(yevent).type(dtype)
    AGE = torch.from_numpy(age).type(dtype)
    
    if torch.cuda.is_available():
        X = X.cuda()
        YTIME = YTIME.cuda()
        YEVENT = YEVENT.cuda()
        AGE = AGE.cuda()
    
    return X, YTIME, YEVENT, AGE

def load_pathway(path, dtype=torch.float32):
    ''' Load a bi-adjacency matrix of pathways from a CSV file and convert it to a PyTorch tensor '''
    pathway_mask = pd.read_csv(path, index_col=0).values
    PATHWAY_MASK = torch.tensor(pathway_mask, dtype=dtype)
    
    if torch.cuda.is_available():
        PATHWAY_MASK = PATHWAY_MASK.cuda()
    
    return PATHWAY_MASK

# Model Class
class Cox_PASNet(nn.Module):
    def __init__(self, In_Nodes, Pathway_Nodes, Hidden_Nodes, Out_Nodes, Pathway_Mask):
        super(Cox_PASNet, self).__init__()
        self.tanh = nn.Tanh()
        self.pathway_mask = Pathway_Mask
        self.sc1 = nn.Linear(In_Nodes, Pathway_Nodes)
        self.sc2 = nn.Linear(Pathway_Nodes, Hidden_Nodes)
        self.sc3 = nn.Linear(Hidden_Nodes, Out_Nodes, bias=False)
        self.sc4 = nn.Linear(Out_Nodes + 1, 1, bias=False)
        self.sc4.weight.data.uniform_(-0.001, 0.001)
        
        self.do_m1 = torch.ones(Pathway_Nodes)
        self.do_m2 = torch.ones(Hidden_Nodes)
        
        if torch.cuda.is_available():
            self.do_m1 = self.do_m1.cuda()
            self.do_m2 = self.do_m2.cuda()
    
    def forward(self, x_1, x_2):
        self.sc1.weight.data = self.sc1.weight.data.mul(self.pathway_mask)
        x_1 = self.tanh(self.sc1(x_1))
        
        if self.training:
            x_1 = x_1.mul(self.do_m1)
        
        x_1 = self.tanh(self.sc2(x_1))
        
        if self.training:
            x_1 = x_1.mul(self.do_m2)
        
        x_1 = self.tanh(self.sc3(x_1))
        x_cat = torch.cat((x_1, x_2), 1)
        lin_pred = self.sc4(x_cat)
        
        return lin_pred

# Other functions (helpers)
def dropout_mask(n_node, drop_p):
    ''' Construct a binary mask to randomly drop nodes in a layer '''
    keep_p = 1.0 - drop_p
    mask = torch.Tensor(np.random.binomial(1, keep_p, size=n_node))
    if torch.cuda.is_available():
        mask = mask.cuda()
    return mask

def s_mask(sparse_level, param_matrix, nonzero_param_1D, dtype):
    ''' Construct a binary matrix w.r.t. a sparsity level of weights between two consecutive layers '''
    non_neg_param_1D = torch.abs(nonzero_param_1D)
    num_param = non_neg_param_1D.size(0)
    
    if num_param == 0:
        # If there are no non-zero parameters, return an all-ones mask (no sparsity applied)
        return torch.ones_like(param_matrix, dtype=torch.float32)  # Ensure using torch.float32 or the correct dtype
    
    top_k = math.ceil(num_param * (100 - sparse_level) * 0.01)
    sorted_non_neg_param_1D, indices = torch.topk(non_neg_param_1D, top_k)
    
    if sorted_non_neg_param_1D.size(0) == 0:
        # Handle the edge case where top_k results in an empty tensor
        return torch.ones_like(param_matrix, dtype=torch.float32)  # Ensure using torch.float32 or the correct dtype
    
    param_mask = torch.abs(param_matrix) > sorted_non_neg_param_1D.min()
    param_mask = param_mask.type(dtype)
    
    if torch.cuda.is_available():
        param_mask = param_mask.cuda()
    
    return param_mask

def R_set(x):
    ''' Create an indicator matrix of risk sets '''
    n_sample = x.size(0)
    matrix_ones = torch.ones(n_sample, n_sample)
    indicator_matrix = torch.tril(matrix_ones)
    return indicator_matrix

def neg_par_log_likelihood(pred, ytime, yevent):
    ''' Calculate the average Cox negative partial log-likelihood '''
    n_observed = yevent.sum(0)
    ytime_indicator = R_set(ytime)
    
    if torch.cuda.is_available():
        ytime_indicator = ytime_indicator.cuda()
    
    risk_set_sum = ytime_indicator.mm(torch.exp(pred))
    diff = pred - torch.log(risk_set_sum)
    sum_diff_in_observed = torch.transpose(diff, 0, 1).mm(yevent)
    cost = (- (sum_diff_in_observed / n_observed)).reshape((-1,))
    
    return cost
def c_index(pred, ytime, yevent):
	'''Calculate concordance index to evaluate models.
	Input:
		pred: linear predictors from trained model.
		ytime: true survival time from load_data().
		yevent: true censoring status from load_data().
	Output:
		concordance_index: c-index (between 0 and 1).
	'''
	n_sample = len(ytime)
	ytime_indicator = R_set(ytime)
	ytime_matrix = ytime_indicator - torch.diag(torch.diag(ytime_indicator))
	###T_i is uncensored
	censor_idx = (yevent == 0).nonzero()
	zeros = torch.zeros(n_sample)
	ytime_matrix[censor_idx, :] = zeros
	###1 if pred_i < pred_j; 0.5 if pred_i = pred_j
	pred_matrix = torch.zeros_like(ytime_matrix)
	for j in range(n_sample):
		for i in range(n_sample):
			if pred[i] < pred[j]:
				pred_matrix[j, i]  = 1
			elif pred[i] == pred[j]: 
				pred_matrix[j, i] = 0.5
	
	concord_matrix = pred_matrix.mul(ytime_matrix)
	###numerator
	concord = torch.sum(concord_matrix)
	###denominator
	epsilon = torch.sum(ytime_matrix)
	###c-index = numerator/denominator
	concordance_index = torch.div(concord, epsilon)
	###if gpu is being used
	if torch.cuda.is_available():
		concordance_index = concordance_index.cuda()
	###
	return(concordance_index)

# Modify the `trainCoxPASNet` function to return only the model, not a tuple
def trainCoxPASNet(train_x, train_age, train_ytime, train_yevent, 
                   eval_x, eval_age, eval_ytime, eval_yevent, pathway_mask,
                   In_Nodes, Pathway_Nodes, Hidden_Nodes, Out_Nodes,
                   Learning_Rate, L2, Num_Epochs, Dropout_Rate, dtype=torch.FloatTensor,
                   patience=100):
    
    net = Cox_PASNet(In_Nodes, Pathway_Nodes, Hidden_Nodes, Out_Nodes, pathway_mask)
    
    if torch.cuda.is_available():
        net = net.cuda()
    
    opt = optim.Adam(net.parameters(), lr=Learning_Rate, weight_decay=L2)

    best_eval_loss = float('inf')
    patience_counter = 0
    best_model = None

    for epoch in range(Num_Epochs + 1):
        net.train()
        opt.zero_grad()
        
        net.do_m1 = dropout_mask(Pathway_Nodes, Dropout_Rate[0])
        net.do_m2 = dropout_mask(Hidden_Nodes, Dropout_Rate[1])

        pred = net(train_x, train_age)
        loss = neg_par_log_likelihood(pred, train_ytime, train_yevent)
        loss.backward()
        opt.step()

        net.sc1.weight.data = net.sc1.weight.data.mul(net.pathway_mask)

        if epoch % 10 == 0:
            net.eval()
            with torch.no_grad():
                eval_pred = net(eval_x, eval_age)
                eval_loss = neg_par_log_likelihood(eval_pred, eval_ytime, eval_yevent).view(1,)

            # Early stopping check
            if eval_loss < best_eval_loss:
                best_eval_loss = eval_loss
                patience_counter = 0
                best_model = copy.deepcopy(net.state_dict())
            else:
                patience_counter += 1
                if patience_counter >= patience:
                    print(f"Early stopping at epoch {epoch}. Validation loss did not improve for {patience} epochs.")
                    break

    if best_model is not None:
        net.load_state_dict(best_model)
    
    return net  # Only return the model, not a tuple

       
             




dtype = torch.FloatTensor
''' Net Settings'''
In_Nodes = 5567 ###number of genes
Pathway_Nodes = 860 ###number of pathways
Hidden_Nodes = 100 ###number of hidden nodes
Out_Nodes = 30 ###number of hidden nodes in the last hidden layer
''' Initialize '''
Initial_Learning_Rate = [0.00075]
L2_Lambda = [0.1]
num_epochs = 3000 ###for grid search
Num_EPOCHS = 3000 ###for training
###sub-network setup
Dropout_Rate = [0.7,0.5]
''' load data and pathway '''
pathway_mask = load_pathway("C:/Users/Admin/Downloads/pathway_mask.csv", dtype=torch.float32)

x_train, ytime_train, yevent_train, age_train = load_data("C:/Users/Admin/Downloads/train.csv", dtype)
x_valid, ytime_valid, yevent_valid, age_valid = load_data("C:/Users/Admin/Downloads/validation.csv", dtype)
x_test, ytime_test, yevent_test, age_test = load_data("C:/Users/Admin/Downloads/test.csv", dtype)

opt_l2_loss = 0
opt_lr_loss = 0
opt_loss = torch.Tensor([float("Inf")])
###if gpu is being used
if torch.cuda.is_available():
	opt_loss = opt_loss.cuda()
###
# Perform hyperparameter search
opt_l2_loss = 0
opt_lr_loss = 0
opt_loss = torch.Tensor([float("Inf")])
if torch.cuda.is_available():
    opt_loss = opt_loss.cuda()
opt_c_index_va = 0

for l2 in L2_Lambda:
    for lr in Initial_Learning_Rate:
        # Train with current hyperparameters
        trained_model = trainCoxPASNet(
            x_train, age_train, ytime_train, yevent_train,
            x_valid, age_valid, ytime_valid, yevent_valid, pathway_mask,
            In_Nodes, Pathway_Nodes, Hidden_Nodes, Out_Nodes,
            lr, l2, num_epochs, Dropout_Rate, patience=500
        )
        
        # Validation metrics
        val_pred = trained_model(x_valid, age_valid)
        loss_valid = neg_par_log_likelihood(val_pred, ytime_valid, yevent_valid)
        c_index_va = c_index(val_pred, ytime_valid, yevent_valid)  # Replace with actual C-index computation
        
        # Update optimal parameters if the current validation loss is the best
        if loss_valid < opt_loss:
            opt_l2_loss = l2
            opt_lr_loss = lr
            opt_loss = loss_valid
            opt_c_index_va = c_index_va
        print(f"L2: {l2}, LR: {lr}, Validation Loss: {loss_valid.item()}, C-index VA: {c_index_va.item()}")

# Train Cox-PASNet with optimal hyperparameters on the train set, then evaluate on test set
trained_model = trainCoxPASNet(
    x_train, age_train, ytime_train, yevent_train,
    x_valid, age_valid, ytime_valid, yevent_valid, pathway_mask,
    In_Nodes, Pathway_Nodes, Hidden_Nodes, Out_Nodes,
    opt_lr_loss, opt_l2_loss, num_epochs, Dropout_Rate, patience=100
)

# Evaluate on test set
test_pred = trained_model(x_test, age_test)
loss_test = neg_par_log_likelihood(test_pred, ytime_test, yevent_test)
c_index_te = c_index(test_pred, ytime_test, yevent_test)  # Use the correct c_index function

print(f"Optimal L2: {opt_l2_loss}, Optimal LR: {opt_lr_loss}")
print(f"Test C-index: {c_index_te.item()}")


L2: 0.1, LR: 0.00075, Validation Loss: 3.4729888439178467, C-index VA: 0.6210272908210754
Early stopping at epoch 1040. Validation loss did not improve for 100 epochs.
Optimal L2: 0.1, Optimal LR: 0.00075
Test C-index: 0.6726282835006714
