# 🤖 Predictive Model Performance
## How do you decide which predictive model to use?
In this notebook, we evaluate several machine learning models to predict whether a plant has medicinal properties based on its taxonomy and other descriptive features. Below is a brief overview of the models used and how they conceptually approach the classification task:

### 1. **Logistic Regression**
- **Type**: Linear model
- **Concept**: Estimates the probability that a plant is medicinal using a weighted combination of input features.
- **Strengths**: Simple, interpretable, fast.
- **Limitations**: Assumes a linear relationship between features and the log-odds of the outcome; struggles with complex patterns.

### 2. **Decision Tree**
- **Type**: Non-linear, rule-based
- **Concept**: Splits the data into branches based on feature thresholds to arrive at a prediction at the leaves.
- **Strengths**: Easy to visualize and understand; captures non-linear relationships.
- **Limitations**: Can overfit the training data if not pruned or regularized.

### 3. **Random Forest**
- **Type**: Ensemble (of decision trees)
- **Concept**: Trains multiple decision trees on different subsets of the data and averages their predictions to reduce variance.
- **Strengths**: More accurate and robust than a single tree; reduces overfitting.
- **Limitations**: Less interpretable; slower than simpler models.

### 4. **Gradient Boosting (e.g., GBT)**
- **Type**: Ensemble (boosted decision trees)
- **Concept**: Trains trees sequentially, where each tree corrects the errors of the previous one using gradient descent.
- **Strengths**: High accuracy; handles complex patterns well.
- **Limitations**: Can overfit if not tuned properly; computationally intensive.

### 5. **Support Vector Machine (SVM)**
- **Type**: Maximum-margin classifier
- **Concept**: Finds the optimal boundary (hyperplane) that best separates medicinal from non-medicinal plants by maximizing the margin between classes.
- **Strengths**: Works well in high-dimensional spaces; robust to overfitting.
- **Limitations**: Not ideal for large datasets; performance depends on kernel choice.

### 6. **XGBoost**
- **Type**: Gradient-boosted tree ensemble (optimized)
- **Concept**: An efficient and regularized implementation of gradient boosting that adds boosting trees iteratively to correct previous errors.
- **Strengths**: Often state-of-the-art in structured data problems; fast and scalable.
- **Limitations**: Complex; requires tuning; less interpretable.

---

Each model captures different aspects of the underlying patterns in the data. By comparing their performance across different evaluation strategies (e.g., with SMOTE and downsampling), we aim to identify not only which models are accurate, but also which ones are most robust under real-world conditions.


In [9]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder, StandardScaler
from sklearn.metrics import classification_report, confusion_matrix, auc, roc_curve
from sklearn.linear_model import LogisticRegression
from sklearn.tree import DecisionTreeClassifier
from sklearn.ensemble import RandomForestClassifier, GradientBoostingClassifier
from sklearn.svm import SVC
import pandas as pd
import plotly.express as px
import ipywidgets as widgets
from IPython.display import display
from sklearn.utils.class_weight import compute_class_weight
from imblearn.over_sampling import SMOTE
from collections import Counter
from xgboost import XGBClassifier
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
from sklearn.preprocessing import LabelEncoder
from imblearn.under_sampling import RandomUnderSampler
from sklearn.model_selection import train_test_split
from collections import Counter
from sklearn.linear_model import LogisticRegression
from sklearn.tree import DecisionTreeClassifier
from sklearn.ensemble import RandomForestClassifier, GradientBoostingClassifier
from sklearn.svm import SVC
from xgboost import XGBClassifier



import warnings
warnings.filterwarnings('ignore')

In [2]:
#####################################
# DATA LOADING
#####################################
#Load the pfaf_plants_merged.csv file
df = pd.read_csv('pfaf_plants_merged.csv')
df.head()

