In [None]:
from xvla_wlr.model_legacy import XVLA, XVLAProcessor, Trainer, get_peft_model, Action, Observation

model = XVLA.from_pretrained("2toINF/X-VLA-SoftFold")
processor = XVLAProcessor.from_pretrained("2toINF/X-VLA-SoftFold", use_fast=True)

Florence2ForConditionalGeneration has generative capabilities, as `prepare_inputs_for_generation` is explicitly overwritten. However, it doesn't directly inherit from `GenerationMixin`. From ðŸ‘‰v4.50ðŸ‘ˆ onwards, `PreTrainedModel` will NOT inherit from `GenerationMixin`, and this model will lose the ability to call `generate` and other related functions.
  - If you are the owner of the model architecture code, please modify your model class such that it inherits from `GenerationMixin` (after `PreTrainedModel`, otherwise you'll get an exception).
  - If you are not the owner of the model architecture class, please contact the model code owner to update it.


In [2]:
model = get_peft_model(model)
# TODO
model = model.to(device="cuda")
trainer = Trainer(model, processor)

In [None]:
import os

from datasets_wlr import WLRZhuangEpisodeDataset
from curobo.types.robot import RobotConfig
from xvla_wlr.model_legacy import DATA_DOMAIN_ID


dataset = WLRZhuangEpisodeDataset("samples/2026-01-21_demo_clothes/episode_0/data.json")
domain_id = DATA_DOMAIN_ID["robomind-agilex"]


from xvla_wlr_experiments.xvla_finetune_piper_v0.dataset import XVLAWLRZhuangEpisodeDataset

# TODO
xvla_dataset = XVLAWLRZhuangEpisodeDataset(
    dataset=dataset,
    robot_config_left=RobotConfig.from_basic(
        f"{os.getcwd()}/robots/piper-dualarm/piper-dualarm.urdf",
        base_link="common_base_link",
        ee_link="left_link8",
    ),
    robot_config_right=RobotConfig.from_basic(
        f"{os.getcwd()}/robots/piper-dualarm/piper-dualarm.urdf",
        base_link="common_base_link",
        ee_link="right_link8",
    ),
    domain_id=domain_id,
    prefetch=True,
)

kinematics_fused_cu not found, JIT compiling...
geom_cu binary not found, jit compiling...
lbfgs_step_cu not found, JIT compiling...
line_search_cu not found, JIT compiling...
tensor_step_cu not found, jit compiling...


In [31]:
xvla_dataset[:].ee_transform[..., :3, 3]

tensor([[[ 0.1879,  0.3311,  0.1797],
         [ 0.1792, -0.3918,  0.1828]],

        [[ 0.1879,  0.3311,  0.1797],
         [ 0.1790, -0.3918,  0.1825]],

        [[ 0.1879,  0.3311,  0.1797],
         [ 0.1789, -0.3918,  0.1826]],

        ...,

        [[ 0.1711,  0.2575,  0.1519],
         [ 0.1887, -0.3224,  0.3010]],

        [[ 0.1711,  0.2575,  0.1519],
         [ 0.1870, -0.3211,  0.2559]],

        [[ 0.1915,  0.7250,  0.0823],
         [ 0.1675,  0.0938,  0.1787]]], device='cuda:0')

In [21]:
robot_config = RobotConfig.from_basic(
    f"{os.getcwd()}/robots/piper-dualarm/piper-dualarm.urdf",
    base_link="common_base_link",
    ee_link="left_link8",
)
robot_config.kinematics.kinematics_config.fixed_transforms[..., :3, 3]

tensor([[ 0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.3300,  0.0000],
        [ 0.0000,  0.0000,  0.1230],
        [ 0.0000,  0.0000,  0.0000],
        [ 0.2836,  0.0287,  0.0000],
        [-0.2422,  0.0685,  0.0000],
        [ 0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0910,  0.0014],
        [ 0.0000,  0.0000,  0.1350]], device='cuda:0')

In [22]:
import torch.linalg
from curobo.types.robot import RobotConfig

def compute_conservative_reach_radius(robot: RobotConfig):
    link_transforms_positions = robot.kinematics.kinematics_config.fixed_transforms[..., :3, 3]
    return torch.sum(torch.linalg.norm(link_transforms_positions, dim=-1))
        
compute_conservative_reach_radius(
    RobotConfig.from_basic(
        f"{os.getcwd()}/robots/piper-dualarm/piper-dualarm.urdf",
        base_link="common_base_link",
        ee_link="left_link8",
    )
)

tensor(1.2158, device='cuda:0')

In [6]:
import torch
torch.set_float32_matmul_precision("high")
torch._dynamo.config.compiled_autograd = True
fit = torch.compile(trainer.fit)

timestep_current = 0
num_timesteps_per_episode = 4
num_timesteps_per_action = 2


while True:
    if timestep_current + num_timesteps_per_episode >= len(xvla_dataset):
        break
    observation = xvla_dataset[
        timestep_current
        :timestep_current + num_timesteps_per_episode
    ]

    action = Action.from_observation(
        observation,
        num_steps=num_timesteps_per_action,
    )

    action_next = action[1:]
    observation_current = observation[:len(action_next)]

    loss = fit(
        observation=observation_current,
        action=action_next,
    )

    timestep_current += len(observation_current)

    if timestep_current % 100 == 0:
        print("TODO episode", timestep_current, loss)


TODO episode 100 tensor(65.3732, device='cuda:0', grad_fn=<CompiledFunctionBackward>)
TODO episode 200 tensor(70.1740, device='cuda:0', grad_fn=<CompiledFunctionBackward>)
TODO episode 300 tensor(19.0454, device='cuda:0', grad_fn=<CompiledFunctionBackward>)
TODO episode 400 tensor(26.9113, device='cuda:0', grad_fn=<CompiledFunctionBackward>)
TODO episode 500 tensor(11.5945, device='cuda:0', grad_fn=<CompiledFunctionBackward>)
TODO episode 600 tensor(20.9806, device='cuda:0', grad_fn=<CompiledFunctionBackward>)
TODO episode 700 tensor(5.1244, device='cuda:0', grad_fn=<CompiledFunctionBackward>)
TODO episode 800 tensor(20.1720, device='cuda:0', grad_fn=<CompiledFunctionBackward>)
TODO episode 900 tensor(58.7672, device='cuda:0', grad_fn=<CompiledFunctionBackward>)
TODO episode 1000 tensor(119.7160, device='cuda:0', grad_fn=<CompiledFunctionBackward>)


In [None]:
from xvla_wlr_experiments.xvla_finetune_piper_v0.experiment import main

main(["samples/2026-01-21_demo_clothes/episode_0/data.json"], num_iterations=10, checkpoint_load_path="todo-checkpoint", checkpoint_save_path="todo-checkpoint")

  0%|          | 0/1.0 [00:00<?, ?it/s]



  0%|          | 0/1.0 [00:00<?, ?it/s]

Checkpoint at iteration 0: todo-checkpoint


  0%|          | 0/1.0 [00:00<?, ?it/s]

Checkpoint at iteration 1: todo-checkpoint


  0%|          | 0/1.0 [00:00<?, ?it/s]

Checkpoint at iteration 2: todo-checkpoint


  0%|          | 0/1.0 [00:00<?, ?it/s]

Checkpoint at iteration 3: todo-checkpoint


  0%|          | 0/1.0 [00:00<?, ?it/s]

Checkpoint at iteration 4: todo-checkpoint


  0%|          | 0/1.0 [00:00<?, ?it/s]

Checkpoint at iteration 5: todo-checkpoint


  0%|          | 0/1.0 [00:00<?, ?it/s]

Checkpoint at iteration 6: todo-checkpoint


  0%|          | 0/1.0 [00:00<?, ?it/s]

Checkpoint at iteration 7: todo-checkpoint


  0%|          | 0/1.0 [00:00<?, ?it/s]

Checkpoint at iteration 8: todo-checkpoint


  0%|          | 0/1.0 [00:00<?, ?it/s]

Checkpoint at iteration 9: todo-checkpoint


: 