In [1]:
%load_ext autoreload
%autoreload 2

import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0,1,2,3'
os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = '0'
%cd ..

/home/adrish/dark-matter-halos-2


  self.shell.db['dhist'] = compress_dhist(dhist)[-100:]


In [2]:
from src.pipeline.pipeline import *

import numpy as np
from pathlib import Path
import pickle
import ray
import gzip
import jax
from jax import numpy as jnp
import matplotlib.pyplot as plt
from matplotlib.colors import LogNorm, LinearSegmentedColormap
from matplotlib.cm import ScalarMappable


ray.init(dashboard_host="0.0.0.0", ignore_reinit_error=True)
rng = np.random.default_rng(seed=0)

root_path = Path(".")
data_path = root_path / "data"
generated_data_path = root_path / "generated_data_extended"
cache_dir = generated_data_path / "cache"
cache_dir.mkdir(exist_ok=True)

2025-02-16 15:35:32.421379: W external/xla/xla/service/gpu/nvptx_compiler.cc:765] The NVIDIA driver's CUDA version is 12.0 which is older than the ptxas CUDA version (12.8.61). Because the driver is older than the ptxas version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.
  self.pid = _posixsubprocess.fork_exec(
  self.pid = _posixsubprocess.fork_exec(
2025-02-16 15:35:34,452	INFO worker.py:1812 -- Started a local Ray instance. View the dashboard at [1m[32m10.210.1.81:8265 [39m[22m


In [3]:
taus = np.logspace(-0.8, 3, num=12, base=10)

In [4]:
FILE = 'halo_pointclouds_extended.pkl.gz'
with gzip.open(data_path / FILE, "r") as f:
    points_list, velocities_list, halos_sel = pickle.load(f)
downsample_sizes = np.logspace(2, 7.2, num=12, base=2.0).round().astype(np.int32)
n_trials = 1

In [5]:
from src.subsampling.subsample import *

ignore_taus = set()
ignore_downsample_size = set()
ignore_feats = False

for downsample_size in downsample_sizes:
    for tau in taus:
        path = generated_data_path / f"kmeans_subsampled_tau{tau}_n{n_trials}_s{downsample_size}.npz"
        if path.exists():
            ignore_taus.add(tau.item())
            ignore_downsample_size.add(downsample_size.item())

if (generated_data_path / "features_and_targets.npz").exists():
    ignore_feats = True

if len(ignore_taus) != len(taus):
    points_list = [preprocess_pointcloud(p) for p in points_list]
    
    downsampled_data = dict()
    
    for downsample_size in (set(downsample_sizes.tolist()) - ignore_downsample_size):
        print("Generating Samples for Downsample Size:", downsample_size)
    
        sampled_positions_all, sampled_velocities_all, sampled_weights_all = kmeans_downsample_points(
            points_list, velocities_list, taus, downsample_size, n_trials, pbar=True
        )
    
        for tau_idx, tau in enumerate(set(taus.tolist()) - ignore_taus):
            sampled_positions = sampled_positions_all[tau_idx]
            sampled_velocities = sampled_velocities_all[tau_idx]
            sampled_weights = sampled_weights_all[tau_idx]
    
            filename = f"kmeans_subsampled_tau{tau}_n{n_trials}_s{downsample_size}.npz"
            jnp.savez(
                generated_data_path / filename,
                points=sampled_positions,
                weights=sampled_weights,
                velocities=sampled_velocities,
            )

if not ignore_feats:
    print("Writing Features and Targets")
    jnp.savez(
        generated_data_path / "features_and_targets.npz",
        **halos_sel
    )

### Running Experiment on different k-means sizes

In [6]:

n_halos = 500

# Generating the train-test split: keep fixed seed of 0
rng = jax.random.PRNGKey(0)
prop_train = 0.75
prop_test = 0.25

n_train = int(n_halos * prop_train)
n_test = n_halos - n_train

train_indices_path = generated_data_path / "train_indices.txt"
test_indices_path = generated_data_path / "test_indices.txt"

if not train_indices_path.exists():
    train_indices = jax.random.choice(rng, n_halos, [n_train], replace=False)
    np.savetxt(train_indices_path, np.array(train_indices), fmt="%i")
    print("[TRAIN] Generating new indices")
else:
    train_indices = np.loadtxt(str(train_indices_path)).astype(np.int32)
    
if not test_indices_path.exists():
    test_indices = jnp.array(
        list(set(range(n_halos)) - set([ix.item() for ix in list(train_indices)]))
    )
    np.savetxt(test_indices_path, np.array(test_indices), fmt="%i")
    print("[TEST] Generating new indices")
else:
    test_indices = np.loadtxt(str(test_indices_path)).astype(np.int32)


In [7]:
downsampled_data = {
    (downsample_size, tau): f"kmeans_subsampled_tau{tau}_n{n_trials}_s{downsample_size}.npz"
    for downsample_size in downsample_sizes
    for tau in taus
}
        
        # Function to preprocess training data
def preprocess_data(features_and_targets, indices, subsampled_data, mass_range=(-11.1, 101.3)):
    log_m = np.log10(features_and_targets["Group_M_Crit200"][indices])
    sliced_mass_ix = indices[
        np.argwhere((log_m > mass_range[0]) & (log_m < mass_range[1]))[:, 0]
    ]

    weights = subsampled_data["weights"][sliced_mass_ix, 0]
    points = subsampled_data["points"][sliced_mass_ix, 0]
    velocities = subsampled_data["velocities"][sliced_mass_ix, 0]
    processed_data = {
        "stellar_mass": jnp.log10(features_and_targets["StellarMass"][sliced_mass_ix] * 1e10 / 0.677),
        "concentration": features_and_targets["SubhaloC200"][sliced_mass_ix],
        "mass": jnp.log10(features_and_targets["Group_M_Crit200"][sliced_mass_ix]),
        "stellar_metallicity": features_and_targets["StellarMetallicity"][sliced_mass_ix],
        "star_formation_rate": features_and_targets["StarFormRate"][sliced_mass_ix],
    }

    return processed_data, weights, points, velocities
    
# Function to create the problem context
def create_problem_context(data, weights, points, velocities, label):
    return ProblemContext(
        points=points,
        weights=weights,
        velocities=velocities,
        masses=data["mass"],
        concentrations=data["concentration"],
        labels=data[label],
    )

# Function to evaluate and collect loss arrays
def evaluate_loss(problem_context, hyperparams):
    result_dict, loss_array = get_oat_losses(problem_context, hyperparams, inner_pbar=False)
    return result_dict, loss_array

In [8]:
def truncate_colormap(cmap, min_val=0.1, max_val=0.9, n_colors=256):
    new_cmap = LinearSegmentedColormap.from_list(
        'truncated_cmap', cmap(np.linspace(min_val, max_val, n_colors))
    )
    return new_cmap

def plot_variable_effect(loss_data, variable_values, n_neighbors, variable_name, xlabel, ylabel, title, save_path):
    fig, ax = plt.subplots()
    
    # Create a colormap based on the variable values
    colormap = truncate_colormap(plt.cm.Blues, 0.1, 0.9)  # You can choose any colormap (e.g., 'plasma', 'inferno', 'magma', 'cividis')
    normalize = LogNorm(vmin=min(variable_values), vmax=max(variable_values))
    

    for idx, var_value in enumerate(variable_values):
        mean_losses = np.mean(loss_data[idx], axis=0)
        color = colormap(normalize(var_value))
        ax.plot(n_neighbors, mean_losses, color=color)
    
    sm = ScalarMappable(cmap=colormap, norm=normalize)
    sm.set_array([]) 
    cbar = plt.colorbar(sm, ax=ax)
    cbar.set_label(f"{variable_name} (log-scale)")
    

    ax.legend(loc="upper right")
    ax.set(xlabel=xlabel, ylabel=ylabel, title=title)
    ax.set_xscale("log")
    
    # Save and close the plot
    plt.savefig(save_path, bbox_inches="tight")
    plt.close()

In [9]:
@ray.remote(num_gpus=1)
def process_downsampled_data_item(
    size,
    filename,
    train_indices,
    generated_data_path,
    fixed_hyperparams,
    variable_value,
    n_neighbors,
    label
):
    print("Processing File: ", filename)
    subsampled_data = jnp.load(generated_data_path / filename)
    features_and_targets = jnp.load(generated_data_path / "features_and_targets.npz")
    train_data, weights, points, velocities = preprocess_data(features_and_targets, train_indices, subsampled_data)
    problem_context = create_problem_context(train_data, weights, points, velocities, label)

    hyperparametrization = Hyperparametrization(
        rescale_strategy=["unitless"],
        p=fixed_hyperparams["p"],
        q=fixed_hyperparams["q"],
        tau=1,
        alpha_C=fixed_hyperparams["alpha_C"],
        alpha_M=fixed_hyperparams["alpha_M"],
        alpha_SLB=fixed_hyperparams["alpha_SLB"],
        n_neighbors=n_neighbors,
    )

    result_dict, loss_array = evaluate_loss(problem_context, hyperparametrization)
    return size, result_dict, loss_array


def generate_results(
    train_indices,
    downsampled_data,
    hyperparam_ranges,
    variable_name,
    generated_data_path,
    label
):
    # Fix other hyperparameters to their default values
    ignore_variables = {variable_name, "n_neighbors"}
    fixed_hyperparams = {key: values[0] for key, values in hyperparam_ranges.items() if key not in ignore_variables}
    
    # Extract unique tau values from downsampled_data keys
    tau_values = sorted({key[1] for key in downsampled_data.keys()})
    variable_values = tau_values if variable_name == "tau" else hyperparam_ranges[variable_name]
    n_neighbors = hyperparam_ranges["n_neighbors"]

    loss_data = []
    result_data = []
    for idx, var_value in enumerate(variable_values):
        # Check if cached results exist
        (cache_dir / label).mkdir(exist_ok=True)
        cache_file = (cache_dir / label) / f"{variable_name}_{var_value}.npz"
        if cache_file.exists():
            print(f"Loading cached results for {variable_name} = {var_value}")
            cached_data = np.load(cache_file, allow_pickle=True)
            losses_for_var = cached_data["losses_for_var"]
            result_dict_for_var = cached_data["result_dict_for_var"]
        else:
            # Launch remote tasks
            result_refs = [
                process_downsampled_data_item.remote(
                    size,
                    filename,
                    train_indices,
                    generated_data_path,
                    fixed_hyperparams,
                    var_value,  # Pass tau or other variable value
                    n_neighbors,
                    label
                )
                for (size, tau), filename in downsampled_data.items()
                if tau == var_value or variable_name != "tau"  # Match tau when processing tau
            ]

            # Collect results
            results = ray.get(result_refs)
            results.sort(key=lambda x: x[0])  # Sort by size
            losses_for_var = [loss_array for _, _, loss_array in results]
            losses_for_var = np.stack(losses_for_var)
            result_dict_for_var = [result_dict for _, result_dict, _ in results]

            # Cache the results
            np.savez(cache_file, losses_for_var=losses_for_var, result_dict_for_var=result_dict_for_var)
            print(f"Cached results for {variable_name} = {var_value}")

        loss_data.append(losses_for_var)
        result_data.append(result_dict_for_var)

    return loss_data, result_data, variable_values, n_neighbors

In [10]:
def plot_results_minimum(
    loss_data,
    variable_values,
    n_neighbors,
    variable_name,
    output_dir,
    regression_type,
    find_minimum=False
):
    os.makedirs(output_dir, exist_ok=True)

    # Find the minimum loss if required
    min_idx = 0
    min_loss = float('inf')
    if find_minimum:
        for idx, losses_for_var in enumerate(loss_data):
            if np.min(losses_for_var) < min_loss:
                min_loss = np.min(losses_for_var)
                min_idx = idx
        loss_data = loss_data[min_idx:(min_idx + 1)]
        variable_values = variable_values[min_idx : (min_idx + 1)]
        prefix = "_minimum"
    else:
        prefix = ""

    # Save the plot for the given regression type
    plot_path = os.path.join(output_dir, f"{regression_type}_{variable_name}_effect{prefix}.pdf")
    plot_variable_effect(
        loss_data,
        variable_values,
        n_neighbors,
        variable_name,
        xlabel="Number of Neighbors ($k$)",
        ylabel="RMSE Loss",
        title=f"{regression_type} ({variable_name.capitalize()} Effect)",
        save_path=plot_path,
    )


def process_and_plot_variable(
    train_indices,
    downsampled_data,
    hyperparam_ranges,
    variable_name,
    output_dir,
    regression_type,
    generated_data_path,
    label,
    find_minimum=False
):
    # Generate or load results
    loss_data, result_data, variable_values, n_neighbors = generate_results(
        train_indices,
        downsampled_data,
        hyperparam_ranges,
        variable_name,
        generated_data_path,
        label
    )

    # Plot results
    plot_results_minimum(
        loss_data,
        variable_values,
        n_neighbors,
        variable_name,
        output_dir,
        regression_type,
        find_minimum
    )

In [11]:
# n_neighbors = np.logspace(0, 5.01, num=10, base=np.e).round().astype(np.int32)
n_neighbors = np.asarray([10, 13, 16, 19, 22]).astype(np.int32)
print(n_neighbors)

[10 13 16 19 22]


### Tau Variability

In [12]:
stellar_mass_hyperparam_ranges = {
    "p": [1.0],
    "q": [1.0],
    "tau": taus, 
    "alpha_C": [14.672736961511486],
    "alpha_M": [90.9960841116981],
    "alpha_SLB": [56.26711116650133],
    "n_neighbors": n_neighbors,
}

# Process and plot for Mass experiment
print("STELLAR MASS")
process_and_plot_variable(
    train_indices,
    downsampled_data,
    stellar_mass_hyperparam_ranges,
    variable_name="tau",
    output_dir="mass_experiment_results",
    regression_type="Stellar Mass",
    generated_data_path=generated_data_path,
    label="stellar_mass",
    find_minimum=False
)


# Process and plot for Metallicity experiment
stellar_metallicity_hyperparam_ranges = {
    "p": [1.0],
    "q": [1.0],
    "tau": taus,
    "alpha_C": [41.38652023058744],
    "alpha_M": [94.95040630340328],
    "alpha_SLB": [70.89754616927755],
    "n_neighbors": n_neighbors,
}

print("STELLAR METALLICITY")
process_and_plot_variable(
    train_indices,
    downsampled_data,
    stellar_metallicity_hyperparam_ranges,
    variable_name="tau",
    output_dir="metallicity_experiment_results",
    regression_type="Stellar Metallicity",
    generated_data_path=generated_data_path,
    label="stellar_metallicity",
    find_minimum=False
)

print("STAR FORMATION")
# Process and plot for star formation rate experiment
star_formation_rate_hyperparam_ranges = {
    "p": [2.],
    "q": [2.],
    "tau": taus,
    "alpha_C": [22.954853753245022],
    "alpha_M": [94.96140687714441],
    "alpha_SLB": [14.423051676594431],
    "n_neighbors": n_neighbors
}

process_and_plot_variable(
    train_indices,
    downsampled_data,
    star_formation_rate_hyperparam_ranges,
    variable_name="tau",
    output_dir="star_formation_rate_experiment_results",
    regression_type="Star Formation Rate",
    generated_data_path=generated_data_path,
    label="star_formation_rate",
    find_minimum=False
)

STELLAR MASS
Loading cached results for tau = 0.15848931924611134
Loading cached results for tau = 0.3511191734215131
Loading cached results for tau = 0.7778737048694307
Loading cached results for tau = 1.7233109056135083
Loading cached results for tau = 3.817844026370506
Loading cached results for tau = 8.458098281751338
Loading cached results for tau = 18.73817422860384
Loading cached results for tau = 41.51278002752293
Loading cached results for tau = 91.96791985117059
Loading cached results for tau = 203.74685280397102
Loading cached results for tau = 451.3832659768975
Loading cached results for tau = 1000.0


  ax.legend(loc="upper right")


STELLAR METALLICITY
Loading cached results for tau = 0.15848931924611134
Loading cached results for tau = 0.3511191734215131
Loading cached results for tau = 0.7778737048694307
Loading cached results for tau = 1.7233109056135083
Loading cached results for tau = 3.817844026370506
Loading cached results for tau = 8.458098281751338
Loading cached results for tau = 18.73817422860384
Loading cached results for tau = 41.51278002752293
Loading cached results for tau = 91.96791985117059
Loading cached results for tau = 203.74685280397102
Loading cached results for tau = 451.3832659768975
Loading cached results for tau = 1000.0


  ax.legend(loc="upper right")


STAR FORMATION
Loading cached results for tau = 0.15848931924611134
Loading cached results for tau = 0.3511191734215131
Loading cached results for tau = 0.7778737048694307
Loading cached results for tau = 1.7233109056135083
Loading cached results for tau = 3.817844026370506
Loading cached results for tau = 8.458098281751338
Loading cached results for tau = 18.73817422860384
Loading cached results for tau = 41.51278002752293
Loading cached results for tau = 91.96791985117059
Loading cached results for tau = 203.74685280397102
Loading cached results for tau = 451.3832659768975
Loading cached results for tau = 1000.0


  ax.legend(loc="upper right")


In [14]:
ray.shutdown()