# GIN(E)-based-cmsiRpred (module A+B2+C)

2-gin-v3.2-draw_0708-Copy1

In [1]:
import time
import copy
import re
import os
from tqdm import tqdm

import torch
import torch.nn as nn
import numpy as np
import pandas as pd
from torchinfo import summary
from torch.utils.data import Dataset, DataLoader, Subset
import torch.optim as optim
from sklearn.metrics import precision_score, recall_score, mean_absolute_error
from sklearn.model_selection import KFold

BATCH_SIZE = 256

In [2]:
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader as PyG_DataLoader
import torch.nn.functional as F
from torch_geometric.nn import GINEConv,global_add_pool
import matplotlib.pyplot as plt

import networkx as nx

## Load data

In [3]:
df_structured_encoded = pd.read_pickle('/home/ken/MyStorage/siRNA_2503/Data/df_structured_encoded_0409.pkl')

In [4]:
df_structured_encoded_iid_trvl = df_structured_encoded[df_structured_encoded['dataset_usage']=='IID_trvl'].sample(frac=1)
df_structured_encoded_iid_test = df_structured_encoded[df_structured_encoded['dataset_usage']=='IID_test']
df_structured_encoded_ood_test = df_structured_encoded[df_structured_encoded['dataset_usage']=='OOD_test']
print(df_structured_encoded_iid_trvl.shape,df_structured_encoded_iid_test.shape,df_structured_encoded_ood_test.shape)

(20626, 165) (2568, 165) (2588, 165)


## Define the graph representation of cm-siRNA data

In [5]:
TRVL_sense_len_list = list(df_structured_encoded_iid_trvl['siRNA_sense_seq'].str.len())
TRVL_antis_len_list = list(df_structured_encoded_iid_trvl['siRNA_antisense_seq'].str.len())

TEST_OOD_sense_len_list = list(df_structured_encoded_ood_test['siRNA_sense_seq'].str.len())
TEST_OOD_antis_len_list = list(df_structured_encoded_ood_test['siRNA_antisense_seq'].str.len())

TEST_IID_sense_len_list = list(df_structured_encoded_iid_test['siRNA_sense_seq'].str.len())
TEST_IID_antis_len_list = list(df_structured_encoded_iid_test['siRNA_antisense_seq'].str.len())

---

### edges

In [6]:
def pp_mtx_2_edge_info(np_mtx_pp): # for 1_v3 output dataset df_structured_encoded_0326
    np_mtx_pp_eye = np.eye(np_mtx_pp.shape[1]) + np.eye(np_mtx_pp.shape[1],k=1) + np_mtx_pp # primary structure and self-loop
    np_adjmtx = np_mtx_pp_eye + np_mtx_pp_eye.T - np.diag(np_mtx_pp_eye.diagonal())

    tnsr_edges = torch.tensor(np_adjmtx).nonzero().t() # adjacent matrix to edge list
    tnsr_eweight = torch.tensor(np_adjmtx)[tnsr_edges[0],tnsr_edges[1]] # edge weight
    
    return np_adjmtx,tnsr_edges,tnsr_eweight

def get_edge_info_v2(df_mtx_pp): # df_structured_encoded[['mtx_pp_sense']]
    df_edge_info = pd.DataFrame()
    df_edge_info[['adjmtx','edges','eweight']] = df_mtx_pp.apply(pp_mtx_2_edge_info).apply(pd.Series)
    ls_edges = list(df_edge_info['edges'])
    ls_eweight = list(df_edge_info['eweight'])
    return {'ls_edges':ls_edges,'ls_eweight':ls_eweight}

In [7]:
# Secondary structure

edge_info_iid_test_sense = get_edge_info_v2(df_structured_encoded_iid_test['mtx_pp_sense'])
edge_info_iid_test_antis = get_edge_info_v2(df_structured_encoded_iid_test['mtx_pp_antis'])

edge_info_ood_test_sense = get_edge_info_v2(df_structured_encoded_ood_test['mtx_pp_sense'])
edge_info_ood_test_antis = get_edge_info_v2(df_structured_encoded_ood_test['mtx_pp_antis'])

edge_info_trvl_sense = get_edge_info_v2(df_structured_encoded_iid_trvl['mtx_pp_sense'])
edge_info_trvl_antis = get_edge_info_v2(df_structured_encoded_iid_trvl['mtx_pp_antis'])

In [8]:
# "Full" adjacent matrix：all edges weighted 1

def get_edge_info_all1_v2(df_str_enc,sn_or_an):
    ls_edges = []
    ls_eweight = []
    for i in range(df_str_enc.shape[0]):
        adjmtx_shape = df_str_enc.iloc[i]['mtx_pp_'+sn_or_an].shape
        np_adjmtx = np.ones(adjmtx_shape)
        tnsr_edges = torch.tensor(np_adjmtx).nonzero().t()
        tnsr_eweight = torch.tensor(np_adjmtx)[tnsr_edges[0],tnsr_edges[1]]
        ls_edges.append(tnsr_edges)
        ls_eweight.append(tnsr_eweight)
    return {'ls_edges':ls_edges,'ls_eweight':ls_eweight}

edge_info_all1_iid_test_sense = get_edge_info_all1_v2(df_structured_encoded_iid_test,'sense')
edge_info_all1_iid_test_antis = get_edge_info_all1_v2(df_structured_encoded_iid_test,'antis')

edge_info_all1_ood_test_sense = get_edge_info_all1_v2(df_structured_encoded_ood_test,'sense')
edge_info_all1_ood_test_antis = get_edge_info_all1_v2(df_structured_encoded_ood_test,'antis')

edge_info_all1_trvl_sense = get_edge_info_all1_v2(df_structured_encoded_iid_trvl,'sense')
edge_info_all1_trvl_antis = get_edge_info_all1_v2(df_structured_encoded_iid_trvl,'antis')

In [9]:
# "Primary" adjacent matrix: primary structure only

edge_info_str1_iid_test_sense = get_edge_info_v2(df_structured_encoded_iid_test['mtx_pp_sense'].apply(lambda x: np.zeros_like(x)))
edge_info_str1_iid_test_antis = get_edge_info_v2(df_structured_encoded_iid_test['mtx_pp_antis'].apply(lambda x: np.zeros_like(x)))

edge_info_str1_ood_test_sense = get_edge_info_v2(df_structured_encoded_ood_test['mtx_pp_sense'].apply(lambda x: np.zeros_like(x)))
edge_info_str1_ood_test_antis = get_edge_info_v2(df_structured_encoded_ood_test['mtx_pp_antis'].apply(lambda x: np.zeros_like(x)))

edge_info_str1_trvl_sense = get_edge_info_v2(df_structured_encoded_iid_trvl['mtx_pp_sense'].apply(lambda x: np.zeros_like(x)))
edge_info_str1_trvl_antis = get_edge_info_v2(df_structured_encoded_iid_trvl['mtx_pp_antis'].apply(lambda x: np.zeros_like(x)))

In [10]:
# "Null" adjacent matrix: no edges

edges_null = torch.tensor([[],[]])
eweights_null = torch.tensor([])

edge_info_none_iid_test_sense = {'ls_edges' : [edges_null] * df_structured_encoded_iid_test.shape[0],
                                 'ls_eweight' : [eweights_null] * df_structured_encoded_iid_test.shape[0]}
edge_info_none_iid_test_antis = {'ls_edges' : [edges_null] * df_structured_encoded_iid_test.shape[0],
                                 'ls_eweight' : [eweights_null] * df_structured_encoded_iid_test.shape[0]}

edge_info_none_ood_test_sense = {'ls_edges' : [edges_null] * df_structured_encoded_ood_test.shape[0],
                                 'ls_eweight' : [eweights_null] * df_structured_encoded_ood_test.shape[0]}
edge_info_none_ood_test_antis = {'ls_edges' : [edges_null] * df_structured_encoded_ood_test.shape[0],
                                 'ls_eweight' : [eweights_null] * df_structured_encoded_ood_test.shape[0]}

edge_info_none_trvl_sense = {'ls_edges' : [edges_null] * df_structured_encoded_iid_trvl.shape[0],
                             'ls_eweight' : [eweights_null] * df_structured_encoded_iid_trvl.shape[0]}
