# BHAI VAE Paper Figures

This notebook generates all figures for the paper.

**Prerequisites:**
- Trained models in `models/` directory
- Training data in `data/` directory
- LILY datasets in `data/lily-datasets/` directory

In [None]:
import numpy as np
import pandas as pd
import torch
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D
from pathlib import Path
import sys

# Add parent directory for imports
sys.path.insert(0, '..')
from models.vae import VAE, SemiSupervisedVAE, DistributionAwareScaler, FEATURE_COLS

%matplotlib inline
plt.rcParams['figure.dpi'] = 150

In [None]:
# Configuration
DATA_DIR = Path('../data')
MODEL_DIR = Path('../models')
OUTPUT_DIR = Path('../figures')
OUTPUT_DIR.mkdir(exist_ok=True)

# Model paths
UNSUP_MODEL = MODEL_DIR / 'unsup.pt'
SEMISUP_MODEL = MODEL_DIR / 'semisup.pt'
TRAINING_DATA = DATA_DIR / 'vae_training_data_v2_20cm.csv'

## Load Data and Models

In [None]:
# Load training data
train_df = pd.read_csv(TRAINING_DATA)
X_raw = train_df[FEATURE_COLS].values
valid_mask = ~np.isnan(X_raw).any(axis=1)
X_raw = X_raw[valid_mask]
train_df = train_df[valid_mask].reset_index(drop=True)

# Scale
scaler = DistributionAwareScaler()
X_scaled = scaler.fit_transform(X_raw)

print(f"Samples: {len(X_scaled):,}")

In [None]:
# Load models
model_unsup = VAE(input_dim=6, latent_dim=10)
model_unsup.load_state_dict(torch.load(UNSUP_MODEL, map_location='cpu'))
model_unsup.eval()

model_semisup = SemiSupervisedVAE(input_dim=6, latent_dim=10, n_classes=139)
model_semisup.load_state_dict(torch.load(SEMISUP_MODEL, map_location='cpu'))
model_semisup.eval()

print("Models loaded")

In [None]:
# Generate embeddings
with torch.no_grad():
    X_t = torch.FloatTensor(X_scaled)
    emb_unsup = model_unsup.get_embeddings(X_t).numpy()
    emb_semisup = model_semisup.get_embeddings(X_t).numpy()

print(f"Embeddings shape: {emb_unsup.shape}")

## Figure: Zero-Shot Scatter Plot

Compares R² scores for predicting LILY variables using embeddings from both models.

