In [25]:
import pandas as pd

diet_cols = {
    "SEQN": "SEQN",
    "DR1TIRON": "Iron_intake_mg",
    "DR1TTFAT": "Fat_intake_g",
    "DR1TFIBE": "Fiber_intake_g"
}

labs_cols = {
    "SEQN": "SEQN",
    "LBXSIR": "Iron_serum_ug/dL"
}

diet = pd.read_csv("data/NHANES/diet.csv", usecols=lambda c: c in diet_cols.keys()).rename(columns=diet_cols)
labs = pd.read_csv("data/NHANES/labs.csv", usecols=lambda c: c in labs_cols.keys()).rename(columns=labs_cols)

nhanes = diet.merge(labs, on="SEQN", how="left")

fiber_std = nhanes['Fiber_intake_g'].std()
fat_std = nhanes['Fat_intake_g'].std()
iron_intake_std = nhanes['Iron_intake_mg'].std()
    
# For comparison
print(f'Fiber intake std: {fiber_std}')
print(f'Fat intake std: {fat_std}')
print(f'Iron intake std: {iron_intake_std}')

Fiber intake std: 10.132655979964545
Fat intake std: 45.50421029810836
Iron intake std: 8.544092329884053


In [26]:
import numpy as np
import pandas as pd

# Version 1: Links from microbiome -> diet

# Simulate dietary factors based on bacterial abundance
def simulate_dietary_intake(microbiome_sample):
    # Use direct indexing for each bacterium
    bifidobacterium_abundance = microbiome_sample["Bifidobacterium"]
    prevotella_abundance = microbiome_sample["Prevotella"]
    firmicutes_abundance = microbiome_sample["Firmicutes"]
    lactobacillus_abundance = microbiome_sample["Lactobacillus"]
    bacteroides_abundance = microbiome_sample["Bacteroides"]
    proteobacteria_abundance = microbiome_sample["Proteobacteria"]
    
    # Calculate dietary factors using these abundances
    fiber_intake = max(0, (
        15.28 +
        bacteroides_abundance * 25 +
        bifidobacterium_abundance * 15 +
        prevotella_abundance * 10 +
        np.random.normal(0, 10.13)
    ))
    
    fat_intake = max(0, (
        75.10 +
        firmicutes_abundance * 50 +
        lactobacillus_abundance * 5 +
        np.random.normal(0, 15)
    ))
    
    iron_intake = max(0, (
        14.06 +
        bacteroides_abundance * 10 +
        proteobacteria_abundance * 5 + 
        lactobacillus_abundance * 5 +
        np.random.normal(0, 2)
    ))

    return pd.Series({
        'Fiber_intake_g': np.round(fiber_intake, 2),
        'Fat_intake_g': np.round(fat_intake, 2),
        'Iron_intake_mg': np.round(iron_intake, 2)
    })


cleaned_ag_dataset = pd.read_csv('data/combined_ag_data.csv')

# Apply simulation to each microbiome sample
simulated_diet = cleaned_ag_dataset.apply(simulate_dietary_intake, axis=1)
print(simulated_diet.describe())

# Combine simulated diet with microbiome data
combined_data = pd.concat([cleaned_ag_dataset, simulated_diet], axis=1)
combined_data

       Fiber_intake_g  Fat_intake_g  Iron_intake_mg
count     2512.000000   2512.000000     2512.000000
mean        21.906819     97.863658       17.050840
std         10.970136     17.751831        2.717056
min          0.000000     21.540000        8.050000
25%         14.102500     85.547500       15.230000
50%         21.450000     97.800000       16.965000
75%         29.462500    110.190000       18.922500
max         58.850000    151.290000       26.460000