edge_info_none_trvl_antis = {'ls_edges' : [edges_null] * df_structured_encoded_iid_trvl.shape[0],
                             'ls_eweight' : [eweights_null] * df_structured_encoded_iid_trvl.shape[0]}

### nodes

In [11]:
def get_nt_strtype_vec(dot_bracket,SEQ_MAX_LEN=28):
    import forgi.graph.bulge_graph as fgb
    bg = fgb.BulgeGraph.from_dotbracket(dot_bracket)
    elements_strcode = bg.to_element_string()
    map_dict = {'P':7,'s':0,'h':1,'i':2,'m':3,'f':4,'t':5}
    elements_numcode = list(map(lambda x:map_dict[x],list(elements_strcode)))
    return torch.tensor(elements_numcode)
    
def str_num_2_tnsr_int(num_str):
    ls_num = list(num_str)
    ls_num_int =  list(map(int,ls_num))
    return torch.tensor(ls_num_int)

def get_node_ftr_v2(df_structured_encoded):
    node_ftr = {}
    label_tensor = torch.tensor(list(df_structured_encoded['mRNA_remaining_pct']))
    node_ftr['label_tensor'] = list(label_tensor.reshape([len(label_tensor),1]).to(torch.float32))
    ls_seq_sense_index = df_structured_encoded['seq_agct_int_sense'].apply(str_num_2_tnsr_int).apply(lambda x: torch.flip(x,dims=[0])).to_list()
    ls_seq_antis_index = df_structured_encoded['seq_agct_int_anti'].apply(str_num_2_tnsr_int).to_list()
    ls_modi_sense_index = df_structured_encoded['seq_modi_int_sense'].apply(str_num_2_tnsr_int).apply(lambda x: torch.flip(x,dims=[0])).to_list()
    ls_modi_antis_index = df_structured_encoded['seq_modi_int_anti'].apply(str_num_2_tnsr_int).to_list()
    ls_struct_sense_index = df_structured_encoded['dp_MEA_sense'].apply(get_nt_strtype_vec).to_list()
    ls_struct_antis_index = df_structured_encoded['dp_MEA_antis'].apply(get_nt_strtype_vec).to_list()
    
    ls_ftr_sense = list(zip(ls_seq_sense_index,ls_modi_sense_index,ls_struct_sense_index))
    node_ftr['ls_ftr_sense'] = list(map(lambda x: torch.transpose(torch.stack(x),0,1),ls_ftr_sense))
    ls_ftr_antis = list(zip(ls_seq_antis_index,ls_modi_antis_index,ls_struct_antis_index))
    node_ftr['ls_ftr_antis'] = list(map(lambda x: torch.transpose(torch.stack(x),0,1),ls_ftr_antis))
    
    return node_ftr

In [12]:
node_ftr_iid_test = get_node_ftr_v2(df_structured_encoded_iid_test)
node_ftr_ood_test = get_node_ftr_v2(df_structured_encoded_ood_test)
node_ftr_trvl = get_node_ftr_v2(df_structured_encoded_iid_trvl)

## Transform data into graph

In [13]:
'''
Params Example for TRVL
    ftrmtx = TRVL_node_ftr['sense_node_ftr']
    edgls = TRVL_edge_info['sense_edges']
    eweightls = TRVL_edge_info['sense_eweights']
    g_index_ls = trvl_index
'''

def get_graph_ls(ftrmtx,edgels,eweightls,g_index_ls):
    if len(ftrmtx)!=len(g_index_ls) or len(g_index_ls)!=len(edgels):
        raise Exception('Graph numbers are not match between ftrmtx, edgels and g_index_ls')
    graph_ls = []
    for i in range(len(g_index_ls)):
        data = Data(x=ftrmtx[i],
                    edge_index=edgels[i].to(torch.long),
                    edge_attr=eweightls[i],
                    g_id=g_index_ls[i])
        graph_ls.append(data)
    return graph_ls

#### Complete

In [14]:
TRVL_sense_glist = get_graph_ls(node_ftr_trvl['ls_ftr_sense'],
                                edge_info_trvl_sense['ls_edges'],
                                edge_info_trvl_sense['ls_eweight'],
                                list(df_structured_encoded_iid_trvl.index))
TRVL_antis_glist = get_graph_ls(node_ftr_trvl['ls_ftr_antis'],
                                edge_info_trvl_antis['ls_edges'],
                                edge_info_trvl_antis['ls_eweight'],
                                list(df_structured_encoded_iid_trvl.index))

TEST_IID_sense_glist = get_graph_ls(node_ftr_iid_test['ls_ftr_sense'],
                                edge_info_iid_test_sense['ls_edges'],
                                edge_info_iid_test_sense['ls_eweight'],
                                list(df_structured_encoded_iid_test.index))
TEST_IID_antis_glist = get_graph_ls(node_ftr_iid_test['ls_ftr_antis'],
                                edge_info_iid_test_antis['ls_edges'],
                                edge_info_iid_test_antis['ls_eweight'],
                                list(df_structured_encoded_iid_test.index))

TEST_OOD_sense_glist = get_graph_ls(node_ftr_ood_test['ls_ftr_sense'],
                                edge_info_ood_test_sense['ls_edges'],
                                edge_info_ood_test_sense['ls_eweight'],
                                list(df_structured_encoded_ood_test.index))
TEST_OOD_antis_glist = get_graph_ls(node_ftr_ood_test['ls_ftr_antis'],
                                edge_info_ood_test_antis['ls_edges'],
                                edge_info_ood_test_antis['ls_eweight'],
                                list(df_structured_encoded_ood_test.index))

#### Full

In [15]:
TRVL_sense_glist_all1 = get_graph_ls(node_ftr_trvl['ls_ftr_sense'],
                                edge_info_all1_trvl_sense['ls_edges'],
                                edge_info_all1_trvl_sense['ls_eweight'],
                                list(df_structured_encoded_iid_trvl.index))
TRVL_antis_glist_all1 = get_graph_ls(node_ftr_trvl['ls_ftr_antis'],
                                edge_info_all1_trvl_antis['ls_edges'],
                                edge_info_all1_trvl_antis['ls_eweight'],
                                list(df_structured_encoded_iid_trvl.index))

TEST_IID_sense_glist_all1 = get_graph_ls(node_ftr_iid_test['ls_ftr_sense'],
                                edge_info_all1_iid_test_sense['ls_edges'],
                                edge_info_all1_iid_test_sense['ls_eweight'],
                                list(df_structured_encoded_iid_test.index))
TEST_IID_antis_glist_all1 = get_graph_ls(node_ftr_iid_test['ls_ftr_antis'],
                                edge_info_all1_iid_test_antis['ls_edges'],
                                edge_info_all1_iid_test_antis['ls_eweight'],
                                list(df_structured_encoded_iid_test.index))

TEST_OOD_sense_glist_all1 = get_graph_ls(node_ftr_ood_test['ls_ftr_sense'],
                                edge_info_all1_ood_test_sense['ls_edges'],
                                edge_info_all1_ood_test_sense['ls_eweight'],
                                list(df_structured_encoded_ood_test.index))
TEST_OOD_antis_glist_all1 = get_graph_ls(node_ftr_ood_test['ls_ftr_antis'],
                                edge_info_all1_ood_test_antis['ls_edges'],
                                edge_info_all1_ood_test_antis['ls_eweight'],
                                list(df_structured_encoded_ood_test.index))

#### Primary

In [16]:
TRVL_sense_glist_str1 = get_graph_ls(node_ftr_trvl['ls_ftr_sense'],
                                edge_info_str1_trvl_sense['ls_edges'],
                                edge_info_str1_trvl_sense['ls_eweight'],
                                list(df_structured_encoded_iid_trvl.index))
TRVL_antis_glist_str1 = get_graph_ls(node_ftr_trvl['ls_ftr_antis'],
                                edge_info_str1_trvl_antis['ls_edges'],
                                edge_info_str1_trvl_antis['ls_eweight'],
                                list(df_structured_encoded_iid_trvl.index))

TEST_IID_sense_glist_str1 = get_graph_ls(node_ftr_iid_test['ls_ftr_sense'],
                                edge_info_str1_iid_test_sense['ls_edges'],
                                edge_info_str1_iid_test_sense['ls_eweight'],
                                list(df_structured_encoded_iid_test.index))
