In [1]:
from biom import load_table
import pandas as pd

# Load table
biom_table = load_table('./data/AG_100nt_even10k.biom')
otu_table = biom_table.to_dataframe()
taxonomy = biom_table.metadata_to_dataframe('observation')

genera_of_interest = ["Lactobacillus", "Bifidobacterium", "Clostridium", "Bacteroides", "Prevotella"]

def extract_genus(row):
    g = row["taxonomy_5"].replace("g__", "")
    if g in genera_of_interest:
       return g
    else:
        p = row["taxonomy_1"].replace("p__", "")
        p = p.replace("[", "").replace("]", "")
        return p

taxonomy["Genus"] = taxonomy.apply(extract_genus, axis=1)

genus_abundance = otu_table.groupby(taxonomy["Genus"]).sum()

total_abundance = genus_abundance.sum(axis=1)
threshold = 0.001 * total_abundance.sum()  # 0.1% of total counts
rare_taxa = [tax for tax in total_abundance[total_abundance < threshold].index 
             if tax not in genera_of_interest]

# Replace rare taxa with 'Other' and sum
genus_abundance = genus_abundance.rename(index=lambda x: "Other" if x in rare_taxa else x)
genus_abundance = genus_abundance.groupby(genus_abundance.index).sum()

# Normalize to relative abundance
genus_abundance = genus_abundance.div(genus_abundance.sum(axis=0), axis=1)

# Check results
samples = genus_abundance.T.copy()
samples.index.name = 'SampleID'
samples.head()

Genus,Actinobacteria,Bacteroides,Bacteroidetes,Bifidobacterium,Clostridium,Cyanobacteria,Firmicutes,Fusobacteria,Lactobacillus,Other,Prevotella,Proteobacteria,Tenericutes,Verrucomicrobia
SampleID,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1
7117.1075649,0.3244,0.0013,0.0315,0.0,0.0013,0.0352,0.2629,0.0082,0.0027,0.0461,0.0028,0.2821,0.001,0.0005
5636.1053788,0.2811,0.0,0.0055,0.0003,0.0,0.0134,0.5261,0.0059,0.0008,0.0019,0.0012,0.1633,0.0003,0.0002
5637.1053909,0.2378,0.0004,0.0088,0.0003,0.0001,0.012,0.4414,0.0052,0.0004,0.0029,0.0025,0.2878,0.0,0.0004
5634.1053886,0.4447,0.0005,0.0084,0.0002,0.0,0.0088,0.3892,0.0039,0.0,0.0037,0.0007,0.1399,0.0,0.0
7115.1075661,0.0347,0.0087,0.1735,0.0018,0.0149,0.0289,0.1783,0.0032,0.0166,0.0153,0.0025,0.5196,0.0,0.002


In [2]:
# Merge demographic data onto samples
ag = pd.read_csv('./data/AG.txt', sep="\t")
ag = ag.rename(columns={'#SampleID': 'SampleID'})
columns_of_interest = ['SampleID', 'AGE', 'SEX', 'PREGNANT']

samples_index_df = samples.reset_index().rename(columns={'index': 'SampleID'})

ag_merged = samples_index_df.merge(
    ag[columns_of_interest],  # metadata
    on='SampleID',
    how='left'  # keeps all rows from samples
)

ag_merged = ag_merged.rename(columns={"AGE": "Age", "SEX": "Sex"})

# Remove pregnant participants
ag_filtered = ag_merged[ag_merged['PREGNANT'].str.lower() != 'yes']
ag_filtered = ag_filtered.drop(columns=['PREGNANT'])

# Remove participants with missing age or sex
ag_filtered = ag_filtered[ag_filtered['Sex'].str.lower().isin(['male', 'female'])]
ag_filtered['Age'] = pd.to_numeric(ag_filtered['Age'], errors='coerce')
ag_filtered = ag_filtered[ag_filtered['Age'].notna()]

# Convert age to integer and sex to 0 (female) and 1 (male)
ag_filtered['Age'] = ag_filtered['Age'].astype(int)
ag_filtered['Sex'] = ag_filtered['Sex'].map({'female': 0, 'male': 1})

# Remove children
ag_filtered = ag_filtered[ag_filtered['Age'] >= 18]

ag_filtered

  ag = pd.read_csv('./data/AG.txt', sep="\t")


