In [122]:
from biom import load_table
import pandas as pd

# Load table
biom_table = load_table('./data/AG_100nt_even10k.biom')
otu_table = biom_table.to_dataframe()
taxonomy = biom_table.metadata_to_dataframe('observation')

genera_of_interest = ["Lactobacillus", "Bifidobacterium", "Clostridium", "Bacteroides", "Prevotella"]

def extract_genus(row):
    g = row["taxonomy_5"].replace("g__", "")
    if g in genera_of_interest:
       return g
    else:
        p = row["taxonomy_1"].replace("p__", "")
        p = p.replace("[", "").replace("]", "")
        return p

taxonomy["Genus"] = taxonomy.apply(extract_genus, axis=1)

genus_abundance = otu_table.groupby(taxonomy["Genus"]).sum()

total_abundance = genus_abundance.sum(axis=1)
threshold = 0.001 * total_abundance.sum()  # 0.1% of total counts
rare_taxa = [tax for tax in total_abundance[total_abundance < threshold].index 
             if tax not in genera_of_interest]

# Replace rare taxa with 'Other' and sum
genus_abundance = genus_abundance.rename(index=lambda x: "Other" if x in rare_taxa else x)
genus_abundance = genus_abundance.groupby(genus_abundance.index).sum()

# Normalize to relative abundance
genus_abundance = genus_abundance.div(genus_abundance.sum(axis=0), axis=1)

# Check results
samples = genus_abundance.T.copy()
samples.index.name = 'SampleID'
samples.head()

Genus,Actinobacteria,Bacteroides,Bacteroidetes,Bifidobacterium,Clostridium,Cyanobacteria,Firmicutes,Fusobacteria,Lactobacillus,Other,Prevotella,Proteobacteria,Tenericutes,Verrucomicrobia
SampleID,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1
7117.1075649,0.3244,0.0013,0.0315,0.0,0.0013,0.0352,0.2629,0.0082,0.0027,0.0461,0.0028,0.2821,0.001,0.0005
5636.1053788,0.2811,0.0,0.0055,0.0003,0.0,0.0134,0.5261,0.0059,0.0008,0.0019,0.0012,0.1633,0.0003,0.0002
5637.1053909,0.2378,0.0004,0.0088,0.0003,0.0001,0.012,0.4414,0.0052,0.0004,0.0029,0.0025,0.2878,0.0,0.0004
5634.1053886,0.4447,0.0005,0.0084,0.0002,0.0,0.0088,0.3892,0.0039,0.0,0.0037,0.0007,0.1399,0.0,0.0
7115.1075661,0.0347,0.0087,0.1735,0.0018,0.0149,0.0289,0.1783,0.0032,0.0166,0.0153,0.0025,0.5196,0.0,0.002


In [123]:
# Merge demographic data onto samples
ag = pd.read_csv('./data/AG.txt', sep="\t")
ag = ag.rename(columns={'#SampleID': 'SampleID'})
columns_of_interest = ['SampleID', 'AGE', 'SEX', 'PREGNANT']

samples_index_df = samples.reset_index().rename(columns={'index': 'SampleID'})

ag_merged = samples_index_df.merge(
    ag[columns_of_interest],  # metadata
    on='SampleID',
    how='left'  # keeps all rows from samples
)

ag_merged = ag_merged.rename(columns={"AGE": "Age", "SEX": "Sex"})

# Remove pregnant participants
ag_filtered = ag_merged[ag_merged['PREGNANT'].str.lower() != 'yes']
ag_filtered = ag_filtered.drop(columns=['PREGNANT'])

# Remove participants with missing age or sex
ag_filtered = ag_filtered[ag_filtered['Sex'].str.lower().isin(['male', 'female'])]
ag_filtered['Age'] = pd.to_numeric(ag_filtered['Age'], errors='coerce')
ag_filtered = ag_filtered[ag_filtered['Age'].notna()]

# Convert age to integer and sex to 0 (female) and 1 (male)
ag_filtered['Age'] = ag_filtered['Age'].astype(int)
ag_filtered['Sex'] = ag_filtered['Sex'].map({'female': 0, 'male': 1})

# Remove children
ag_filtered = ag_filtered[ag_filtered['Age'] >= 18]

ag_filtered

  ag = pd.read_csv('./data/AG.txt', sep="\t")


