In [43]:
import polars as pl
import pandas as pd
import sys
from tqdm import tqdm
sys.path.append("../src")
import re
from gutatlas.data import filter_by_tag
from gutatlas.utils.constants import MENTAL_HEALTH_TAGS, GI_TAGS
from gutatlas.data import map_gi_status_binary, normalize_multilabel_gi_tags
from gutatlas.features import clean_feature_names
import janitor

## Build Training Sets

In [None]:
taxon_path = "../data/raw/taxonomic_table.csv"
metadata_path = "../data/raw/sample_metadata.tsv"
tags_path = "../data/raw/tags.tsv"

sample_metadata = pl.read_csv(metadata_path, separator="\t").with_columns(
    (pl.col("project") + "_" + pl.col("srr")).alias("sample")
)
sample_tags = pl.read_csv(tags_path, separator="\t").with_columns(
    (pl.col("project") + "_" + pl.col("srr")).alias("sample")
)

#taxon table is huge. open batches of 1000, normalize the abundances of each taxon on a per-sample basis
#then save each batch for easier processing going forward
reader = pl.read_csv_batched(taxon_path, batch_size=1000)
first_batch = True
i = 0
while True:
    batches = reader.next_batches(1)
    if not batches:
        break
    batch = batches[0]
    taxon_cols = batch.columns[2:]

    batch = batch.with_columns(
        pl.sum_horizontal([pl.col(col) for col in taxon_cols]).alias("total_reads")
    ).with_columns([
        (pl.col(col) / pl.col("total_reads")).alias(col) for col in taxon_cols
    ]).select(["sample"] + taxon_cols)

    merged = (
        batch.join(sample_metadata, on="sample", how="inner")
             .join(sample_tags, on="sample", how="left")
             .drop(["project_right", "srr_right", "srs_right", "total_bases", 
                    "instrument",'srs', 'project', 'srr', 'library_strategy', 'library_source'])
    )

    merged.write_parquet(f"../data/interim/batches/taxa_merged_batch_{i}.parquet")
    i += 1


#lazily read all batches, split by region, and save each region separately
batches = pl.scan_parquet('../data/interim/batches/taxa_merged_batch_*.parquet')

unique_regions = (
    batches.select("iso")
           .unique()
           .collect()
)

for region in unique_regions['iso']:
    split = batches.filter(pl.col('iso') == region).collect()

    split.write_parquet(f'../data/interim/regional_data/{region}_microbiome.parquet')

### GI tags

In [46]:
gi_merged = filter_by_tag('../data/interim/regional_data/',GI_TAGS)
gi_merged.write_parquet('../data/interim/filtered_and_merged/gi_microbiomes_merged.parquet')

#### binary classification set

In [53]:
inactive_cols = pd.read_csv('../data/processed/inactive_columns.csv')
#remove duplicate rows for each sample. only need to know if any disease is present or not 
merged_gi = pd.read_parquet('../data/interim/filtered_and_merged/gi_microbiomes_merged.parquet')
merged_gi['disease_present'] = merged_gi.value.apply(map_gi_status_binary)

gi_training = (merged_gi
                         .sort_values(by = ['disease_present','sample'],ascending=False)
                         .drop_duplicates(subset = 'sample',keep='first')
                         .reset_index(drop=True)
                         .drop(columns = ['pubdate','geo_loc_name','iso','region','tag','value'])
                         .rename(columns = {col:clean_feature_names(col) for col in merged_gi.columns})
                         .clean_names()
                         )


gi_training.to_parquet('../data/processed/gi_binary_training.parquet')

#### multilabel classification set

In [54]:
merged_gi = pd.read_parquet('../data/interim/filtered_and_merged/gi_microbiomes_merged.parquet')

In [55]:
normalized_gi_tags = normalize_multilabel_gi_tags(merged_gi)
merged_gi_normalized = pd.merge(left = merged_gi.drop(columns=['tag','value']), right= normalized_gi_tags, on = 'sample', how = 'left')
#keep disease classes specific
merged_gi_normalized = merged_gi_normalized[merged_gi_normalized.tag != 'GI_other'].dropna()

In [56]:
multilabels = merged_gi_normalized.pivot_table(index = 'sample',columns='tag',values='value', fill_value=0).reset_index()

gi_multilabel_training = (pd.merge(multilabels,merged_gi_normalized.drop_duplicates('sample'),'left','sample')
                          .drop(columns = ['pubdate','geo_loc_name','iso','region','tag','value'])
                          .clean_names()
)
gi_multilabel_training.to_parquet('../data/processed/gi_multilabel_training.parquet')


