In [None]:
import matplotlib.pyplot as plt
import numpy as np
from mpl_toolkits.mplot3d import Axes3D

import jax
import jax.numpy as jnp
from jax import random

# diffusionlab imports
from diffusionlab.dynamics import VariancePreservingProcess
from diffusionlab.losses import DiffusionLoss
from diffusionlab.vector_fields import VectorFieldType

# repo imports
from src.diffusion_mem_gen.models.gmm import (
    IsoHomGMMInitStrategy,
    iso_hom_gmm_create_initialization_parameters,
    IsoHomGMMSharedParametersEstimator,
)
from src.diffusion_mem_gen.utils.factories import compute_loss_factory
from investigating_diffusion_loss import AmbientDiffusionLoss

key = random.PRNGKey(0)

## Crossover Analysis: How crossover point varies with N

This notebook investigates how the crossover point (where PMEM loss equals Generalising loss) varies with different values of N (number of training samples).

We create a 3D plot with:
- x-axis: t_n values
- y-axis: N values (50, 100, 150, 300, 600, 1000)
- z-axis: M/N at crossover point

In [None]:
# Dimension and other fixed parameters
laboratory_d = 100  # data dimension
laboratory_K = 12  # number of components in the GMM

u_means_variance = 30 ** (1/2)  # order 1, controls the variance of the means of the true distribution
sample_variance = 1  # isotropic variance for each component

# Diffusion process (needed for L_N function)
diffusion_process = VariancePreservingProcess()

# N values to investigate
N_values = [50, 100, 150, 300, 600, 1000]

# M values: range to search for crossover (will be adjusted per N)
M_points = 10  # number of M values to evaluate

# t_n values to vary (reduced by half: from 17 to ~8-9)
t_n_values = np.linspace(0.01, 0.99, 9)  # Reduced from 17 to 9
print(f"t_n values: {t_n_values}")

# Time grid for evaluation
t_val_array = jnp.linspace(0.01, 0.99, 26)

# Lambda function (constant for now)
lambda_fn = lambda t: 1

In [None]:
def L_N(model_callable, lambda_fn, t_val_array, t_n_value, X_0_dataset, X_t_n_dataset, key):
    '''
    This function computes the L_N loss for a given model and a given lambda function.
    '''
    ambient_loss_obj = AmbientDiffusionLoss(diffusion_process, num_noise_draws_per_sample=1, t_n=t_n_value)
    standard_loss_obj = DiffusionLoss(diffusion_process, vector_field_type=VectorFieldType.X0, num_noise_draws_per_sample=1)

    def L_N_t_ambient(model_callable, t_n_value, t_val, X_t_n_dataset, key):
        '''
        This function computes the L_N_t for the ambient denoising term.
        '''
        assert t_val > t_n_value
        
        compute_loss = compute_loss_factory(ambient_loss_obj, jnp.array(t_val))
        loss_val = compute_loss(key, model_callable, X_t_n_dataset) / X_0_dataset.shape[0] 
        
        return loss_val

    def L_N_t_standard_score(model_callable, t_val, X_t_dataset, key):
        '''
        This function computes the L_N_t for the standard denoising term.
        '''
        compute_loss = compute_loss_factory(standard_loss_obj, jnp.array(t_val))
        loss_val = compute_loss(key, model_callable, X_t_dataset) / X_0_dataset.shape[0] 
        
        return loss_val
    
    # Split into less than t_n and greater than t_n
    less_than_n_mask = jnp.less_equal(t_val_array, t_n_value)
    greater_than_n_mask = jnp.logical_not(less_than_n_mask)

    standard_denoising_t_values = t_val_array[less_than_n_mask]
    ambient_denoising_t_values = t_val_array[greater_than_n_mask]
  
    # Compute the loss for the standard denoising term
    current_key = key
    standard_denoising_loss_values = []
    for t in standard_denoising_t_values:
        current_key, subk = random.split(current_key)
        loss_val = L_N_t_standard_score(model_callable, t, X_0_dataset, subk)
        standard_denoising_loss_values.append(lambda_fn(t) * loss_val)

    averaged_standard_denoising_loss = t_n_value * np.mean(standard_denoising_loss_values)

    # Compute the loss for the ambient denoising term
    ambient_denoising_loss_values = []
    for t in ambient_denoising_t_values:
        current_key, subk = random.split(current_key)
        loss_val = L_N_t_ambient(model_callable, t_n_value, t, X_t_n_dataset, subk)
        ambient_denoising_loss_values.append(lambda_fn(t) * loss_val)
    
    averaged_ambient_denoising_loss = (1 - t_n_value) * np.mean(ambient_denoising_loss_values)

    return averaged_standard_denoising_loss + averaged_ambient_denoising_loss

