In [3]:
# imports
import polars as pl
import numpy as np
import sklearn.tree
import pandas as pd
import pyarrow
import matplotlib.pyplot as plt

import polars as pl
import numpy as np
import sklearn.tree
from sklearn.tree import DecisionTreeClassifier
from sklearn.metrics import balanced_accuracy_score
from sklearn.model_selection import train_test_split
from sklearn.model_selection import StratifiedKFold

In [11]:
#THRESHOLD_HASHES=4 # scaled=500k
THRESHOLD_HASHES=20 # scaled=100k

In [12]:
# get metadata, only compare WGS
metadata_df = (
    pl.scan_parquet("/group/ctbrowngrp5/sra-metagenomes/20241128-metadata.parquet")
    .filter(pl.col("acc") != "NP")
    .filter(pl.col("assay_type") == "WGS")
    .select(["acc", "organism", "bioproject"]) 
    .collect()
)
print(len(metadata_df))

703916


In [13]:
# remove shitty entries with shitty metatada
def remove_unnannotated(df: pl.DataFrame) -> pl.DataFrame:
    bad_values = [
        "metagenome",
        "gut metagenome",
        "feces metagenome",
        "manure metagenome",
        "bacterium",
        "unidentified",
        "null"
    ]
    return df.filter(~pl.col("organism").is_in(bad_values))

# group by organism and count number of bacteria per SRA 
def pivot_count(df: pl.DataFrame) -> pl.DataFrame:
    return (
        df
        .group_by(["organism", "count"])
        .len()  # count occurrences
        .pivot(
            values="len",
            index="organism",
            columns="count"
        )
        .fill_null(0)  # optional: fill missing with 0
    )

# set mapping for "categories of interest" I like human v pig v neither. 
category_map_simple = {
    'human associated': ['human', 'homo', 'sapiens'],
    'pig': ['pig', 'sus', 'scrofa']}


# Function to assign category based on keywords above
def get_broad_cat_simple(organism):
    for cat, keywords in category_map_simple.items():
        if any(keyword.lower() in str(organism).lower() for keyword in keywords):
            return cat
    return 'other'  

# grouping data for plotting (either cumulative or infdividually)
def group_for_plot(df):
    numeric_cols = sorted(df.select_dtypes(include='number').columns, key=lambda x: int(x))
    df_grouped = df.groupby('broad_cat')[numeric_cols].sum().reset_index()
    df_cumulative = df_grouped.copy()
    df_cumulative[numeric_cols] = df_cumulative[numeric_cols].iloc[:, ::-1].cumsum(axis=1).iloc[:, ::-1]
    return df_grouped, df_cumulative

# colormap plotting
colors = {
    'pig': '#264653',
    'human associated': '#e76f51',
    'other': '#f4a261',
    'pig/other': '#2a9d8f',
    'pig-associated': '#e9c46a'
    }

# Plotting in numbers (xx metagenomes)
def plot_cat(df, *, ax=None):
    numeric_cols = sorted(df.select_dtypes(include='number').columns, key=lambda x: int(x))
    df_abs = df.set_index('broad_cat')[numeric_cols]
    ax = df_abs.T.plot(
        kind='bar',
        stacked=True,
        figsize=(5,5),
        color=[colors[c] for c in df_abs.index],
        ax=ax,
    )
    ax.set_ylabel('Number of metagenomes')
    ax.set_xlabel('Number of bacteria (out of 16)')
    plt.legend(title='Metagenome origin', loc='upper right')
    #plt.tight_layout()


# plot relative percentages (x percent of the metagenomes (total 100%))
def plot_cat_percent(df, *, ax=None):
    numeric_cols = sorted(df.select_dtypes(include='number').columns, key=lambda x: int(x))
    df_norm = df.set_index('broad_cat')[numeric_cols].div(df[numeric_cols].sum(axis=0), axis=1) * 100
    ax = df_norm.T.plot(
        kind='bar',
        stacked=True,
        figsize=(6,6),
        color=[colors[c] for c in df_norm.index],
        ax=ax,
    )
    ax.set_ylabel('Percent of metagenomes')
    ax.set_xlabel('Number of bacteria (out of 16)')
    plt.legend(title='Metagenome origin', bbox_to_anchor=(1.05, 1), loc='upper left', fontsize=8)
    #plt.tight_layout()


