In [None]:
import os
import torch
import pickle
import random
import cupy as cp
import numpy as np
import pandas as pd
import torch.nn as nn
import torch.optim as optim
from torch.optim import Adam
import torch.nn.functional as F
import cupyx.scipy.sparse as sp
import torch.backends.cudnn as cudnn
from cupyx.scipy.sparse import coo_matrix as cp_coo_matrix
from torch.utils.data import Dataset, DataLoader, TensorDataset
from cupyx.scipy.sparse import coo_matrix, dia_matrix, csr_matrix, triu, diags

torch.cuda.empty_cache()
cudnn.benchmark = False  
cudnn.deterministic = True 
devices=torch.cuda.current_device()

In [None]:
def sparse_collate_fn(batch):
    features, adj_labels, train_patient = zip(*batch)
    features_tensor = torch.stack(features)
    adj_labels_tensor = torch.stack(adj_labels)
    train_patient_tensor = train_patient  
    return features_tensor, adj_labels_tensor, train_patient_tensor

def sparse_to_tuple(sparse_mx):
    if not isinstance(sparse_mx, coo_matrix):
        sparse_mx = sparse_mx.tocoo()
    coords = cp.vstack((sparse_mx.row, sparse_mx.col)).T
    values = sparse_mx.data
    shape = sparse_mx.shape
    return coords, values, shape

def preprocess_graph(adj):
    adj = coo_matrix(adj)
    adj_ = adj + dia_matrix((cp.ones(adj.shape[0]), [0]), shape=adj.shape)
    rowsum = cp.array(adj_.sum(axis=1)).flatten()
    degree_mat_inv_sqrt = diags(cp.power(rowsum, -0.5))
    adj_normalized = adj_.dot(degree_mat_inv_sqrt).transpose().dot(degree_mat_inv_sqrt).tocoo()
    return sparse_to_tuple(adj_normalized)

def mask_test_edges(adj, ratio_val):
    if isinstance(adj, cp.ndarray):
        adj = csr_matrix(adj)
    elif isinstance(adj, torch.Tensor): 
        adj = cp.asarray(adj.cpu().numpy())
        adj = cp.sparse.csr_matrix(adj)
    adj = adj - dia_matrix((adj.diagonal()[cp.newaxis, :], [0]), shape=adj.shape)
    adj.eliminate_zeros()
    assert cp.diag(adj.toarray()).sum() == 0

    adj_triu = triu(adj)
    adj_tuple = sparse_to_tuple(adj_triu)
    edges = adj_tuple[0]

    num_val = int(cp.floor(edges.shape[0] * ratio_val))
    all_edge_idx = cp.arange(edges.shape[0])
    cp.random.shuffle(all_edge_idx)
    val_edge_idx = all_edge_idx[:num_val]
    val_edges = edges[val_edge_idx]
    mask = cp.ones(edges.shape[0], dtype=bool)
    mask[val_edge_idx] = False
    train_edges = edges[mask]
    data = cp.ones(train_edges.shape[0])
    adj_train = csr_matrix((data, (train_edges[:, 0], train_edges[:, 1])), shape=adj.shape)
    adj_train = adj_train + adj_train.T

    return adj_train, train_edges, val_edges

class MultiHeadAttentionModule(nn.Module):
    def __init__(self, embed_dim, num_heads, dropout=0.5):
        super(MultiHeadAttentionModule, self).__init__()
        self.multihead_attn = nn.MultiheadAttention(embed_dim=embed_dim, num_heads=num_heads, dropout=dropout,
                                                    batch_first=True)
        self.layer_norm = nn.LayerNorm(embed_dim)
        self.fc = nn.Linear(embed_dim, embed_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, query, key, value, mask=None):
        if isinstance(query, torch.sparse.Tensor):
            query = query.to_dense()

        if isinstance(key, torch.sparse.Tensor):
            key = key.to_dense()

        if isinstance(value, torch.sparse.Tensor):
            value = value.to_dense()
        attn_output, attn_weights = self.multihead_attn(query, key, value, attn_mask=mask)
        output = self.layer_norm(query + self.dropout(attn_output))
        output = self.fc(output)
        return output, attn_weights

