In [1]:
import matplotlib.pyplot as plt
import acm.observables.emc as emc
from pathlib import Path
import numpy as np
import torch
plt.style.use(['science','no-latex'])

In [8]:
stat_map = {
    'number_density': emc.GalaxyNumberDensity(
        select_filters={'cosmo_idx': [0], 'hod_idx': [30,], 'multipoles': [0, 2]},
        slice_filters={},
    ),
    'wp': emc.GalaxyProjectedCorrelationFunction(
        select_filters={'cosmo_idx': [0], 'hod_idx': [30,], 'multipoles': [0, 2]},
        slice_filters={}),
    'tpcf': emc.GalaxyCorrelationFunctionMultipoles(
        select_filters={'cosmo_idx': [0], 'hod_idx': [30,], 'multipoles': [0, 2]},
        slice_filters={}),
    'pk': emc.GalaxyPowerSpectrumMultipoles(
        select_filters={'cosmo_idx': [0], 'hod_idx': [30,]}, 
        slice_filters={}),
    'bk': emc.GalaxyBispectrumMultipoles(
        select_filters={'cosmo_idx': [0], 'hod_idx': [30,]}, 
        slice_filters={}),
    'wst': emc.WaveletScatteringTransform(
        select_filters={'cosmo_idx': [0], 'hod_idx': [30,]}, 
        slice_filters={}),
    'dt_voids': emc.DTVoidGalaxyCorrelationFunctionMultipoles(
        select_filters={'cosmo_idx': [0], 'hod_idx': [30,]}, 
        slice_filters={}),
    'dsc_pk': emc.DensitySplitPowerSpectrumMultipoles(
        select_filters={'cosmo_idx': [0], 'hod_idx': [30,]}, 
        slice_filters={}),
    'minkowski': emc.MinkowskiFunctionals(
        select_filters={'cosmo_idx': [0], 'hod_idx': [30,]}, 
        slice_filters={}),
}

In [4]:

import torch 
import torch.func as func

def get_gradient(statistic):
    fiducial_parameters = stat_map[statistic].lhc_x
    fiducial_parameters = torch.tensor(fiducial_parameters.astype(np.float32), requires_grad=True,).unsqueeze(0)
    def model_fn(x_batch):
        # Add batch dimension for the model
        return stat_map[statistic].model.get_prediction(x_batch)
    return func.jacrev(model_fn)(fiducial_parameters).detach().squeeze().numpy()

def get_full_gradients(statistics):
    return np.vstack([get_gradient(stat) for stat in statistics],)

def get_precision_matrix(full_covariance_vector):
    if full_covariance_vector.shape[1] == 1:
        # For 1D case, calculate variance manually and force a 1x1 matrix
        variance = np.var(full_covariance_vector, ddof=1)
        covariance_matrix = np.array([[variance]])
    else:
        covariance_matrix = np.cov(full_covariance_vector.T)
    correction = stat_map['tpcf'].get_covariance_correction(
        n_s=full_covariance_vector.shape[0],
        n_d=len(covariance_matrix),
        n_theta=20,
        method='percival',
    )
    precision_matrix = np.linalg.inv(correction * covariance_matrix)
    return precision_matrix


def get_fisher_log_det(statistics,):
    full_covariance_vector = np.hstack([stat_map[stat].create_small_box_y()[1] for stat in statistics])
    precision_matrix = get_precision_matrix(full_covariance_vector)
    gradients = get_full_gradients(statistics)
    fisher_matrix = np.dot(gradients.T, np.dot(precision_matrix, gradients))
    sign, fisher_log_det = np.linalg.slogdet(fisher_matrix)
    return fisher_log_det

In [5]:
get_fisher_log_det(['tpcf'])

46.38607297925267

In [6]:
get_fisher_log_det(['tpcf','bk'])

80.10231256664001

In [7]:
get_fisher_log_det(['tpcf','bk', 'minkowski'])

AttributeError: 'MinkowskiFunctionals' object has no attribute 'create_small_box_y'

In [9]:

def precompute_derivatives_and_covariance(statistics=['tpcf', 'bk']):
    precomputed = {}
    precomputed['derivatives'] = {}
    for stat_name in statistics:
        precomputed['derivatives'][stat_name] = get_gradient(stat_name)
    
    precomputed['covariance_data'] = {}
    for stat_name in statistics:
        precomputed['covariance_data'][stat_name] = stat_map[stat_name].create_small_box_y()[1]
    
    precomputed['bin_counts'] = {
        stat_name: precomputed['derivatives'][stat_name].shape[1] 
        for stat_name in statistics
    }
    return precomputed

In [10]:
precomputed = precompute_derivatives_and_covariance(statistics=['tpcf', 'bk'])

In [11]:
precomputed['derivatives']['tpcf'].shape, precomputed['derivatives']['bk'].shape

((150, 20), (478, 20))

In [14]:


