In [None]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "1"

In [None]:
import orbax.checkpoint as ocp
from flax import traverse_util
import pathlib

ckpt_path = "checkpoints/pi0_fql_libero_lora_finetune/exp8_plz/29999/params"
ckpt_path = pathlib.Path(ckpt_path).resolve()

with ocp.PyTreeCheckpointer() as ckptr:
    metadata = ckptr.metadata(ckpt_path)
    print("metadata keys:", metadata.tree.keys())  # 어떤 key가 있는지 확인

    # 예시: 여러 key가 있을 때
    # item = {k: metadata[k] for k in metadata.tree.keys()}
    # params = ckptr.restore(ckpt_path, ocp.args.PyTreeRestore(item=item))

    # 예시: 단일 트리라면
    params = ckptr.restore(ckpt_path)

flat_params = traverse_util.flatten_dict(params)
for k, v in flat_params.items():
    print(k, type(v), getattr(v, 'shape', None), getattr(v, 'dtype', None))

In [None]:
import orbax.checkpoint as ocp
from flax import traverse_util
import pathlib

ckpt_path = "checkpoints/pi0_fql_libero_lora_finetune/exp8_plz/29999/params"
ckpt_path = pathlib.Path(ckpt_path).resolve()

with ocp.PyTreeCheckpointer() as ckptr:
    params = ckptr.restore(ckpt_path)

# 트리 전체를 dict로 출력 (작은 모델일 때만!)
import pprint
pprint.pprint(params)

In [None]:
from collections.abc import Sequence
import dataclasses
import logging
import pathlib
from typing import Any

import jax.numpy as jnp

import src.openpi.models.model as _model
import src.openpi.policies.policy as _policy
import src.openpi.shared.download as download
from src.openpi.training import checkpoints as _checkpoints
from src.openpi.training import config as _config
import src.openpi.transforms as transforms


import abc
from collections.abc import Sequence
import dataclasses
import enum
import logging
import pathlib
from typing import Generic, TypeVar

import augmax
from flax import nnx
from flax import struct
from flax import traverse_util
import jax
import jax.numpy as jnp
import numpy as np
import orbax.checkpoint as ocp

from src.openpi.shared import image_tools
import src.openpi.shared.array_typing as at

In [None]:
from src.openpi.policies.policy_config import create_trained_policy_fql
from src.openpi.training.config import get_config

checkpoint_dir = "checkpoints/pi0_fql_libero_lora_finetune/exp8_plz/29999/params"
train_config = get_config("pi0_fql_libero_lora_finetune")

In [None]:
repack_transforms = transforms.Group()
checkpoint_dir = download.maybe_download(str(checkpoint_dir))
logging.info("Loading model...")

In [None]:
params_path = pathlib.Path(checkpoint_dir).resolve()
restore_type = jax.Array
sharding = None

if not params_path.exists():
    raise FileNotFoundError(f"Model params not found at: {params_path}")
    
if restore_type is jax.Array and sharding is None:
    mesh = jax.sharding.Mesh(jax.devices(), ("x",))
    sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec())

In [None]:
dtype=jnp.bfloat16
with ocp.PyTreeCheckpointer() as ckptr:
    metadata = ckptr.metadata(params_path)
    item = {
        "actor_params": metadata["actor_params"],
        "critic_params": metadata["critic_params"],
        "critic_target_params": metadata["critic_target_params"],
    }
    params = ckptr.restore(
        params_path,
            ocp.args.PyTreeRestore(
                item=item,
                restore_args=jax.tree.map(
                    lambda _: ocp.ArrayRestoreArgs(sharding=sharding, restore_type=restore_type, dtype=dtype), item
                ),
            ),
    )

In [None]:
flat_params = {}
for k in ["actor_params", "critic_params", "critic_target_params"]:
    flat = traverse_util.flatten_dict(params[k])
    if all(kp[-1] == "value" for kp in flat):
        flat = {kp[:-1]: v for kp, v in flat.items()}
    flat_params[k] = traverse_util.unflatten_dict(flat)