In [None]:
def fig_zeroshot_scatter(results_df, save_path=None):
    """
    Create zero-shot prediction scatter plot.
    
    Parameters
    ----------
    results_df : pd.DataFrame
        Must have columns: r2_v214, r2_v267, n_samples
    """
    # Filter outliers
    results_df = results_df[
        (results_df['r2_v267'] > -1) & 
        (results_df['r2_v214'] > -1) &
        (results_df['r2_v267'] < 1.1) &
        (results_df['r2_v214'] < 1.1)
    ].copy()
    
    fig, ax = plt.subplots(figsize=(10, 10))
    
    x = results_df['r2_v214'].values  # Semi-supervised
    y = results_df['r2_v267'].values  # Unsupervised
    n_samples = results_df['n_samples'].values
    
    # Scale point size by log of sample count
    log_samples = np.log10(np.clip(n_samples, 100, 1e6))
    sizes = 30 + 150 * (log_samples - 2) / 4
    
    # Color: blue if semi-sup higher, orange if unsup higher
    colors = ['#1f77b4' if xi > yi else '#ff7f0e' for xi, yi in zip(x, y)]
    
    # Scatter plot
    for i in range(len(x)):
        ax.scatter(x[i], y[i], s=sizes[i], c=colors[i], alpha=0.7, 
                  edgecolors='white', linewidth=0.5)
    
    # Diagonal line
    ax.plot([-0.2, 1.05], [-0.2, 1.05], 'k--', lw=1.5, alpha=0.5)
    
    # Labels
    ax.set_xlabel('Semi-supervised R²', fontsize=14)
    ax.set_ylabel('Unsupervised R²', fontsize=14)
    ax.set_xlim(-0.2, 1.05)
    ax.set_ylim(-0.2, 1.05)
    ax.set_aspect('equal')
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    
    # Color legend
    legend_elements = [
        Line2D([0], [0], marker='o', color='w', markerfacecolor='#1f77b4', 
               markersize=10, label='Semi-supervised higher'),
        Line2D([0], [0], marker='o', color='w', markerfacecolor='#ff7f0e', 
               markersize=10, label='Unsupervised higher'),
    ]
    ax.legend(handles=legend_elements, loc='lower right', fontsize=11)
    
    # Size legend
    size_legend_ax = fig.add_axes([0.15, 0.72, 0.15, 0.18])
    size_legend_ax.set_xlim(0, 1)
    size_legend_ax.set_ylim(0, 1)
    size_legend_ax.axis('off')
    
    sample_sizes = [100, 1000, 10000, 100000, 1000000]
    sample_labels = ['100', '1k', '10k', '100k', '1M']
    y_positions = [0.85, 0.65, 0.45, 0.25, 0.05]
    
    for ss, label, yp in zip(sample_sizes, sample_labels, y_positions):
        log_s = np.log10(ss)
        size = 30 + 150 * (log_s - 2) / 4
        size_legend_ax.scatter(0.3, yp, s=size, c='gray', alpha=0.7, edgecolors='white')
        size_legend_ax.text(0.6, yp, label, va='center', fontsize=10)
    
    size_legend_ax.text(0.4, 1.0, 'Samples', ha='center', fontsize=11, fontweight='bold')
    
    plt.tight_layout()
    
    if save_path:
        fig.savefig(save_path, dpi=300, bbox_inches='tight')
        print(f"Saved: {save_path}")
    
    return fig

In [None]:
# Load pre-computed zeroshot results (or compute fresh)
zeroshot_results_path = DATA_DIR / 'zeroshot_results.csv'
if zeroshot_results_path.exists():
    zeroshot_df = pd.read_csv(zeroshot_results_path)
    fig = fig_zeroshot_scatter(zeroshot_df, OUTPUT_DIR / 'fig_zeroshot_scatter.png')
    plt.show()
else:
    print(f"Zeroshot results not found at {zeroshot_results_path}")
    print("Run scripts/compute_zeroshot.py first")

## Figure: Reconstruction Scatter

Compares reconstruction quality (True vs Predicted) for both models.

