In [3]:
import pandas as pd
import numpy as np
import skbio
from skbio import TreeNode
import dendropy
# from torch_geometric.data import InMemoryDataset, Data

In [4]:
from sklearn.decomposition import PCA

In [5]:
import torch.nn as nn

In [6]:
import biom
from biom.util import biom_open
import torch
from tqdm import tqdm

In [7]:
import numpy as np
import pandas as pd
from skbio import TreeNode
from scipy.spatial.distance import squareform
from scipy.cluster.hierarchy import linkage, fcluster, leaves_list

In [8]:
from sklearn.model_selection import KFold

In [9]:
import matplotlib.pyplot as plt

In [10]:
from sklearn.preprocessing import LabelEncoder

In [11]:
kf = KFold(n_splits=5, shuffle=True, random_state=42)

In [12]:
def gaussian_kernel(d, rho=2):
    """
    weight = exp(-2 * ρ * d^2)
    """
    return np.exp(-2 * rho * d ** 2)

In [13]:
def build_graph_from_sample(otu_list, abundance_list, label, sample_id, full_tree):
    """
    [私有方法] 根据单个样本的数据构建一个 PyG Data 对象。
    """
    leaf_abundance_map = {otu: abund for otu, abund in zip(otu_list, abundance_list)}
    
    # tree_sample = full_tree.shear(otu_list)
    tree_sample = full_tree
    all_nodes_in_tree = list(tree_sample.traverse(self_before=True))
    node_to_idx = {node.name: i for i, node in enumerate(all_nodes_in_tree)}

    full_abundance_map = {}
    x = torch.zeros((len(all_nodes_in_tree), 1), dtype=torch.float)
    edge_list, edge_weight_list = [], []
    node_id = []
    for node in tree_sample.postorder():
        node_id.append(node_to_idx[node.name])
        if node.is_tip():
            full_abundance_map[node.name] = leaf_abundance_map.get(node.name, 0.0)
            x[node_to_idx[node.name]] = full_abundance_map[node.name]
        else:
            child_abundances = [full_abundance_map.get(child.name, 0.0) for child in node.children]
            full_abundance_map[node.name] = sum(child_abundances)
            x[node_to_idx[node.name]] = full_abundance_map[node.name]
            
        if not node.is_root():
            node_idx = node_to_idx[node.name]
            parent_idx = node_to_idx[node.parent.name]
            parent_idx = node_to_idx[node.parent.name]
            edge_list.extend([[parent_idx, node_idx], [node_idx, parent_idx]])
            distance = node.length if node.length is not None else 0.0
            weight = gaussian_kernel(distance, rho=5e-5) 
            edge_weight_list.extend([weight, weight])
            
    edge_index = torch.tensor(edge_list, dtype=torch.long).t().contiguous()
    y = torch.tensor([label], dtype=torch.long)
    x = torch.tensor(x, dtype=torch.float)
    edge_attr = torch.tensor(edge_weight_list, dtype=torch.float).view(-1, 1)
    node_id = torch.tensor(node_id, dtype=torch.long)
    
    data = Data(x=x, node_id=node_id, edge_index=edge_index, edge_attr=edge_attr, y=y, sample_id=sample_id)

    return data, x, edge_index, edge_attr, node_to_idx

In [208]:
newick_str = (
    "("
    "(L1:0.1, L2:0.1)I1:0.1,"  # 2 leaves, 1 internal
    "(L3:0.1,"
      "(L4:0.1,"
        "(L5:0.1,"
          "(L6:0.1,"
            "(L7:0.1,"
              "(L8:0.1,"
                "(L9:0.1, L10:0.1)I8:0.1" # 2 leaves, 1 internal
              ")I7:0.1"
            ")I6:0.1"
          ")I5:0.1"
        ")I4:0.1"
      ")I3:0.1"
    ")I2:0.1"
    ")Root;"
)

tree = TreeNode.read([newick_str], format='newick')
if tree.name is None: tree.name = "Root"
    
all_nodes = list(tree.traverse())
leaves = list(tree.tips())
internal_nodes = [n for n in all_nodes if not n.is_tip()]

print(f"--- 树结构统计 ---")
print(f"总节点数: {len(all_nodes)}")
print(f"叶子节点数: {len(leaves)} (预期10)")
print(f"内部节点数: {len(internal_nodes)} (预期10)")

--- 树结构统计 ---
总节点数: 19
叶子节点数: 10 (预期10)
内部节点数: 9 (预期10)


In [209]:
print(tree.ascii_art())

                    /-L1
          /I1------|
         |          \-L2
-Root----|
         |          /-L3
          \I2------|
                   |          /-L4
                    \I3------|
                             |          /-L5
                              \I4------|
                                       |          /-L6
                                        \I5------|
                                                 |          /-L7
                                                  \I6------|
                                                           |          /-L8
                                                            \I7------|
                                                                     |          /-L9
                                                                      \I8------|
                                                                                \-L10


In [197]:
embed = nn.Embedding(19, 8)

In [206]:
embed(torch.tensor([0,2,4]))

tensor([[ 0.8767, -0.1226, -0.0716,  0.1161,  0.6367, -0.1893,  1.1291, -0.4544],
        [-0.5683,  1.0945, -0.3673, -0.9997, -0.6321, -0.6187, -0.0289, -1.5410],
        [-1.5725,  1.5881, -0.4956, -1.1765, -0.2684,  0.6021, -1.9315, -2.3289]],
       grad_fn=<EmbeddingBackward0>)

In [210]:
# 构造叶子名称列表
otu_list = [node.name for node in leaves] # ['L1', 'L2', ... 'L10']


abundance_list = [0.1] * 10

sample_id = "Test_Sample_001"
label = 1

print(f"\n--- 输入数据 ---")
print(f"OTU List: {otu_list}")
print(f"Abundances: {abundance_list}")


--- 输入数据 ---
OTU List: ['L1', 'L2', 'L3', 'L4', 'L5', 'L6', 'L7', 'L8', 'L9', 'L10']
Abundances: [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]


In [211]:
data, x, edge_index, edge_attr, node_to_idx = build_graph_from_sample(otu_list, abundance_list, label, sample_id, tree)

  x = torch.tensor(x, dtype=torch.float)


In [167]:
full_tree = TreeNode.read("/home/dongbiao/GCN/data/dietary_fiber/phylogeny.nwk")
table = biom.load_table("/home/dongbiao/GCN/data/dietary_fiber/table.biom")
fid = table.ids(axis='observation')
df_metadata = pd.read_csv("/home/dongbiao/GCN/data/dietary_fiber/metadata.tsv", index_col=0, sep='\t')

table = table.norm(axis='sample', inplace=False)
table = table.to_dataframe().T

In [168]:
all_nodes_in_tree = [node.name for node in full_tree.postorder()]
len(all_nodes_in_tree)

4905

In [170]:
data_list = []
processed_sample_ids = []

for sample_id in tqdm(df_metadata.index):
    sample_data = np.array(table.loc[sample_id])

    # nonzero_indices = sample_data > 0  
    # valid_otus = fid[nonzero_indices]
    # sample_data = sample_data[nonzero_indices]
    label = df_metadata.loc[sample_id, "group"]
    
    data = build_graph_from_sample(fid, sample_data, label, sample_id, full_tree)
    
    if data is not None:
        data_list.append(data)
        processed_sample_ids.append(sample_id)

100%|██████████| 2270/2270 [01:57<00:00, 19.26it/s]


In [171]:
data

Data(x=[4905], edge_index=[2, 9808], edge_attr=[9808, 1], y=[1], sample_id='SRR6445039')

In [774]:
def filter_low_abundance_asvs(abundance_df, min_relative_abundance=0.001, min_sample_proportion=0.01):
    """
    """
    n_samples = abundance_df.shape[1]
    min_samples = int(np.ceil(n_samples * min_sample_proportion))  
    
    above_threshold = abundance_df >= min_relative_abundance
    samples_above_threshold = above_threshold.sum(axis=1)
    
    asv_to_keep = samples_above_threshold >= min_samples
    asv_retained = abundance_df.index[asv_to_keep].tolist()
    asv_removed = abundance_df.index[~asv_to_keep].tolist()
            
    return asv_retained, asv_removed

### IBD

In [939]:
table_1 = biom.load_table("/home/dongbiao/GCN/data/IBD/merged_table.biom")
table_2 = biom.load_table("/home/dongbiao/GCN/data/IBD/PRJNA324147.biom")
table_3 = biom.load_table("/home/dongbiao/GCN/data/IBD/PRJNA368966.biom")
table_4 = biom.load_table("/home/dongbiao/GCN/data/IBD/PRJNA450340.biom")

In [940]:
table = table_1.merge(table_2)
table = table.merge(table_3)
table = table.merge(table_4)

In [941]:
metadata = pd.read_csv("/home/dongbiao/GCN/data/IBD/metadata_copy.tsv", sep="\t", index_col = 0)

In [942]:
metadata = metadata.fillna("id")

In [943]:
metadata.loc[metadata.subject_id.values == "id", "subject_id"] = metadata.index.values[metadata.subject_id.values == "id"]

In [944]:
metadata.study.unique()

