## Setup Instructions

Before running this notebook, ensure you have:

1. Cloned the BaLLM repository
2. Installed dependencies: `pip install -r requirements.txt`
3. Set the `BASE_DIR` variable in the next cell to point to your local copy of the repository

The notebook expects the following directory structure relative to `BASE_DIR`:
- `results/final_kg.csv` — Knowledge graph
- `results/KG model parameters/` — Pre-trained model files
- Test dataset files (genotype and phenotype CSVs)

In [None]:
import os

# Set this to the root of your BaLLM repository
BASE_DIR = os.path.join(os.path.dirname(os.path.abspath("__file__")), "..", "..")

# Paths used throughout the notebook
KG_PATH = os.path.join(BASE_DIR, "results", "final_kg.csv")
MODEL_DIR = os.path.join(BASE_DIR, "results", "KG model parameters")

## This Notebook is an illustration of how to load and run logistic regression models and Bayesian hierarchical models on a test dataset.

### To successfully run this script, you will need:

1. Test dataset genotype file (row as bacteria, column as genes)
2. Test dataset phenotype file (row as test on specific bacteria and antibiotic, column as test result: resistant'susceptible)
3. Imported knowledge graph called final_kg.csv
4. Imported logistic regression model parameters
5. Imported Bayesian model parameters

In [1]:
import numpy as np
import pandas as pd
import statsmodels.api as sm
from scipy.special import expit
from sklearn.metrics import accuracy_score, confusion_matrix
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
from sklearn.metrics import roc_curve, roc_auc_score
import pytensor.tensor as at
from sklearn.linear_model import LogisticRegression
import seaborn as sns
import pickle
import os
import networkx as nx
from sklearn.metrics import roc_curve, auc
import warnings
warnings.filterwarnings("ignore")
warnings.simplefilter(action='ignore', category=pd.errors.PerformanceWarning)

### 1. Input test data: genotype data and phenotype data

In [None]:
###############    test models on Shelburne dataset  ##################
shelburne_gene=pd.read_csv(os.path.join(BASE_DIR, 'data', 'Shelburne_VAMP.csv'))
shelburne_pheno=pd.read_csv(os.path.join(BASE_DIR, 'data', 'Shelburne_phenotype.csv'))

###  format column   ###
shelburne_pheno['phenotype'] = shelburne_pheno['phenotype'].map({'resistant': 1, 'susceptible': 0})
shelburne_gene['variants'] = shelburne_gene['variants'].str.split('.').str[0]  #only extract Kegg genotype before .


### Merge the dataframes on `sample_id` ###
merged_df = pd.merge(shelburne_pheno, shelburne_gene, on='sample_id', how='left')

print(merged_df.head(10))

### 2. Split data by antibiotics

In [3]:
###  Split by antibiotics ###
shelburne_dict = {}

for abx in merged_df['antibiotics'].unique():
    key = f'shelburne_{abx.lower()}'
    shelburne_dict[key] = merged_df[merged_df['antibiotics'] == abx].copy()

# Example: 
shelburne_dict['shelburne_cefepime']

Unnamed: 0,sample_id,bacteria,antibiotics,phenotype,variants
0,Enterobacter_MB019,Enterobacter cloacae,cefepime,0,K00116
1,Enterobacter_MB019,Enterobacter cloacae,cefepime,0,K00116
2,Enterobacter_MB019,Enterobacter cloacae,cefepime,0,K00116
3,Enterobacter_MB019,Enterobacter cloacae,cefepime,0,K00287
4,Enterobacter_MB019,Enterobacter cloacae,cefepime,0,K00563
...,...,...,...,...,...
307120,Pseudomonas_MB4396,Pseudomonas aeruginosa,cefepime,0,K21135
307121,Pseudomonas_MB4396,Pseudomonas aeruginosa,cefepime,0,K21135
307122,Pseudomonas_MB4396,Pseudomonas aeruginosa,cefepime,0,K21136
307123,Pseudomonas_MB4396,Pseudomonas aeruginosa,cefepime,0,K21137


### 3. Import KG and extract antibiotics from KG that exist in Shelburne dataset

In [None]:
final_kg=pd.read_csv(KG_PATH)

In [5]:
# Dictionary to hold wide-format bayes dataframes
shelburne_bayes_df_dict = {}

# Loop through each antibiotic in final_kg
for abx in final_kg['antibiotic'].unique():
    abx_key = f'shelburne_{abx.lower()}'
    
    # Skip if antibiotic is not found in shelburne_dict
    if abx_key not in shelburne_dict:
        continue

    # Subset final_kg for this antibiotic
    abx_genes = final_kg[final_kg['antibiotic'] == abx]['gene'].unique()
    
    # Subset the long dataframe for this antibiotic
    abx_long_df = shelburne_dict[abx_key]

    # Get unique sample_id, bacteria, and phenotype (first occurrence per sample)
    sample_info = (
        abx_long_df
        .drop_duplicates(subset=['sample_id'])[['sample_id', 'bacteria', 'phenotype']]
        .reset_index(drop=True)
    )

    # Create the base bayes dataframe
    abx_bayes = sample_info.copy()

    # Add gene columns, initialized to 0
    for gene in abx_genes:
        abx_bayes[gene] = 0

    # Fill gene values based on presence in variants
    for i, row in abx_bayes.iterrows():
        sample_id = row['sample_id']
        
        # Get all variants for this sample_id
        sample_variants = abx_long_df[abx_long_df['sample_id'] == sample_id]['variants'].unique()
        
        for gene in abx_genes:
            if gene in sample_variants:
                abx_bayes.at[i, gene] = 1

    # Save to dictionary
    shelburne_bayes_df_dict[f'{abx.lower()}_bayes'] = abx_bayes

### 4. Write a function to load model parameters and evalutate 

In [6]:
def evaluate_all_antibiotics(
    test_df_dict, model_dir
):
    results = []

    for key in test_df_dict.keys():
        if not key.endswith("_bayes"):
            continue

        abx = key.replace("_bayes", "")
        print(f"Evaluating antibiotic: {abx}")

        # Step 1: Extract data
        df = test_df_dict[key].copy()
        non_gene_cols = ['sample_id', 'bacteria', 'phenotype']
        gene_cols = [col for col in df.columns if col not in non_gene_cols]

        X = df[gene_cols].astype(float)
        y = df['phenotype'].astype(int)

        # Load models and gene strength
        try:
            with open(os.path.join(model_dir, f"{abx}_logistic_model.pkl"), "rb") as f:
                clf = pickle.load(f)

            bayes = np.load(os.path.join(model_dir, f"{abx}_posterior_params_2.5.npz"))
            posterior_beta = bayes["posterior_beta"]
            posterior_beta0 = bayes["posterior_beta0"]

            gene_strength = pd.read_csv(os.path.join(model_dir, f"{abx}_gene_strength.csv"))
            ordered_genes = gene_strength['gene'].tolist()
        except FileNotFoundError:
            print(f"Model or data missing for {abx}, skipping.")
            continue

        # Fill missing genes with 0
        X_filled = pd.DataFrame(0, index=X.index, columns=ordered_genes)
        for gene in ordered_genes:
            if gene in X.columns:
                X_filled[gene] = X[gene].astype(float)
        
        # count how many columns are all 0:
        num_all_zero_columns = (X_filled == 0).all(axis=0).sum()  
        percentage = num_all_zero_columns / X_filled.shape[1] * 100
        print(f"All-zero columns: {num_all_zero_columns} / {X_filled.shape[1]} ({percentage:.2f}%)")

        X_valid = X_filled.values

        # Logistic model prediction
        y_pred_logistic = clf.predict(X_valid)
        acc_logistic = accuracy_score(y, y_pred_logistic)

        # Bayesian model prediction
        logits = posterior_beta0 + X_valid @ posterior_beta
        probs = 1 / (1 + np.exp(-logits))
        
        row_mean = np.mean(probs, axis=1)
        row_std = np.std(probs, axis=1)
        df_stats = pd.DataFrame({
            'mean': row_mean,
            'std': row_std})
        
        P_mean = probs.mean(axis=1)
        y_pred_bayes = (P_mean > 0.5).astype(int)

        acc_bayes = accuracy_score(y, y_pred_bayes)

        # Class balance info
        percent_1 = y.mean() * 100
        unique, counts = np.unique(y_pred_bayes, return_counts=True)
        bayes_prediction_counts = dict(zip(unique, counts))
        always_majority_class = len(bayes_prediction_counts) == 1 and (
            list(bayes_prediction_counts.keys())[0] == int(y.mean() > 0.5)
        )
        
        print(f"  - Number of samples: {len(y)}")
        print(f"  - % of resistant (y=1): {percent_1:.2f}%")
        print(f"  - Logistic Accuracy:   {acc_logistic:.4f}")
        print(f"  - Bayesian Accuracy:   {acc_bayes:.4f}")
        if always_majority_class:
            print(f" - Bayesian model predicted only the majority class ({list(bayes_prediction_counts.keys())[0]})!")

        results.append({
            "antibiotic": abx,
            "%_resistant": percent_1,
            "logistic_acc": acc_logistic,
            "bayesian_acc": acc_bayes,
            "bayes_majority_class": always_majority_class
        })

    return pd.DataFrame(results)

In [None]:
results_df = evaluate_all_antibiotics(
    test_df_dict=shelburne_bayes_df_dict,
    model_dir=MODEL_DIR
)