# GATDep

## Data preprocess 

### Download 

dowload public datasets from `DepMap`, `TCGA`, `GEO`, `DrugComb`, `STRING` etc.

#### DepMap

[DepMap](https://depmap.org/portal/data_page/?tab=overview)

+ CRISPRGeneEffect.csv
+ CRISPRInferredCommonEssentials.csv
+ Model.csv
+ OmicsCNGene.csv
+ OmicsExpressionProteinCodingGenesTPMLogp1BatchCorrected.csv
+ PortalOmicsCNGeneLog2.csv
+ OmicsSomaticMutationsProfile.csv

#### TCGA

[TCGA](https://www.cancer.gov/ccg/research/genome-sequencing/tcga)

[cBioProtal](https://www.cbioportal.org/)

#### GEO

[GEO](https://www.ncbi.nlm.nih.gov/geo/)

+ GSE272107
+ GSE259249
+ GSE219938/GSE219474
+ GSE221475

#### DrugComb

[DrugComb](https://drugcomb.org/)

+ summary_v_1_5.csv

#### STRING

[STRING](https://cn.string-db.org/cgi/download?sessionId=bEqmxblRdyEl)

+ 9606.protein.links.v10.5.txt.gz

#### MSigDB

[MSigDB](https://www.gsea-msigdb.org/gsea/msigdb/index.jsp)

+ C2 category
+ C5 category

## Model construction

### GATDep 

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GATConv, LayerNorm
import torch.optim as optim
from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score
from scipy.stats import pearsonr
import numpy as np

def evaluate(model, loader):
    model.eval()
    preds_all, y_all = [], []
    with torch.no_grad():
        for data in loader:
            data = data.to(device)
            out = model(data.x, data.edge_index,data.batch)
            preds_all.append(out.cpu().numpy().flatten())
            y_all.append(data.y.cpu().numpy().flatten())

    y_true = np.concatenate(y_all)
    y_pred = np.concatenate(preds_all)

    mae = mean_absolute_error(y_true, y_pred)
    mse = mean_squared_error(y_true, y_pred)
    rmse = np.sqrt(mse)
    r2 = r2_score(y_true, y_pred)
    pearson = pearsonr(y_true.ravel(), y_pred.ravel())[0]

    return {'MAE': mae, 'MSE': mse, 'RMSE': rmse, 'R2': r2, 'Pearson': pearson}

class GeneDependencyGAT(nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels=1, heads=2, dropout=0.2):
        super(GeneDependencyGAT, self).__init__()
    
        self.layer = nn.Linear(in_channels, 512)
        self.ln1 = nn.LayerNorm(512)
        self.dropout1 = nn.Dropout(0.1)#0.1,0.2
        
        #  GAT layers
        self.gat1 = GATConv(512, hidden_channels, heads=2, dropout=dropout,concat=True)
        self.norm1 = LayerNorm(hidden_channels * 2)

        self.gat2 = GATConv(hidden_channels * 2, hidden_channels)
        self.norm2 = LayerNorm(hidden_channels)

        # node level regression head
        self.lin = nn.Linear(hidden_channels * 2 , out_channels)

    def forward(self, x, edge_index, batch):
        
        x = self.layer(x)
        x = self.ln1(x)
        x = torch.relu(x)#relu,tanh
        
        x = self.gat1(x, edge_index)
        x = self.norm1(x)
        x = F.relu(x)

        x = self.gat2(x, edge_index)
        x = self.norm2(x)
        x = F.relu(x)


        x_global = global_mean_pool(x, batch)               # [num_graphs, hidden_dim]
        x_global = x_global[batch]                          # broadcast  [num_nodes, hidden_dim]

        x = torch.cat([x, x_global], dim=1)                 # [num_nodes, hidden*2]
        out = self.lin(x)
        return out    
    
model = GeneDependencyGAT(in_channels=6561, hidden_channels=64, out_channels=1)
optimizer = torch.optim.Adam(model.parameters(), lr=5e-3)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=3)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

### GATDep_Mut 

In [None]:

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GATConv, LayerNorm
import torch.optim as optim
from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score
from scipy.stats import pearsonr
import numpy as np

# === Loss functions ===
huber_loss = nn.HuberLoss()
mse_loss = nn.MSELoss()
mae_loss = nn.L1Loss()
def evaluate(model, loader):
    model.eval()
    preds_all, y_all = [], []
    loss_es = []
    with torch.no_grad():
        for data in loader:
            data = data.to(device)
            out = model(data.x, data.x_mut,data.edge_index,data.batch)
            preds_all.append(out.cpu().numpy().flatten())
            y_all.append(data.y.cpu().numpy().flatten())
            loss = huber_loss(out, data.y)
            loss_es.append(loss.item())
            
    
    #y_true = torch.cat(preds_all, dim=0).numpy() #np.concatenate(y_all)
    #y_pred = torch.cat(y_all, dim=0).numpy()#np.concatenate(preds_all)
    y_true = np.concatenate(y_all)
    y_pred = np.concatenate(preds_all)

    loss = np.mean(loss_es)
    mae = mean_absolute_error(y_true, y_pred)
    mse = mean_squared_error(y_true, y_pred)
    rmse = np.sqrt(mse)
    r2 = r2_score(y_true, y_pred)
    pearson = pearsonr(y_true.ravel(), y_pred.ravel())[0]

    return {'MAE': mae, 'MSE': mse, 'RMSE': rmse, 'R2': r2, 'Pearson': pearson,'Loss':loss}

class MutationLayer(nn.Module):
    def __init__(self, embed_dim=32, hidden_dim=256):
        """
        Args:
            embed_dim
            hidden_dim
        """
        super().__init__()
        self.embed_dim = embed_dim
        
        self.embed = nn.Linear(1, embed_dim)  
        self.encoder = nn.Sequential(
            nn.LazyLinear(hidden_dim),  # input shape
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.1)
        )
    
    def forward(self, x):
        """
        Args:
            x: [batch_size, 2350, 1] or [batch_size, 2350] (float tensor)
        Returns:
            [batch_size, hidden_dim] 
        """
        if x.dim() == 2:
            x = x.unsqueeze(-1)  # [batch, 2350] -> [batch, 2350, 1]
        
        x = self.embed(x)  # [batch, 2350, embed_dim]
        
        x = x.mean(dim=1)  # [batch, embed_dim]
        
        return self.encoder(x)

class GeneDependencyGAT(nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels=1, heads=4, dropout=0.2):
        super(GeneDependencyGAT, self).__init__()
    
        self.layer = nn.Linear(in_channels, 512)
        self.ln1 = nn.LayerNorm(512)
        self.dropout1 = nn.Dropout(0.1)#0.1,0.2
        
        #add mutation features extraction layer
        self.mut_layer = MutationLayer(embed_dim=32, hidden_dim=128)
        
        #self.gat1 = GATConv(512, hidden_channels, heads=2, dropout=dropout,concat=True)
        self.gat1 = GATConv(640, hidden_channels, heads=2, dropout=dropout,concat=True) # 512+128 -》 640
        self.norm1 = LayerNorm(hidden_channels * 2)

        self.gat2 = GATConv(hidden_channels * 2, hidden_channels)
        self.norm2 = LayerNorm(hidden_channels)

        self.lin = nn.Linear(hidden_channels * 2 , out_channels)

    def forward(self, x,x_mut, edge_index, batch):
        
        x = self.layer(x)
        x = self.ln1(x)
        x = torch.relu(x)#relu,tanh
        
        #add mutation features
        x_mut=self.mut_layer(x_mut)
        x = torch.cat([x, x_mut], dim=1)#cat MUT and EXP into one: 2350 x (512+128)
        
        x = self.gat1(x, edge_index)
        x = self.norm1(x)
        x = F.relu(x)

        x = self.gat2(x, edge_index)
        x = self.norm2(x)
        x = F.relu(x)

        x_global = global_mean_pool(x, batch)               # [num_graphs, hidden_dim]
        x_global = x_global[batch]                          # broadcast [num_nodes, hidden_dim]

        x = torch.cat([x, x_global], dim=1)                 # [num_nodes, hidden*2]
        out = self.lin(x)
        return out    
    
model = GeneDependencyGAT(in_channels=6561, hidden_channels=64, out_channels=1)
optimizer = torch.optim.Adam(model.parameters(), lr=5e-3)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=3)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

### GATDep_Mut_CNV 

In [None]:

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GATConv, LayerNorm
import torch.optim as optim
from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score
from scipy.stats import pearsonr
import numpy as np

huber_loss = nn.HuberLoss()
mse_loss = nn.MSELoss()
mae_loss = nn.L1Loss()
def evaluate(model, loader):
    model.eval()
    preds_all, y_all = [], []
    loss_es = []
    with torch.no_grad():
        for data in loader:
            data = data.to(device)
            out = model(data.x, data.x_mut,data.x_cnv,data.edge_index,data.batch)
            preds_all.append(out.cpu().numpy().flatten())
            y_all.append(data.y.cpu().numpy().flatten())
            loss = huber_loss(out, data.y)
            loss_es.append(loss.item())
            
    y_true = np.concatenate(y_all)
    y_pred = np.concatenate(preds_all)

    loss = np.mean(loss_es)
    mae = mean_absolute_error(y_true, y_pred)
    mse = mean_squared_error(y_true, y_pred)
    rmse = np.sqrt(mse)
    r2 = r2_score(y_true, y_pred)
    pearson = pearsonr(y_true.ravel(), y_pred.ravel())[0]

    return {'MAE': mae, 'MSE': mse, 'RMSE': rmse, 'R2': r2, 'Pearson': pearson,'Loss':loss}

class MutationLayer(nn.Module):
    def __init__(self, embed_dim=32, hidden_dim=256):
        """
        Args:
            embed_dim
            hidden_dim
        """
        super().__init__()
        self.embed_dim = embed_dim
        

        self.embed = nn.Linear(1, embed_dim)  
        self.encoder = nn.Sequential(
            nn.LazyLinear(hidden_dim),  
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.1)
        )
    
    def forward(self, x):
        """
        Args:
            x: [batch_size, 2350, 1] or [batch_size, 2350] 
        Returns:
            [batch_size, hidden_dim] 
        """
        if x.dim() == 2:
            x = x.unsqueeze(-1)  # [batch, 2350] -> [batch, 2350, 1]
        
        x = self.embed(x)  # [batch, 2350, embed_dim]
        
        x = x.mean(dim=1)  # [batch, embed_dim]
        
        return self.encoder(x)

class GeneDependencyGAT(nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels=1, heads=4, dropout=0.2):
        super(GeneDependencyGAT, self).__init__()
    
        self.layer = nn.Linear(in_channels, 512)
        self.ln1 = nn.LayerNorm(512)
        self.dropout1 = nn.Dropout(0.1)#0.1,0.2
        
        #mutation layer
        self.mut_layer = MutationLayer(embed_dim=32, hidden_dim=128)
        #cnv layer
        self.CNV_layer = MutationLayer(embed_dim=32, hidden_dim=128)
        
        self.gat1 = GATConv(768, hidden_channels, heads=2, dropout=dropout,concat=True) #512+128+128 -> 768
        self.norm1 = LayerNorm(hidden_channels * 2)

        self.gat2 = GATConv(hidden_channels * 2, hidden_channels)
        self.norm2 = LayerNorm(hidden_channels)

        self.lin = nn.Linear(hidden_channels * 2 , out_channels)

    def forward(self,x,x_mut,x_cnv, edge_index, batch):
        
        x = self.layer(x)
        x = self.ln1(x)
        x = torch.relu(x)#relu,tanh
        
        #add mut, cnv features
        x_mut=self.mut_layer(x_mut)
        x_cnv=self.CNV_layer(x_cnv)
        x = torch.cat([x, x_mut,x_cnv], dim=1)#cat Exp, Mut, CNV into 2350 x (512+128+128)
        
        
        x = self.gat1(x, edge_index)
        x = self.norm1(x)
        x = F.relu(x)

        x = self.gat2(x, edge_index)
        x = self.norm2(x)
        x = F.relu(x)

        x_global = global_mean_pool(x, batch)               # [num_graphs, hidden_dim]
        x_global = x_global[batch]                          # broadcast  [num_nodes, hidden_dim]

        x = torch.cat([x, x_global], dim=1)                 # [num_nodes, hidden*2]
        out = self.lin(x)
        return out    
    
model = GeneDependencyGAT(in_channels=6561, hidden_channels=64, out_channels=1)
optimizer = torch.optim.Adam(model.parameters(), lr=5e-3)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=3)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

### DNN 

In [None]:

model=nn.Sequential(

nn.Linear(6561,512),
nn.LayerNorm(512),
nn.ReLU(),
nn.Dropout(0.2),

nn.Linear(512,128),
nn.LayerNorm(128),
nn.ReLU(),
nn.Dropout(0.2),

nn.Linear(128,1),
)
optimizer = optim.Adam(model.parameters(), lr=0.001)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=3)

## Model Training 

### GATDep

In [None]:

import pandas as pd
import numpy as np
import json

import torch
from torch_geometric.utils import from_networkx
import networkx as nx
from torch_geometric.data import Data, DataLoader
import torch.nn.functional as F

from scipy.stats import spearmanr
from sklearn.model_selection import train_test_split
from torch.utils.data import TensorDataset
from torch_geometric.data import Data, DataLoader

from sklearn.preprocessing import StandardScaler

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GATConv, global_max_pool

import torch
from torch_geometric.data import Data, Batch
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GATConv, global_max_pool
from torch_geometric.nn import global_mean_pool

import torch
from torch_geometric.data import Data, Batch
import numpy as np


ppi_df = pd.read_csv("./datasets/ggi_id.txt", sep=" ")
edges = ppi_df[['g1_id', 'g2_id']].values
edge_index = torch.tensor(edges.T, dtype=torch.long) 

with open(r'./datasets/genes2ID_2k.json', 'r', encoding='utf-8') as file:
    genes2ID_2k = json.load(file)
with open(r'./datasets/samples2ID.json', 'r', encoding='utf-8') as file:
    samples2ID = json.load(file)    
gene_geneset=np.load(r"./datasets/adj_gene_geneset.npy")
gsva = np.load(r"./datasets/gsva_score_go_kegg_v3_format.npy")
crispr_score = np.genfromtxt(r"./datasets/crispr_score_v4_20250423_scale_in_each_sample.txt", 
                     dtype=str, delimiter=' ', 
                     skip_header=1, filling_values='', 
                     autostrip=True) #(2556800, 4)#格式化后的数据
                     
def get_train_test_dataset(crispr_dataset):
    samples=list(np.unique(crispr_dataset[:,0])) #CRISPR数据中的样本
    score_list=[]
    gene_feature_list=[]
    
    for sample in samples:
        score,gene_feature=get_score_and_gene_features_per_sample(sample,crispr_dataset)
        score_list.append(score)
        gene_feature_list.append(gene_feature)
    
    return gene_feature_list,score_list

def get_score_and_gene_features_per_sample(sample,crispr_score):
    tmp=[]
    for gene in genes2ID_2k.keys():
        tmp.append(gene_geneset[genes2ID_2k[gene],:] * gsva[samples2ID[sample]])
        #返回一个样本，所有gene(2350)的特征信息(6561)：
    
    #获取每个样本的数据
    crispr_score_sample=crispr_score[crispr_score[:,0]==sample,]    
    # 提取排序的列（第0列）
    sort_column = crispr_score_sample[:, 1]    
    # 生成排序索引（根据字典映射的权重）
    sort_indices = np.argsort([genes2ID_2k[x] for x in sort_column])    
    # 按索引重新排列数组
    sorted_crispr_score_sample = crispr_score_sample[sort_indices]
    # 单个样本中2k+ 基因score值
    score_per_sample=np.array([np.round(np.float32(m),4) for m in sorted_crispr_score_sample[:,2]]).reshape(-1,1) #(2350, 1)

    return score_per_sample,np.array(tmp)
    
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GATConv, LayerNorm
import torch.optim as optim
from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score
from scipy.stats import pearsonr
import numpy as np

def evaluate(model, loader):
    model.eval()
    preds_all, y_all = [], []
    with torch.no_grad():
        for data in loader:
            data = data.to(device)
            out = model(data.x, data.edge_index,data.batch)
            preds_all.append(out.cpu().numpy().flatten())
            y_all.append(data.y.cpu().numpy().flatten())
    
    #y_true = torch.cat(preds_all, dim=0).numpy() #np.concatenate(y_all)
    #y_pred = torch.cat(y_all, dim=0).numpy()#np.concatenate(preds_all)
    y_true = np.concatenate(y_all)
    y_pred = np.concatenate(preds_all)

    #loss = nn.HuberLoss()(y_true, y_pred)
    mae = mean_absolute_error(y_true, y_pred)
    mse = mean_squared_error(y_true, y_pred)
    rmse = np.sqrt(mse)
    r2 = r2_score(y_true, y_pred)
    pearson = pearsonr(y_true.ravel(), y_pred.ravel())[0]

    return {'MAE': mae, 'MSE': mse, 'RMSE': rmse, 'R2': r2, 'Pearson': pearson}

class GeneDependencyGAT(nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels=1, heads=4, dropout=0.2):
        super(GeneDependencyGAT, self).__init__()
    
        self.layer = nn.Linear(in_channels, 512)
        self.ln1 = nn.LayerNorm(512)
        self.dropout1 = nn.Dropout(0.1)#0.1,0.2
        
        # 两层 GAT
        self.gat1 = GATConv(512, hidden_channels, heads=2, dropout=dropout,concat=True)
        self.norm1 = LayerNorm(hidden_channels * 2)

        self.gat2 = GATConv(hidden_channels * 2, hidden_channels)
        self.norm2 = LayerNorm(hidden_channels)

        # 节点级别回归头
        self.lin = nn.Linear(hidden_channels * 2 , out_channels)

    def forward(self, x, edge_index, batch):
        # x: [N_nodes, in_channels]
        # edge_index: [2, E]
        # batch: [N_nodes], 如果需要处理多个图
        
        x = self.layer(x)
        x = self.ln1(x)
        x = torch.relu(x)#relu,tanh
        
        x = self.gat1(x, edge_index)
        x = self.norm1(x)
        x = F.relu(x)

        x = self.gat2(x, edge_index)
        x = self.norm2(x)
        x = F.relu(x)

        # 获取图级全局信息
        x_global = global_mean_pool(x, batch)               # [num_graphs, hidden_dim]
        x_global = x_global[batch]                          # broadcast 到每个节点 [num_nodes, hidden_dim]

        x = torch.cat([x, x_global], dim=1)                 # [num_nodes, hidden*2]
        out = self.lin(x)
        #out = self.lin(x)  # 输出节点级别预测值 [N_nodes, 1]
        #return out.squeeze(-1)  # [N_nodes]
        return out    
    
for k in np.arange(10):
    print(f'kx: {k+1}')
    crispr_score_test=crispr_score[crispr_score[:,3]==str(k+1),:]
    crispr_score_train=crispr_score[crispr_score[:,3]!=str(k+1),:]
    #path_to_save_model='models/best_model_v1_basedon_v4_scale_data.pt'
    model = GeneDependencyGAT(in_channels=6561, hidden_channels=64, out_channels=1)
    #optimizer = torch.optim.Adam(model.parameters(), lr=5e-3)
    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-2)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=3)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    path_to_save_model=f'trained_models/best_model_k_{k+1}_basedon_v4_scale_data.pt'
    best_val_loss = float('inf')
    early_stop_counter = 0
    early_stop_patience = 10
    num_epochs=150
    
    gene_feature_list,score_list=get_train_test_dataset(crispr_score_train)
    X_train, X_val, y_train, y_val = train_test_split(gene_feature_list, score_list, test_size=0.2, random_state=42)

    data_list_val = [Data(x=torch.from_numpy(x_i), edge_index=edge_index, y=torch.from_numpy(y_i)) for x_i,y_i in zip(X_val,y_val)]
    val_loader = DataLoader(data_list_val, batch_size=5)

    data_list_train = [Data(x=torch.from_numpy(x_i), edge_index=edge_index, y=torch.from_numpy(y_i)) for x_i,y_i in zip(X_train,y_train)]
    train_loader = DataLoader(data_list_train, batch_size=5, shuffle=True)    
    
    gene_feature_list_test,score_list_test=get_train_test_dataset(crispr_score_test)
    data_list_test = [Data(x=torch.from_numpy(x_i), edge_index=edge_index, y=torch.from_numpy(y_i)) for x_i,y_i in zip(gene_feature_list_test,score_list_test)]
    test_loader = DataLoader(data_list_test, batch_size=5)
    
    
    for epoch in range(num_epochs):
        model.train()
        for batch in train_loader:
            optimizer.zero_grad()
            batch.x, batch.edge_index, batch.batch = batch.x.to(device), batch.edge_index.to(device), batch.batch.to(device)
            out = model(batch.x, batch.edge_index, batch.batch)
            loss = F.huber_loss(out.cpu(), batch.y,delta=1.0)
            #loss = 0.8*negative_pearson_loss(out, batch.y) + 0.3*F.huber_loss(out, batch.y,delta=1.0)
            #loss = compute_loss(out, batch.y) 
            loss.backward()
            optimizer.step()
        #print(f"Epoch {epoch+1}, Loss: {loss.item():.4f}")
        val_metrics = evaluate(model, val_loader)
        train_metrics = evaluate(model, train_loader)
        val_loss = val_metrics['MSE']
        
        test_metrics = evaluate(model, test_loader)
        test_loss = test_metrics['MSE']
        
        current_lr = optimizer.param_groups[0]['lr']
        print(f"[Epoch {epoch+1}] Train Loss: {loss.item():.4f} | \
        Train MSE: {train_metrics['MSE']:.4f} | \
        Val MSE: {val_metrics['MSE']:.4f} | \
        Train MAE: {train_metrics['MAE']:.4f} | \
        Val MAE: {val_metrics['MAE']:.4f} | \
        learning rate: {current_lr} | \
        Train R2: {train_metrics['R2']:.4f} | \
        Val R2: {val_metrics['R2']:.4f} | \
        Train Pearson: {train_metrics['Pearson']:.4f} | \
        Val Pearson: {val_metrics['Pearson']:.4f} | \
        Test MSE: {test_metrics['MSE']:.4f} | \
        Test MAE: {test_metrics['MAE']:.4f} | \
        Test R2: {test_metrics['R2']:.4f} | \
        Test Pearson: {test_metrics['Pearson']:.4f} ")

        scheduler.step(val_loss)
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            early_stop_counter = 0
            torch.save(model.state_dict(), path_to_save_model)  # 可选保存,持续的保存优化的模型
        else:
            early_stop_counter += 1
            if early_stop_counter >= early_stop_patience:
                print("Early stopping triggered.")
                break


