# Nemotron-VLA: Minimal Training & Inference

**Vision:** NVIDIA RADIO | **Language:** NVIDIA Nemotron Nano 9B v2 | **Action:** Diffusion Policy

Dataset: [`keivalya/nemotron-vla-metaworld`](https://huggingface.co/datasets/keivalya/nemotron-vla-metaworld) | Model: [`keivalya/nemotron-vla`](https://huggingface.co/keivalya/nemotron-vla)

> Helper functions : [`helpers.py`](./helpers.py)


## 1. Setup


In [None]:
import os
os.environ["MUJOCO_GL"] = "egl"

!pip install -q torch torchvision transformers accelerate \
    mujoco gymnasium metaworld \
    imageio[ffmpeg] imageio-ffmpeg \
    matplotlib datasets pyarrow Pillow \
    open_clip_torch timm tqdm

# Nemotron Mamba dependencies
!pip install -q causal-conv1d mamba-ssm

import torch
print(f"PyTorch: {torch.__version__}, CUDA: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")


## 2. Load Dataset from HuggingFace

Downloads the expert demonstration dataset and extracts images, states, actions, and instructions.


In [None]:
from datasets import load_dataset
from PIL import Image
import io, numpy as np, os

print("Downloading dataset from HuggingFace...")
ds = load_dataset("keivalya/nemotron-vla-metaworld", split="train")
print(f"Total transitions: {len(ds)}")

# Extract arrays
print("\nExtracting data...")
images, states, actions, instructions = [], [], [], []
for i, row in enumerate(ds):
    img = np.array(Image.open(io.BytesIO(row["image"])))
    images.append(img)
    states.append(np.array(row["state"], dtype=np.float32))
    actions.append(np.array(row["action"], dtype=np.float32))
    instructions.append(row["instruction"])
    if (i + 1) % 50000 == 0:
        print(f"  {i+1}/{len(ds)}")

images = np.stack(images)
states = np.stack(states)
actions = np.stack(actions)
unique_instructions = list(set(instructions))

print(f"\nLoaded: {images.shape[0]} transitions, {len(unique_instructions)} tasks")
print(f"   Images:  {images.shape}")
print(f"   States:  {states.shape}")
print(f"   Actions: {actions.shape}")

## 3. Precompute Embeddings

Sequentially load RADIO → extract vision features → unload, then load Nemotron → extract text embeddings → unload.
This keeps GPU memory under 40GB.


In [None]:
import torch, numpy as np
from helpers import load_radio, extract_radio_features, unload

EMB_PATH = "embeddings.npz"
if os.path.exists(EMB_PATH):
    print("Embeddings already cached, loading...")
    emb = np.load(EMB_PATH, allow_pickle=True)
    radio_features = emb["radio_features"]
    nemotron_embeddings = emb["nemotron_embeddings"]
    radio_dim = int(emb["radio_dim"])
    nemotron_dim = int(emb["nemotron_dim"])
else:
    # RADIO vision features
    print("\nRADIO Vision Features")
    radio_model, radio_dim = load_radio("cuda")
    radio_features = extract_radio_features(radio_model, images, "cuda", batch_size=64)
    unload(radio_model)

    # Nemotron text embeddings
    print("\nNemotron Text Embeddings")
    from helpers import load_nemotron, extract_nemotron_embedding
    nem_model, tok, nemotron_dim = load_nemotron("cuda")

    instr_to_emb = {}
    for i, instr in enumerate(unique_instructions):
        print(f"  [{i+1}/{len(unique_instructions)}] \"{instr}\"")
        instr_to_emb[instr] = extract_nemotron_embedding(nem_model, tok, instr, "cuda")
    unload(nem_model); del tok; torch.cuda.empty_cache()

    nemotron_embeddings = np.stack([instr_to_emb[ins] for ins in instructions])

    np.savez_compressed(EMB_PATH, radio_features=radio_features, nemotron_embeddings=nemotron_embeddings,
                        radio_dim=radio_dim, nemotron_dim=nemotron_dim)
    print(f"\nSaved embeddings: {EMB_PATH}")

print(f"\nRadio features:     {radio_features.shape} (dim={radio_dim})")
print(f"Nemotron embeddings: {nemotron_embeddings.shape} (dim={nemotron_dim})")

## 4. Train Nemotron-VLA

Only the lightweight fusion + diffusion head (~0.8M params) are trained. We keep the vision *(NVIDIA RADIO)* and language *(NVIDIA Nemotron Nano 9B v2)* encoders frozen.


In [None]:
from helpers import NemotronVLA, VLADataset, train

state_dim = states.shape[1]
action_dim = actions.shape[1]

dataset = VLADataset(radio_features, nemotron_embeddings, states, actions)
print(f"Dataset: {len(dataset)} samples")

model = NemotronVLA(
    radio_dim=radio_dim, nemotron_dim=nemotron_dim,
    state_dim=state_dim, action_dim=action_dim,
    d_model=256, n_heads=4, diffusion_T=20,
)

total = sum(p.numel() for p in model.parameters())
print(f"Trainable parameters: {total:,}")

losses = train(model, dataset, epochs=80, batch_size=128, lr=3e-4,
               device="cuda", save_path="nemotron_vla.pt")

In [None]:
# (optional) Plot training loss
import matplotlib.pyplot as plt
plt.figure(figsize=(10, 4))
plt.plot(losses, color="#76b900", linewidth=2)
plt.xlabel("Epoch"); plt.ylabel("Loss"); plt.title("Nemotron-VLA Training")
plt.yscale("log"); plt.grid(alpha=0.3); plt.tight_layout(); plt.show()

## 5. Inference

Enter any MetaWorld task + instruction. The model generates actions and saves a video.


In [None]:
import torch, numpy as np, os
from helpers import load_radio, load_nemotron, extract_nemotron_embedding, unload, NemotronVLA, run_inference

device = "cuda"

# Pre-encode instructions (Nemotron loaded once, then freed)
TASKS = [
    ("push-v3",         "push the object to the goal"),
    ("door-open-v3",    "open the door"),
    ("drawer-close-v3", "close the drawer"),
    ("window-open-v3",  "open the window"),
    ("button-press-v3", "press the button"),
    ("faucet-open-v3",  "open the faucet"),
    ("reach-v3",        "reach to the target"),
    ("pick-place-v3",   "pick and place the object"),
]

print("Encoding instructions with Nemotron...")
nem_model, tok, _ = load_nemotron(device)
cached = {}
for env_name, instr in TASKS:
    print(f"  → \"{instr}\"")
    cached[instr] = extract_nemotron_embedding(nem_model, tok, instr, device)
unload(nem_model); del nem_model, tok; torch.cuda.empty_cache()

# ── Load VLA + RADIO ──
ckpt = torch.load("nemotron_vla.pt", map_location=device, weights_only=False)
vla = NemotronVLA(**ckpt["config"]).to(device)
vla.load_state_dict(ckpt["model_state_dict"])
vla.eval()
print(f"\nVLA loaded (epoch {ckpt['epoch']}, loss {ckpt['loss']:.6f})")

radio, _ = load_radio(device)

# ── Run all tasks ──
os.makedirs("inference", exist_ok=True)
print(f"\n{'━'*50}")
for env_name, instr in TASKS:
    run_inference(vla, radio, cached[instr], env_name, instr, device=device)


## 6. View Results


In [None]:
import glob
from IPython.display import HTML, display
from base64 import b64encode

for vf in sorted(glob.glob("inference/*.mp4")):
    with open(vf, "rb") as f:
        b64 = b64encode(f.read()).decode()
    name = os.path.basename(vf).replace(".mp4", "").replace("_", " ")
    display(HTML(f'<h4>{name}</h4><video width="360" controls loop><source src="data:video/mp4;base64,{b64}" type="video/mp4"></video>'))