In [None]:
def fig_reconstruction_scatter(X_scaled, pred_unsup, pred_semisup, save_path=None):
    """
    Create reconstruction scatter plot comparing both models.
    4 columns: Physical (unsup, semisup) | Optical (unsup, semisup)
    3 rows: Bulk/R, MS/G, NGR/B
    """
    def r2(y_true, y_pred):
        return 1 - np.sum((y_true - y_pred)**2) / np.sum((y_true - np.mean(y_true))**2)
    
    physical_labels = ['Bulk Density', 'Mag. Susc.', 'NGR']
    optical_labels = ['R', 'G', 'B']
    physical_idx = [0, 1, 2]
    optical_idx = [3, 4, 5]
    
    physical_color = '#2c3e50'
    optical_colors = ['#e74c3c', '#27ae60', '#3498db']
    
    fig, axes = plt.subplots(3, 4, figsize=(16, 12))
    fig.suptitle('Reconstruction Quality: Predicted vs True', fontsize=14, fontweight='bold', y=0.98)
    
    # Section headers
    fig.text(0.28, 0.93, 'Physical Properties', ha='center', fontsize=12, fontweight='bold')
    fig.text(0.72, 0.93, 'Optical Properties', ha='center', fontsize=12, fontweight='bold')
    
    np.random.seed(42)
    n_plot = min(10000, len(X_scaled))
    plot_idx = np.random.choice(len(X_scaled), n_plot, replace=False)
    
    for row in range(3):
        # Physical columns
        for col, (pred, model_name) in enumerate([(pred_unsup, 'v2.6.7'), (pred_semisup, 'v2.14')]):
            ax = axes[row, col]
            feat_idx = physical_idx[row]
            true_vals = X_scaled[plot_idx, feat_idx]
            pred_vals = pred[plot_idx, feat_idx]
            r2_val = r2(X_scaled[:, feat_idx], pred[:, feat_idx])
            
            ax.scatter(true_vals, pred_vals, alpha=0.3, s=1, c=physical_color)
            lims = [min(true_vals.min(), pred_vals.min()) - 0.5,
                    max(true_vals.max(), pred_vals.max()) + 0.5]
            ax.plot(lims, lims, '--', color='gray', alpha=0.5)
            ax.set_xlim(lims)
            ax.set_ylim(lims)
            
            ax.text(0.05, 0.95, f'R²={r2_val:.3f}', transform=ax.transAxes, 
                    fontsize=9, va='top', bbox=dict(boxstyle='round', facecolor='white', alpha=0.8))
            
            if col == 0:
                ax.set_ylabel(f'{physical_labels[row]}\nPredicted', fontsize=10)
            if row == 0:
                ax.set_title(model_name, fontsize=11, fontweight='bold')
            if row == 2:
                ax.set_xlabel('True', fontsize=10)
            ax.spines['top'].set_visible(False)
            ax.spines['right'].set_visible(False)
        
        # Optical columns
        for col_offset, (pred, model_name) in enumerate([(pred_unsup, 'v2.6.7'), (pred_semisup, 'v2.14')]):
            col = col_offset + 2
            ax = axes[row, col]
            feat_idx = optical_idx[row]
            true_vals = X_scaled[plot_idx, feat_idx]
            pred_vals = pred[plot_idx, feat_idx]
            r2_val = r2(X_scaled[:, feat_idx], pred[:, feat_idx])
            
            ax.scatter(true_vals, pred_vals, alpha=0.3, s=1, c=optical_colors[row])
            lims = [min(true_vals.min(), pred_vals.min()) - 0.5,
                    max(true_vals.max(), pred_vals.max()) + 0.5]
            ax.plot(lims, lims, '--', color='gray', alpha=0.5)
            ax.set_xlim(lims)
            ax.set_ylim(lims)
            
            ax.text(0.05, 0.95, f'R²={r2_val:.3f}', transform=ax.transAxes, 
                    fontsize=9, va='top', bbox=dict(boxstyle='round', facecolor='white', alpha=0.8))
            
            if col == 2:
                ax.set_ylabel(f'{optical_labels[row]}\nPredicted', fontsize=10)
            if row == 0:
                ax.set_title(model_name, fontsize=11, fontweight='bold')
            if row == 2:
                ax.set_xlabel('True', fontsize=10)
            ax.spines['top'].set_visible(False)
            ax.spines['right'].set_visible(False)
    
    plt.tight_layout(rect=[0, 0, 1, 0.92])
    
    if save_path:
        fig.savefig(save_path, dpi=300, bbox_inches='tight')
        print(f"Saved: {save_path}")
    
    return fig

In [None]:
# Generate reconstructions
with torch.no_grad():
    X_t = torch.FloatTensor(X_scaled)
    pred_unsup, _, _ = model_unsup(X_t)
    pred_semisup, _, _, _ = model_semisup(X_t)
    pred_unsup = pred_unsup.numpy()
    pred_semisup = pred_semisup.numpy()

fig = fig_reconstruction_scatter(X_scaled, pred_unsup, pred_semisup, 
                                  OUTPUT_DIR / 'fig_reconstruction_scatter.png')
plt.show()

## Figure: ROC Comparison

Compares lithology classification ROC curves.

In [None]:
from sklearn.metrics import roc_curve, auc
from sklearn.preprocessing import label_binarize
from catboost import CatBoostClassifier