array(['RISK_PRISM_f', 'qiita_1629', 'qiita_2538', 'PRJNA324147',
       'PRJNA368966', 'PRJNA422193', 'PRJNA431126', 'PRJNA450340'],
      dtype=object)

In [945]:
metadata.head()

Unnamed: 0_level_0,study,diagnosis,group,age,sex,bmi,country,region,pcr_primers,platform,disease_name,title,disease_name_ab,site,subject_id
sample,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1
ERR1368879,RISK_PRISM_f,CD,1,19.0,male,id,515F-806R,V4,Illumina MiSeq,USA,id,id,IBD,feces,ERR1368879
ERR1368880,RISK_PRISM_f,CD,1,26.0,male,id,515F-806R,V4,Illumina MiSeq,USA,id,id,IBD,feces,ERR1368880
ERR1368881,RISK_PRISM_f,UC,1,55.0,male,id,515F-806R,V4,Illumina MiSeq,USA,id,id,IBD,feces,ERR1368881
ERR1368882,RISK_PRISM_f,CD,1,57.0,female,id,515F-806R,V4,Illumina MiSeq,USA,id,id,IBD,feces,ERR1368882
ERR1368883,RISK_PRISM_f,IC,1,46.0,male,id,515F-806R,V4,Illumina MiSeq,USA,id,id,IBD,feces,ERR1368883


In [946]:
table = table.filter(metadata.index.values, axis="sample")
table.remove_empty()
fid = table.ids(axis="observation")
prevalence = table.nonzero_counts(axis="observation")
table = table.filter(fid[prevalence > 10], axis='observation', inplace=False)
table.remove_empty()

2387 x 2491 <class 'biom.table.Table'> with 365600 nonzero entries (6% dense)

In [947]:
otu_retained, otu_removed = filter_low_abundance_asvs(table.norm(axis='sample', inplace=False).to_dataframe())

In [948]:
len(otu_retained)

805

In [949]:
table = table.filter(otu_retained, axis='observation', inplace=False)
table.remove_empty()

805 x 2491 <class 'biom.table.Table'> with 282395 nonzero entries (14% dense)

In [950]:
metadata = metadata.loc[table.ids(axis="sample")]

In [951]:
metadata.to_csv("/home/dongbiao/GCN/data/IBD/metadata.tsv", sep="\t")

In [952]:
with biom.util.biom_open("/home/dongbiao/GCN/data/IBD/table.biom", 'w') as f:
    table.to_hdf5(f, 'example')

In [953]:
table = biom.load_table("/home/dongbiao/GCN/data/IBD/table.biom")

In [954]:
table

805 x 2491 <class 'biom.table.Table'> with 282395 nonzero entries (14% dense)

In [955]:
metadata = pd.read_csv("/home/dongbiao/GCN/data/IBD/metadata.tsv", sep="\t", index_col=0)

In [898]:
study_id = metadata.study.unique()

In [957]:
# n = 1
# for i in study_id:
#     sid = metadata.loc[metadata.study != i].index.values
#     table_train = table.filter(sid, axis="sample", inplace=False)
#     with biom.util.biom_open(f"/home/dongbiao/GCN/data/IBD/data/train_{n}.biom", 'w') as f:
#         table_train.to_hdf5(f, 'example')
#     sid = metadata.loc[metadata.study == i].index.values
#     table_test = table.filter(sid, axis="sample", inplace=False)
#     with biom.util.biom_open(f"/home/dongbiao/GCN/data/IBD/data/test_{n}.biom", 'w') as f:
#         table_test.to_hdf5(f, 'example')
        
#     table_train = table_train.norm(axis='sample', inplace=False).to_dataframe().T
#     table_train.loc[:, "Group"] = metadata.loc[table_train.index.values, "group"]
#     table_train.to_csv(f"/home/dongbiao/GCN/data/IBD/data/train_{n}.csv")
#     table_test = table_test.norm(axis='sample', inplace=False).to_dataframe().T
#     table_test.loc[:, "Group"] = metadata.loc[table_test.index.values, "group"]
#     table_test.to_csv(f"/home/dongbiao/GCN/data/IBD/data/test_{n}.csv")
    
#     n = n + 1
n = 1
suject_id = metadata.subject_id.unique()
for fold, (train_idx, test_idx) in enumerate(kf.split(suject_id)):

    train_sid = metadata.index.values[[i in suject_id[train_idx] for i in metadata.subject_id.values]]
    test_sid = metadata.index.values[[i in suject_id[test_idx] for i in metadata.subject_id.values]]
    
    table_train = table.filter(train_sid, axis="sample", inplace=False)
    with biom.util.biom_open(f"/home/dongbiao/GCN/data/IBD/data/train_{n}.biom", 'w') as f:
        table_train.to_hdf5(f, 'example')
    table_test = table.filter(test_sid, axis="sample", inplace=False)
    with biom.util.biom_open(f"/home/dongbiao/GCN/data/IBD/data/test_{n}.biom", 'w') as f:
        table_test.to_hdf5(f, 'example')
        
    table_train = table_train.norm(axis='sample', inplace=False).to_dataframe().T
    table_train.loc[:, "Group"] = metadata.loc[table_train.index.values, "group"]
    table_train.to_csv(f"/home/dongbiao/GCN/data/IBD/data/train_{n}.csv")
    table_test = table_test.norm(axis='sample', inplace=False).to_dataframe().T
    table_test.loc[:, "Group"] = metadata.loc[table_test.index.values, "group"]
    table_test.to_csv(f"/home/dongbiao/GCN/data/IBD/data/test_{n}.csv")
    
    n = n + 1

### IBD version 2

In [977]:
table = pd.read_csv("/home/dongbiao/software/Phylo-Spec/data/Real_Dateset_16S_IBD/16S_IBD.csv")

In [979]:
metadata = pd.DataFrame({"sample_id": table.iloc[:, 0].values, "group": table["Group"].values})
metadata["group"] = [0 if i == "HC" else 1 for i in metadata["group"].values]
metadata.to_csv("/home/dongbiao/software/Phylo-Spec/data/Real_Dateset_16S_IBD/metadata.tsv", sep="\t", index=False)

In [980]:
table = biom.Table(data=table.iloc[:, 1:417].values.T, observation_ids=table.iloc[:, 1:417].columns.values, sample_ids=table.iloc[:, 0].values)
with biom.util.biom_open("/home/dongbiao/software/Phylo-Spec/data/Real_Dateset_16S_IBD/table.biom", 'w') as f:
    table.to_hdf5(f, 'example')

In [981]:
metadata = metadata.set_index("sample_id")

In [982]:
table = biom.load_table("/home/dongbiao/software/Phylo-Spec/data/Real_Dateset_16S_IBD/table.biom")
kf = KFold(n_splits=5, shuffle=True, random_state=42)
sid = table.ids(axis="sample")
n = 1
for train_idx, test_idx in kf.split(sid):
    table_train = table.filter(sid[train_idx], axis="sample", inplace=False)
    table_train.remove_empty()
    with biom.util.biom_open(f"/home/dongbiao/software/Phylo-Spec/data/Real_Dateset_16S_IBD/data/train_{n}.biom", 'w') as f:
        table_train.to_hdf5(f, 'example')
    
    table_test = table.filter(sid[test_idx], axis="sample", inplace=False)
    table_test.remove_empty()
    with biom.util.biom_open(f"/home/dongbiao/software/Phylo-Spec/data/Real_Dateset_16S_IBD/data/test_{n}.biom", 'w') as f:
        table_test.to_hdf5(f, 'example')

    table_train = table_train.to_dataframe().T
    table_train.loc[:, "Group"] = metadata.loc[table_train.index.values, "group"]
    table_train.to_csv(f"/home/dongbiao/software/Phylo-Spec/data/Real_Dateset_16S_IBD/data/train_{n}.csv")
    table_test = table_test.to_dataframe().T
    table_test.loc[:, "Group"] = metadata.loc[table_test.index.values, "group"]
    table_test.to_csv(f"/home/dongbiao/software/Phylo-Spec/data/Real_Dateset_16S_IBD/data/test_{n}.csv")
        
    n = n + 1

In [336]:
for i in range(1, 6):
    table_train = biom.load_table(f"/home/dongbiao/software/Phylo-Spec/data/Real_Dateset_16S_IBD/data/train_{i}.biom")
    table_test = biom.load_table(f"/home/dongbiao/software/Phylo-Spec/data/Real_Dateset_16S_IBD/data/test_{i}.biom")
    for m in ["without_low", "without_high"]:
        for n in [0.1, 0.2, 0.4, 0.5, 0.6, 0.8]:
            tree = TreeNode.read(f"/home/dongbiao/software/Phylo-Spec/data/Real_Dateset_16S_IBD/phylogeny_{m}_{n}.nwk")
            leaf_names = [node.name for node in tree.tips()]
            
            table_train_pick = table_train.filter(leaf_names, axis="observation", inplace=False)
            with biom.util.biom_open(f"/home/dongbiao/software/Phylo-Spec/data/Real_Dateset_16S_IBD/data/train_{m}_{n}_{i}.biom", 'w') as f:
                table_train_pick.to_hdf5(f, 'example')
                
            table_test_pick = table_test.filter(leaf_names, axis="observation", inplace=False)
            with biom.util.biom_open(f"/home/dongbiao/software/Phylo-Spec/data/Real_Dateset_16S_IBD/data/test_{m}_{n}_{i}.biom", 'w') as f:
                table_test_pick.to_hdf5(f, 'example')