class GraphConvSparse(nn.Module):
    def __init__(self, input_dim, output_dim, adj, activation=F.relu,device=devices, **kwargs):
        super(GraphConvSparse, self).__init__(**kwargs)
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.weight = glorot_init(input_dim, output_dim).to(device)
        self.adj = adj
        self.activation = activation
        self.device=device

    def forward(self, x):
        if x.dim() == 3:
            x = x.squeeze(0)
        x = x.to(self.device, dtype=torch.float32)
        adj_dense = self.adj.clone().detach().to(self.device, dtype=torch.float32)
        adj_dense = adj_dense.to_sparse().to(torch.float32)
        self.weight=self.weight.to(torch.float32) 
        x = x.to(torch.float32) 
        x = torch.sparse.mm(adj_dense, torch.mm(x, self.weight))
        outputs = self.activation(x)
        return outputs

def dot_product_decode(Z):
    return torch.sigmoid(torch.mm(Z, Z.t()))

def glorot_init(input_dim, output_dim):
    init_range = torch.sqrt(torch.tensor(6.0 / (input_dim + output_dim), device=devices))
    initial = cp.random.uniform(-init_range, init_range, size=(input_dim, output_dim))
    initial_gpu = cp.asarray(initial).get()
    return nn.Parameter(torch.tensor(initial_gpu, device=devices, dtype=torch.float32), requires_grad=True)

class Identity(nn.Module):
    def forward(self, x):
        return x

class EarlyStopping:

    def __init__(self, save_path, patience=10, verbose=False, delta=0):

        self.save_path = save_path
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.val_loss_min = cp.inf
        self.delta = delta

    def __call__(self, val_loss, model):

        score = -val_loss

        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
        elif score < self.best_score + self.delta:
            self.counter += 1
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
            self.counter = 0

    def save_checkpoint(self, val_loss, model):
        if self.verbose:
            print(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}).  Saving model ...')
        path = os.path.join(self.save_path, 'best_network.pth')
        torch.save(model.state_dict(), path) 
        self.val_loss_min = val_loss

class Autoencoder(nn.Module):
    def __init__(self, input_dim, latent_dim=1):
        super(Autoencoder, self).__init__()
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, 128),
            nn.LeakyReLU(negative_slope=0.05),
            nn.Linear(128, 32),
            nn.LeakyReLU(negative_slope=0.05),
            nn.Linear(32, 8),
            nn.LeakyReLU(negative_slope=0.05),
            nn.Linear(8, latent_dim), 
            nn.Tanh()
        )
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, 8),
            nn.LeakyReLU(negative_slope=0.05),
            nn.Linear(8, 32),
            nn.LeakyReLU(negative_slope=0.05),
            nn.Linear(32, 128),
            nn.LeakyReLU(negative_slope=0.05),
            nn.Linear(128, input_dim), 
            nn.Tanh()
        )

    def forward(self, x):
        encoded = self.encoder(x)
        decoded = self.decoder(encoded)
        return encoded, decoded

class GAEWithAttention(nn.Module):
    def __init__(self, adj, input_dim, hidden1_dim, hidden2_dim, num_heads,device=devices):

        super(GAEWithAttention, self).__init__()
        self.adj = adj.to_sparse().to(device)
        self.attention = MultiHeadAttentionModule(embed_dim=input_dim, num_heads=num_heads).to(device)
        self.gcn1 = GraphConvSparse(input_dim, hidden1_dim, self.adj).to(device)
        self.gcn2 = GraphConvSparse(hidden1_dim, hidden2_dim, self.adj, activation=Identity()).to(device)
        self.device=device

    def encode(self, x):
        x = x.to(self.device)
        z_output,z_atten = self.attention(x, x, x)
        hidden = self.gcn1(z_output)
        z = self.gcn2(hidden)
        z = z.unsqueeze(0)
        return z,z_atten

    def decode(self, z):
        return dot_product_decode(z.squeeze(0))

    def forward(self, x):
        x = x.to(self.device)
        z,z_atten = self.encode(x)
        a_pred = self.decode(z)
        return a_pred, z, z_atten