Unnamed: 0,SampleID,Actinobacteria,Bacteroides,Bacteroidetes,Bifidobacterium,Clostridium,Cyanobacteria,Firmicutes,Fusobacteria,Lactobacillus,Other,Prevotella,Proteobacteria,Tenericutes,Verrucomicrobia,Age,Sex,Fiber_intake_g,Fat_intake_g,Iron_intake_mg
0,000007117.1075649,0.3244,0.0013,0.0315,0.0000,0.0013,0.0352,0.2629,0.0082,0.0027,0.0461,0.0028,0.2821,0.0010,0.0005,59,1,20.90,84.43,15.25
1,000007115.1075661,0.0347,0.0087,0.1735,0.0018,0.0149,0.0289,0.1783,0.0032,0.0166,0.0153,0.0025,0.5196,0.0000,0.0020,59,1,18.87,85.39,12.39
2,000007123.1075697,0.0493,0.0000,0.0153,0.0000,0.0002,0.1265,0.0451,0.0025,0.1222,0.0024,0.0007,0.6347,0.0000,0.0011,59,1,12.96,65.19,18.19
3,000009713.1130401,0.4052,0.0010,0.0091,0.0000,0.0002,0.0265,0.4219,0.0026,0.0015,0.0055,0.0062,0.1191,0.0000,0.0012,70,0,45.61,101.71,14.05
4,000005598.1130569,0.2394,0.0072,0.0401,0.0001,0.0010,0.0255,0.2582,0.0013,0.0170,0.0209,0.0024,0.3828,0.0003,0.0038,72,1,24.82,95.34,16.97
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
2507,000003719.1257129,0.0018,0.0096,0.0028,0.3933,0.0000,0.0001,0.1794,0.0000,0.0016,0.0000,0.0004,0.4108,0.0001,0.0001,57,0,15.94,54.43,15.23
2508,000015353.fixed1024,0.0830,0.0006,0.3497,0.0000,0.0001,0.0000,0.0971,0.0029,0.0000,0.0019,0.0279,0.4360,0.0008,0.0000,35,1,12.49,75.26,17.48
2509,000011980.1210764,0.0050,0.0008,0.0003,0.0000,0.0000,0.0001,0.0023,0.0000,0.9781,0.0000,0.0003,0.0117,0.0014,0.0000,52,0,35.42,81.96,18.58
2510,000005567.1131812,0.0430,0.0006,0.0000,0.0000,0.0000,0.0000,0.0187,0.0003,0.0000,0.0001,0.0012,0.9361,0.0000,0.0000,51,1,12.68,77.90,17.09


In [27]:
import numpy as np

serum_iron_std = nhanes['Iron_serum_ug/dL'].std()
serum_iron_mean = nhanes['Iron_serum_ug/dL'].mean()

np.random.seed(42)

base = np.random.normal(serum_iron_mean, serum_iron_std * 0.6, len(combined_data))

signal_strength = 1

simulated_iron = (
    base
    + signal_strength * (
        # Dietary factors
        0.3 * combined_data["Iron_intake_mg"]
        + 0.05 * combined_data["Fiber_intake_g"]
        - 0.03 * combined_data["Fat_intake_g"]
        # Microbiome factors
        + 150 * combined_data["Lactobacillus"]
        + 100 * combined_data["Bifidobacterium"]
        + 80 * combined_data["Bacteroides"]
        - 100 * combined_data["Clostridium"]
        # Diet x microbiome
        + 3 * combined_data["Iron_intake_mg"] * combined_data["Lactobacillus"]
        + 1 * combined_data["Fiber_intake_g"] * combined_data["Bifidobacterium"]
        + 0.5 * combined_data["Fat_intake_g"] * combined_data["Lactobacillus"]
        # Diet x diet
        + 0.05 * combined_data["Iron_intake_mg"] * combined_data["Fiber_intake_g"]
        + 0.03 * combined_data["Iron_intake_mg"] * combined_data["Fat_intake_g"]
        # Microbiome x microbiome
        + 200 * combined_data["Lactobacillus"] * combined_data["Bifidobacterium"]
        + 150 * combined_data["Bacteroides"] * combined_data["Lactobacillus"]
        - 250 * combined_data["Firmicutes"] * combined_data["Clostridium"]
    )
    # Demographic factors
    - 0.1 * combined_data["Age"] - 0.001 * combined_data["Age"]**2
    + combined_data["Sex"].map({0: -5, 1: 5}) # females typically have lower iron
    - 60 # offset
)