Unnamed: 0,SampleID,Actinobacteria,Bacteroides,Bacteroidetes,Bifidobacterium,Clostridium,Cyanobacteria,Firmicutes,Fusobacteria,Lactobacillus,Other,Prevotella,Proteobacteria,Tenericutes,Verrucomicrobia,Age,Sex
0,000007117.1075649,0.3244,0.0013,0.0315,0.0,0.0013,0.0352,0.2629,0.0082,0.0027,0.0461,0.0028,0.2821,0.001,0.0005,59,1
4,000007115.1075661,0.0347,0.0087,0.1735,0.0018,0.0149,0.0289,0.1783,0.0032,0.0166,0.0153,0.0025,0.5196,0.0,0.002,59,1
5,000007123.1075697,0.0493,0.0,0.0153,0.0,0.0002,0.1265,0.0451,0.0025,0.1222,0.0024,0.0007,0.6347,0.0,0.0011,59,1
6,000009713.1130401,0.4052,0.001,0.0091,0.0,0.0002,0.0265,0.4219,0.0026,0.0015,0.0055,0.0062,0.1191,0.0,0.0012,70,0
7,000005598.1130569,0.2394,0.0072,0.0401,0.0001,0.001,0.0255,0.2582,0.0013,0.017,0.0209,0.0024,0.3828,0.0003,0.0038,72,1
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
3100,000003719.1257129,0.0018,0.0096,0.0028,0.3933,0.0,0.0001,0.1794,0.0,0.0016,0.0,0.0004,0.4108,0.0001,0.0001,57,0
3103,000015353.fixed1024,0.083,0.0006,0.3497,0.0,0.0001,0.0,0.0971,0.0029,0.0,0.0019,0.0279,0.436,0.0008,0.0,35,1
3104,000011980.1210764,0.005,0.0008,0.0003,0.0,0.0,0.0001,0.0023,0.0,0.9781,0.0,0.0003,0.0117,0.0014,0.0,52,0
3105,000005567.1131812,0.043,0.0006,0.0,0.0,0.0,0.0,0.0187,0.0003,0.0,0.0001,0.0012,0.9361,0.0,0.0,51,1


In [124]:
diet_cols = {
    "SEQN": "SEQN",
    "DR1TIRON": "Iron_intake_mg",
    "DR1TTFAT": "Fat_intake_g",
    "DR1TFIBE": "Fiber_intake_g"
}

labs_cols = {
    "SEQN": "SEQN",
    "LBXSIR": "Iron_serum_ug/dL"
}

diet = pd.read_csv("data/NHANES/diet.csv", usecols=lambda c: c in diet_cols.keys()).rename(columns=diet_cols)
labs = pd.read_csv("data/NHANES/labs.csv", usecols=lambda c: c in labs_cols.keys()).rename(columns=labs_cols)

nhanes = diet.merge(labs, on="SEQN", how="left")
nhanes

Unnamed: 0,SEQN,Fiber_intake_g,Fat_intake_g,Iron_intake_mg,Iron_serum_ug/dL
0,73557,10.8,52.81,8.41,58.0
1,73558,16.7,124.29,26.88,79.0
2,73559,9.9,65.97,17.57,98.0
3,73560,10.6,58.27,14.19,
4,73561,12.3,55.36,17.72,91.0
...,...,...,...,...,...
9808,83727,30.4,193.51,47.01,73.0
9809,83728,9.3,52.39,6.62,
9810,83729,25.7,110.30,15.06,49.0
9811,83730,,,,


In [125]:
fiber_std = nhanes['Fiber_intake_g'].std()
fat_std = nhanes['Fat_intake_g'].std()
iron_intake_std = nhanes['Iron_intake_mg'].std()


print(f'Fiber intake std: {fiber_std}')
print(f'Fat intake std: {fat_std}')
print(f'Iron intake std: {iron_intake_std}')

Fiber intake std: 10.132655979964545
Fat intake std: 45.50421029810836
Iron intake std: 8.544092329884053


In [143]:
import numpy as np
import pandas as pd