### CRC

In [825]:
table_1 = biom.load_table("/home/dongbiao/GCN/data/CRC/merged_table.biom")
table_2 = biom.load_table("/home/dongbiao/GCN/data/CRC/PRJDB11845.biom")
table_3 = biom.load_table("/home/dongbiao/GCN/data/CRC/PRJEB6070.biom")
table_4 = biom.load_table("/home/dongbiao/GCN/data/CRC/PRJNA290926.biom")

In [826]:
table = table_1.merge(table_2)
table = table.merge(table_3)
table = table.merge(table_4)

In [827]:
metadata = pd.read_csv("/home/dongbiao/GCN/data/CRC/metadata_copy.tsv", sep="\t", index_col = 0)

In [828]:
fid = table.ids(axis="observation")
prevalence = table.nonzero_counts(axis="observation")
table = table.filter(fid[prevalence > 5], axis='observation', inplace=False)
table.remove_empty()

3041 x 1016 <class 'biom.table.Table'> with 174928 nonzero entries (5% dense)

In [829]:
otu_retained, otu_removed = filter_low_abundance_asvs(table.norm(axis='sample', inplace=False).to_dataframe())

In [830]:
table = table.filter(otu_retained, axis='observation', inplace=False)
table.remove_empty()

1105 x 1015 <class 'biom.table.Table'> with 136819 nonzero entries (12% dense)

In [831]:
metadata = metadata.loc[table.ids(axis="sample")]
metadata.to_csv("/home/dongbiao/GCN/data/CRC/metadata.tsv", sep="\t")

In [832]:
with biom.util.biom_open("/home/dongbiao/GCN/data/CRC/table.biom", 'w') as f:
    table.to_hdf5(f, 'example')

In [762]:
table = biom.load_table("/home/dongbiao/GCN/data/CRC/table.biom")

In [833]:
study_id = metadata.study.unique()
n = 1
for i in study_id:
    sid = metadata.loc[metadata.study != i].index.values
    table_train = table.filter(sid, axis="sample", inplace=False)
    with biom.util.biom_open(f"/home/dongbiao/GCN/data/CRC/data/train_{n}.biom", 'w') as f:
        table_train.to_hdf5(f, 'example')
    sid = metadata.loc[metadata.study == i].index.values
    table_test = table.filter(sid, axis="sample", inplace=False)
    with biom.util.biom_open(f"/home/dongbiao/GCN/data/CRC/data/test_{n}.biom", 'w') as f:
        table_test.to_hdf5(f, 'example')

    table_train = table_train.norm(axis='sample', inplace=False).to_dataframe().T
    table_train.loc[:, "Group"] = metadata.loc[table_train.index.values, "group"]
    table_train.to_csv(f"/home/dongbiao/GCN/data/CRC/data/train_{n}.csv")
    table_test = table_test.norm(axis='sample', inplace=False).to_dataframe().T
    table_test.loc[:, "Group"] = metadata.loc[table_test.index.values, "group"]
    table_test.to_csv(f"/home/dongbiao/GCN/data/CRC/data/test_{n}.csv")
    
    n = n + 1

### CRC 16S

In [1028]:
table = pd.read_csv("/home/dongbiao/software/Phylo-Spec/data/Real_Dateset_16S_CRC/16S_CRC.csv")

metadata = pd.DataFrame({"sample_id": table.iloc[:, 0].values, "group": table["Group"].values})
metadata["group"] = [0 if i == "HC" else 1 for i in metadata["group"].values]
metadata.to_csv("/home/dongbiao/software/Phylo-Spec/data/Real_Dateset_16S_CRC/metadata.tsv", sep="\t", index=False)

table = biom.Table(data=table.iloc[:, 1:-1].values.T, observation_ids=table.iloc[:, 1:-1].columns.values, sample_ids=table.iloc[:, 0].values)
with biom.util.biom_open("/home/dongbiao/software/Phylo-Spec/data/Real_Dateset_16S_CRC/table.biom", 'w') as f:
    table.to_hdf5(f, 'example')

metadata = metadata.set_index("sample_id")
table = biom.load_table("/home/dongbiao/software/Phylo-Spec/data/Real_Dateset_16S_CRC/table.biom")

kf = KFold(n_splits=5, shuffle=True, random_state=42)
sid = table.ids(axis="sample")
n = 1
for train_idx, test_idx in kf.split(sid):
    table_train = table.filter(sid[train_idx], axis="sample", inplace=False)
    with biom.util.biom_open(f"/home/dongbiao/software/Phylo-Spec/data/Real_Dateset_16S_CRC/data/train_{n}.biom", 'w') as f:
        table_train.to_hdf5(f, 'example')
    
    table_test = table.filter(sid[test_idx], axis="sample", inplace=False)
    with biom.util.biom_open(f"/home/dongbiao/software/Phylo-Spec/data/Real_Dateset_16S_CRC/data/test_{n}.biom", 'w') as f:
        table_test.to_hdf5(f, 'example')

    table_train = table_train.to_dataframe().T
    table_train.loc[:, "Group"] = metadata.loc[table_train.index.values, "group"]
    table_train.to_csv(f"/home/dongbiao/software/Phylo-Spec/data/Real_Dateset_16S_CRC/data/train_{n}.csv")
    table_test = table_test.to_dataframe().T
    table_test.loc[:, "Group"] = metadata.loc[table_test.index.values, "group"]
    table_test.to_csv(f"/home/dongbiao/software/Phylo-Spec/data/Real_Dateset_16S_CRC/data/test_{n}.csv")
        
    n = n + 1

In [334]:
for i in range(1, 6):
    table_train = biom.load_table(f"/home/dongbiao/software/Phylo-Spec/data/Real_Dateset_16S_CRC/data/train_{i}.biom")
    table_test = biom.load_table(f"/home/dongbiao/software/Phylo-Spec/data/Real_Dateset_16S_CRC/data/test_{i}.biom")
    for m in ["without_low", "without_high"]:
        for n in [0.1, 0.2, 0.4, 0.5, 0.6, 0.8]:
            tree = TreeNode.read(f"/home/dongbiao/software/Phylo-Spec/data/Real_Dateset_16S_CRC/phylogeny_{m}_{n}.nwk")
            leaf_names = [node.name for node in tree.tips()]
            
            table_train_pick = table_train.filter(leaf_names, axis="observation", inplace=False)
            with biom.util.biom_open(f"/home/dongbiao/software/Phylo-Spec/data/Real_Dateset_16S_CRC/data/train_{m}_{n}_{i}.biom", 'w') as f:
                table_train_pick.to_hdf5(f, 'example')
                
            table_test_pick = table_test.filter(leaf_names, axis="observation", inplace=False)
            with biom.util.biom_open(f"/home/dongbiao/software/Phylo-Spec/data/Real_Dateset_16S_CRC/data/test_{m}_{n}_{i}.biom", 'w') as f:
                table_test_pick.to_hdf5(f, 'example')

### CRC WGS

In [1031]:
table = pd.read_csv("/home/dongbiao/software/Phylo-Spec/data/Real_Dateset_WGS_CRC/WGS_CRC.csv", sep="\t")

metadata = pd.DataFrame({"sample_id": table.iloc[:, 0].values, "group": table["Group"].values})
metadata["group"] = [0 if i == "CTR" else 1 for i in metadata["group"].values]
metadata.to_csv("/home/dongbiao/software/Phylo-Spec/data/Real_Dateset_WGS_CRC/metadata.tsv", sep="\t", index=False)

fid = table.iloc[:, 1:-1].columns.values.astype(str)
fid = np.char.replace(fid, '__', '.')
fid = np.char.replace(fid, '_', '.')
table = biom.Table(data=table.iloc[:, 1:-1].values.T, observation_ids=fid, sample_ids=table.iloc[:, 0].values)
with biom.util.biom_open("/home/dongbiao/software/Phylo-Spec/data/Real_Dateset_WGS_CRC/table.biom", 'w') as f:
    table.to_hdf5(f, 'example')

metadata = metadata.set_index("sample_id")
table = biom.load_table("/home/dongbiao/software/Phylo-Spec/data/Real_Dateset_WGS_CRC/table.biom")

kf = KFold(n_splits=5, shuffle=True, random_state=42)
sid = table.ids(axis="sample")
n = 1
for train_idx, test_idx in kf.split(sid):
    table_train = table.filter(sid[train_idx], axis="sample", inplace=False)
    with biom.util.biom_open(f"/home/dongbiao/software/Phylo-Spec/data/Real_Dateset_WGS_CRC/data/train_{n}.biom", 'w') as f:
        table_train.to_hdf5(f, 'example')
    
    table_test = table.filter(sid[test_idx], axis="sample", inplace=False)
    with biom.util.biom_open(f"/home/dongbiao/software/Phylo-Spec/data/Real_Dateset_WGS_CRC/data/test_{n}.biom", 'w') as f:
        table_test.to_hdf5(f, 'example')

    table_train = table_train.to_dataframe().T
    table_train.loc[:, "Group"] = metadata.loc[table_train.index.values, "group"]
    table_train.to_csv(f"/home/dongbiao/software/Phylo-Spec/data/Real_Dateset_WGS_CRC/data/train_{n}.csv")
    table_test = table_test.to_dataframe().T
    table_test.loc[:, "Group"] = metadata.loc[table_test.index.values, "group"]
    table_test.to_csv(f"/home/dongbiao/software/Phylo-Spec/data/Real_Dateset_WGS_CRC/data/test_{n}.csv")
        
    n = n + 1