TEST_IID_antis_glist_str1 = get_graph_ls(node_ftr_iid_test['ls_ftr_antis'],
                                edge_info_str1_iid_test_antis['ls_edges'],
                                edge_info_str1_iid_test_antis['ls_eweight'],
                                list(df_structured_encoded_iid_test.index))

TEST_OOD_sense_glist_str1 = get_graph_ls(node_ftr_ood_test['ls_ftr_sense'],
                                edge_info_str1_ood_test_sense['ls_edges'],
                                edge_info_str1_ood_test_sense['ls_eweight'],
                                list(df_structured_encoded_ood_test.index))
TEST_OOD_antis_glist_str1 = get_graph_ls(node_ftr_ood_test['ls_ftr_antis'],
                                edge_info_str1_ood_test_antis['ls_edges'],
                                edge_info_str1_ood_test_antis['ls_eweight'],
                                list(df_structured_encoded_ood_test.index))

#### Null

In [17]:
TRVL_sense_glist_none = get_graph_ls(node_ftr_trvl['ls_ftr_sense'],
                                edge_info_none_trvl_sense['ls_edges'],
                                edge_info_none_trvl_sense['ls_eweight'],
                                list(df_structured_encoded_iid_trvl.index))
TRVL_antis_glist_none = get_graph_ls(node_ftr_trvl['ls_ftr_antis'],
                                edge_info_none_trvl_antis['ls_edges'],
                                edge_info_none_trvl_antis['ls_eweight'],
                                list(df_structured_encoded_iid_trvl.index))

TEST_IID_sense_glist_none = get_graph_ls(node_ftr_iid_test['ls_ftr_sense'],
                                edge_info_none_iid_test_sense['ls_edges'],
                                edge_info_none_iid_test_sense['ls_eweight'],
                                list(df_structured_encoded_iid_test.index))
TEST_IID_antis_glist_none = get_graph_ls(node_ftr_iid_test['ls_ftr_antis'],
                                edge_info_none_iid_test_antis['ls_edges'],
                                edge_info_none_iid_test_antis['ls_eweight'],
                                list(df_structured_encoded_iid_test.index))

TEST_OOD_sense_glist_none = get_graph_ls(node_ftr_ood_test['ls_ftr_sense'],
                                edge_info_none_ood_test_sense['ls_edges'],
                                edge_info_none_ood_test_sense['ls_eweight'],
                                list(df_structured_encoded_ood_test.index))
TEST_OOD_antis_glist_none = get_graph_ls(node_ftr_ood_test['ls_ftr_antis'],
                                edge_info_none_ood_test_antis['ls_edges'],
                                edge_info_none_ood_test_antis['ls_eweight'],
                                list(df_structured_encoded_ood_test.index))

### Experimental context

In [18]:
class NonGraph_dataset(Dataset):
    def __init__(self, df_structured_encoded):  
        self.g_id = list(df_structured_encoded.index)
        df_tabular_encoded = df_structured_encoded.loc[:,df_structured_encoded.columns.str.contains(r'!\w+!')]
        self.features_tensor = torch.tensor(df_tabular_encoded.values).to(torch.float32)
        label_tensor = torch.tensor(list(df_structured_encoded['mRNA_remaining_pct']))
        self.label_tensor = label_tensor.reshape([len(label_tensor),1]).to(torch.float32)
    def __getitem__(self,index):
        return self.g_id[index],self.features_tensor[index],self.label_tensor[index]
    def __len__(self):
        return len(self.g_id)

TRVL_NonGraph_dataset = NonGraph_dataset(df_structured_encoded_iid_trvl)
TEST_OOD_NonGraph_dataset = NonGraph_dataset(df_structured_encoded_ood_test)
TEST_IID_NonGraph_dataset = NonGraph_dataset(df_structured_encoded_iid_test)

## Model Design

### GINE

In [19]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model=48, max_len=28):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe.unsqueeze(0)) # 形状为 (1, max_len, d_model)

    def forward(self, x):
        return x + self.pe[:, :x.size(1)]

class Embeddings(nn.Module):
    def __init__(self, embed_dim=16, max_len=28):
        super().__init__()
        self.seq_embed = nn.Embedding(8, embed_dim)
        self.modi_embed = nn.Embedding(8, embed_dim)
        self.combine_fc = nn.Linear(embed_dim, 32)
        self.output_fc = nn.Linear(32,16)

    def forward(self, seq_modi_str):# mask = (seq==7)
        seq_modi_str = seq_modi_str.int()
        seq = seq_modi_str[:,0]
        
        modi = seq_modi_str[:,1]
        struct = seq_modi_str[:,2]
        seq_emb = self.seq_embed(seq)
        modi_emb = self.modi_embed(modi)
        combine_emb = seq_emb + modi_emb
        
        features = self.combine_fc(combine_emb)
        output = self.output_fc(features)
        
        return output

# node_xlen = 6+7+1 = 14
# dim_h = 16

class GINE(nn.Module):
    def __init__(self,node_xlen,edge_xlen,dim_h):
        super(GINE,self).__init__()
        self.edge_lins = nn.ModuleList([nn.Linear(edge_xlen,node_xlen),
                                        nn.Linear(edge_xlen,dim_h),
                                        nn.Linear(edge_xlen,dim_h)])
        self.conv1 = GINEConv(nn.Sequential(nn.Linear(node_xlen,dim_h),
                                            nn.BatchNorm1d(dim_h),
                                            nn.ReLU(),
                                            nn.Linear(dim_h,dim_h),
                                            nn.ReLU()))
        self.conv2 = GINEConv(nn.Sequential(nn.Linear(dim_h,dim_h),
                                            nn.BatchNorm1d(dim_h),
                                            nn.ReLU(),
                                            nn.Linear(dim_h,dim_h),
                                            nn.ReLU()))
        self.conv3 = GINEConv(nn.Sequential(nn.Linear(dim_h,dim_h),
                                            nn.BatchNorm1d(dim_h),
                                            nn.ReLU(),
                                            nn.Linear(dim_h,dim_h),
                                            nn.ReLU()))
    
    def forward(self,node_x,edge_index,edge_x,batch):
        #print('GINE: edge_index',edge_index.shape,edge_index.max())
        #print('GINE: node_x',node_x.shape)
        
        edge_x1 = self.edge_lins[0](edge_x)
        h1 = self.conv1(node_x,edge_index,edge_x1)
        
        edge_x2 = self.edge_lins[1](edge_x)
        h2 = self.conv2(h1,edge_index,edge_x2)
        
        edge_x3 = self.edge_lins[2](edge_x)
        h3 = self.conv3(h2,edge_index,edge_x3)
        
        h1 = global_add_pool(h1, batch)
        h2 = global_add_pool(h2, batch)
        h3 = global_add_pool(h3, batch)
        
        h = torch.cat((h1, h2, h3), dim=1)
        
        return h

# tfx_in = 109
# tfx_out = 16

class TfxMLP(nn.Module):
    def __init__(self):
        super().__init__()
        self.dense1 = nn.Linear(109,512)
        self.dense2 = nn.Linear(512,256)
        self.dense3 = nn.Linear(256,192)
    
    def forward(self,x): #output_onehot3dconv,dataload):    
        x = self.dense1(x)
        x = F.relu(x)
        x = self.dense2(x)
        x = F.relu(x)
        x = self.dense3(x)
        x = F.relu(x)
        return x.view(x.size(0),-1)

class CombineMLP(nn.Module):
    def __init__(self):
        super().__init__()
        self.dense1 = nn.Linear(576,256)
        self.dense2 = nn.Linear(256,128)
        self.dense3 = nn.Linear(128,64)
        self.linear_out = nn.Linear(64,1)
    def forward(self,x):
        x = self.dense1(x)
        x = F.relu(x)
        x = self.dense2(x)
        x = F.relu(x)
        x = self.dense3(x)
        x = F.relu(x)
        x = F.dropout(x,p=0.5,training=self.training)
        x = self.linear_out(x)
        return x.view(x.size(0),-1)