# Simulate dietary factors based on bacterial abundance
def simulate_dietary_intake(microbiome_sample):
    # Use direct indexing for each bacterium
    bifidobacterium_abundance = microbiome_sample["Bifidobacterium"]
    prevotella_abundance = microbiome_sample["Prevotella"]
    firmicutes_abundance = microbiome_sample["Firmicutes"]
    lactobacillus_abundance = microbiome_sample["Lactobacillus"]
    bacteroides_abundance = microbiome_sample["Bacteroides"]
    proteobacteria_abundance = microbiome_sample["Proteobacteria"]
    
    # Calculate dietary factors using these abundances
    fiber_intake = max(0, (
        15.28 +
        bacteroides_abundance * 25 +
        bifidobacterium_abundance * 15 +
        prevotella_abundance * 10 +
        np.random.normal(0, 10.13)
    ))
    
    fat_intake = max(0, (
        75.10 +
        firmicutes_abundance * 50 +
        lactobacillus_abundance * 5 +
        np.random.normal(0, 15)
    ))
    
    iron_intake = max(0, (
        14.06 +
        bacteroides_abundance * 10 +
        proteobacteria_abundance * 5 + 
        lactobacillus_abundance * 5 +
        np.random.normal(0, 2)
    ))

    return pd.Series({
        'Fiber_intake_g': np.round(fiber_intake, 2),
        'Fat_intake_g': np.round(fat_intake, 2),
        'Iron_intake_mg': np.round(iron_intake, 2)
    })


cleaned_ag_dataset = ag_filtered.copy()

# Apply simulation to each microbiome sample
simulated_diet = cleaned_ag_dataset.apply(simulate_dietary_intake, axis=1)
print(simulated_diet.describe())

# Combine simulated diet with microbiome data
combined_data = pd.concat([cleaned_ag_dataset, simulated_diet], axis=1)
combined_data

       Fiber_intake_g  Fat_intake_g  Iron_intake_mg
count     2512.000000   2512.000000     2512.000000
mean        21.906819     97.863658       17.050840
std         10.970136     17.751831        2.717056
min          0.000000     21.540000        8.050000
25%         14.102500     85.547500       15.230000
50%         21.450000     97.800000       16.965000
75%         29.462500    110.190000       18.922500
max         58.850000    151.290000       26.460000


Unnamed: 0,SampleID,Actinobacteria,Bacteroides,Bacteroidetes,Bifidobacterium,Clostridium,Cyanobacteria,Firmicutes,Fusobacteria,Lactobacillus,Other,Prevotella,Proteobacteria,Tenericutes,Verrucomicrobia,Age,Sex,Fiber_intake_g,Fat_intake_g,Iron_intake_mg
0,000007117.1075649,0.3244,0.0013,0.0315,0.0,0.0013,0.0352,0.2629,0.0082,0.0027,0.0461,0.0028,0.2821,0.001,0.0005,59,1,20.90,84.43,15.25
4,000007115.1075661,0.0347,0.0087,0.1735,0.0018,0.0149,0.0289,0.1783,0.0032,0.0166,0.0153,0.0025,0.5196,0.0,0.002,59,1,18.87,85.39,12.39
5,000007123.1075697,0.0493,0.0,0.0153,0.0,0.0002,0.1265,0.0451,0.0025,0.1222,0.0024,0.0007,0.6347,0.0,0.0011,59,1,12.96,65.19,18.19
6,000009713.1130401,0.4052,0.001,0.0091,0.0,0.0002,0.0265,0.4219,0.0026,0.0015,0.0055,0.0062,0.1191,0.0,0.0012,70,0,45.61,101.71,14.05
7,000005598.1130569,0.2394,0.0072,0.0401,0.0001,0.001,0.0255,0.2582,0.0013,0.017,0.0209,0.0024,0.3828,0.0003,0.0038,72,1,24.82,95.34,16.97
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
3100,000003719.1257129,0.0018,0.0096,0.0028,0.3933,0.0,0.0001,0.1794,0.0,0.0016,0.0,0.0004,0.4108,0.0001,0.0001,57,0,15.94,54.43,15.23
3103,000015353.fixed1024,0.083,0.0006,0.3497,0.0,0.0001,0.0,0.0971,0.0029,0.0,0.0019,0.0279,0.436,0.0008,0.0,35,1,12.49,75.26,17.48
3104,000011980.1210764,0.005,0.0008,0.0003,0.0,0.0,0.0001,0.0023,0.0,0.9781,0.0,0.0003,0.0117,0.0014,0.0,52,0,35.42,81.96,18.58
3105,000005567.1131812,0.043,0.0006,0.0,0.0,0.0,0.0,0.0187,0.0003,0.0,0.0001,0.0012,0.9361,0.0,0.0,51,1,12.68,77.90,17.09