In [None]:
def find_crossover_point(M_values, pmem_losses, generalising_losses, N):
    '''
    Find the M value where PMEM and Generalising losses cross over.
    Returns the M/N ratio at crossover, or None if no crossover found.
    '''
    loss_difference = pmem_losses - generalising_losses
    
    # Check if there's a sign change
    sign_changes = np.where(np.diff(np.sign(loss_difference)))[0]
    
    if len(sign_changes) == 0:
        # No sign change - check if we're close to zero
        min_abs_diff_idx = np.argmin(np.abs(loss_difference))
        min_abs_diff = np.abs(loss_difference[min_abs_diff_idx])
        
        # If we're very close to zero, return that point
        if min_abs_diff < 0.01:  # threshold for "close enough"
            return M_values[min_abs_diff_idx] / N
        else:
            return None
    
    # Interpolate to find exact crossover point
    crossover_M_over_N = []
    for idx in sign_changes:
        # Linear interpolation between points
        x1, x2 = M_values[idx] / N, M_values[idx + 1] / N
        y1, y2 = loss_difference[idx], loss_difference[idx + 1]
        
        # Find where y = 0
        if y2 != y1:
            x_cross = x1 - y1 * (x2 - x1) / (y2 - y1)
            crossover_M_over_N.append(x_cross)
    
    # Return the first crossover point (or average if multiple)
    if len(crossover_M_over_N) > 0:
        return crossover_M_over_N[0]
    else:
        return None

In [None]:
# Store results for all (N, t_n) combinations
all_crossover_data = []  # List of (N, t_n, M/N at crossover)

# Evaluate for each N value
for laboratory_N in N_values:
    print(f"\n{'='*70}")
    print(f"Processing N = {laboratory_N}")
    print(f"{'='*70}")
    
    # M values: range to search for crossover (adjusted per N)
    M_values = np.linspace(10, laboratory_N, M_points).astype(int)
    print(f"M values: {M_values}")
    
    # Ground-truth GMM params
    key, sk = random.split(key)
    true_means = random.normal(sk, (laboratory_K, laboratory_d)) * u_means_variance
    equal_weighted_prior = jnp.array([1/laboratory_K for _ in range(laboratory_K)])  # must sum to 1

    # Sample training set from the true GMM
    key, sk = random.split(key)
    comp_ids = random.choice(sk, laboratory_K, shape=(laboratory_N,), p=equal_weighted_prior)
    key, sk = random.split(key)
    X_train = true_means[comp_ids] + jnp.sqrt(sample_variance) * random.normal(sk, (laboratory_N, laboratory_d))

    print(f"Training data shape: {X_train.shape}")
    
    # Build Generalizing denoiser (same for all M and t_n values)
    generalising_model = IsoHomGMMSharedParametersEstimator(
        dim=laboratory_d,
        num_components=laboratory_K,
        vf_type=VectorFieldType.X0,
        diffusion_process=diffusion_process,
        init_means=true_means,
        init_var=jnp.asarray(sample_variance),
        priors=equal_weighted_prior,
    )
    
    # Evaluate for each t_n value
    for t_n in t_n_values:
        print(f"\n  Processing t_n = {t_n:.3f}...", end=" ")
        
        t_n_value = jnp.array(t_n)
        
        # Generate X_t_n_dataset for this t_n value
        key, subk = random.split(key)
        X_t_n_eps = jax.random.normal(subk, X_train.shape)
        batch_diffusion_forward = jax.vmap(
            diffusion_process.forward, in_axes=(0, None, 0)
        )
        X_t_n_dataset = batch_diffusion_forward(X_train, t_n_value, X_t_n_eps)
        
        # Store losses for this t_n
        pmem_losses = []
        generalising_losses = []
        
        # Evaluate L_N for each M value
        for M in M_values:
            # Build PMEM denoiser for this M value
            key, sk = random.split(key)
            context = {
                "X_train": X_train,
                "init_var_scale": 1e-6,
                "init_means_noise_var": 0.0,
            }
            means_pmem, var_pmem, priors_pmem = iso_hom_gmm_create_initialization_parameters(
                sk, IsoHomGMMInitStrategy.PMEM, laboratory_d, M, context
            )
            
            pmem_model = IsoHomGMMSharedParametersEstimator(
                dim=laboratory_d,
                num_components=M,
                vf_type=VectorFieldType.X0,
                diffusion_process=diffusion_process,
                init_means=means_pmem,
                init_var=var_pmem,
                priors=priors_pmem,
            )
            
            # Evaluate L_N for PMEM model
            key, subk = random.split(key)
            pmem_loss = L_N(pmem_model, lambda_fn, t_val_array, t_n_value, X_train, X_t_n_dataset, subk)
            pmem_losses.append(float(pmem_loss))
            
            # Evaluate L_N for Generalizing model (same for all M)
            key, subk = random.split(key)
            generalising_loss = L_N(generalising_model, lambda_fn, t_val_array, t_n_value, X_train, X_t_n_dataset, subk)
            generalising_losses.append(float(generalising_loss))
        
        pmem_losses = np.array(pmem_losses)
        generalising_losses = np.array(generalising_losses)
        
        # Find crossover point
        crossover_M_over_N = find_crossover_point(M_values, pmem_losses, generalising_losses, laboratory_N)
        if crossover_M_over_N is not None and not np.isnan(crossover_M_over_N):
            all_crossover_data.append((laboratory_N, t_n, crossover_M_over_N))
            print(f"Crossover at M/N = {crossover_M_over_N:.4f}")
        else:
            print("No crossover found")
    
    print(f"\nCompleted N = {laboratory_N}")