class modisiR_GINE(nn.Module):
    def __init__(self):
        super(modisiR_GINE,self).__init__()
        self.embed_sense = Embeddings(embed_dim=16, max_len=28)
        self.embed_antis = Embeddings(embed_dim=16, max_len=28)
        self.gine_sense = GINE(node_xlen=16,edge_xlen=1,dim_h=64)
        self.gine_antis = GINE(node_xlen=16,edge_xlen=1,dim_h=64)
        self.tfxmlp = TfxMLP()
        self.combinemlp = CombineMLP()
    def forward(self,sense_gbatch,antis_gbatch,nongrf_batch): #node_x,edge_index,edge_x,batch,tfx_x
        
        sense_node_x = self.embed_sense(sense_gbatch.x.to(torch.float32))
        sense_edge_index = sense_gbatch.edge_index
        sense_edge_attr = torch.unsqueeze(sense_gbatch.edge_attr,1).to(torch.float32)
        sense_batch_vec = sense_gbatch.batch
        sense_graph_h = self.gine_sense(sense_node_x,sense_edge_index,sense_edge_attr,sense_batch_vec)

        antis_node_x = self.embed_antis(antis_gbatch.x.to(torch.float32))
        antis_edge_index = antis_gbatch.edge_index
        antis_edge_attr = torch.unsqueeze(antis_gbatch.edge_attr,1).to(torch.float32)
        antis_batch_vec = antis_gbatch.batch
        antis_graph_h = self.gine_antis(antis_node_x,antis_edge_index,antis_edge_attr,antis_batch_vec)
        
        tfxftr_x = nongrf_batch[1]
        tfxftr_h = self.tfxmlp(nongrf_batch)
        
        graph_tfx = torch.cat((sense_graph_h,antis_graph_h,tfxftr_h),dim=1)
        pred = self.combinemlp(graph_tfx)
        return pred.reshape([len(pred),1])

### train_func

In [20]:
def train_GINE_ERM(dataload_TRAIN_zip,model,optimizer,criterion):
    DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
    model.to(DEVICE)
    model.train()
    loss_train = 0
    
    for data_triple in dataload_TRAIN_zip:
        #print('*',end='')
        sense_gbatch = data_triple[0].to(DEVICE)
        antis_gbatch = data_triple[1].to(DEVICE)
        nongrf_batch = data_triple[2][1].to(DEVICE)
        
        #print('train_GINE_ERM:',sense_gbatch.edge_index.max())
        y_batch_pred = model(sense_gbatch,antis_gbatch,nongrf_batch)
        y_batch_lbl = data_triple[2][2].to(torch.float32)
        
        loss_batch = criterion(y_batch_pred, y_batch_lbl).to(torch.float32)  # 计算加权损失
        optimizer.zero_grad()
        loss_batch.backward()
        optimizer.step()
        
        loss_train += loss_batch.item()
    return loss_train/len(dataload_TRAIN_zip)

In [21]:
def calculate_metrics(y_pred, y_true, threshold=30):
    import warnings
    warnings.simplefilter("ignore")
    
    y_true = y_true.clip(0,100)
    y_pred = y_pred.clip(0,100)
    
    mae = np.mean(np.abs(y_true - y_pred))

    y_true_binary = (y_true < threshold).astype(int)
    y_pred_binary = (y_pred < threshold).astype(int)

    mask = (y_pred >= 0) & (y_pred <= threshold)
    range_mae = mean_absolute_error(y_true[mask], y_pred[mask]) if mask.sum() > 0 else 100

    precision = precision_score(y_true_binary, y_pred_binary, average='binary')
    recall = recall_score(y_true_binary, y_pred_binary, average='binary')
    
    f1 = 2 * precision * recall / (precision + recall)

    score = (1 - mae / 100) * 0.5 + (1 - range_mae / 100) * f1 * 0.5
    
    warnings.filterwarnings("default")
    return score

In [22]:
def validate_GINE(dataload_VAL_zip,model,criterion):
    DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
    model.to(DEVICE)
    model.eval()
    loss_val = 0
    y_val_lbl = []
    y_val_pred = []
    with torch.no_grad():
        for data_triple in dataload_VAL_zip:

            sense_gbatch = data_triple[0].to(DEVICE)
            antis_gbatch = data_triple[1].to(DEVICE)
            nongrf_batch = data_triple[2][1].to(DEVICE)
            
            y_batch_pred = model(sense_gbatch,antis_gbatch,nongrf_batch)
            y_batch_lbl = data_triple[2][2].to(torch.float32)
            
            loss_batch = criterion(y_batch_pred,y_batch_lbl)

            loss_val += loss_batch.item()
            y_val_lbl.extend(y_batch_lbl.cpu().numpy())
            y_val_pred.extend(y_batch_pred.cpu().numpy())
        
    y_val_pred = np.array(y_val_pred)
    y_val_lbl = np.array(y_val_lbl)
    model_score = calculate_metrics(y_val_pred, y_val_lbl)
    return loss_val/len(dataload_VAL_zip),model_score

---

## Training

In [23]:
def id_consistent_check(graphload_sense,graphload_antis,dataload_tab):
    sense_gid = []
    antis_gid = []
    tabular_id = []

    for b in graphload_sense:
        sense_gid.append(b.g_id)

    for b in graphload_antis:
        antis_gid.append(b.g_id)

    for b in dataload_tab:
        tabular_id.append(b[0])

    ######
    if len(sense_gid)!=len(antis_gid) or len(tabular_id)!=len(antis_gid):
        raise Exception('Error: Sample numbers are not match between dataloads!')
    print('ID consistent check between dataloads')
    #print('batch','s&a_gid','tid&gid',sep='\t')
    for i in range(len(sense_gid)):
        print(i,end='\r')
        assert(False not in (sense_gid[i]==antis_gid[i]))
        assert(False not in (tabular_id[i]==antis_gid[i]))
        #print(False not in (sense_gid[i]==antis_gid[i]),end='\t')
        #print(False not in (tabular_id[i]==antis_gid[i]))
    return

