In [16]:
import numpy as np
from sklearn.decomposition import TruncatedSVD
from deep_generative_models.dataset import create_dataloader
from config.paths import CELL_DATA

def analyze_latent_dimensions(tile_size, batch_size=32, tiles_per_epoch=1000, hdf5_file_path=CELL_DATA, brains=["B20", "B02", "B01", "B05"]):
    """Analyzes optimal latent dimensions using incremental SVD analysis."""
    
    # Initialize dataloader
    dataloader = create_dataloader(
        hdf5_file_path, brains, tile_size, batch_size, 
        tiles_per_epoch, num_workers=0
    )
    
    # Calculate total dimensions
    sample_batch = next(iter(dataloader))
    total_dims = np.prod(sample_batch.shape[1:])
    
    # Prepare data matrix efficiently
    data_matrix = []
    total_samples = 0
    max_samples = 1000  # Limit samples for memory efficiency
    
    for batch in dataloader:
        if total_samples >= max_samples:
            break
        flat_batch = batch.numpy().reshape(batch.shape[0], -1)
        data_matrix.append(flat_batch)
        total_samples += batch.shape[0]
    
    data_matrix = np.vstack(data_matrix)
    
    # Analyze variance with increasing components
    target_variance = 0.95  # 95% explained variance threshold
    n_components_list = [32, 64, 128, 256, 512]
    results = {}
    
    for n_comp in n_components_list:
        svd = TruncatedSVD(n_components=min(n_comp, total_dims - 1))
        svd.fit(data_matrix)
        cumulative_variance = np.cumsum(svd.explained_variance_ratio_)
        results[n_comp] = cumulative_variance[-1]
        
        if cumulative_variance[-1] >= target_variance:
            # Find exact number of components needed
            for i, var in enumerate(cumulative_variance):
                if var >= target_variance:
                    return {
                        'optimal_dimensions': i + 1,
                        'explained_variance': var,
                        'all_tested_dimensions': results
                    }
    
    return {
        'optimal_dimensions': n_components_list[-1],
        'explained_variance': results[n_components_list[-1]],
        'all_tested_dimensions': results,
        'warning': 'Target variance not reached, may need more dimensions'
    }
    
    
def print_results(results):
    print(f"Optimal latent dimensions: {results['optimal_dimensions']}")
    print(f"Explained variance: {results['explained_variance']:.3f}")
    print("\nTested dimensions and their explained variance:")
    for dims, var in results['all_tested_dimensions'].items():
        print(f"{dims} dimensions: {var:.3f}")




In [17]:
results = analyze_latent_dimensions(
    hdf5_file_path=CELL_DATA,
    tile_size=64
)

print_results(results)

Optimal latent dimensions: 150
Explained variance: 0.950

Tested dimensions and their explained variance:
32 dimensions: 0.781
64 dimensions: 0.874
128 dimensions: 0.940
256 dimensions: 0.976


In [18]:
results = analyze_latent_dimensions(
    hdf5_file_path=CELL_DATA,
    tile_size=128
)

print_results(results)

Optimal latent dimensions: 376
Explained variance: 0.950

Tested dimensions and their explained variance:
32 dimensions: 0.610
64 dimensions: 0.707
128 dimensions: 0.815
256 dimensions: 0.910
512 dimensions: 0.973


In [19]:
results = analyze_latent_dimensions(
    hdf5_file_path=CELL_DATA,
    tile_size=256
)

print_results(results)

Optimal latent dimensions: 512
Explained variance: 0.920

Tested dimensions and their explained variance:
32 dimensions: 0.519
64 dimensions: 0.587
128 dimensions: 0.682
256 dimensions: 0.800
512 dimensions: 0.920
