In [None]:
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F

from lerobot.datasets.lerobot_dataset import LeRobotDataset

parquet_path = "/home/joy4mj/Feel2Grasp/train.parquet"
repo_id = "mjkim00/Feel2Grasp"
revision = "main"
video_backend = "pyav"

encoder_ckpt_path = "./ae_out/encoder.pt"
batch_size = 256
device = "cuda" if torch.cuda.is_available() else "cpu"

out_path = "replay_buffer_iql_normalized.npz" 
EPS = 1e-8 

# load encoder
class ConvEncoder(nn.Module):
    def __init__(self, latent_dim: int):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(3, 32, 4, 2, 1), nn.ReLU(inplace=True),
            nn.Conv2d(32, 64, 4, 2, 1), nn.ReLU(inplace=True),
            nn.Conv2d(64, 128, 4, 2, 1), nn.ReLU(inplace=True),
            nn.Conv2d(128, 256, 4, 2, 1), nn.ReLU(inplace=True),
            nn.AdaptiveAvgPool2d((4, 4)),
        )
        self.fc = nn.Linear(256 * 4 * 4, latent_dim)

    def forward(self, x):
        return self.fc(self.net(x).flatten(1))

try:
    ckpt = torch.load(encoder_ckpt_path, map_location="cpu")
except FileNotFoundError:
    print(f"Error: Encoder checkpoint not found at {encoder_ckpt_path}")
    ckpt = {"latent_dim": 64, "resize_hw": [128, 128], "encoder_state_dict": {}}
    print("Using dummy encoder settings. Please ensure real encoder is loaded.")

encoder = ConvEncoder(ckpt["latent_dim"])
if ckpt["encoder_state_dict"]:
    encoder.load_state_dict(ckpt["encoder_state_dict"])
encoder.to(device).eval()
for p in encoder.parameters():
    p.requires_grad_(False)

resize_hw = tuple(ckpt["resize_hw"])

#  Parquet load + sanity check

df = pd.read_parquet(parquet_path)

need_cols = [
    "action", "observation.state", "episode_index", "frame_index", "index",
    "left_image_circle", "right_image_circle", "circle_reward"
]

df = df.sort_values(["episode_index", "frame_index"]).reset_index(drop=True)
N = len(df)
print("N =", N)

ep = df["episode_index"].to_numpy()
fr = df["frame_index"].to_numpy()
global_idx = df["index"].to_numpy()

# check for duplicated (episode_index, frame_index)
obs_state = np.stack(df["observation.state"].to_list()).astype(np.float32)
actions   = np.stack(df["action"].to_list()).astype(np.float32)
lr01      = df[["left_image_circle", "right_image_circle"]].to_numpy().astype(np.float32)
rewards   = df["circle_reward"].to_numpy().astype(np.float32)

if obs_state.ndim != 2 or obs_state.shape[1] != 6:
    raise ValueError(f"observation.state expected (N,6) but got {obs_state.shape}")

# Reward Rescaling: -1/+10 -> -0.1/+1.0 
rewards = rewards / 10.0

# LeRobotDataset load
ds = LeRobotDataset(repo_id, revision=revision, video_backend=video_backend)
print("len(ds) =", len(ds))

def get_sample(ds, i: int):
    return ds[int(i)]

def get_key(sample, key):
    if key in sample:
        return sample[key]
    raise KeyError(f"Key {key} not found in sample keys: {list(sample.keys())[:30]} ...")

# index mapping
check_k = min(200, N)
ok_direct = True
for k in np.linspace(0, N-1, check_k, dtype=int):
    i = int(global_idx[k])
    samp = get_sample(ds, i)
    ep_ds = int(get_key(samp, "episode_index"))
    fr_ds = int(get_key(samp, "frame_index"))
    if ep_ds != int(ep[k]) or fr_ds != int(fr[k]):
        ok_direct = False
        break

if not ok_direct:
    print("[WARN] parquet.index != ds index mapping. Building (episode,frame)->ds_index map...")
    map_epfr_to_dsidx = {}
    for i in range(len(ds)):
        samp = get_sample(ds, i)
        ep_i = int(get_key(samp, "episode_index"))
        fr_i = int(get_key(samp, "frame_index"))
        map_epfr_to_dsidx[(ep_i, fr_i)] = i

    ds_indices = np.empty((N,), dtype=np.int64)
    for k in range(N):
        key = (int(ep[k]), int(fr[k]))
        if key not in map_epfr_to_dsidx:
            raise KeyError(f"Cannot find ds index for (episode,frame)={key}")
        ds_indices[k] = map_epfr_to_dsidx[key]
else:
    ds_indices = global_idx.astype(np.int64)