def cv_train_GINE(model_type:str,TRVL_NonGraph_dataset,TRVL_sense_glist,TRVL_antis_glist,
                  dataload_TEST_IID_zip,dataload_TEST_OOD_zip):
    lr = 0.002
    EPOCHS = 80
    BETA = 1
    BATCH_SIZE = 256
    OVERSAMP = False

    early_stop_score = 1
    warm_up_epoch_num = 10
    loss_tolerance_epoch_num = 5
    
    model_list = []
    cv_log = []
    
    kfold = KFold(n_splits=10,shuffle=False)
    splits = kfold.split(list(range(len(TRVL_NonGraph_dataset))))
    
    for train_index, val_index in splits:
        start_time = time.time()
        
        dataset_TRAIN = {'sense_graph':[TRVL_sense_glist[q] for q in train_index],
                         'antis_graph':[TRVL_antis_glist[q] for q in train_index],
                         'tabular':Subset(TRVL_NonGraph_dataset,train_index)}
        dataset_VAL = {'sense_graph':[TRVL_sense_glist[q] for q in val_index],
                         'antis_graph':[TRVL_antis_glist[q] for q in val_index],
                         'tabular':Subset(TRVL_NonGraph_dataset,val_index)}

        graphload_TRAIN_sense = PyG_DataLoader(dataset_TRAIN['sense_graph'],batch_size=BATCH_SIZE,shuffle=False)
        graphload_TRAIN_antis = PyG_DataLoader(dataset_TRAIN['antis_graph'],batch_size=BATCH_SIZE,shuffle=False)
        dataload_TRAIN = DataLoader(dataset=dataset_TRAIN['tabular'],batch_size=BATCH_SIZE,shuffle=False)

        graphload_VAL_sense = PyG_DataLoader(dataset_VAL['sense_graph'],batch_size=BATCH_SIZE,shuffle=False)
        graphload_VAL_antis = PyG_DataLoader(dataset_VAL['antis_graph'],batch_size=BATCH_SIZE,shuffle=False)
        dataload_VAL = DataLoader(dataset=dataset_VAL['tabular'],batch_size=BATCH_SIZE,shuffle=False)

        id_consistent_check(graphload_TRAIN_sense,graphload_TRAIN_antis,dataload_TRAIN)
        id_consistent_check(graphload_VAL_sense,graphload_VAL_antis,dataload_VAL)

        dataload_TRAIN_zip = list(zip(graphload_TRAIN_sense,graphload_TRAIN_antis,dataload_TRAIN))
        dataload_VAL_zip = list(zip(graphload_VAL_sense,graphload_VAL_antis,dataload_VAL))
        
        model = modisiR_GINE()
        lowest_loss_epoch = {'loss_val':float("inf"),'epoch':0}
        best_score = -float('inf')
        best_OOD = -float('inf')
        optimizer = optim.AdamW([{'params':model.parameters(),'lr':lr}])
        criterion = nn.MSELoss(reduction='mean')
        
        log_train = []
        
        for epoch in range(EPOCHS):
            start_time_epoch = time.time()
            if model_type == 'vrex':
                assert('not available yet')
            elif model_type == 'erm':
                loss_train = train_GINE_ERM(dataload_TRAIN_zip,model,optimizer,criterion)
            else: assert('No such model type. Type should be erm or vrex.')
            loss_val,model_score = validate_GINE(dataload_VAL_zip,model,criterion)
            
            _,test_score_iid = validate_GINE(dataload_TEST_IID_zip,model,criterion)
            _,test_score_ood = validate_GINE(dataload_TEST_OOD_zip,model,criterion)
            
            if epoch > warm_up_epoch_num:
                if loss_val < lowest_loss_epoch['loss_val']:
                    lowest_loss_epoch['epoch'] = epoch
                    lowest_loss_epoch['loss_val'] = loss_val
                elif (epoch-lowest_loss_epoch['epoch']) >= loss_tolerance_epoch_num:
                    lowest_loss_epoch['epoch'] = epoch
                    for param_group in optimizer.param_groups:
                        param_group['lr'] *= 0.5
            
            log_train.append((epoch,loss_train,loss_val,model_score,test_score_iid,test_score_ood,lr))
            
            if model_score > best_score:
                best_score = model_score
                best_OOD = test_score_ood
                best_model = copy.deepcopy(model.state_dict())
            
            print(f'\r{epoch}\t{model_score:.4f}\t{test_score_iid:.4f}\t{test_score_ood:.4f}\t({best_score:.4f},{best_OOD:.4f})',sep='',end='')
            
            if best_score >= early_stop_score:
                break
        model_list.append(best_model)
        cv_log.append(log_train)
        print('')
    df_cv_log = pd.DataFrame()
    for i in range(len(cv_log)):
        df_log = pd.DataFrame(cv_log[i])
        mindex = pd.MultiIndex.from_product([['Model_'+str(i)],['epoch','loss_train','loss_val','val_score','iid_score','ood_score','lr']])
        df_log.columns = mindex
        df_cv_log = pd.concat([df_cv_log,df_log],axis=1)
        
    return model_list,df_cv_log

def cvmodel_test(model_list,dataload_TEST_zip):
    model4test = modisiR_GINE()
    test_score_list = []
    for i in range(len(model_list)):
        model4test.load_state_dict(model_list[i])
        model4test.eval()
        _,test_score = validate_GINE(dataload_TEST_zip,model4test,nn.MSELoss(reduction='mean'))
        print(test_score)
        test_score_list.append(test_score)
    return test_score_list

In [24]:
def save_state_dict_2cpu(PATH_SAVE,model_list,model):
    os.mkdir(PATH_SAVE+'models')
    for i in range(len(model_list)):
        print(i,list(model_list[i].values())[0].device,end='\t')
        model.load_state_dict(model_list[i])
        model.to('cpu')
        print('to',list(model.state_dict().values())[0].device)
        torch.save(model.state_dict(), PATH_SAVE+'models/state_dict_cpu_'+str(i)+'.pth')

### train_ERM: Complete

In [25]:
dataset_TEST_IID = {'sense_graph':TEST_IID_sense_glist,
                 'antis_graph':TEST_IID_antis_glist,
                 'tabular':TEST_IID_NonGraph_dataset}

graphload_TEST_IID_sense = PyG_DataLoader(dataset_TEST_IID['sense_graph'],batch_size=BATCH_SIZE,shuffle=False)
graphload_TEST_IID_antis = PyG_DataLoader(dataset_TEST_IID['antis_graph'],batch_size=BATCH_SIZE,shuffle=False)
dataload_TEST_IID = DataLoader(dataset=dataset_TEST_IID['tabular'],batch_size=BATCH_SIZE,shuffle=False)
dataload_TEST_IID_zip = list(zip(graphload_TEST_IID_sense,graphload_TEST_IID_antis,dataload_TEST_IID))

dataset_TEST_OOD = {'sense_graph':TEST_OOD_sense_glist,
                 'antis_graph':TEST_OOD_antis_glist,
                 'tabular':TEST_OOD_NonGraph_dataset}

graphload_TEST_OOD_sense = PyG_DataLoader(dataset_TEST_OOD['sense_graph'],batch_size=BATCH_SIZE,shuffle=False)
graphload_TEST_OOD_antis = PyG_DataLoader(dataset_TEST_OOD['antis_graph'],batch_size=BATCH_SIZE,shuffle=False)
dataload_TEST_OOD = DataLoader(dataset=dataset_TEST_OOD['tabular'],batch_size=BATCH_SIZE,shuffle=False)
dataload_TEST_OOD_zip = list(zip(graphload_TEST_OOD_sense,graphload_TEST_OOD_antis,dataload_TEST_OOD))

In [26]:
ERM_model_list,ERM_cv_log = cv_train_GINE('erm',TRVL_NonGraph_dataset,TRVL_sense_glist,TRVL_antis_glist,
                  dataload_TEST_IID_zip,dataload_TEST_OOD_zip)

ID consistent check between dataloads
ID consistent check between dataloads
79	0.8159	0.8165	0.5521	(0.8172,0.5565)
ID consistent check between dataloads
ID consistent check between dataloads
79	0.8110	0.8106	0.5605	(0.8120,0.5621)
ID consistent check between dataloads
ID consistent check between dataloads
79	0.8217	0.8106	0.5336	(0.8234,0.5346)
ID consistent check between dataloads
ID consistent check between dataloads
79	0.8190	0.8150	0.5804	(0.8195,0.5738)
ID consistent check between dataloads
ID consistent check between dataloads
79	0.8157	0.8201	0.5642	(0.8157,0.5642)
ID consistent check between dataloads
ID consistent check between dataloads
79	0.8095	0.8221	0.5613	(0.8095,0.5613)
ID consistent check between dataloads
ID consistent check between dataloads
79	0.8251	0.8144	0.5695	(0.8291,0.5682)
ID consistent check between dataloads
ID consistent check between dataloads
79	0.8205	0.8169	0.5745	(0.8205,0.5745)
ID consistent check between dataloads
ID consistent check between datalo

In [27]:
print('IID:',np.mean(np.array(cvmodel_test(ERM_model_list,dataload_TEST_IID_zip))),'\n')
print('OOD:',np.mean(np.array(cvmodel_test(ERM_model_list,dataload_TEST_OOD_zip))),'\n')

0.8181822166711684
0.8111163987835635
0.8095424122671009
0.8170279305688393
0.8201071098627466
0.8221191507577896
0.8121104845781938
0.8168754730530697
0.8196733327071575
0.8078007415303551
IID: 0.8154555250779986 

0.5564907883746284
0.5620661156098827
0.5345839387585928
0.5737923000405409
0.5641624794189506
0.5613489875489689
0.5682139381042308
0.5745393467561577
0.5806921317118591
0.6041925777772248
OOD: 0.5680082604101037 



In [28]:
PATH_SAVE = './Models_out/GINE-ERM-80epo-0708v3_2/'
cv_log = ERM_cv_log
mdl_list = ERM_model_list
#################################
if not os.path.exists(PATH_SAVE):
    os.makedirs(PATH_SAVE)
mdlsave = modisiR_GINE()
save_state_dict_2cpu(PATH_SAVE,mdl_list,mdlsave)
######
cv_log.to_pickle(PATH_SAVE+'/df_cv_log.pickle')