Unnamed: 0,SampleID,Actinobacteria,Bacteroides,Bacteroidetes,Bifidobacterium,Clostridium,Cyanobacteria,Firmicutes,Fusobacteria,Lactobacillus,Other,Prevotella,Proteobacteria,Tenericutes,Verrucomicrobia,Age,Sex
0,000007117.1075649,0.3244,0.0013,0.0315,0.0,0.0013,0.0352,0.2629,0.0082,0.0027,0.0461,0.0028,0.2821,0.001,0.0005,59,1
4,000007115.1075661,0.0347,0.0087,0.1735,0.0018,0.0149,0.0289,0.1783,0.0032,0.0166,0.0153,0.0025,0.5196,0.0,0.002,59,1
5,000007123.1075697,0.0493,0.0,0.0153,0.0,0.0002,0.1265,0.0451,0.0025,0.1222,0.0024,0.0007,0.6347,0.0,0.0011,59,1
6,000009713.1130401,0.4052,0.001,0.0091,0.0,0.0002,0.0265,0.4219,0.0026,0.0015,0.0055,0.0062,0.1191,0.0,0.0012,70,0
7,000005598.1130569,0.2394,0.0072,0.0401,0.0001,0.001,0.0255,0.2582,0.0013,0.017,0.0209,0.0024,0.3828,0.0003,0.0038,72,1
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
3100,000003719.1257129,0.0018,0.0096,0.0028,0.3933,0.0,0.0001,0.1794,0.0,0.0016,0.0,0.0004,0.4108,0.0001,0.0001,57,0
3103,000015353.fixed1024,0.083,0.0006,0.3497,0.0,0.0001,0.0,0.0971,0.0029,0.0,0.0019,0.0279,0.436,0.0008,0.0,35,1
3104,000011980.1210764,0.005,0.0008,0.0003,0.0,0.0,0.0001,0.0023,0.0,0.9781,0.0,0.0003,0.0117,0.0014,0.0,52,0
3105,000005567.1131812,0.043,0.0006,0.0,0.0,0.0,0.0,0.0187,0.0003,0.0,0.0001,0.0012,0.9361,0.0,0.0,51,1


In [3]:
diet_cols = {
    "SEQN": "SEQN",
    "DR1TIRON": "Iron_intake_mg",
    "DR1TTFAT": "Fat_intake_g",
    "DR1TFIBE": "Fiber_intake_g"
}

labs_cols = {
    "SEQN": "SEQN",
    "LBXSIR": "Iron_serum_ug/dL"
}

diet = pd.read_csv("data/NHANES/diet.csv", usecols=lambda c: c in diet_cols.keys()).rename(columns=diet_cols)
labs = pd.read_csv("data/NHANES/labs.csv", usecols=lambda c: c in labs_cols.keys()).rename(columns=labs_cols)

nhanes = diet.merge(labs, on="SEQN", how="left")
nhanes

Unnamed: 0,SEQN,Fiber_intake_g,Fat_intake_g,Iron_intake_mg,Iron_serum_ug/dL
0,73557,10.8,52.81,8.41,58.0
1,73558,16.7,124.29,26.88,79.0
2,73559,9.9,65.97,17.57,98.0
3,73560,10.6,58.27,14.19,
4,73561,12.3,55.36,17.72,91.0
...,...,...,...,...,...
9808,83727,30.4,193.51,47.01,73.0
9809,83728,9.3,52.39,6.62,
9810,83729,25.7,110.30,15.06,49.0
9811,83730,,,,


In [4]:
fiber_std = nhanes['Fiber_intake_g'].std()
fat_std = nhanes['Fat_intake_g'].std()
iron_intake_std = nhanes['Iron_intake_mg'].std()


print(f'Fiber intake std: {fiber_std}')
print(f'Fat intake std: {fat_std}')
print(f'Iron intake std: {iron_intake_std}')

Fiber intake std: 10.132655979964545
Fat intake std: 45.50421029810836
Iron intake std: 8.544092329884053


In [5]:
import numpy as np
import pandas as pd

# Simulate dietary factors based on bacterial abundance
def simulate_dietary_intake(microbiome_sample):
    # Use direct indexing for each bacterium
    bifidobacterium_abundance = microbiome_sample["Bifidobacterium"]
    prevotella_abundance = microbiome_sample["Prevotella"]
    firmicutes_abundance = microbiome_sample["Firmicutes"]
    lactobacillus_abundance = microbiome_sample["Lactobacillus"]
    bacteroides_abundance = microbiome_sample["Bacteroides"]
    proteobacteria_abundance = microbiome_sample["Proteobacteria"]
    
    # Calculate dietary factors using these abundances
    fiber_intake = max(0, (
        15.28 +
        bacteroides_abundance * 25 +
        bifidobacterium_abundance * 15 +
        prevotella_abundance * 10 +
        np.random.normal(0, 10.13)
    ))
    
    fat_intake = max(0, (
        75.10 +
        firmicutes_abundance * 50 +
        lactobacillus_abundance * 5 +
        np.random.normal(0, 15)
    ))
    
    iron_intake = max(0, (
        14.06 +
        bacteroides_abundance * 10 +
        proteobacteria_abundance * 5 + 
        lactobacillus_abundance * 5 +
        np.random.normal(0, 2)
    ))

    return pd.Series({
        'Fiber_intake_g': np.round(fiber_intake, 2),
        'Fat_intake_g': np.round(fat_intake, 2),
        'Iron_intake_mg': np.round(iron_intake, 2)
    })