In [325]:
for i in range(1, 6):
    table_train = biom.load_table(f"/home/dongbiao/software/Phylo-Spec/data/Real_Dateset_WGS_CRC/data/train_{i}.biom")
    table_test = biom.load_table(f"/home/dongbiao/software/Phylo-Spec/data/Real_Dateset_WGS_CRC/data/test_{i}.biom")
    for m in ["without_low", "without_high"]:
        for n in [0.1, 0.2, 0.4, 0.5, 0.6, 0.8]:
            tree = TreeNode.read(f"/home/dongbiao/software/Phylo-Spec/data/Real_Dateset_WGS_CRC/phylogeny_{m}_{n}.nwk")
            leaf_names = [node.name for node in tree.tips()]
            
            table_train_pick = table_train.filter(leaf_names, axis="observation", inplace=False)
            with biom.util.biom_open(f"/home/dongbiao/software/Phylo-Spec/data/Real_Dateset_WGS_CRC/data/train_{m}_{n}_{i}.biom", 'w') as f:
                table_train_pick.to_hdf5(f, 'example')
                
            table_test_pick = table_test.filter(leaf_names, axis="observation", inplace=False)
            with biom.util.biom_open(f"/home/dongbiao/software/Phylo-Spec/data/Real_Dateset_WGS_CRC/data/test_{m}_{n}_{i}.biom", 'w') as f:
                table_test_pick.to_hdf5(f, 'example')

### T2D WGS

In [1043]:
table = pd.read_csv("/home/dongbiao/software/Phylo-Spec/data/Real_Dateset_WGS_T2D/WGS_T2D.csv")

metadata = pd.DataFrame({"sample_id": table.iloc[:, 0].values, "group": table["Group"].values})
metadata["group"] = [0 if i == "Health" else 1 for i in metadata["group"].values]
metadata.to_csv("/home/dongbiao/software/Phylo-Spec/data/Real_Dateset_WGS_T2D/metadata.tsv", sep="\t", index=False)

fid = table.iloc[:, 1:-1].columns.values.astype(str)
fid = np.char.replace(fid, '__', '.')
fid = np.char.replace(fid, '_', '.')
table = biom.Table(data=table.iloc[:, 1:-1].values.T, observation_ids=fid, sample_ids=table.iloc[:, 0].values)
with biom.util.biom_open("/home/dongbiao/software/Phylo-Spec/data/Real_Dateset_WGS_T2D/table.biom", 'w') as f:
    table.to_hdf5(f, 'example')

metadata = metadata.set_index("sample_id")
table = biom.load_table("/home/dongbiao/software/Phylo-Spec/data/Real_Dateset_WGS_T2D/table.biom")

kf = KFold(n_splits=5, shuffle=True, random_state=42)
sid = table.ids(axis="sample")
n = 1
for train_idx, test_idx in kf.split(sid):
    table_train = table.filter(sid[train_idx], axis="sample", inplace=False)
    with biom.util.biom_open(f"/home/dongbiao/software/Phylo-Spec/data/Real_Dateset_WGS_T2D/data/train_{n}.biom", 'w') as f:
        table_train.to_hdf5(f, 'example')
    
    table_test = table.filter(sid[test_idx], axis="sample", inplace=False)
    with biom.util.biom_open(f"/home/dongbiao/software/Phylo-Spec/data/Real_Dateset_WGS_T2D/data/test_{n}.biom", 'w') as f:
        table_test.to_hdf5(f, 'example')

    table_train = table_train.to_dataframe().T
    table_train.loc[:, "Group"] = metadata.loc[table_train.index.values, "group"]
    table_train.to_csv(f"/home/dongbiao/software/Phylo-Spec/data/Real_Dateset_WGS_T2D/data/train_{n}.csv")
    table_test = table_test.to_dataframe().T
    table_test.loc[:, "Group"] = metadata.loc[table_test.index.values, "group"]
    table_test.to_csv(f"/home/dongbiao/software/Phylo-Spec/data/Real_Dateset_WGS_T2D/data/test_{n}.csv")
        
    n = n + 1

In [331]:
for i in range(1, 6):
    table_train = biom.load_table(f"/home/dongbiao/software/Phylo-Spec/data/Real_Dateset_WGS_T2D/data/train_{i}.biom")
    table_test = biom.load_table(f"/home/dongbiao/software/Phylo-Spec/data/Real_Dateset_WGS_T2D/data/test_{i}.biom")
    for m in ["without_low", "without_high"]:
        for n in [0.1, 0.2, 0.4, 0.5, 0.6, 0.8]:
            tree = TreeNode.read(f"/home/dongbiao/software/Phylo-Spec/data/Real_Dateset_WGS_T2D/phylogeny_{m}_{n}.nwk")
            leaf_names = [node.name for node in tree.tips()]
            
            table_train_pick = table_train.filter(leaf_names, axis="observation", inplace=False)
            with biom.util.biom_open(f"/home/dongbiao/software/Phylo-Spec/data/Real_Dateset_WGS_T2D/data/train_{m}_{n}_{i}.biom", 'w') as f:
                table_train_pick.to_hdf5(f, 'example')
                
            table_test_pick = table_test.filter(leaf_names, axis="observation", inplace=False)
            with biom.util.biom_open(f"/home/dongbiao/software/Phylo-Spec/data/Real_Dateset_WGS_T2D/data/test_{m}_{n}_{i}.biom", 'w') as f:
                table_test_pick.to_hdf5(f, 'example')

### Dietary fiber

In [277]:
metadata = pd.read_excel("/home/dongbiao/word_embedding_microbiome/programe_test/dietary_fber/data/metadata.xls", index_col=0)
table = biom.load_table("/home/dongbiao/word_embedding_microbiome/programe_test/dietary_fber/data/projects/table_6721_2378.biom")
metadata = metadata.loc[table.ids(axis="sample")]
metadata = metadata.loc[metadata.study.values != "PRJEB2165"]
metadata = metadata.loc[metadata.study.values != "PRJNA385004"]
table = table.filter(metadata.index.values)
table.remove_empty()

6693 x 2270 <class 'biom.table.Table'> with 313231 nonzero entries (2% dense)

In [278]:
fid = table.ids(axis="observation")
prevalence = table.nonzero_counts(axis="observation")
table = table.filter(fid[prevalence > 50], axis='observation', inplace=False)
table.remove_empty()

930 x 2270 <class 'biom.table.Table'> with 252062 nonzero entries (11% dense)

In [279]:
metadata = metadata.loc[table.ids(axis="sample")]
table = table.norm(axis='sample', inplace=False)

In [280]:
with biom.util.biom_open(f"/home/dongbiao/GCN/data/dietary_fiber/table.biom", 'w') as f:
    table.to_hdf5(f, 'example')

In [281]:
kf = KFold(n_splits=5, shuffle=True, random_state=42)
sid = table.ids(axis="sample")
n = 1
for train_idx, test_idx in kf.split(sid):
    table_train = table.filter(sid[train_idx], axis="sample", inplace=False)
    with biom.util.biom_open(f"/home/dongbiao/GCN/data/dietary_fiber/data/train_{n}.biom", 'w') as f:
        table_train.to_hdf5(f, 'example')
    
    table_test = table.filter(sid[test_idx], axis="sample", inplace=False)
    with biom.util.biom_open(f"/home/dongbiao/GCN/data/dietary_fiber/data/test_{n}.biom", 'w') as f:
        table_test.to_hdf5(f, 'example')

    table_train = table_train.to_dataframe().T
    table_train.loc[:, "Group"] = metadata.loc[table_train.index.values, "group"]
    table_train.to_csv(f"/home/dongbiao/GCN/data/dietary_fiber/data/train_{n}.csv")
    table_test = table_test.to_dataframe().T
    table_test.loc[:, "Group"] = metadata.loc[table_test.index.values, "group"]
    table_test.to_csv(f"/home/dongbiao/GCN/data/dietary_fiber/data/test_{n}.csv")
        
    n = n + 1

In [324]:
for i in range(1, 6):
    table_train = biom.load_table(f"/home/dongbiao/GCN/data/dietary_fiber/data/train_{i}.biom")
    table_test = biom.load_table(f"/home/dongbiao/GCN/data/dietary_fiber/data/test_{i}.biom")
    for m in ["without_low", "without_high"]:
        for n in [0.1, 0.2, 0.4, 0.5, 0.6, 0.8]:
            tree = TreeNode.read(f"/home/dongbiao/GCN/data/dietary_fiber/phylogeny_{m}_{n}.nwk")
            leaf_names = [node.name for node in tree.tips()]
            
            table_train_pick = table_train.filter(leaf_names, axis="observation", inplace=False)
            with biom.util.biom_open(f"/home/dongbiao/GCN/data/dietary_fiber/data/train_{m}_{n}_{i}.biom", 'w') as f:
                table_train_pick.to_hdf5(f, 'example')
                
            table_test_pick = table_test.filter(leaf_names, axis="observation", inplace=False)
            with biom.util.biom_open(f"/home/dongbiao/GCN/data/dietary_fiber/data/test_{m}_{n}_{i}.biom", 'w') as f:
                table_test_pick.to_hdf5(f, 'example')