### DNN 

In [None]:
import pandas as pd
import numpy as np
import json

import torch
from torch_geometric.utils import from_networkx
import networkx as nx
from torch_geometric.data import Data, DataLoader
import torch.nn.functional as F

from scipy.stats import spearmanr
from sklearn.model_selection import train_test_split
from torch.utils.data import TensorDataset
from torch_geometric.data import Data, DataLoader

from sklearn.preprocessing import StandardScaler

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GATConv, global_max_pool

import torch
from torch_geometric.data import Data, Batch
import numpy as np


with open(r'datasets/genes2ID_2k.json', 'r', encoding='utf-8') as file:
    genes2ID_2k = json.load(file)
with open(r'datasets/samples2ID.json', 'r', encoding='utf-8') as file:
    samples2ID = json.load(file)    
gene_geneset=np.load(r"datasets/adj_gene_geneset.npy")
gsva = np.load(r"datasets/gsva_score_go_kegg_v3_format.npy")
crispr_score = np.genfromtxt(r"datasets/crispr_score_v4_20250423_scale_in_each_sample.txt", 
                     dtype=str, delimiter=' ', 
                     skip_header=1, filling_values='', 
                     autostrip=True) #(2556800, 4)#格式化后的数据
                     

def get_score_and_gene_features_per_sample(sample,crispr_score):
    tmp=[]
    for gene in genes2ID_2k.keys():
        tmp.append(gene_geneset[genes2ID_2k[gene],:] * gsva[samples2ID[sample]])

    crispr_score_sample=crispr_score[crispr_score[:,0]==sample,]    
    sort_column = crispr_score_sample[:, 1]    
    sort_indices = np.argsort([genes2ID_2k[x] for x in sort_column])    
    sorted_crispr_score_sample = crispr_score_sample[sort_indices]
    score_per_sample=np.array([np.round(np.float32(m),4) for m in sorted_crispr_score_sample[:,2]]).reshape(-1,1) #(2350, 1)

    return score_per_sample,np.array(tmp)