cleaned_ag_dataset = ag_filtered.copy()

# Apply simulation to each microbiome sample
simulated_diet = cleaned_ag_dataset.apply(simulate_dietary_intake, axis=1)
print(simulated_diet.describe())

# Combine simulated diet with microbiome data
combined_data = pd.concat([cleaned_ag_dataset, simulated_diet], axis=1)
combined_data

       Fiber_intake_g  Fat_intake_g  Iron_intake_mg
count     2512.000000   2512.000000     2512.000000
mean        22.286947     97.916350       17.065322
std         10.967446     18.168486        2.729833
min          0.000000     35.680000        7.410000
25%         14.575000     85.880000       15.140000
50%         22.300000     97.475000       16.970000
75%         29.465000    110.287500       18.930000
max         67.010000    160.020000       26.110000


Unnamed: 0,SampleID,Actinobacteria,Bacteroides,Bacteroidetes,Bifidobacterium,Clostridium,Cyanobacteria,Firmicutes,Fusobacteria,Lactobacillus,Other,Prevotella,Proteobacteria,Tenericutes,Verrucomicrobia,Age,Sex,Fiber_intake_g,Fat_intake_g,Iron_intake_mg
0,000007117.1075649,0.3244,0.0013,0.0315,0.0,0.0013,0.0352,0.2629,0.0082,0.0027,0.0461,0.0028,0.2821,0.001,0.0005,59,1,19.11,94.11,14.50
4,000007115.1075661,0.0347,0.0087,0.1735,0.0018,0.0149,0.0289,0.1783,0.0032,0.0166,0.0153,0.0025,0.5196,0.0,0.002,59,1,10.20,104.56,15.41
5,000007123.1075697,0.0493,0.0,0.0153,0.0,0.0002,0.1265,0.0451,0.0025,0.1222,0.0024,0.0007,0.6347,0.0,0.0011,59,1,6.71,77.74,18.66
6,000009713.1130401,0.4052,0.001,0.0091,0.0,0.0002,0.0265,0.4219,0.0026,0.0015,0.0055,0.0062,0.1191,0.0,0.0012,70,0,12.07,102.76,12.42
7,000005598.1130569,0.2394,0.0072,0.0401,0.0001,0.001,0.0255,0.2582,0.0013,0.017,0.0209,0.0024,0.3828,0.0003,0.0038,72,1,0.00,90.42,16.59
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
3100,000003719.1257129,0.0018,0.0096,0.0028,0.3933,0.0,0.0001,0.1794,0.0,0.0016,0.0,0.0004,0.4108,0.0001,0.0001,57,0,27.43,111.52,14.32
3103,000015353.fixed1024,0.083,0.0006,0.3497,0.0,0.0001,0.0,0.0971,0.0029,0.0,0.0019,0.0279,0.436,0.0008,0.0,35,1,28.72,89.47,12.67
3104,000011980.1210764,0.005,0.0008,0.0003,0.0,0.0,0.0001,0.0023,0.0,0.9781,0.0,0.0003,0.0117,0.0014,0.0,52,0,7.05,76.98,17.72
3105,000005567.1131812,0.043,0.0006,0.0,0.0,0.0,0.0,0.0187,0.0003,0.0,0.0001,0.0012,0.9361,0.0,0.0,51,1,11.22,75.42,20.87


In [6]:
import numpy as np

serum_iron_std = nhanes['Iron_serum_ug/dL'].std()
serum_iron_mean = nhanes['Iron_serum_ug/dL'].mean()

np.random.seed(42)

base = np.random.normal(serum_iron_mean, serum_iron_std * 0.5, len(combined_data))

signal_strength = 1

