# Active Learning Example

This notebook demonstrates loading data, generating representations (ECFP/MACCS/ChemBERTa),
running a GP-based active learning experiment via the `explainable_al` package, plotting recall,
and saving results. The cells below follow the example script provided in the repository.

In [None]:
# Imports and dataset load
%matplotlib inline
import os
import numpy as np
import pandas as pd
from explainable_al.featuriser import smiles_to_ecfp8_df, get_maccs_from_smiles_list, load_chemberta_embeddings
from explainable_al import experiments
from explainable_al.metrics_plots import make_plot_recall

data_path = os.path.join('..', 'data', 'D2R.csv') if os.path.exists(os.path.join('..', 'data', 'D2R.csv')) else os.path.join('data', 'D2R.csv')
df = pd.read_csv(data_path)
print('Loaded dataset:', data_path)
df.head()

In [None]:
# Generate ECFP (ECFP8)
ecfp = smiles_to_ecfp8_df(df, 'SMILES')
print('ECFP shape:', ecfp.shape)
ecfp.head()

In [None]:
# Generate MACCS
maccs = get_maccs_from_smiles_list(df['SMILES'].tolist())
print('MACCS shape:', maccs.shape)
maccs[:5]

In [None]:
# ChemBERTa embeddings (placeholder - load if precomputed)
npz_path = os.path.join('..', 'data', 'chemberta_embeddings.npz') if os.path.exists(os.path.join('..', 'data', 'chemberta_embeddings.npz')) else 'data/chemberta_embeddings.npz'
if os.path.exists(npz_path):
    emb = load_chemberta_embeddings(npz_path)
    print('Loaded ChemBERTa embeddings:', getattr(emb, 'shape', None))
else:
    print('No precomputed ChemBERTa file found at', npz_path)
    print('To compute embeddings: from explainable_al.featuriser import smiles_to_chemberta; emb = smiles_to_chemberta(df)')

In [None]:
# Run a packaged experiment (uses GP surrogate and pre-defined protocols)
results, dataset_size = experiments.run_experiment(data_path, 'D2R')
print('Protocols returned:', list(results.keys()))

In [None]:
# Prepare results for plotting and visualize recall (2% and 5%)
rows = []
for protocol_name, cycles in results.items():
    for c in cycles:
        rows.append({
            'Protocol': protocol_name,
            'Compounds acquired': c.get('compounds_acquired', 0),
            'Recall (2%)': c.get('top_2p', 0) / max(1, 0.02 * dataset_size),
            'Recall (5%)': c.get('top_5p', 0) / max(1, 0.05 * dataset_size),
        })
plot_df = pd.DataFrame(rows)
print('Prepared', len(plot_df), 'rows for plotting')
make_plot_recall(plot_df, y='Recall (2%)')
make_plot_recall(plot_df, y='Recall (5%)')

In [None]:
# Save cycle-level CSV outputs
out_dir = 'results'
os.makedirs(out_dir, exist_ok=True)
for protocol_name, cycles in results.items():
    pd.DataFrame(cycles).to_csv(os.path.join(out_dir, f'results_{protocol_name}.csv'), index=False)
print('Saved results to', out_dir)