In [1]:
import os

import pandas as pd
import numpy as np
from scipy.sparse import csr_matrix, save_npz

from MAGNN_utils.preprocess import (
    get_column,
    assign_index,
    map_index_to_relation_file,
    export_index2dat,
    split_date,
    save_split_data2npz,
    generate_triplet_array,
    generate_long_relationship_array,
    lexicographical_sort,
    process_and_save_metapath_idx_in_batches,
    process_and_save_adjlist_in_batches
)

# Input data preprocess

In [3]:
# list all file paths for the original relation .dat files
file_path = os.getcwd()
# file 1, 2: microbe-disease
file1 = os.path.join(file_path, "../data", "MAGNN_data", "disbiome_taxid_mondo.dat")
file2 = os.path.join(file_path, "../data", "MAGNN_data", "gmmad2_taxid_mondo.dat")
# file 3, 4: microbe-metabolite
file3 = os.path.join(file_path, "../data", "MAGNN_data", "gmmad2_taxid_met.dat")
file4 = os.path.join(file_path, "../data", "MAGNN_data", "hmdb_taxid_met.dat")
# file 5: metabolite-disease
file5 = os.path.join(file_path, "../data", "MAGNN_data", "hmdb_met_disease.dat")

In [4]:
# get entity columns individually
microbes1 = get_column(file1, colname1="Microbe", colname2="Disease", col="col1")
microbes2 = get_column(file2, colname1="Microbe", colname2="Disease", col="col1")
microbes3 = get_column(file3, colname1="Microbe", colname2="Metabolite", col="col1")
microbes4 = get_column(file4, colname1="Microbe", colname2="Metabolite", col="col1")
all_microbes = assign_index([microbes1, microbes2, microbes3, microbes4])

disease1 = get_column(file1, colname1="Microbe", colname2="Disease", col="col2")
disease2 = get_column(file2, colname1="Microbe", colname2="Disease", col="col2")
disease3 = get_column(file5, colname1="Metabolite", colname2="Disease", col="col2")
all_diseases = assign_index([disease1, disease2, disease3])

metabolite1 = get_column(file3, colname1="Microbe", colname2="Metabolite", col="col2")
metabolite2 = get_column(file4, colname1="Microbe", colname2="Metabolite", col="col2")
metabolite3 = get_column(file5, colname1="Metabolite", colname2="Disease", col="col1")
all_metabolites = assign_index([metabolite1, metabolite2, metabolite3])

In [5]:
# export index file to MAGNN_MKG/MAGNN/data
export_index2dat(all_microbes, "data/raw/microbe_index.dat")
export_index2dat(all_metabolites, "data/raw/metabolite_index.dat")
export_index2dat(all_diseases, "data/raw/disease_index.dat")

In [6]:
# merge two relation dfs together
md_merged_df = map_index_to_relation_file([file1, file2], "Microbe", "Disease", all_microbes, all_diseases)
mm_df = map_index_to_relation_file([file3, file4], "Microbe", "Metabolite", all_microbes, all_metabolites)
metd_df = map_index_to_relation_file([file5], "Metabolite", "Disease", all_metabolites, all_diseases)

In [7]:
# export relational dfs to .dat files
export_index2dat(md_merged_df, "data/raw/microbe_disease_idx.dat")
export_index2dat(mm_df, "data/raw/microbe_metabolite_idx.dat")
export_index2dat(metd_df, "data/raw/metabolite_disease_idx.dat")

## Statistics of total count of each entity and edges

In [8]:
microbe_idx = pd.read_csv("data/raw/microbe_index.dat", header=None, names=["Microbe", "Idx"])
print(f"Total number of microbes: {microbe_idx.shape[0]}")
metabolite_idx = pd.read_csv("data/raw/metabolite_index.dat", header=None, names=["Metabolite", "Idx"])
print(f"Total number of metabolites: {metabolite_idx.shape[0]}")
disease_idx = pd.read_csv("data/raw/disease_index.dat", header=None, names=["Disease", "Idx"])
print(f"Total number of diseases: {disease_idx.shape[0]}")

Total number of microbes: 8202
Total number of metabolites: 23823
Total number of diseases: 898