simulated_iron = (
    base
    + signal_strength * (
        # Dietary factors
        0.3 * combined_data["Iron_intake_mg"]
        + 0.05 * combined_data["Fiber_intake_g"]
        - 0.03 * combined_data["Fat_intake_g"]
        # Microbiome factors
        + 150 * combined_data["Lactobacillus"]
        + 100 * combined_data["Bifidobacterium"]
        + 80 * combined_data["Bacteroides"]
        - 100 * combined_data["Clostridium"]
        # Diet x microbiome
        + 4.5 * combined_data["Iron_intake_mg"] * combined_data["Lactobacillus"]
        + 1 * combined_data["Fiber_intake_g"] * combined_data["Bifidobacterium"]
        + 0.5 * combined_data["Fat_intake_g"] * combined_data["Lactobacillus"]
        # Diet x diet
        + 0.05 * combined_data["Iron_intake_mg"] * combined_data["Fiber_intake_g"]
        + 0.03 * combined_data["Iron_intake_mg"] * combined_data["Fat_intake_g"]
        # Microbiome x microbiome
        + 200 * combined_data["Lactobacillus"] * combined_data["Bifidobacterium"]
        + 150 * combined_data["Bacteroides"] * combined_data["Lactobacillus"]
        - 250 * combined_data["Firmicutes"] * combined_data["Clostridium"]
    )
    # Demographic factors
    - 0.1 * combined_data["Age"] - 0.001 * combined_data["Age"]**2
    + combined_data["Sex"].map({0: -5, 1: 5}) # females typically have lower iron
    - 60 # offset
)

# interaction_noise = (
#     np.random.normal(0, 0.5, len(combined_data)) * combined_data["Iron_intake_mg"] * combined_data["Bifidobacterium"]
#     + np.random.normal(0, 0.3, len(combined_data)) * combined_data["Fiber_intake_g"] * combined_data["Clostridium"]
#     + np.random.normal(0, 0.2, len(combined_data)) * combined_data["Fat_intake_g"] * combined_data["Bacteroides"]
# )

# simulated_iron += interaction_noise

# keep within physiological range
simulated_iron = np.clip(simulated_iron, 5, 557)
simulated_iron = np.round(np.array(simulated_iron), 2)

simulated_iron


array([ 85.98,  78.02, 117.55, ..., 299.54,  93.05,  63.22])

In [7]:
# Combine simulated diet with microbiome data
combined_data["Serum_iron_ug"] = simulated_iron

# Drop demographics since we won't use them for predicition
final_data = combined_data.drop(columns=['Age', 'Sex'])

print(final_data["Serum_iron_ug"].describe())
final_data

count    2512.000000
mean      109.830494
std        34.791174
min         7.770000
25%        85.527500
50%       108.225000
75%       133.195000
max       299.540000
Name: Serum_iron_ug, dtype: float64


Unnamed: 0,SampleID,Actinobacteria,Bacteroides,Bacteroidetes,Bifidobacterium,Clostridium,Cyanobacteria,Firmicutes,Fusobacteria,Lactobacillus,Other,Prevotella,Proteobacteria,Tenericutes,Verrucomicrobia,Fiber_intake_g,Fat_intake_g,Iron_intake_mg,Serum_iron_ug
0,000007117.1075649,0.3244,0.0013,0.0315,0.0,0.0013,0.0352,0.2629,0.0082,0.0027,0.0461,0.0028,0.2821,0.001,0.0005,19.11,94.11,14.50,85.98
4,000007115.1075661,0.0347,0.0087,0.1735,0.0018,0.0149,0.0289,0.1783,0.0032,0.0166,0.0153,0.0025,0.5196,0.0,0.002,10.20,104.56,15.41,78.02
5,000007123.1075697,0.0493,0.0,0.0153,0.0,0.0002,0.1265,0.0451,0.0025,0.1222,0.0024,0.0007,0.6347,0.0,0.0011,6.71,77.74,18.66,117.55
6,000009713.1130401,0.4052,0.001,0.0091,0.0,0.0002,0.0265,0.4219,0.0026,0.0015,0.0055,0.0062,0.1191,0.0,0.0012,12.07,102.76,12.42,81.72
7,000005598.1130569,0.2394,0.0072,0.0401,0.0001,0.001,0.0255,0.2582,0.0013,0.017,0.0209,0.0024,0.3828,0.0003,0.0038,0.00,90.42,16.59,64.09
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
3100,000003719.1257129,0.0018,0.0096,0.0028,0.3933,0.0,0.0001,0.1794,0.0,0.0016,0.0,0.0004,0.4108,0.0001,0.0001,27.43,111.52,14.32,153.46
3103,000015353.fixed1024,0.083,0.0006,0.3497,0.0,0.0001,0.0,0.0971,0.0029,0.0,0.0019,0.0279,0.436,0.0008,0.0,28.72,89.47,12.67,53.92
3104,000011980.1210764,0.005,0.0008,0.0003,0.0,0.0,0.0001,0.0023,0.0,0.9781,0.0,0.0003,0.0117,0.0014,0.0,7.05,76.98,17.72,299.54
3105,000005567.1131812,0.043,0.0006,0.0,0.0,0.0,0.0,0.0187,0.0003,0.0,0.0001,0.0012,0.9361,0.0,0.0,11.22,75.42,20.87,93.05


In [8]:
# Save the data to a CSV
final_data.to_csv("data/cleaned_data.csv", index=False)