# front image -> encoder -> z(64) 
@torch.no_grad()
def encode_front_batch(ds, idx_batch):
    imgs = []
    for i in idx_batch:
        samp = get_sample(ds, int(i))
        img = get_key(samp, "observation.images.front")
        if isinstance(img, np.ndarray):
            if img.ndim == 3 and img.shape[-1] == 3:
                img = torch.from_numpy(img).permute(2, 0, 1)
            else:
                img = torch.from_numpy(img)
        
        if img.dtype == torch.uint8:
            img = img.float() / 255.0
        else:
            img = img.float()
            if img.max() > 1.5:
                img = img / 255.0

        imgs.append(img)

    x = torch.stack(imgs, dim=0).to(device)
    x = F.interpolate(x, size=resize_hw, mode="bilinear", align_corners=False)
    z = encoder(x)
    return z.detach().cpu().numpy().astype(np.float32)

Z = np.empty((N, ckpt["latent_dim"]), dtype=np.float32)

for start in range(0, N, batch_size):
    end = min(N, start + batch_size)
    Z[start:end] = encode_front_batch(ds, ds_indices[start:end])
    if (start // batch_size) % 20 == 0:
        print(f"encoded {end}/{N}")


# states(72) 
states = np.concatenate([Z, lr01, obs_state], axis=1).astype(np.float32)
if states.shape[1] != 72:
    raise ValueError(f"states dim expected 72 but got {states.shape}")

# done / next_state 
terminals = np.zeros((N,), dtype=np.float32)
cont = (ep[1:] == ep[:-1]) & (fr[1:] == fr[:-1] + 1)
terminals[:-1] = (~cont).astype(np.float32)
terminals[-1] = 1.0

next_states = np.empty_like(states)
next_states[:-1] = states[1:]
next_states[-1] = states[-1]

terminal_idx = np.where(terminals > 0.5)[0]
next_states[terminal_idx] = states[terminal_idx]

# state / action normalization
# 1. State Normalization 
obs_mean = states.mean(axis=0, keepdims=True)
obs_std = states.std(axis=0, keepdims=True)
obs_std = np.where(obs_std < EPS, 1.0, obs_std) 

normalized_states = (states - obs_mean) / obs_std
next_states_norm = (next_states - obs_mean) / obs_std

# 2. Action Normalization 
act_mean = actions.mean(axis=0, keepdims=True)
act_std = actions.std(axis=0, keepdims=True)
act_std = np.where(act_std < EPS, 1.0, act_std)

normalized_actions = (actions - act_mean) / act_std

# save normalized replay buffer
np.savez_compressed(
    out_path,
    
    # normalized 
    observations=normalized_states.astype(np.float32),
    actions=normalized_actions.astype(np.float32),
    next_observations=next_states_norm.astype(np.float32),
    
    rewards=rewards.astype(np.float32), 
    terminals=terminals.astype(np.float32),
    episode_index=ep.astype(np.int32),
    frame_index=fr.astype(np.int32),
    ds_index=ds_indices.astype(np.int64),
    
    obs_mean=obs_mean.astype(np.float32),
    obs_std=obs_std.astype(np.float32),
    act_mean=act_mean.astype(np.float32),
    act_std=act_std.astype(np.float32),
)

print("Saved:", out_path)
print("Shapes:",
      "S", states.shape,
      "A", actions.shape,
      "R", rewards.shape,
      "S'", next_states.shape,
      "D", terminals.shape)

N = 53936
len(ds) = 53936
encoded 256/53936
encoded 5376/53936
encoded 10496/53936
encoded 15616/53936
encoded 20736/53936
encoded 25856/53936
encoded 30976/53936
encoded 36096/53936
encoded 41216/53936
encoded 46336/53936
encoded 51456/53936
Saved: replay_buffer_iql_normalized.npz
Shapes: S (53936, 72) A (53936, 6) R (53936,) S' (53936, 72) D (53936,)


In [None]:
import numpy as np

d = np.load("replay_buffer_iql_72d.npz")
ep = d["episode_index"]
fr = d["frame_index"]
done = d["terminals"]

# check episode/timestep alignment
idx = np.where(done[:-1] < 0.5)[0]
assert np.all(ep[idx] == ep[idx+1]), "Found done=0 but episode changes!"
assert np.all(fr[idx+1] == fr[idx] + 1), "Found done=0 but frame_index not consecutive!"

# done==1 
idx = np.where(done[:-1] > 0.5)[0]
bad = np.where((ep[idx] == ep[idx+1]) & (fr[idx+1] == fr[idx] + 1))[0]
assert len(bad) == 0, "Found done=1 but next is still a valid continuation!"

print("OK: episode/timestep alignment is consistent.")


OK: episode/timestep alignment is consistent.


In [None]:
import numpy as np
d = np.load("replay_buffer_iql_normalized.npz")
r = d["rewards"]
print("N:", len(r))
print("reward min/max:", r.min(), r.max())
print("count success:", (r > 0.9).sum())   
print("success ratio:", (r > 0.9).mean())

N: 53936
reward min/max: -0.1 1.0
count success: 5932
success ratio: 0.10998220112726194