In [None]:
import numpy as np

serum_iron_std = nhanes['Iron_serum_ug/dL'].std()
serum_iron_mean = nhanes['Iron_serum_ug/dL'].mean()

np.random.seed(42)

base = np.random.normal(serum_iron_mean, serum_iron_std * 0.8, len(combined_data))

signal_strength = 2

iron_coef = 0.8
fiber_coef = 0.25 # fiber typically means better gut health
fat_coef = -0.15 # possible inflammation/metabolic effects
age_coef = -0.15 # decreased absorption with age

simulated_iron = (
    base
    + signal_strength * iron_coef * combined_data["Iron_intake_mg"].fillna(0)
    + signal_strength * fiber_coef * combined_data["Fiber_intake_g"].fillna(0)
    + signal_strength * fat_coef * combined_data["Fat_intake_g"].fillna(0)
    + age_coef * combined_data["Age"].fillna(0)
    + combined_data["Sex"].map({0: -5, 1: 5}) # females typically have lower iron
)

# keep within physiological range
simulated_iron = np.clip(simulated_iron, 5, 557)
simulated_iron = np.round(simulated_iron, 2)
simulated_iron


0       103.58
4        79.21
5       114.48
6       127.06
7        81.79
         ...  
3100    122.14
3103     55.49
3104     54.99
3105    105.14
3106    102.00
Length: 2512, dtype: float64

In [148]:
# Combine simulated diet with microbiome data
combined_data["Serum_iron_ug"] = simulated_iron

# Drop demographics since we won't use them for predicition
final_data = combined_data.drop(columns=['Age', 'Sex'])
final_data

Unnamed: 0,SampleID,Actinobacteria,Bacteroides,Bacteroidetes,Bifidobacterium,Clostridium,Cyanobacteria,Firmicutes,Fusobacteria,Lactobacillus,Other,Prevotella,Proteobacteria,Tenericutes,Verrucomicrobia,Fiber_intake_g,Fat_intake_g,Iron_intake_mg,Serum_iron_ug
0,000007117.1075649,0.3244,0.0013,0.0315,0.0,0.0013,0.0352,0.2629,0.0082,0.0027,0.0461,0.0028,0.2821,0.001,0.0005,20.90,84.43,15.25,103.58
4,000007115.1075661,0.0347,0.0087,0.1735,0.0018,0.0149,0.0289,0.1783,0.0032,0.0166,0.0153,0.0025,0.5196,0.0,0.002,18.87,85.39,12.39,79.21
5,000007123.1075697,0.0493,0.0,0.0153,0.0,0.0002,0.1265,0.0451,0.0025,0.1222,0.0024,0.0007,0.6347,0.0,0.0011,12.96,65.19,18.19,114.48
6,000009713.1130401,0.4052,0.001,0.0091,0.0,0.0002,0.0265,0.4219,0.0026,0.0015,0.0055,0.0062,0.1191,0.0,0.0012,45.61,101.71,14.05,127.06
7,000005598.1130569,0.2394,0.0072,0.0401,0.0001,0.001,0.0255,0.2582,0.0013,0.017,0.0209,0.0024,0.3828,0.0003,0.0038,24.82,95.34,16.97,81.79
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
3100,000003719.1257129,0.0018,0.0096,0.0028,0.3933,0.0,0.0001,0.1794,0.0,0.0016,0.0,0.0004,0.4108,0.0001,0.0001,15.94,54.43,15.23,122.14
3103,000015353.fixed1024,0.083,0.0006,0.3497,0.0,0.0001,0.0,0.0971,0.0029,0.0,0.0019,0.0279,0.436,0.0008,0.0,12.49,75.26,17.48,55.49
3104,000011980.1210764,0.005,0.0008,0.0003,0.0,0.0,0.0001,0.0023,0.0,0.9781,0.0,0.0003,0.0117,0.0014,0.0,35.42,81.96,18.58,54.99
3105,000005567.1131812,0.043,0.0006,0.0,0.0,0.0,0.0,0.0187,0.0003,0.0,0.0001,0.0012,0.9361,0.0,0.0,12.68,77.90,17.09,105.14


In [149]:
# Save the data to a CSV
final_data.to_csv("data/cleaned_data.csv", index=False)