# Advanced Kernel Options in KLRfome

This notebook explores kernel options and their effects on model performance. KLRfome supports two kernel approaches:

1. **Exact RBF Kernel**: Computes the full kernel matrix (O(nÂ²) complexity)
2. **Random Fourier Features (RFF)**: Approximates RBF kernel with explicit feature maps (O(nD) complexity)

We'll compare these approaches and explore the effect of hyperparameters.


## Setup


In [None]:
import numpy as np
import jax.numpy as jnp
import matplotlib.pyplot as plt
import time
from rasterio.transform import from_bounds
import geopandas as gpd
from shapely.geometry import Point

from klrfome import KLRfome, RBFKernel, RandomFourierFeatures, MeanEmbeddingKernel
from klrfome.data.simulation import create_simulated_raster_stack
from klrfome.visualization import plot_similarity_matrix, plot_predictions

SEED = 42
np.random.seed(SEED)


## Create Test Data

We'll use a small dataset for quick comparisons.


In [None]:
# Create small test dataset
raster_stack = create_simulated_raster_stack(cols=50, rows=50, n_bands=2, seed=SEED)

# Create a few site locations
n_sites = 5
site_points = [Point(np.random.uniform(0.1, 0.9), np.random.uniform(0.1, 0.9)) 
               for _ in range(n_sites)]
sites_gdf = gpd.GeoDataFrame(geometry=site_points, crs=raster_stack.crs)

print(f"Test raster: {raster_stack.data.shape}")
print(f"Sites: {len(sites_gdf)}")


## Part 1: Compare Exact RBF vs RFF Approximation

Compare the similarity matrices computed by exact RBF and RFF with different numbers of features.


In [None]:
# Prepare training data once
model_temp = KLRfome(sigma=1.0, n_rff_features=128, seed=SEED)
training_data = model_temp.prepare_data(
    raster_stack=raster_stack,
    sites=sites_gdf,
    n_background=20,
    samples_per_location=15
)

print(f"Training data: {training_data.n_locations} locations")

# Compute exact RBF similarity matrix
exact_kernel = RBFKernel(sigma=1.0)
exact_dist_kernel = MeanEmbeddingKernel(exact_kernel)
K_exact = exact_dist_kernel.build_similarity_matrix(training_data.collections)

print(f"\nExact RBF similarity matrix:")
print(f"  Shape: {K_exact.shape}")
print(f"  Range: [{jnp.min(K_exact):.3f}, {jnp.max(K_exact):.3f}]")
print(f"  Mean: {jnp.mean(K_exact):.3f}")


In [None]:
# Compare RFF with different numbers of features
rff_features = [64, 128, 256, 512]
rff_matrices = {}

for n_features in rff_features:
    rff_kernel = RandomFourierFeatures(sigma=1.0, n_features=n_features, seed=SEED)
    rff_dist_kernel = MeanEmbeddingKernel(rff_kernel)
    K_rff = rff_dist_kernel.build_similarity_matrix(training_data.collections)
    rff_matrices[n_features] = K_rff
    
    # Compute difference from exact
    diff = jnp.abs(K_rff - K_exact)
    print(f"RFF with {n_features:3d} features:")
    print(f"  Mean absolute error: {jnp.mean(diff):.6f}")
    print(f"  Max absolute error: {jnp.max(diff):.6f}")
    print(f"  Correlation with exact: {float(jnp.corrcoef(K_exact.flatten(), K_rff.flatten())[0,1]):.4f}")
    print()


## Part 2: Effect of Sigma (Bandwidth) Parameter

The `sigma` parameter controls the kernel bandwidth - larger values make the kernel more smooth, smaller values make it more local.


In [None]:
# Test different sigma values
sigma_values = [0.1, 0.5, 1.0, 2.0, 5.0]
sigma_results = {}

for sigma in sigma_values:
    model = KLRfome(
        sigma=sigma,
        lambda_reg=0.1,
        n_rff_features=256,
        window_size=3,
        seed=SEED
    )
    
    model.fit(training_data)
    predictions = model.predict(raster_stack, show_progress=False)
    
    sigma_results[sigma] = {
        'similarity_matrix': model._similarity_matrix,
        'predictions': predictions,
        'similarity_range': (float(jnp.min(model._similarity_matrix)), 
                            float(jnp.max(model._similarity_matrix))),
        'prediction_range': (float(jnp.min(predictions)), 
                           float(jnp.max(predictions))),
        'prediction_mean': float(jnp.mean(predictions))
    }
    
    print(f"Sigma = {sigma:3.1f}:")
    print(f"  Similarity range: {sigma_results[sigma]['similarity_range']}")
    print(f"  Prediction range: {sigma_results[sigma]['prediction_range']}")
    print(f"  Mean prediction: {sigma_results[sigma]['prediction_mean']:.3f}")
    print()


