In [1]:
# --- synthetic Pfam-counts & media-label dataset ----------------------------
from pathlib import Path
import numpy as np
import pandas as pd

# 1. Parameters --------------------------------------------------------------
n_genomes   = 30           # rows
n_pfams     = 50           # genomic features
media_cols  = ["needs_glucose",
               "needs_biotin",
               "tolerates_salt",
               "requires_cysteine",
               "requires_pH7"]

out_dir     = Path("./demo_data")
pfam_file   = out_dir / "example_pfam_counts.tsv"
media_file  = out_dir / "example_media_labels.tsv"

# 2. Create directories ------------------------------------------------------
out_dir.mkdir(exist_ok=True)

# 3. Make synthetic Pfam count matrix ---------------------------------------
rng      = np.random.default_rng(seed=42)
taxids   = [f"T{i:03d}" for i in range(1, n_genomes + 1)]
pfams    = [f"PF{1000 + i}" for i in range(1, n_pfams + 1)]

pfam_counts = rng.poisson(lam=3, size=(n_genomes, n_pfams))
pfam_df     = pd.DataFrame(pfam_counts, index=taxids, columns=pfams)

# 4. Make synthetic multi-label media matrix ---------------------------------
# each label gets its own prevalence
prevalence = dict(zip(media_cols, [0.6, 0.3, 0.5, 0.4, 0.7]))
media_mat  = np.column_stack([
    rng.binomial(1, p=prevalence[label], size=n_genomes)
    for label in media_cols
])
media_df = pd.DataFrame(media_mat, index=taxids, columns=media_cols)

# 5. Save to TSV -------------------------------------------------------------
pfam_df.to_csv(pfam_file,  sep="\t")
media_df.to_csv(media_file, sep="\t")

print("✓ Synthetic data written to:", pfam_file, "&", media_file)
display(pfam_df.head())
display(media_df.head())


✓ Synthetic data written to: demo_data/example_pfam_counts.tsv & demo_data/example_media_labels.tsv


Unnamed: 0,PF1001,PF1002,PF1003,PF1004,PF1005,PF1006,PF1007,PF1008,PF1009,PF1010,...,PF1041,PF1042,PF1043,PF1044,PF1045,PF1046,PF1047,PF1048,PF1049,PF1050
T001,4,4,5,1,7,1,4,2,2,5,...,4,3,2,7,3,3,2,3,4,0
T002,1,2,4,2,4,3,2,3,4,3,...,2,2,5,2,3,7,4,2,4,3
T003,4,3,1,1,4,0,1,3,5,1,...,2,2,5,1,5,5,2,3,1,4
T004,5,6,6,3,0,7,4,1,3,3,...,1,2,0,2,1,2,3,5,6,3
T005,1,0,1,4,3,4,0,5,2,4,...,1,6,2,4,2,2,4,4,2,5


Unnamed: 0,needs_glucose,needs_biotin,tolerates_salt,requires_cysteine,requires_pH7
T001,1,0,0,0,0
T002,0,0,0,0,1
T003,1,0,0,0,1
T004,1,0,1,0,1
T005,0,1,1,0,1


In [3]:
from pathlib import Path
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn.multioutput import MultiOutputClassifier
from sklearn.metrics import classification_report, accuracy_score, f1_score

# --- predictors -------------------------------------------------------------
# pfam_df: rows = genomes (taxid), cols = Pfam counts   (built earlier)
# media_df: rows = same taxid index, cols = binary media traits
X = pfam_df.join(media_df, how="inner", lsuffix="_pfam")
y = X[media_df.columns]
X = X[pfam_df.columns]

X_train, X_test, y_train, y_test = train_test_split(
        X, y, test_size=0.2, random_state=42, stratify=None)  # stratify if not too sparse

# --- model ------------------------------------------------------------------
base_rf = RandomForestClassifier(
            n_estimators=500,
            max_depth=None,
            n_jobs=-1,
            class_weight="balanced_subsample",  # handles label imbalance
            random_state=42)

clf = MultiOutputClassifier(base_rf)
clf.fit(X_train, y_train)

y_pred = clf.predict(X_test)

# --- evaluation -------------------------------------------------------------
for i, label in enumerate(y.columns):
    print(f"\nLabel: {label}")
    print(classification_report(y_test.iloc[:, i], y_pred[:, i], digits=3))
    
# Optional overall metrics
exact_match = accuracy_score(y_test, y_pred)      # all labels correct
macro_f1 = f1_score(y_test, y_pred, average="macro")
print(f"\nExact-match accuracy: {exact_match:.3f}")
print(f"Macro-averaged F1    : {macro_f1:.3f}")



Label: needs_glucose
              precision    recall  f1-score   support

           0      0.000     0.000     0.000         1
           1      0.833     1.000     0.909         5

    accuracy                          0.833         6
   macro avg      0.417     0.500     0.455         6
weighted avg      0.694     0.833     0.758         6


Label: needs_biotin
              precision    recall  f1-score   support

           0      0.167     1.000     0.286         1
           1      0.000     0.000     0.000         5

    accuracy                          0.167         6
   macro avg      0.083     0.500     0.143         6
weighted avg      0.028     0.167     0.048         6


Label: tolerates_salt
              precision    recall  f1-score   support

           0      0.500     0.200     0.286         5
           1      0.000     0.000     0.000         1

    accuracy                          0.167         6
   macro avg      0.250     0.100     0.143         6
weighted

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


In [None]:
importances = {
    label: est.feature_importances_
    for label, est in zip(y.columns, clf.estimators_)
}

# Show top 10 Pfams driving “requires_biotin”
pd.Series(importances[ingredientstr], index=X.columns) \
  .sort_values(ascending=False).head(10)

PF1046    0.092173
PF1040    0.086311
PF1035    0.083453
PF1008    0.045028
PF1043    0.039180
PF1050    0.033578
PF1002    0.031329
PF1025    0.027291
PF1013    0.027262
PF1016    0.023708
dtype: float64