### Check for int

This code checks for the parameters that should be float for SVI but that are incorrectly initialized as int. 

In [1]:
from matplotlib import pyplot as plt
from jax import numpy as jnp
import numpyro
numpyro.set_host_device_count(8)
import jax
import numpy as np

from tfscreen.analysis.hierarchical.growth_model import GrowthModel
import tfscreen

from numpyro.infer.util import initialize_model
from jax import random
import jax.numpy as jnp


to_get_list = ["wt","M42I","H74A","K84L","I64N","L45P","I79C","T68V","A81C"]

growth_df = tfscreen.util.read_dataframe("growth.csv")
growth_df_subset = growth_df[growth_df["genotype"].isin(to_get_list)].reset_index(drop=True)

bind_df = tfscreen.util.read_dataframe("binding.csv")
bind_df_subset = bind_df[bind_df["genotype"].isin(to_get_list)].reset_index(drop=True)

gm = GrowthModel(growth_df=growth_df_subset,
                 binding_df=bind_df_subset,
                 theta="hill",
                 condition_growth="hierarchical",
                 theta_binding_noise="none",
                 theta_growth_noise="none",
                 activity="fixed")



# You need a PRNGKey
debug_key = random.PRNGKey(8675309)

# args and kwargs should be exactly what you pass to svi.update or the model
# Assuming 'args' is your tuple of arguments and 'kwargs' is your dict
# If you usually pass them unpacked to SVI, unpack them here.

print("\n========== DEBUGGING PARAMETERS ==========")
try:
    # This simulates the first setup step of SVI
    init_params, potential_fn, postprocess_fn, model_trace = initialize_model(
        debug_key, 
        model=gm.jax_model_guide, 
        model_args=[], 
        model_kwargs={"data":gm.data,"priors":gm.priors} 
    )

    found_int = False
    print(f"{'PARAMETER NAME':<30} | {'DTYPE':<10} | {'SHAPE'}")
    print("-" * 60)

    for param_class in init_params:

        if not hasattr(param_class,"items"):
            iterable = [("single",param_class)]
        else:
            iterable = param_class.items()
        
        for name, val in iterable:
            dtype_str = str(val.dtype)
            print(f"{name:<30} | {dtype_str:<10} | {val.shape}")
            
            # Flag anything that looks like an integer
            if "int" in dtype_str:
                found_int = True
                print(f"   >>> üö® FOUND INT PARAMETER: {name}")
    
        print("-" * 60)
        if not found_int:
            print("‚úÖ No integer parameters found in the params dict.")
        else:
            print("‚ùå CRITICAL: Integer parameters detected. SVI will crash on these.")

except Exception as e:
    print(f"Crash during model initialization debug: {e}")
    # If it crashes here, the issue is inside the guide/model trace generation itself
    import traceback
    traceback.print_exc()

print("==========================================\n")


PARAMETER NAME                 | DTYPE      | SHAPE
------------------------------------------------------------
theta_logit_low_hyper_loc      | float32    | ()
theta_logit_low_hyper_scale    | float32    | ()
theta_logit_delta_hyper_loc    | float32    | ()
theta_logit_delta_hyper_scale  | float32    | ()
theta_log_hill_K_hyper_loc     | float32    | ()
theta_log_hill_K_hyper_scale   | float32    | ()
theta_log_hill_n_hyper_loc     | float32    | ()
theta_log_hill_n_hyper_scale   | float32    | ()
theta_logit_low_offset         | float32    | (1, 9)
theta_logit_delta_offset       | float32    | (1, 9)
theta_log_hill_K_offset        | float32    | (1, 9)
theta_log_hill_n_offset        | float32    | (1, 9)
condition_growth_k_hyper_loc   | float32    | ()
condition_growth_k_hyper_scale | float32    | ()
condition_growth_m_hyper_loc   | float32    | ()
condition_growth_m_hyper_scale | float32    | ()
condition_growth_k_offset      | float32    | (8,)
condition_growth_m_offset      | fl

In [2]:
# --- DEBUG BLOCK START ---
print("\nüîç STARTING DEEP SVI PARAMETER INSPECTION")
from numpyro.infer import SVI, Trace_ELBO
from numpyro.optim import Adam

# Build an SVI object
debug_optim = Adam(0.001) 
debug_svi = SVI(gm.jax_model,
                gm.jax_model_guide,
                debug_optim,
                loss=Trace_ELBO())

# Initialize the object
debug_state = debug_svi.init(jax.random.PRNGKey(0),data=gm.data,priors=gm.priors)

# Extract parameters directly from the SVI state
debug_params = debug_svi.get_params(debug_state)

# Recursive search for integers (handles nested Flax dicts)
def find_ints(tree, path=""):
    
    if hasattr(tree, 'dtype'):
        # It's an array/tensor
        if "int" in str(tree.dtype):
            print(f"üö® FOUND INT PARAMETER! | Path: {path} | Dtype: {tree.dtype} | Shape: {tree.shape}")
            return True
        else:
            print(f"‚úÖ Float Param: {path} ({tree.dtype})")
            return False
            
    elif isinstance(tree, dict):
        found = False
        for k, v in tree.items():
            if find_ints(v, path=f"{path}.{k}" if path else k):
                found = True
        return found
        
    else:
        # Check for Flax FrozenDict or other containers
        if hasattr(tree, 'items'): 
            found = False
            for k, v in tree.items():
                if find_ints(v, path=f"{path}.{k}" if path else k):
                    found = True
            return found
        return False

print("Scanning SVI params for integers...")
found_any = find_ints(debug_params)

if not found_any:
    print("‚úÖ Params look clean.")
else:
    print("üî• Found the culprit above.")




üîç STARTING DEEP SVI PARAMETER INSPECTION
Scanning SVI params for integers...
‚úÖ Float Param: condition_growth_k_hyper_loc_loc (float32)
‚úÖ Float Param: condition_growth_k_hyper_loc_scale (float32)
‚úÖ Float Param: condition_growth_k_hyper_scale_loc (float32)
‚úÖ Float Param: condition_growth_k_hyper_scale_scale (float32)
‚úÖ Float Param: condition_growth_k_offset_locs (float32)
‚úÖ Float Param: condition_growth_k_offset_scales (float32)
‚úÖ Float Param: condition_growth_m_hyper_loc_loc (float32)
‚úÖ Float Param: condition_growth_m_hyper_loc_scale (float32)
‚úÖ Float Param: condition_growth_m_hyper_scale_loc (float32)
‚úÖ Float Param: condition_growth_m_hyper_scale_scale (float32)
‚úÖ Float Param: condition_growth_m_offset_locs (float32)
‚úÖ Float Param: condition_growth_m_offset_scales (float32)
‚úÖ Float Param: dk_geno_hyper_loc_loc (float32)
‚úÖ Float Param: dk_geno_hyper_loc_scale (float32)
‚úÖ Float Param: dk_geno_hyper_scale_loc (float32)
‚úÖ Float Param: dk_geno_hyper_scale