Unnamed: 0,use_keyword,latin_name_search,common_name_search,edibility_rating_search,medicinal_rating_search,plant_url,Care Requirements,Common Name,Common Names,Cultivation Details,...,Native Range,Other Uses,Other Uses Rating,Propagation,Range,Scientific Name,Special Uses,Summary,USDA hardiness,Weed Potential
0,Stomachic,Abelmoschus moschatus,"Musk Mallow,Musk Okra",2,3,https://pfaf.org/user/Plant.aspx?LatinName=Abe...,Moist Soil; Half Hardy; Full sun,"Musk Mallow,Musk Okra","Musk Mallow,Musk Okra",Easily grown in a rich well-drained soil in a ...,...,"TEMPERATE ASIA: China (Hunan Sheng, Jiangxi Sh...",Essential Fibre Insecticide Oil Size An essent...,3.0,Seed - sow April in a greenhouse. The seed ger...,S.E. Asia - Himalayas to China and Vietnam.,Abelmoschus moschatus - Medik.,Scented Plants,,8-11,No
1,Stomachic,Abies grandis,"Grand Fir, Giant Fir, Lowland White Fir",2,2,https://pfaf.org/user/Plant.aspx?LatinName=Abi...,Semi-shade; Fully Hardy; Moist Soil; Full shad...,"Grand Fir, Giant Fir, Lowland White Fir","Grand Fir, Giant Fir, Lowland White Fir","Landscape Uses:Screen, Specimen. Prefers a goo...",...,"NORTHERN AMERICA: Canada (British Columbia), U...",Baby care Dye Gum Incense Repellent Roofing Wo...,3.0,Seed - sow early February in a greenhouse or o...,Western N. America - British Columbia to Calif...,Abies grandis - (Douglas. ex D.Don.)Lindl.,Food Forest Scented Plants,"Form: Columnar, Upright or erect.",5-6,No
2,Stomachic,Abies spectabilis,Himalayan Fir,0,2,https://pfaf.org/user/Plant.aspx?LatinName=Abi...,Semi-shade; Fully Hardy; Moist Soil; Full shad...,Himalayan Fir,Himalayan Fir,Prefers a good moist but not water-logged soil...,...,"TEMPERATE ASIA: Afghanistan, China (Xizang Ziz...",Essential Fuel Incense Wood An essential oil i...,3.0,Seed - sow early February in a greenhouse or o...,E. Asia - Himalayas from Afghanistan to Nepal.,Abies spectabilis - (D.Don.)Spach.,Scented Plants,,6-9,No
3,Stomachic,Abutilon theophrasti,"China Jute, Velvetleaf, Butterprint Buttonweed...",3,2,https://pfaf.org/user/Plant.aspx?LatinName=Abu...,Semi-shade; Fully Hardy; Well drained soil; Mo...,"China Jute, Velvetleaf, Butterprint Buttonweed...","China Jute, Velvetleaf, Butterprint Buttonweed...",Requires full sun or part day shade and a fert...,...,"TEMPERATE ASIA: Afghanistan, Egypt (Sinai), Ir...",Fibre Oil Paper A fibre obtained from the stem...,4.0,Seed - sow early April in a greenhouse. Germin...,Asia - tropical. Naturalised in S.E. Europe an...,Abutilon theophrasti - Medik.,,Form: Upright or erect.,Coming soon,Yes
4,Stomachic,Acacia farnesiana,"Sweet Acacia, Perfume Acacia, Huisache",2,2,https://pfaf.org/user/Plant.aspx?LatinName=Aca...,Moist Soil; Frost Hardy; Full sun; Well draine...,"Sweet Acacia, Perfume Acacia, Huisache","Sweet Acacia, Perfume Acacia, Huisache",Landscape Uses: Pest tolerant. Originally trop...,...,NORTHERN AMERICA: United States (Florida (nort...,Adhesive Dye Essential Gum Gum Ink Tannin Teet...,4.0,Seed - best sown as soon as it is ripe in a su...,"The original range is uncertain, but is probab...",Acacia farnesiana - (L.)Willd.,Nitrogen Fixer Scented Plants,Bloom Color: Yellow. Main Bloom Time: Early su...,9-11,Yes


