# Tanaka Dataset Preprocessing

Generates metadata CSVs from raw source data:
- **final_trimmed_db.csv**: sid → patient_id, age mapping
- **fix099_Supp.csv**: patient_id → allergen details

Output: `month_X.csv` files with columns `[sid, patient_id, label, allergen_class]`
- `label`: 0=healthy, 1=food allergy
- `allergen_class`: 'healthy', 'egg', 'milk', 'egg_milk_wheat', etc.

> **Note:** Excludes OA (Other Allergy) patients. Only includes samples with allergen info.

In [None]:
import pandas as pd
import numpy as np
from pathlib import Path

np.random.seed(42)

SCRIPT_DIR = Path('.')
OUTPUT_DIR = Path('../../../../huggingface_datasets/Tanaka/metadata')
MONTHS = [1, 2, 3, 6, 12, 24, 36]

## 1. Load Raw Data

In [None]:
# Load sid -> patient_id, age mapping
sid_mapping = pd.read_csv(SCRIPT_DIR / 'final_trimmed_db.csv', dtype=str)
sid_mapping.columns = ['idx', 'sid', 'patient_id', 'age_months']
sid_mapping = sid_mapping[['sid', 'patient_id', 'age_months']]
sid_mapping['age_months'] = sid_mapping['age_months'].astype(int)
print(f"Loaded {len(sid_mapping)} samples from final_trimmed_db.csv")

In [None]:
# Load allergen data (skip 5-row complex header)
allergen_raw = pd.read_csv(
    SCRIPT_DIR / 'fix099_Supp.csv', dtype=str, skiprows=5, header=None,
    names=['ID', 'Group', 'Gender', 'Mode_delivery', 'Milk_feeding', 
           'Antibiotics_m1', 'Antibiotics_m1_2', 'Antibiotics_m2_6', 'Antibiotics_m6_y1',
           'Mother_allergy', 'Egg', 'Milk', 'Soybean', 'Wheat', 'Onset_age', 
           'Atopic_dermatitis', 'Asthmatic', 'Rhinitis']
)
print(f"Loaded {len(allergen_raw)} patients")
print(f"Group distribution: {allergen_raw['Group'].value_counts(dropna=False).to_dict()}")

## 2. Create Labels

In [None]:
# Filter: Keep NaN (healthy) and FA (food allergy), exclude OA
allergen_df = allergen_raw[(allergen_raw['Group'].isna()) | (allergen_raw['Group'] == 'FA')].copy()
allergen_df['patient_id'] = allergen_df['ID'].str.strip()
allergen_df['label'] = (allergen_df['Group'] == 'FA').astype(int)

def get_allergen_class(row):
    if row['Group'] != 'FA':
        return 'healthy'
    allergens = []
    if str(row.get('Egg', '')).startswith('+'): allergens.append('egg')
    if str(row.get('Milk', '')).startswith('+'): allergens.append('milk')
    if str(row.get('Soybean', '')).startswith('+'): allergens.append('soybean')
    if str(row.get('Wheat', '')).startswith('+'): allergens.append('wheat')
    return '_'.join(sorted(allergens)) if allergens else 'food_allergy_unspecified'

allergen_df['allergen_class'] = allergen_df.apply(get_allergen_class, axis=1)
allergen_df = allergen_df[['patient_id', 'label', 'allergen_class']]

print(f"Filtered to {len(allergen_df)} patients (healthy + FA)")
print(f"\nAllergen classes:\n{allergen_df['allergen_class'].value_counts()}")

## 3. Join and Generate Month Files

In [None]:
# Join sid mapping with allergen data
all_data = sid_mapping.merge(allergen_df, on='patient_id', how='inner')
print(f"Total samples with allergen info: {len(all_data)}")

# Generate month files
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

for month in MONTHS:
    month_data = all_data[all_data['age_months'] == month].copy()
    
    if len(month_data) > 0:
        # Remove duplicates per patient
        month_data = month_data.sample(frac=1, random_state=42).drop_duplicates(subset='patient_id', keep='first')
        result = month_data[['sid', 'patient_id', 'label', 'allergen_class']].sort_values('sid').reset_index(drop=True)
    else:
        result = pd.DataFrame(columns=['sid', 'patient_id', 'label', 'allergen_class'])
    
    result.to_csv(OUTPUT_DIR / f'month_{month}.csv', index=False)
    status = f"Labels: {result['label'].value_counts().to_dict()}" if len(result) > 0 else "(no data)"
    print(f"month_{month}.csv: {len(result)} samples | {status}")

## 4. Summary

In [None]:
all_dfs = [pd.read_csv(OUTPUT_DIR / f'month_{m}.csv') for m in MONTHS]
combined = pd.concat(all_dfs, ignore_index=True)

print(f"Total samples: {len(combined)}")
print(f"\nLabel distribution:\n{combined['label'].value_counts()}")
print(f"\nAllergen classes:\n{combined['allergen_class'].value_counts()}")