## Introduction

In this notebook, I compare several aggregation strategies for combining replicate-level morphological profiles into compound-level features. Cell Painting experiments often include multiple replicates per compound, and the method of aggregation can significantly affect downstream model performance.

The main goal is to evaluate how different aggregation techniques influence the accuracy and generalization ability of MoA classification models.

Specifically, I assess the following strategies:
- Arithmetic mean
- Geometric mean
- Arithmetic–geometric mean (AGM)
- Selection of the closest replicate to each mean-based reference

Each approach is applied to the same underlying data, and the resulting models are compared using accuracy and macro F1-score. The objective is to identify the most effective and reliable method for morphological profile summarization in multimodal pipelines.

In [None]:
import os
import glob
from datetime import date

import numpy as np
import pandas as pd

import matplotlib.pyplot as plt
import seaborn as sns

In [None]:
# paths to data
save_path = "result/"

def load_latest_file(suffix: str):
    # loading dataset
    file_merged_type = '/*[0-9]_' + suffix
    files_merged = glob.glob(save_path + file_merged_type)

    # gets latest file
    max_file_merged = max(files_merged, key=os.path.getctime)

    # load file
    return pd.read_csv(max_file_merged)

# load file
df_merged_exp_mean = load_latest_file('merged_exp_mean.csv')

We need to prepare nessary variables for our pipeline:

In [27]:
metadata_cols = [col for col in df_merged_exp_mean.columns if col.startswith('Metadata_')]
binary_cols = [col for col in df_merged_exp_mean.columns if col.startswith('binary_')]
chemical_cols = [col for col in df_merged_exp_mean.columns if col.startswith('chemical_')]
moa_cols = [col for col in df_merged_exp_mean.columns if col.startswith('moa_')]
drug_status_cols = [col for col in df_merged_exp_mean.columns if col.startswith('drug_status_')]
fingerprints_cols = [col for col in df_merged_exp_mean.columns if col.startswith('fp_')]
morphology_cols = [col for col in df_merged_exp_mean.columns if col.startswith('morphology_')]

In [28]:
moa_counts = df_merged_exp_mean[moa_cols].sum().sort_values(ascending=False)
top_moa = moa_counts[moa_counts > 100].index.tolist()
top_moa

['moa_inhibitor', 'moa_antagonist', 'moa_agonist']

We will use our previous pipeline:

In [29]:
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split, GridSearchCV, StratifiedKFold
from sklearn.metrics import accuracy_score, f1_score, classification_report, confusion_matrix
from sklearn.metrics import precision_recall_curve, average_precision_score, roc_curve, auc, roc_auc_score
from sklearn.preprocessing import StandardScaler