In [None]:
flat_params["actor_params"]["actor"]

In [None]:
params = {
        "actor": flat_params["actor_params"]["actor"],
        "critic": flat_params["critic_params"]["critic"],
        "critic_target": flat_params["critic_target_params"]["critic_target"],
    }


In [None]:
params.keys()

In [None]:
train_config.model.create

In [None]:
model = nnx.eval_shape(train_config.model.create, jax.random.key(0))

In [None]:
graphdef, state = nnx.split(model)

In [None]:
state.to_pure_dict()['critic']['cross_attn_in']['cross_attn']['rngs']

In [None]:
expected_tree = state.to_pure_dict()

In [None]:
state.to_pure_dict()['critic']['img_embed']['proj']

In [None]:
expected_tree['critic']['cross_attn_in']['cross_attn']['rngs']['default']['key']

In [None]:
params['critic']['cross_attn_in']['cross_attn']['rngs']['default']['key']

In [None]:
def print_shape_mismatches(params, expected_tree, path=None):
    import numpy as np

    if path is None:
        path = []

    # 둘 다 dict면 재귀적으로 탐색
    if isinstance(params, dict) and isinstance(expected_tree, dict):
        for key in params:
            if key in expected_tree:
                print_shape_mismatches(params[key], expected_tree[key], path + [key])
            else:
                # expected_tree에 없는 key는 무시 (필요시 출력)
                pass
        return

    # 둘 다 array-like면 shape 비교
    if hasattr(params, 'shape') and hasattr(expected_tree, 'shape'):
        if params.shape != expected_tree.shape:
            print(f"Shape mismatch at {'.'.join(map(str, path))}: "
                  f"params shape={params.shape}, expected shape={expected_tree.shape}")
        return

    # 타입이 다르거나, 한쪽만 array-like인 경우
    if type(params) != type(expected_tree):
        print(f"Type mismatch at {'.'.join(map(str, path))}: "
              f"params type={type(params)}, expected type={type(expected_tree)}")
        return

    # leaf인데 shape 속성이 없는 경우는 무시 (필요시 추가 처리)

In [None]:
print_shape_mismatches(params, expected_tree)

In [None]:
def auto_fix_shape(params, expected_tree, path=None):
    import numpy as np

    if path is None:
        path = []

    if isinstance(params, dict) and isinstance(expected_tree, dict):
        return {k: auto_fix_shape(params[k], expected_tree[k], path + [k])
                for k in params if k in expected_tree}

    # shape 비교
    if hasattr(params, 'shape') and hasattr(expected_tree, 'shape'):
        if params.shape != expected_tree.shape:
            # expected가 스칼라일 때
            if expected_tree.shape == ():
                # JAX, numpy, torch 모두 flatten()[0] 지원
                try:
                    value = params.flatten()[0]
                except Exception:
                    # fallback: numpy array로 변환 후 flatten
                    value = np.array(params).flatten()[0]
                print(f"Auto-fixing {'/'.join(map(str, path))}: {params.shape} -> {expected_tree.shape} (using first value)")
                return value
            # expected shape에 맞게 reshape (가능한 경우)
            try:
                params_reshaped = np.reshape(params, expected_tree.shape)
                print(f"Auto-fixing {'/'.join(map(str, path))}: {params.shape} -> {expected_tree.shape}")
                return params_reshaped
            except Exception as e:
                print(f"Cannot auto-fix {'/'.join(map(str, path))}: {e}")
                return params
        return params

    return params

In [None]:
auto_fix_shape(params, expected_tree)['critic']['cross_attn_in']['cross_attn']['rngs']['default']['key']

In [None]:
params_rm = ocp.transform_utils.intersect_trees(state.to_pure_dict(), params)

In [None]:
params_rm['critic']['cross_attn_in']['cross_attn']['rngs']

In [None]:
params_rm['critic']['img_embed']['proj']

In [None]:
def make_none_tree(tree):
    if isinstance(tree, dict):
        return {k: make_none_tree(v) for k, v in tree.items()}
    else:
        return tree  # None이 아니라 state_dict의 값을 그대로 반환
