In [None]:
%load_ext blackcellmagic
%matplotlib inline
%config InlineBackend.figure_format = 'retina'

import sys
import os
import pickle
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

# Build RF model to predict viral vs non-viral

In [None]:
sys.path.insert(0, os.path.abspath(os.path.join(os.getcwd(), '..', 'src')))

from utils import data_to_use, explain_col

from model_rf import (
    plot_roc,
    plot_confusion_matrices,
    get_standard_features,
    build_rf_model,
    get_feature_importances,
    plot_feature_importances,
    create_feature_sets,
    model_new_feature_sets,
    plot_feature_distributions
)

In [None]:
train_test_data_folder = "XXX"

Get full set of numerical and categorial features:

In [None]:
num_cols_present, cat_cols_present = get_standard_features()

Define pipeline:

In [None]:
def run_pipeline(pathogen):

    # Build model using all features
    (
        local_preprocessor,
        models_rf,
        y_pred_prob_dict_rf,
        roc_curve_values_rf,
        y_test_dict_rf
    ) = build_rf_model(
        pathogen=pathogen, 
        train_test_data_folder=train_test_data_folder, 
        num_cols_present=num_cols_present, 
        cat_cols_present=cat_cols_present
        )

    plot_roc(roc_curve_values_rf, models_rf, pathogen, colors=["tab:red", "black"])

    # Get and plot importance of each feature
    importance_df = get_feature_importances(
        models_rf["true"], local_preprocessor, pathogen
    )
    print(importance_df.head())
    print(importance_df.tail())

    plot_feature_importances(importance_df, pathogen)

    # Plot distributions of top 12 important unique features
    seen = set()
    features_to_plot = [x for x in importance_df["Feature_clean"].values if x not in seen and not seen.add(x)][:20]
    plot_feature_distributions(features_to_plot, importance_df, pathogen)

    # Create new models using feature subsets
    feature_sets = create_feature_sets(importance_df, pathogen)
    (
        models_subset,
        preprocessors_subset,
        roc_curve_values_subset,
        y_pred_prob_dict_subset,
        y_test_dict_subset
    ) = model_new_feature_sets(
        feature_sets, pathogen, train_test_data_folder, num_cols_present, cat_cols_present
    )

    # Plot new prediction results using feature subsets
    # colors = ["red", "#3182bd", "#9ecae1", "#deebf7", "darkorange", "#31a354", "grey"]
    colors = ["red", "#3182bd", "#9ecae1", "#deebf7", "darkorange", "grey"]

    # Reorganize results from models trained on feature subsets for plotting
    temp_model_dict = {}
    temp_pp_dict = {}
    temp_roc_dict = {}
    temp_y_pred_prob_dict = {}
    temp_y_test_dict = {}
    for key in models_subset.keys():
        temp_model_dict[key] = models_subset[key]["true"]
        temp_pp_dict[key] = preprocessors_subset[key]
        temp_roc_dict[key] = roc_curve_values_subset[key]["true"]
        temp_y_pred_prob_dict[key] = y_pred_prob_dict_subset[key]["true"]
        temp_y_test_dict[key] = y_test_dict_subset[key]["true"]

    # Add scrambled control from "all features" model
    temp_model_dict["scrambled_control"] = models_subset["all"]["scrambled"]
    temp_pp_dict["scrambled_control"] = preprocessors_subset["all"]
    temp_roc_dict["scrambled_control"] = roc_curve_values_subset["all"]["scrambled"]
    temp_y_pred_prob_dict["scrambled_control"] = y_pred_prob_dict_subset["all"]["scrambled"]
    temp_y_test_dict["scrambled_control"] = y_test_dict_subset["all"]["scrambled"]

    # Plot ROC curve and confusion matrices
    plot_roc(temp_roc_dict, temp_model_dict, pathogen, colors=colors)
    plot_confusion_matrices(temp_y_test_dict, temp_y_pred_prob_dict, temp_roc_dict, pathogen)

    # Save models and results
    master_dict = {}
    master_dict["models"] = temp_model_dict
    master_dict["preprocessors"] = temp_pp_dict
    master_dict["roc"] = temp_roc_dict
    master_dict["y_pred_prob"] = temp_y_pred_prob_dict
    master_dict["y_test"] = temp_y_test_dict
    with open(f"../data/model_rf/{pathogen}_rf.pkl", "wb") as f:
        pickle.dump(master_dict, f)

    importance_df.to_csv(f"../data/model_rf/{pathogen}_feature_importances_rf.csv", index=False)

Build model to predict viral vs non-viral:

In [None]:
%%time
pathogen = "all-viral"
run_pipeline(pathogen)

# Show predictions for unseen Malaria cases

In [None]:
# Load data
data_date = "XXX"
data_path = "XXX"
data_df = pd.read_csv(data_path, low_memory=False)


In [None]:
# Load the generated model
pathogen = "all-viral"
with open("XXX", "rb") as f:
    master_dict = pickle.load(f)

model = master_dict['models']['all']
local_preprocessor = master_dict['preprocessors']['all']

# Define threshold based on Youden's J for plotting
threshold = 0.36

In [None]:
# malaria_confirmed_mask does not allow co-infections
malaria_label_col = "Malaria_label"
malaria_confirmed_mask = (
    (data_df[malaria_label_col] == 1) &
    ~((data_df[[col for col in data_df.columns if col.endswith('_label') and col != malaria_label_col]] == 1).any(axis=1))
)

malaria_data = data_df[
    malaria_confirmed_mask &
    (data_df[f"{pathogen}_label"] == 2)
].copy()

print(len(malaria_data))

X = local_preprocessor.transform(data_to_use(malaria_data))

y_pred_prob = model.predict_proba(X)[:, 1]
malaria_data["viral_prob"] = y_pred_prob