### Visualize Similarity Matrices for Different Sigma Values


In [None]:
# Plot similarity matrices for different sigma values
fig, axes = plt.subplots(1, len(sigma_values), figsize=(20, 4))

for idx, sigma in enumerate(sigma_values):
    K = sigma_results[sigma]['similarity_matrix']
    im = axes[idx].imshow(np.array(K), cmap='viridis', aspect='auto')
    axes[idx].set_title(f'Sigma = {sigma}')
    axes[idx].set_xlabel('Sample Index')
    if idx == 0:
        axes[idx].set_ylabel('Sample Index')
    plt.colorbar(im, ax=axes[idx])

plt.tight_layout()
plt.show()


### Visualize Predictions for Different Sigma Values


In [None]:
# Plot predictions for different sigma values
fig, axes = plt.subplots(1, len(sigma_values), figsize=(20, 4))

for idx, sigma in enumerate(sigma_values):
    pred = sigma_results[sigma]['predictions']
    im = axes[idx].imshow(np.array(pred), cmap='viridis', aspect='auto', vmin=0, vmax=1)
    axes[idx].set_title(f'Sigma = {sigma}\nMean: {sigma_results[sigma]["prediction_mean"]:.3f}')
    axes[idx].set_xlabel('Column')
    if idx == 0:
        axes[idx].set_ylabel('Row')
    plt.colorbar(im, ax=axes[idx], label='Probability')

plt.tight_layout()
plt.show()


## Part 3: Performance Comparison

Compare computation time and memory usage between exact RBF and RFF approximations.


In [None]:
# Create larger dataset for performance testing
large_raster = create_simulated_raster_stack(cols=100, rows=100, n_bands=3, seed=SEED)
large_sites = gpd.GeoDataFrame(
    geometry=[Point(np.random.uniform(0.1, 0.9), np.random.uniform(0.1, 0.9)) 
              for _ in range(20)],
    crs=large_raster.crs
)

# Prepare training data
model_large = KLRfome(sigma=1.0, n_rff_features=128, seed=SEED)
large_training = model_large.prepare_data(
    raster_stack=large_raster,
    sites=large_sites,
    n_background=50,
    samples_per_location=20
)

print(f"Large dataset: {large_training.n_locations} locations")
print(f"Total samples: {sum(coll.n_samples for coll in large_training.collections)}")


In [None]:
# Time exact RBF
print("Timing exact RBF kernel...")
start = time.time()
exact_kernel = RBFKernel(sigma=1.0)
exact_dist = MeanEmbeddingKernel(exact_kernel)
K_exact_large = exact_dist.build_similarity_matrix(large_training.collections)
exact_time = time.time() - start

print(f"Exact RBF: {exact_time:.3f} seconds")
print(f"Memory: {K_exact_large.nbytes / 1024**2:.2f} MB")


In [None]:
# Time RFF with different feature counts
rff_times = {}
rff_memory = {}

for n_features in [64, 128, 256, 512]:
    print(f"\nTiming RFF with {n_features} features...")
    start = time.time()
    rff_kernel = RandomFourierFeatures(sigma=1.0, n_features=n_features, seed=SEED)
    rff_dist = MeanEmbeddingKernel(rff_kernel)
    K_rff_large = rff_dist.build_similarity_matrix(large_training.collections)
    rff_time = time.time() - start
    
    rff_times[n_features] = rff_time
    rff_memory[n_features] = K_rff_large.nbytes / 1024**2
    
    print(f"RFF ({n_features} features): {rff_time:.3f} seconds")
    print(f"Memory: {rff_memory[n_features]:.2f} MB")
    print(f"Speedup: {exact_time / rff_time:.2f}x")


### Performance Summary


In [None]:
# Plot performance comparison
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))

# Time comparison
features_list = sorted(rff_times.keys())
times_list = [rff_times[f] for f in features_list]
ax1.plot(features_list, times_list, 'o-', label='RFF', linewidth=2, markersize=8)
ax1.axhline(exact_time, color='r', linestyle='--', label='Exact RBF')
ax1.set_xlabel('Number of RFF Features')
ax1.set_ylabel('Time (seconds)')
ax1.set_title('Computation Time')
ax1.legend()
ax1.grid(True, alpha=0.3)

# Memory comparison
memory_list = [rff_memory[f] for f in features_list]
ax2.plot(features_list, memory_list, 'o-', label='RFF', linewidth=2, markersize=8)
ax2.axhline(K_exact_large.nbytes / 1024**2, color='r', linestyle='--', label='Exact RBF')
ax2.set_xlabel('Number of RFF Features')
ax2.set_ylabel('Memory (MB)')
ax2.set_title('Memory Usage')
ax2.legend()
ax2.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()


22222222eeeeee -->