In [2]:
import os
# Configure JAX to use CPU for easier debugging
os.environ['JAX_PLATFORMS'] = 'cpu'

import jax
import jax.numpy as jnp
import math
import time
import sys
import copy
import functools
from typing import Dict, Any

import flax.linen as nn
from ml_collections import config_dict
from clu import parameter_overview
from etils import epath

# Add the md4 module to the path
sys.path.append('/mnt/workspace/md4')

from md4.configs.md4 import molecular_finetune, molecular
from md4.models import utils as model_utils
from md4 import train
from md4 import partial_load_utils

print(f"JAX devices: {jax.devices()}")
print(f"JAX default backend: {jax.default_backend()}")
print(f"Number of devices: {len(jax.devices())}")
print("=" * 60)

  from .autonotebook import tqdm as notebook_tqdm


JAX devices: [CpuDevice(id=0)]
JAX default backend: cpu
Number of devices: 1


In [3]:
# Load both configurations for comparison
new_config = molecular_finetune.get_config()
old_config = molecular.get_config()

print("🔍 CONFIGURATION COMPARISON")
print("=" * 60)

def compare_configs(old_cfg, new_cfg, prefix=""):
    """Compare two configurations and highlight differences."""
    differences = []
    
    for key in set(list(old_cfg.keys()) + list(new_cfg.keys())):
        old_val = getattr(old_cfg, key, "<MISSING>")
        new_val = getattr(new_cfg, key, "<MISSING>")
        
        if old_val != new_val:
            differences.append((key, old_val, new_val))
    
    return differences

differences = compare_configs(old_config, new_config)

print("Key differences between old and new configs:")
for key, old_val, new_val in differences:
    print(f"  {key}: {old_val} → {new_val}")

print(f"\nOLD CONFIG (molecular.py):")
print(f"  Model type: {old_config.model_type}")
print(f"  Dataset: {old_config.dataset}")
print(f"  Vocab size: {old_config.vocab_size}")
print(f"  Feature dim: {old_config.feature_dim}")
print(f"  Fingerprint dim: {old_config.fingerprint_dim}")
print(f"  Frozen mode: {old_config.get('frozen', False)}")

print(f"\nNEW CONFIG (molecular_finetune.py):")
print(f"  Model type: {new_config.model_type}")
print(f"  Dataset: {new_config.dataset}")
print(f"  Vocab size: {new_config.vocab_size}")
print(f"  Feature dim: {new_config.feature_dim}")
print(f"  Fingerprint dim: {new_config.fingerprint_dim}")
print(f"  Frozen mode: {new_config.get('frozen', False)}")
print(f"  Partial load: {new_config.get('partial_load', False)}")

print("=" * 60)

🔍 CONFIGURATION COMPARISON
Key differences between old and new configs:
  warmup_steps: 2000 → 1000
  frozen: <MISSING> → True
  num_eval_steps: 1000 → 100
  eval_every_steps: 20000 → 2000
  checkpoint_every_steps: 20000 → 2000
  partial_load: <MISSING> → True
  checkpoint_keep_period: 200000 → 6000
  learning_rate: 0.0003 → 1e-05
  learning_rate_schedule: cosine → constant
  frozen_paths: <MISSING> → []
  old_config: <MISSING> → md4/configs/md4/molecular.py
  num_train_steps: 1500000 → 100000
  dataset: pubchem_large → msg_finetune
  adapter_init_paths: <MISSING> → ['fp_adapter']
  unfrozen_paths: <MISSING> → ['fp_adapter']
  weight_decay: 1e-06 → 0.0
  fingerprint_adapter: <MISSING> → True
  raw_fingerprint_dim: <MISSING> → 4096

OLD CONFIG (molecular.py):
  Model type: md4
  Dataset: pubchem_large
  Vocab size: 1024
  Feature dim: 64
  Fingerprint dim: 2048
  Frozen mode: False

NEW CONFIG (molecular_finetune.py):
  Model type: md4
  Dataset: msg_finetune
  Vocab size: 1024
  Featur

In [4]:
# Create models for both configurations
print("🏗️  MODEL CREATION")
print("=" * 60)

old_model = model_utils.get_model(old_config)
new_model = model_utils.get_model(new_config)

print(f"Old model: {type(old_model)}")
print(f"New model: {type(new_model)}")

# Set up common parameters
per_device_batch_size = 4  # Small batch for testing
data_shape = (new_config.max_length,)
num_train_steps = 1000000

schedule_fn = functools.partial(
    train.get_learning_rate,
    base_learning_rate=new_config.learning_rate,
    num_steps=num_train_steps,
    warmup_steps=new_config.warmup_steps,
    schedule_type=new_config.learning_rate_schedule,
)

print(f"Data shape: {data_shape}")
print(f"Per device batch size: {per_device_batch_size}")
print("=" * 60)

🏗️  MODEL CREATION
Old model: <class 'md4.models.diffusion.md4.MD4'>
New model: <class 'md4.models.diffusion.md4.MD4'>
Data shape: (128,)
Per device batch size: 4


In [6]:
new_rng = jax.random.PRNGKey(42)  # Same seed for reproducibility
import md4.state_utils as state_utils
new_model, new_optimizer, new_train_state, new_metrics_class = state_utils.create_train_state(
    new_config,
    new_rng,
    input_shape=(per_device_batch_size,) + data_shape,
    schedule_fn=schedule_fn,
)
print(f"✅ New FROZEN train state created (step: {new_train_state.step})")

✅ New FROZEN train state created (step: 0)


In [7]:
new_train_state.params["cond_embeddings"]