In [3]:
#####################################
# DATA PREPROCESSING WITH GROUPING
# ----------- Using Down-Sampling -----------
#####################################

from sklearn.preprocessing import LabelEncoder
from imblearn.over_sampling import SMOTE
from sklearn.model_selection import train_test_split

# Drop rows with missing key info
df = df.dropna(subset=[
    'Family', 'Scientific Name', 
    'medicinal_rating_search', 'use_keyword'
])
df.columns = df.columns.str.strip()

# Filter only medicinal plants & extract first medicinal property
df = df[df['medicinal_rating_search'] > 0]
df['medicinal_property'] = (
    df['use_keyword']
      .astype(str)
      .str.lower()
      .str.split(';|,')
      .str[0]
      .str.strip()
)

# Rename features
df = df.rename(columns={
    'Edibility Rating': 'edibility',
    'Other Uses Rating': 'other_uses'
})

# Drop any rows missing our core features
df = df.dropna(subset=[
    'Family', 'Scientific Name', 
    'medicinal_property', 'edibility', 'other_uses'
])

# Map properties to groups (identity here)
df['medicinal_group'] = df['medicinal_property']
df = df.dropna(subset=['medicinal_group'])

# Encode categorical 'Family'
df['Family'] = LabelEncoder().fit_transform(df['Family'])

# ---- NEW: keep only groups with ≥ 50 samples ----
group_counts = df['medicinal_group'].value_counts()
keep_groups  = group_counts[group_counts >= 50].index
df = df[df['medicinal_group'].isin(keep_groups)]
print("Groups retained (>=50 samples):")
print(df['medicinal_group'].value_counts(), "\n")

# 1) Fit the encoder on your entire target column
label_encoder = LabelEncoder()
y_encoded     = label_encoder.fit_transform(df['medicinal_group'])

# 2) Prepare X and apply SMOTE
X = df[['Family', 'edibility', 'other_uses']]
sm = SMOTE(random_state=42, k_neighbors=1)
X_resampled, y_resampled = sm.fit_resample(X, y_encoded)

# 3) Train / test split
X_train, X_test, y_train, y_test = train_test_split(
    X_resampled, y_resampled,
    test_size=0.2,
    random_state=42,
    stratify=y_resampled
)

Groups retained (>=50 samples):
medicinal_group
diuretic        998
astringent      837
tonic           604
stomachic       551
febrifuge       471
               ... 
antiemetic       62
antiperiodic     62
hypnotic         61
emollient        60
antipruritic     56
Name: count, Length: 61, dtype: int64 



In [4]:
#####################################
# MODEL TRAINING AND EVALUATION
#####################################
# Optional second round of SMOTE
use_smote = False  # Set to True if you want to apply SMOTE again
if use_smote:
    print("Before SMOTE:", Counter(y_train))
    sm = SMOTE(random_state=42)
    X_train_res, y_train_res = sm.fit_resample(X_train, y_train)
    print("After SMOTE:", Counter(y_train_res))
    scale_pos_weight = 1
else:
    X_train_res, y_train_res = X_train, y_train
    # Only valid for binary targets; unused here
    scale_pos_weight = 1

# Define models
models = {
    'Logistic Regression': LogisticRegression(max_iter=1000, class_weight='balanced'),
    'Decision Tree': DecisionTreeClassifier(class_weight='balanced'),
    'Random Forest': RandomForestClassifier(class_weight='balanced'),
    'Gradient Boosting': GradientBoostingClassifier(),  # doesn't support class_weight
    'SVM': SVC(probability=True, class_weight='balanced'),
    'XGBoost': XGBClassifier(use_label_encoder=False, eval_metric='logloss')
}

results = {}