def GAE_function_with_attention_EP(net_m, feature_m_list, GAE_epochs=30000, learning_rate=0.01, num_heads=4,
                                   ratio_val=0, seed=666, hidden1_dim=10, hidden2_dim=1, save_path=None, patience=10,device=devices):
    patients=feature_m_list.keys()
    a=list(feature_m_list.keys())[0]
    input_dim = int(feature_m_list[a].shape[1])
    
    def convert_to_sparse(matrix):
        if isinstance(matrix, torch.Tensor):
            matrix = matrix.cpu().numpy() 
            matrix = cp.asarray(matrix) 
        row, col = cp.nonzero(matrix)  
        data = matrix[row, col] 
        data = data.astype(cp.float32)  
        coo_matrix = cp.sparse.coo_matrix((data, (row, col)), shape=matrix.shape)
        return coo_matrix
    net_m_copy=net_m
    net_m =convert_to_sparse(cp.asarray(net_m))
    adj = cp.sparse.csr_matrix(net_m)
    adj_train, train_edges, val_edges = mask_test_edges(net_m, ratio_val=ratio_val)
    adj_norm = preprocess_graph(net_m)
    eye_matrix = sp.eye(adj_train.shape[0], dtype=adj_train.dtype, format='csr')
    adj_label = adj_train + eye_matrix
    adj_label = sparse_to_tuple(adj_label)
    adj_norm = torch.sparse_coo_tensor(torch.LongTensor(adj_norm[0].T),
                                            torch.FloatTensor(adj_norm[1]),
                                            torch.Size(adj_norm[2]))
    adj_label = torch.sparse_coo_tensor(torch.LongTensor(adj_label[0].T),
                                        torch.FloatTensor(adj_label[1]),
                                        torch.Size(adj_label[2]))
    adj_sum = adj_train.sum()
    adj_sum = adj_sum if adj_sum > 0 else 1e-6
    pos_weight = (adj_train.shape[0] ** 2 - adj_sum) / adj_sum
    norm = adj_train.shape[0] ** 2 / ((adj_train.shape[0] ** 2 - adj_sum) * 2)

    weight_mask = adj_label.to_dense().view(-1) == 1
    pos_weight = torch.tensor(pos_weight, device=device).float()
    weight_mask = weight_mask.clone().detach().to(device)
    weight_tensor = torch.ones(weight_mask.size(0), device=device)
    weight_tensor[weight_mask] = pos_weight
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    model = GAEWithAttention(adj=adj_norm, input_dim=input_dim, hidden1_dim=hidden1_dim, hidden2_dim=hidden2_dim,
                             num_heads=num_heads).to(device)
    optimizer_A = Adam(model.parameters(), lr=learning_rate, weight_decay=0.001)
    best_epoch_loss = float("inf")
    train_losses = []
    early_stopping = EarlyStopping(patience=patience,delta=1e-6, save_path=save_path)
    adj_norm = adj_norm.to(device)
    adj_label = adj_label.to(device)
    norm = torch.tensor(norm, device=device, dtype=torch.float32)
    weight_tensor = weight_tensor.to(device)
    best_z = None
    best_epoch = 0
    best_model = None
    torch.cuda.manual_seed(seed)
    reconstruction_loss = nn.MSELoss()
    feature_all={}
    for patient in patients:
        feature_m=feature_m_list[patient]
        df_b_reindexed = pd.DataFrame(cp.asnumpy(feature_m)).reindex(index=pd.DataFrame(cp.asnumpy(net_m_copy)).index).fillna(0).values
        feature_m = df_b_reindexed
        feature_m = convert_to_sparse(cp.asarray(feature_m))
        feature_m = sp.csr_matrix(feature_m)
        features = sparse_to_tuple(feature_m.tocoo())
        
        features = torch.sparse_coo_tensor(torch.LongTensor(features[0].T),
                                            torch.FloatTensor(features[1]),
                                            torch.Size(features[2]))  
        features = features.to(device)
        feature_all.update({patient:features})

    for epoch in range(GAE_epochs):
        model.train()
        epoch_loss = 0.0
        for patient in patients:
            features=feature_all[patient]
            A_pred, z, _ = model(x=features)
            loss_1=reconstruction_loss(A_pred.view(-1), adj_label.to_dense().view(-1))
            optimizer_A.zero_grad()
            loss_1.backward()
            optimizer_A.step()
            epoch_loss += loss_1.item()
        epoch_loss = epoch_loss/len(patients)
        if epoch_loss < best_epoch_loss:
            best_epoch_loss = epoch_loss
            best_model = model
            best_epoch = epoch
            min_loss_val = epoch_loss
            patience_count = 0
        else:
            patience_count += 1
        
        if patience_count > patience:
            break

    return best_model, min_loss_val, best_epoch,feature_all