### Mental health tags. ended up being too few positives for use in modelling unfortunately :(

In [None]:
mental_health_merged = filter_by_tag('../data/interim/regional_data/',MENTAL_HEALTH_TAGS)
mental_health_merged.write_parquet('../data/interim/filtered_and_merged/mental_health_microbiomes_merged.parquet')

In [None]:
#i'm only interested in specific mental illnesses, not general/mixed values. I'm dropping anything vague/general thats mapped to -1
#i want to use this for multilabel classification

tag_mapping = {
    # Core combined
    "depression_bipolar_schizophrenia": -1,

    # General
    "mental_illness": -1,
    "mental_illness_type": -1,

    # Specific disorders
    "mental_illness_type_anorexia_nervosa": "anorexia",
    "mental_illness_type_bipolar_disorder": "bipolar_disorder",
    "mental_illness_type_bulimia_nervosa": "bulimia",
    "mental_illness_type_depression": "depression",
    "mental_illness_type_schizophrenia": "schizophrenia",
    "mental_illness_type_substance_abuse": "substance_abuse",
    "mental_illness_type_unspecified": -1,

    # PTSD variants (normalize both spellings)
    "mental_illness_type_ptsd_posttraumatic_stress_disorder": "ptsd",
    "mental_illness_type_ptsd_post_traumatic_stress_disorder": "ptsd",

    # Keep both depression-related signals
    "has_depression1": "depression",
    "has_depression2": "depression",

    #i wasn't able to locate what scale was being used for these numerical tags, so i'm dropping them
    "depression_index1": -1,
    "depression_index2": -1,

    "anxiety_index1": -1,
    "anxiety_index2": -1,

    "depression_level": -1,
    "depression_status": -1,
    "stress_level": -1,
    "stress_status": -1,
}

value_mapping = {
    'false':0,
    'no':0,
    'not provided':-1,
    'i do not have this condition':0,
    'self-diagnosed':1,
    'not collected':-1,
    'diagnosed by a medical professional (doctor, physician assistant)':1,
    'labcontrol test':1,
    'unspecified':-1
}

def bin_promis_scale(score):
    # https://www.sciencedirect.com/science/article/abs/pii/S0889159119315314?via%3Dihub
    #based on the above paper (where the promis score tag/values originated), under 21 and below is a good threshold for binary binning
    try:
        score =  int(score)
        return 0 if score < 22 else 1
    except (ValueError, TypeError):
        return score
    
mental_health_merged = pd.read_parquet('../data/interim/filtered_and_merged/mental_health_microbiomes_merged.parquet')
mental_health_merged.value = mental_health_merged.value.apply(bin_promis_scale)
mental_health_merged.tag = mental_health_merged.tag.map(tag_mapping)
mental_health_merged.value = mental_health_merged.value.map(value_mapping)
mental_health_merged = (mental_health_merged[~((mental_health_merged.tag == -1) | (mental_health_merged.value == -1))]
                        .dropna(how='any')
                        .reset_index(drop=True))


In [8]:
import os
import pandas as pd
merged = pd.DataFrame()
batch_dir = os.listdir('../data/interim/regional_data/')
for idx,batch in enumerate(batch_dir):
    print('starting',idx+1,'of',len(batch_dir))
    df = pd.read_parquet(f'../data/interim/regional_data/{batch}')
    merged = pd.concat([merged,df[['tag','value']]],ignore_index=True)
    del df
    print('finished',idx+1)



starting 1 of 69
finished 0
starting 2 of 69
finished 1
starting 3 of 69
finished 2
starting 4 of 69
finished 3
starting 5 of 69
finished 4
starting 6 of 69
finished 5
starting 7 of 69
finished 6
starting 8 of 69
finished 7
starting 9 of 69
finished 8
starting 10 of 69
finished 9
starting 11 of 69
finished 10
starting 12 of 69
finished 11
starting 13 of 69
finished 12
starting 14 of 69
finished 13
starting 15 of 69
finished 14
starting 16 of 69
finished 15
starting 17 of 69
finished 16
starting 18 of 69
finished 17
starting 19 of 69
finished 18
starting 20 of 69
finished 19
starting 21 of 69
finished 20
starting 22 of 69
finished 21
starting 23 of 69
finished 22
starting 24 of 69
finished 23
starting 25 of 69
finished 24
starting 26 of 69
finished 25
starting 27 of 69
finished 26
starting 28 of 69
finished 27
starting 29 of 69
finished 28
starting 30 of 69
finished 29
starting 31 of 69
finished 30
starting 32 of 69
finished 31
starting 33 of 69
finished 32
starting 34 of 69
finished 33