In [250]:
study_id = metadata.study.unique()
study_id

array(['PRJEB41443', 'PRJNA560950', 'PRJNA780023', 'PRJNA891951',
       'PRJNA293971', 'PRJNA306884', 'PRJNA428736'], dtype=object)

In [251]:
# PRJEB41443
temp = metadata.loc[metadata.study == study_id[0]]
sid_1 = temp.loc[temp.timepoint_numeric == 1].index.values
sid_2 = temp.loc[temp.group == 1].index.values
PRJEB41443_sid = list(sid_1) + list(sid_2)

In [252]:
# PRJNA560950
metadata.loc[metadata.study == study_id[1]].groupby(["subject_id", "group", "timepoint_id"]).count()
temp = metadata.loc[metadata.study == study_id[1]]
PRJNA560950_sid = temp.loc[temp.timepoint_id == "W4"].index.tolist()

In [253]:
# PRJNA780023
temp = metadata.loc[metadata.study == study_id[2]]
sid_1 = temp.loc[temp.group == 0].drop_duplicates(subset='subject_id').index.values
sid_2 = temp.loc[temp.group == 1].drop_duplicates(subset='subject_id').index.values
PRJNA780023_sid = list(sid_1) + list(sid_2) 

In [254]:
# PRJNA891951
temp = metadata.loc[metadata.study == study_id[3]]
PRJNA891951_sid = temp.loc[temp.timepoint == "after"].index.tolist()

In [255]:
# PRJNA293971
temp = metadata.loc[metadata.study == study_id[4]]
PRJNA293971_sid = temp.loc[[i in ["before_1", "after_3"] for i in temp.timepoint_id.values]].index.tolist()

In [256]:
# PRJNA306884
temp = metadata.loc[metadata.study == study_id[5]]
PRJNA306884_sid = temp.loc[[i in [4, 8] for i in temp.timepoint_numeric.values]].index.tolist()

In [257]:
# PRJNA428736
temp = metadata.loc[metadata.study == study_id[6]]
sid_1 = temp.loc[temp.group == 0].drop_duplicates(subset='subject_id').index.values
sid_2 = temp.loc[temp.group == 1].drop_duplicates(subset='subject_id').index.values
PRJNA428736_sid = list(sid_1) + list(sid_2) 

In [258]:
pick_sid = PRJEB41443_sid + PRJNA560950_sid + PRJNA780023_sid + PRJNA891951_sid + PRJNA293971_sid + PRJNA306884_sid + PRJNA428736_sid
table.filter(pick_sid, axis="sample")
table.remove_empty()

6235 x 731 <class 'biom.table.Table'> with 105646 nonzero entries (2% dense)

In [259]:
fid = table.ids(axis="observation")
prevalence = table.nonzero_counts(axis="observation")
table = table.filter(fid[prevalence > 20], axis='observation', inplace=False)
table.remove_empty()

936 x 731 <class 'biom.table.Table'> with 80180 nonzero entries (11% dense)

In [260]:
with biom.util.biom_open(f"/home/dongbiao/GCN/data/dietary_fiber/table.biom", 'w') as f:
    table.to_hdf5(f, 'example')

In [261]:
metadata = metadata.loc[table.ids(axis="sample")]

In [262]:
table = table.norm(axis='sample', inplace=False)

In [263]:
kf = KFold(n_splits=5, shuffle=True, random_state=42)
sid = table.ids(axis="sample")
n = 1
for train_idx, test_idx in kf.split(sid):
    table_train = table.filter(sid[train_idx], axis="sample", inplace=False)
    with biom.util.biom_open(f"/home/dongbiao/GCN/data/dietary_fiber/data/train_{n}.biom", 'w') as f:
        table_train.to_hdf5(f, 'example')
    
    table_test = table.filter(sid[test_idx], axis="sample", inplace=False)
    with biom.util.biom_open(f"/home/dongbiao/GCN/data/dietary_fiber/data/test_{n}.biom", 'w') as f:
        table_test.to_hdf5(f, 'example')

    table_train = table_train.to_dataframe().T
    table_train.loc[:, "Group"] = metadata.loc[table_train.index.values, "group"]
    table_train.to_csv(f"/home/dongbiao/GCN/data/dietary_fiber/data/train_{n}.csv")
    table_test = table_test.to_dataframe().T
    table_test.loc[:, "Group"] = metadata.loc[table_test.index.values, "group"]
    table_test.to_csv(f"/home/dongbiao/GCN/data/dietary_fiber/data/test_{n}.csv")
        
    n = n + 1

In [308]:
for i in range(1, 6):
    table_train = biom.load_table(f"/home/dongbiao/GCN/data/dietary_fiber/data/train_{i}.biom")
    table_test = biom.load_table(f"/home/dongbiao/GCN/data/dietary_fiber/data/test_{i}.biom")
    for m in ["without_low", "without_high"]:
        for n in [0.1, 0.2, 0.4, 0.5, 0.6, 0.8]:
            tree = TreeNode.read(f"/home/dongbiao/GCN/data/dietary_fiber/phylogeny_{m}_{n}.nwk")
            leaf_names = [node.name for node in tree.tips()]
            
            table_train_pick = table_train.filter(leaf_names, axis="observation", inplace=False)
            with biom.util.biom_open(f"/home/dongbiao/GCN/data/dietary_fiber/data/train_{m}_{n}_{i}.biom", 'w') as f:
                table_train_pick.to_hdf5(f, 'example')
                
            table_test_pick = table_test.filter(leaf_names, axis="observation", inplace=False)
            with biom.util.biom_open(f"/home/dongbiao/GCN/data/dietary_fiber/data/test_{m}_{n}_{i}.biom", 'w') as f:
                table_test_pick.to_hdf5(f, 'example')

### OSCC

In [1046]:
table = biom.load_table("/home/dongbiao/word_embedding_microbiome/all_data/oral/global-oral-microbiome/data/table_gg2_id99_24409_56564_v5.biom")

In [1047]:
metadata = pd.read_csv("/home/dongbiao/GCN/data/OSCC/metadata.txt", sep="\t", index_col=0, low_memory=False)
metadata = metadata.loc[metadata.diagnose != "OPMD"]
metadata.loc[:, "group"] = [1 if i == "OSCC" else 0 for i in metadata.diagnose.values]

In [1048]:
table_filter = None  
for i, study_id in enumerate(metadata.project_ID.unique()):
    metadata_study = metadata.loc[metadata.project_ID.values == study_id]
    sid = np.intersect1d(metadata_study.index.values, table.ids(axis="sample"))
    temp = table.filter(sid, inplace=False)
    temp.remove_empty()
    
    fid = temp.ids(axis="observation")
    prevalence = temp.nonzero_counts(axis="observation")
    temp = temp.filter(fid[prevalence > 10], axis='observation', inplace=False)
    temp.remove_empty()
    
    if table_filter is None:
        table_filter = temp
    else:
        table_filter = table_filter.merge(temp)

In [1049]:
table_filter

864 x 631 <class 'biom.table.Table'> with 88061 nonzero entries (16% dense)

In [1050]:
otu_retained, otu_removed = filter_low_abundance_asvs(table_filter.norm(axis='sample', inplace=False).to_dataframe())
table_filter = table_filter.filter(otu_retained, axis='observation', inplace=False)
table_filter.remove_empty()

569 x 631 <class 'biom.table.Table'> with 79776 nonzero entries (22% dense)

In [1051]:
with biom.util.biom_open("/home/dongbiao/GCN/data/OSCC/table.biom", 'w') as f:
    table_filter.to_hdf5(f, 'example')

In [862]:
metadata = metadata.loc[table_filter.ids(axis="sample")]
metadata.to_csv("/home/dongbiao/GCN/data/OSCC/metadata.tsv", sep="\t")

In [1052]:
metadata = pd.read_csv("/home/dongbiao/GCN/data/OSCC/metadata.tsv", sep="\t", index_col=0)
table = biom.load_table("/home/dongbiao/GCN/data/OSCC/table.biom")

In [1053]:
# study_id = metadata.project_ID.unique()
# n = 1
# for i in study_id:
#     sid = metadata.loc[metadata.project_ID != i].index.values
#     table_train = table.filter(sid, axis="sample", inplace=False)
#     with biom.util.biom_open(f"/home/dongbiao/GCN/data/OSCC/data/train_{n}.biom", 'w') as f:
#         table_train.to_hdf5(f, 'example')
#     sid = metadata.loc[metadata.project_ID == i].index.values
#     table_test = table.filter(sid, axis="sample", inplace=False)
#     with biom.util.biom_open(f"/home/dongbiao/GCN/data/OSCC/data/test_{n}.biom", 'w') as f:
#         table_test.to_hdf5(f, 'example')

