# Fine-tuning SmolVLA on AIRoA MoMa

[LeRobot](https://github.com/huggingface/lerobot) を使い、
[SmolVLA](https://huggingface.co/lerobot/smolvla_base) を
[airoa-org/airoa-moma](https://huggingface.co/datasets/airoa-org/airoa-moma) データセットで Fine-tuning する。

| | |
|---|---|
| **モデル** | SmolVLA (450M) — Vision-Language-Action モデル |
| **データセット** | AIRoA MoMa — 23K エピソード / 9.4M フレーム（Toyota HSR） |
| **フレームワーク** | LeRobot v0.4 |

### 前提条件
- Google Colab で **GPU ランタイム**（T4 以上）を選択済み
- `HF_TOKEN` を Colab Secrets に登録済み — [トークン取得](https://huggingface.co/settings/tokens)
- `WANDB_API_KEY` を Colab Secrets に登録済み — [キー取得](https://wandb.ai/authorize)

## 1. インストール

In [None]:
import subprocess, sys, shutil

subprocess.check_call([sys.executable, "-m", "pip", "install", "-q", "lerobot[smolvla]"])

import lerobot
print(f"lerobot {lerobot.__version__}")
print("lerobot-train:", shutil.which("lerobot-train") or "(未検出、python -m で代替)")

In [None]:
import subprocess, sys, re

def _ver(pkg):
    try:
        out = subprocess.check_output([sys.executable, "-m", "pip", "show", pkg], text=True)
        m = re.search(r"^Version:\s*(.+)$", out, re.M)
        return m.group(1).strip() if m else None
    except subprocess.CalledProcessError:
        return None

versions = {p: _ver(p) for p in ["lerobot", "transformers", "huggingface_hub"]}
for p, v in versions.items():
    print(f"  {p}: {v}")

needs_restart = False
tf_major = int(versions["transformers"].split(".")[0]) if versions["transformers"] else 0
hub_major = int(versions["huggingface_hub"].split(".")[0]) if versions["huggingface_hub"] else 0

# SmolVLA は transformers 4.x が必要
if tf_major >= 5:
    subprocess.check_call([sys.executable, "-m", "pip", "install", "-U", "-q",
                           "transformers>=4.57.1,<5.0.0"])
    needs_restart = True
if tf_major < 5 and hub_major >= 1:
    subprocess.check_call([sys.executable, "-m", "pip", "install", "-U", "-q",
                           "huggingface_hub>=0.24.0,<1.0.0"])
    needs_restart = True

if needs_restart:
    print("\n⚠️ 依存関係を更新しました。ランタイムを再起動してください。")
else:
    print("\n✅ 依存関係OK")

## 2. 設定

In [None]:
import os
from google.colab import userdata

os.environ["HF_TOKEN"] = userdata.get("HF_TOKEN")
os.environ["HUGGINGFACE_HUB_TOKEN"] = os.environ["HF_TOKEN"]
os.environ["WANDB_API_KEY"] = userdata.get("WANDB_API_KEY")

assert os.environ["HF_TOKEN"], "Colab Secrets に HF_TOKEN を登録してください"
assert os.environ["WANDB_API_KEY"], "Colab Secrets に WANDB_API_KEY を登録してください"
print("✅ トークン読み込み完了")

In [None]:
# ===== 学習パラメータ =====

DATASET_REPO_ID  = "airoa-org/airoa-moma"
DATASET_REVISION = "main"
POLICY_PATH      = "lerobot/smolvla_base"

# WandB
WANDB_ENTITY  = "ken05-matuo-llm-88_llm_2025_suzuki"
WANDB_PROJECT = "icra2026-vla"

# 学習ハイパーパラメータ
STEPS      = 147233
BATCH_SIZE = 64
SAVE_FREQ  = 50000

# 学習済みポリシーを HF Hub にプッシュ
PUSH_TO_HUB        = True
POLICY_PRIVATE      = True
POLICY_REPO_ID_BASE = "ICRA-2026-RAMEN/smolvla-airoa-moma"

# ジョブ命名
BASE_JOB_NAME   = "smolvla_airoa_moma"
BASE_OUTPUT_DIR = "./outputs/smolvla_airoa_moma"

# 特徴量マッピング: airoa-moma → SmolVLA カメラ名
# データセットはカメラ2台、SmolVLA は3台想定（1台をダミーで補完）
RENAME_MAP = {
    "observation.image.hand": "observation.images.camera1",
    "observation.image.head": "observation.images.camera2",
}
EMPTY_CAMERAS = 1

print(f"Dataset: {DATASET_REPO_ID}")
print(f"Policy:  {POLICY_PATH}")
print(f"Steps:   {STEPS}, Batch: {BATCH_SIZE}")

## 3. データセット

In [None]:
from huggingface_hub import hf_hub_download, snapshot_download
import json, os

DATASET_ROOT = "/content/lerobot_datasets/airoa-moma"

# データセットのメタ情報を確認
info_path = hf_hub_download(
    repo_id=DATASET_REPO_ID, repo_type="dataset",
    filename="meta/info.json", revision=DATASET_REVISION,
    token=os.environ.get("HF_TOKEN"),
)
with open(info_path) as f:
    info = json.load(f)
print(f"Episodes: {info['total_episodes']:,}  Frames: {info['total_frames']:,}  Tasks: {info['total_tasks']}")
print(f"Size: data {info.get('data_files_size_in_mb', '?')} MB + video {info.get('video_files_size_in_mb', '?')} MB")

# 全ファイルをダウンロード
print(f"\n{DATASET_ROOT} にダウンロード中 ...")
snapshot_download(
    repo_id=DATASET_REPO_ID, repo_type="dataset",
    revision=DATASET_REVISION, local_dir=DATASET_ROOT,
    allow_patterns=["meta/**", "data/**", "videos/**"],
    token=os.environ.get("HF_TOKEN"),
)
print("✅ 完了")

In [None]:
import json, os, glob
import pyarrow.parquet as pq

# airoa-moma はアクションをサブ特徴量 (action.absolute, action.relative, ...) で保持している。
# LeRobot は単一の "action" キーを要求するため action.absolute [8D] を "action" にリネームする。
# また SmolVLA が使わない特徴量を除去してデータセット読み込みを高速化する。

ACTION_SOURCE = "action.absolute"
KEEP_FEATURES = {
    "action", "observation.state",
    "observation.image.hand", "observation.image.head",
    "episode_index", "frame_index", "timestamp",
    "next.done", "index", "task_index",
}

# --- info.json 更新 ---
info_path = os.path.join(DATASET_ROOT, "meta", "info.json")
with open(info_path) as f:
    info = json.load(f)
features = info["features"]
if ACTION_SOURCE in features:
    features["action"] = features.pop(ACTION_SOURCE)
dropped = [k for k in list(features) if k not in KEEP_FEATURES]
for k in dropped:
    del features[k]
info["features"] = features
with open(info_path, "w") as f:
    json.dump(info, f, indent=4)
print(f"info.json: {len(features)} 特徴量を保持、{len(dropped)} 件を除去")

# --- stats.json 更新 ---
stats_path = os.path.join(DATASET_ROOT, "meta", "stats.json")
if os.path.exists(stats_path):
    with open(stats_path) as f:
        stats = json.load(f)
    if ACTION_SOURCE in stats:
        stats["action"] = stats.pop(ACTION_SOURCE)
    for k in [k for k in stats if k not in KEEP_FEATURES]:
        del stats[k]
    with open(stats_path, "w") as f:
        json.dump(stats, f, indent=4)

# --- Parquet ファイル更新 ---
keep_parquet = {k for k in features if features[k].get("dtype") != "video"}
parquet_files = sorted(glob.glob(
    os.path.join(DATASET_ROOT, "data", "**", "*.parquet"), recursive=True
))
sample = pq.read_table(parquet_files[0])
if set(sample.column_names) == keep_parquet:
    print("Parquet ファイルは整合済み。スキップ。")
else:
    print(f"{len(parquet_files)} 個の Parquet ファイルを更新中 ...")
    for pf in parquet_files:
        table = pq.read_table(pf)
        if ACTION_SOURCE in table.column_names:
            table = table.rename_columns(
                ["action" if c == ACTION_SOURCE else c for c in table.column_names]
            )
        table = table.drop([c for c in table.column_names if c not in keep_parquet])
        pq.write_table(table, pf)
    assert set(pq.read_table(parquet_files[0]).column_names) == keep_parquet
    print("✅ 完了")

print(f"\n特徴量: {sorted(features)}")
print(f"Action shape: {features['action']['shape']}")

## 4. 学習

In [None]:
import shutil, subprocess, shlex, os, sys, json
from datetime import datetime

# 実行ごとにユニークなIDを生成
ts = datetime.now().strftime("%Y%m%d_%H%M%S")
JOB_NAME    = f"{BASE_JOB_NAME}_{ts}"
OUTPUT_DIR  = f"{BASE_OUTPUT_DIR}_{ts}"
POLICY_REPO = f"{POLICY_REPO_ID_BASE}-{ts}"

# 前処理が完了しているか確認
with open(os.path.join(DATASET_ROOT, "meta", "info.json")) as f:
    _info = json.load(f)
assert "action" in _info["features"], "先に前処理セルを実行してください"
assert not any(k.startswith("action.") for k in _info["features"]), "前処理が不完全です"

# 学習率スケジュールを STEPS に合わせて自動調整
# SmolVLA デフォルト: warmup=1000, decay=30000 (比率 1:30)
# decay_steps = STEPS とし、warmup もデフォルト比率で比例させる
WARMUP_STEPS = max(100, STEPS // 30)
DECAY_STEPS  = STEPS
print(f"LR schedule: warmup {WARMUP_STEPS} steps → cosine decay {DECAY_STEPS} steps")

# コマンド構築
exe = shutil.which("lerobot-train")
base = [exe] if exe else [sys.executable, "-m", "lerobot.scripts.lerobot_train"]

cmd = base + [
    f"--dataset.repo_id={DATASET_REPO_ID}",
    f"--dataset.revision={DATASET_REVISION}",
    f"--dataset.root={DATASET_ROOT}",
    f"--policy.path={POLICY_PATH}",
    f"--output_dir={OUTPUT_DIR}",
    f"--job_name={JOB_NAME}",
    "--policy.device=cuda",
    f"--steps={STEPS}",
    f"--batch_size={BATCH_SIZE}",
    f"--save_freq={SAVE_FREQ}",
    f"--policy.scheduler_warmup_steps={WARMUP_STEPS}",
    f"--policy.scheduler_decay_steps={DECAY_STEPS}",
    "--wandb.enable=true",
    f"--wandb.entity={WANDB_ENTITY}",
    f"--wandb.project={WANDB_PROJECT}",
    f"--rename_map={json.dumps(RENAME_MAP)}",
    f"--policy.empty_cameras={EMPTY_CAMERAS}",
]
if PUSH_TO_HUB:
    cmd += [f"--policy.repo_id={POLICY_REPO}",
            f"--policy.private={str(POLICY_PRIVATE).lower()}"]

print(f"Job:    {JOB_NAME}")
print(f"Output: {OUTPUT_DIR}")
if PUSH_TO_HUB:
    print(f"Hub:    {POLICY_REPO}")
print(f"\n{' '.join(shlex.quote(x) for x in cmd)}\n")

# リアルタイム出力で学習実行
env = os.environ.copy()
env["HUGGINGFACE_HUB_TOKEN"] = env.get("HF_TOKEN", "")
env["PYTHONUNBUFFERED"] = "1"

proc = subprocess.Popen(
    cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT,
    text=True, env=env, bufsize=1,
)
for line in proc.stdout:
    print(line, end="", flush=True)
proc.wait()

print(f"\n{'=' * 50}")
if proc.returncode == 0:
    print("✅ 学習完了!")
else:
    raise RuntimeError(f"学習失敗 (return code {proc.returncode})")