{'cond_dense_0': {'bias': Array([0., 0., 0., ..., 0., 0., 0.], dtype=float32),
  'kernel': Array([[-2.0289680e-02,  2.5764512e-02, -5.0052800e-03, ...,
           3.9162129e-02,  4.2294681e-02,  2.7065177e-02],
         [ 2.0283964e-05,  1.2533314e-02,  4.0390089e-02, ...,
           8.4358966e-03,  1.7612189e-02, -2.7960740e-02],
         [ 2.8920980e-02, -3.3317164e-02,  1.8416533e-02, ...,
          -5.2914713e-03,  4.3843381e-02, -7.5638820e-03],
         ...,
         [-3.1251512e-02, -1.4581830e-02,  1.9599354e-02, ...,
          -2.8813062e-02, -1.7215077e-02, -3.7434134e-03],
         [-7.7669602e-03, -8.6089112e-03, -2.3589877e-02, ...,
          -1.2109401e-02,  6.1565200e-03,  3.5236843e-03],
         [ 2.2222603e-02,  3.1736728e-02, -3.0510605e-03, ...,
           4.2397384e-02, -1.5349303e-02, -3.1204224e-02]], dtype=float32)},
 'cond_dense_1': {'bias': Array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0

In [5]:
new_workdir = "/mnt/workspace/md4/finetune_frozen_expt"
new_checkpoint_manager = checkpoint_utils._get_checkpoint_manager(new_config, new_workdir, create=False)

print(f"New checkpoint directory: {epath.Path(new_workdir) / 'checkpoints'}")
print(f"Latest new checkpoint step: {new_checkpoint_manager.latest_step()}")

# Load the new checkpoint
new_checkpointed_state = {"train_state": copy.deepcopy(new_train_state)}
new_checkpointed_state = new_checkpoint_manager.restore(
    new_checkpoint_manager.latest_step(), 
    items=new_checkpointed_state,
)

new_checkpointed_state_2 = {"train_state": copy.deepcopy(new_train_state)}
new_checkpointed_state_2 = new_checkpoint_manager.restore(
    86000, 
    items=new_checkpointed_state_2,
)



New checkpoint directory: /mnt/workspace/md4/finetune_frozen_expt/checkpoints
Latest new checkpoint step: 20000


ValueError: User-provided restore item and on-disk value metadata tree structures do not match: {'params': {'fp_adapter': {'fingerprint_adapter_out': Diff(lhs={'bias': Array([0., 0., 0., ..., 0., 0., 0.], dtype=float32), 'kernel': Array([[1., 0., 0., ..., 0., 0., 0.],
       [0., 1., 0., ..., 0., 0., 0.],
       [0., 0., 1., ..., 0., 0., 0.],
       ...,
       [0., 0., 0., ..., 1., 0., 0.],
       [0., 0., 0., ..., 0., 1., 0.],
       [0., 0., 0., ..., 0., 0., 1.]], dtype=float32)}, rhs=None)}}, 'ema_params': {'fp_adapter': {'fingerprint_adapter_out': Diff(lhs={'bias': Array([0., 0., 0., ..., 0., 0., 0.], dtype=float32), 'kernel': Array([[1., 0., 0., ..., 0., 0., 0.],
       [0., 1., 0., ..., 0., 0., 0.],
       [0., 0., 1., ..., 0., 0., 0.],
       ...,
       [0., 0., 0., ..., 1., 0., 0.],
       [0., 0., 0., ..., 0., 1., 0.],
       [0., 0., 0., ..., 0., 0., 1.]], dtype=float32)}, rhs=None)}}, 'opt_state': [None, {'inner_states': {'train': {'inner_state': [{'mu': {'fp_adapter': {'fingerprint_adapter_out': Diff(lhs={'bias': Array([0., 0., 0., ..., 0., 0., 0.], dtype=float32), 'kernel': Array([[0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       ...,
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.]], dtype=float32)}, rhs=None)}}, 'nu': {'fp_adapter': {'fingerprint_adapter_out': Diff(lhs={'bias': Array([0., 0., 0., ..., 0., 0., 0.], dtype=float32), 'kernel': Array([[0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       ...,
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.]], dtype=float32)}, rhs=None)}}}, None, None]}}}]}

In [25]:
# Create train states for both models
print("🚂 TRAIN STATE CREATION")
print("=" * 60)

# Create old train state (regular)
old_rng = jax.random.PRNGKey(42)
old_model, old_optimizer, old_train_state, old_metrics_class = train.create_train_state(
    old_config,
    old_rng,
    input_shape=(per_device_batch_size,) + data_shape,
    schedule_fn=schedule_fn,
)

print(f"✅ Old train state created (step: {old_train_state.step})")
print(f"   Parameters keys: {list(old_train_state.params.keys())}")

# Create new train state (frozen)
if new_config.get('frozen', False):
    new_model, new_optimizer, new_train_state, new_metrics_class = train.create_frozen_train_state(
        new_config,
        new_rng,
        input_shape=(per_device_batch_size,) + data_shape,
        schedule_fn=schedule_fn,
    )
    print(f"✅ New FROZEN train state created (step: {new_train_state.step})")
else:
    new_model, new_optimizer, new_train_state, new_metrics_class = train.create_train_state(
        new_config,
        new_rng,
        input_shape=(per_device_batch_size,) + data_shape,
        schedule_fn=schedule_fn,
    )
    print(f"✅ New regular train state created (step: {new_train_state.step})")

print(f"   Parameters keys: {list(new_train_state.params.keys())}")

# Compare parameter structures
def get_param_structure(params, prefix=""):
    """Get a nested dictionary showing parameter shapes."""
    structure = {}
    for key, value in params.items():
        if isinstance(value, dict):
            structure[key] = get_param_structure(value, f"{prefix}.{key}")
        else:
            structure[key] = value.shape
    return structure

old_structure = get_param_structure(old_train_state.params)
new_structure = get_param_structure(new_train_state.params)

print(f"\nParameter structure comparison:")
print(f"Old model parameters: {list(old_structure.keys())}")
print(f"New model parameters: {list(new_structure.keys())}")

# Find differences
old_only = set(old_structure.keys()) - set(new_structure.keys())
new_only = set(new_structure.keys()) - set(old_structure.keys())
common = set(old_structure.keys()) & set(new_structure.keys())

if old_only:
    print(f"❌ Parameters only in old model: {old_only}")
if new_only:
    print(f"✨ Parameters only in new model: {new_only}")
print(f"✅ Common parameters: {common}")

print("=" * 60)

🚂 TRAIN STATE CREATION
✅ Old train state created (step: 0)
   Parameters keys: ['classifier', 'cond_embeddings']
✅ New FROZEN train state created (step: 0)
   Parameters keys: ['classifier', 'cond_embeddings', 'fp_adapter']

Parameter structure comparison:
Old model parameters: ['classifier', 'cond_embeddings']
New model parameters: ['classifier', 'cond_embeddings', 'fp_adapter']
✨ Parameters only in new model: {'fp_adapter'}
✅ Common parameters: {'cond_embeddings', 'classifier'}


In [26]:
old_workdir = "/mnt/workspace/md4/large_expt"
old_checkpoint_manager = checkpoint_utils._get_checkpoint_manager(old_config, old_workdir, create=False)



In [28]:
partially_loaded_checkpoint, _ = train.partial_load_utils.partial_load_checkpoint(
    old_config,
    new_train_state,
    train_iter=None,
    checkpoint_manager=old_checkpoint_manager,
    create_train_state_fn=train.create_train_state,
    schedule_fn=schedule_fn,
    per_device_batch_size= per_device_batch_size,
    data_shape=data_shape,
)

In [34]:
partially_loaded_checkpoint.params["cond_embeddings"]["cond_dense_out"]["bias"]

array([ 0.10222556,  0.05280137,  0.01212622,  0.12334919, -0.09347872,
       -0.0856513 ,  0.08631224, -0.01410997,  0.16532548,  0.10793076,
        0.00095248, -0.15173772,  0.11472066, -0.03458977, -0.08238201,
       -0.04790057, -0.0243195 , -0.14598721, -0.25267345, -0.00888992,
        0.07931393, -0.12354071, -0.14782493, -0.00068826, -0.12568152,
       -0.16664398, -0.09426703,  0.04617818, -0.12923817,  0.00388334,
       -0.18511422,  0.0141763 ,  0.19037528, -0.20642547,  0.11950786,
       -0.06452266,  0.18228224, -0.04121022,  0.05384567,  0.04185708,
        0.02210524,  0.07651379, -0.12801877, -0.06195638, -0.1300139 ,
       -0.060221  ,  0.24796437,  0.15395756, -0.1924481 , -0.00986998,
        0.09546501,  0.01254566,  0.00191896, -0.09120514, -0.06756873,
        0.03463075,  0.02213649,  0.00371821, -0.02300274,  0.09934283,
       -0.00430454,  0.06706794, -0.04506921,  0.22055061], dtype=float32)

In [35]:
old_checkpointed_state = {"train_state": old_train_state}
old_checkpointed_state = old_checkpoint_manager.restore(
    old_checkpoint_manager.latest_step(), 
    items=old_checkpointed_state,
)

In [41]:
partially_loaded_checkpoint.params.keys()

dict_keys(['classifier', 'cond_embeddings', 'fp_adapter'])

In [17]:
# Load old checkpoint
print("📂 LOADING OLD CHECKPOINT")
print("=" * 60)

old_workdir = "/mnt/workspace/md4/large_expt"
old_checkpoint_manager = checkpoint_utils._get_checkpoint_manager(old_config, old_workdir, create=False)

print(f"Old checkpoint directory: {epath.Path(old_workdir) / 'checkpoints'}")
print(f"Latest old checkpoint step: {old_checkpoint_manager.latest_step()}")

if old_checkpoint_manager.latest_step() is not None:
    # Load the old checkpoint
    old_checkpointed_state = {"train_state": old_train_state}
    old_checkpointed_state = old_checkpoint_manager.restore(
        old_checkpoint_manager.latest_step(), 
        items=old_checkpointed_state,
    )
    
    loaded_old_train_state = old_checkpointed_state["train_state"]
    
    print(f"✅ Old checkpoint loaded successfully")
    print(f"   Loaded step: {loaded_old_train_state.step}")
    print(f"   Parameter keys: {list(loaded_old_train_state.params.keys())}")
    
    # Get parameter overview
    old_overview = parameter_overview.get_parameter_overview(loaded_old_train_state.params)
    print(f"   Total parameters: {old_overview}")
    
else:
    print("❌ No old checkpoint found!")
    loaded_old_train_state = None

print("=" * 60)



📂 LOADING OLD CHECKPOINT
Old checkpoint directory: /mnt/workspace/md4/large_expt/checkpoints
Latest old checkpoint step: 1440000
✅ Old checkpoint loaded successfully
   Loaded step: 1440000
   Parameter keys: ['classifier', 'cond_embeddings']
   Total parameters: +---------------------------------------------------------------------+--------------+---------+-----------+-----------+--------+
| Name                                                                | Shape        | Dtype   | Size      | Mean      | Std    |
+---------------------------------------------------------------------+--------------+---------+-----------+-----------+--------+
| classifier/CondEmbedding_0/Dense_0/bias                             | (64,)        | float32 | 64        | -0.136    | 0.209  |
| classifier/CondEmbedding_0/Dense_0/kernel                           | (256, 64)    | float32 | 16,384    | 0.0184    | 0.0984 |
| classifier/CondEmbedding_0/dense0/bias                              | (256,)       |

In [18]:
# Load new checkpoint (if it exists)
print("📂 LOADING NEW CHECKPOINT")
print("=" * 60)

new_workdir = "/mnt/workspace/md4/finetune_frozen_expt"
new_checkpoint_manager = checkpoint_utils._get_checkpoint_manager(new_config, new_workdir, create=False)

print(f"New checkpoint directory: {epath.Path(new_workdir) / 'checkpoints'}")
print(f"Latest new checkpoint step: {new_checkpoint_manager.latest_step()}")

loaded_new_train_state = None
if new_checkpoint_manager.latest_step() is not None:
    # Load the new checkpoint
    new_checkpointed_state = {"train_state": new_train_state}
    new_checkpointed_state = new_checkpoint_manager.restore(
        new_checkpoint_manager.latest_step(), 
        items=new_checkpointed_state,
    )
    
    loaded_new_train_state = new_checkpointed_state["train_state"]
    
    print(f"✅ New checkpoint loaded successfully")
    print(f"   Loaded step: {loaded_new_train_state.step}")
    print(f"   Parameter keys: {list(loaded_new_train_state.params.keys())}")
    
    # Get parameter overview
    new_overview = parameter_overview.get_parameter_overview(loaded_new_train_state.params)
    print(f"   Total parameters: {new_overview}")
    
else:
    print("ℹ️  No new checkpoint found (expected for fresh setup)")

print("=" * 60)



📂 LOADING NEW CHECKPOINT
New checkpoint directory: /mnt/workspace/md4/finetune_frozen_expt/checkpoints
Latest new checkpoint step: 10000
✅ New checkpoint loaded successfully
   Loaded step: 10000
   Parameter keys: ['classifier', 'cond_embeddings', 'fp_adapter']
   Total parameters: +---------------------------------------------------------------------+--------------+---------+------------+-----------+--------+
| Name                                                                | Shape        | Dtype   | Size       | Mean      | Std    |
+---------------------------------------------------------------------+--------------+---------+------------+-----------+--------+
| classifier/CondEmbedding_0/Dense_0/bias                             | (64,)        | float32 | 64         | -0.136    | 0.209  |
| classifier/CondEmbedding_0/Dense_0/kernel                           | (256, 64)    | float32 | 16,384     | 0.0184    | 0.0984 |
| classifier/CondEmbedding_0/dense0/bias                     

In [19]:
# Simulate the partial loading process
print("🔄 SIMULATING PARTIAL LOADING")
print("=" * 60)

if loaded_old_train_state is not None:
    print("Performing manual partial loading simulation...")
    
    # Get the loaded parameters from old checkpoint
    loaded_params = loaded_old_train_state.params
    loaded_ema_params = getattr(loaded_old_train_state, "ema_params", None)
    
    print(f"Loaded old parameters: {list(loaded_params.keys())}")
    
    # Create a fresh new train state (to simulate partial loading)
    fresh_new_train_state = new_train_state
    
    # Copy compatible parameters manually (simulating partial_load_utils logic)
    merged_params = dict(fresh_new_train_state.params)
    merged_ema_params = (
        dict(fresh_new_train_state.ema_params)
        if hasattr(fresh_new_train_state, "ema_params") and fresh_new_train_state.ema_params is not None
        else None
    )
    
    params_copied = []
    params_skipped = []
    params_new = []
    
    # Copy parameters that exist in both models
    for key in loaded_params:
        if key in merged_params:
            merged_params[key] = loaded_params[key]
            params_copied.append(key)
        else:
            params_skipped.append(key)
    
    # Copy EMA parameters if they exist
    if merged_ema_params is not None and loaded_ema_params is not None:
        for key in loaded_ema_params:
            if key in merged_ema_params:
                merged_ema_params[key] = loaded_ema_params[key]
    
    # Find parameters that are new in the current model
    for key in merged_params:
        if key not in loaded_params:
            params_new.append(key)
    
    # Create the merged train state
    merged_train_state = fresh_new_train_state.replace(
        params=merged_params,
        ema_params=merged_ema_params,
    )
    
    print(f"✅ Parameters copied from old model: {params_copied}")
    print(f"❌ Parameters skipped (not in new model): {params_skipped}")
    print(f"✨ Parameters kept as new (not in old model): {params_new}")
    
else:
    print("❌ Cannot simulate partial loading - no old checkpoint available")
    merged_train_state = None

print("=" * 60)

🔄 SIMULATING PARTIAL LOADING
Performing manual partial loading simulation...
Loaded old parameters: ['classifier', 'cond_embeddings']
✅ Parameters copied from old model: ['classifier', 'cond_embeddings']
❌ Parameters skipped (not in new model): []
✨ Parameters kept as new (not in old model): ['fp_adapter']


In [20]:
# Detailed parameter comparison
print("🔍 DETAILED PARAMETER COMPARISON")
print("=" * 60)

def analyze_parameter_tree(params, name):
    """Analyze a parameter tree and return statistics."""
    print(f"\n{name} Parameter Analysis:")
    
    def get_stats(arr):
        return {
            'shape': arr.shape,
            'mean': float(jnp.mean(arr)),
            'std': float(jnp.std(arr)),
            'min': float(jnp.min(arr)),
            'max': float(jnp.max(arr)),
            'num_params': int(jnp.prod(jnp.array(arr.shape)))
        }
    
    total_params = 0
    for key, value in params.items():
        print(f"  {key}:")
        if isinstance(value, dict):
            for subkey, subvalue in value.items():
                if hasattr(subvalue, 'shape'):  # It's an array
                    stats = get_stats(subvalue)
                    total_params += stats['num_params']
                    print(f"    {subkey}: {stats}")
        elif hasattr(value, 'shape'):  # It's an array
            stats = get_stats(value)
            total_params += stats['num_params']
            print(f"    {stats}")
    
    print(f"  Total parameters: {total_params:,}")
    return total_params

if loaded_old_train_state is not None:
    # Analyze original fresh new state
    print("1️⃣ FRESH NEW MODEL (before partial loading)")
    fresh_total = analyze_parameter_tree(new_train_state.params, "Fresh New")
    
    # Analyze loaded old state
    print("\n2️⃣ LOADED OLD MODEL")
    old_total = analyze_parameter_tree(loaded_old_train_state.params, "Loaded Old")
    
    if merged_train_state is not None:
        # Analyze merged state
        print("\n3️⃣ MERGED MODEL (after partial loading)")
        merged_total = analyze_parameter_tree(merged_train_state.params, "Merged")

print("=" * 60)

🔍 DETAILED PARAMETER COMPARISON
1️⃣ FRESH NEW MODEL (before partial loading)

Fresh New Parameter Analysis:
  classifier:
  cond_embeddings:
  fp_adapter:
  Total parameters: 0

2️⃣ LOADED OLD MODEL

Loaded Old Parameter Analysis:
  classifier:
  cond_embeddings:
  Total parameters: 0

3️⃣ MERGED MODEL (after partial loading)

Merged Parameter Analysis:
  classifier:
  cond_embeddings:
  fp_adapter:
  Total parameters: 0


In [21]:
# Verify parameter copying accuracy
print("✅ PARAMETER COPYING VERIFICATION")
print("=" * 60)

def compare_param_trees(tree1, tree2, name1, name2, tolerance=1e-15):
    """Compare two parameter trees for exact matches."""
    print(f"\nComparing {name1} vs {name2}:")
    
    def compare_arrays(arr1, arr2, path=""):
        if arr1.shape != arr2.shape:
            print(f"  ❌ Shape mismatch at {path}: {arr1.shape} vs {arr2.shape}")
            return False
        
        if jnp.allclose(arr1, arr2, rtol=tolerance, atol=tolerance):
            print(f"  ✅ Exact match at {path}: shape {arr1.shape}")
            return True
        else:
            max_diff = float(jnp.max(jnp.abs(arr1 - arr2)))
            mean_diff = float(jnp.mean(jnp.abs(arr1 - arr2)))
            print(f"  ⚠️  Difference at {path}: max={max_diff:.2e}, mean={mean_diff:.2e}")
            return False
    
    all_match = True
    
    # Compare common keys
    common_keys = set(tree1.keys()) & set(tree2.keys())
    for key in common_keys:
        val1, val2 = tree1[key], tree2[key]
        
        if isinstance(val1, dict) and isinstance(val2, dict):
            # Recursively compare sub-dictionaries
            sub_match = compare_param_trees(val1, val2, f"{name1}.{key}", f"{name2}.{key}", tolerance)
            all_match = all_match and sub_match
        elif hasattr(val1, 'shape') and hasattr(val2, 'shape'):
            # Compare arrays
            match = compare_arrays(val1, val2, key)
            all_match = all_match and match
    
    # Report keys that exist in only one tree
    only_in_1 = set(tree1.keys()) - set(tree2.keys())
    only_in_2 = set(tree2.keys()) - set(tree1.keys())
    
    if only_in_1:
        print(f"  📝 Only in {name1}: {only_in_1}")
    if only_in_2:
        print(f"  📝 Only in {name2}: {only_in_2}")
    
    return all_match

if loaded_old_train_state is not None and merged_train_state is not None:
    # Verify that copied parameters are identical
    print("🔍 Checking if copied parameters are identical:")
    
    # Compare parameters that should have been copied
    for param_name in params_copied:
        if param_name in loaded_old_train_state.params and param_name in merged_train_state.params:
            params_match = compare_param_trees(
                loaded_old_train_state.params[param_name],
                merged_train_state.params[param_name],
                f"old.{param_name}",
                f"merged.{param_name}"
            )
            
            if params_match:
                print(f"✅ {param_name}: Parameters copied correctly")
            else:
                print(f"❌ {param_name}: Parameters may not have copied correctly")
    
    # Verify that new parameters are different from any old parameters
    print(f"\n🔍 Checking new parameters are properly initialized:")
    for param_name in params_new:
        if param_name in merged_train_state.params:
            param_vals = merged_train_state.params[param_name]
            
            # Check if it's properly initialized (not all zeros, not NaN)
            def check_initialization(arr, path=""):
                if hasattr(arr, 'shape'):
                    mean_val = float(jnp.mean(arr))
                    std_val = float(jnp.std(arr))
                    has_nans = bool(jnp.any(jnp.isnan(arr)))
                    all_zeros = bool(jnp.all(arr == 0))
                    
                    print(f"    {path}: mean={mean_val:.6f}, std={std_val:.6f}, "
                          f"has_nans={has_nans}, all_zeros={all_zeros}")
                    
                    return not (has_nans or all_zeros or std_val < 1e-8)
                return True
            
            print(f"  {param_name}:")
            if isinstance(param_vals, dict):
                init_ok = True
                for subkey, subval in param_vals.items():
                    init_ok = init_ok and check_initialization(subval, subkey)
            else:
                init_ok = check_initialization(param_vals)
            
            if init_ok:
                print(f"    ✅ Properly initialized")
            else:
                print(f"    ⚠️  May have initialization issues")

else:
    print("❌ Cannot verify parameter copying - missing required states")

print("=" * 60)

✅ PARAMETER COPYING VERIFICATION
🔍 Checking if copied parameters are identical:

Comparing old.classifier vs merged.classifier:

Comparing old.classifier.Embed_0 vs merged.classifier.Embed_0:
  ✅ Exact match at embedding: shape (1025, 64)

Comparing old.classifier.Transformer_0 vs merged.classifier.Transformer_0:

Comparing old.classifier.Transformer_0.TransformerBlock_4 vs merged.classifier.Transformer_0.TransformerBlock_4:

Comparing old.classifier.Transformer_0.TransformerBlock_4.attention vs merged.classifier.Transformer_0.TransformerBlock_4.attention:

Comparing old.classifier.Transformer_0.TransformerBlock_4.attention.wv vs merged.classifier.Transformer_0.TransformerBlock_4.attention.wv:
  ✅ Exact match at kernel: shape (1024, 1024)

Comparing old.classifier.Transformer_0.TransformerBlock_4.attention.wo vs merged.classifier.Transformer_0.TransformerBlock_4.attention.wo:
  ✅ Exact match at kernel: shape (1024, 1024)

Comparing old.classifier.Transformer_0.TransformerBlock_4.attent

In [24]:
# Test frozen behavior and optimizer masks
print("🧊 FROZEN BEHAVIOR ANALYSIS")
print("=" * 60)

if new_config.get('frozen', False):
    print("Testing frozen train state behavior...")
    
    # Let's examine the optimizer state to see what's frozen
    print(f"New optimizer: {new_optimizer}")
    
    # Check if we can access the freeze mask from the optimizer
    # The freeze mask should be in the optimizer chain
    try:
        # Initialize optimizer state to examine structure
        dummy_params = new_train_state.params
        opt_state = new_optimizer.init(dummy_params)
        
        print(f"\nOptimizer state structure: {jax.tree_util.tree_map(lambda x: type(x).__name__, opt_state)}")
        
        # Try to find the freeze mask - look specifically for MaskedState
        print("\nLooking for freeze information in optimizer...")
        
        def find_freeze_info(tree, path=""):
            if hasattr(tree, '__class__') and 'Masked' in str(tree.__class__):
                print(f"  🧊 Found MaskedState at {path}: {tree}")
                if hasattr(tree, 'inner_state'):
                    print(f"    Inner state: {tree.inner_state}")
            
            if isinstance(tree, (tuple, list)):
                for i, item in enumerate(tree):
                    find_freeze_info(item, f"{path}[{i}]")
            elif isinstance(tree, dict):
                for key, value in tree.items():
                    find_freeze_info(value, f"{path}.{key}" if path else key)
            elif hasattr(tree, '__dict__'):
                for attr_name in dir(tree):
                    if not attr_name.startswith('_') and attr_name in ['inner_state', 'mask']:
                        try:
                            attr_value = getattr(tree, attr_name)
                            print(f"  Found {attr_name} at {path}: {type(attr_value)}")
                        except:
                            pass
        
        find_freeze_info(opt_state)
        
    except Exception as e:
        print(f"Error examining optimizer state: {e}")
    
    # Test what happens when we try to update frozen parameters
    print("\n🧪 Testing parameter update behavior...")
    
    # Create some dummy gradients
    dummy_grads = jax.tree_util.tree_map(
        lambda x: jnp.ones_like(x) * 0.01,  # Small gradient
        new_train_state.params
    )
    
    # Apply one optimizer step
    try:
        updates, new_opt_state = new_optimizer.update(
            dummy_grads, 
            new_train_state.opt_state, 
            new_train_state.params
        )
        
        print("✅ Optimizer update successful")
        
        # Check which parameters would actually be updated
        print("Parameter update analysis:")
        
        def analyze_updates(grads_tree, updates_tree, path=""):
            """Recursively analyze parameter updates."""
            if isinstance(grads_tree, dict) and isinstance(updates_tree, dict):
                for key in grads_tree.keys():
                    if key in updates_tree:
                        new_path = f"{path}.{key}" if path else key
                        analyze_updates(grads_tree[key], updates_tree[key], new_path)
            elif hasattr(grads_tree, 'shape') and hasattr(updates_tree, 'shape'):
                # These are actual arrays
                grad_norm = float(jnp.linalg.norm(jnp.ravel(grads_tree)))
                update_norm = float(jnp.linalg.norm(jnp.ravel(updates_tree)))
                ratio = update_norm / grad_norm if grad_norm > 0 else 0
                
                if update_norm < 1e-10:
                    status = "🧊 FROZEN"
                elif ratio < 1e-6:
                    status = "🔒 MOSTLY FROZEN"
                else:
                    status = "🔥 ACTIVE"
                
                print(f"  {path}: grad_norm={grad_norm:.6f}, update_norm={update_norm:.6f}, ratio={ratio:.6f} {status}")
        
        analyze_updates(dummy_grads, updates)
        
        # Also check the high-level parameter groups
        print(f"\nHigh-level parameter group analysis:")
        for param_group in ['classifier', 'cond_embeddings', 'fp_adapter']:
            if param_group in dummy_grads and param_group in updates:
                # Calculate total norms for each parameter group
                grad_leaves = jax.tree_util.tree_leaves(dummy_grads[param_group])
                update_leaves = jax.tree_util.tree_leaves(updates[param_group])
                
                total_grad_norm = float(jnp.sqrt(sum(jnp.sum(jnp.square(leaf)) for leaf in grad_leaves)))
                total_update_norm = float(jnp.sqrt(sum(jnp.sum(jnp.square(leaf)) for leaf in update_leaves)))
                ratio = total_update_norm / total_grad_norm if total_grad_norm > 0 else 0
                
                if total_update_norm < 1e-8:
                    status = "🧊 FROZEN"
                elif ratio < 1e-4:
                    status = "🔒 MOSTLY FROZEN"
                else:
                    status = "🔥 ACTIVE"
                
                print(f"  {param_group}: total_grad={total_grad_norm:.6f}, total_update={total_update_norm:.6f}, ratio={ratio:.6f} {status}")
        
    except Exception as e:
        print(f"❌ Error during optimizer update: {e}")
        import traceback
        traceback.print_exc()

else:
    print("New model is not in frozen mode - skipping frozen behavior analysis")

print("=" * 60)

🧊 FROZEN BEHAVIOR ANALYSIS
Testing frozen train state behavior...
New optimizer: GradientTransformationExtraArgs(init=<function chain.<locals>.init_fn at 0x7e7124d4c180>, update=<function chain.<locals>.update_fn at 0x7e7124d4c220>)

Optimizer state structure: (EmptyState(), (ScaleByAdamState(count='ArrayImpl', mu={'classifier': {'CondEmbedding_0': {'Dense_0': {'bias': 'ArrayImpl', 'kernel': 'ArrayImpl'}, 'dense0': {'bias': 'ArrayImpl', 'kernel': 'ArrayImpl'}}, 'Embed_0': {'embedding': 'ArrayImpl'}, 'Transformer_0': {'Dense_0': {'bias': 'ArrayImpl', 'kernel': 'ArrayImpl'}, 'Dense_1': {'bias': 'ArrayImpl', 'kernel': 'ArrayImpl'}, 'Dense_2': {'kernel': 'ArrayImpl'}, 'TransformerBlock_0': {'Dense_0': {'bias': 'ArrayImpl', 'kernel': 'ArrayImpl'}, 'attention': {'wk': {'kernel': 'ArrayImpl'}, 'wo': {'kernel': 'ArrayImpl'}, 'wq': {'kernel': 'ArrayImpl'}, 'wv': {'kernel': 'ArrayImpl'}}, 'feed_forward': {'w1': {'kernel': 'ArrayImpl'}, 'w2': {'kernel': 'ArrayImpl'}, 'w3': {'kernel': 'ArrayImpl'}

In [None]:
# Test the actual partial_load_utils function
print("🔧 TESTING ACTUAL PARTIAL LOADING FUNCTION")
print("=" * 60)

if loaded_old_train_state is not None:
    try:
        # Test if we should use partial loading according to the utility
        should_partial_load = partial_load_utils.should_use_partial_loading(new_config)
        print(f"Should use partial loading according to config: {should_partial_load}")
        
        # Test the actual partial loading function if available
        if hasattr(partial_load_utils, 'partial_load_checkpoint'):
            print("\nTesting partial_load_checkpoint function...")
            
            # We'll need to create a mock checkpoint manager for this test
            # since partial_load_checkpoint expects to load from a checkpoint manager
            
            # For now, let's just test the logic by examining the function
            import inspect
            source = inspect.getsource(partial_load_utils.partial_load_checkpoint)
            print("Function signature analysis:")
            sig = inspect.signature(partial_load_utils.partial_load_checkpoint)
            print(f"Parameters: {list(sig.parameters.keys())}")
            
    except Exception as e:
        print(f"Error testing partial loading function: {e}")

else:
    print("Cannot test partial loading function - no old checkpoint available")

print("=" * 60)

In [25]:
# Examine the freeze mask from train.py
print("🔍 FREEZE MASK ANALYSIS")
print("=" * 60)

if new_config.get('frozen', False):
    print("Analyzing the freeze mask logic from create_frozen_train_state...")
    
    # Let's recreate the freeze mask logic to understand what should be frozen
    from flax import traverse_util
    
    def _should_freeze(path, v) -> bool:
        """Recreate the freeze logic from train.py"""
        # Based on the train.py code, this function should return True for frozen params
        if "fp_adapter" in path:
            return False  # fp_adapter should NOT be frozen (can be trained)
        return True  # Everything else should be frozen
    
    # Create the freeze mask
    mask = traverse_util.path_aware_map(_should_freeze, new_train_state.params)
    
    # The train.py code also manually sets some parameters to False (trainable)
    if 'cond_embeddings' in mask and 'cond_dense_0' in mask['cond_embeddings']:
        mask['cond_embeddings']['cond_dense_0']['kernel'] = False
        mask['cond_embeddings']['cond_dense_0']['bias'] = False
    
    print("Freeze mask analysis:")
    
    def print_mask(mask_tree, params_tree, path="", indent=0):
        """Print the freeze mask with parameter info."""
        prefix = "  " * indent
        
        if isinstance(mask_tree, dict) and isinstance(params_tree, dict):
            for key in mask_tree.keys():
                if key in params_tree:
                    new_path = f"{path}.{key}" if path else key
                    print_mask(mask_tree[key], params_tree[key], new_path, indent)
        elif isinstance(mask_tree, bool):
            # This is a leaf node with a boolean mask value
            if hasattr(params_tree, 'shape'):
                param_count = int(jnp.prod(jnp.array(params_tree.shape)))
                status = "🧊 FROZEN" if mask_tree else "🔥 TRAINABLE"
                print(f"{prefix}{path}: {status} (shape: {params_tree.shape}, params: {param_count:,})")
    
    print_mask(mask, new_train_state.params)
    
    # Count frozen vs trainable parameters
    def count_params(mask_tree, params_tree):
        """Count frozen vs trainable parameters."""
        frozen_count = 0
        trainable_count = 0
        
        def count_recursive(mask_val, param_val):
            nonlocal frozen_count, trainable_count
            
            if isinstance(mask_val, dict) and isinstance(param_val, dict):
                for key in mask_val.keys():
                    if key in param_val:
                        count_recursive(mask_val[key], param_val[key])
            elif isinstance(mask_val, bool) and hasattr(param_val, 'shape'):
                param_count = int(jnp.prod(jnp.array(param_val.shape)))
                if mask_val:
                    frozen_count += param_count
                else:
                    trainable_count += param_count
        
        count_recursive(mask_tree, params_tree)
        return frozen_count, trainable_count
    
    frozen_params, trainable_params = count_params(mask, new_train_state.params)
    total_params = frozen_params + trainable_params
    
    print(f"\nParameter count summary:")
    print(f"  🧊 Frozen parameters: {frozen_params:,} ({frozen_params/total_params*100:.1f}%)")
    print(f"  🔥 Trainable parameters: {trainable_params:,} ({trainable_params/total_params*100:.1f}%)")
    print(f"  📊 Total parameters: {total_params:,}")
    
    # Verify this matches what we expect for finetuning
    print(f"\nExpected trainable components:")
    print(f"  - fp_adapter: Should be trainable (new fingerprint processing layers)")
    print(f"  - cond_embeddings.cond_dense_0: Should be trainable (first conditioning layer)")
    print(f"  - Everything else: Should be frozen (pretrained weights)")

else:
    print("Model is not in frozen mode - no freeze mask to analyze")

print("=" * 60)

🔍 FREEZE MASK ANALYSIS
Analyzing the freeze mask logic from create_frozen_train_state...
Freeze mask analysis:
classifier.CondEmbedding_0.Dense_0.bias: 🧊 FROZEN (shape: (64,), params: 64)
classifier.CondEmbedding_0.Dense_0.kernel: 🧊 FROZEN (shape: (256, 64), params: 16,384)
classifier.CondEmbedding_0.dense0.bias: 🧊 FROZEN (shape: (256,), params: 256)
classifier.CondEmbedding_0.dense0.kernel: 🧊 FROZEN (shape: (128, 256), params: 32,768)
classifier.Embed_0.embedding: 🧊 FROZEN (shape: (1025, 64), params: 65,600)
classifier.Transformer_0.Dense_0.bias: 🧊 FROZEN (shape: (1024,), params: 1,024)
classifier.Transformer_0.Dense_0.kernel: 🧊 FROZEN (shape: (64, 1024), params: 65,536)
classifier.Transformer_0.Dense_1.bias: 🧊 FROZEN (shape: (2048,), params: 2,048)
classifier.Transformer_0.Dense_1.kernel: 🧊 FROZEN (shape: (64, 2048), params: 131,072)
classifier.Transformer_0.Dense_2.kernel: 🧊 FROZEN (shape: (1024, 1024), params: 1,048,576)
classifier.Transformer_0.TransformerBlock_0.Dense_0.bias: 🧊 F

In [27]:
# Debug the actual freeze mask in the optimizer
print("🔧 DEBUGGING FREEZE MASK IN OPTIMIZER")
print("=" * 60)

# Import optax for freeze functionality
import optax

if new_config.get('frozen', False) and new_train_state is not None:
    print("Examining the actual freeze mask used by the optimizer...")
    
    # Try to extract the freeze mask from the optimizer state
    opt_state = new_train_state.opt_state
    
    # The freeze mask should be in the MaskedState part of the optimizer
    print(f"Optimizer state structure: {type(opt_state)}")
    
    if hasattr(opt_state, '__iter__') and len(opt_state) > 0:
        print(f"Optimizer state has {len(opt_state)} components:")
        for i, component in enumerate(opt_state):
            print(f"  Component {i}: {type(component)}")
            
            # Look for MaskedState which should contain the freeze mask
            if hasattr(component, '__class__') and 'Masked' in str(component.__class__):
                print(f"    🧊 Found MaskedState at component {i}")
                
                # Try to access the mask
                if hasattr(component, 'mask'):
                    print(f"    Mask type: {type(component.mask)}")
                    
                    # Print a summary of the mask
                    if component.mask is not None:
                        def print_mask_summary(mask_tree, path=""):
                            """Print a summary of the freeze mask."""
                            if isinstance(mask_tree, dict):
                                for key, value in mask_tree.items():
                                    new_path = f"{path}.{key}" if path else key
                                    print_mask_summary(value, new_path)
                            elif isinstance(mask_tree, bool):
                                status = "🧊 FROZEN" if mask_tree else "🔥 TRAINABLE"
                                print(f"      {path}: {status}")
                        
                        print("    Actual freeze mask in optimizer:")
                        print_mask_summary(component.mask)
                    else:
                        print("    ⚠️  Mask is None!")
                else:
                    print("    ⚠️  No mask attribute found in MaskedState")
    
    # Also let's manually recreate what should happen and compare
    print(f"\n🔍 MANUAL FREEZE MASK RECREATION")
    print("Recreating the freeze logic step by step...")
    
    # Step 1: Create the initial mask using path_aware_map
    from flax import traverse_util
    
    def debug_should_freeze(path, v):
        """Debug version of _should_freeze"""
        path_str = '.'.join(str(p.key) if hasattr(p, 'key') else str(p) for p in path)
        
        if "fp_adapter" in path_str:
            print(f"    {path_str}: fp_adapter found -> TRAINABLE (False)")
            return False
        else:
            print(f"    {path_str}: no fp_adapter -> FROZEN (True)")
            return True
    
    print("  Creating initial mask with path_aware_map:")
    mask = traverse_util.path_aware_map(debug_should_freeze, new_train_state.params)
    
    # Step 2: Manual overrides
    print("  \n  Applying manual overrides:")
    if 'cond_embeddings' in mask and 'cond_dense_0' in mask['cond_embeddings']:
        print("    Setting cond_embeddings.cond_dense_0.kernel to TRAINABLE")
        print("    Setting cond_embeddings.cond_dense_0.bias to TRAINABLE")
        mask['cond_embeddings']['cond_dense_0']['kernel'] = False
        mask['cond_embeddings']['cond_dense_0']['bias'] = False
    
    # Step 3: Create the freezer
    print("  \n  Creating freezer with optax.transforms.freeze(mask)")
    freezer = optax.transforms.freeze(mask)
    
    # Step 4: Test the freezer
    print("  \n  Testing freezer behavior:")
    dummy_params = new_train_state.params
    dummy_grads = jax.tree_util.tree_map(lambda x: jnp.ones_like(x) * 0.01, dummy_params)
    
    # Apply freezer to gradients
    freezer_state = freezer.init(dummy_params)
    frozen_updates, _ = freezer.update(dummy_grads, freezer_state, dummy_params)
    
    # Check which updates are actually frozen
    print("    Freeze results:")
    for param_group in ['classifier', 'cond_embeddings', 'fp_adapter']:
        if param_group in frozen_updates:
            update_leaves = jax.tree_util.tree_leaves(frozen_updates[param_group])
            total_update_norm = float(jnp.sqrt(sum(jnp.sum(jnp.square(leaf)) for leaf in update_leaves)))
            
            if total_update_norm < 1e-10:
                status = "🧊 FROZEN"
            else:
                status = "🔥 TRAINABLE"
            
            print(f"      {param_group}: update_norm={total_update_norm:.6f} {status}")

print("=" * 60)

🔧 DEBUGGING FREEZE MASK IN OPTIMIZER
Examining the actual freeze mask used by the optimizer...
Optimizer state structure: <class 'tuple'>
Optimizer state has 3 components:
  Component 0: <class 'optax._src.base.EmptyState'>
  Component 1: <class 'tuple'>
  Component 2: <class 'optax.transforms._masking.MaskedState'>
    🧊 Found MaskedState at component 2
    ⚠️  No mask attribute found in MaskedState

🔍 MANUAL FREEZE MASK RECREATION
Recreating the freeze logic step by step...
  Creating initial mask with path_aware_map:
    classifier.CondEmbedding_0.Dense_0.bias: no fp_adapter -> FROZEN (True)
    classifier.CondEmbedding_0.Dense_0.kernel: no fp_adapter -> FROZEN (True)
    classifier.CondEmbedding_0.dense0.bias: no fp_adapter -> FROZEN (True)
    classifier.CondEmbedding_0.dense0.kernel: no fp_adapter -> FROZEN (True)
    classifier.Embed_0.embedding: no fp_adapter -> FROZEN (True)
    classifier.Transformer_0.Dense_0.bias: no fp_adapter -> FROZEN (True)
    classifier.Transformer_0.

In [28]:
# Focused fp_adapter freeze diagnostic
print("🎯 FP_ADAPTER FREEZE DIAGNOSTIC")
print("=" * 60)

if new_config.get('frozen', False) and new_train_state is not None:
    import optax
    from flax import traverse_util
    
    # Step 1: Check if fp_adapter exists in parameters
    print("1️⃣ Checking fp_adapter existence:")
    if 'fp_adapter' in new_train_state.params:
        print("✅ fp_adapter found in parameters")
        fp_params = new_train_state.params['fp_adapter']
        print(f"   fp_adapter structure: {list(fp_params.keys())}")
    else:
        print("❌ fp_adapter NOT found in parameters!")
        print(f"   Available parameter groups: {list(new_train_state.params.keys())}")
    
    # Step 2: Test the _should_freeze function logic
    print("\n2️⃣ Testing _should_freeze logic:")
    
    def test_should_freeze(path_str):
        if "fp_adapter" in path_str:
            return False  # Should be trainable
        return True  # Should be frozen
    
    test_paths = [
        "fp_adapter.fingerprint_adapter_dense.kernel",
        "fp_adapter.fingerprint_adapter_out.bias", 
        "classifier.Embed_0.embedding",
        "cond_embeddings.cond_dense_0.kernel"
    ]
    
    for path in test_paths:
        should_freeze = test_should_freeze(path)
        status = "🧊 FROZEN" if should_freeze else "🔥 TRAINABLE"
        print(f"   {path}: {status}")
    
    # Step 3: Create actual freeze mask and test fp_adapter specifically
    print("\n3️⃣ Creating and testing actual freeze mask:")
    
    def actual_should_freeze(path, v):
        path_str = '.'.join(str(p.key) if hasattr(p, 'key') else str(p) for p in path)
        return "fp_adapter" not in path_str  # True = frozen, False = trainable
    
    mask = traverse_util.path_aware_map(actual_should_freeze, new_train_state.params)
    
    # Override cond_dense_0
    if 'cond_embeddings' in mask and 'cond_dense_0' in mask['cond_embeddings']:
        mask['cond_embeddings']['cond_dense_0']['kernel'] = False
        mask['cond_embeddings']['cond_dense_0']['bias'] = False
    
    # Check fp_adapter mask values
    if 'fp_adapter' in mask:
        print("   fp_adapter mask values:")
        def print_fp_mask(tree, path=""):
            if isinstance(tree, dict):
                for k, v in tree.items():
                    print_fp_mask(v, f"{path}.{k}" if path else k)
            else:
                status = "🧊 FROZEN" if tree else "🔥 TRAINABLE"
                print(f"     {path}: {status}")
        
        print_fp_mask(mask['fp_adapter'], "fp_adapter")
    
    # Step 4: Test the freeze behavior directly
    print("\n4️⃣ Testing freeze behavior with dummy gradients:")
    
    freezer = optax.transforms.freeze(mask)
    dummy_grads = jax.tree_util.tree_map(lambda x: jnp.ones_like(x) * 0.1, new_train_state.params)
    
    freezer_state = freezer.init(new_train_state.params)
    frozen_updates, _ = freezer.update(dummy_grads, freezer_state, new_train_state.params)
    
    # Check fp_adapter specifically
    if 'fp_adapter' in frozen_updates:
        fp_update_leaves = jax.tree_util.tree_leaves(frozen_updates['fp_adapter'])
        fp_grad_leaves = jax.tree_util.tree_leaves(dummy_grads['fp_adapter'])
        
        total_grad_norm = float(jnp.sqrt(sum(jnp.sum(jnp.square(leaf)) for leaf in fp_grad_leaves)))
        total_update_norm = float(jnp.sqrt(sum(jnp.sum(jnp.square(leaf)) for leaf in fp_update_leaves)))
        
        print(f"   fp_adapter: grad_norm={total_grad_norm:.6f}, update_norm={total_update_norm:.6f}")
        
        if total_update_norm < 1e-10:
            print("   ❌ fp_adapter is FROZEN (updates are zero)")
        elif total_update_norm > 0.01:
            print("   ✅ fp_adapter is TRAINABLE (updates are non-zero)")
        else:
            print("   ⚠️  fp_adapter updates are very small")
        
        # Check individual parameters
        print("   Individual fp_adapter parameter updates:")
        for key in frozen_updates['fp_adapter']:
            if isinstance(frozen_updates['fp_adapter'][key], dict):
                for subkey in frozen_updates['fp_adapter'][key]:
                    update_val = frozen_updates['fp_adapter'][key][subkey]
                    norm = float(jnp.linalg.norm(jnp.ravel(update_val)))
                    status = "🧊 FROZEN" if norm < 1e-10 else "🔥 TRAINABLE"
                    print(f"     {key}.{subkey}: norm={norm:.6f} {status}")

print("=" * 60)

🎯 FP_ADAPTER FREEZE DIAGNOSTIC
1️⃣ Checking fp_adapter existence:
✅ fp_adapter found in parameters
   fp_adapter structure: ['fingerprint_adapter_dense', 'fingerprint_adapter_out']

2️⃣ Testing _should_freeze logic:
   fp_adapter.fingerprint_adapter_dense.kernel: 🔥 TRAINABLE
   fp_adapter.fingerprint_adapter_out.bias: 🔥 TRAINABLE
   classifier.Embed_0.embedding: 🧊 FROZEN
   cond_embeddings.cond_dense_0.kernel: 🧊 FROZEN

3️⃣ Creating and testing actual freeze mask:
   fp_adapter mask values:
     fp_adapter.fingerprint_adapter_dense.bias: 🔥 TRAINABLE
     fp_adapter.fingerprint_adapter_dense.kernel: 🔥 TRAINABLE
     fp_adapter.fingerprint_adapter_out.bias: 🔥 TRAINABLE
     fp_adapter.fingerprint_adapter_out.kernel: 🔥 TRAINABLE

4️⃣ Testing freeze behavior with dummy gradients:
   fp_adapter: grad_norm=579.336609, update_norm=579.336609
   ✅ fp_adapter is TRAINABLE (updates are non-zero)
   Individual fp_adapter parameter updates:
     fingerprint_adapter_dense.bias: norm=6.400000 🔥 TRAI

In [29]:
# Verify the original frozen behavior analysis
print("🔍 VERIFYING ORIGINAL FROZEN ANALYSIS")
print("=" * 60)

if new_config.get('frozen', False) and new_train_state is not None:
    # Recreate the exact test from the frozen behavior analysis
    print("Recreating the exact test from the original frozen behavior analysis...")
    
    # Create dummy gradients (same as original)
    dummy_grads = jax.tree_util.tree_map(
        lambda x: jnp.ones_like(x) * 0.01,  # Small gradient
        new_train_state.params
    )
    
    # Apply optimizer update (same as original)
    try:
        updates, new_opt_state = new_optimizer.update(
            dummy_grads, 
            new_train_state.opt_state, 
            new_train_state.params
        )
        
        print("✅ Optimizer update successful")
        
        # High-level parameter group analysis (same as original)
        print(f"\nHigh-level parameter group analysis:")
        for param_group in ['classifier', 'cond_embeddings', 'fp_adapter']:
            if param_group in dummy_grads and param_group in updates:
                # Calculate total norms for each parameter group
                grad_leaves = jax.tree_util.tree_leaves(dummy_grads[param_group])
                update_leaves = jax.tree_util.tree_leaves(updates[param_group])
                
                total_grad_norm = float(jnp.sqrt(sum(jnp.sum(jnp.square(leaf)) for leaf in grad_leaves)))
                total_update_norm = float(jnp.sqrt(sum(jnp.sum(jnp.square(leaf)) for leaf in update_leaves)))
                ratio = total_update_norm / total_grad_norm if total_grad_norm > 0 else 0
                
                if total_update_norm < 1e-8:
                    status = "🧊 FROZEN"
                elif ratio < 1e-4:
                    status = "🔒 MOSTLY FROZEN"
                else:
                    status = "🔥 ACTIVE"
                
                print(f"  {param_group}: total_grad={total_grad_norm:.6f}, total_update={total_update_norm:.6f}, ratio={ratio:.6f} {status}")
                
                # Additional debugging for fp_adapter
                if param_group == 'fp_adapter':
                    print(f"    🔍 fp_adapter detailed analysis:")
                    print(f"      Gradient magnitude: {total_grad_norm:.6f}")
                    print(f"      Update magnitude: {total_update_norm:.6f}")
                    print(f"      Ratio (update/grad): {ratio:.6f}")
                    print(f"      Number of gradient leaves: {len(grad_leaves)}")
                    print(f"      Number of update leaves: {len(update_leaves)}")
                    
                    # Check individual leaves
                    for i, (grad_leaf, update_leaf) in enumerate(zip(grad_leaves, update_leaves)):
                        leaf_grad_norm = float(jnp.linalg.norm(jnp.ravel(grad_leaf)))
                        leaf_update_norm = float(jnp.linalg.norm(jnp.ravel(update_leaf)))
                        leaf_ratio = leaf_update_norm / leaf_grad_norm if leaf_grad_norm > 0 else 0
                        print(f"        Leaf {i}: grad={leaf_grad_norm:.6f}, update={leaf_update_norm:.6f}, ratio={leaf_ratio:.6f}")
        
    except Exception as e:
        print(f"❌ Error during optimizer update: {e}")
        import traceback
        traceback.print_exc()

print("=" * 60)

🔍 VERIFYING ORIGINAL FROZEN ANALYSIS
Recreating the exact test from the original frozen behavior analysis...
✅ Optimizer update successful

High-level parameter group analysis:
  classifier: total_grad=126.608932, total_update=0.000000, ratio=0.000000 🧊 FROZEN
  cond_embeddings: total_grad=14.972701, total_update=0.000000, ratio=0.000000 🧊 FROZEN
  fp_adapter: total_grad=57.933586, total_update=0.000000, ratio=0.000000 🧊 FROZEN
    🔍 fp_adapter detailed analysis:
      Gradient magnitude: 57.933586
      Update magnitude: 0.000000
      Ratio (update/grad): 0.000000
      Number of gradient leaves: 4
      Number of update leaves: 4
        Leaf 0: grad=0.640000, update=0.000000, ratio=0.000000
        Leaf 1: grad=40.960011, update=0.000000, ratio=0.000000
        Leaf 2: grad=0.640000, update=0.000000, ratio=0.000000
        Leaf 3: grad=40.960011, update=0.000000, ratio=0.000000


In [30]:
# Test optimizer chain order issue
print("🔧 OPTIMIZER CHAIN ORDER INVESTIGATION")
print("=" * 60)

if new_config.get('frozen', False) and new_train_state is not None:
    import optax
    from flax import traverse_util
    
    # Recreate the freeze mask
    def _should_freeze(path, v):
        path_str = '.'.join(str(p.key) if hasattr(p, 'key') else str(p) for p in path)
        return "fp_adapter" not in path_str
    
    mask = traverse_util.path_aware_map(_should_freeze, new_train_state.params)
    mask["cond_embeddings"]["cond_dense_0"]["kernel"] = False
    mask["cond_embeddings"]["cond_dense_0"]["bias"] = False
    
    print("🧪 Testing different optimizer chain orders:")
    
    # Test 1: Current order (clip -> adamw -> freeze)
    print("\n1️⃣ Current order: clip -> adamw -> freeze")
    
    optimizer_current = optax.chain(
        optax.clip(1.0),  # Use a fixed clip value for testing
        optax.adamw(0.001, b1=0.9, b2=0.999, weight_decay=0.01),
        optax.transforms.freeze(mask),
    )
    
    # Test 2: Alternative order (freeze -> clip -> adamw) 
    print("2️⃣ Alternative order: freeze -> clip -> adamw")
    
    optimizer_alt = optax.chain(
        optax.transforms.freeze(mask),
        optax.clip(1.0),
        optax.adamw(0.001, b1=0.9, b2=0.999, weight_decay=0.01),
    )
    
    # Test both optimizers
    dummy_params = new_train_state.params
    dummy_grads = jax.tree_util.tree_map(lambda x: jnp.ones_like(x) * 0.01, dummy_params)
    
    for i, (name, optimizer) in enumerate([("Current", optimizer_current), ("Alternative", optimizer_alt)], 1):
        print(f"\n  {i}️⃣ Testing {name} optimizer:")
        
        opt_state = optimizer.init(dummy_params)
        updates, _ = optimizer.update(dummy_grads, opt_state, dummy_params)
        
        # Check fp_adapter updates
        if 'fp_adapter' in updates:
            fp_update_leaves = jax.tree_util.tree_leaves(updates['fp_adapter'])
            total_update_norm = float(jnp.sqrt(sum(jnp.sum(jnp.square(leaf)) for leaf in fp_update_leaves)))
            
            if total_update_norm < 1e-10:
                print(f"    ❌ fp_adapter: FROZEN (update_norm={total_update_norm:.6f})")
            else:
                print(f"    ✅ fp_adapter: TRAINABLE (update_norm={total_update_norm:.6f})")
        
        # Check other parameter groups
        for param_group in ['classifier', 'cond_embeddings']:
            if param_group in updates:
                leaves = jax.tree_util.tree_leaves(updates[param_group])
                norm = float(jnp.sqrt(sum(jnp.sum(jnp.square(leaf)) for leaf in leaves)))
                
                if norm < 1e-10:
                    print(f"    ✅ {param_group}: FROZEN (update_norm={norm:.6f})")
                else:
                    print(f"    ❌ {param_group}: NOT FROZEN (update_norm={norm:.6f})")
    
    # Test 3: Check if there's a mask inversion issue
    print(f"\n3️⃣ Checking for mask inversion issue:")
    print("Optax freeze documentation says:")
    print("- mask=True means FROZEN (no updates)")
    print("- mask=False means TRAINABLE (allow updates)")
    
    print(f"\nOur mask values:")
    print(f"  fp_adapter mask: {mask.get('fp_adapter', 'NOT FOUND')}")
    if 'fp_adapter' in mask:
        for key in mask['fp_adapter']:
            if isinstance(mask['fp_adapter'][key], dict):
                for subkey in mask['fp_adapter'][key]:
                    val = mask['fp_adapter'][key][subkey]
                    status = "FROZEN" if val else "TRAINABLE"
                    print(f"    fp_adapter.{key}.{subkey}: {val} ({status})")
    
    print(f"\n  classifier.Embed_0.embedding mask: {mask['classifier']['Embed_0']['embedding']} (should be True=FROZEN)")
    print(f"  cond_embeddings.cond_dense_0.kernel mask: {mask['cond_embeddings']['cond_dense_0']['kernel']} (should be False=TRAINABLE)")

print("=" * 60)

🔧 OPTIMIZER CHAIN ORDER INVESTIGATION
🧪 Testing different optimizer chain orders:

1️⃣ Current order: clip -> adamw -> freeze
2️⃣ Alternative order: freeze -> clip -> adamw

  1️⃣ Testing Current optimizer:
    ✅ fp_adapter: TRAINABLE (update_norm=5.793282)
    ✅ classifier: FROZEN (update_norm=0.000000)
    ❌ cond_embeddings: NOT FROZEN (update_norm=1.448497)

  2️⃣ Testing Alternative optimizer:
    ✅ fp_adapter: TRAINABLE (update_norm=5.793282)
    ❌ classifier: NOT FROZEN (update_norm=0.002546)
    ❌ cond_embeddings: NOT FROZEN (update_norm=1.448497)

3️⃣ Checking for mask inversion issue:
Optax freeze documentation says:
- mask=True means FROZEN (no updates)
- mask=False means TRAINABLE (allow updates)

Our mask values:
  fp_adapter mask: {'fingerprint_adapter_dense': {'bias': False, 'kernel': False}, 'fingerprint_adapter_out': {'bias': False, 'kernel': False}}
    fp_adapter.fingerprint_adapter_dense.bias: False (TRAINABLE)
    fp_adapter.fingerprint_adapter_dense.kernel: False (

In [31]:
# Compare with actual train state optimizer
print("🔍 ACTUAL TRAIN STATE OPTIMIZER ANALYSIS")
print("=" * 60)

if new_config.get('frozen', False) and new_train_state is not None:
    print("Comparing test optimizers vs actual train state optimizer...")
    
    # Test the actual optimizer from the train state
    print("\n🎯 Testing actual train state optimizer:")
    
    dummy_grads = jax.tree_util.tree_map(lambda x: jnp.ones_like(x) * 0.01, new_train_state.params)
    
    try:
        updates, _ = new_optimizer.update(
            dummy_grads, 
            new_train_state.opt_state, 
            new_train_state.params
        )
        
        print("✅ Optimizer update successful")
        
        # Check each parameter group
        for param_group in ['classifier', 'cond_embeddings', 'fp_adapter']:
            if param_group in updates:
                leaves = jax.tree_util.tree_leaves(updates[param_group])
                norm = float(jnp.sqrt(sum(jnp.sum(jnp.square(leaf)) for leaf in leaves)))
                
                if norm < 1e-10:
                    status = "🧊 FROZEN"
                elif norm < 1e-6:
                    status = "🔒 MOSTLY FROZEN"
                else:
                    status = "🔥 TRAINABLE"
                
                print(f"  {param_group}: update_norm={norm:.6f} {status}")
                
    except Exception as e:
        print(f"❌ Error with actual optimizer: {e}")
    
    # Check the configuration used to create the train state
    print(f"\n📋 Configuration check:")
    print(f"  frozen: {new_config.get('frozen', False)}")
    print(f"  clip: {new_config.get('clip', 'not set')}")
    print(f"  learning_rate: {new_config.get('learning_rate', 'not set')}")
    print(f"  b2: {new_config.get('b2', 'not set')}")
    print(f"  weight_decay: {new_config.get('weight_decay', 'not set')}")
    
    # Try to compare the optimizer structure
    print(f"\n🔧 Optimizer structure comparison:")
    print(f"  Actual optimizer type: {type(new_optimizer)}")
    
    # Test if we can recreate the exact same optimizer
    from md4 import train
    
    print(f"\n🧪 Recreation test:")
    try:
        # Use the same parameters as the create_frozen_train_state function
        schedule_fn_test = functools.partial(
            train.get_learning_rate,
            base_learning_rate=new_config.learning_rate,
            num_steps=1000000,
            warmup_steps=new_config.warmup_steps,
            schedule_type=new_config.learning_rate_schedule,
        )
        
        # This should match what's in create_frozen_train_state exactly
        test_rng = jax.random.PRNGKey(42)
        _, test_optimizer, test_train_state, _ = train.create_frozen_train_state(
            new_config,
            test_rng,
            input_shape=(4,) + (new_config.max_length,),  # Small batch size
            schedule_fn=schedule_fn_test,
        )
        
        # Test this recreated optimizer
        test_updates, _ = test_optimizer.update(
            dummy_grads, 
            test_train_state.opt_state, 
            test_train_state.params
        )
        
        # Check fp_adapter in recreated optimizer
        if 'fp_adapter' in test_updates:
            leaves = jax.tree_util.tree_leaves(test_updates['fp_adapter'])
            norm = float(jnp.sqrt(sum(jnp.sum(jnp.square(leaf)) for leaf in leaves)))
            
            if norm < 1e-10:
                print(f"  Recreated optimizer: fp_adapter FROZEN (norm={norm:.6f})")
            else:
                print(f"  Recreated optimizer: fp_adapter TRAINABLE (norm={norm:.6f})")
        
    except Exception as e:
        print(f"  ❌ Error recreating optimizer: {e}")
        import traceback
        traceback.print_exc()

print("=" * 60)

🔍 ACTUAL TRAIN STATE OPTIMIZER ANALYSIS
Comparing test optimizers vs actual train state optimizer...

🎯 Testing actual train state optimizer:
✅ Optimizer update successful
  classifier: update_norm=0.000000 🧊 FROZEN
  cond_embeddings: update_norm=0.000000 🧊 FROZEN
  fp_adapter: update_norm=0.000000 🧊 FROZEN

📋 Configuration check:
  frozen: True
  clip: 0.0
  learning_rate: 1e-05
  b2: 0.999
  weight_decay: 1e-06

🔧 Optimizer structure comparison:
  Actual optimizer type: <class 'optax._src.base.GradientTransformationExtraArgs'>

🧪 Recreation test:
  Recreated optimizer: fp_adapter FROZEN (norm=0.000000)


In [32]:
# Test the clip=0.0 issue
print("🐛 CLIP=0.0 BUG INVESTIGATION")
print("=" * 60)

if new_config.get('frozen', False) and new_train_state is not None:
    import optax
    from flax import traverse_util
    
    # Recreate the exact freeze mask
    def _should_freeze(path, v):
        path_str = '.'.join(str(p.key) if hasattr(p, 'key') else str(p) for p in path)
        return "fp_adapter" not in path_str
    
    mask = traverse_util.path_aware_map(_should_freeze, new_train_state.params)
    mask["cond_embeddings"]["cond_dense_0"]["kernel"] = False
    mask["cond_embeddings"]["cond_dense_0"]["bias"] = False
    
    freezer = optax.transforms.freeze(mask)
    
    # Test different clip configurations
    clip_value = new_config.get('clip', 0.0)
    print(f"Configuration clip value: {clip_value}")
    
    print(f"\n🧪 Testing different clip configurations:")
    
    # Test 1: With clip=0.0 (identity)
    print(f"\n1️⃣ With clip=0.0 (using optax.identity()):")
    optimizer_identity = optax.chain(
        optax.identity(),
        optax.adamw(1e-5, b1=0.9, b2=0.999, weight_decay=1e-6),
        freezer,
    )
    
    # Test 2: Without any clipping
    print(f"2️⃣ Without any clipping at all:")
    optimizer_no_clip = optax.chain(
        optax.adamw(1e-5, b1=0.9, b2=0.999, weight_decay=1e-6),
        freezer,
    )
    
    # Test 3: With actual clipping (clip > 0)
    print(f"3️⃣ With actual clipping (clip=1.0):")
    optimizer_with_clip = optax.chain(
        optax.clip(1.0),
        optax.adamw(1e-5, b1=0.9, b2=0.999, weight_decay=1e-6),
        freezer,
    )
    
    # Test all three
    dummy_params = new_train_state.params
    dummy_grads = jax.tree_util.tree_map(lambda x: jnp.ones_like(x) * 0.01, dummy_params)
    
    for name, optimizer in [
        ("identity", optimizer_identity),
        ("no_clip", optimizer_no_clip), 
        ("with_clip", optimizer_with_clip)
    ]:
        print(f"\n  Testing {name} optimizer:")
        
        try:
            opt_state = optimizer.init(dummy_params)
            updates, _ = optimizer.update(dummy_grads, opt_state, dummy_params)
            
            # Check fp_adapter
            if 'fp_adapter' in updates:
                leaves = jax.tree_util.tree_leaves(updates['fp_adapter'])
                norm = float(jnp.sqrt(sum(jnp.sum(jnp.square(leaf)) for leaf in leaves)))
                
                if norm < 1e-10:
                    print(f"    ❌ fp_adapter: FROZEN (norm={norm:.6f})")
                else:
                    print(f"    ✅ fp_adapter: TRAINABLE (norm={norm:.6f})")
            
        except Exception as e:
            print(f"    ❌ Error: {e}")
    
    # Test 4: Check if the issue is with the mask itself
    print(f"\n4️⃣ Testing freeze mask directly:")
    
    # Test freezer alone
    freezer_state = freezer.init(dummy_params)
    frozen_updates_direct, _ = freezer.update(dummy_grads, freezer_state, dummy_params)
    
    if 'fp_adapter' in frozen_updates_direct:
        leaves = jax.tree_util.tree_leaves(frozen_updates_direct['fp_adapter'])
        norm = float(jnp.sqrt(sum(jnp.sum(jnp.square(leaf)) for leaf in leaves)))
        
        if norm < 1e-10:
            print(f"  ❌ Direct freezer: fp_adapter FROZEN (norm={norm:.6f})")
        else:
            print(f"  ✅ Direct freezer: fp_adapter TRAINABLE (norm={norm:.6f})")
    
    # Investigate the mask values more carefully
    print(f"\n5️⃣ Detailed mask investigation:")
    
    def print_detailed_mask(mask_tree, path=""):
        if isinstance(mask_tree, dict):
            for key, value in mask_tree.items():
                print_detailed_mask(value, f"{path}.{key}" if path else key)
        else:
            print(f"  {path}: {mask_tree} ({'FROZEN' if mask_tree else 'TRAINABLE'})")
    
    print("  Complete freeze mask:")
    print_detailed_mask(mask)

print("=" * 60)

🐛 CLIP=0.0 BUG INVESTIGATION
Configuration clip value: 0.0

🧪 Testing different clip configurations:

1️⃣ With clip=0.0 (using optax.identity()):
2️⃣ Without any clipping at all:
3️⃣ With actual clipping (clip=1.0):

  Testing identity optimizer:
    ✅ fp_adapter: TRAINABLE (norm=0.057933)

  Testing no_clip optimizer:
    ✅ fp_adapter: TRAINABLE (norm=0.057933)

  Testing with_clip optimizer:
    ✅ fp_adapter: TRAINABLE (norm=0.057933)

4️⃣ Testing freeze mask directly:
  ✅ Direct freezer: fp_adapter TRAINABLE (norm=57.933586)

5️⃣ Detailed mask investigation:
  Complete freeze mask:
  classifier.CondEmbedding_0.Dense_0.bias: True (FROZEN)
  classifier.CondEmbedding_0.Dense_0.kernel: True (FROZEN)
  classifier.CondEmbedding_0.dense0.bias: True (FROZEN)
  classifier.CondEmbedding_0.dense0.kernel: True (FROZEN)
  classifier.Embed_0.embedding: True (FROZEN)
  classifier.Transformer_0.Dense_0.bias: True (FROZEN)
  classifier.Transformer_0.Dense_0.kernel: True (FROZEN)
  classifier.Transfo

In [None]:
# Investigate potential issues in train.py
print("🐛 POTENTIAL ISSUES INVESTIGATION")
print("=" * 60)

print("Analyzing the actual train.py implementation...")

# Check if the _should_freeze function is properly implemented
print("\n1️⃣ FREEZE FUNCTION ANALYSIS")
print("Checking the _should_freeze function implementation...")

# Let's look at what the function should do according to the source
print("Expected behavior:")
print("- fp_adapter parameters: should be TRAINABLE (return False)")
print("- All other parameters: should be FROZEN (return True)")
print("- cond_embeddings.cond_dense_0: manually set to TRAINABLE")

# Simulate the freeze logic
def simulate_should_freeze(path):
    """Simulate the _should_freeze function from train.py"""
    if "fp_adapter" in path:
        return False  # Trainable
    return True  # Frozen

# Test it on our parameter paths
test_paths = [
    "classifier.Embed_0.embedding",
    "classifier.Transformer_0.Dense_0.kernel", 
    "cond_embeddings.cond_dense_0.kernel",
    "cond_embeddings.cond_dense_1.kernel",
    "fp_adapter.fingerprint_adapter_dense.kernel",
    "fp_adapter.fingerprint_adapter_out.bias"
]

print("\nTesting freeze logic:")
for path in test_paths:
    should_freeze = simulate_should_freeze(path)
    # Special case for cond_dense_0 (manually set to trainable)
    if "cond_embeddings" in path and "cond_dense_0" in path:
        should_freeze = False  # Manually overridden
    
    status = "🧊 FROZEN" if should_freeze else "🔥 TRAINABLE"
    print(f"  {path}: {status}")

print("\n2️⃣ POTENTIAL ISSUES IDENTIFIED")

issues_found = []

# Check if the model has the expected parameter groups
if new_train_state is not None:
    param_keys = list(new_train_state.params.keys())
    
    if 'fp_adapter' not in param_keys:
        issues_found.append("❌ fp_adapter not found in model parameters")
    else:
        print("✅ fp_adapter found in model parameters")
    
    if 'cond_embeddings' not in param_keys:
        issues_found.append("❌ cond_embeddings not found in model parameters")
    else:
        print("✅ cond_embeddings found in model parameters")
        
        # Check for cond_dense_0
        if ('cond_embeddings' in new_train_state.params and 
            'cond_dense_0' in new_train_state.params['cond_embeddings']):
            print("✅ cond_dense_0 found in cond_embeddings")
        else:
            issues_found.append("❌ cond_dense_0 not found in cond_embeddings")

# Check for potential configuration issues
print("\n3️⃣ CONFIGURATION COMPATIBILITY")

if 'frozen' in new_config and new_config.frozen:
    print("✅ Model is configured for frozen training")
    
    if 'fingerprint_dim' in new_config and new_config.fingerprint_dim > 0:
        print("✅ Fingerprint conditioning is enabled")
    else:
        issues_found.append("⚠️  Fingerprint conditioning may not be properly configured")
        
    if 'partial_load' in new_config and new_config.partial_load:
        print("✅ Partial loading is enabled")
    else:
        issues_found.append("⚠️  Partial loading is not enabled in config")
else:
    issues_found.append("⚠️  Model is not configured for frozen training")

print("\n4️⃣ SUMMARY OF FINDINGS")

if not issues_found:
    print("✅ No critical issues found in the configuration!")
    print("The partial loading and frozen training setup appears to be correct.")
else:
    print("Issues that need attention:")
    for issue in issues_found:
        print(f"  {issue}")

print("\n5️⃣ RECOMMENDATIONS FOR DEBUGGING")
print("If you're experiencing issues with partial loading or frozen training:")
print("1. Verify that old checkpoints contain compatible parameter structures")
print("2. Check that the fingerprint dimensions match between old and new models")
print("3. Ensure that the fp_adapter layers are present in the new model architecture")
print("4. Verify that only the intended parameters are being updated during training")
print("5. Monitor training loss to ensure that trainable parameters are learning")

print("=" * 60)