def greedy_bin_selection(precomputed, max_bins=10,):
    derivatives = precomputed['derivatives']
    covariance_data = precomputed['covariance_data']
    statistics = list(derivatives.keys())
    
    available_bins = {
        stat_name: list(range(derivatives[stat_name].shape[0]))
        for stat_name in statistics
    }
    selected_bins = {stat_name: [] for stat_name in statistics}
    
    current_log_det = float('-inf')
    total_selected_bins = 0
    
    all_bins = []
    for stat_name in statistics:
        for bin_idx in available_bins[stat_name]:
            all_bins.append((stat_name, bin_idx))
    
    print(f"Total bins to evaluate: {len(all_bins)}")
    
    while total_selected_bins < max_bins and all_bins:
        best_bin = None
        best_log_det = current_log_det
        best_improvement = 0
        
        for i, (stat_name, bin_idx) in enumerate(all_bins):
            temp_selected = {
                stat: selected_bins[stat].copy() for stat in statistics
            }
            temp_selected[stat_name].append(bin_idx)
            
            new_log_det = evaluate_selected_bins_precomputed(
                temp_selected, 
                derivatives, 
                covariance_data
            )
            
            improvement = new_log_det - current_log_det
            if improvement > best_improvement:
                best_improvement = improvement
                best_log_det = new_log_det
                best_bin = i

        
        if best_bin is None or best_improvement < 1e-6:
            print("No significant improvement found, stopping early")
            break
        
        stat_name, bin_idx = all_bins.pop(best_bin)
        selected_bins[stat_name].append(bin_idx)
        total_selected_bins += 1
        current_log_det = best_log_det
        
        if total_selected_bins % 5 == 0 or total_selected_bins == max_bins:
            print(f"Selected {total_selected_bins}/{max_bins} bins, current log-det: {current_log_det:.4f}")
            distribution = ", ".join([f"{stat}: {len(bins)}" for stat, bins in selected_bins.items()])
            print(f"Distribution: {distribution}")
    
    return selected_bins, current_log_det

def evaluate_selected_bins_precomputed(selected_bins, derivatives, covariance_data):

    selected_cov_data = []
    selected_grads = []
    
    for stat_name, bin_indices in selected_bins.items():
        if not bin_indices:
            continue
            
        full_cov_data = covariance_data[stat_name]
        selected_cov_data.append(full_cov_data[:, bin_indices])
        
        full_derivatives = derivatives[stat_name]
        selected_grads.append(full_derivatives[bin_indices,:])
    
    full_covariance_vector = np.hstack(selected_cov_data)
    gradients = np.vstack(selected_grads)
    precision_matrix = get_precision_matrix(full_covariance_vector)
    fisher_matrix = gradients.T @ precision_matrix @ gradients
    
    sign, log_det = np.linalg.slogdet(fisher_matrix)
    if sign <= 0:
        return float('-inf')
    
    return log_det

def run_optimization(statistics=['tpcf', 'bk'], max_bins=100):
    print(f"Precomputing data for statistics: {statistics}")
    precomputed_data = precompute_derivatives_and_covariance(statistics)
    
    print(f"\nRunning greedy selection with max_bins={max_bins}")
    selected_bins, final_log_det = greedy_bin_selection(precomputed_data, max_bins=max_bins)
    
    # Print final results
    print("\nFinal selection:")
    for stat, bins in selected_bins.items():
        print(f"{stat}: {len(bins)} bins selected")
    print(f"Total: {sum(len(bins) for bins in selected_bins.values())} bins")
    print(f"Final log-determinant: {final_log_det:.4f}")
    
    return selected_bins, final_log_det

In [15]:
selected_bins, final_log_det = run_optimization(statistics=['tpcf', 'bk'], max_bins=20)

Precomputing data for statistics: ['tpcf', 'bk']

Running greedy selection with max_bins=20
Total bins to evaluate: 628
Selected 5/20 bins, current log-det: -401.9209
Distribution: tpcf: 4, bk: 1
Selected 10/20 bins, current log-det: -238.6040
Distribution: tpcf: 7, bk: 3
Selected 15/20 bins, current log-det: -90.0780
Distribution: tpcf: 9, bk: 6
Selected 20/20 bins, current log-det: 43.3513
Distribution: tpcf: 11, bk: 9

Final selection:
tpcf: 11 bins selected
bk: 9 bins selected
Total: 20 bins
Final log-determinant: 43.3513


In [16]:
selected_bins, final_log_det = run_optimization(statistics=['tpcf', 'bk'], max_bins=100)

Precomputing data for statistics: ['tpcf', 'bk']

Running greedy selection with max_bins=100
Total bins to evaluate: 628
Selected 5/100 bins, current log-det: -401.9209
Distribution: tpcf: 4, bk: 1
Selected 10/100 bins, current log-det: -238.6040
Distribution: tpcf: 7, bk: 3
Selected 15/100 bins, current log-det: -90.0780
Distribution: tpcf: 9, bk: 6
Selected 20/100 bins, current log-det: 43.3513
Distribution: tpcf: 11, bk: 9
Selected 25/100 bins, current log-det: 54.0682
Distribution: tpcf: 12, bk: 13
Selected 30/100 bins, current log-det: 58.8847
Distribution: tpcf: 14, bk: 16
Selected 35/100 bins, current log-det: 61.9713
Distribution: tpcf: 15, bk: 20
Selected 40/100 bins, current log-det: 64.3042
Distribution: tpcf: 15, bk: 25
Selected 45/100 bins, current log-det: 66.0838
Distribution: tpcf: 15, bk: 30
Selected 50/100 bins, current log-det: 67.8263
Distribution: tpcf: 17, bk: 33
Selected 55/100 bins, current log-det: 69.0499
Distribution: tpcf: 17, bk: 38
Selected 60/100 bins, cu