# 🤖 Predictive Model Performance
## How do you decide which predictive model to use?
In this notebook, we evaluate several machine learning models to predict whether a plant has medicinal properties based on its taxonomy and other descriptive features. Below is a brief overview of the models used and how they conceptually approach the classification task:

### 1. **Logistic Regression**
- **Type**: Linear model
- **Concept**: Estimates the probability that a plant is medicinal using a weighted combination of input features.
- **Strengths**: Simple, interpretable, fast.
- **Limitations**: Assumes a linear relationship between features and the log-odds of the outcome; struggles with complex patterns.

### 2. **Decision Tree**
- **Type**: Non-linear, rule-based
- **Concept**: Splits the data into branches based on feature thresholds to arrive at a prediction at the leaves.
- **Strengths**: Easy to visualize and understand; captures non-linear relationships.
- **Limitations**: Can overfit the training data if not pruned or regularized.

### 3. **Random Forest**
- **Type**: Ensemble (of decision trees)
- **Concept**: Trains multiple decision trees on different subsets of the data and averages their predictions to reduce variance.
- **Strengths**: More accurate and robust than a single tree; reduces overfitting.
- **Limitations**: Less interpretable; slower than simpler models.

### 4. **Gradient Boosting (e.g., GBT)**
- **Type**: Ensemble (boosted decision trees)
- **Concept**: Trains trees sequentially, where each tree corrects the errors of the previous one using gradient descent.
- **Strengths**: High accuracy; handles complex patterns well.
- **Limitations**: Can overfit if not tuned properly; computationally intensive.

### 5. **Support Vector Machine (SVM)**
- **Type**: Maximum-margin classifier
- **Concept**: Finds the optimal boundary (hyperplane) that best separates medicinal from non-medicinal plants by maximizing the margin between classes.
- **Strengths**: Works well in high-dimensional spaces; robust to overfitting.
- **Limitations**: Not ideal for large datasets; performance depends on kernel choice.

### 6. **XGBoost**
- **Type**: Gradient-boosted tree ensemble (optimized)
- **Concept**: An efficient and regularized implementation of gradient boosting that adds boosting trees iteratively to correct previous errors.
- **Strengths**: Often state-of-the-art in structured data problems; fast and scalable.
- **Limitations**: Complex; requires tuning; less interpretable.

---

Each model captures different aspects of the underlying patterns in the data. By comparing their performance across different evaluation strategies (e.g., with SMOTE and downsampling), we aim to identify not only which models are accurate, but also which ones are most robust under real-world conditions.


In [8]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import ast

from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder, StandardScaler
from sklearn.metrics import classification_report, confusion_matrix, auc, roc_curve
from sklearn.linear_model import LogisticRegression
from sklearn.tree import DecisionTreeClassifier
from sklearn.ensemble import RandomForestClassifier, GradientBoostingClassifier
from sklearn.svm import SVC
import pandas as pd
import plotly.express as px
import ipywidgets as widgets
from IPython.display import display
from sklearn.utils.class_weight import compute_class_weight
from imblearn.over_sampling import SMOTE
from collections import Counter
from xgboost import XGBClassifier
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
from sklearn.preprocessing import LabelEncoder
from imblearn.under_sampling import RandomUnderSampler
from sklearn.model_selection import train_test_split
from collections import Counter
from sklearn.linear_model import LogisticRegression
from sklearn.tree import DecisionTreeClassifier
from sklearn.ensemble import RandomForestClassifier, GradientBoostingClassifier
from sklearn.svm import SVC
from xgboost import XGBClassifier
from sklearn.preprocessing import MultiLabelBinarizer



import warnings
warnings.filterwarnings('ignore')

In [12]:
#####################################
# DATA LOADING
#####################################
#Load the pfaf_plants_merged.csv file
df = pd.read_csv('pfaf_plants_merged.csv')
df_countries = pd.read_excel('plants_native_countries.xlsx')
df_countries = df_countries.rename(columns={'Scientific name': 'Scientific Name'})
#df_countries.head()
# Add 'native countries' column to the main dataframe
countries = df_countries[['Family', 'native_countries']]
df = df.merge(countries, on='Family', how='left')
df.head()
# Drop rows where 'native_countries' is NaN
df = df.dropna(subset=['native_countries'])
# Drop duplicate rows based on 'Scientific Name'
df = df.drop_duplicates(subset=['Scientific Name'])