def get_train_test_dataset(crispr_dataset):
    samples=list(np.unique(crispr_dataset[:,0])) #CRISPR数据中的样本
    score_list=[]
    gene_feature_list=[]
    
    for sample in samples:
        score,gene_feature=get_score_and_gene_features_per_sample(sample,crispr_dataset)
        score_list.append(score)
        gene_feature_list.append(gene_feature)
    
    return gene_feature_list,score_list



import torch.optim as optim
from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score
from scipy.stats import pearsonr
import numpy as np
# === Loss functions ===
huber_loss = nn.HuberLoss()
mse_loss = nn.MSELoss()
mae_loss = nn.L1Loss()

# === Early stopping ===
best_val_loss = float('inf')
early_stop_counter = 0
early_stop_patience = 10
huber_loss = nn.HuberLoss()

def combined_loss(pred, true):
    # pred, true: (batch_size, num_genes)
    pred_centered = pred - pred.mean(dim=0, keepdim=True)
    true_centered = true - true.mean(dim=0, keepdim=True)

    numerator = (pred_centered * true_centered).sum(dim=0)
    denominator = torch.sqrt((pred_centered**2).sum(dim=0)) * torch.sqrt((true_centered**2).sum(dim=0))
    corr = numerator / (denominator + 1e-8)

    loss_samplewise = nn.SmoothL1Loss()(pred, true)
    return loss_samplewise + 1 - corr.mean()