def train_pathway_model(pathways_matrix,features_matrix_list,GAE_epochs=30000, learning_rate=0.01, num_heads=4,
                                   ratio_val=0, seed=666, hidden1_dim=10, hidden2_dim=1, save_path=None, patience=10,device=devices):
    pathways=pathways_matrix.keys()
    pathway_model={}
    patient_feature={}
    i=0
    for path in pathways:
        model,_,_,features=GAE_function_with_attention_EP(net_m=pathways_matrix[path], feature_m_list=features_matrix_list, GAE_epochs=GAE_epochs, learning_rate=learning_rate, num_heads=num_heads,
                                   ratio_val=ratio_val, seed=seed, hidden1_dim=hidden1_dim, hidden2_dim=hidden2_dim, save_path=save_path, patience=patience,device=devices)
        pathway_model.update({path:model})
        patient_feature.update({path:features})
        i=i+1
        print(f"Pathway {i}/{len(pathways)}")
    return pathway_model,patient_feature

def calculate_pathway_activity_3d(pathway_model, patient_feature, latent_dim=4, device='cuda:0'):

    pathways = list(pathway_model.keys())
    a = list(patient_feature.keys())[0]
    patients = list(patient_feature[a].keys())
    
    all_pathways_activity= []
    columns_path=[]
    i = 0
    for _, path in enumerate(pathways):
        model = pathway_model[path].cuda()
        features = patient_feature[path]
        sample_pathway_dim=[]
        for _, patient in enumerate(patients):
            pat_feature = features[patient]
            _, z, _= model(x=pat_feature)
            z = z.squeeze(0)
            model1 = Autoencoder(z.shape[0], latent_dim)
            model1 = model1.to(device)
            criterion = nn.MSELoss().cuda()
            optimizer = optim.Adam(model1.parameters(), lr=0.001)
            
            aaa = torch.tensor(z.T, dtype=torch.float32).cuda()
            num_epochs = 50
            best_encoded = None
            best_train_loss = float('inf')
            
            for epoch in range(num_epochs):
                model1.train()
                optimizer.zero_grad()
                encoded, decoded = model1(aaa)
                loss = criterion(decoded, aaa)
                loss.backward()
                optimizer.step()
                
                if loss.item() < best_train_loss:
                    best_train_loss = loss.item()
                    best_encoded = encoded.detach().cpu().numpy()  
                    
            sample_pathway_dim.append(best_encoded.squeeze())
        sample_pathway_dim = np.stack(sample_pathway_dim, axis=0)
        all_pathways_activity.append(sample_pathway_dim)
        column_names = [f"{path}_latent_{i+1}" for i in range(latent_dim)]
        columns_path.extend(column_names)
        i += 1
        print(f"Processed Pathway {i}/{len(pathways)}")
        
    flattened_pathways_activity = torch.tensor(np.stack(all_pathways_activity, axis=1), dtype=torch.float32)
    return flattened_pathways_activity


In [None]:
with open("Data/pathways_adjacency_matrix_without_disease.pkl", "rb") as f:
    pathways_matrix = pickle.load(f)

with open("Data/Liu/com_data_144_3.pkl", "rb") as f:
    com_data = pickle.load(f)
    
pathway_model_144, patient_feature_144 = train_pathway_model(pathways_matrix = pathways_matrix, features_matrix_list = com_data, GAE_epochs = 30000, learning_rate = 0.01, num_heads = 1,
                                   ratio_val = 0, seed = 666, hidden1_dim = 3, hidden2_dim = 1, save_path = 'Data', patience = 20,device = devices)


In [None]:
with open("Data/Liu//patient_feature_144.pkl", "rb") as f:
    patient_feature_144 = pickle.load(f)

with open("Data/Liu//pathway_model_144.pkl", "rb") as f:
    pathway_model_144 = pickle.load(f)

pathway_144_activity = calculate_pathway_activity_3d(pathway_model = pathway_model_144, patient_feature = patient_feature_144, latent_dim = 6)

with open(f'Data/Liu//pathway_144_activity.pkl', "wb") as f:
    pickle.dump(pathway_144_activity, f, protocol=pickle.HIGHEST_PROTOCOL)