In [13]:
#####################################
# DATA PREPROCESSING WITH GROUPING
# ----------- Using Down-Sampling -----------
#####################################

from sklearn.preprocessing import LabelEncoder
from imblearn.over_sampling import SMOTE
from sklearn.model_selection import train_test_split

# Drop rows with missing key info
df = df.dropna(subset=[
    'Family', 'Scientific Name', 
    'medicinal_rating_search', 'use_keyword'
])
df.columns = df.columns.str.strip()

# Filter only medicinal plants & extract first medicinal property
df = df[df['medicinal_rating_search'] > 0]
df['medicinal_property'] = (
    df['use_keyword']
      .astype(str)
      .str.lower()
      .str.split(';|,')
      .str[0]
      .str.strip()
)

# Rename features
df = df.rename(columns={
    'Edibility Rating': 'edibility',
    'Other Uses Rating': 'other_uses'
})

# Drop any rows missing our core features
df = df.dropna(subset=[
    'Family', 'Scientific Name', 
    'medicinal_property', 'edibility', 'other_uses', 'native_countries'
])

# Map properties to groups (identity here)
df['medicinal_group'] = df['medicinal_property']
df = df.dropna(subset=['medicinal_group'])

# Encode categorical 'Family'
df['Family'] = LabelEncoder().fit_transform(df['Family'])

# ---- NEW: keep only groups with ≥ 50 samples ----
group_counts = df['medicinal_group'].value_counts()
keep_groups  = group_counts[group_counts >= 50].index
df = df[df['medicinal_group'].isin(keep_groups)]
print("Groups retained (>=50 samples):")
print(df['medicinal_group'].value_counts(), "\n")

# At this point you have df and y_encoded
label_encoder = LabelEncoder()
y_encoded     = label_encoder.fit_transform(df['medicinal_group'])

# --- NEW: Parse and binarize the multi-label country lists ---
# If native_countries is stored as a string repr of a list, first turn it into an actual list:
df['native_countries'] = df['native_countries'].apply(
    lambda x: ast.literal_eval(x) if isinstance(x, str) else x
)

# Fit the MultiLabelBinarizer on your full dataset
mlb = MultiLabelBinarizer()
country_dummies = pd.DataFrame(
    mlb.fit_transform(df['native_countries']),
    columns=mlb.classes_,
    index=df.index
)

# Combine with your other numeric features
X_other = df[['Family', 'edibility', 'other_uses']]
X_full  = pd.concat([X_other, country_dummies], axis=1)

# 2) Apply SMOTE on the fully numeric matrix
sm = SMOTE(random_state=42, k_neighbors=1)
X_resampled, y_resampled = sm.fit_resample(X_full, y_encoded)

# 3) Train / test split
X_train, X_test, y_train, y_test = train_test_split(
    X_resampled, y_resampled,
    test_size=0.2,
    random_state=42,
    stratify=y_resampled
)

Groups retained (>=50 samples):
medicinal_group
tonic          604
stomachic      551
skin           212
astringent     178
diuretic       168
poultice       158
febrifuge      158
sedative       152
miscellany     134
vermifuge      129
ophthalmic     114
styptic         86
hypotensive     81
stimulant       78
laxative        77
salve           75
purgative       73
pectoral        60
Name: count, dtype: int64 



In [14]:
#####################################
# MODEL TRAINING AND EVALUATION
#####################################
# Optional second round of SMOTE
use_smote = False  # Set to True if you want to apply SMOTE again
if use_smote:
    print("Before SMOTE:", Counter(y_train))
    sm = SMOTE(random_state=42)
    X_train_res, y_train_res = sm.fit_resample(X_train, y_train)
    print("After SMOTE:", Counter(y_train_res))
    scale_pos_weight = 1
else:
    X_train_res, y_train_res = X_train, y_train
    # Only valid for binary targets; unused here
    scale_pos_weight = 1

# Define models
models = {
    'Logistic Regression': LogisticRegression(max_iter=1000, class_weight='balanced'),
    'Decision Tree': DecisionTreeClassifier(class_weight='balanced'),
    'Random Forest': RandomForestClassifier(class_weight='balanced'),
    'Gradient Boosting': GradientBoostingClassifier(),  # doesn't support class_weight
    'SVM': SVC(probability=True, class_weight='balanced'),
    'XGBoost': XGBClassifier(use_label_encoder=False, eval_metric='logloss')
}