#     table_train = table_train.norm(axis='sample', inplace=False).to_dataframe().T
#     table_train.loc[:, "Group"] = metadata.loc[table_train.index.values, "group"]
#     table_train.to_csv(f"/home/dongbiao/GCN/data/OSCC/data/train_{n}.csv")
#     table_test = table_test.norm(axis='sample', inplace=False).to_dataframe().T
#     table_test.loc[:, "Group"] = metadata.loc[table_test.index.values, "group"]
#     table_test.to_csv(f"/home/dongbiao/GCN/data/OSCC/data/test_{n}.csv")
    
#     n = n + 1
kf = KFold(n_splits=5, shuffle=True, random_state=42)
sid = table.ids(axis="sample")
n = 1
for train_idx, test_idx in kf.split(sid):
    table_train = table.filter(sid[train_idx], axis="sample", inplace=False)
    with biom.util.biom_open(f"/home/dongbiao/GCN/data/OSCC/data/train_{n}.biom", 'w') as f:
        table_train.to_hdf5(f, 'example')
    
    table_test = table.filter(sid[test_idx], axis="sample", inplace=False)
    with biom.util.biom_open(f"/home/dongbiao/GCN/data/OSCC/data/test_{n}.biom", 'w') as f:
        table_test.to_hdf5(f, 'example')

    table_train = table_train.to_dataframe().T
    table_train.loc[:, "Group"] = metadata.loc[table_train.index.values, "group"]
    table_train.to_csv(f"/home/dongbiao/GCN/data/OSCC/data/train_{n}.csv")
    table_test = table_test.to_dataframe().T
    table_test.loc[:, "Group"] = metadata.loc[table_test.index.values, "group"]
    table_test.to_csv(f"/home/dongbiao/GCN/data/OSCC/data/test_{n}.csv")
        
    n = n + 1

### Multi-status classsification

In [53]:
table = pd.read_csv("/home/dongbiao/software/Phylo-Spec/data/Real_Dateset_Multi-classification/Multi-classification.csv")
metadata = pd.DataFrame({"sample_id": table.iloc[:, 0].values, "disease": table["Group"].values,
                         "group": encoder.fit_transform(table["Group"].values)})
metadata.to_csv("/home/dongbiao/software/Phylo-Spec/data/Real_Dateset_Multi-classification/metadata.tsv", sep="\t", index=False)

fid = table.iloc[:, 1:-1].columns.values.astype(str)
table = biom.Table(data=table.iloc[:, 1:-1].values.T, observation_ids=fid, sample_ids=table.iloc[:, 0].values)
with biom.util.biom_open("/home/dongbiao/software/Phylo-Spec/data/Real_Dateset_Multi-classification/table.biom", 'w') as f:
    table.to_hdf5(f, 'example')

metadata = metadata.set_index("sample_id")
table = biom.load_table("/home/dongbiao/software/Phylo-Spec/data/Real_Dateset_Multi-classification/table.biom")

kf = KFold(n_splits=5, shuffle=True, random_state=42)
sid = table.ids(axis="sample")
n = 1
for train_idx, test_idx in kf.split(sid):
    table_train = table.filter(sid[train_idx], axis="sample", inplace=False)
    with biom.util.biom_open(f"/home/dongbiao/software/Phylo-Spec/data/Real_Dateset_Multi-classification/data/train_{n}.biom", 'w') as f:
        table_train.to_hdf5(f, 'example')
    
    table_test = table.filter(sid[test_idx], axis="sample", inplace=False)
    with biom.util.biom_open(f"/home/dongbiao/software/Phylo-Spec/data/Real_Dateset_Multi-classification/data/test_{n}.biom", 'w') as f:
        table_test.to_hdf5(f, 'example')

    table_train = table_train.to_dataframe().T
    table_train.loc[:, "Group"] = metadata.loc[table_train.index.values, "group"]
    table_train.to_csv(f"/home/dongbiao/software/Phylo-Spec/data/Real_Dateset_Multi-classification/data/train_{n}.csv")
    table_test = table_test.to_dataframe().T
    table_test.loc[:, "Group"] = metadata.loc[table_test.index.values, "group"]
    table_test.to_csv(f"/home/dongbiao/software/Phylo-Spec/data/Real_Dateset_Multi-classification/data/test_{n}.csv")
        
    n = n + 1

In [337]:
for i in range(1, 6):
    table_train = biom.load_table(f"/home/dongbiao/software/Phylo-Spec/data/Real_Dateset_Multi-classification/data/train_{i}.biom")
    table_test = biom.load_table(f"/home/dongbiao/software/Phylo-Spec/data/Real_Dateset_Multi-classification/data/test_{i}.biom")
    for m in ["without_low", "without_high"]:
        for n in [0.1, 0.2, 0.4, 0.5, 0.6, 0.8]:
            tree = TreeNode.read(f"/home/dongbiao/software/Phylo-Spec/data/Real_Dateset_Multi-classification/phylogeny_{m}_{n}.nwk")
            leaf_names = [node.name for node in tree.tips()]
            
            table_train_pick = table_train.filter(leaf_names, axis="observation", inplace=False)
            with biom.util.biom_open(f"/home/dongbiao/software/Phylo-Spec/data/Real_Dateset_Multi-classification/data/train_{m}_{n}_{i}.biom", 'w') as f:
                table_train_pick.to_hdf5(f, 'example')
                
            table_test_pick = table_test.filter(leaf_names, axis="observation", inplace=False)
            with biom.util.biom_open(f"/home/dongbiao/software/Phylo-Spec/data/Real_Dateset_Multi-classification/data/test_{m}_{n}_{i}.biom", 'w') as f:
                table_test_pick.to_hdf5(f, 'example')

### Synthetic data

In [225]:
class GreengenesSyntheticGenerator:
    def __init__(self, tree_path, n_features=60, random_seed=42):
        """
        初始化生成器：基于真实的 Greengenes 树文件。
        """
        self.n_features = n_features
        self.rng = np.random.default_rng()
        
        print(f"正在加载进化树: {tree_path} ...")
        self.full_tree = TreeNode.read(tree_path)
        
        # 随机选择 OTU
        all_tips = [node.name for node in self.full_tree.tips()]
            
        self.selected_otus = self.rng.choice(all_tips, size=n_features, replace=False)
        self.leaf_names = list(self.selected_otus)
        
        # 剪切树并计算距离
        print("正在计算系统发育距离矩阵...")
        self.sheared_tree = self.full_tree.shear(self.leaf_names)
        dm = self.sheared_tree.tip_tip_distances()
        
        self.phylo_dist_matrix = np.zeros((n_features, n_features))
        for i, otu1 in enumerate(self.leaf_names):
            for j, otu2 in enumerate(self.leaf_names):
                self.phylo_dist_matrix[i, j] = dm[otu1, otu2]

        # =================================================================
        # 【修改重点】：强制平衡划分 Clade 1 和 Clade 2
        # =================================================================
        # 1. 使用 'ward' 方法进行聚类，它倾向于产生大小更均匀的簇（相比 average/single）
        #    注意：虽然 UPGMA (average) 更符合生物学进化，但为了模拟平衡数据，Ward 更稳健
        Z = linkage(squareform(self.phylo_dist_matrix), method='ward')
        
        # 2. 获取按树结构排序后的叶子节点索引列表
        #    leaves_list 返回的顺序保证了相邻的索引在树上是亲缘关系最近的
        ordered_indices = leaves_list(Z)
        
        # 3. 强制一分为二 (例如 60 -> 30 + 30)
        #    这样既保证了数量平衡，又因为基于树排序，保证了组内的系统发育相似性
        midpoint = len(ordered_indices) // 2
        clade1_indices = ordered_indices[:midpoint]
        clade2_indices = ordered_indices[midpoint:]
        
        # 4. 生成标签数组
        self.cluster_labels = np.zeros(n_features, dtype=int)
        self.cluster_labels[clade1_indices] = 1
        self.cluster_labels[clade2_indices] = 2
        
        print(f"聚类完成: Clade 1 ({len(clade1_indices)} 个), Clade 2 ({len(clade2_indices)} 个)")

    def _generate_base_abundance(self, n_samples, pattern_type):
        data = np.zeros((n_samples, self.n_features))
        
        clade1_idx = np.where(self.cluster_labels == 1)[0]
        clade2_idx = np.where(self.cluster_labels == 2)[0]
        
        if pattern_type == 'S1':
            # S1 主要在 Clade 1 (Cluster 1)
            active_idx = clade1_idx
            # if len(clade2_idx) > 0:
            #     active_idx = np.concatenate([active_idx, self.rng.choice(clade2_idx, size=int(len(clade2_idx)*0.1))])
        else:
            # S2 主要在 Clade 2 (Cluster 2)
            active_idx = clade2_idx
            # if len(clade1_idx) > 0:
            #     active_idx = np.concatenate([active_idx, self.rng.choice(clade1_idx, size=int(len(clade1_idx)*0.1))])
            
        for i in range(n_samples):
            base_vals = self.rng.dirichlet(np.ones(len(active_idx)), size=1)[0]
            data[i, active_idx] = base_vals
        return data

    def generate_dataset_1(self):
        """
        生成 Synthetic data set 1
        修改说明：
        - S3 (从 S1 扰动而来): 被扰动的特征从 S2 的主要特征 (Clade 2) 中选取。
        - S4 (从 S2 扰动而来): 被扰动的特征从 S1 的主要特征 (Clade 1) 中选取。
        """
        n_total_group1 = 77
        n_total_group2 = 71
        
        # 1. 生成基础数据
        data_g1 = self._generate_base_abundance(n_total_group1, 'S1')
        data_g2 = self._generate_base_abundance(n_total_group2, 'S2')
        
        labels_g1 = np.zeros(n_total_group1)
        labels_g2 = np.ones(n_total_group2)
        
        # 2. 添加噪声
        data_g1 += self.rng.normal(0, 1e-6, data_g1.shape)
        data_g2 += self.rng.normal(0, 1e-6, data_g2.shape)
        
        # 获取 S1 (Clade 1) 和 S2 (Clade 2) 对应的特征索引
        # self.cluster_labels 在 __init__ 中已经生成
        clade1_idx = np.where(self.cluster_labels == 1)[0]
        clade2_idx = np.where(self.cluster_labels == 2)[0]

        # 3. 引入特征错位 (Feature Misalignment)
        # 修改：增加 target_pool 参数，用于指定从哪个集合中选取要被扰动的特征
        def apply_misalignment(data, target_pool, move_id):
            n_samples = data.shape[0]
            # 随机选择 20%-30% 的样本进行扰动 (根据你的上一段代码逻辑调整回论文常见比例，或者保持你设定的0.8)
            # 这里我保持你代码中的 0.8-0.9，如果需要论文原比例请改为 0.2-0.3
            perturb_ratio = 0.90
            n_perturb = int(n_samples * perturb_ratio)
            perturb_indices = self.rng.choice(n_samples, n_perturb, replace=False)

            for idx in perturb_indices:
                # 随机选择一定比例的 OTU 进行移动
                otu_ratio = self.rng.uniform(0.45, 0.45)
                # 计算需要移动的特征数量
                actual_n_move = int(len(move_id) * otu_ratio)
                source_otus = self.rng.choice(move_id, actual_n_move, replace=False)
                replace_otus = self.rng.choice(target_pool, actual_n_move, replace=False)

                data[idx, replace_otus] = data[idx, source_otus]
                data[idx, source_otus] = 0

            return data

        # S1 -> S3: 扰动 S2 的特征 (即从 clade2_idx 中选特征来捣乱)
        # 逻辑：S1 样本中本来 S2 特征很少，现在我们特意去动这些特征(如果存在的话)
        data_g1 = apply_misalignment(data_g1, target_pool=clade2_idx, move_id=clade1_idx)
        
        # S2 -> S4: 扰动 S1 的特征 (即从 clade1_idx 中选特征来捣乱)
        data_g2 = apply_misalignment(data_g2, target_pool=clade1_idx, move_id=clade2_idx)
        
        # 非负截断
        data_g1 = np.maximum(data_g1, 0)
        data_g2 = np.maximum(data_g2, 0)
        
        X = np.vstack([data_g1, data_g2])
        y = np.concatenate([labels_g1, labels_g2])
        
        return X, y, self.leaf_names