def align_params_to_state(state_dict, params_dict):
    if isinstance(state_dict, dict):
        out = {}
        for k, v in state_dict.items():
            if k == 'rngs':
                out[k] = make_none_tree(v)  # 구조는 유지, leaf는 state_dict 값
            elif k in params_dict:
                out[k] = align_params_to_state(v, params_dict[k])
            else:
                out[k] = v
        return out
    else:
        return params_dict if params_dict is not None else state_dict

In [None]:
params_rm['critic']['cross_attn_in']['cross_attn']['rngs'] == params['critic']['cross_attn_in']['cross_attn']['rngs']

In [None]:
params_rm['critic']['cross_attn_in']['cross_attn']['rngs']

In [None]:
params['critic']['cross_attn_in']['cross_attn']['rngs']

In [None]:
params_rm == params

In [None]:
from functools import reduce
import operator

def get_keypaths(d, prefix=()):
    if isinstance(d, dict):
        for k, v in d.items():
            yield from get_keypaths(v, prefix + (k,))
    else:
        yield prefix

# params와 params_rm의 keypath 집합
params_kp    = set(get_keypaths(params))
params_rm_kp = set(get_keypaths(params_rm))

# 모델에만 있고 체크포인트에 없는 keypath
only_in_params    = params_kp    - params_rm_kp
# 체크포인트에만 있고 모델에 없는 keypath
only_in_params_rm = params_rm_kp - params_kp

print("=== params에만 있는 keypaths ===")
for kp in sorted(only_in_params):
    print(kp)

print("=== params_rm에만 있는 keypaths ===")
for kp in sorted(only_in_params_rm):
    print(kp)

In [None]:
import numpy as np
from functools import reduce
import operator

def get_by_path(d, keypath):
    return reduce(operator.getitem, keypath, d)

# keypath 집합 계산은 이전과 동일
common_kp = params_kp & params_rm_kp

MAX_SAMPLES = 4000  # 비교할 최대 원소 수

for kp in sorted(common_kp):
    v1 = get_by_path(params,    kp)
    v2 = get_by_path(params_rm, kp)

    # 배열형 leaf만 처리
    if hasattr(v1, 'shape') and hasattr(v2, 'shape'):
        # shape 불일치 점검
        if v1.shape != v2.shape:
            print(f"[SHAPE MISMATCH] {kp}: {v1.shape} vs {v2.shape}")
            continue

        # float32로 바꿔 numpy array 생성
        try:
            a1 = np.array(v1, dtype=np.float32).ravel()
            a2 = np.array(v2, dtype=np.float32).ravel()
        except Exception:
            print(f"[SKIPPED] {kp}: cannot convert to float32")
            continue

        # 비교할 원소 개수 제한
        n = min(a1.size, MAX_SAMPLES)
        # 앞에서 n개만 사용하거나, 랜덤 샘플을 뽑으려면 아래 주석 코드 사용
        idx = np.arange(n)
        # idx = np.random.choice(a1.size, n, replace=False)

        sample1 = a1[idx]
        sample2 = a2[idx]

        # 값 차이 계산
        diff = np.abs(sample1 - sample2)
        max_diff  = diff.max()
        mean_diff = diff.mean()

        if max_diff > 1e-6:  # 어떤 기준값을 설정하세요
            print(f"[VALUE MISMATCH] {kp}: max={max_diff:.3e}, mean={mean_diff:.3e}")
    else:
        # 배열이 아닌 경우 간단 비교
        if v1 != v2:
            print(f"[TYPE MISMATCH] {kp}: {v1!r} vs {v2!r}")

In [None]:
def get_keypaths(d, prefix=()):
    """파라미터 트리의 모든 keypath를 tuple로 반환"""
    if isinstance(d, dict):
        for k, v in d.items():
            yield from get_keypaths(v, prefix + (k,))
    else:
        yield prefix