print(f"\n\nTotal crossover points found: {len(all_crossover_data)}")

In [None]:
# Prepare data for 3D plotting
if len(all_crossover_data) > 0:
    # Extract data
    N_array = np.array([d[0] for d in all_crossover_data])
    t_n_array = np.array([d[1] for d in all_crossover_data])
    M_over_N_array = np.array([d[2] for d in all_crossover_data])
    
    # Create 3D plot
    fig = plt.figure(figsize=(12, 8))
    ax = fig.add_subplot(111, projection='3d')
    
    # Create scatter plot
    scatter = ax.scatter(t_n_array, N_array, M_over_N_array, 
                        c=M_over_N_array, cmap='viridis', 
                        s=50, alpha=0.7, edgecolors='black', linewidth=0.5)
    
    # Add colorbar
    cbar = plt.colorbar(scatter, ax=ax, pad=0.1)
    cbar.set_label('M/N at Crossover', fontsize=11)
    
    # Set labels
    ax.set_xlabel('$t_n$', fontsize=12, labelpad=10)
    ax.set_ylabel('$N$ (Number of Training Samples)', fontsize=12, labelpad=10)
    ax.set_zlabel('Crossover $M/N$', fontsize=12, labelpad=10)
    ax.set_title('Crossover Point Variation with $N$ and $t_n$', fontsize=14, pad=20)
    
    # Set y-axis ticks to match N_values
    ax.set_yticks(N_values)
    
    # Improve viewing angle
    ax.view_init(elev=20, azim=45)
    
    plt.tight_layout()
    plt.show()
    
    # Print summary
    print("\nCrossover points summary:")
    print(f"{'N':<6} {'t_n':<8} {'M/N':<10}")
    print("-" * 25)
    for N, t_n, m_over_n in all_crossover_data:
        print(f"{N:<6} {t_n:<8.3f} {m_over_n:<10.4f}")
else:
    print("No crossover points found. The losses may not cross in the evaluated range.")

In [None]:
# Alternative: Surface plot if we have enough data points
if len(all_crossover_data) > 0:
    # Organize data into a grid for surface plotting
    # Create meshgrid for N and t_n
    N_unique = np.array(N_values)
    t_n_unique = t_n_values
    
    # Create a grid to store M/N values
    M_over_N_grid = np.full((len(N_unique), len(t_n_unique)), np.nan)
    
    # Fill in the grid with crossover points
    for N, t_n, m_over_n in all_crossover_data:
        N_idx = np.where(N_unique == N)[0][0]
        t_n_idx = np.where(np.isclose(t_n_unique, t_n))[0][0]
        M_over_N_grid[N_idx, t_n_idx] = m_over_n
    
    # Create meshgrid
    T_n_grid, N_grid = np.meshgrid(t_n_unique, N_unique)
    
    # Create surface plot
    fig = plt.figure(figsize=(14, 10))
    ax = fig.add_subplot(111, projection='3d')
    
    # Plot surface (only where we have data)
    surf = ax.plot_surface(T_n_grid, N_grid, M_over_N_grid, 
                          cmap='viridis', alpha=0.8, 
                          linewidth=0, antialiased=True,
                          edgecolor='none')
    
    # Add scatter points on top
    N_array = np.array([d[0] for d in all_crossover_data])
    t_n_array = np.array([d[1] for d in all_crossover_data])
    M_over_N_array = np.array([d[2] for d in all_crossover_data])
    ax.scatter(t_n_array, N_array, M_over_N_array, 
              c='red', s=30, alpha=1.0, edgecolors='black', linewidth=0.5)
    
    # Add colorbar
    cbar = plt.colorbar(surf, ax=ax, pad=0.1)
    cbar.set_label('M/N at Crossover', fontsize=11)
    
    # Set labels
    ax.set_xlabel('$t_n$', fontsize=12, labelpad=10)
    ax.set_ylabel('$N$ (Number of Training Samples)', fontsize=12, labelpad=10)
    ax.set_zlabel('Crossover $M/N$', fontsize=12, labelpad=10)
    ax.set_title('Crossover Point Surface: Variation with $N$ and $t_n$', fontsize=14, pad=20)
    
    # Set y-axis ticks to match N_values
    ax.set_yticks(N_values)
    
    # Improve viewing angle
    ax.view_init(elev=25, azim=45)
    
    plt.tight_layout()
    plt.show()