In [1]:
import os
import sys

os.chdir('/container/mount/point/')

import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
import plotly.express as px
import plotly.graph_objects as go

import pickle as pkl

from tqdm import tqdm
from sklearn import preprocessing
from scipy.cluster.hierarchy import linkage, leaves_list

from gglasso.problem import glasso_problem
from gglasso.helper.basic_linalg import adjacency_matrix
from gglasso.helper.basic_linalg import scale_array_by_diagonal

from utils.helper import transform_features, scale_array_by_diagonal
from utils.preprocessing import add_noise_to_data

In [3]:
asv = pd.read_csv("data/feature_table.tsv", index_col=0, sep='\t')
control_sample_df = pd.read_csv("data/control_samples.csv", sep=",", index_col=0, low_memory=False)
cluster_df = pd.read_csv("data/cluster_df_latent.csv", sep=",", index_col =0, low_memory=False)
covariates = pd.read_csv("data/54subset_latent.csv", sep=",", index_col='u3_16s_id', low_memory=False)

asv_ige_508_filtered = pd.read_csv("data/asv_ige_508.csv", sep=",", index_col=0, low_memory=False)
co_occurrence_matrix = pd.read_csv("data/co_occurrence_matrix.csv", sep=",", index_col=0, low_memory=False)
taxa = pd.read_csv('data/taxonomy_clean.csv', sep=',', index_col=0)
taxa["ASV"] = taxa.index

# Subset the ASV data to include only columns 508 IgE samples
asv_ige_508_all = asv[asv_ige_508_filtered.columns].copy()

### Create datasets

In [None]:
#### Drop ASVs which are not present in 508 samples

clusters = ["A", "B", "C", "D", "E", "F", "I", "H", "G"] 

cluster_asv_data = {}

# Loop through each unique cluster
for cluster in clusters:
    # Get the sample IDs that belong to the current cluster
    sample_ids = cluster_df.loc[cluster_df['cluster'] == cluster].index.astype(str)
    
    # Select the columns in asv_ige_508_all corresponding to these sample IDs
    cluster_asv = asv_ige_508_all.loc[:, asv_ige_508_all.columns.isin(sample_ids)]
    
    # Store the ASV data for the current cluster in the dictionary
    cluster_asv_data[cluster] = cluster_asv

    
healthy_asv = asv_ige_508_all[control_sample_df.index.astype(str)]
allergic_asv = asv_ige_508_all.loc[:, ~asv_ige_508_all.columns.isin(control_sample_df.index.astype(str))]

cluster_asv_data["healthy"] = healthy_asv
cluster_asv_data["allergic"] = allergic_asv
cluster_asv_data["all"] = asv_ige_508_all

In [9]:
levels = ["phylum", "class", "order", "family", "genus", "species", "ASV"]

# Define thresholds to iterate over
thresholds = [0.01, 0.05, 0.1, 0.2]

# Dictionary to store results for each threshold
results_by_threshold = {}

for threshold in thresholds:
    print(f"\n{'='*60}")
    print(f"PROCESSING THRESHOLD: {threshold}")
    print(f"{'='*60}")
    
    data_dict = dict()
    
    for cluster, asv_table in cluster_asv_data.items():
        print(f"\n=== Cluster: {cluster} ===")

        ASV_table = asv_table.copy()
        # Initialize dictionary for the current cluster
        if cluster not in data_dict:
            data_dict[cluster] = {}

        #################### DROP ZERO FEATURES ##############################
        print(f"ASVs BEFORE dropping zero features: p = {ASV_table.shape[0]}")

        ### Frequency of bacterium across all samples
        taxa_freq = ASV_table.sum(axis=1)

        ### drop features with NO COUNTS
        non_zero_taxa = taxa_freq[taxa_freq > 0]

        ### Filter ASV table based on non-zero frequencies
        ASV_table_non_zero = ASV_table[ASV_table.index.isin(non_zero_taxa.index)]

        ### Filter columns with zero mean
        asv_samples_ids = set(ASV_table_non_zero.columns)
        means = ASV_table_non_zero.mean()
        zero_mean_cols = means[means == 0].index

        ### Drop columns with zero mean if any
        if any(zero_mean_cols):
            print(f"Zero mean features found and removed: {list(zero_mean_cols)}")
            ASV_table_non_zero = ASV_table_non_zero.drop(zero_mean_cols, axis=1)

        print(f"ASVs AFTER dropping zero features: p = {ASV_table_non_zero.shape[0]}")

        for level in levels:
            print(f"\n--- Level: {level} ---")

            ##################### AGGREGATION ##############################
            counts_plus_label = ASV_table_non_zero.join(taxa[level])
            counts = counts_plus_label.groupby(level).sum()

            ### DO THE FILTERING ON EVERY LEVEL
            counts_freq = counts.astype(bool).sum(axis=1) / counts.shape[1]
            filter_threshold = counts_freq[counts_freq > threshold]
            counts_filtered = counts[counts.index.isin(filter_threshold.index)]

            ### Apply CLR transformation
            clr_counts = transform_features(counts_filtered, transformation='clr')
            mclr_counts = transform_features(counts_filtered, transformation='mclr')
            
            ### Calculate covariance
            S0 = np.cov(clr_counts.values, bias = True)
            S = scale_array_by_diagonal(S0)
            corr_df = pd.DataFrame(S, index=clr_counts.index, columns=clr_counts.index)

            data_dict[cluster][level] = {'raw_counts': counts_filtered, 
                                         'clr_counts': clr_counts, 
                                         'mclr_counts': mclr_counts,
                                         "corr": corr_df}

            print(f"{level} count table shape: p = {counts_filtered.shape[0]}, N = {counts_filtered.shape[1]}")
    
    # Store the data_dict for this threshold
    results_by_threshold[threshold] = data_dict
    
    # # Save to pickle file
    with open(f'data_dict_threshold_{threshold}.pkl', 'wb') as f:
        pkl.dump(data_dict, f)
    print(f"\nSaved data_dict for threshold {threshold} to 'data_dict_threshold_{threshold}.pkl'")