0 cpu	to cpu
1 cpu	to cpu
2 cpu	to cpu
3 cpu	to cpu
4 cpu	to cpu
5 cpu	to cpu
6 cpu	to cpu
7 cpu	to cpu
8 cpu	to cpu
9 cpu	to cpu


### train_ERM: Full

In [29]:
dataset_TEST_IID_all1 = {'sense_graph':TEST_IID_sense_glist_all1,
                 'antis_graph':TEST_IID_antis_glist_all1,
                 'tabular':TEST_IID_NonGraph_dataset}

graphload_TEST_IID_sense_all1 = PyG_DataLoader(dataset_TEST_IID_all1['sense_graph'],batch_size=BATCH_SIZE,shuffle=False)
graphload_TEST_IID_antis_all1 = PyG_DataLoader(dataset_TEST_IID_all1['antis_graph'],batch_size=BATCH_SIZE,shuffle=False)
dataload_TEST_IID_all1 = DataLoader(dataset=dataset_TEST_IID_all1['tabular'],batch_size=BATCH_SIZE,shuffle=False)
dataload_TEST_IID_zip_all1 = list(zip(graphload_TEST_IID_sense_all1,graphload_TEST_IID_antis_all1,dataload_TEST_IID_all1))

dataset_TEST_OOD_all1 = {'sense_graph':TEST_OOD_sense_glist_all1,
                 'antis_graph':TEST_OOD_antis_glist_all1,
                 'tabular':TEST_OOD_NonGraph_dataset}

graphload_TEST_OOD_sense_all1 = PyG_DataLoader(dataset_TEST_OOD_all1['sense_graph'],batch_size=BATCH_SIZE,shuffle=False)
graphload_TEST_OOD_antis_all1 = PyG_DataLoader(dataset_TEST_OOD_all1['antis_graph'],batch_size=BATCH_SIZE,shuffle=False)
dataload_TEST_OOD_all1 = DataLoader(dataset=dataset_TEST_OOD_all1['tabular'],batch_size=BATCH_SIZE,shuffle=False)
dataload_TEST_OOD_zip_all1 = list(zip(graphload_TEST_OOD_sense_all1,graphload_TEST_OOD_antis_all1,dataload_TEST_OOD_all1))

In [30]:
ERM_model_list_all1,ERM_cv_log_all1 = cv_train_GINE('erm',TRVL_NonGraph_dataset,TRVL_sense_glist_all1,TRVL_antis_glist_all1,
                  dataload_TEST_IID_zip_all1,dataload_TEST_OOD_zip_all1)

ID consistent check between dataloads
ID consistent check between dataloads
79	0.7934	0.7907	0.5329	(0.7969,0.5337)
ID consistent check between dataloads
ID consistent check between dataloads
79	0.7997	0.7955	0.5551	(0.8027,0.5561)
ID consistent check between dataloads
ID consistent check between dataloads
79	0.8028	0.8000	0.5634	(0.8057,0.5757)
ID consistent check between dataloads
ID consistent check between dataloads
79	0.7993	0.8016	0.5792	(0.8037,0.5812)
ID consistent check between dataloads
ID consistent check between dataloads
79	0.7955	0.7942	0.5279	(0.7976,0.5276)
ID consistent check between dataloads
ID consistent check between dataloads
79	0.7722	0.7919	0.5241	(0.7740,0.5242)
ID consistent check between dataloads
ID consistent check between dataloads
79	0.7935	0.7886	0.5576	(0.7962,0.5568)
ID consistent check between dataloads
ID consistent check between dataloads
79	0.7929	0.7959	0.5431	(0.7981,0.5457)
ID consistent check between dataloads
ID consistent check between datalo

In [31]:
PATH_SAVE = './Models_out/GINE_all1-ERM-80epo-0708v3_2/'
cv_log = ERM_cv_log_all1
mdl_list = ERM_model_list_all1
#################################
if not os.path.exists(PATH_SAVE):
    os.makedirs(PATH_SAVE)
mdlsave = modisiR_GINE()
save_state_dict_2cpu(PATH_SAVE,mdl_list,mdlsave)
######
cv_log.to_pickle(PATH_SAVE+'/df_cv_log.pickle')

0 cpu	to cpu
1 cpu	to cpu
2 cpu	to cpu
3 cpu	to cpu
4 cpu	to cpu
5 cpu	to cpu
6 cpu	to cpu
7 cpu	to cpu
8 cpu	to cpu
9 cpu	to cpu


### train_ERM: Null

In [32]:
dataset_TEST_IID_none = {'sense_graph':TEST_IID_sense_glist_none,
                 'antis_graph':TEST_IID_antis_glist_none,
                 'tabular':TEST_IID_NonGraph_dataset}

graphload_TEST_IID_sense_none = PyG_DataLoader(dataset_TEST_IID_none['sense_graph'],batch_size=BATCH_SIZE,shuffle=False)
graphload_TEST_IID_antis_none = PyG_DataLoader(dataset_TEST_IID_none['antis_graph'],batch_size=BATCH_SIZE,shuffle=False)
dataload_TEST_IID_none = DataLoader(dataset=dataset_TEST_IID_none['tabular'],batch_size=BATCH_SIZE,shuffle=False)
dataload_TEST_IID_zip_none = list(zip(graphload_TEST_IID_sense_none,graphload_TEST_IID_antis_none,dataload_TEST_IID_none))

dataset_TEST_OOD_none = {'sense_graph':TEST_OOD_sense_glist_none,
                 'antis_graph':TEST_OOD_antis_glist_none,
                 'tabular':TEST_OOD_NonGraph_dataset}

graphload_TEST_OOD_sense_none = PyG_DataLoader(dataset_TEST_OOD_none['sense_graph'],batch_size=BATCH_SIZE,shuffle=False)
graphload_TEST_OOD_antis_none = PyG_DataLoader(dataset_TEST_OOD_none['antis_graph'],batch_size=BATCH_SIZE,shuffle=False)
dataload_TEST_OOD_none = DataLoader(dataset=dataset_TEST_OOD_none['tabular'],batch_size=BATCH_SIZE,shuffle=False)
dataload_TEST_OOD_zip_none = list(zip(graphload_TEST_OOD_sense_none,graphload_TEST_OOD_antis_none,dataload_TEST_OOD_none))

In [33]:
ERM_model_list_none,ERM_cv_log_none = cv_train_GINE('erm',TRVL_NonGraph_dataset,TRVL_sense_glist_none,TRVL_antis_glist_none,
                  dataload_TEST_IID_zip_none,dataload_TEST_OOD_zip_none)

ID consistent check between dataloads
ID consistent check between dataloads
79	0.7621	0.7715	0.5213	(0.7654,0.5225)
ID consistent check between dataloads
ID consistent check between dataloads
79	0.7805	0.7644	0.5643	(0.7851,0.5598)
ID consistent check between dataloads
ID consistent check between dataloads
79	0.7815	0.7687	0.5363	(0.7815,0.5363)
ID consistent check between dataloads
ID consistent check between dataloads
79	0.7681	0.7782	0.5739	(0.7692,0.5769)
ID consistent check between dataloads
ID consistent check between dataloads
79	0.7572	0.7635	0.5859	(0.7587,0.5861)
ID consistent check between dataloads
ID consistent check between dataloads
79	0.7446	0.7770	0.5231	(0.7455,0.5241)
ID consistent check between dataloads
ID consistent check between dataloads
79	0.7676	0.7630	0.5699	(0.7676,0.5699)
ID consistent check between dataloads
ID consistent check between dataloads
79	0.7644	0.7611	0.5962	(0.7728,0.5928)
ID consistent check between dataloads
ID consistent check between datalo

In [34]:
PATH_SAVE = './Models_out/GINE_none-ERM-80epo-0708v3_2/'
cv_log = ERM_cv_log_none
mdl_list = ERM_model_list_none
#################################
if not os.path.exists(PATH_SAVE):
    os.makedirs(PATH_SAVE)
mdlsave = modisiR_GINE()
save_state_dict_2cpu(PATH_SAVE,mdl_list,mdlsave)
######
cv_log.to_pickle(PATH_SAVE+'/df_cv_log.pickle')

0 cpu	to cpu
1 cpu	to cpu
2 cpu	to cpu
3 cpu	to cpu
4 cpu	to cpu
5 cpu	to cpu
6 cpu	to cpu
7 cpu	to cpu
8 cpu	to cpu
9 cpu	to cpu