In [9]:
microbe_disease = pd.read_csv("data/raw/microbe_disease_idx.dat", encoding='utf-8', delimiter='\t', names=['MicrobeIdx', 'DiseaseIdx'])
print(f"Total edges between microbe-disease: {microbe_disease.shape[0]}")
microbe_metabolite = pd.read_csv('data/raw/microbe_metabolite_idx.dat', encoding='utf-8', delimiter='\t', names=['MicrobeIdx', 'MetaboliteIdx'])
print(f"Total edges between microbe-metabolite: {microbe_metabolite.shape[0]}")
metabolite_disease = pd.read_csv('data/raw/metabolite_disease_idx.dat', encoding='utf-8', delimiter='\t', names=['MetaboliteIdx', 'DiseaseIdx'])
print(f"Total edges between metabolite-disease: {metabolite_disease.shape[0]}")

Total edges between microbe-disease: 505852
Total edges between microbe-metabolite: 599202
Total edges between metabolite-disease: 27546


# Create adjacency matrix

In [10]:
save_prefix = "data/preprocessed/"

In [11]:
microbe_disease = pd.read_csv("data/raw/microbe_disease_idx.dat", encoding='utf-8', delimiter='\t', names=['MicrobeIdx', 'DiseaseIdx'])
microbe_metabolite = pd.read_csv('data/raw/microbe_metabolite_idx.dat', encoding='utf-8', delimiter='\t', names=['MicrobeIdx', 'MetaboliteIdx'])
metabolite_disease = pd.read_csv('data/raw/metabolite_disease_idx.dat', encoding='utf-8', delimiter='\t', names=['MetaboliteIdx', 'DiseaseIdx'])
num_microbe = 8202
num_metabolite = 23823
num_disease = 898

In [12]:
# build adjacency matrix
# 0 for microbe, 1 for disease, 2 for metabolite
dim = num_microbe + num_disease + num_metabolite

type_mask = np.zeros(dim, dtype=np.int16)
type_mask[num_microbe:num_microbe+num_disease] = 1
type_mask[num_microbe+num_disease:]= 2

adjM = np.zeros((dim, dim), dtype=np.int16)
for _, row in microbe_disease.iterrows():
    microID = row["MicrobeIdx"]
    diseaseID = num_microbe + row["DiseaseIdx"]
    adjM[microID, diseaseID] = 1
    adjM[diseaseID, microID] = 1
for _, row in microbe_metabolite.iterrows():
    microID = row["MicrobeIdx"]
    metID = num_microbe + num_disease + row["MetaboliteIdx"]
    adjM[microID, metID] = 1
    adjM[metID, microID] = 1
for _, row in metabolite_disease.iterrows():
    metID = num_microbe + num_disease + row["MetaboliteIdx"]
    diseaseID = num_microbe + row["DiseaseIdx"]
    adjM[metID, diseaseID] = 1
    adjM[diseaseID, metID] = 1

In [13]:
# Example sparse adjacency matrix
adjM_sparse = csr_matrix(adjM)

# Calculate sparsity
total_elements = adjM_sparse.shape[0] * adjM_sparse.shape[1]
non_zero_elements = adjM_sparse.nnz
sparsity = 1 - (non_zero_elements / total_elements)
sparsity_percentage = sparsity * 100

print(f"Sparsity of the adjacency matrix: {sparsity_percentage:.2f}%")
print(f"Sparsity of the sparse adjacency matrix: {sparsity:.2f}")

Sparsity of the adjacency matrix: 99.79%
Sparsity of the sparse adjacency matrix: 1.00


 # Create edge metapath index array

In [14]:
# map each microbe to a list of diseases within adjM and remove empty arrays
# adjM[microbe, diseases]
microbe_disease_list = {i: adjM[i, num_microbe:num_microbe+num_disease].nonzero()[0].astype(np.int16) for i in range(num_microbe)}
microbe_disease_list = {i: v for i, v in microbe_disease_list.items() if v.size > 0}

# map each disease to a list of microbes within adjM and remove empty arrays
# adjM[disease, microbes]
disease_microbe_list = {i: adjM[num_microbe+i, :num_microbe].nonzero()[0].astype(np.int16) for i in range(num_disease)}
disease_microbe_list = {i: v for i, v in disease_microbe_list.items() if v.size > 0}

# map each metabolite to a list of diseases within adjM and remove empty arrays
# adjM[metabolite, diseases]
metabolite_disease_list = {i: adjM[num_microbe+num_disease+i, num_microbe:num_microbe+num_disease].nonzero()[0].astype(np.int16) for i in range(num_metabolite)}
metabolite_disease_list = {i: v for i, v in metabolite_disease_list.items() if v.size > 0}