# 예시: params와 expected_tree가 dict 형태라고 가정
params_keypaths = set(get_keypaths(params_rm))
expected_keypaths = set(get_keypaths(expected_tree))

print("=== 모델에만 있고 체크포인트에 없는 keypaths ===")
for kp in sorted(expected_keypaths - params_keypaths):
    print(kp)

print("=== 체크포인트에만 있고 모델에 없는 keypaths ===")
for kp in sorted(params_keypaths - expected_keypaths):
    print(kp)

In [None]:
state.replace_by_pure_dict(params)
state.to_pure_dict()['critic']['cross_attn_in']['cross_attn']['rngs']

In [None]:
state.replace_by_pure_dict(params_rm)
state2 = state

In [None]:
state1 == state2

In [None]:
state1.to_pure_dict() == state2.to_pure_dict()

In [None]:
state1.to_pure_dict()['critic']['cross_attn_in']['cross_attn']['rngs']

In [None]:
state2.to_pure_dict()['critic']['cross_attn_in']['cross_attn']['rngs']

In [None]:
pi_model = nnx.merge(graphdef, state)

In [None]:
nnx.state(pi_model).to_pure_dict()

In [None]:
nnx.state(pi_model).to_pure_dict()['actor']['PaliGemma']['llm']['layers']['attn']['attn_vec_einsum']['lora_a'] == params['actor']['PaliGemma']['llm']['layers']['attn']['attn_vec_einsum']['lora_a']

In [None]:
def get_lora_keypaths(tree, prefix=()):
    """
    중첩된 dict(PyTree)에서 keypath tuple을 재귀적으로 수집하다가
    'lora'가 들어간 keypath만 반환
    """
    lora_keys = []
    if isinstance(tree, dict):
        for k, v in tree.items():
            lora_keys += get_lora_keypaths(v, prefix + (k,))
    else:
        full_key = "/".join(prefix)
        if "lora" in full_key.lower():
            lora_keys.append(full_key)
    return lora_keys

# 사용 예시
# params = model.params 또는 torch로는 model.state_dict()
lora_keys = get_lora_keypaths(nnx.state(pi_model).to_pure_dict())
print("▶︎ LoRA 파라미터 keypaths:")
for k in lora_keys:
    print(k)

In [None]:
rng = jax.random.key(0)

In [None]:
import src.openpi.training.data_loader as _data_loader
data_loader = _data_loader.create_data_loader(train_config, sharding=sharding, shuffle=True, num_batches=1)

In [None]:
pi_model.sample_actions(rng, )

In [None]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "2"

In [None]:
import pandas as pd

path = "/ssd1/openpi_official/datasets/libero_fql/data/chunk-000/episode_000000.parquet"
df = pd.read_parquet(path)

# 컬럼과 dtypes 확인
print(df.info())

# 샘플 5행 (후니)보기
print(df.head())

In [None]:
df['next_state'][0]

In [None]:
df['state'][1]

In [None]:
df['terminal']

In [None]:
df['actions'][100]

In [None]:
df['reward']

In [None]:
from src.openpi.training.config import get_config
import jax
import src.openpi.training.data_loader as _data_loader
import src.openpi.training.checkpoints as _checkpoints
import src.openpi.training.sharding as sharding


config = get_config("pi0_fql_libero_lora_finetune")

In [None]:
if config.batch_size % jax.device_count() != 0:
    raise ValueError(f"Batch size {config.batch_size} must be divisible by number of devices.")

rng = jax.random.key(config.seed)
train_rng, init_rng = jax.random.split(rng)

mesh = sharding.make_mesh(config.fsdp_devices)
data_sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec(sharding.DATA_AXIS))
replicated_sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec())

checkpoint_manager, resuming = _checkpoints.initialize_checkpoint_dir(
    config.checkpoint_dir,
    keep_period=config.keep_period,
    overwrite=config.overwrite,
        resume=config.resume,
    )

data_loader = _data_loader.create_data_loader(
        config,
        sharding=data_sharding,
        shuffle=True,
    )
data_iter = iter(data_loader)
batch = next(data_iter)