In [14]:
hash_df = pl.scan_parquet('../byhash/*.parquet').collect()

In [15]:
hash_df['hashval'].n_unique()

2625

In [16]:
thresh_df = hash_df.group_by(['species', 'acc']).len(name='num').filter(pl.col('num') >= THRESHOLD_HASHES)
thresh_df

species,acc,num
str,str,u32
"""s__Roseburia inulinivorans""","""SRR23937520""",77
"""s__JAFBIX01 sp021531895""","""ERR3648472""",36
"""s__Roseburia inulinivorans""","""SRR29449285""",40
"""s__Prevotella sp000434975""","""SRR13236208""",39
"""s__Holdemanella porci""","""SRR27538999""",34
…,…,…
"""s__Roseburia inulinivorans""","""SRR11073036""",95
"""s__Roseburia inulinivorans""","""SRR17643896""",76
"""s__Roseburia inulinivorans""","""ERR9250244""",40
"""s__Gemmiger qucibialis""","""ERR4097201""",70


In [17]:
# should be ~118 x 100,000? or 11.8mb?
thresh_df.filter((pl.col("species") == "s__Bariatricus sp004560705") & (pl.col("acc") == "ERR2597482"))

species,acc,num
str,str,u32
"""s__Bariatricus sp004560705""","""ERR2597482""",118


In [18]:
hash2_df = thresh_df.join(hash_df, on=['species', 'acc'], how='inner')
hash2_df

species,acc,num,hashval
str,str,u32,i64
"""s__Bariatricus sp004560705""","""SRR8239394""",21,1819434176410
"""s__Bariatricus sp004560705""","""SRR5866270""",20,1819434176410
"""s__Bariatricus sp004560705""","""ERR11504214""",20,1819434176410
"""s__Bariatricus sp004560705""","""SRR8960347""",99,1819434176410
"""s__Bariatricus sp004560705""","""ERR7745903""",20,1819434176410
…,…,…,…
"""s__UBA2868 sp004552595""","""SRR25936234""",22,184340211879066
"""s__UBA2868 sp004552595""","""SRR25936303""",37,184340211879066
"""s__UBA2868 sp004552595""","""SRR20632250""",22,184340211879066
"""s__UBA2868 sp004552595""","""SRR11125875""",33,184340211879066


In [19]:
hash2_df.group_by(['species', 'acc']).len(name='num2').sort(by='num2')