results = {}

for name, model in models.items():
    model.fit(X_train_res, y_train_res)
    y_pred = model.predict(X_test)
    
    # Decode back to original string labels
    y_test_decoded = label_encoder.inverse_transform(y_test)
    y_pred_decoded = label_encoder.inverse_transform(y_pred)
    
    # Only binary metrics use proba; skip for now
    y_proba = None

    # Save results
    results[name] = {
        'model': model,
        'pred': y_pred_decoded,
        'proba': y_proba,
        'report': classification_report(y_test_decoded, y_pred_decoded, output_dict=True)
    }

    # Print readable classification report
    print(f"\n=== {name} ===")
    print(classification_report(y_test_decoded, y_pred_decoded))

df_smote_results = results

# Build a Tab, one child per model
tabs = widgets.Tab()
children = []
titles   = []

for name, result in results.items():
    out = widgets.Output()
    with out:
        # Decode predictions & truth
        y_pred_decoded = result['pred']
        y_test_decoded = label_encoder.inverse_transform(y_test)

        # Compute CM
        cm = confusion_matrix(
            y_test_decoded,
            y_pred_decoded,
            labels=label_encoder.classes_
        )
        disp = ConfusionMatrixDisplay(
            confusion_matrix=cm,
            display_labels=label_encoder.classes_
        )

        # Plot
        fig, ax = plt.subplots(figsize=(8,8))
        disp.plot(ax=ax, xticks_rotation='vertical', cmap='Blues')
        ax.set_title(f"Confusion Matrix — {name}")
        plt.tight_layout()
        plt.show()

    children.append(out)
    titles.append(name)

tabs.children = children
for i, title in enumerate(titles):
    tabs.set_title(i, title)

display(tabs)


=== Logistic Regression ===
              precision    recall  f1-score   support

  astringent       0.62      0.39      0.48       120
    diuretic       0.34      0.28      0.30       120
   febrifuge       0.25      0.45      0.32       121
 hypotensive       0.36      0.88      0.51       121
    laxative       0.14      0.23      0.18       121
  miscellany       0.28      0.16      0.20       121
  ophthalmic       0.11      0.07      0.09       121
    pectoral       0.13      0.08      0.10       121
    poultice       0.27      0.16      0.20       120
   purgative       0.25      0.40      0.31       121
       salve       0.15      0.12      0.13       121
    sedative       0.23      0.31      0.26       121
        skin       0.28      0.22      0.25       121
   stimulant       0.25      0.31      0.28       121
   stomachic       0.20      0.10      0.13       121
     styptic       0.18      0.18      0.18       121
       tonic       0.21      0.02      0.04       12