PROCESSING THRESHOLD: 0.01

=== Cluster: A ===
ASVs BEFORE dropping zero features: p = 15170
ASVs AFTER dropping zero features: p = 758

--- Level: phylum ---
phylum count table shape: p = 10, N = 12

--- Level: class ---
class count table shape: p = 15, N = 12

--- Level: order ---
order count table shape: p = 30, N = 12

--- Level: family ---
family count table shape: p = 62, N = 12

--- Level: genus ---
genus count table shape: p = 249, N = 12

--- Level: species ---
species count table shape: p = 710, N = 12

--- Level: ASV ---
ASV count table shape: p = 758, N = 12

=== Cluster: B ===
ASVs BEFORE dropping zero features: p = 15170
ASVs AFTER dropping zero features: p = 1276

--- Level: phylum ---
phylum count table shape: p = 11, N = 24

--- Level: class ---
class count table shape: p = 17, N = 24

--- Level: order ---
order count table shape: p = 43, N = 24

--- Level: family ---
family count table shape: p = 89, N = 24

--- Level: genus ---
genus count table shape: p = 369, N = 

In [4]:
# with open('data/cluster_count_dict.pkl', 'wb') as f:
#     pkl.dump(data_dict, f)

threshold = 0.2
    
with open(f'data_dict_threshold_{threshold}.pkl', 'rb') as f:
    data_dict = pkl.load(f)

In [5]:
level = 'family'

# Extract the counts
counts_healthy = data_dict["healthy"][level]["raw_counts"]
counts_allergic = data_dict["allergic"][level]["raw_counts"]
counts_A = data_dict["A"][level]["raw_counts"]
counts_E = data_dict["E"][level]["raw_counts"]
counts_D = data_dict["D"][level]["raw_counts"]
counts_B = data_dict["B"][level]["raw_counts"]

# Save as CSV files
counts_healthy.to_csv(f'data/counts_healthy_{level}.csv')
counts_allergic.to_csv(f'data/counts_allergic_{level}.csv')
counts_A.to_csv(f'data/counts_A_{level}.csv')
counts_E.to_csv(f'data/counts_E_{level}.csv')
counts_D.to_csv(f'data/counts_D_{level}.csv')
counts_B.to_csv(f'data/counts_B_{level}.csv')

print("✓ Saved:")
print(f"  - data/counts_healthy_{level}.csv")
print(f"  - data/counts_allergic_{level}.csv")
print(f"  - data/counts_A_{level}.csv")
print(f"  - data/counts_E_{level}.csv")
print(f"  - data/counts_D_{level}.csv")
print(f"  - data/counts_B_{level}.csv")

✓ Saved:
  - data/counts_healthy_family.csv
  - data/counts_allergic_family.csv
  - data/counts_A_family.csv
  - data/counts_E_family.csv
  - data/counts_D_family.csv
  - data/counts_B_family.csv


#### We use R package NetCoMi for the visualization (code is available in RMarkdown)