# === Training function ===
def evaluate(model, loader):
    model.eval()
    preds_all, y_all = [], []
    loss_es = []
    
    with torch.no_grad():
        for x, y in loader:
            x, y = x.to(device), y.to(device)
            preds = model(x)
            loss = huber_loss(preds, y)
            preds_all.append(preds.cpu().numpy().flatten())
            y_all.append(y.cpu().numpy().flatten())
            loss_es.append(loss.item())
    
    y_true = np.concatenate(y_all)
    y_pred = np.concatenate(preds_all)

    loss = np.mean(loss_es)
    mae = mean_absolute_error(y_true, y_pred)
    mse = mean_squared_error(y_true, y_pred)
    rmse = np.sqrt(mse)
    r2 = r2_score(y_true, y_pred)
    pearson = pearsonr(y_true.ravel(), y_pred.ravel())[0]

    return {'MAE': mae, 'MSE': mse, 'RMSE': rmse, 'R2': r2, 'Pearson': pearson,'Loss':loss}
    
#num_epochs = 150
#path_to_save_model='models/best_model_dnn_v2.pt'
logs={
'k':[0],
'Epoch':[0],
'Train_MSE':[0],
'Val_MSE': [0],
'Test_MSE': [0],

'Train_MAE': [0],
'Val_MAE': [0],
'Test_MAE': [0],

'Learning_rate': [0],
'Train_R2': [0],
'Val_R2': [0],
'Test_R2': [0],

'Train_Pearson': [0],
'Val_Pearson': [0],
'Test_Pearson': [0],

'Train_Loss': [0],
'Val_Loss': [0],
'Test_Loss': [0]
}