class MultimodalMoAPipeline:
    """
    A pipeline for multimodal classification tasks using morphological, chemical,
    and fingerprint features.

    Parameters
    ----------
    morph_cols : list of str, default=[]
        List of column names representing morphological features.

    chem_cols : list of str, default=[]
        List of column names representing chemical descriptors.

    fp_cols : list of str, default=[]
        List of column names representing fingerprint features.

    use_morph : bool, default=True
        Whether to include morphological features in the model.

    use_chem : bool, default=True
        Whether to include chemical features in the model.

    use_fp : bool, default=True
        Whether to include fingerprint features in the model.

    scaler : str or None, default='standard'
        Type of scaler to apply to features. Options:
            - 'standard': StandardScaler from sklearn
            - None or any other value: no scaling will be applied

        Note: Some models like CatBoost do not require scaling.

    model : sklearn-like classifier, default=None
        A scikit-learn compatible classifier. If None, defaults to
        RandomForestClassifier with predefined parameters.

    random_state : int, default=42
        Random seed for reproducibility.

    use_gridsearch : bool, default=False
        Whether to perform GridSearchCV to tune hyperparameters.
        Only supported for sklearn-compatible estimators.
    """
    def __init__(self, morph_cols=[], chem_cols=[], fp_cols=[],
                 use_morph=True, use_chem=True, use_fp=True,
                 scaler='standard', model=None, random_state=42,
                 use_gridsearch=False):
        self.morph_cols = morph_cols
        self.chem_cols = chem_cols
        self.fp_cols = fp_cols
        self.use_morph = use_morph
        self.use_chem = use_chem
        self.use_fp = use_fp
        
        self.scaler_type = scaler
        self.random_state = random_state
        self.use_gridsearch = use_gridsearch
        
        self.model = model if model is not None else RandomForestClassifier(n_estimators=200, random_state=random_state, class_weight='balanced', min_samples_leaf=3)

    def _get_feature_set(self, df):
        cols = []
        if self.use_morph:
            cols += self.morph_cols
        if self.use_chem:
            cols += self.chem_cols
        if self.use_fp:
            cols += self.fp_cols
        return df[cols].copy()

    def _scale(self, X):
        if self.scaler_type == 'standard':
            self.scaler = StandardScaler()
            X_scaled = pd.DataFrame(self.scaler.fit_transform(X), columns=X.columns, index=X.index)
            return X_scaled
        return X  # no scaling

    def fit(self, df, target_col):
        X = self._get_feature_set(df)
        X = self._scale(X)

        # Support for multilabel: if target_col — it's a one-hot list
        if isinstance(target_col, list):
            df = df.copy()
            df['__moa_label'] = df[target_col].idxmax(axis=1)
            y = df['__moa_label']
        else:
            y = df[target_col]

        self.X_train, self.X_test, self.y_train, self.y_test = train_test_split(
            X, y, stratify=y, test_size=0.2, random_state=self.random_state
        )

        if self.use_gridsearch:
            param_grid = {
                'n_estimators': [100, 200],
                'max_depth': [None, 10, 20],
                'min_samples_leaf': [1, 3, 5],
                'class_weight': ['balanced']
            }
            base_model = RandomForestClassifier(random_state=self.random_state)
            cv_strategy = StratifiedKFold(n_splits=3, shuffle=True, random_state=self.random_state)
            grid = GridSearchCV(base_model, param_grid, scoring='f1_macro', cv=cv_strategy, n_jobs=-1)
            grid.fit(self.X_train, self.y_train)
            print("Best params from GridSearchCV:", grid.best_params_)
            self.model = grid.best_estimator_
        else:
            self.model.fit(self.X_train, self.y_train)

        self.y_pred = self.model.predict(self.X_test)


    def evaluate(self, show_plots=True):
        acc = accuracy_score(self.y_test, self.y_pred)
        f1 = f1_score(self.y_test, self.y_pred, average='macro')

        print(f"\n🎯 Accuracy: {acc:.4f}")
        print(f"🎯 Macro F1-score: {f1:.4f}\n")
        print("Classification Report:\n")
        print(classification_report(self.y_test, self.y_pred))
        
        if not show_plots:
            return

        plt.figure(figsize=(10, 8))
        cm = confusion_matrix(self.y_test, self.y_pred, labels=self.model.classes_)
        sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
                    xticklabels=self.model.classes_, yticklabels=self.model.classes_)
        plt.title("Confusion Matrix")
        plt.xlabel("Predicted")
        plt.ylabel("True")
        plt.tight_layout()
        plt.show()
        
        # PR and ROC curves only for binary classification
        if len(self.model.classes_) == 2:
            y_proba = self.model.predict_proba(self.X_test)[:, 1]

            # Precision-Recall Curve
            precision, recall, thresholds = precision_recall_curve(self.y_test, y_proba)
            avg_precision = average_precision_score(self.y_test, y_proba)

            plt.figure(figsize=(8, 6))
            plt.plot(recall, precision, marker='.')
            plt.xlabel('Recall')
            plt.ylabel('Precision')
            plt.title(f'Precision-Recall Curve (AP = {avg_precision:.4f})')
            plt.grid()
            plt.tight_layout()
            plt.show()

            # ROC Curve
            fpr, tpr, _ = roc_curve(self.y_test, y_proba)
            roc_auc = roc_auc_score(self.y_test, y_proba)

            plt.figure(figsize=(8, 6))
            plt.plot(fpr, tpr, label=f'ROC curve (AUC = {roc_auc:.4f})')
            plt.plot([0, 1], [0, 1], linestyle='--', color='gray')
            plt.xlabel('False Positive Rate')
            plt.ylabel('True Positive Rate')
            plt.title('ROC Curve')
            plt.legend(loc='lower right')
            plt.grid()
            plt.tight_layout()
            plt.show()

    def plot_importance(self, top_n=30):
        if not hasattr(self.model, 'feature_importances_'):
            print("This model does not support feature importances.")
            return

        feat_imp = pd.DataFrame({
            'feature': self.X_train.columns,
            'importance': self.model.feature_importances_
        })

        # Figure out the feature groups
        def get_group(feature):
            if feature in self.morph_cols:
                return 'morphology'
            elif feature in self.chem_cols:
                return 'chemistry'
            elif feature in self.fp_cols:
                return 'fingerprint'
            else:
                return 'other'

        feat_imp['group'] = feat_imp['feature'].apply(get_group)

        # Group by feature group and sum importances
        grouped = feat_imp.groupby('group')['importance'].sum().sort_values(ascending=False)
        print("\n📊 Feature Importance by Group:")
        print(grouped)

        # Sort and select top_n features
        feat_imp = feat_imp.sort_values(by='importance', ascending=False).head(top_n)

        plt.figure(figsize=(12, 8))
        sns.barplot(data=feat_imp, x='importance', y='feature', hue='group', dodge=False, palette='viridis')
        plt.title(f"Top {top_n} Feature Importances by Group")
        plt.tight_layout()
        plt.show()

        return feat_imp
    