def fig_roc_comparison(emb_unsup, emb_semisup, labels, n_classes, 
                       max_samples=50000, save_path=None):
    """
    Create fair ROC comparison using same classifier on both embeddings.
    """
    # Subsample
    if len(labels) > max_samples:
        np.random.seed(42)
        idx = np.random.choice(len(labels), max_samples, replace=False)
        emb_unsup = emb_unsup[idx]
        emb_semisup = emb_semisup[idx]
        labels = labels[idx]
    
    # Train/test split
    n = len(labels)
    n_train = int(0.8 * n)
    
    # Train classifiers
    clf_unsup = CatBoostClassifier(iterations=500, verbose=False, random_state=42)
    clf_unsup.fit(emb_unsup[:n_train], labels[:n_train])
    
    clf_semisup = CatBoostClassifier(iterations=500, verbose=False, random_state=42)
    clf_semisup.fit(emb_semisup[:n_train], labels[:n_train])
    
    # Get predictions
    prob_unsup = clf_unsup.predict_proba(emb_unsup[n_train:])
    prob_semisup = clf_semisup.predict_proba(emb_semisup[n_train:])
    
    # ROC
    y_test = labels[n_train:]
    y_test_bin = label_binarize(y_test, classes=range(n_classes))
    
    # Handle class mismatch
    prob_unsup_full = np.zeros((len(y_test), n_classes))
    prob_semisup_full = np.zeros((len(y_test), n_classes))
    for i, c in enumerate(clf_unsup.classes_):
        prob_unsup_full[:, c] = prob_unsup[:, i]
    for i, c in enumerate(clf_semisup.classes_):
        prob_semisup_full[:, c] = prob_semisup[:, i]
    
    fpr_unsup, tpr_unsup, _ = roc_curve(y_test_bin.ravel(), prob_unsup_full.ravel())
    fpr_semisup, tpr_semisup, _ = roc_curve(y_test_bin.ravel(), prob_semisup_full.ravel())
    
    auc_unsup = auc(fpr_unsup, tpr_unsup)
    auc_semisup = auc(fpr_semisup, tpr_semisup)
    
    # Plot
    fig, ax = plt.subplots(figsize=(8, 8))
    
    ax.plot(fpr_unsup, tpr_unsup, '#ff7f0e', lw=2.5, 
            label=f'Unsupervised (v2.6.7) AUC = {auc_unsup:.3f}')
    ax.plot(fpr_semisup, tpr_semisup, '#1f77b4', lw=2.5,
            label=f'Semi-supervised (v2.14) AUC = {auc_semisup:.3f}')
    ax.plot([0, 1], [0, 1], 'k--', lw=1.5, alpha=0.5, label='Random')
    
    ax.set_xlabel('False Positive Rate', fontsize=14)
    ax.set_ylabel('True Positive Rate', fontsize=14)
    ax.set_title('Lithology Classification: ROC Comparison', fontsize=12, fontweight='bold')
    ax.legend(loc='lower right', fontsize=11)
    ax.set_xlim([0, 1])
    ax.set_ylim([0, 1.02])
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.grid(True, alpha=0.3)
    
    plt.tight_layout()
    
    if save_path:
        fig.savefig(save_path, dpi=300, bbox_inches='tight')
        print(f"Saved: {save_path}")
    
    return fig

In [None]:
# Get labels
labels_raw = train_df['Principal'].values
unique_labels = np.unique(labels_raw)
label_to_idx = {l: i for i, l in enumerate(unique_labels)}
y = np.array([label_to_idx[l] for l in labels_raw])
n_classes = len(unique_labels)

print(f"Classes: {n_classes}")

# This takes a while
fig = fig_roc_comparison(emb_unsup, emb_semisup, y, n_classes,
                         save_path=OUTPUT_DIR / 'fig_roc_comparison.png')
plt.show()

## Summary

All figures saved to `figures/` directory.

In [None]:
import os
print("Generated figures:")
for f in sorted(OUTPUT_DIR.glob('*.png')):
    size_kb = os.path.getsize(f) / 1024
    print(f"  {f.name}: {size_kb:.1f} KB")