for k in np.arange(10):
    crispr_score_test=crispr_score[crispr_score[:,3]==str(k+1),:]
    crispr_score_train=crispr_score[crispr_score[:,3]!=str(k+1),:]
    
    model=nn.Sequential(
    nn.Linear(6561,512),
    nn.LayerNorm(512),
    nn.ReLU(),
    nn.Dropout(0.2),

    nn.Linear(512,128),
    nn.LayerNorm(128),
    nn.ReLU(),
    nn.Dropout(0.2),

    nn.Linear(128,1),
    )
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = model.to(device)
    # === Optimizer & Scheduler ===
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=3)
    
    path_to_save_model=f'trained_models/best_model_k{k+1}_DNN_scale_data.pt'
    best_val_loss = float('inf')
    early_stop_counter = 0
    early_stop_patience = 10
    num_epochs=150
    
    gene_feature_list,score_list=get_train_test_dataset(crispr_score_train)
    X_train, X_val, y_train, y_val = train_test_split(gene_feature_list, score_list, test_size=0.2, random_state=42)

    #data_list_val = [Data(x=torch.from_numpy(x_i), edge_index=edge_index, y=torch.from_numpy(y_i)) for x_i,y_i in zip(X_val,y_val)]
    data_list_val=TensorDataset(torch.tensor(X_val, dtype=torch.float32),
                            torch.tensor(y_val, dtype=torch.float32))
    val_loader = DataLoader(data_list_val, batch_size=5)

    #data_list_train = [Data(x=torch.from_numpy(x_i), edge_index=edge_index, y=torch.from_numpy(y_i)) for x_i,y_i in zip(X_train,y_train)]
    data_list_train=TensorDataset(torch.tensor(X_train, dtype=torch.float32),
                              torch.tensor(y_train, dtype=torch.float32))
    train_loader = DataLoader(data_list_train, batch_size=5, shuffle=True)    
    
    gene_feature_list_test,score_list_test=get_train_test_dataset(crispr_score_test)
    #data_list_test = [Data(x=torch.from_numpy(x_i), edge_index=edge_index, y=torch.from_numpy(y_i)) for x_i,y_i in zip(gene_feature_list_test,score_list_test)]
    data_list_test=TensorDataset(torch.tensor(gene_feature_list_test, dtype=torch.float32),
                              torch.tensor(score_list_test, dtype=torch.float32))
    test_loader = DataLoader(data_list_test, batch_size=5)
    
    
    
    for epoch in range(num_epochs):
        model.train()
        
        for x_batch, y_batch in train_loader:
            optimizer.zero_grad()
            x_batch, y_batch = x_batch.to(device), y_batch.to(device)

            preds = model(x_batch)       
            #loss = huber_loss(preds, y_batch) + mse_loss(preds, y_batch)
            loss = huber_loss(preds, y_batch)
            #loss = combined_loss(preds, y_batch)
            loss.backward()
            optimizer.step()
            #total_loss += loss.item()

        # === Evaluate on validation set ===
        val_metrics = evaluate(model, val_loader)
        val_loss = val_metrics['Loss'] 
        
        train_metrics = evaluate(model, train_loader)
        test_metrics = evaluate(model, test_loader)
        
        current_lr = optimizer.param_groups[0]['lr']
        #print(f"[Epoch {epoch+1}] Train Loss: {total_loss:.4f} |Learning Rate: {current_lr}|Val Loss : {val_loss:.4f} |Val MSE: {val_metrics['MSE']:.4f} | R2: {val_metrics['R2']:.4f} | Pearson: {val_metrics['Pearson']:.4f}")
        logs['k'].append(k+1)
        logs['Epoch'].append(epoch+1)
        logs['Train_MSE'].append(train_metrics['MSE'])
        logs['Val_MSE'].append(val_metrics['MSE'])
        logs['Test_MSE'].append(test_metrics['MSE'])
        
        logs['Train_MAE'].append(train_metrics['MAE'])
        logs['Val_MAE'].append(val_metrics['MAE'])
        logs['Test_MAE'].append(test_metrics['MAE'])
        
        logs['Learning_rate'].append(current_lr)
        
        logs['Train_R2'].append(train_metrics['R2'])
        logs['Val_R2'].append(val_metrics['R2'])
        logs['Test_R2'].append(test_metrics['R2'])
        
        logs['Train_Pearson'].append(train_metrics['Pearson'])
        logs['Val_Pearson'].append(val_metrics['Pearson'])
        logs['Test_Pearson'].append(test_metrics['Pearson'])
        
        logs['Train_Loss'].append(train_metrics['Loss'])
        logs['Val_Loss'].append(val_metrics['Loss'])
        logs['Test_Loss'].append(test_metrics['Loss'])
        
        

        scheduler.step(val_loss)
        # === Early stopping ===
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            early_stop_counter = 0
            torch.save(model.state_dict(), path_to_save_model) 
        else:
            early_stop_counter += 1
            if early_stop_counter >= early_stop_patience:
                print("Early stopping triggered.")
                pd.DataFrame(logs).to_csv(f'./logs/logs_k{k+1}.csv')
                break