Our baseline model is based on CatBoost Classifier:

In [30]:
from catboost import CatBoostClassifier
from sklearn.utils.class_weight import compute_class_weight

X = df_merged_exp_mean[morphology_cols + chemical_cols + fingerprints_cols]
y = df_merged_exp_mean[top_moa].idxmax(axis=1)

X_train, X_test, y_train, y_test = train_test_split(X, y, stratify=y, test_size=0.2, random_state=42)

classes = y_train.unique()
weights = compute_class_weight(class_weight='balanced', classes=classes, y=y_train)
class_weights = dict(zip(classes, weights))

cat_boost_model = CatBoostClassifier(
    iterations=300,
    depth=6,
    learning_rate=0.1,
    loss_function='MultiClass',
    class_weights=class_weights,
    eval_metric='TotalF1',
    random_seed=42,
    verbose=50,
)

In [31]:
cat_boost_pipe = MultimodalMoAPipeline(
    morph_cols=morphology_cols,
    chem_cols=chemical_cols,
    fp_cols=fingerprints_cols,
    use_gridsearch=False,
    model=cat_boost_model
)

Now we will test our pipeline on mean-based aggregation:

In [32]:
cat_boost_pipe.fit(df_merged_exp_mean, target_col=top_moa)
cat_boost_pipe.evaluate(show_plots=False)

0:	learn: 0.5930239	total: 64.2ms	remaining: 19.2s
50:	learn: 0.9769683	total: 1.75s	remaining: 8.55s
100:	learn: 0.9940559	total: 3.37s	remaining: 6.63s
150:	learn: 1.0000000	total: 4.83s	remaining: 4.77s
200:	learn: 1.0000000	total: 6.26s	remaining: 3.08s
250:	learn: 1.0000000	total: 7.67s	remaining: 1.5s
299:	learn: 1.0000000	total: 9.05s	remaining: 0us

🎯 Accuracy: 0.8795
🎯 Macro F1-score: 0.3535

Classification Report:

                precision    recall  f1-score   support

   moa_agonist       0.50      0.07      0.12        14
moa_antagonist       0.00      0.00      0.00        12
 moa_inhibitor       0.89      0.99      0.94       198

      accuracy                           0.88       224
     macro avg       0.46      0.35      0.35       224
  weighted avg       0.82      0.88      0.83       224



Now we will test our pipeline on geometric mean aggregation:

In [None]:
df_merged_exp_geometric_mean = load_latest_file('merged_exp_geometric_mean.csv')

cat_boost_pipe.fit(df_merged_exp_geometric_mean, target_col=top_moa)
cat_boost_pipe.evaluate(show_plots=False)

