# DIABIMMUNE Microbiome Data Preprocessing

This notebook preprocesses microbiome data from the **DIABIMMUNE** study for **early prediction of food allergies** (milk, egg, peanut).

## Key Challenges Solved

1. **SRS → Sample Linking**: Public SRS (Sequence Read Archive Sample) IDs from MicrobeAtlas need to be linked to internal study metadata, but there's no direct mapping column.

2. **Timepoint Recovery**: The public SRS data doesn't include `collection_month`, which is critical for longitudinal analysis. We recover this by querying the ENA (European Nucleotide Archive) API for `host_age`.

3. **Longitudinal Labeling**: For early prediction, we label ALL samples from a patient based on whether they **ever develop** a food allergy (not just their status at that timepoint).

## Output

- `diabimmune_longitudinal_labels.csv` — Full dataset with longitudinal labels
- `preprocessed_diabimmune_longitudinal/Month_X.csv` — Per-month files for training/evaluation

In [23]:
!pip install requests beautifulsoup4 pandas
!pip install atlasclient



[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m25.2[0m[39;49m -> [0m[32;49m25.3[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m
Collecting atlasclient
  Downloading atlasclient-1.0.0.tar.gz (46 kB)
  Preparing metadata (setup.py) ... [?25ldone
Building wheels for collected packages: atlasclient
[33m  DEPRECATION: Building 'atlasclient' using the legacy setup.py bdist_wheel mechanism, which will be removed in a future version. pip 25.3 will enforce this behaviour change. A possible replacement is to use the standardized build interface by setting the `--use-pep517` option, (possibly combined with `--no-build-isolation`), or adding a `pyproject.toml` file to the source tree of 'atlasclient'. Discussion can be found at https://github.com/pypa/pip/issues/6334[0m[33m
[0m  Building wheel for atlasclient (setup.py) ... [?25ldone
[?25h  Created wheel for atlasclient: 

## Step 1: Load Source Data

We have two main data sources:

1. **`download_samples_tsv-2.tsv`** — SRS IDs from MicrobeAtlas (public microbiome database)
2. **`metadata.csv`** — DIABIMMUNE study metadata with patient IDs, collection months, and allergy data

**Problem**: These don't share a common key! The SRS file has no `subjectID` or `collection_month`.

In [13]:
import pandas as pd

df_srs = pd.read_csv("download_samples_tsv-2.tsv", sep="\t", dtype=str)
df_srs = df_srs.rename(columns={"#sid": "SRS"})
df_srs["SRS"] = df_srs["SRS"].astype(str)

df_meta = pd.read_csv("metadata.csv", dtype=str)

print("SRS columns:", df_srs.columns.tolist())
print("Meta columns:", df_meta.columns.tolist())

SRS columns: ['SRS', 'name', 'note', 'sample_env', 'keywords_clean', 'taxa_stats', 'num_rids', 'num_hq_runs', 'rids', 'projects', 'publications']
Meta columns: ['subjectID', 'SampleID', 'age_at_collection', 'collection_month', 'delivery', 'gest_time', 'gender', 'country', 'Exclusive_breast_feeding', 'Breast_feeding_end', 'Regular_formula', 'hydrosylated_formula', 'partly_hydrosylated_formula', 'Any_baby_formula', 'Fruits_and_berries', 'Corn', 'Rice', 'Wheat', 'Oat', 'Barley', 'Rye', 'Cereal', 'Other_grains', 'Root_vegetables', 'Vegetables', 'Eggs', 'Soy', 'Milk', 'Meat', 'Fish', 'Other_food', 'Other_than_BF', 'bf_length', 'num_abx_treatments', 'num_abx_first_year', 'abx_first_year', 'after_abx', 'num_preceeding_abx', 'hla_risk_class', 'seroconverted', 'num_aabs', 'totalige', 'totalige_log', 'allergy_milk', 'allergy_egg', 'allergy_peanut', 'allergy_dustmite', 'totalige_high', 'allergy_cat', 'allergy_dog', 'allergy_birch', 'allergy_timothy', 'gid_wgs', 'mgx_reads', 'mgx_pool', 'mgx_reads

## Step 2: Link SRS to Patient IDs via ENA API

**Solution**: Each SRS accession has metadata stored in the ENA (European Nucleotide Archive). We can query it to extract `host_subject_id` (the patient ID).

The function below queries the ENA XML API and parses the `host_subject_id` attribute.

In [14]:
import requests
import re

def get_host_subject_id(srs_id: str) -> str | None:
    """
    Use ENA browser XML API to get host_subject_id for a given SRS accession.
    This matches the 'Original MetaData' you see in MicrobeAtlas.
    """
    url = f"https://www.ebi.ac.uk/ena/browser/api/xml/{srs_id}"
    try:
        r = requests.get(url, timeout=30)
        r.raise_for_status()
    except Exception as e:
        print(f"[ERROR] SRS={srs_id}: {e}")
        return None

    xml = r.text

    # ENA attribute structure is like:
    # <SAMPLE_ATTRIBUTE>
    #   <TAG>host_subject_id</TAG>
    #   <VALUE>P018832</VALUE>
    # </SAMPLE_ATTRIBUTE>
    m = re.search(
        r"<TAG>\s*host_subject_id\s*</TAG>\s*<VALUE>\s*([^<\s]+)\s*</VALUE>",
        xml,
        flags=re.IGNORECASE,
    )
    if m:
        return m.group(1)

    print(f"[WARN] no host_subject_id found in XML for {srs_id}")
    return None

In [15]:
test_srs = "SRS1719087"
print("Testing SRS:", test_srs)
print("host_subject_id:", get_host_subject_id(test_srs))

Testing SRS: SRS1719087
host_subject_id: P018832


In [16]:
import pandas as pd
import time

df_srs = pd.read_csv("download_samples_tsv-2.tsv", sep="\t", dtype=str)
df_srs = df_srs.rename(columns={"#sid": "SRS"})
df_srs["SRS"] = df_srs["SRS"].astype(str)

unique_srs = df_srs["SRS"].dropna().unique().tolist()
print("Number of unique SRS:", len(unique_srs))

rows = []
for i, srs in enumerate(unique_srs, start=1):
    subj = get_host_subject_id(srs)
    rows.append({"SRS": srs, "host_subject_id": subj})

    if i % 50 == 0:
        print(f"{i}/{len(unique_srs)} SRS processed")
        time.sleep(0.2)

df_map = pd.DataFrame(rows)
df_map.to_csv("srs_to_host_subject_id.csv", index=False)
df_map.head()

Number of unique SRS: 785
50/785 SRS processed
100/785 SRS processed
150/785 SRS processed
200/785 SRS processed
250/785 SRS processed
300/785 SRS processed
350/785 SRS processed
400/785 SRS processed
450/785 SRS processed
500/785 SRS processed
550/785 SRS processed
600/785 SRS processed
650/785 SRS processed
700/785 SRS processed
750/785 SRS processed


Unnamed: 0,SRS,host_subject_id
0,SRS1719087,P018832
1,SRS1719088,P017743
2,SRS1719089,P000648
3,SRS1719090,T022883
4,SRS1719091,T012374


## Step 3: Extract Collection Month from ENA

**Problem**: Even with `subjectID`, we can't merge directly because one patient has multiple samples across different months. Merging only on `subjectID` creates a **Cartesian product** (data leakage!).

**Solution**: The ENA metadata includes `host_age` (age in days at sample collection). We convert this to `collection_month`:

```
collection_month = round(host_age_days / 30.44)
```

This gives us the **SRS → collection_month** mapping we need.

In [24]:
import os

srs_month_csv = "srs_to_collection_month.csv"

if os.path.exists(srs_month_csv):
    print("="*70)
    print(f"{srs_month_csv} already exists. Loading instead of extracting...")
    df_srs_month = pd.read_csv(srs_month_csv, dtype=str)
    print(f"✓ Loaded {len(df_srs_month)} SRS samples from {srs_month_csv}")
    print(df_srs_month.head(10))
else:
    # SOLUTION: Extract host_age for all SRS samples and convert to collection_month
    print("="*70)
    print("EXTRACTING host_age FOR ALL SRS SAMPLES")
    print("="*70)

    # Extract for all unique SRS
    unique_srs = df_map['SRS'].dropna().unique().tolist()
    print(f"\nProcessing {len(unique_srs)} SRS samples...")

    srs_age_data = []
    failed_srs = []

    for i, srs in enumerate(unique_srs, start=1):
        metadata = get_sra_metadata(srs)
        
        if metadata and 'host_age' in metadata:
            try:
                age_days = int(metadata['host_age'])
                # Convert days to months (using 30.44 days/month average)
                age_months = round(age_days / 30.44)
                
                srs_age_data.append({
                    'SRS': srs,
                    'host_age_days': age_days,
                    'collection_month': age_months,
                    'host_subject_id_ena': metadata.get('host_subject_id', None)
                })
            except ValueError:
                print(f"  [{i}] {srs}: Invalid host_age value: {metadata['host_age']}")
                failed_srs.append(srs)
        else:
            print(f"  [{i}] {srs}: No host_age found")
            failed_srs.append(srs)
        
        # Progress indicator
        if i % 50 == 0:
            print(f"  Processed {i}/{len(unique_srs)} samples...")
            time.sleep(0.5)  # Be nice to the API

    print(f"\n✓ Successfully extracted {len(srs_age_data)} samples")
    print(f"✗ Failed for {len(failed_srs)} samples")

    # Create DataFrame with SRS → collection_month mapping
    df_srs_month = pd.DataFrame(srs_age_data)
    df_srs_month.to_csv(srs_month_csv, index=False)
    print(f"\nSaved to: {srs_month_csv}")
    print(df_srs_month.head(10))

srs_to_collection_month.csv already exists. Loading instead of extracting...
✓ Loaded 785 SRS samples from srs_to_collection_month.csv
          SRS host_age_days collection_month host_subject_id_ena
0  SRS1719087           686               23             P018832
1  SRS1719088           173                6             P017743
2  SRS1719089           493               16             P000648
3  SRS1719090           229                8             T022883
4  SRS1719091           502               16             T012374
5  SRS1719092           390               13             T016811
6  SRS1719093           427               14             T017394
7  SRS1719094           587               19             T007750
8  SRS1719095           598               20             T003950
9  SRS1719096           594               20             T004341


## Step 4: Correct Merge on (subjectID, collection_month)

Now we can perform a **correct merge** using BOTH `subjectID` AND `collection_month` as keys. This ensures:

- Each SRS maps to exactly one timepoint
- No data leakage between timepoints
- Allergy labels correspond to the correct sample

In [25]:
# CORRECT MERGE: Use SRS → collection_month mapping
print("="*70)
print("PERFORMING CORRECT MERGE WITH COLLECTION MONTH")
print("="*70)

# Load the extracted SRS → collection_month mapping
df_srs_month = pd.read_csv("srs_to_collection_month.csv", dtype=str)
df_srs_month['collection_month'] = pd.to_numeric(df_srs_month['collection_month'], errors='coerce').astype('Int64')

# Load metadata
df_meta = pd.read_csv("metadata.csv", dtype=str)
df_meta['collection_month'] = pd.to_numeric(df_meta['collection_month'], errors='coerce').astype('Int64')

# Detect subject column in metadata
subj_col = None
for c in df_meta.columns:
    if c.lower() in ("subjectid", "subject_id", "host_subject_id", "host_subjectid"):
        subj_col = c
        break

if subj_col:
    df_meta = df_meta.rename(columns={subj_col: "subjectID"})
    df_srs_month = df_srs_month.rename(columns={"host_subject_id_ena": "subjectID"})

print(f"✓ Loaded {len(df_srs_month)} SRS samples with collection months")
print(f"✓ Loaded {len(df_meta)} metadata rows")

# CORRECT MERGE: on BOTH subjectID AND collection_month
print("\nPerforming merge on ['subjectID', 'collection_month']...")
df_merged_correct = df_srs_month.merge(
    df_meta, 
    on=['subjectID', 'collection_month'], 
    how='left'
)

print(f"\n✓ Merge complete: {len(df_merged_correct)} rows")
print(f"  (Should be same as number of SRS samples: {len(df_srs_month)})")

# Check for any duplicates
duplicate_check = df_merged_correct.groupby('SRS').size()
duplicates = duplicate_check[duplicate_check > 1]

if len(duplicates) > 0:
    print(f"\n⚠ WARNING: {len(duplicates)} SRS samples still have duplicates after merge!")
    print("First few duplicates:")
    print(duplicates.head(10))
else:
    print("\n✓ SUCCESS: No sample appears in multiple rows!")

# Show sample of merged data
print("\nSample merged data:")
print(df_merged_correct[['SRS', 'subjectID', 'country', 'collection_month', 'allergy_milk', 'allergy_egg', 'allergy_peanut']].head(10))

PERFORMING CORRECT MERGE WITH COLLECTION MONTH
✓ Loaded 785 SRS samples with collection months
✓ Loaded 1946 metadata rows

Performing merge on ['subjectID', 'collection_month']...

✓ Merge complete: 826 rows
  (Should be same as number of SRS samples: 785)

First few duplicates:
SRS
SRS1719147    2
SRS1719148    2
SRS1719152    3
SRS1719166    2
SRS1719199    2
SRS1719227    2
SRS1719237    2
SRS1719283    2
SRS1719304    2
SRS1719316    2
dtype: int64

Sample merged data:
          SRS subjectID country  collection_month allergy_milk allergy_egg  \
0  SRS1719087   P018832     RUS                23        False       False   
1  SRS1719088   P017743     RUS                 6          NaN         NaN   
2  SRS1719089   P000648     RUS                16        False       False   
3  SRS1719090   T022883     EST                 8         True        True   
4  SRS1719091   T012374     EST                16        False       False   
5  SRS1719092   T016811     EST                13    

## Step 5: Handle Duplicate Rows

Some patients have **multiple physical samples** (different `SampleID`s) collected in the same month. After the merge, these create duplicate rows for the same SRS.

**Solution**: Aggregate by SRS, taking the maximum allergy value (if ANY sample shows allergy, mark as allergic).

In [26]:
# Investigate and resolve duplicates
print("="*70)
print("INVESTIGATING DUPLICATES")
print("="*70)

# Find which SRS have duplicates
duplicate_srs = df_merged_correct.groupby('SRS').size()
duplicate_srs = duplicate_srs[duplicate_srs > 1].index.tolist()

print(f"\n{len(duplicate_srs)} SRS samples have duplicates")
print("\nExample: SRS1719152 (has 3 rows)")
example = df_merged_correct[df_merged_correct['SRS'] == 'SRS1719152'][['SRS', 'subjectID', 'country', 'collection_month', 'SampleID', 'allergy_milk', 'allergy_egg', 'allergy_peanut']]
print(example)

print("\nREASON: Some patients have multiple physical samples (different SampleIDs) collected in the same month.")
print("SOLUTION: Aggregate allergy data - if ANY sample in that month shows allergy, mark as allergic.")

# Aggregate: for each SRS, take max of allergy values (True > False)
# This ensures if any sample that month was allergic, we mark it as allergic

def aggregate_allergy_data(group):
    """For duplicate rows, aggregate allergy info"""
    result = group.iloc[0].copy()  # Start with first row
    
    # For allergy columns, take the max (True > False)
    allergy_cols = ['allergy_milk', 'allergy_egg', 'allergy_peanut', 
                    'allergy_dustmite', 'totalige_high', 'allergy_cat', 
                    'allergy_dog', 'allergy_birch', 'allergy_timothy']
    
    for col in allergy_cols:
        if col in group.columns:
            # Convert to boolean, taking True if any row is True
            vals = group[col].fillna('False').astype(str).str.lower()
            result[col] = 'True' if any(v in {'1', 'true', 'yes'} for v in vals) else 'False'
    
    return result

print("\nAggregating duplicate rows...")
df_final = df_merged_correct.groupby('SRS', as_index=False).apply(aggregate_allergy_data)
df_final = df_final.reset_index(drop=True)

print(f"✓ After aggregation: {len(df_final)} rows (one per SRS)")
print(f"✓ No duplicates: {df_final['SRS'].nunique() == len(df_final)}")

print("\nFinal data sample:")
print(df_final[['SRS', 'subjectID', 'country', 'collection_month', 'allergy_milk', 'allergy_egg', 'allergy_peanut']].head(10))

INVESTIGATING DUPLICATES

40 SRS samples have duplicates

Example: SRS1719152 (has 3 rows)
           SRS subjectID country  collection_month SampleID allergy_milk  \
67  SRS1719152   P020604     RUS                13  3104057          NaN   
68  SRS1719152   P020604     RUS                13  3104056          NaN   
69  SRS1719152   P020604     RUS                13  3104053          NaN   

   allergy_egg allergy_peanut  
67         NaN            NaN  
68         NaN            NaN  
69         NaN            NaN  

REASON: Some patients have multiple physical samples (different SampleIDs) collected in the same month.
SOLUTION: Aggregate allergy data - if ANY sample in that month shows allergy, mark as allergic.

Aggregating duplicate rows...
✓ After aggregation: 785 rows (one per SRS)
✓ No duplicates: True

Final data sample:
          SRS subjectID country  collection_month allergy_milk allergy_egg  \
0  SRS1719087   P018832     RUS                23        False       False   
1 

  df_final = df_merged_correct.groupby('SRS', as_index=False).apply(aggregate_allergy_data)


## Step 6: Longitudinal Labeling for Early Prediction

**This is the key step for our prediction task!**

For **early allergy prediction**, we want to identify patients who will **eventually develop** an allergy, even from samples collected before the allergy manifests.

**Strategy**: For each patient, compute their **maximum/latest** allergy status across ALL timepoints. Then apply this label to ALL their samples.

| Patient | Month 4 Status | Month 10 Status | Label Applied |
|---------|----------------|-----------------|---------------|
| T022883 | Not allergic   | Milk allergy    | **Allergic** (all samples) |

This enables the model to learn early biomarkers that predict future allergy development.

In [27]:
# LONGITUDINAL LABELING: Option 3 - Latest/Maximum allergy status per patient
print("="*70)
print("APPLYING LONGITUDINAL LABELING STRATEGY")
print("="*70)

# Food allergy columns
food_allergy_cols = [
    "allergy_milk",
    "allergy_egg",
    "allergy_peanut",
]

def is_allergic(val):
    """Check if a value represents allergic status"""
    v = str(val).strip().lower()
    return v in {"1", "true", "yes"}

# For each patient, compute their MAXIMUM/LATEST allergy status across all timepoints
print("\nStep 1: Computing maximum allergy status per patient...")

patient_max_allergy = df_final.groupby('subjectID').agg({
    'allergy_milk': lambda x: any(is_allergic(v) for v in x),
    'allergy_egg': lambda x: any(is_allergic(v) for v in x),
    'allergy_peanut': lambda x: any(is_allergic(v) for v in x)
}).reset_index()

print(f"✓ Computed max allergy status for {len(patient_max_allergy)} patients")
print("\nSample patient max allergy status:")
print(patient_max_allergy.head(10))

# Apply longitudinal labels: all samples from a patient get labeled with their max allergy
print("\nStep 2: Applying longitudinal labels to all samples...")

df_longitudinal = df_final.drop(columns=food_allergy_cols).merge(
    patient_max_allergy, 
    on='subjectID', 
    how='left'
)

# Create binary label and allergen class
def get_food_allergy_label(row):
    """Binary label: 1 if any food allergy, 0 otherwise"""
    return 1 if any([row['allergy_milk'], row['allergy_egg'], row['allergy_peanut']]) else 0

def get_allergen_class(row):
    """
    Allergen class label:
    0 = non-allergic
    1 = milk allergy only
    2 = egg allergy only
    3 = peanut allergy only
    4 = multiple food allergies
    """
    allergies = []
    if row['allergy_milk']:
        allergies.append(1)
    if row['allergy_egg']:
        allergies.append(2)
    if row['allergy_peanut']:
        allergies.append(3)
    
    if len(allergies) == 0:
        return 0
    elif len(allergies) == 1:
        return allergies[0]
    else:
        return 4

df_longitudinal["label"] = df_longitudinal.apply(get_food_allergy_label, axis=1)
df_longitudinal["allergen_class"] = df_longitudinal.apply(get_allergen_class, axis=1)

print(f"✓ Labeled {len(df_longitudinal)} samples")
print("\nSample labeled data:")
print(df_longitudinal[['SRS', 'subjectID', 'country', 'collection_month', 'allergy_milk', 'allergy_egg', 'allergy_peanut', 'label', 'allergen_class']].head(20))

APPLYING LONGITUDINAL LABELING STRATEGY

Step 1: Computing maximum allergy status per patient...
✓ Computed max allergy status for 212 patients

Sample patient max allergy status:
  subjectID  allergy_milk  allergy_egg  allergy_peanut
0   E002338          True        False           False
1   E002473         False        False           False
2   E002681         False        False           False
3   E002825          True        False           False
4   E003393         False        False           False
5   E004071         False        False           False
6   E004080         False        False           False
7   E004781          True         True           False
8   E004934         False         True           False
9   E005804         False        False           False

Step 2: Applying longitudinal labels to all samples...
✓ Labeled 785 samples

Sample labeled data:
           SRS subjectID country  collection_month  allergy_milk  allergy_egg  \
0   SRS1719087   P018832     RUS  

## Step 7: Verification

Let's verify the longitudinal labeling is working correctly by checking a patient who develops an allergy.

In [28]:
# Verify longitudinal labeling
print("="*70)
print("VERIFICATION: Checking longitudinal labeling")
print("="*70)

# Example: Check a patient who develops allergy later
# Find a patient with milk allergy
allergic_patient = df_longitudinal[df_longitudinal['allergy_milk'] == True]['subjectID'].iloc[0] if any(df_longitudinal['allergy_milk'] == True) else None

if allergic_patient:
    print(f"\nExample: Patient {allergic_patient} (develops milk allergy)")
    patient_samples = df_longitudinal[df_longitudinal['subjectID'] == allergic_patient][
        ['SRS', 'country', 'collection_month', 'allergy_milk', 'allergy_egg', 'allergy_peanut', 'label']
    ].sort_values('collection_month')
    print(patient_samples)
    print("\n✓ All samples from this patient are labeled as allergic (label=1)")
else:
    print("\nNo allergic patients found in dataset")

# Summary statistics
print("\n" + "="*70)
print("SUMMARY STATISTICS")
print("="*70)
print(f"Total samples: {len(df_longitudinal)}")
print(f"Unique patients: {df_longitudinal['subjectID'].nunique()}")
print(f"Collection months range: {df_longitudinal['collection_month'].min()} - {df_longitudinal['collection_month'].max()}")
print(f"\nLabel distribution:")
print(df_longitudinal['label'].value_counts().sort_index())
print(f"\nAllergen class distribution:")
print(df_longitudinal['allergen_class'].value_counts().sort_index())

VERIFICATION: Checking longitudinal labeling

Example: Patient T022883 (develops milk allergy)
            SRS country  collection_month  allergy_milk  allergy_egg  \
535  SRS1735472     EST                 4          True         True   
3    SRS1719090     EST                 8          True         True   
444  SRS1719531     EST                10          True         True   
92   SRS1719179     EST                13          True         True   
113  SRS1719200     EST                16          True         True   

     allergy_peanut  label  
535           False      1  
3             False      1  
444           False      1  
92            False      1  
113           False      1  

✓ All samples from this patient are labeled as allergic (label=1)

SUMMARY STATISTICS
Total samples: 785
Unique patients: 212
Collection months range: 1 - 38

Label distribution:
label
0    527
1    258
Name: count, dtype: int64

Allergen class distribution:
allergen_class
0    527
1    101
2    

## Step 7.5: Impute Missing Country Information

Some samples may have missing country data. Since a patient's country doesn't change over time, we can impute missing country values using the available country information from other samples of the same patient.

In [31]:
print("="*70)
print("IMPUTING MISSING COUNTRY VALUES")
print("="*70)

# Check for missing country values
missing_country_before = df_longitudinal['country'].isna().sum()
print(f"\nBefore imputation: {missing_country_before} samples with missing country")

if missing_country_before > 0:
    # Show samples with missing country
    missing_samples = df_longitudinal[df_longitudinal['country'].isna()][['SRS', 'subjectID', 'country', 'collection_month']]
    print(f"\nSamples with missing country:")
    print(missing_samples)
    
    # For each patient, fill missing country with the most common country value for that patient
    def impute_country(group):
        """Fill missing country values within a patient group"""
        # Get non-null country values for this patient
        valid_countries = group['country'].dropna()
        
        if len(valid_countries) > 0:
            # Use the most common country (should be unique per patient, but just in case)
            most_common_country = valid_countries.mode()
            if len(most_common_country) > 0:
                group['country'] = group['country'].fillna(most_common_country.iloc[0])
        
        return group
    
    # Apply imputation by patient
    df_longitudinal = df_longitudinal.groupby('subjectID', group_keys=False).apply(impute_country)
    
    # Check results
    missing_country_after = df_longitudinal['country'].isna().sum()
    imputed_count = missing_country_before - missing_country_after
    
    print(f"\n✓ Imputed {imputed_count} country values")
    print(f"After imputation: {missing_country_after} samples still missing country")
    
    if imputed_count > 0:
        # Show what was imputed
        imputed_samples = df_longitudinal[df_longitudinal['SRS'].isin(missing_samples['SRS'])][['SRS', 'subjectID', 'country', 'collection_month']]
        print(f"\nSamples after imputation:")
        print(imputed_samples)
else:
    print("\n✓ No missing country values found!")

IMPUTING MISSING COUNTRY VALUES

Before imputation: 1 samples with missing country

Samples with missing country:
            SRS subjectID country  collection_month
501  SRS1735438   P005558     NaN                36

✓ Imputed 1 country values
After imputation: 0 samples still missing country

Samples after imputation:
            SRS subjectID country  collection_month
501  SRS1735438   P005558     RUS                36


  df_longitudinal = df_longitudinal.groupby('subjectID', group_keys=False).apply(impute_country)


## Step 8: Save Output Files

Save the final preprocessed dataset with longitudinal labels.

In [29]:
# Save the longitudinally-labeled dataset
output_file = "diabimmune_longitudinal_labels.csv"
df_longitudinal.to_csv(output_file, index=False)

print("="*70)
print(f"✓ Saved longitudinally-labeled dataset to: {output_file}")
print("="*70)
print(f"Columns: {df_longitudinal.columns.tolist()}")
print(f"\nFirst 10 rows:")
print(df_longitudinal[['SRS', 'subjectID', 'country', 'collection_month', 'label', 'allergen_class']].head(10))

✓ Saved longitudinally-labeled dataset to: diabimmune_longitudinal_labels.csv
Columns: ['SRS', 'host_age_days', 'collection_month', 'subjectID', 'SampleID', 'age_at_collection', 'delivery', 'gest_time', 'gender', 'country', 'Exclusive_breast_feeding', 'Breast_feeding_end', 'Regular_formula', 'hydrosylated_formula', 'partly_hydrosylated_formula', 'Any_baby_formula', 'Fruits_and_berries', 'Corn', 'Rice', 'Wheat', 'Oat', 'Barley', 'Rye', 'Cereal', 'Other_grains', 'Root_vegetables', 'Vegetables', 'Eggs', 'Soy', 'Milk', 'Meat', 'Fish', 'Other_food', 'Other_than_BF', 'bf_length', 'num_abx_treatments', 'num_abx_first_year', 'abx_first_year', 'after_abx', 'num_preceeding_abx', 'hla_risk_class', 'seroconverted', 'num_aabs', 'totalige', 'totalige_log', 'allergy_dustmite', 'totalige_high', 'allergy_cat', 'allergy_dog', 'allergy_birch', 'allergy_timothy', 'gid_wgs', 'mgx_reads', 'mgx_pool', 'mgx_reads_filtered', 'read_count_16S', 'sequencing_PDO_16S', 'gid_16s', 'allergy_milk', 'allergy_egg', 'all

## Step 9: Create Per-Month Files

Generate individual CSV files for each collection month. These are useful for:
- Month-specific model training
- Temporal analysis
- Cross-validation strategies

**Important**: Each SRS appears in exactly ONE month (no leakage between files).

In [30]:
# Optional: Create per-month files for downstream analysis
import os

output_dir = "preprocessed_diabimmune_longitudinal"
os.makedirs(output_dir, exist_ok=True)

print("="*70)
print(f"Creating per-month files in: {output_dir}/")
print("="*70)

# Sanity check: Each SRS should appear in exactly ONE month
sample_month_counts = df_longitudinal.groupby("SRS")["collection_month"].nunique()
leaked_samples = sample_month_counts[sample_month_counts > 1]

if len(leaked_samples) > 0:
    print(f"⚠ WARNING: {len(leaked_samples)} samples appear in multiple months!")
    print("This should not happen with correct preprocessing!")
else:
    print(f"✓ All {len(df_longitudinal)} samples appear in exactly one month (no leakage)\n")

# Write per-month files
for month, grp in df_longitudinal.groupby("collection_month"):
    out = grp[["SRS", "subjectID", "country", "label", "allergen_class"]].copy()
    out = out.rename(columns={"SRS": "sid", "subjectID": "patient_id"})
    
    fname = f"Month_{int(month)}.csv"
    fpath = os.path.join(output_dir, fname)
    out.to_csv(fpath, index=False)
    
    # Count labels
    allergic_count = (out['label'] == 1).sum()
    non_allergic_count = (out['label'] == 0).sum()
    print(f"Month {int(month):2d}: {len(out):3d} samples ({allergic_count} allergic, {non_allergic_count} non-allergic)")

print("\n✓ Per-month files created successfully!")

Creating per-month files in: preprocessed_diabimmune_longitudinal/
✓ All 785 samples appear in exactly one month (no leakage)

Month  1:  20 samples (9 allergic, 11 non-allergic)
Month  2:  17 samples (11 allergic, 6 non-allergic)
Month  3:   8 samples (2 allergic, 6 non-allergic)
Month  4:  39 samples (24 allergic, 15 non-allergic)
Month  5:  11 samples (1 allergic, 10 non-allergic)
Month  6:  15 samples (5 allergic, 10 non-allergic)
Month  7:  58 samples (16 allergic, 42 non-allergic)
Month  8:  16 samples (4 allergic, 12 non-allergic)
Month  9:  21 samples (6 allergic, 15 non-allergic)
Month 10:  55 samples (17 allergic, 38 non-allergic)
Month 11:  27 samples (4 allergic, 23 non-allergic)
Month 12:  20 samples (5 allergic, 15 non-allergic)
Month 13:  60 samples (24 allergic, 36 non-allergic)
Month 14:  25 samples (5 allergic, 20 non-allergic)
Month 15:  11 samples (1 allergic, 10 non-allergic)
Month 16:  72 samples (28 allergic, 44 non-allergic)
Month 17:  25 samples (4 allergic, 21