## Applications

### GSE272107

In [1]:
# load GSVA results
gsva_scores_df=pd.read_csv(r"GSE272107_gsva.txt",delimiter=' ')
with open(r'datasets/genesets_select_6k.json', 'r', encoding='utf-8') as file:
    geneset2ID = json.load(file)

def get_rows_with_fallback(df, row_names):
    existing = df.reindex(row_names).fillna(0)
    return existing
result_df = get_rows_with_fallback(gsva_scores_df.T, geneset2ID.keys())
gsva_scores_GSE272107=np.array(result_df.T,dtype='float32') # sample X 6561	
def get_gene_features_per_sample(sample):
    tmp=[]
    for gene in genes2ID_2k.keys():
        tmp.append(gene_geneset[genes2ID_2k[gene],:] * gsva_scores_GSE272107[sample])
    return np.array(tmp)

samples=np.arange(10)
score_list=[]
gene_feature_list=[]

for sample in samples:
    gene_feature=get_gene_features_per_sample(sample)
    #score_list.append(score)
    gene_feature_list.append(gene_feature)

In [None]:
data_list_GSE272107 = [Data(x=torch.from_numpy(x_i), edge_index=edge_index) for x_i in gene_feature_list]
data_list_GSE272107_loader = DataLoader(data_list_GSE272107, batch_size=5)
all_preds = []
all_labels = []
model.eval()
with torch.no_grad():
    for data in data_list_GSE272107_loader:
    #for data in train_loader:
        data = data.to(device)
        out = model(data.x, data.edge_index,data.batch)
        all_preds.append(out.cpu())