for name, model in models.items():
    model.fit(X_train_res, y_train_res)
    y_pred = model.predict(X_test)
    
    # Decode back to original string labels
    y_test_decoded = label_encoder.inverse_transform(y_test)
    y_pred_decoded = label_encoder.inverse_transform(y_pred)
    
    # Only binary metrics use proba; skip for now
    y_proba = None

    # Save results
    results[name] = {
        'model': model,
        'pred': y_pred_decoded,
        'proba': y_proba,
        'report': classification_report(y_test_decoded, y_pred_decoded, output_dict=True)
    }

    # Print readable classification report
    print(f"\n=== {name} ===")
    print(classification_report(y_test_decoded, y_pred_decoded))

df_smote_results = results

# Build a Tab, one child per model
tabs = widgets.Tab()
children = []
titles   = []

for name, result in results.items():
    out = widgets.Output()
    with out:
        # Decode predictions & truth
        y_pred_decoded = result['pred']
        y_test_decoded = label_encoder.inverse_transform(y_test)

        # Compute CM
        cm = confusion_matrix(
            y_test_decoded,
            y_pred_decoded,
            labels=label_encoder.classes_
        )
        disp = ConfusionMatrixDisplay(
            confusion_matrix=cm,
            display_labels=label_encoder.classes_
        )

        # Plot
        fig, ax = plt.subplots(figsize=(8,8))
        disp.plot(ax=ax, xticks_rotation='vertical', cmap='Blues')
        ax.set_title(f"Confusion Matrix — {name}")
        plt.tight_layout()
        plt.show()

    children.append(out)
    titles.append(name)

tabs.children = children
for i, title in enumerate(titles):
    tabs.set_title(i, title)

display(tabs)


=== Logistic Regression ===
                  precision    recall  f1-score   support

   abortifacient       0.00      0.00      0.00       199
      alterative       0.00      0.00      0.00       200
       analgesic       0.00      0.00      0.00       200
         anodyne       0.05      0.13      0.07       199
    anthelmintic       0.00      0.00      0.00       199
   antiasthmatic       0.05      0.02      0.02       199
   antibacterial       0.02      0.05      0.03       199
        antidote       0.00      0.00      0.00       200
      antiemetic       0.00      0.00      0.00       199
      antifungal       0.01      0.04      0.02       200
antiinflammatory       0.00      0.00      0.00       200
    antiperiodic       0.00      0.00      0.00       200
  antiphlogistic       0.04      0.02      0.02       199
    antipruritic       0.00      0.00      0.00       199
     antipyretic       0.50      0.01      0.01       200
   antirheumatic       0.00      0.00     