np.random.seed(123)
interaction_noise = (
    np.random.normal(0, 0.5, len(combined_data)) * combined_data["Iron_intake_mg"] * combined_data["Bifidobacterium"]
    + np.random.normal(0, 0.3, len(combined_data)) * combined_data["Fiber_intake_g"] * combined_data["Clostridium"]
    + np.random.normal(0, 0.2, len(combined_data)) * combined_data["Fat_intake_g"] * combined_data["Bacteroides"]
)

simulated_iron += interaction_noise

# keep within physiological range
simulated_iron = np.clip(simulated_iron, 5, 557)
simulated_iron = np.round(np.array(simulated_iron), 2)

simulated_iron


array([ 88.14,  64.16, 113.68, ..., 306.59,  85.52,  91.97])

In [28]:
# Combine simulated diet with microbiome data
combined_data["Serum_iron_ug"] = simulated_iron

# Drop demographics since we won't use them for predicition
final_data = combined_data.drop(columns=['Age', 'Sex'])

print(final_data["Serum_iron_ug"].describe())
final_data

count    2512.000000
mean      109.540076
std        37.257427
min         5.000000
25%        82.297500
50%       107.110000
75%       134.055000
max       306.590000
Name: Serum_iron_ug, dtype: float64


Unnamed: 0,SampleID,Actinobacteria,Bacteroides,Bacteroidetes,Bifidobacterium,Clostridium,Cyanobacteria,Firmicutes,Fusobacteria,Lactobacillus,Other,Prevotella,Proteobacteria,Tenericutes,Verrucomicrobia,Fiber_intake_g,Fat_intake_g,Iron_intake_mg,Serum_iron_ug
0,000007117.1075649,0.3244,0.0013,0.0315,0.0000,0.0013,0.0352,0.2629,0.0082,0.0027,0.0461,0.0028,0.2821,0.0010,0.0005,20.90,84.43,15.25,88.14
1,000007115.1075661,0.0347,0.0087,0.1735,0.0018,0.0149,0.0289,0.1783,0.0032,0.0166,0.0153,0.0025,0.5196,0.0000,0.0020,18.87,85.39,12.39,64.16
2,000007123.1075697,0.0493,0.0000,0.0153,0.0000,0.0002,0.1265,0.0451,0.0025,0.1222,0.0024,0.0007,0.6347,0.0000,0.0011,12.96,65.19,18.19,113.68
3,000009713.1130401,0.4052,0.0010,0.0091,0.0000,0.0002,0.0265,0.4219,0.0026,0.0015,0.0055,0.0062,0.1191,0.0000,0.0012,45.61,101.71,14.05,118.54
4,000005598.1130569,0.2394,0.0072,0.0401,0.0001,0.0010,0.0255,0.2582,0.0013,0.0170,0.0209,0.0024,0.3828,0.0003,0.0038,24.82,95.34,16.97,88.68
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
2507,000003719.1257129,0.0018,0.0096,0.0028,0.3933,0.0000,0.0001,0.1794,0.0000,0.0016,0.0000,0.0004,0.4108,0.0001,0.0001,15.94,54.43,15.23,123.12
2508,000015353.fixed1024,0.0830,0.0006,0.3497,0.0000,0.0001,0.0000,0.0971,0.0029,0.0000,0.0019,0.0279,0.4360,0.0008,0.0000,12.49,75.26,17.48,48.26
2509,000011980.1210764,0.0050,0.0008,0.0003,0.0000,0.0000,0.0001,0.0023,0.0000,0.9781,0.0000,0.0003,0.0117,0.0014,0.0000,35.42,81.96,18.58,306.59
2510,000005567.1131812,0.0430,0.0006,0.0000,0.0000,0.0000,0.0000,0.0187,0.0003,0.0000,0.0001,0.0012,0.9361,0.0000,0.0000,12.68,77.90,17.09,85.52


In [29]:
# Save the data to a CSV
final_data.to_csv("data/cleaned_data.csv", index=False)