In [None]:
y_pred = torch.cat(all_preds, dim=0).numpy()
pd.DataFrame({
    'preds': y_pred.flatten()
}).to_csv('GSE272107_predictions.csv')

## GNNExplainer 

### select node/gene and graph/sample for GNNExplanier

In [None]:
# select node/gene and graph/sample for GNNExplanier
score_list=[]
gene_feature_list=[]
samples = ['ACH-000587','ACH-000448']
node_idx_in_graph=local_node_idx=921 # egfr

In [None]:
for sample in samples:
    score,gene_feature=get_score_and_gene_features_per_sample(sample)
    score_list.append(score)
    gene_feature_list.append(gene_feature)
    
ppi_df = pd.read_csv("./datasets/ggi_id.txt", sep=" ")
edges = ppi_df[['g1_id', 'g2_id']].values
edge_index = torch.tensor(edges.T, dtype=torch.long) 

data_list_train = [Data(x=torch.from_numpy(x_i), 
                        edge_index=edge_index, 
                        y=torch.from_numpy(y_i))   for x_i,y_i in zip(gene_feature_list,score_list) ]
train_loader = DataLoader(data_list_train)

In [None]:
graph_index=0
#node_idx_in_graph=local_node_idx=921 # egfr
def get_global_node_index_from_local(data, graph_index, local_node_idx):
    node_indices = (data.batch == graph_index).nonzero(as_tuple=False).view(-1)
    return int(node_indices[local_node_idx].item())

