# KAN for Chemistry: Interpretable Yield Prediction

This notebook demonstrates using Kolmogorov-Arnold Networks (KANs) for predicting reaction yields in the famous **Buchwald-Hartwig C-N cross-coupling** dataset from [Ahneman et al. (Science, 2018)](https://www.science.org/doi/10.1126/science.aar5169).

## Why KANs for Chemistry?

In chemistry, we often use **Multivariate Linear Regression (MLR)** because:
1. Interpretability matters - we want to understand *why* a reaction works
2. Datasets are often small (50-500 reactions)
3. Physical meaning of parameters is important

**KANs offer advantages over both MLR and black-box ML:**
- Like MLR: Interpretable (can see what each input contributes)
- Unlike MLR: Can capture nonlinear relationships
- Unlike neural networks: Shows explicit activation functions on edges
- The learned univariate functions may reveal physical relationships!

## The Dataset

The Buchwald-Hartwig dataset explores Pd-catalyzed C-N cross-coupling:
- **4 Ligands**: Different phosphine ligands
- **22 Additives**: Heterocyclic additives (potential catalyst poisons)
- **3 Bases**: Different organic bases
- **15 Aryl halides**: Different substrates
- **Yield**: 0-100% (target variable)

**References:**
- Ahneman et al. "Predicting reaction performance in C–N cross-coupling using machine learning" *Science* 2018
- [SigmanGroup](https://github.com/SigmanGroup) - Multivariate Linear Regression for catalysis
- [Paton Lab](https://patonlab.com/) - Data-driven chemistry tools

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import LinearRegression, Ridge
from sklearn.ensemble import RandomForestRegressor
from sklearn.metrics import r2_score, mean_absolute_error

# RDKit for molecular descriptors
from rdkit import Chem
from rdkit.Chem import Descriptors, AllChem

# KAN from pycse
import os
os.environ["JAX_PLATFORM_NAME"] = "cpu"
from pycse.sklearn.kan import KAN

import warnings
warnings.filterwarnings('ignore')

## 1. Load and Explore the Data

In [None]:
# Load the Buchwald-Hartwig dataset
# Download from: https://github.com/rxn4chemistry/rxn_yields

# For this demo, we'll download directly
import urllib.request
import os

data_url = "https://raw.githubusercontent.com/rxn4chemistry/rxn_yields/master/data/Buchwald-Hartwig/Dreher_and_Doyle_input_data.xlsx"
data_path = "/tmp/buchwald_hartwig.xlsx"

if not os.path.exists(data_path):
    print("Downloading Buchwald-Hartwig dataset...")
    urllib.request.urlretrieve(data_url, data_path)
    print("Done!")

# Load the data
df = pd.read_excel(data_path, sheet_name='Plates1-3')
print(f"Dataset shape: {df.shape}")
print(f"\nColumns: {list(df.columns)}")
print(f"\nYield statistics:")
print(df['Output'].describe())

In [None]:
# Visualize the yield distribution
fig, axes = plt.subplots(1, 2, figsize=(12, 4))

# Histogram
axes[0].hist(df['Output'], bins=50, edgecolor='black', alpha=0.7)
axes[0].set_xlabel('Yield (%)')
axes[0].set_ylabel('Count')
axes[0].set_title('Distribution of Reaction Yields')
axes[0].axvline(df['Output'].mean(), color='red', linestyle='--', label=f'Mean: {df["Output"].mean():.1f}%')
axes[0].legend()

# Component counts
components = ['Ligand', 'Additive', 'Base', 'Aryl halide']
counts = [df[c].nunique() for c in components]
axes[1].bar(components, counts, color=['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728'])
axes[1].set_ylabel('Number of Unique Components')
axes[1].set_title('Reaction Components')
for i, v in enumerate(counts):
    axes[1].text(i, v + 0.5, str(v), ha='center', fontweight='bold')

plt.tight_layout()
plt.show()

## 2. Compute Molecular Descriptors

We use **RDKit** to compute molecular descriptors from SMILES strings.

For each reaction component (ligand, additive, base, aryl halide), we calculate:
- **MolWt**: Molecular weight
- **MolLogP**: Lipophilicity (octanol-water partition coefficient)
- **TPSA**: Topological polar surface area
- **NumRotatableBonds**: Flexibility measure
- **NumHAcceptors/Donors**: H-bonding capacity
- **NumAromaticRings**: Aromaticity
- **FractionCSP3**: Fraction sp3 carbons (3D character)

In [None]:
def get_mol_descriptors(smiles, prefix=''):
    """Calculate RDKit 2D descriptors for a molecule."""
    mol = Chem.MolFromSmiles(smiles)
    if mol is None:
        return None
    
    # Selected descriptors relevant for reactivity
    desc_funcs = {
        'MolWt': Descriptors.MolWt,
        'MolLogP': Descriptors.MolLogP,
        'TPSA': Descriptors.TPSA,
        'NumRotatableBonds': Descriptors.NumRotatableBonds,
        'NumHAcceptors': Descriptors.NumHAcceptors,
        'NumHDonors': Descriptors.NumHDonors,
        'NumAromaticRings': Descriptors.NumAromaticRings,
        'FractionCSP3': Descriptors.FractionCSP3,
    }
    
    descriptors = {}
    for name, func in desc_funcs.items():
        try:
            descriptors[f"{prefix}{name}"] = func(mol)
        except:
            descriptors[f"{prefix}{name}"] = np.nan
    return descriptors

# Compute descriptors for each component
print("Computing molecular descriptors...")

# Create lookup tables for unique components
component_cols = ['Ligand', 'Additive', 'Base', 'Aryl halide']
prefixes = ['lig_', 'add_', 'base_', 'aryl_']

desc_lookup = {}
for col, prefix in zip(component_cols, prefixes):
    unique_smiles = df[col].unique()
    desc_lookup[col] = {}
    for smi in unique_smiles:
        desc = get_mol_descriptors(smi, prefix)
        if desc is not None:
            desc_lookup[col][smi] = desc
    print(f"  {col}: {len(desc_lookup[col])} molecules processed")

print("Done!")

In [None]:
# Build the feature matrix
def build_feature_row(row):
    """Build feature vector for a reaction."""
    features = {}
    for col, prefix in zip(component_cols, prefixes):
        smi = row[col]
        if smi in desc_lookup[col]:
            features.update(desc_lookup[col][smi])
        else:
            # Handle missing
            for name in ['MolWt', 'MolLogP', 'TPSA', 'NumRotatableBonds', 
                        'NumHAcceptors', 'NumHDonors', 'NumAromaticRings', 'FractionCSP3']:
                features[f"{prefix}{name}"] = np.nan
    return features

# Apply to all rows
feature_dicts = df.apply(build_feature_row, axis=1).tolist()
X_df = pd.DataFrame(feature_dicts)
y = df['Output'].values

# Remove rows with NaN
valid_mask = ~X_df.isna().any(axis=1)
X_df = X_df[valid_mask]
y = y[valid_mask]

print(f"Feature matrix shape: {X_df.shape}")
print(f"Feature columns: {list(X_df.columns)}")
print(f"\nValid samples: {len(y)} / {len(df)}")

## 3. Compare Models: MLR vs Random Forest vs KAN

We'll compare three approaches:
1. **Multivariate Linear Regression (MLR)** - Traditional approach in physical organic chemistry
2. **Random Forest** - What Ahneman et al. used (achieved R² ≈ 0.92 with DFT descriptors)
3. **KAN** - Our interpretable nonlinear model

In [None]:
# Prepare data
X = X_df.values
feature_names = list(X_df.columns)

# Scale features
scaler = StandardScaler()
X_scaled = scaler.fit_transform(X)

# Split data
X_train, X_test, y_train, y_test = train_test_split(
    X_scaled, y, test_size=0.2, random_state=42
)

print(f"Training set: {len(X_train)} samples")
print(f"Test set: {len(X_test)} samples")
print(f"Number of features: {X_train.shape[1]}")

In [None]:
# Train and evaluate models
results = {}

# 1. Linear Regression (MLR)
print("Training Multivariate Linear Regression...")
lr = LinearRegression()
lr.fit(X_train, y_train)
y_pred_lr = lr.predict(X_test)
results['MLR'] = {
    'model': lr,
    'r2': r2_score(y_test, y_pred_lr),
    'mae': mean_absolute_error(y_test, y_pred_lr),
    'y_pred': y_pred_lr
}
print(f"  R² = {results['MLR']['r2']:.4f}, MAE = {results['MLR']['mae']:.2f}%")

# 2. Random Forest
print("\nTraining Random Forest...")
rf = RandomForestRegressor(n_estimators=100, random_state=42, n_jobs=-1)
rf.fit(X_train, y_train)
y_pred_rf = rf.predict(X_test)
results['Random Forest'] = {
    'model': rf,
    'r2': r2_score(y_test, y_pred_rf),
    'mae': mean_absolute_error(y_test, y_pred_rf),
    'y_pred': y_pred_rf
}
print(f"  R² = {results['Random Forest']['r2']:.4f}, MAE = {results['Random Forest']['mae']:.2f}%")

# 3. KAN
print("\nTraining KAN...")
n_features = X_train.shape[1]
kan = KAN(
    layers=(n_features, 8, 1),  # Input -> 8 hidden -> 1 output
    grid_size=5,
    spline_order=3,
)
kan.fit(X_train, y_train, maxiter=300)
y_pred_kan = kan.predict(X_test)
results['KAN'] = {
    'model': kan,
    'r2': r2_score(y_test, y_pred_kan),
    'mae': mean_absolute_error(y_test, y_pred_kan),
    'y_pred': y_pred_kan
}
print(f"  R² = {results['KAN']['r2']:.4f}, MAE = {results['KAN']['mae']:.2f}%")

In [None]:
# Visualize predictions
fig, axes = plt.subplots(1, 3, figsize=(15, 4))

for ax, (name, res) in zip(axes, results.items()):
    ax.scatter(y_test, res['y_pred'], alpha=0.3, s=10)
    ax.plot([0, 100], [0, 100], 'r--', linewidth=2, label='Perfect prediction')
    ax.set_xlabel('Actual Yield (%)')
    ax.set_ylabel('Predicted Yield (%)')
    ax.set_title(f"{name}\nR² = {res['r2']:.3f}, MAE = {res['mae']:.1f}%")
    ax.set_xlim(-5, 105)
    ax.set_ylim(-5, 105)
    ax.legend()
    ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

## 4. Model Comparison Summary

Let's compare the models in terms of performance and interpretability.

In [None]:
# Summary table
print("=" * 60)
print("MODEL COMPARISON SUMMARY")
print("=" * 60)
print(f"{'Model':<20} {'R²':<10} {'MAE (%)':<10} {'Interpretable?'}")
print("-" * 60)

interpretability = {
    'MLR': 'Yes (linear coefs)',
    'Random Forest': 'Partial (feature importance)',
    'KAN': 'Yes (activation functions)'
}

for name, res in results.items():
    print(f"{name:<20} {res['r2']:<10.4f} {res['mae']:<10.2f} {interpretability[name]}")

print("=" * 60)

## 5. Interpretability: Understanding the Models

### 5.1 MLR: Linear Coefficients

In MLR, each coefficient tells us how much the yield changes per unit change in that descriptor.

In [None]:
# MLR coefficients
lr_coefs = pd.DataFrame({
    'Feature': feature_names,
    'Coefficient': lr.coef_
}).sort_values('Coefficient', key=abs, ascending=False)

print("Top 10 MLR Coefficients (most influential features):")
print(lr_coefs.head(10).to_string(index=False))

# Plot
fig, ax = plt.subplots(figsize=(10, 6))
colors = ['green' if c > 0 else 'red' for c in lr_coefs['Coefficient']]
ax.barh(range(len(lr_coefs)), lr_coefs['Coefficient'], color=colors, alpha=0.7)
ax.set_yticks(range(len(lr_coefs)))
ax.set_yticklabels(lr_coefs['Feature'])
ax.set_xlabel('Coefficient (effect on yield per std. dev.)')
ax.set_title('MLR Coefficients: Linear Effects on Yield')
ax.axvline(0, color='black', linewidth=0.5)
ax.invert_yaxis()
plt.tight_layout()
plt.show()

### 5.2 Random Forest: Feature Importance

Random Forest gives us feature importances, but no insight into *how* the feature affects the outcome.

In [None]:
# Random Forest feature importance
rf_importance = pd.DataFrame({
    'Feature': feature_names,
    'Importance': rf.feature_importances_
}).sort_values('Importance', ascending=False)

print("Top 10 Random Forest Feature Importances:")
print(rf_importance.head(10).to_string(index=False))

# Plot
fig, ax = plt.subplots(figsize=(10, 6))
ax.barh(range(len(rf_importance)), rf_importance['Importance'], color='steelblue', alpha=0.7)
ax.set_yticks(range(len(rf_importance)))
ax.set_yticklabels(rf_importance['Feature'])
ax.set_xlabel('Feature Importance')
ax.set_title('Random Forest: Feature Importances\n(No insight into direction or shape of effect)')
ax.invert_yaxis()
plt.tight_layout()
plt.show()

### 5.3 KAN: Learned Activation Functions

**This is where KAN shines!** We can visualize the actual nonlinear transformations learned for each input.

Unlike MLR (linear only) or Random Forest (black box), KAN shows us:
- The **shape** of each input's effect
- Whether the relationship is linear, quadratic, threshold-like, etc.
- Potential **physical interpretations** of the learned functions

In [None]:
# Visualize KAN activations for the input layer
kan.plot_activations(layer_idx=0, figsize=(16, 12))
plt.suptitle('KAN Input Layer Activations: How Each Descriptor Affects Yield', y=1.02)
plt.tight_layout()
plt.show()

In [None]:
# Network visualization
kan.plot_network(figsize=(14, 10))
plt.title('KAN Network Structure with Learned Activations')
plt.show()

## 6. Chemical Interpretation

Looking at the learned activation functions, we can extract chemical insights:

### Key Observations

1. **Additive LogP**: Often shows nonlinear effects - too hydrophobic or too hydrophilic additives may poison the catalyst differently

2. **Ligand Properties**: The phosphine ligand descriptors often show threshold-like behavior - the catalyst works well above a certain size/electron-donating capacity

3. **Aryl Halide Reactivity**: Electronic properties (related to TPSA, H-bonding) can show optimal ranges

### Comparison with Literature

The Sigman group's MLR work has identified similar relationships but assumed linearity. KAN reveals where these assumptions break down!

In [None]:
# Partial dependence-like analysis for top features
# Show how yield changes with each descriptor while holding others constant

def plot_partial_effect(model, feature_idx, X, y, feature_name, n_points=100):
    """Plot partial effect of a feature."""
    x_range = np.linspace(X[:, feature_idx].min(), X[:, feature_idx].max(), n_points)
    X_temp = np.tile(X.mean(axis=0), (n_points, 1))
    X_temp[:, feature_idx] = x_range
    
    y_pred = model.predict(X_temp)
    return x_range, y_pred

# Plot partial effects for top features
top_features = rf_importance.head(6)['Feature'].tolist()
top_indices = [feature_names.index(f) for f in top_features]

fig, axes = plt.subplots(2, 3, figsize=(15, 8))
axes = axes.flatten()

for ax, feat, idx in zip(axes, top_features, top_indices):
    # MLR prediction (linear)
    x_range_lr, y_pred_lr = plot_partial_effect(lr, idx, X_scaled, y, feat)
    
    # KAN prediction (nonlinear)
    x_range_kan, y_pred_kan = plot_partial_effect(kan, idx, X_scaled, y, feat)
    
    ax.plot(x_range_lr, y_pred_lr, 'b--', linewidth=2, label='MLR (linear)')
    ax.plot(x_range_kan, y_pred_kan, 'r-', linewidth=2, label='KAN (nonlinear)')
    ax.set_xlabel(f'{feat} (standardized)')
    ax.set_ylabel('Predicted Yield (%)')
    ax.set_title(feat)
    ax.legend()
    ax.grid(True, alpha=0.3)

plt.suptitle('Partial Effects: MLR vs KAN\nKAN captures nonlinear relationships!', fontsize=14)
plt.tight_layout()
plt.show()

## 7. Uncertainty Quantification with KAN

KAN also supports uncertainty quantification through ensemble predictions - crucial for making confident predictions in chemistry!

In [None]:
# Train KAN with ensemble for uncertainty quantification
print("Training KAN with uncertainty quantification...")
kan_uq = KAN(
    layers=(n_features, 4, 1),
    n_ensemble=32,  # Ensemble for UQ
    grid_size=5,
    loss_type='crps',
)

# Split validation set for calibration
X_tr, X_val, y_tr, y_val = train_test_split(X_train, y_train, test_size=0.2, random_state=42)
kan_uq.fit(X_tr, y_tr, val_X=X_val, val_y=y_val, maxiter=200)

# Get predictions with uncertainty
y_pred_uq = kan_uq.predict(X_test)
ensemble_preds = kan_uq.predict_ensemble(X_test)
y_std = ensemble_preds.std(axis=1)

print(f"\nTest R²: {r2_score(y_test, y_pred_uq):.4f}")
print(f"Mean uncertainty (std): {y_std.mean():.2f}%")

In [None]:
# Visualize predictions with uncertainty
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Prediction vs actual with error bars
sort_idx = np.argsort(y_test)
axes[0].errorbar(range(len(y_test)), y_pred_uq[sort_idx], yerr=2*y_std[sort_idx], 
                 fmt='none', alpha=0.3, capsize=0, color='red')
axes[0].scatter(range(len(y_test)), y_test[sort_idx], s=10, alpha=0.5, label='Actual', color='blue')
axes[0].scatter(range(len(y_test)), y_pred_uq[sort_idx], s=10, alpha=0.5, label='Predicted', color='red')
axes[0].set_xlabel('Sample (sorted by actual yield)')
axes[0].set_ylabel('Yield (%)')
axes[0].set_title('Predictions with ±2σ Uncertainty')
axes[0].legend()

# Uncertainty vs error
errors = np.abs(y_test - y_pred_uq)
axes[1].scatter(y_std, errors, alpha=0.3, s=10)
axes[1].plot([0, y_std.max()], [0, y_std.max()], 'r--', label='Error = σ')
axes[1].set_xlabel('Predicted Uncertainty (σ)')
axes[1].set_ylabel('Actual Error (|y - ŷ|)')
axes[1].set_title('Uncertainty Calibration\n(Well-calibrated: points near red line)')
axes[1].legend()
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

## 8. Summary and Conclusions

### Key Takeaways

1. **KAN bridges the gap** between simple interpretable models (MLR) and powerful but opaque models (Random Forest, Neural Networks)

2. **Interpretability in chemistry**: KAN's learned activation functions can reveal:
   - Nonlinear structure-activity relationships
   - Optimal ranges for descriptors
   - Threshold effects that linear models miss

3. **Uncertainty quantification** is crucial for making actionable predictions in synthesis planning

### When to Use KAN for Chemistry

- When you want **interpretability** but suspect **nonlinear effects**
- For small-to-medium datasets where neural networks might overfit
- When **physical insight** into the learned relationships matters
- As a complement to traditional Sigman-style MLR analysis

### Future Directions

- Use KAN with **DFT-computed descriptors** for higher accuracy
- Apply to **enantioselectivity prediction** (Sigman's specialty)
- Combine with **symbolic regression** to extract explicit equations

In [None]:
# Final KAN report
print("=" * 60)
print("FINAL KAN MODEL REPORT")
print("=" * 60)
kan.report()
print(f"\nTest Performance:")
print(f"  R² = {results['KAN']['r2']:.4f}")
print(f"  MAE = {results['KAN']['mae']:.2f}%")
print(f"\nComparison:")
print(f"  MLR R² = {results['MLR']['r2']:.4f}")
print(f"  Random Forest R² = {results['Random Forest']['r2']:.4f}")
print("=" * 60)