# map each disease to a list of metabolites within adjM and remove empty arrays
# adjM[disease, metabolites]
disease_metabolite_list = {i: adjM[num_microbe+i, num_microbe+num_disease:].nonzero()[0].astype(np.int16) for i in range(num_disease)}
disease_metabolite_list = {i: v for i, v in disease_metabolite_list.items() if v.size > 0}

# map each microbe to a list of metabolites within adjM and remove empty arrays
# adjM[microbe, metabolites]
microbe_metabolite_list = {i: adjM[i, num_microbe+num_disease:].nonzero()[0].astype(np.int16) for i in range(num_microbe)}
microbe_metabolite_list = {i: v for i, v in microbe_metabolite_list.items() if v.size > 0}

# map each metabolite to a list of microbes within adjM and remove empty arrays
# adjM[metabolite, microbes]
metabolite_microbe_list = {i: adjM[num_microbe+num_disease+i, :num_microbe].nonzero()[0].astype(np.int16) for i in range(num_metabolite)}
metabolite_microbe_list = {i: v for i, v in metabolite_microbe_list.items() if v.size > 0}

In [None]:
# 0-1-0 (microbe-disease-microbe)
# remove the same metapath types with reverse order. e.g., (1, 0, 2) and (2, 0, 1) are the same
# remove path includes the same microbe1 and microbe2 (same 1st and last element). e.g., (1, 4, 1) and (0, 4, 0) are removed
microbe_disease_microbe = generate_triplet_array(disease_microbe_list)
microbe_disease_microbe[:, 1] += num_microbe
microbe_disease_microbe = lexicographical_sort(microbe_disease_microbe, [0, 2, 1])

In [None]:
# save 0-1-0 in batches
process_and_save_metapath_idx_in_batches(
    metapath_type=(0, 1, 0),
    metapath_array=microbe_disease_microbe,
    target_idx_list=np.arange(num_microbe),
    offset=0,
    save_prefix=save_prefix,
    group_index=0
)

In [None]:
process_and_save_adjlist_in_batches(
    metapath_type=(0, 1, 0),
    metapath_array=microbe_disease_microbe,
    target_idx_list=np.arange(num_microbe),
    offset=0,
    save_prefix=save_prefix,
    group_index=0,
)

In [None]:
# 2-0-1-0-2 (metabolite-microbe-disease-microbe-metabolite)
meta_micro_d_micro_meta = generate_long_relationship_array(
    relational_list=microbe_metabolite_list,
    intermediate_triplet=microbe_disease_microbe,
    num_offset2=(num_microbe + num_disease)
)

meta_micro_d_micro_meta = lexicographical_sort(meta_micro_d_micro_meta, [0, 2, 1, 3, 4])

In [None]:
# save 2-0-1-0-2 in batches
process_and_save_metapath_idx_in_batches(
    metapath_type=(2, 0, 1, 0, 2),
    metapath_array=meta_micro_d_micro_meta,
    target_idx_list=np.arange(num_metabolite),
    offset=num_microbe + num_disease,
    save_prefix=save_prefix,
    group_index=2
)

In [None]:
process_and_save_adjlist_in_batches(
    metapath_type=(2, 0, 1, 0, 2),
    metapath_array=meta_micro_d_micro_meta,
    target_idx_list=np.arange(num_metabolite),
    offset=num_microbe + num_disease,
    save_prefix=save_prefix,
    group_index=2,
)

In [None]:
del microbe_disease_microbe
del meta_micro_d_micro_meta

In [None]:
# 0-2-0 (microbe-metabolite-microbe)
microbe_metabolite_microbe = generate_triplet_array(metabolite_microbe_list)
microbe_metabolite_microbe[:, 1] += num_microbe + num_disease
microbe_metabolite_microbe = lexicographical_sort(microbe_metabolite_microbe, [0, 2, 1])

In [None]:
# save 0-2-0 in batches
process_and_save_metapath_idx_in_batches(
    metapath_type=(0, 2, 0),
    metapath_array=microbe_metabolite_microbe,
    target_idx_list=np.arange(num_microbe),
    offset=0,
    save_prefix=save_prefix,
    group_index=0
)

In [None]:
process_and_save_adjlist_in_batches(
    metapath_type=(0, 2, 0),
    metapath_array=microbe_metabolite_microbe,
    target_idx_list=np.arange(num_microbe),
    offset=0,
    save_prefix=save_prefix,
    group_index=0,
)

In [None]:
# 1-0-2-0-1 (disease-microbe-metabolite-microbe-disease)
d_micro_meta_micro_d = generate_long_relationship_array(
    relational_list=microbe_disease_list,
    intermediate_triplet=microbe_metabolite_microbe,
    num_offset2=num_microbe,
)