species,acc,num2
str,str,u32
"""s__Lactobacillus amylovorus""","""ERR3224610""",20
"""s__Cryptobacteroides sp9005469…","""SRR25571832""",20
"""s__Gemmiger qucibialis""","""SRR6228382""",20
"""s__Holdemanella porci""","""ERR2749150""",20
"""s__Mogibacterium_A kristiansen…","""SRR8239395""",20
…,…,…
"""s__Roseburia inulinivorans""","""SRR30924210""",194
"""s__Roseburia inulinivorans""","""ERR1224342""",196
"""s__Roseburia inulinivorans""","""ERR1224350""",200
"""s__Roseburia inulinivorans""","""ERR1224353""",204


In [20]:
rename_df = []

thresh_metadata_df = thresh_df.join(metadata_df, how='inner', on='acc')
for organism in thresh_metadata_df['organism'].unique().to_list():
    simpleorg = 'unknown'
    if organism:
        organism = organism.lower()
        for kw in ['human', 'homo']:
            if kw in organism.lower():
                simpleorg = 'human'
                break
        for kw in ['pig', 'sus', 'scrofa']:
            if kw in organism.lower():
                simpleorg = 'pig'
                break
    rename_df.append(dict(organism=organism, simpleorg=simpleorg))

rename_df = pl.DataFrame(rename_df)
rename_df['simpleorg'].value_counts().sort(by='count', descending=True)


simpleorg,count
str,u32
"""unknown""",253
"""human""",15
"""pig""",6


In [21]:
print(len(hash2_df))
bw_df = hash2_df.join(metadata_df, on='acc', how='inner')
print(len(bw_df))
bw_df = bw_df.join(rename_df, on='organism', how='inner')
print(len(bw_df))
bw_df

27075556
24829901
23480430


species,acc,num,hashval,organism,bioproject,simpleorg
str,str,u32,i64,str,str,str
"""s__Bariatricus sp004560705""","""SRR5866270""",20,1819434176410,"""human gut metagenome""","""PRJNA395744""","""human"""
"""s__Bariatricus sp004560705""","""ERR11504214""",20,1819434176410,"""human gut metagenome""","""PRJEB51353""","""human"""
"""s__Bariatricus sp004560705""","""SRR8960347""",99,1819434176410,"""pig gut metagenome""","""PRJNA526405""","""pig"""
"""s__Bariatricus sp004560705""","""ERR7745903""",20,1819434176410,"""human gut metagenome""","""PRJEB49206""","""human"""
"""s__Bariatricus sp004560705""","""SRR11183406""",40,1819434176410,"""pig gut metagenome""","""PRJNA526405""","""pig"""
…,…,…,…,…,…,…
"""s__UBA2868 sp004552595""","""SRR25936234""",22,184340211879066,"""pig gut metagenome""","""PRJNA1010706""","""pig"""
"""s__UBA2868 sp004552595""","""SRR25936303""",37,184340211879066,"""pig gut metagenome""","""PRJNA1010706""","""pig"""
"""s__UBA2868 sp004552595""","""SRR20632250""",22,184340211879066,"""pig gut metagenome""","""PRJNA798835""","""pig"""
"""s__UBA2868 sp004552595""","""SRR11125875""",33,184340211879066,"""pig gut metagenome""","""PRJNA526405""","""pig"""


In [22]:
def make_matrices(df):
    acc_df = df['acc'].unique().to_frame().with_row_index(name='acc_index')
    hashval_df = df['hashval'].unique().to_frame().with_row_index(name='hashval_index')
    org_df = df['simpleorg'].unique().to_frame().with_row_index(name='org_index')

    df = df.join(org_df, on='simpleorg', how='left')
    df = df.join(acc_df, on='acc', how='left')
    df = df.join(hashval_df, on='hashval', how='left')

    print(len(acc_df), len(hashval_df))

    obs = np.zeros((len(acc_df), len(hashval_df)))
    target = np.zeros((len(acc_df)))

    for row in df.iter_rows(named=True):
        acc_id = row["acc_index"]
        org_id = row["org_index"]
        hashval_id = row["hashval_index"]

        obs[acc_id, hashval_id] = 1
        target[acc_id] = org_id

    print(f'observations matrix shape is: {obs.shape}')

    return df, obs, target


In [23]:
human_pig_only = bw_df.filter(pl.col("simpleorg") != "unknown")
hp_df, hp_obs, hp_target = make_matrices(human_pig_only)

138075 2620
observations matrix shape is: (138075, 2620)


In [24]:
kf = StratifiedKFold(n_splits=6)

accuracies = []
i = 1
for train_sub, test_sub in kf.split(hp_obs, hp_target):
    dt = DecisionTreeClassifier(random_state=42)
    tree = dt.fit(hp_obs[train_sub], hp_target[train_sub])
    pred = tree.predict(hp_obs[test_sub])

    accuracy = balanced_accuracy_score(hp_target[test_sub], pred)
    accuracies.append(accuracy)
    print(f"iteration {i}: accuracy {accuracy:.3f}")
    i += 1

print(f"mean accuracy across {i-1} splits: {np.mean(accuracies):.3f}")

iteration 1: accuracy 0.990
iteration 2: accuracy 0.992
iteration 3: accuracy 0.992
iteration 4: accuracy 0.991
iteration 5: accuracy 0.994
iteration 6: accuracy 0.993
mean accuracy across 6 splits: 0.992


In [25]:
dt = DecisionTreeClassifier(random_state=42)
tree = dt.fit(hp_obs, hp_target)
pred = tree.predict(hp_obs)

accuracy = balanced_accuracy_score(hp_target, pred)
print(f"full classifier: accuracy {accuracy:.3f}")

full classifier: accuracy 1.000


In [26]:
hashval_index_df = hp_df.select(['hashval_index', 'hashval']).unique()
hashval_index_df

hashval_index,hashval
u32,i64
46,3395815561515
2140,153366926424798
1024,70227126060793
270,18481026195974
1259,85547927735671
…,…
962,66774703493179
2012,143249423639993
1352,91846248478214
233,16049106833996


In [28]:
importances_df = []
for (hashval_index, importance) in enumerate(tree.feature_importances_):
    d = dict(hashval_index=hashval_index, importance=importance)
    importances_df.append(d)

importances_df = pl.DataFrame(importances_df)
importances_df = importances_df.join(hashval_index_df, on='hashval_index', how='left')
importances_df = importances_df.sort(by='importance', descending=True)
importances_df

hashval_index,importance,hashval
i64,f64,i64
657,0.841089,46686150318053
1859,0.029961,130978579836892
2168,0.026687,155155266550172
1449,0.021561,98897342060808
2156,0.01049,154531024533774
…,…,…
1455,4.7923e-7,100170150600801
379,3.0912e-7,26617407661648
2213,2.0263e-7,158086748834657
1108,3.8408e-8,74796698212505


In [29]:
importances_df.filter(pl.col("importance") > 0.0)

hashval_index,importance,hashval
i64,f64,i64
657,0.841089,46686150318053
1859,0.029961,130978579836892
2168,0.026687,155155266550172
1449,0.021561,98897342060808
2156,0.01049,154531024533774
…,…,…
1455,4.7923e-7,100170150600801
379,3.0912e-7,26617407661648
2213,2.0263e-7,158086748834657
1108,3.8408e-8,74796698212505


In [30]:
top_hashvals = set(importances_df.filter(pl.col("importance") > 0.0)['hashval'])

In [31]:
top_human_pig_only = human_pig_only.filter(pl.col("hashval").is_in(top_hashvals))

In [32]:
hp2_df, hp2_obs, hp2_target = make_matrices(top_human_pig_only)

137924 145
observations matrix shape is: (137924, 145)


In [33]:
kf = StratifiedKFold(n_splits=6)

accuracies = []
i = 1
for train_sub, test_sub in kf.split(hp2_obs, hp2_target):
    dt = DecisionTreeClassifier(random_state=42)
    tree = dt.fit(hp_obs[train_sub], hp_target[train_sub])
    pred = tree.predict(hp_obs[test_sub])

    accuracy = balanced_accuracy_score(hp_target[test_sub], pred)
    accuracies.append(accuracy)
    print(f"iteration {i}: accuracy {accuracy:.3f}")
    i += 1

print(f"mean accuracy across {i-1} splits: {np.mean(accuracies):.3f}")

iteration 1: accuracy 0.991
iteration 2: accuracy 0.991
iteration 3: accuracy 0.992
iteration 4: accuracy 0.992
iteration 5: accuracy 0.993
iteration 6: accuracy 0.992
mean accuracy across 6 splits: 0.992


In [34]:
dt = DecisionTreeClassifier(random_state=42)
tree = dt.fit(hp2_obs, hp2_target)
pred = tree.predict(hp2_obs)

accuracy = balanced_accuracy_score(hp2_target, pred)
print(f"full classifier: accuracy {accuracy:.3f}")

full classifier: accuracy 1.000


In [35]:
top_hashvals

{2276810424501,
 3007401320634,
 3035107886934,
 3251579571603,
 3359889922663,
 7447595582446,
 8162040142414,
 8316542247078,
 11990840081531,
 12747921740537,
 14869673020998,
 15798690698179,
 18481026195974,
 20505615293657,
 25359547637592,
 26617407661648,
 26964148698269,
 27722487130859,
 28078339860085,
 30116475662313,
 31475178625788,
 32524744015764,
 33140055463751,
 34114804210232,
 34435720331744,
 35060787050688,
 35633148197870,
 35766251147702,
 36485762550599,
 39883632124890,
 40943651421503,
 45436004968317,
 45643311818038,
 45738548022836,
 46686150318053,
 46835491623951,
 47528319887378,
 48448886320946,
 48882494780142,
 50748318191662,
 54718860558479,
 54808437975348,
 55265762384077,
 56575601536939,
 56818280950635,
 58369902641009,
 62070338511618,
 63951121195305,
 65580746466396,
 69329214095517,
 70626174406795,
 71628595607876,
 72373383410919,
 73913779896132,
 74796698212505,
 75163282972818,
 75825592941266,
 75863513103981,
 76237398318112,
 7901

In [36]:
import sourmash, sourmash.sourmash_args

In [37]:
mh = sourmash.MinHash(n=0, ksize=21, scaled=100_000)
mh.add_many(top_hashvals)

In [38]:
ss = sourmash.SourmashSignature(mh, name='pigger100k')

with sourmash.sourmash_args.SaveSignaturesToLocation('pigger100k.sig.zip') as save_ss:
    save_ss.add(ss)

In [39]:
metadata_df

acc,organism,bioproject
str,str,str
"""SRR28523869""","""human metagenome""","""PRJNA1095378"""
"""SRR19901137""","""Streptococcus suis""","""PRJNA854064"""
"""SRR24962475""","""Escherichia coli""","""PRJNA982891"""
"""SRR29161330""","""biofilm metagenome""","""PRJNA1116137"""
"""SRR26668739""","""bovine gut metagenome""","""PRJNA1035769"""
…,…,…
"""ERR4174288""","""human gut metagenome""","""PRJEB36461"""
"""ERR3160109""","""human gut metagenome""","""PRJEB23292"""
"""ERR2709457""","""human gut metagenome""","""PRJEB27799"""
"""ERR2709367""","""human gut metagenome""","""PRJEB27799"""


In [40]:
rename_df

organism,simpleorg
str,str
"""lichen metagenome""","""unknown"""
"""parabacteroides""","""unknown"""
"""mixed culture metagenome""","""unknown"""
"""dama dama""","""unknown"""
"""odoribacter laneus""","""unknown"""
…,…
"""soil metagenome""","""unknown"""
"""coprococcus""","""unknown"""
"""escherichia coli""","""unknown"""
"""escherichia""","""unknown"""


In [41]:
metadata2_df = metadata_df.join(rename_df, on='organism', how='inner')

In [42]:
pigger_df = pl.read_csv('../pigger100k.x.bw.csv').with_columns(
    pl.col("match_name").alias("acc")
)

pigger_df = pigger_df.join(metadata2_df, on='acc', how='inner')
pigger_df['simpleorg'].value_counts()

simpleorg,count
str,u32
"""human""",179026
"""pig""",7174
"""unknown""",143847


In [43]:
xx = []
for thresh in range(1, len(top_hashvals)):
    dd = pigger_df.filter(pl.col("intersect_hashes") >= thresh)['simpleorg'].value_counts()
    dd = dd.to_dict(as_series=False)
    counts = dict(zip(dd['simpleorg'], dd['count']))
    print(f"threshold={thresh}    human={counts.get('human', 0)}  pig={counts.get('pig', 0)}   unk={counts.get('unknown', 0)}")    

threshold=1    human=179026  pig=7174   unk=143847
threshold=2    human=179026  pig=7174   unk=143847
threshold=3    human=167627  pig=7067   unk=115469
threshold=4    human=159170  pig=6953   unk=96271
threshold=5    human=153717  pig=6842   unk=82409
threshold=6    human=149153  pig=6754   unk=72249
threshold=7    human=144938  pig=6670   unk=65312
threshold=8    human=140933  pig=6604   unk=60310
threshold=9    human=136919  pig=6530   unk=56528
threshold=10    human=132791  pig=6469   unk=53534
threshold=11    human=128633  pig=6407   unk=50950
threshold=12    human=124357  pig=6347   unk=48840
threshold=13    human=119804  pig=6304   unk=46829
threshold=14    human=115297  pig=6264   unk=44948
threshold=15    human=110496  pig=6222   unk=43195
threshold=16    human=105585  pig=6185   unk=41317
threshold=17    human=100535  pig=6137   unk=39422
threshold=18    human=95380  pig=6088   unk=37559
threshold=19    human=90011  pig=6031   unk=35609
threshold=20    human=84502  pig=5967  