0:	learn: 0.5081643	total: 47.7ms	remaining: 14.3s
50:	learn: 0.9704265	total: 1.43s	remaining: 6.98s
100:	learn: 0.9953284	total: 3.08s	remaining: 6.08s
150:	learn: 0.9995759	total: 4.65s	remaining: 4.59s
200:	learn: 1.0000000	total: 6.08s	remaining: 2.99s
250:	learn: 1.0000000	total: 7.52s	remaining: 1.47s
299:	learn: 1.0000000	total: 8.9s	remaining: 0us

🎯 Accuracy: 0.8616
🎯 Macro F1-score: 0.3086

Classification Report:

                precision    recall  f1-score   support

   moa_agonist       0.00      0.00      0.00        14
moa_antagonist       0.00      0.00      0.00        12
 moa_inhibitor       0.88      0.97      0.93       198

      accuracy                           0.86       224
     macro avg       0.29      0.32      0.31       224
  weighted avg       0.78      0.86      0.82       224



Now we will test our pipeline on geometric mean aggregation with selection of the closest replicate to each mean-based reference:

In [34]:
df_merged_exp_closest_geometric_mean = load_latest_file('merged_exp_closest_geometric_mean.csv')

cat_boost_pipe.fit(df_merged_exp_closest_geometric_mean, target_col=top_moa)
cat_boost_pipe.evaluate(show_plots=False)

0:	learn: 0.6084567	total: 45.6ms	remaining: 13.6s
50:	learn: 0.9790918	total: 1.47s	remaining: 7.19s
100:	learn: 0.9961783	total: 2.86s	remaining: 5.64s
150:	learn: 0.9995759	total: 4.28s	remaining: 4.22s
200:	learn: 1.0000000	total: 5.74s	remaining: 2.83s
250:	learn: 1.0000000	total: 7.3s	remaining: 1.42s
299:	learn: 1.0000000	total: 8.69s	remaining: 0us

🎯 Accuracy: 0.8839
🎯 Macro F1-score: 0.3128

Classification Report:

                precision    recall  f1-score   support

   moa_agonist       0.00      0.00      0.00        14
moa_antagonist       0.00      0.00      0.00        12
 moa_inhibitor       0.88      1.00      0.94       198

      accuracy                           0.88       224
     macro avg       0.29      0.33      0.31       224
  weighted avg       0.78      0.88      0.83       224



  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Now we will test our pipeline on arithmetic–geometric mean aggregation:

In [35]:
df_merged_exp_agm = load_latest_file('merged_exp_agm.csv')

cat_boost_pipe.fit(df_merged_exp_agm, target_col=top_moa)
cat_boost_pipe.evaluate(show_plots=False)

0:	learn: 0.5849560	total: 42.5ms	remaining: 12.7s
50:	learn: 0.9743312	total: 1.51s	remaining: 7.37s
100:	learn: 0.9974533	total: 3.15s	remaining: 6.2s
150:	learn: 0.9995759	total: 4.66s	remaining: 4.59s
200:	learn: 1.0000000	total: 6.13s	remaining: 3.02s
250:	learn: 1.0000000	total: 7.71s	remaining: 1.5s
299:	learn: 1.0000000	total: 9.11s	remaining: 0us

🎯 Accuracy: 0.8661
🎯 Macro F1-score: 0.3606

Classification Report:

                precision    recall  f1-score   support

   moa_agonist       0.00      0.00      0.00        14
moa_antagonist       1.00      0.08      0.15        12
 moa_inhibitor       0.89      0.97      0.93       198

      accuracy                           0.87       224
     macro avg       0.63      0.35      0.36       224
  weighted avg       0.84      0.87      0.83       224



Now we will test our pipeline on arithmetic–geometric mean aggregation with selection of the closest replicate to each mean-based reference:

In [36]:
df_merged_exp_closest_agm = load_latest_file('merged_exp_closest_agm.csv')

cat_boost_pipe.fit(df_merged_exp_closest_agm, target_col=top_moa)
cat_boost_pipe.evaluate(show_plots=False)