In [604]:
tree_path = "/home/dongbiao/GCN/data/synthetic_data/phylogeny.nwk"

generator = GreengenesSyntheticGenerator(tree_path, n_features=200)

X, y, otu_ids = generator.generate_dataset_1()

正在加载进化树: /home/dongbiao/GCN/data/synthetic_data/phylogeny.nwk ...
正在计算系统发育距离矩阵...
聚类完成: Clade 1 (100 个), Clade 2 (100 个)


In [605]:
s = [f"S_{i}" for i in range(1, X.shape[0]+1)]

In [606]:
table = biom.Table(data=X.T, observation_ids=otu_ids, sample_ids=s)
with biom.util.biom_open("/home/dongbiao/GCN/data/synthetic_data/table.biom", 'w') as f:
    table.to_hdf5(f, 'example')

In [607]:
metadata = pd.DataFrame({"sample_id": s, "group": y.astype(np.int32)})
metadata.to_csv("/home/dongbiao/GCN/data/synthetic_data/metadata.txt", sep="\t", index=False)

In [608]:
metadata = metadata.set_index("sample_id")

In [609]:
table = biom.load_table("/home/dongbiao/GCN/data/synthetic_data/table.biom")
kf = KFold(n_splits=5, shuffle=True, random_state=42)
sid = table.ids(axis="sample")
n = 1
for train_idx, test_idx in kf.split(sid):
    table_train = table.filter(sid[train_idx], axis="sample", inplace=False)
    table_train.remove_empty()
    with biom.util.biom_open(f"/home/dongbiao/GCN/data/synthetic_data/data/train_{n}.biom", 'w') as f:
        table_train.to_hdf5(f, 'example')
    
    table_test = table.filter(sid[test_idx], axis="sample", inplace=False)
    table_test.remove_empty()
    with biom.util.biom_open(f"/home/dongbiao/GCN/data/synthetic_data/data/test_{n}.biom", 'w') as f:
        table_test.to_hdf5(f, 'example')

    table_train = table_train.to_dataframe().T
    table_train.loc[:, "Group"] = metadata.loc[table_train.index.values, "group"]
    table_train.to_csv(f"/home/dongbiao/GCN/data/synthetic_data/data/train_{n}.csv")
    table_test = table_test.to_dataframe().T
    table_test.loc[:, "Group"] = metadata.loc[table_test.index.values, "group"]
    table_test.to_csv(f"/home/dongbiao/GCN/data/synthetic_data/data/test_{n}.csv")
        
    n = n + 1

### Synthetic data suxiaoquan

In [641]:
table = pd.read_csv("/home/dongbiao/software/Phylo-Spec/data/Synthetic_Dataset_1/Synthetic_Dataset_1.csv")

In [643]:
metadata = pd.DataFrame({"sample_id": table.iloc[:, 0].values, "group": table["Group"].values})
metadata["group"] = [0 if i == "Group_S1_S3" else 1 for i in metadata["group"].values]
metadata.to_csv("/home/dongbiao/software/Phylo-Spec/data/Synthetic_Dataset_1/metadata.txt", sep="\t", index=False)

In [635]:
table = biom.Table(data=table.iloc[:, 1:61].values.T, observation_ids=table.iloc[:, 1:61].columns.values, sample_ids=table.iloc[:, 0].values)
with biom.util.biom_open("/home/dongbiao/software/Phylo-Spec/data/Synthetic_Dataset_1/table.biom", 'w') as f:
    table.to_hdf5(f, 'example')

In [638]:
metadata = metadata.set_index("sample_id")

In [639]:
table = biom.load_table("/home/dongbiao/software/Phylo-Spec/data/Synthetic_Dataset_1/table.biom")
kf = KFold(n_splits=5, shuffle=True, random_state=42)
sid = table.ids(axis="sample")
n = 1
for train_idx, test_idx in kf.split(sid):
    table_train = table.filter(sid[train_idx], axis="sample", inplace=False)
    table_train.remove_empty()
    with biom.util.biom_open(f"/home/dongbiao/software/Phylo-Spec/data/Synthetic_Dataset_1/data/train_{n}.biom", 'w') as f:
        table_train.to_hdf5(f, 'example')
    
    table_test = table.filter(sid[test_idx], axis="sample", inplace=False)
    table_test.remove_empty()
    with biom.util.biom_open(f"/home/dongbiao/software/Phylo-Spec/data/Synthetic_Dataset_1/data/test_{n}.biom", 'w') as f:
        table_test.to_hdf5(f, 'example')

    table_train = table_train.to_dataframe().T
    table_train.loc[:, "Group"] = metadata.loc[table_train.index.values, "group"]
    table_train.to_csv(f"/home/dongbiao/software/Phylo-Spec/data/Synthetic_Dataset_1/data/train_{n}.csv")
    table_test = table_test.to_dataframe().T
    table_test.loc[:, "Group"] = metadata.loc[table_test.index.values, "group"]
    table_test.to_csv(f"/home/dongbiao/software/Phylo-Spec/data/Synthetic_Dataset_1/data/test_{n}.csv")
        
    n = n + 1

In [644]:
tree_1 = TreeNode.read("/home/dongbiao/software/Phylo-Spec/data/Synthetic_Dataset_1/phylogeny.nwk")
tree_2 = TreeNode.read("/home/dongbiao/GCN/data/synthetic_data/phylogeny.nwk")

In [645]:
tree_1_distance = tree_1.tip_tip_distances()
tree_2_distance = tree_2.tip_tip_distances()

In [647]:
tree_1_distance.data

array([[0.     , 0.21668, 0.2518 , ..., 0.38665, 0.37041, 0.35561],
       [0.21668, 0.     , 0.22174, ..., 0.41933, 0.40309, 0.38829],
       [0.2518 , 0.22174, 0.     , ..., 0.45445, 0.43821, 0.42341],
       ...,
       [0.38665, 0.41933, 0.45445, ..., 0.     , 0.2272 , 0.24656],
       [0.37041, 0.40309, 0.43821, ..., 0.2272 , 0.     , 0.23032],
       [0.35561, 0.38829, 0.42341, ..., 0.24656, 0.23032, 0.     ]])