In [None]:
fig, ax = plt.subplots()
ax.hist(malaria_data["viral_prob"], bins=30, alpha=0.7)
ax.set_xlabel("Viral Model Output")
ax.set_ylabel("Count")

ax.axvline(threshold, lw=1, ls="--", c="red")

ax.set_title("Unseen Malaria positive (no co-infections) samples")

fig.savefig("figures/rf_model_preds_unseen_malaria", dpi=300, bbox_inches="tight")

plt.show()

# Predict unseen remaining 'other' samples

In [None]:
new_data = data_df[
    (data_df[f"{pathogen}_label"] == 2) &
    (data_df["Malaria_label"] == 0) &
    (data_df["Syphilis_label"] == 0)
    ].copy()

X = local_preprocessor.transform(data_to_use(new_data))

y_pred_prob = model.predict_proba(X)[:, 1]
new_data["viral_prob"] = y_pred_prob

In [None]:
fig, ax = plt.subplots()
ax.hist(new_data["viral_prob"], bins=30, alpha=0.7)
ax.set_xlabel("Viral Model Output")
ax.set_ylabel("Count")

ax.axvline(threshold, lw=1, ls="--", c="red")

ax.set_title("Unseen 'other' samples")

fig.savefig("figures/rf_model_preds_unseen", dpi=300, bbox_inches="tight")

plt.show()

In [None]:
category = "high_priority_zoonotic"
zoonotic_df = pd.read_csv(f"../data_small/rule_based_selection_{category}_{data_date}.csv")
zoonotic_df

In [None]:
new_data = new_data.merge(zoonotic_df[["record_id", "features", "num_features"]], on="record_id", how="left")
new_data["num_features"] = new_data["num_features"].fillna(0)
new_data

In [None]:
# Count number of points at each (num_features, viral_prob) pair
counts = new_data.groupby(["num_features", "viral_prob"]).size().reset_index(name='count')
# Merge counts back to new_data for coloring
new_data_merged = new_data.merge(counts, on=["num_features", "viral_prob"], how="left")

fig, ax = plt.subplots()
sc = ax.scatter(
    new_data_merged["num_features"],
    new_data_merged["viral_prob"],
    c=new_data_merged["count"],
    cmap="viridis",
    alpha=0.6
)
ax.set_xlabel("Number of High Priority Clinical Features")
ax.set_ylabel("Viral Model Output (RF)")
ax.axhline(threshold, lw=1, ls="--", c="red")
cbar = plt.colorbar(sc, ax=ax)
cbar.set_label("Count")

fig.savefig("figures/rf_model_preds_over_clinical_features", dpi=300, bbox_inches="tight")

plt.show()

# Generate RF predictions for comparison with LLMs

In [None]:
pathogen = "all-viral"
# Include all data not used for either training or testing
pred_df = data_df[
    (data_df[f"{pathogen}_label"] == 2)
    ].copy()

X = local_preprocessor.transform(data_to_use(pred_df))

y_pred_prob = model.predict_proba(X)[:, 1]
pred_df["probability_of_viral_dec"] = y_pred_prob
pred_df["probability_of_viral"] = y_pred_prob * 100
pred_df["viral"] = np.where(pred_df["probability_of_viral_dec"] > threshold, "yes", "no")
pred_df

In [None]:
data_df_rf = data_df.merge(pred_df[["record_id", "probability_of_viral", "viral"]], on="record_id")
data_df_rf

In [None]:
# Save data frame with RF predictions
print(data_path.replace(".csv", "_RF.csv"))
data_df_rf.to_csv(data_path.replace(".csv", "_RF.csv"), index=False)

# Re-generate predictions for testing and training data for LLM training
??? Training preds is weird cause that's the data that the RF model was trained on but what else to include in knowledge prompt??

In [None]:
training_df = pd.read_pickle(f"../data/test_train_splits/X_train_all-viral.pkl")
testing_df = pd.read_pickle(f"../data/test_train_splits/X_test_all-viral.pkl")

In [None]:
# NOTE: Seed models manually create by changing SEED in model_rf.py

thresholds = [0.36, 0.29, 0.32]

# SEEDS: 42, 1, 120
for i, seed in enumerate(["", "_02", "_03"]):
    # Load the generated model
    pathogen = "all-viral"
    with open(f"../data/model_rf/{pathogen}_rf{seed}.pkl", "rb") as f:
        master_dict = pickle.load(f)

    model = master_dict['models']['all']
    local_preprocessor = master_dict['preprocessors']['all']

    # Define threshold based on Youden's J for plotting
    threshold = thresholds[i]

    # Predict for training data (used in knowledge summary for LLMs)
    X_train = local_preprocessor.transform(training_df.drop(columns=["record_id"]))

    y_pred_prob_train = model.predict_proba(X_train)[:, 1]
    training_df["probability_of_viral_rf"] = y_pred_prob_train * 100
    training_df["viral_rf"] = np.where(training_df["probability_of_viral_rf"]/100 > threshold, "yes", "no")

    with open(f"../data/test_train_splits/X_train_all-viral_rf{seed}.pkl", "wb") as f:
        pickle.dump(training_df, f)

    # Predict for testing data
    X_test = local_preprocessor.transform(testing_df.drop(columns=["record_id"]))

    y_pred_prob_test = model.predict_proba(X_test)[:, 1]
    testing_df["probability_of_viral_rf"] = y_pred_prob_test * 100
    testing_df["viral_rf"] = np.where(testing_df["probability_of_viral_rf"]/100 > threshold, "yes", "no")

    with open(f"../data/test_train_splits/X_test_all-viral_rf{seed}.pkl", "wb") as f:
        pickle.dump(testing_df, f)