data = next(iter(train_loader))
global_node_idx = get_global_node_index_from_local(data, graph_index, local_node_idx)

### Explainer 

In [None]:
model = GeneDependencyGAT(in_channels=6561, hidden_channels=64, out_channels=1)
model.load_state_dict(torch.load('./trained_models/GATDep/best_model_v1_basedon_v4_scale_data.pt'))

explainer = Explainer(
    model=model,
    algorithm=GNNExplainer(epochs=100),
    explanation_type='model',
    model_config=ModelConfig(
        mode='regression',  
        task_level='node', 
        return_type='raw', 
    ),
    node_mask_type='object',   
    edge_mask_type='object',  
)

explanation = explainer(
    x=data.x,
    edge_index=data.edge_index,
    batch=data.batch,
    index=global_node_idx  
)

In [None]:
# 2. subset node importance
importance = node_mask.detach().cpu().numpy()
node_color = {}
for i, score in enumerate(importance):
    if score >= 0.05:
        node_color[i] = score
if node_idx_in_graph is not None:
        node_color[node_idx_in_graph] = np.array(node_mask[node_idx_in_graph])

### Visualization 

In [None]:
# 3. top 20 genes
top_nodes = sorted(node_color.items(), key=lambda x: -x[1])[:20] 
node_color = dict(top_nodes)
node_color[node_idx_in_graph] = np.array(node_mask[node_idx_in_graph])
node_color

In [None]:
edge_list = edge_index.cpu().numpy().T.tolist()
G = nx.Graph()
G.add_edges_from(edge_list)
    
subG1 = G.subgraph(node_color.keys()).copy()
pos = nx.spring_layout(subG1, seed=42)

#获取node_labels
node_labels={}
for k,v in node_color.items():
    node_labels[k]=id2gene[k]

In [None]:
subG = G.subgraph(node_color.keys()).copy()

pos = nx.spring_layout(subG, seed=42,k=0.9)
#cmap = plt.cm.viridis
from matplotlib.colors import Normalize, LinearSegmentedColormap
cmap = LinearSegmentedColormap.from_list("my_colormap", ["grey","white", "#548235"])

values = list(node_color.values())
nodes = list(node_color.keys())

norm = plt.Normalize(vmin=min(values), vmax=max(values))

#plt.figure(figsize=(8, 6))
fig, ax = plt.subplots(figsize=(12, 6))  
nx.draw_networkx_edges(subG, pos, alpha=0.5)
nodes_drawn = nx.draw_networkx_nodes(
    subG, pos, nodelist=nodes,
    node_color=[node_color[n] for n in nodes],
    cmap=cmap,
    node_size=2000,
    alpha=0.9,
    linewidths=0,
    edgecolors="black"
)
nx.draw_networkx_nodes(
            subG, pos, nodelist=[local_node_idx],
            node_color='grey',
            node_size=2000,
            label='Center Node',
            edgecolors="darkred"
        )
nx.draw_networkx_edges(subG,pos, width=3)
nx.draw_networkx_labels(subG, pos, font_size=18, 
                        font_color="black",
                        #font_family='arial',
                       labels=node_labels)
#nx.draw_networkx_labels(subG, pos1, font_size=10, font_color="white")
sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
sm.set_array([])
cbar = plt.colorbar(sm, ax=ax,shrink=0.8,aspect=20)
#cbar.ax.set_aspect(20)
cbar.set_label('Node Importance (mask)',size=20)
cbar.ax.tick_params(labelsize=20)


ax.set_title('')
ax.axis('off')
plt.tight_layout()
plt.savefig(f'gene_id_{id2gene[local_node_idx]}_{local_node_idx}.png',)
plt.show()