In [648]:
tree_2_distance.data

array([[  0.      ,  25.260833, 117.193324, ..., 332.71476 , 310.231344,
        294.66027 ],
       [ 25.260833,   0.      , 120.236223, ..., 335.757659, 313.274243,
        297.703169],
       [117.193324, 120.236223,   0.      , ..., 308.432318, 285.948902,
        270.377828],
       ...,
       [332.71476 , 335.757659, 308.432318, ...,   0.      , 199.988276,
        372.112816],
       [310.231344, 313.274243, 285.948902, ..., 199.988276,   0.      ,
        349.6294  ],
       [294.66027 , 297.703169, 270.377828, ..., 372.112816, 349.6294  ,
          0.      ]])

In [655]:
def gaussian_kernel(d, rho=2):
    """
    weight = exp(-2 * ρ * d^2)
    """
    return np.exp(-2 * rho * d ** 2)

In [665]:
gaussian_kernel(0.4, rho=5e-5)

0.9999840001279994

In [38]:
def normalize_tree_for_gnn(input_path, output_path):
    tree = TreeNode.read(input_path)
    max_depth = 0.0
    for tip in tree.tips():
        dist = tip.accumulate_to_ancestor(tree)
        if dist > max_depth:
            max_depth = dist

    count = 0
    for node in tree.traverse():
        if node.length is not None:
            node.length = node.length / max_depth
            count += 1

    tips = list(tree.tips())
    if len(tips) >= 2:
        d = tips[0].distance(tips[1])

    tree.write(output_path)

In [867]:
input_tree_file = "/home/dongbiao/GCN/data/synthetic_data/phylogeny.nwk" 
output_tree_file = "/home/dongbiao/GCN/data/synthetic_data/phylogeny_normalize.nwk"
normalize_tree_for_gnn(input_tree_file, output_tree_file)

In [805]:
input_tree_file = "/home/dongbiao/GCN/data/IBD/phylogeny.nwk" 
output_tree_file = "/home/dongbiao/GCN/data/IBD/phylogeny_normalize.nwk"
normalize_tree_for_gnn(input_tree_file, output_tree_file)

In [834]:
input_tree_file = "/home/dongbiao/GCN/data/CRC/phylogeny.nwk" 
output_tree_file = "/home/dongbiao/GCN/data/CRC/phylogeny_normalize.nwk"
normalize_tree_for_gnn(input_tree_file, output_tree_file)

In [854]:
input_tree_file = "/home/dongbiao/GCN/data/dietary_fiber/phylogeny.nwk" 
output_tree_file = "/home/dongbiao/GCN/data/dietary_fiber/phylogeny_normalize.nwk"
normalize_tree_for_gnn(input_tree_file, output_tree_file)

In [1055]:
input_tree_file = "/home/dongbiao/GCN/data/OSCC_16S/phylogeny.nwk" 
output_tree_file = "/home/dongbiao/GCN/data/OSCC_16S/phylogeny_normalize.nwk"
normalize_tree_for_gnn(input_tree_file, output_tree_file)

In [49]:
input_tree_file = "/home/dongbiao/software/Phylo-Spec/data/Real_Dateset_16S_CRC/16S_CRC.nwk" 
output_tree_file = "/home/dongbiao/software/Phylo-Spec/data/Real_Dateset_16S_CRC/phylogeny_normalize.nwk"
normalize_tree_for_gnn(input_tree_file, output_tree_file)

In [282]:
input_tree_file = "/home/dongbiao/GCN/data/dietary_fiber/phylogeny.nwk" 
output_tree_file = "/home/dongbiao/GCN/data/dietary_fiber/phylogeny_normalize.nwk"
normalize_tree_for_gnn(input_tree_file, output_tree_file)

In [1019]:
input_tree_file = "/home/dongbiao/software/Phylo-Spec/data/Real_Dateset_WGS_CRC/phylogeny.nwk" 
output_tree_file = "/home/dongbiao/software/Phylo-Spec/data/Real_Dateset_WGS_CRC/phylogeny_normalize.nwk"
normalize_tree_for_gnn(input_tree_file, output_tree_file)

In [1044]:
input_tree_file = "/home/dongbiao/software/Phylo-Spec/data/Real_Dateset_WGS_T2D/phylogeny.nwk" 
output_tree_file = "/home/dongbiao/software/Phylo-Spec/data/Real_Dateset_WGS_T2D/phylogeny_normalize.nwk"
normalize_tree_for_gnn(input_tree_file, output_tree_file)

In [39]:
input_tree_file = "/home/dongbiao/software/Phylo-Spec/data/Real_Dateset_Multi-classification/phylogeny.nwk" 
output_tree_file = "/home/dongbiao/software/Phylo-Spec/data/Real_Dateset_Multi-classification/phylogeny_normalize.nwk"
normalize_tree_for_gnn(input_tree_file, output_tree_file)

### Deepphylo Embedding

In [54]:
reducer = PCA(n_components=32)

In [57]:
### synthetic data
tree = TreeNode.read("/home/dongbiao/software/Phylo-Spec/data/Synthetic_Dataset_1/phylogeny.nwk")
dm = tree.tip_tip_distances()
dm = pd.DataFrame(data=dm.data, index=dm.ids, columns=dm.ids)
phy_embedding = reducer.fit_transform(dm.values)
phy_embedding = pd.DataFrame(data=phy_embedding, index=dm.index.values)
phy_embedding.to_csv("/home/dongbiao/GCN/data/synthetic_data/PCA_32.txt", sep="\t")

In [58]:
### IBD 16S
tree = TreeNode.read("/home/dongbiao/software/Phylo-Spec/data/Real_Dateset_16S_IBD/phylogeny_normalize.nwk")
dm = tree.tip_tip_distances()
dm = pd.DataFrame(data=dm.data, index=dm.ids, columns=dm.ids)
phy_embedding = reducer.fit_transform(dm.values)
phy_embedding = pd.DataFrame(data=phy_embedding, index=dm.index.values)
phy_embedding.to_csv("/home/dongbiao/GCN/data/IBD_16S/PCA_32.txt", sep="\t")

In [59]:
### CRC 16S
tree = TreeNode.read("/home/dongbiao/software/Phylo-Spec/data/Real_Dateset_16S_CRC/phylogeny_normalize.nwk")
dm = tree.tip_tip_distances()
dm = pd.DataFrame(data=dm.data, index=dm.ids, columns=dm.ids)
phy_embedding = reducer.fit_transform(dm.values)
phy_embedding = pd.DataFrame(data=phy_embedding, index=dm.index.values)
phy_embedding.to_csv("/home/dongbiao/GCN/data/CRC_16S//PCA_32.txt", sep="\t")

In [283]:
### dietary fiber 16S
tree = TreeNode.read("/home/dongbiao/GCN/data/dietary_fiber/phylogeny_normalize.nwk")
dm = tree.tip_tip_distances()
dm = pd.DataFrame(data=dm.data, index=dm.ids, columns=dm.ids)
phy_embedding = reducer.fit_transform(dm.values)
phy_embedding = pd.DataFrame(data=phy_embedding, index=dm.index.values)
phy_embedding.to_csv("/home/dongbiao/GCN/data/dietary_fiber/PCA_32.txt", sep="\t")

In [60]:
### CRC WGC
tree = TreeNode.read("/home/dongbiao/software/Phylo-Spec/data/Real_Dateset_WGS_CRC/phylogeny_normalize.nwk")
dm = tree.tip_tip_distances()
dm = pd.DataFrame(data=dm.data, index=dm.ids, columns=dm.ids)
phy_embedding = reducer.fit_transform(dm.values)
phy_embedding = pd.DataFrame(data=phy_embedding, index=dm.index.values)
phy_embedding.to_csv("/home/dongbiao/GCN/data/CRC_WGS/PCA_32.txt", sep="\t")

In [61]:
### T2D WGS
tree = TreeNode.read("/home/dongbiao/software/Phylo-Spec/data/Real_Dateset_WGS_T2D/phylogeny_normalize.nwk")
dm = tree.tip_tip_distances()
dm = pd.DataFrame(data=dm.data, index=dm.ids, columns=dm.ids)
phy_embedding = reducer.fit_transform(dm.values)
phy_embedding = pd.DataFrame(data=phy_embedding, index=dm.index.values)
phy_embedding.to_csv("/home/dongbiao/GCN/data/T2D_WGS/PCA_32.txt", sep="\t")

In [62]:
### Multi-classification 16S
tree = TreeNode.read("/home/dongbiao/software/Phylo-Spec/data/Real_Dateset_Multi-classification/phylogeny_normalize.nwk")
dm = tree.tip_tip_distances()
dm = pd.DataFrame(data=dm.data, index=dm.ids, columns=dm.ids)
phy_embedding = reducer.fit_transform(dm.values)
phy_embedding = pd.DataFrame(data=phy_embedding, index=dm.index.values)
phy_embedding.to_csv("/home/dongbiao/GCN/data/Multi_classification/PCA_32.txt", sep="\t")