d_micro_meta_micro_d = lexicographical_sort(d_micro_meta_micro_d, [0, 2, 1, 3, 4])

In [None]:
# save 1-0-2-0-1 in batches
process_and_save_metapath_idx_in_batches(
    metapath_type=(1, 0, 2, 0, 1),
    metapath_array=d_micro_meta_micro_d,
    target_idx_list=np.arange(num_disease),
    offset=num_microbe,
    save_prefix=save_prefix,
    group_index=1
)

In [None]:
process_and_save_adjlist_in_batches(
    metapath_type=(1, 0, 2, 0, 1),
    metapath_array=d_micro_meta_micro_d,
    target_idx_list=np.arange(num_disease),
    offset=num_microbe,
    save_prefix=save_prefix,
    group_index=1,
)

In [None]:
del microbe_metabolite_microbe
del d_micro_meta_micro_d

In [None]:
# 1-2-1 (disease-metabolite-disease)
disease_metabolite_disease = generate_triplet_array(metabolite_disease_list)
disease_metabolite_disease[:, (0, 2)] += num_microbe
disease_metabolite_disease[:, 1] += num_microbe + num_disease
disease_metabolite_disease = lexicographical_sort(disease_metabolite_disease, [0, 2, 1])

In [None]:
# save 1-2-1 in batches
process_and_save_metapath_idx_in_batches(
    metapath_type=(1, 2, 1),
    metapath_array=disease_metabolite_disease,
    target_idx_list=np.arange(num_disease),
    offset=num_microbe,
    save_prefix=save_prefix,
    group_index=1
)

In [None]:
process_and_save_adjlist_in_batches(
    metapath_type=(1, 2, 1),
    metapath_array=disease_metabolite_disease,
    target_idx_list=np.arange(num_disease),
    offset=num_microbe,
    save_prefix=save_prefix,
    group_index=1,
)

In [None]:
# 0-1-2-1-0 (microbe-disease-metabolite-disease-microbe)
micro_d_meta_d_micro = generate_long_relationship_array(
    relational_list=disease_microbe_list,
    intermediate_triplet=disease_metabolite_disease,
    num_offset1=num_microbe
)

micro_d_meta_d_micro = lexicographical_sort(micro_d_meta_d_micro, [0, 2, 1, 3, 4])

In [None]:
# save 0-1-2-1-0 in batches
process_and_save_metapath_idx_in_batches(
    metapath_type=(0, 1, 2, 1, 0),
    metapath_array=micro_d_meta_d_micro,
    target_idx_list=np.arange(num_microbe),
    offset=0,
    save_prefix=save_prefix,
    group_index=0
)

In [None]:
process_and_save_adjlist_in_batches(
    metapath_type=(0, 1, 2, 1, 0),
    metapath_array=micro_d_meta_d_micro,
    target_idx_list=np.arange(num_microbe),
    offset=0,
    save_prefix=save_prefix,
    group_index=0,
)

In [None]:
del disease_metabolite_disease
del micro_d_meta_d_micro

In [None]:
# 2-1-2 (metabolite-disease-metabolite)
metabolite_disease_metabolite = generate_triplet_array(disease_metabolite_list)
metabolite_disease_metabolite[:, (0, 2)] += num_microbe + num_disease  
metabolite_disease_metabolite[:, 1] += num_microbe                 
metabolite_disease_metabolite = lexicographical_sort(metabolite_disease_metabolite, [0, 2, 1])

In [None]:
# save 2-1-2 in batches
process_and_save_metapath_idx_in_batches(
    metapath_type=(2, 1, 2),
    metapath_array=metabolite_disease_metabolite,
    target_idx_list=np.arange(num_metabolite),
    offset=num_microbe+num_disease,
    save_prefix=save_prefix,
    group_index=2
)

In [None]:
process_and_save_adjlist_in_batches(
    metapath_type=(2, 1, 2),
    metapath_array=metabolite_disease_metabolite,
    target_idx_list=np.arange(num_metabolite),
    offset=num_microbe + num_disease,
    save_prefix=save_prefix,
    group_index=2,
)

In [None]:
# 0-2-1-2-0 (microbe-metabolite-disease-metabolite-microbe)
micro_meta_d_meta_micro = generate_long_relationship_array(
    relational_list=metabolite_microbe_list,
    intermediate_triplet=metabolite_disease_metabolite,
    num_offset1=(num_microbe + num_disease)
)