### train_ERM: Primary

In [35]:
dataset_TEST_IID_str1 = {'sense_graph':TEST_IID_sense_glist_str1,
                 'antis_graph':TEST_IID_antis_glist_str1,
                 'tabular':TEST_IID_NonGraph_dataset}

graphload_TEST_IID_sense_str1 = PyG_DataLoader(dataset_TEST_IID_str1['sense_graph'],batch_size=BATCH_SIZE,shuffle=False)
graphload_TEST_IID_antis_str1 = PyG_DataLoader(dataset_TEST_IID_str1['antis_graph'],batch_size=BATCH_SIZE,shuffle=False)
dataload_TEST_IID_str1 = DataLoader(dataset=dataset_TEST_IID_str1['tabular'],batch_size=BATCH_SIZE,shuffle=False)
dataload_TEST_IID_zip_str1 = list(zip(graphload_TEST_IID_sense_str1,graphload_TEST_IID_antis_str1,dataload_TEST_IID_str1))

dataset_TEST_OOD_str1 = {'sense_graph':TEST_OOD_sense_glist_str1,
                 'antis_graph':TEST_OOD_antis_glist_str1,
                 'tabular':TEST_OOD_NonGraph_dataset}

graphload_TEST_OOD_sense_str1 = PyG_DataLoader(dataset_TEST_OOD_str1['sense_graph'],batch_size=BATCH_SIZE,shuffle=False)
graphload_TEST_OOD_antis_str1 = PyG_DataLoader(dataset_TEST_OOD_str1['antis_graph'],batch_size=BATCH_SIZE,shuffle=False)
dataload_TEST_OOD_str1 = DataLoader(dataset=dataset_TEST_OOD_str1['tabular'],batch_size=BATCH_SIZE,shuffle=False)
dataload_TEST_OOD_zip_str1 = list(zip(graphload_TEST_OOD_sense_str1,graphload_TEST_OOD_antis_str1,dataload_TEST_OOD_str1))

In [36]:
ERM_model_list_str1,ERM_cv_log_str1 = cv_train_GINE('erm',TRVL_NonGraph_dataset,TRVL_sense_glist_str1,TRVL_antis_glist_str1,
                  dataload_TEST_IID_zip_str1,dataload_TEST_OOD_zip_str1)

ID consistent check between dataloads
ID consistent check between dataloads
79	0.8081	0.8067	0.5934	(0.8103,0.5986)
ID consistent check between dataloads
ID consistent check between dataloads
79	0.8126	0.8045	0.5826	(0.8147,0.5842)
ID consistent check between dataloads
ID consistent check between dataloads
79	0.8161	0.8101	0.5785	(0.8164,0.5777)
ID consistent check between dataloads
ID consistent check between dataloads
79	0.8072	0.8076	0.6424	(0.8194,0.6280)
ID consistent check between dataloads
ID consistent check between dataloads
79	0.8087	0.8094	0.5908	(0.8124,0.5915)
ID consistent check between dataloads
ID consistent check between dataloads
79	0.7852	0.8073	0.5873	(0.7891,0.5831)
ID consistent check between dataloads
ID consistent check between dataloads
79	0.8120	0.8018	0.5892	(0.8143,0.5900)
ID consistent check between dataloads
ID consistent check between dataloads
79	0.8104	0.8134	0.6031	(0.8120,0.6112)
ID consistent check between dataloads
ID consistent check between datalo

In [37]:
PATH_SAVE = './Models_out/GINE_str1-ERM-80epo-0708v3_2/'
cv_log = ERM_cv_log_str1
mdl_list = ERM_model_list_str1
#################################
if not os.path.exists(PATH_SAVE):
    os.makedirs(PATH_SAVE)
mdlsave = modisiR_GINE()
save_state_dict_2cpu(PATH_SAVE,mdl_list,mdlsave)
######
cv_log.to_pickle(PATH_SAVE+'/df_cv_log.pickle')

0 cpu	to cpu
1 cpu	to cpu
2 cpu	to cpu
3 cpu	to cpu
4 cpu	to cpu
5 cpu	to cpu
6 cpu	to cpu
7 cpu	to cpu
8 cpu	to cpu
9 cpu	to cpu


---

In [181]:
colors = {'b4':'#060666', 'b3':'#186898', 'b2':'#96BCCB', 'b1':'#DDF0FF', 
          'pearl':'#F0F2F0', 
          'y1':'#FFF0DD', 'y2':'#F2C892', 'y3':'#F88808', 'y4':'#960606'}

## Loss

In [None]:
def loss_plot(data1,data2,save_name=None):
    x = np.linspace(1,80,80)
    # 创建画布和坐标轴
    fig, ax = plt.subplots(figsize=(5, 4))

    # 绘制所有原始曲线
    for y in data1:
        ax.plot(x, y, color=colors['b2'], alpha=0.6, linewidth=1, label='Individual_Train' if not ax.lines else "")

    # 计算并绘制平均值曲线
    mean_curve = np.mean(data1, axis=0)
    ax.plot(x, mean_curve, color=colors['b3'], linewidth=2, label="Average_Train", zorder=10)

    # 绘制所有原始曲线
    for y in data2[:-1]:
        ax.plot(x, y, color=colors['y2'], alpha=0.6, linewidth=1, label='Individual_Validate' if not ax.lines else "")
    ax.plot(x, data2[-1], color=colors['y2'], alpha=0.6, linewidth=1, label='Individual_Validate')

    # 计算并绘制平均值曲线
    mean_curve = np.mean(data2, axis=0)
    ax.plot(x, mean_curve, color=colors['y3'], linewidth=2, label="Average_Validate", zorder=10)

    # 美化图形
    #ax.set_title(title, fontsize=14)
    ax.set_xlabel('Epochs')
    ax.set_ylabel('Loss')
    ax.grid(True, alpha=0.3)
    ax.legend()
    if save_name!=None:
        plt.savefig(save_name+'loss_plot.pdf',  bbox_inches='tight', dpi=300)
    plt.tight_layout()
    plt.show()

In [None]:
SAVE_PATH = '../DATA/GINE_ERM/'

loss_plot(ERM_cv_log.swaplevel(axis=1)['loss_train'].to_numpy().swapaxes(0,1),
          ERM_cv_log.swaplevel(axis=1)['loss_val'].to_numpy().swapaxes(0,1)
         ,SAVE_PATH)

## PCC

In [None]:
from scipy.stats import pearsonr

def GINE_pred(model,dataload_zip):
    model.eval()
    y_lbl = []
    y_pred = []
    with torch.no_grad():
        for data_triple in dataload_zip:

            sense_gbatch = data_triple[0]
            antis_gbatch = data_triple[1]
            nongrf_batch = data_triple[2][1]
            
            y_batch_pred = model(sense_gbatch,antis_gbatch,nongrf_batch)
            y_batch_lbl = data_triple[2][2].to(torch.float32)
            y_lbl.extend(y_batch_lbl.cpu().numpy())
            y_pred.extend(y_batch_pred.cpu().numpy())
    y_lbl = np.array(y_lbl)
    y_pred = np.array(y_pred)
    return y_pred,y_lbl

def cvmodel_pred_GINE(model_list,dataload_test_zip):
    ls_y_pred = []
    for model in model_list:
        print('*')
        y_pred_TEST = GINE_pred(model,dataload_test_zip)[0]
        ls_y_pred.append(y_pred_TEST)
    return ls_y_pred