Tab(children=(Output(), Output(), Output(), Output(), Output(), Output()), selected_index=0, titles=('Logistic…

In [15]:
#####################################
# DATA PREPROCESSING WITH GROUPING
# ----------- Using Down-Sampling -----------
#####################################

from sklearn.preprocessing import LabelEncoder
from imblearn.over_sampling import SMOTE
from sklearn.model_selection import train_test_split

# Drop rows with missing key info
df = df.dropna(subset=[
    'Family', 'Scientific Name', 
    'medicinal_rating_search', 'use_keyword'
])
df.columns = df.columns.str.strip()

# Filter only medicinal plants & extract first medicinal property
df = df[df['medicinal_rating_search'] > 0]
df['medicinal_property'] = (
    df['use_keyword']
      .astype(str)
      .str.lower()
      .str.split(';|,')
      .str[0]
      .str.strip()
)

# Rename features
df = df.rename(columns={
    'Edibility Rating': 'edibility',
    'Other Uses Rating': 'other_uses'
})

# Drop any rows missing our core features
df = df.dropna(subset=[
    'Family', 'Scientific Name', 
    'medicinal_property', 'edibility', 'other_uses', 'native_countries'
])

# Map properties to groups (identity here)
df['medicinal_group'] = df['medicinal_property']
df = df.dropna(subset=['medicinal_group'])

# Encode categorical 'Family'
df['Family'] = LabelEncoder().fit_transform(df['Family'])

# ---- NEW: keep only groups with ≥ 50 samples ----
group_counts = df['medicinal_group'].value_counts()
keep_groups  = group_counts[group_counts >= 50].index
df = df[df['medicinal_group'].isin(keep_groups)]
print("Groups retained (>=50 samples):")
print(df['medicinal_group'].value_counts(), "\n")

# At this point you have df and y_encoded
label_encoder = LabelEncoder()
y_encoded     = label_encoder.fit_transform(df['medicinal_group'])

# --- NEW: Parse and binarize the multi-label country lists ---
# If native_countries is stored as a string repr of a list, first turn it into an actual list:
df['native_countries'] = df['native_countries'].apply(
    lambda x: ast.literal_eval(x) if isinstance(x, str) else x
)

# Fit the MultiLabelBinarizer on your full dataset
mlb = MultiLabelBinarizer()
country_dummies = pd.DataFrame(
    mlb.fit_transform(df['native_countries']),
    columns=mlb.classes_,
    index=df.index
)

# Combine with your other numeric features
X_other = df[['Family', 'edibility', 'other_uses']]
X_full  = pd.concat([X_other, country_dummies], axis=1)

# 2) Apply SMOTE on the fully numeric matrix
rus = RandomUnderSampler(random_state=42)
X_resampled, y_resampled = rus.fit_resample(X_full, y_encoded)

# Finally, split into train/test
# 3) Train / test split
X_train, X_test, y_train, y_test = train_test_split(
    X_resampled, y_resampled,
    test_size=0.2,
    random_state=42,
    stratify=y_resampled
)

Groups retained (>=50 samples):
medicinal_group
tonic          604
stomachic      551
skin           212
astringent     178
diuretic       168
poultice       158
febrifuge      158
sedative       152
miscellany     134
vermifuge      129
ophthalmic     114
styptic         86
hypotensive     81
stimulant       78
laxative        77
salve           75
purgative       73
pectoral        60
Name: count, dtype: int64 



In [16]:
#####################################
# MODEL TRAINING AND EVALUATION
#####################################
# Set seed for reproducibility
np.random.seed(42)
from sklearn.utils.class_weight import compute_class_weight
from imblearn.over_sampling import SMOTE
from collections import Counter
from xgboost import XGBClassifier

# --- Optional: Apply SMOTE to training set ---
use_smote = True  # Set to True because of the class imbalance in the dataset

if use_smote:
    # Check class distribution in y_train
    class_counts = Counter(y_train)
    min_class_size = min(class_counts.values())

    # Set k_neighbors to one less than the smallest class count
    # SMOTE requires: n_neighbors < min_class_size
    k_neighbors = max(1, min_class_size - 1)  # Must be at least 1

    # Apply SMOTE with adjusted k_neighbors
    sm = SMOTE(random_state=42, k_neighbors=k_neighbors)
    X_train_res, y_train_res = sm.fit_resample(X_train, y_train)

    print(f"SMOTE used k_neighbors={k_neighbors}")
    print("After SMOTE:", Counter(y_train_res))
else:
    X_train_res, y_train_res = X_train, y_train
    # Compute scale_pos_weight for XGBoost (used when not using SMOTE)
    # Compute scale_pos_weight for XGBoost (used when not using SMOTE)
    class_counts = np.bincount(y_train)
    if len(class_counts) == 2:  # Binary classification
        neg, pos = class_counts
        scale_pos_weight = neg / pos
    else:  # Multiclass classification
        scale_pos_weight = 1  # Default value for multiclass

# Define models with class_weight or equivalent
models = {
    'Logistic Regression': LogisticRegression(class_weight='balanced'),
    'Decision Tree': DecisionTreeClassifier(class_weight='balanced'),
    'Random Forest': RandomForestClassifier(class_weight='balanced'),
    'Gradient Boosting': GradientBoostingClassifier(),  # Cannot set class_weight
    'SVM': SVC(probability=True, class_weight='balanced'),
    'XGBoost': XGBClassifier(use_label_encoder=False, eval_metric='logloss', scale_pos_weight=scale_pos_weight)
}

results = {}

for name, model in models.items():
    model.fit(X_train_res, y_train_res)
    y_pred = model.predict(X_test)
    y_proba = model.predict_proba(X_test)[:, 1] if hasattr(model, 'predict_proba') else None
    y_pred_labels = label_encoder.inverse_transform(y_pred)
    y_test_labels = label_encoder.inverse_transform(y_test)
    results[name] = {
        'model': model,
        'pred': y_pred_labels,
        'proba': y_proba,
        'report': classification_report(y_test_labels, y_pred_labels, output_dict=True)
    }
    print(f"=== {name} ===")
    # Decode numeric predictions back to strings
    print(classification_report(y_test_labels, y_pred_labels))

# Build a Tab, one child per model
tabs = widgets.Tab()
children = []
titles   = []

for name, result in results.items():
    out = widgets.Output()
    with out:
        # Decode predictions & truth
        y_pred_decoded = result['pred']
        y_test_decoded = label_encoder.inverse_transform(y_test)

        # Compute CM
        cm = confusion_matrix(
            y_test_decoded,
            y_pred_decoded,
            labels=label_encoder.classes_
        )
        disp = ConfusionMatrixDisplay(
            confusion_matrix=cm,
            display_labels=label_encoder.classes_
        )

        # Plot
        fig, ax = plt.subplots(figsize=(8,8))
        disp.plot(ax=ax, xticks_rotation='vertical', cmap='Blues')
        ax.set_title(f"Confusion Matrix — {name}")
        plt.tight_layout()
        plt.show()

    children.append(out)
    titles.append(name)

tabs.children = children
for i, title in enumerate(titles):
    tabs.set_title(i, title)

display(tabs)

df_downsampled_results = results

SMOTE used k_neighbors=47
After SMOTE: Counter({4: 48, 1: 48, 9: 48, 0: 48, 6: 48, 12: 48, 7: 48, 8: 48, 16: 48, 17: 48, 13: 48, 2: 48, 11: 48, 14: 48, 15: 48, 5: 48, 10: 48, 3: 48})
=== Logistic Regression ===
              precision    recall  f1-score   support

  astringent       0.14      0.17      0.15        12
    diuretic       0.20      0.17      0.18        12
   febrifuge       0.28      0.42      0.33        12
 hypotensive       0.26      0.58      0.36        12
    laxative       0.00      0.00      0.00        12
  miscellany       0.00      0.00      0.00        12
  ophthalmic       0.00      0.00      0.00        12
    pectoral       0.07      0.08      0.08        12
    poultice       0.00      0.00      0.00        12
   purgative       0.22      0.17      0.19        12
       salve       0.00      0.00      0.00        12
    sedative       0.14      0.33      0.20        12
        skin       0.06      0.08      0.07        12
   stimulant       0.21      0.2

Tab(children=(Output(), Output(), Output(), Output(), Output(), Output()), selected_index=0, titles=('Logistic…

In [None]:
#####################################
# VISUALIZATION AND INTERACTIVE WIDGETS
#####################################
def extract_classification_metrics(results_dict, eval_set_name):
    rows = []
    for model_name, result in results_dict.items():
        report = result['report']
        for cls, metrics in report.items():
            # Skip summary rows
            if cls in ['accuracy', 'macro avg', 'weighted avg']:
                continue
            rows.append({
                'Model': model_name,
                'Metric': 'F1',
                'Class': cls,
                'Value': metrics.get('f1-score', 0.0),
                'Set': eval_set_name
            })
            rows.append({
                'Model': model_name,
                'Metric': 'Precision',
                'Class': cls,
                'Value': metrics.get('precision', 0.0),
                'Set': eval_set_name
            })
            rows.append({
                'Model': model_name,
                'Metric': 'Recall',
                'Class': cls,
                'Value': metrics.get('recall', 0.0),
                'Set': eval_set_name
            })
    df_metrics = pd.DataFrame(rows)
    if df_metrics.empty:
        print("🚨 WARNING: Extracted metrics dataframe is empty!")
    return df_metrics

# 1. Use your real results dictionaries (from your model evaluation)
smote_results = {k: v for k, v in df_smote_results.items() if k != 'Set'}
downsampled_results = {k: v for k, v in df_downsampled_results.items() if k != 'Set'}

# 2. Extract performance metrics into tidy DataFrames
df_smote_metrics = extract_classification_metrics(smote_results, 'SMOTE')
df_downsampled_metrics = extract_classification_metrics(downsampled_results, 'Downsampled')

# 3. Combine them for plotting
df_viz = pd.concat([df_smote_metrics, df_downsampled_metrics], ignore_index=True)

# Create widgets
metric_dropdown = widgets.Dropdown(
    options=['F1', 'Precision', 'Recall'],
    value='F1',
    description='Metric:'
)
# Dynamically pull all class names from df_viz
unique_classes = sorted(df_viz['Class'].unique())

class_dropdown = widgets.Dropdown(
    options=unique_classes,
    value=unique_classes[0],
    description='Class:'
)

def update_plot(metric, target_class):
    filtered = df_viz[
        (df_viz['Metric'] == metric) &
        (df_viz['Class'] == target_class)
    ]
    fig = px.bar(
        filtered,
        x='Model',
        y='Value',
        color='Set',
        barmode='group',
        title=f"{metric} Score Comparison - Class {target_class}"
    )
    fig.update_layout(yaxis=dict(range=[0, 1]))
    # save out
    filename = f"{metric}_{target_class}_comparison.html"
    fig.write_html(filename, include_plotlyjs='cdn')
    print(f"▶️ Saved bar chart to {filename}")
    fig.show()

out = widgets.interactive_output(update_plot, {
    'metric': metric_dropdown,
    'target_class': class_dropdown
})

# Display widgets and output
display(widgets.HBox([metric_dropdown, class_dropdown]), out)

HBox(children=(Dropdown(description='Metric:', options=('F1', 'Precision', 'Recall'), value='F1'), Dropdown(de…

Output()

In [None]:
#######################################
# ROC CURVE VISUALIZATION
#######################################
# 1) Extract the trained XGB and your test set
xgb_model = df_smote_results['XGBoost']['model']
y_true    = y_test                            # integer‐encoded labels
X_test    = X_test
classes   = label_encoder.classes_            # array of string labels

# 2) Get probability estimates
y_proba = xgb_model.predict_proba(X_test)     # shape (n_samples, n_classes)

# 3) Compute ROC + AUC for each class
roc_data = {}
for idx, name in enumerate(classes):
    y_true_bin = (y_true == idx).astype(int)
    y_score    = y_proba[:, idx]
    fpr, tpr, _ = roc_curve(y_true_bin, y_score)
    roc_data[name] = (fpr, tpr, auc(fpr, tpr))

# 4) Filter to only those with AUC > 0.7
auc_threshold    = 0.7
filtered_roc_data = {
    name: data 
    for name, data in roc_data.items() 
    if data[2] > auc_threshold
}
print(f"Plotting {len(filtered_roc_data)} classes with AUC > {auc_threshold}:")
print(list(filtered_roc_data.keys()))

# 5) Build interactive ROC widget over the filtered set
import plotly.graph_objs as go
from plotly.subplots import make_subplots

# assume filtered_roc_data is already defined as before

# 5) Plot with Plotly for click-to-toggle traces
import plotly.graph_objs as go
from plotly.subplots import make_subplots

# assume filtered_roc_data and auc_threshold are already defined

# 1) Create the figure
fig = go.Figure()

# add the diagonal
fig.add_trace(go.Scatter(
    x=[0,1], y=[0,1],
    mode='lines',
    line=dict(color='black', dash='dash'),
    showlegend=False,
    hoverinfo='none'
))

# add one ROC line per class
for name, (fpr, tpr, score) in filtered_roc_data.items():
    fig.add_trace(go.Scatter(
        x=fpr, y=tpr,
        mode='lines',
        name=f"{name} (AUC={score:.2f})",
        line=dict(width=2),
        opacity=0.7,
        hoverinfo='name+x+y'
    ))

# 2) Update layout with toggle-others behavior
# … your existing code to build `fig` …

fig.update_layout(
    title=f"ROC Curves (AUC > {auc_threshold:.2f})",
    xaxis_title="False Positive Rate",
    yaxis_title="True Positive Rate",
    legend=dict(
        title="Click to isolate trace",
        orientation="h",
        x=0, y=-0.1,
        itemclick='toggleothers',
        itemdoubleclick='toggle',
    ),
    margin=dict(l=50, r=50, t=50, b=100),
    width=800, height=600,
    clickmode='none'
)

# ▶️ Save to HTML:
roc_filename = f"roc_curves_auc_above_{int(auc_threshold*100)}.html"
fig.write_html(roc_filename, include_plotlyjs='cdn')
print(f"▶️ Saved ROC curves to {roc_filename}")

fig.show()

Plotting 18 classes with AUC > 0.7:
['astringent', 'diuretic', 'febrifuge', 'hypotensive', 'laxative', 'miscellany', 'ophthalmic', 'pectoral', 'poultice', 'purgative', 'salve', 'sedative', 'skin', 'stimulant', 'stomachic', 'styptic', 'tonic', 'vermifuge']