Tab(children=(Output(), Output(), Output(), Output(), Output(), Output()), selected_index=0, titles=('Logistic…

In [5]:
#####################################
# TESTING FOR REAL-WORLD PERFORMANCE
#####################################

from sklearn.preprocessing import LabelEncoder
from imblearn.under_sampling import RandomUnderSampler
from sklearn.model_selection import train_test_split

# Drop rows with missing essential info
df = df.dropna(subset=['Family','Scientific Name','medicinal_rating_search','use_keyword'])
df.columns = df.columns.str.strip()

# Keep only truly‐medicinal plants
df = df[df['medicinal_rating_search'] > 0]

# Extract the first medicinal_property keyword
df['medicinal_property'] = (
    df['use_keyword']
      .astype(str)
      .str.lower()
      .str.split(';|,')
      .str[0]
      .str.strip()
)

# Rename feature columns
df = df.rename(columns={
    'Edibility Rating':'edibility',
    'Other Uses Rating':'other_uses'
})

# Drop any rows missing our three features
features = ['Family','edibility','other_uses']
df = df.dropna(subset=features)

# --- NEW: only keep groups with >=50 samples ---
group_counts = df['medicinal_group'].value_counts()
keep_groups  = group_counts[group_counts >= 50].index
df = df[df['medicinal_group'].isin(keep_groups)]

# Encode Family
df['Family'] = LabelEncoder().fit_transform(df['Family'])

# Encode the target
label_encoder = LabelEncoder()
y = label_encoder.fit_transform(df['medicinal_group'])

# Our X matrix
X = df[features]

# Balance by under‐sampling the majority classes
rus = RandomUnderSampler(random_state=42)
X_res, y_res = rus.fit_resample(X, y)

# Finally, split into train/test
X_train, X_test, y_train, y_test = train_test_split(
    X_res, y_res,
    test_size=0.2,
    random_state=42,
    stratify=y_res
)

In [6]:
#####################################
# MODEL TRAINING AND EVALUATION
#####################################
# Set seed for reproducibility
np.random.seed(42)
from sklearn.utils.class_weight import compute_class_weight
from imblearn.over_sampling import SMOTE
from collections import Counter
from xgboost import XGBClassifier

# --- Optional: Apply SMOTE to training set ---
use_smote = True  # Set to True because of the class imbalance in the dataset

if use_smote:
    # Check class distribution in y_train
    class_counts = Counter(y_train)
    min_class_size = min(class_counts.values())

    # Set k_neighbors to one less than the smallest class count
    # SMOTE requires: n_neighbors < min_class_size
    k_neighbors = max(1, min_class_size - 1)  # Must be at least 1

    # Apply SMOTE with adjusted k_neighbors
    sm = SMOTE(random_state=42, k_neighbors=k_neighbors)
    X_train_res, y_train_res = sm.fit_resample(X_train, y_train)

    print(f"SMOTE used k_neighbors={k_neighbors}")
    print("After SMOTE:", Counter(y_train_res))
else:
    X_train_res, y_train_res = X_train, y_train
    # Compute scale_pos_weight for XGBoost (used when not using SMOTE)
    # Compute scale_pos_weight for XGBoost (used when not using SMOTE)
    class_counts = np.bincount(y_train)
    if len(class_counts) == 2:  # Binary classification
        neg, pos = class_counts
        scale_pos_weight = neg / pos
    else:  # Multiclass classification
        scale_pos_weight = 1  # Default value for multiclass

# Define models with class_weight or equivalent
models = {
    'Logistic Regression': LogisticRegression(class_weight='balanced'),
    'Decision Tree': DecisionTreeClassifier(class_weight='balanced'),
    'Random Forest': RandomForestClassifier(class_weight='balanced'),
    'Gradient Boosting': GradientBoostingClassifier(),  # Cannot set class_weight
    'SVM': SVC(probability=True, class_weight='balanced'),
    'XGBoost': XGBClassifier(use_label_encoder=False, eval_metric='logloss', scale_pos_weight=scale_pos_weight)
}

results = {}

for name, model in models.items():
    model.fit(X_train_res, y_train_res)
    y_pred = model.predict(X_test)
    y_proba = model.predict_proba(X_test)[:, 1] if hasattr(model, 'predict_proba') else None
    y_pred_labels = label_encoder.inverse_transform(y_pred)
    y_test_labels = label_encoder.inverse_transform(y_test)
    results[name] = {
        'model': model,
        'pred': y_pred_labels,
        'proba': y_proba,
        'report': classification_report(y_test_labels, y_pred_labels, output_dict=True)
    }
    print(f"=== {name} ===")
    # Decode numeric predictions back to strings
    print(classification_report(y_test_labels, y_pred_labels))

# Build a Tab, one child per model
tabs = widgets.Tab()
children = []
titles   = []

for name, result in results.items():
    out = widgets.Output()
    with out:
        # Decode predictions & truth
        y_pred_decoded = result['pred']
        y_test_decoded = label_encoder.inverse_transform(y_test)

        # Compute CM
        cm = confusion_matrix(
            y_test_decoded,
            y_pred_decoded,
            labels=label_encoder.classes_
        )
        disp = ConfusionMatrixDisplay(
            confusion_matrix=cm,
            display_labels=label_encoder.classes_
        )

        # Plot
        fig, ax = plt.subplots(figsize=(8,8))
        disp.plot(ax=ax, xticks_rotation='vertical', cmap='Blues')
        ax.set_title(f"Confusion Matrix — {name}")
        plt.tight_layout()
        plt.show()

    children.append(out)
    titles.append(name)

tabs.children = children
for i, title in enumerate(titles):
    tabs.set_title(i, title)

display(tabs)

df_downsampled_results = results

SMOTE used k_neighbors=43
After SMOTE: Counter({41: 45, 57: 45, 29: 45, 42: 45, 21: 45, 53: 45, 36: 45, 50: 45, 5: 45, 28: 45, 45: 45, 27: 45, 18: 45, 16: 45, 51: 45, 7: 45, 35: 45, 2: 45, 46: 45, 32: 45, 52: 45, 1: 45, 8: 45, 54: 45, 3: 45, 33: 45, 25: 45, 24: 45, 14: 45, 48: 45, 6: 45, 44: 45, 56: 45, 34: 45, 23: 45, 39: 45, 26: 45, 38: 45, 55: 45, 13: 45, 30: 45, 37: 45, 43: 45, 31: 45, 60: 45, 17: 45, 15: 45, 0: 45, 9: 45, 10: 45, 12: 45, 22: 45, 58: 45, 49: 45, 20: 45, 19: 45, 59: 45, 47: 45, 40: 45, 4: 45, 11: 45})
=== Logistic Regression ===
                  precision    recall  f1-score   support

   abortifacient       0.07      0.09      0.08        11
      alterative       0.00      0.00      0.00        11
       analgesic       0.00      0.00      0.00        11
         anodyne       0.00      0.00      0.00        11
    anthelmintic       0.00      0.00      0.00        11
   antiasthmatic       0.00      0.00      0.00        11
   antibacterial       0.00      0.00 

Tab(children=(Output(), Output(), Output(), Output(), Output(), Output()), selected_index=0, titles=('Logistic…

In [7]:
#####################################
# VISUALIZATION AND INTERACTIVE WIDGETS
#####################################
def extract_classification_metrics(results_dict, eval_set_name):
    rows = []
    for model_name, result in results_dict.items():
        report = result['report']
        for cls, metrics in report.items():
            # Skip summary rows
            if cls in ['accuracy', 'macro avg', 'weighted avg']:
                continue
            rows.append({
                'Model': model_name,
                'Metric': 'F1',
                'Class': cls,
                'Value': metrics.get('f1-score', 0.0),
                'Set': eval_set_name
            })
            rows.append({
                'Model': model_name,
                'Metric': 'Precision',
                'Class': cls,
                'Value': metrics.get('precision', 0.0),
                'Set': eval_set_name
            })
            rows.append({
                'Model': model_name,
                'Metric': 'Recall',
                'Class': cls,
                'Value': metrics.get('recall', 0.0),
                'Set': eval_set_name
            })
    df_metrics = pd.DataFrame(rows)
    if df_metrics.empty:
        print("🚨 WARNING: Extracted metrics dataframe is empty!")
    return df_metrics

# 1. Use your real results dictionaries (from your model evaluation)
smote_results = {k: v for k, v in df_smote_results.items() if k != 'Set'}
downsampled_results = {k: v for k, v in df_downsampled_results.items() if k != 'Set'}

# 2. Extract performance metrics into tidy DataFrames
df_smote_metrics = extract_classification_metrics(smote_results, 'SMOTE')
df_downsampled_metrics = extract_classification_metrics(downsampled_results, 'Downsampled')

# 3. Combine them for plotting
df_viz = pd.concat([df_smote_metrics, df_downsampled_metrics], ignore_index=True)

# Create widgets
metric_dropdown = widgets.Dropdown(
    options=['F1', 'Precision', 'Recall'],
    value='F1',
    description='Metric:'
)
# Dynamically pull all class names from df_viz
unique_classes = sorted(df_viz['Class'].unique())

class_dropdown = widgets.Dropdown(
    options=unique_classes,
    value=unique_classes[0],
    description='Class:'
)

def update_plot(metric, target_class):
    filtered = df_viz[
        (df_viz['Metric'] == metric) &
        (df_viz['Class'] == target_class)
    ]
    fig = px.bar(
        filtered,
        x='Model',
        y='Value',
        color='Set',
        barmode='group',
        title=f"{metric} Score Comparison - Class {target_class}"
    )
    fig.update_layout(yaxis=dict(range=[0, 1]))
    fig.show()

out = widgets.interactive_output(update_plot, {
    'metric': metric_dropdown,
    'target_class': class_dropdown
})

# Display widgets and output
display(widgets.HBox([metric_dropdown, class_dropdown]), out)

HBox(children=(Dropdown(description='Metric:', options=('F1', 'Precision', 'Recall'), value='F1'), Dropdown(de…

Output()

In [10]:
#######################################
# ROC CURVE VISUALIZATION
#######################################
# 1) Extract the trained XGB and your test set
xgb_model = df_smote_results['XGBoost']['model']
y_true    = y_test                            # integer‐encoded labels
X_test    = X_test
classes   = label_encoder.classes_            # array of string labels

# 2) Get probability estimates
y_proba = xgb_model.predict_proba(X_test)     # shape (n_samples, n_classes)

# 3) Compute ROC + AUC for each class
roc_data = {}
for idx, name in enumerate(classes):
    y_true_bin = (y_true == idx).astype(int)
    y_score    = y_proba[:, idx]
    fpr, tpr, _ = roc_curve(y_true_bin, y_score)
    roc_data[name] = (fpr, tpr, auc(fpr, tpr))

# 4) Filter to only those with AUC > 0.7
auc_threshold    = 0.85
filtered_roc_data = {
    name: data 
    for name, data in roc_data.items() 
    if data[2] > auc_threshold
}
print(f"Plotting {len(filtered_roc_data)} classes with AUC > {auc_threshold}:")
print(list(filtered_roc_data.keys()))

# 5) Build interactive ROC widget over the filtered set
dropdown = widgets.Dropdown(
    options=['(none)'] + list(filtered_roc_data.keys()),
    value='(none)',
    description='Highlight:'
)
out = widgets.Output()

def plot_roc(change):
    choice = change['new']
    out.clear_output(wait=True)
    with out:
        fig, ax = plt.subplots(figsize=(10,6))
        ax.plot([0,1],[0,1],'k--', alpha=0.3)
        for name, (fpr, tpr, score) in filtered_roc_data.items():
            if name == choice:
                ax.plot(fpr, tpr, lw=3, alpha=1.0,
                        label=f"{name} (AUC={score:.2f})")
            else:
                ax.plot(fpr, tpr, lw=1, alpha=0.2, label=name)
        ax.set_xlabel("False Positive Rate")
        ax.set_ylabel("True Positive Rate")
        ax.set_title(f"ROC Curves (AUC > {auc_threshold})")

        # shrink axis to make room on right
        box = ax.get_position()
        ax.set_position([box.x0, box.y0, box.width * 0.75, box.height])
        # legend in 3 columns, outside
        ax.legend(
            loc='center left',
            bbox_to_anchor=(1.02, 0.5),
            fontsize='small',
            ncol=3,
            borderaxespad=0.
        )

        plt.tight_layout()
        plt.show()

dropdown.observe(plot_roc, names='value')
plot_roc({'new': dropdown.value})

display(widgets.VBox([dropdown, out]))

Plotting 33 classes with AUC > 0.85:
['antiasthmatic', 'antibacterial', 'antidote', 'antifungal', 'antiperiodic', 'antiphlogistic', 'antipruritic', 'antipyretic', 'antiscorbutic', 'antispasmodic', 'antitussive', 'aperient', 'appetizer', 'aromatic', 'demulcent', 'depurative', 'digestive', 'emetic', 'emollient', 'galactogogue', 'haemostatic', 'hepatic', 'hypnotic', 'hypoglycaemic', 'kidney', 'narcotic', 'nervine', 'odontalgic', 'purgative', 'rubefacient', 'salve', 'styptic', 'vermifuge']


VBox(children=(Dropdown(description='Highlight:', options=('(none)', 'antiasthmatic', 'antibacterial', 'antido…