def PCC(true,pred,size=(4,4),save_name=None):
    # 计算皮尔逊相关系数
    r, p_value = pearsonr(true, pred)
    print(r,p_value)

    # 创建图表
    plt.figure(figsize=size)
    plt.scatter(true, pred, s=15, color=colors['b3'] ,alpha=0.5, label='Data Points')  # 散点图

    # 添加y=x参考线
    min_val = min(true.min(), pred.min())
    max_val = max(true.max(), pred.max())
    plt.plot([min_val, max_val], [min_val, max_val], 'k--', linewidth=2, label='y = x')
    
    # 添加最佳拟合线
    slope, intercept = np.polyfit(true, pred, 1)
    fit_x = np.linspace(min(true), max(true), 100)
    fit_y = slope * fit_x + intercept
    plt.plot(fit_x, fit_y, color=colors['y3'],linewidth=3, label='Best Fit Line')

    # 标注相关系数
    text_str = f'r = {r.item():.3f}'
    plt.text(0.04, 0.95, text_str, transform=plt.gca().transAxes,
             fontsize=12, verticalalignment='top', bbox=dict(facecolor='white', alpha=0.8))

    # 设置图表属性
    plt.xlabel('True mRNA_remains %')
    plt.ylabel('Predicted mRNA_remains %')
    #plt.title('True vs. Predicted Values with Pearson Correlation')
    #plt.legend()
    plt.grid(True, linestyle='--', alpha=0.7)
    plt.tight_layout()
    
    if save_name!=None:
        plt.savefig(save_name+'.pdf', bbox_inches='tight', dpi=300)

    # 显示图表
    plt.show()

In [None]:
ls_models_for_PCC = []
for i in ERM_model_list:
    model = modisiR_GINE(node_xlen=43,edge_xlen=1,dim_h=64)
    model.load_state_dict(i)
    ls_models_for_PCC.append(model)

In [None]:
SAVE_PATH = '../DATA/GINE_ERM/'
y_true_IID = GINE_pred(ls_models_for_PCC[0],dataload_TEST_IID_zip)[1]
ls_y_pred_IID_ERM = cvmodel_pred_GINE(ls_models_for_PCC,dataload_TEST_IID_zip)
mean_y_pred_IID_ERM = np.mean(np.stack(ls_y_pred_IID_ERM),axis=0)
print(mean_y_pred_IID_ERM.shape)

PCC(y_true_IID,
    mean_y_pred_IID_ERM,
    (5,4)
   ,SAVE_PATH+'PCC_IID')

In [None]:
SAVE_PATH = '../DATA/GINE_ERM/'
y_true_OOD = GINE_pred(ls_models_for_PCC[0],dataload_TEST_OOD_zip)[1]
ls_y_pred_OOD_ERM = cvmodel_pred_GINE(ls_models_for_PCC,dataload_TEST_OOD_zip)
mean_y_pred_OOD_ERM = np.mean(np.stack(ls_y_pred_OOD_ERM),axis=0)
print(mean_y_pred_OOD_ERM.shape)

PCC(y_true_OOD,
    mean_y_pred_OOD_ERM,
    (5,4)
   ,SAVE_PATH+'PCC_OOD')

## scores

In [None]:
from sklearn.metrics import roc_curve, auc

def models_scores(y_true_np,y_preds_np,threshold=30):
    score_list = []
    for y_pred in y_preds_np:
        
        y_true = y_true_np.clip(0,100)
        y_pred = y_pred.clip(0,100)
        
        mae = np.mean(np.abs(y_true - y_pred))
        
        # 将实际值和预测值转换为二进制分类（低于阈值为1，高于或等于阈值为0）
        y_true_binary = (y_true < threshold).astype(int)
        y_pred_binary = (y_pred < threshold).astype(int)
        
        # 创建掩码，用于筛选预测值在0和阈值之间的样本
        mask = (y_pred >= 0) & (y_pred <= threshold)
        range_mae = mean_absolute_error(y_true[mask], y_pred[mask]) if mask.sum() > 0 else 100
        
        # 计算精确度、召回率和F1得分
        precision = precision_score(y_true_binary, y_pred_binary, average='binary')
        recall = recall_score(y_true_binary, y_pred_binary, average='binary')
        f1 = 2 * precision * recall / (precision + recall)
        
        # ROC & AUC
        fpr, tpr, thresholds = roc_curve(y_true_binary, 1-y_pred/100)
        roc_auc = auc(fpr, tpr)
        
        # 计算综合评分
        score = (1 - mae / 100) * 0.5 + (1 - range_mae / 100) * f1 * 0.5

        #warnings.filterwarnings("default")
        score_list.append((score,precision,recall,f1,roc_auc))
    df_score_list = pd.DataFrame(score_list)
    df_score_list.columns = ['Score','Precision','Recall','F1','AUC']
    return df_score_list


def scores_bar(df_scores,ylim=(0,1.05),save_name=None):
    score_labels = df_scores.columns
    # 计算均值和标准差
    means = df_scores.mean()
    stds = df_scores.std()
    maxes = df_scores.max()

    # 创建条形图
    plt.figure(figsize=(5, 4))
    bars = plt.bar(
        score_labels,
        means,
        yerr=stds,  # 添加误差条（标准差）
        capsize=10,  # 误差条顶部横线长度
        color=[colors['b2'],colors['b2'],colors['b2'],colors['b2'],colors['b2']],
        edgecolor=[colors['b3'],colors['b3'],colors['b3'],colors['b3'],colors['b3']]
    )

    # 叠加散点图
    for i in range(len(score_labels)):
        x = np.random.normal(i, 0.02, size=len(df_scores[score_labels[i]]))  # 添加轻微抖动避免点重叠
        plt.scatter(x, df_scores[score_labels[i]], s=15,color=colors['b3'], alpha=0.6, label=score_labels[i])


    # 在条形顶部标注均值
    for bar, mean, max_ in zip(bars, means, maxes):
        plt.text(
            bar.get_x() + bar.get_width() / 2,
            max_+0.01,
            f'{mean:.2f}',
            ha='center',
            va='bottom'
        )

    # 添加标题和标签
    #plt.title('Model Performance (10-Fold CV)')
    plt.ylabel('Mean ± STD')
    plt.ylim(ylim[0], ylim[1])  # 根据实际分数范围调整
    plt.grid(axis='y', linestyle='--', alpha=0.7)
    if save_name!=None:
        plt.savefig(save_name+'scores.pdf', bbox_inches='tight', dpi=300)
    plt.show()

In [None]:
df_scores_IID = models_scores(y_true_IID,np.stack(ls_y_pred_IID_ERM))
df_scores_IID

In [None]:
SAVE_PATH = '../DATA/GINE_ERM/IID_'
scores_bar(df_scores_IID,save_name=SAVE_PATH)

In [None]:
df_scores_OOD = models_scores(y_true_OOD,np.stack(ls_y_pred_OOD_ERM))
SAVE_PATH = '../DATA/GINE_ERM/OOD_'
scores_bar(df_scores_OOD,save_name=SAVE_PATH)

## ROC

In [None]:
def roc_auc(y_true,y_pred,threshold=30,size=(5,4),save_name=None):
    y_true = y_true.clip(0,100)
    y_pred = y_pred.clip(0,100)
    
    # 将实际值和预测值转换为二进制分类（低于阈值为1，高于或等于阈值为0）
    y_true_binary = (y_true < threshold).astype(int)
    y_pred_binary = (y_pred < threshold).astype(int)
    
    # ROC & AUC
    fpr, tpr, thresholds = roc_curve(y_true_binary, 1-y_pred/100)
    roc_auc = auc(fpr, tpr)
    
    fig, ax = plt.subplots(figsize=size)
    
    ax.plot(fpr, tpr, color=colors['y3'], lw=2, label='ROC curve (area = %0.2f)' % roc_auc)    
    ax.fill_between(fpr, tpr, 0, where=(tpr > 0), color=colors['y2'], alpha=0.3, interpolate=True)
    ax.plot([0, 1], [0, 1], color='grey', lw=2, linestyle='--')
    ax.set_xlim([0.0, 1.0])
    ax.set_ylim([0.0, 1.05])
    
    ax.set_xlabel('False Positive Rate')
    ax.set_ylabel('True Positive Rate')
    ax.grid(True, alpha=0.3)
    ax.legend(loc="lower right")
    
    if save_name!=None:
        plt.savefig(save_name+'ROC.pdf', bbox_inches='tight', dpi=300)
    
    return roc_auc

In [None]:
roc_auc(y_true_IID,mean_y_pred_IID_ERM,
        size=(5,4)
       ,save_name='../DATA/GINE_ERM/IID_')

In [None]:
roc_auc(y_true_OOD,mean_y_pred_OOD_ERM,
        size=(5,4)
       ,save_name='../DATA/GINE_ERM/OOD_')