micro_meta_d_meta_micro = lexicographical_sort(micro_meta_d_meta_micro, [0, 2, 1, 3, 4])

In [None]:
# save 0-2-1-2-0 in batches
process_and_save_metapath_idx_in_batches(
    metapath_type=(0, 2, 1, 2, 0),
    metapath_array=micro_meta_d_meta_micro,
    target_idx_list=np.arange(num_microbe),
    offset=0,
    save_prefix=save_prefix,
    group_index=0
)

In [None]:
process_and_save_adjlist_in_batches(
    metapath_type=(0, 2, 1, 2, 0),
    metapath_array=micro_meta_d_meta_micro,
    target_idx_list=np.arange(num_microbe),
    offset=0,
    save_prefix=save_prefix,
    group_index=0,
)

In [None]:
del metabolite_disease_metabolite
del micro_meta_d_meta_micro

In [None]:
# 1-0-1 (disease-microbe-disease)
disease_microbe_disease = generate_triplet_array(microbe_disease_list)
disease_microbe_disease[:, (0, 2)] += num_microbe
disease_microbe_disease = lexicographical_sort(disease_microbe_disease, [0, 2, 1])

In [None]:
# save 1-0-1 in batches
process_and_save_metapath_idx_in_batches(
    metapath_type=(1, 0, 1),
    metapath_array=disease_microbe_disease,
    target_idx_list=np.arange(num_disease),
    offset=num_microbe,
    save_prefix=save_prefix,
    group_index=1
)

In [None]:
process_and_save_adjlist_in_batches(
    metapath_type=(1, 0, 1),
    metapath_array=disease_microbe_disease,
    target_idx_list=np.arange(num_disease),
    offset=num_microbe,
    save_prefix=save_prefix,
    group_index=1,
)

In [None]:
# 2-1-0-1-2 (metabolite-disease-microbe-disease-metabolite)
meta_d_micro_d_meta = generate_long_relationship_array(
    relational_list=disease_metabolite_list,
    intermediate_triplet=disease_microbe_disease,
    num_offset1=num_microbe,
    num_offset2=(num_microbe + num_disease)
)

meta_d_micro_d_meta = lexicographical_sort(meta_d_micro_d_meta, [0, 2, 1, 3, 4])

In [None]:
# save 2-1-0-1-2 in batches
process_and_save_metapath_idx_in_batches(
    metapath_type=(2, 1, 0, 1, 2),
    metapath_array=meta_d_micro_d_meta,
    target_idx_list=np.arange(num_metabolite),
    offset=num_microbe + num_disease,
    save_prefix=save_prefix,
    group_index=2
)

In [None]:
process_and_save_adjlist_in_batches(
    metapath_type=(2, 1, 0, 1, 2),
    metapath_array=meta_d_micro_d_meta,
    target_idx_list=np.arange(num_metabolite),
    offset=num_microbe + num_disease,
    save_prefix=save_prefix,
    group_index=2,
)

In [None]:
del disease_microbe_disease
del meta_d_micro_d_meta

In [None]:
# 2-0-2 (metabolite-microbe-metabolite)
metabolite_microbe_metabolite = generate_triplet_array(microbe_metabolite_list)
metabolite_microbe_metabolite[:, (0, 2)] += num_microbe + num_disease
metabolite_microbe_metabolite = lexicographical_sort(metabolite_microbe_metabolite, [0, 2, 1])

In [None]:
# save 2-0-2 in batches
process_and_save_metapath_idx_in_batches(
    metapath_type=(2, 0, 2),
    metapath_array=metabolite_microbe_metabolite,
    target_idx_list=np.arange(num_metabolite),
    offset=num_microbe + num_disease,
    save_prefix=save_prefix,
    group_index=2
)

In [None]:
process_and_save_adjlist_in_batches(
    metapath_type=(2, 0, 2),
    metapath_array=metabolite_microbe_metabolite,
    target_idx_list=np.arange(num_metabolite),
    offset=num_microbe + num_disease,
    save_prefix=save_prefix,
    group_index=2,
)

In [None]:
# 1-2-0-2-1 (disease-metabolite-microbe-metabolite-disease)
d_meta_micro_meta_d = generate_long_relationship_array(
    relational_list=metabolite_disease_list,
    intermediate_triplet=metabolite_microbe_metabolite,
    num_offset1=(num_microbe + num_disease),
    num_offset2=num_microbe
)