0:	learn: 0.6147883	total: 33.1ms	remaining: 9.91s
50:	learn: 0.9821050	total: 1.89s	remaining: 9.22s
100:	learn: 0.9974533	total: 3.63s	remaining: 7.14s
150:	learn: 0.9991518	total: 5.11s	remaining: 5.04s
200:	learn: 1.0000000	total: 6.6s	remaining: 3.25s
250:	learn: 1.0000000	total: 8.1s	remaining: 1.58s
299:	learn: 1.0000000	total: 9.56s	remaining: 0us

🎯 Accuracy: 0.8839
🎯 Macro F1-score: 0.3128

Classification Report:

                precision    recall  f1-score   support

   moa_agonist       0.00      0.00      0.00        14
moa_antagonist       0.00      0.00      0.00        12
 moa_inhibitor       0.88      1.00      0.94       198

      accuracy                           0.88       224
     macro avg       0.29      0.33      0.31       224
  weighted avg       0.78      0.88      0.83       224



  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


## Summary of results

As part of this investigation, I evaluated various strategies for aggregating features across multiple cell lines using a **CatBoostClassifier** for MOA (mechanism of action) classification. The objective was to determine which aggregation approach leads to improved predictive performance, especially for underrepresented classes.

### Aggregation Methods Compared

1. **Mean** (arithmetic mean)
2. **Geometric mean**
3. **Arithmetic–Geometric Mean (AGM)**
4. **Closest compound to geometric mean**
5. **Closest compound to arithmetic–geometric mean**

### Summary of Results

| Aggregation Method              | Accuracy | Macro F1 | Notes |
|--------------------------------|----------|----------|-------|
| **Mean**                       | 0.8795   | 0.3535   | Decent macro F1; better than geometric mean |
| **Geometric Mean**             | 0.8616   | 0.3086   | Weak generalization to minor classes |
| **AGM**                        | 0.8661   | 0.3606   | Best macro F1; recall on `antagonist` slightly improved |
| **Closest to Geometric Mean** | 0.8839   | 0.3128   | High accuracy; fails to classify minor classes |
| **Closest to AGM**            | 0.8839   | 0.3128   | Identical to above; no contribution from `agonist`/`antagonist` |

### Method-by-Method Analysis

#### Mean
- A strong baseline method with **high accuracy** and moderate **macro F1-score (0.3535)**.
- Provides reasonable precision and recall for the dominant class (`moa_inhibitor`), but struggles on minority classes.
- **Conclusion**: A reliable default aggregation; retains generalizability better than geometric methods.

#### Geometric Mean
- Performs slightly worse than the arithmetic mean on all metrics.
- Fails to produce any predictions for `moa_agonist` or `moa_antagonist`.
- **Conclusion**: Too conservative; information may be overly compressed in the log-domain.

#### Arithmetic–Geometric Mean (AGM)
- Achieves the **highest macro F1-score (0.3606)** among all methods tested.
- Notably, it is the only method that yielded **non-zero recall** on `moa_antagonist`.
- **Conclusion**: Most promising aggregation method when class balance is important.

#### Closest to Geometric Mean
- Despite achieving the **highest accuracy (0.8839)**, it completely ignores minority classes.
- Precision and recall for both `moa_agonist` and `moa_antagonist` are 0.
- **Conclusion**: Overfits to dominant class; not suitable when class diversity matters.

#### Closest to AGM
- Identical in performance to "closest geometric mean", suggesting similar failure modes.
- No contribution from rare MOAs despite proximity to a “balanced” synthetic compound.
- **Conclusion**: Selecting the closest real compound sacrifices representation richness.

### Final Recommendation

The **Arithmetic–Geometric Mean (AGM)** stands out as the most effective aggregation strategy in this study. While it does not achieve the absolute highest accuracy, it provides the best balance between sensitivity to minority classes and overall classification performance. This makes AGM the most suitable candidate for future modeling and benchmarking efforts, particularly in settings with class imbalance.

### Future directions:
- Investigating attention-based or learned aggregation schemes
- Using ensemble methods across multiple aggregation strategies
- Integrating biological priors (e.g., tissue type, cell origin) to weight contributions