d_meta_micro_meta_d = lexicographical_sort(d_meta_micro_meta_d, [0, 2, 1, 3, 4])

In [None]:
# save 1-2-0-2-1 in batches
process_and_save_metapath_idx_in_batches(
    metapath_type=(1, 2, 0, 2, 1),
    metapath_array=d_meta_micro_meta_d,
    target_idx_list=np.arange(num_disease),
    offset=num_microbe,
    save_prefix=save_prefix,
    group_index=1
)

In [None]:
process_and_save_adjlist_in_batches(
    metapath_type=(1, 2, 0, 2, 1),
    metapath_array=d_meta_micro_meta_d,
    target_idx_list=np.arange(num_disease),
    offset=num_microbe,
    save_prefix=save_prefix,
    group_index=1,
)

In [None]:
del metabolite_microbe_metabolite
del d_meta_micro_meta_d

In [None]:
# save scipy sparse adjM 
save_npz(save_prefix + 'adjM.npz', csr_matrix(adjM))
# save node type_mask
np.save(save_prefix + 'node_types.npy', type_mask)

In [None]:
# output microbe_disease.npy
microbe_disease = pd.read_csv('data/raw/microbe_disease_idx.dat', encoding='utf-8', delimiter='\t', names=['MicrobeID', 'DiseaseID'])
microbe_disease = microbe_disease[['MicrobeID', 'DiseaseID']].to_numpy()
np.save(save_prefix + 'microbe_disease.npy', microbe_disease)

## Split data into training, validation and testing sets

In [None]:
md_train, md_val, md_test = split_date(microbe_disease, train_ratio=0.7, val_ratio=0.2, test_ratio=0.1)
save_split_data2npz(md_train, md_val, md_test, "data/raw/micro_disease_train_val_test_idx.npz")

In [None]:
# training: 70%, validation: 20%, testing: 10%
train_val_test_idx = np.load("data/raw/micro_disease_train_val_test_idx.npz")
train_idx = train_val_test_idx['train']
val_idx = train_val_test_idx['val']
test_idx = train_val_test_idx['test']

# reset microbe-disease index 
microbe_disease = microbe_disease.loc[train_idx].reset_index(drop=True)
microbe_disease.head()
print(f"Length of Training data: {len(microbe_disease)}")

## Output positive and negative samples for training, validation and testing sets

In [None]:
# output positive and negative samples for training, validation and testing
np.random.seed(453289)
save_prefix = 'data/preprocessed/microbe_disease_neg_pos_processed/'
num_microbe = 8202
num_disease = 898
microbe_disease = np.load('data/preprocessed/microbe_disease.npy')
train_val_test_idx = np.load('data/raw/micro_disease_train_val_test_idx.npz')
train_idx = train_val_test_idx['train']
val_idx = train_val_test_idx['val']
test_idx = train_val_test_idx['test']

neg_candidates = []
counter = 0
for i in range(num_microbe):
    for j in range(num_disease):
        if counter < len(microbe_disease):
            if i == microbe_disease[counter, 0] and j == microbe_disease[counter, 1]:
                counter += 1
            else:
                neg_candidates.append([i, j])
        else:
            neg_candidates.append([i, j])
neg_candidates = np.array(neg_candidates)

idx = np.random.choice(len(neg_candidates), len(val_idx) + len(test_idx), replace=False)
val_neg_candidates = neg_candidates[sorted(idx[:len(val_idx)])]
test_neg_candidates = neg_candidates[sorted(idx[len(val_idx):])]

train_microbe_disease = microbe_disease[train_idx]
train_neg_candidates = []
counter = 0
for i in range(num_microbe):
    for j in range(num_disease):
        if counter < len(train_microbe_disease):
            if i == train_microbe_disease[counter, 0] and j == train_microbe_disease[counter, 1]:
                counter += 1
            else:
                train_neg_candidates.append([i, j])
        else:
            train_neg_candidates.append([i, j])
train_neg_candidates = np.array(train_neg_candidates)

np.savez(save_prefix + 'train_val_test_neg_microbe_disease.npz',
         train_neg_micro_dis=train_neg_candidates,
         val_neg_micro_dis=val_neg_candidates,
         test_neg_micro_dis=test_neg_candidates)
np.savez(save_prefix + 'train_val_test_pos_microbe_disease.npz',
         train_pos_micro_dis=microbe_disease[train_idx],
         val_pos_micro_dis=microbe_disease[val_idx],
         test_pos_micro_dis=microbe_disease[test_idx])