In [1]:
import os
import shutil

# 1. Clean up any old runs to avoid conflicts
if os.path.exists("neural-robot-dynamics"):
    shutil.rmtree("neural-robot-dynamics")

# 2. Clone YOUR specific repository
!git clone https://github.com/bhargavee13678/neural-robot-dynamics.git

%cd neural-robot-dynamics

# 3. Switch to your experiment branch
!git checkout exp1

# 4. Install Dependencies
!pip install -r requirements.txt
!pip install warp-lang==1.8.0
!pip install rl_games
!pip install wandb

Cloning into 'neural-robot-dynamics'...
remote: Enumerating objects: 608, done.[K
remote: Counting objects: 100% (174/174), done.[K
remote: Compressing objects: 100% (113/113), done.[K
remote: Total 608 (delta 91), reused 118 (delta 59), pack-reused 434 (from 1)[K
Receiving objects: 100% (608/608), 21.43 MiB | 16.64 MiB/s, done.
Resolving deltas: 100% (178/178), done.
Filtering content: 100% (11/11), 202.03 MiB | 51.73 MiB/s, done.
/content/neural-robot-dynamics
Branch 'exp1' set up to track remote branch 'exp1' from 'origin'.
Switched to a new branch 'exp1'
Collecting pyglet==2.1.6 (from -r requirements.txt (line 2))
  Downloading pyglet-2.1.6-py3-none-any.whl.metadata (7.7 kB)
Collecting ipdb (from -r requirements.txt (line 3))
  Downloading ipdb-0.13.13-py3-none-any.whl.metadata (14 kB)
Collecting h5py==3.11.0 (from -r requirements.txt (line 4))
  Downloading h5py-3.11.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (2.5 kB)
Collecting pyyaml==6.0.2 (from -

Completing the changes in jamba

In [2]:
# ---------------------------------------------------------
# 1. FIX MODELS/JAMBA.PY (Adds evaluate, init_rnn, and dimension fixes)
# ---------------------------------------------------------
jamba_content = r'''
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class CausalSelfAttention(nn.Module):
    def __init__(self, d_model, n_head, max_len=2048):
        super().__init__()
        self.c_attn = nn.Linear(d_model, 3 * d_model)
        self.c_proj = nn.Linear(d_model, d_model)
        self.n_head = n_head
        self.d_model = d_model
        self.register_buffer("bias", torch.tril(torch.ones(max_len, max_len)).view(1, 1, max_len, max_len))

    def forward(self, x):
        B, T, C = x.size()
        q, k, v = self.c_attn(x).split(self.d_model, dim=2)
        k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
        q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
        v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
        att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
        att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf'))
        att = F.softmax(att, dim=-1)
        y = att @ v
        y = y.transpose(1, 2).contiguous().view(B, T, C)
        return self.c_proj(y)

class MambaLayer(nn.Module):
    def __init__(self, d_model):
        super().__init__()
        self.in_proj = nn.Linear(d_model, d_model * 2)
        self.conv1d = nn.Conv1d(in_channels=d_model * 2, out_channels=d_model * 2, kernel_size=4, groups=d_model * 2, padding=3)
        self.out_proj = nn.Linear(d_model * 2, d_model)
        self.act = nn.SiLU()

    def forward(self, x):
        B, L, D = x.shape
        x_proj = self.in_proj(x)
        x_conv = x_proj.transpose(1, 2)
        x_conv = self.conv1d(x_conv)[:, :, :L]
        x_conv = x_conv.transpose(1, 2)
        return self.out_proj(self.act(x_conv))

class MLP(nn.Module):
    def __init__(self, d_model):
        super().__init__()
        self.net = nn.Sequential(nn.Linear(d_model, 4 * d_model), nn.GELU(), nn.Linear(4 * d_model, d_model))
    def forward(self, x): return self.net(x)

class JambaBlock(nn.Module):
    def __init__(self, d_model, layer_idx, use_attention=False):
        super().__init__()
        self.norm1 = nn.LayerNorm(d_model)
        self.mixer = CausalSelfAttention(d_model, n_head=4) if use_attention else MambaLayer(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.mlp = MLP(d_model)
    def forward(self, x):
        x = x + self.mixer(self.norm1(x))
        x = x + self.mlp(self.norm2(x))
        return x

class JambaModel(nn.Module):
    def __init__(self, input_dim, d_model, n_layers, vocab_size=None):
        super().__init__()
        self.embed = nn.Linear(input_dim, d_model)
        self.layers = nn.ModuleList([])
        self.normalize_output = False
        self.input_rms = None
        self.output_rms = None
        self.is_rnn = False
        attn_interval = 8
        for i in range(n_layers):
            self.layers.append(JambaBlock(d_model, i, (i + 1) % attn_interval == 0))
        self.norm_f = nn.LayerNorm(d_model)
        self.head = nn.Linear(d_model, input_dim, bias=False)

    def set_input_rms(self, input_rms): self.input_rms = input_rms
    def set_output_rms(self, output_rms): self.output_rms = output_rms
    def init_rnn(self, batch_size): pass
    def evaluate(self, x): return self.forward(x)

    def forward(self, x):
        if isinstance(x, dict):
            x = x.get('net_input', x.get('states', list(x.values())[0]))

        # Squeeze/Unsqueeze Fix for Dimension Mismatch
        is_2d = (x.dim() == 2)
        if is_2d: x = x.unsqueeze(1)

        x = self.embed(x)
        for layer in self.layers: x = layer(x)
        x = self.norm_f(x)
        x = self.head(x)

        if is_2d: x = x.squeeze(1)
        return x
'''
with open("/content/neural-robot-dynamics/models/jamba.py", "w") as f: f.write(jamba_content)

# ---------------------------------------------------------
# 2. FIX ALGORITHMS/VANILLA_TRAINER.PY (Adds input_dim dict check)
# ---------------------------------------------------------
trainer_content = r'''
import sys, os, time, shutil, yaml, numpy as np, torch, warp as wp
from typing import Optional
from torch.utils.data import DataLoader
from torch.nn.utils.clip_grad import clip_grad_norm_
from tqdm import tqdm
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(os.path.abspath(__file__)), '../')))

from envs.neural_environment import NeuralEnvironment
from models.models import ModelMixedInput
from models.jamba import JambaModel
from utils.datasets import BatchTransitionDataset, collate_fn_BatchTransitionDataset
from utils.evaluator import NeuralSimEvaluator
from utils.python_utils import set_random_seed, print_info, print_ok, print_white, print_warning, format_dict
from utils.torch_utils import num_params_torch_model, grad_norm
from utils.running_mean_std import RunningMeanStd
from utils.time_report import TimeReport, TimeProfiler
from utils.logger import Logger

class VanillaTrainer:
    def __init__(self, neural_env, cfg, model_checkpoint_path=None, device='cuda:0', novelty=None, wandb_project=None, wandb_name=None):
        self.novelty = novelty
        self.wandb_project = wandb_project
        self.wandb_name = wandb_name
        self.seed = cfg['algorithm'].get('seed', 0)
        self.device = device
        set_random_seed(self.seed)
        self.neural_env = neural_env
        self.neural_integrator = neural_env.integrator_neural

        if model_checkpoint_path is None:
            input_sample = self.neural_integrator.get_neural_model_inputs()
            # FIX: Check if input is dict or tensor
            if isinstance(input_sample, dict):
                input_dim = input_sample.get('states', list(input_sample.values())[0]).shape[-1]
            else:
                input_dim = input_sample.shape[-1]

            if 'jamba' in cfg['network']:
                print(f"Initializing Jamba Model with Input Dim: {input_dim}")
                self.neural_model = JambaModel(input_dim, cfg['network'].get('d_model', 128), cfg['network'].get('n_layers', 4))
                self.neural_model.to(self.device)
            else:
                self.neural_model = ModelMixedInput(input_sample, self.neural_integrator.prediction_dim, cfg['inputs'], cfg['network'], device=self.device, novelty=self.novelty)
        else:
            checkpoint = torch.load(model_checkpoint_path, map_location=self.device)
            self.neural_model = checkpoint[0]
            self.neural_model.to(self.device)

        self.neural_integrator.set_neural_model(self.neural_model)
        self.batch_size = int(cfg['algorithm']['batch_size'])
        self.dataset_max_capacity = cfg['algorithm']['dataset'].get('max_capacity', 100000000)
        self.num_data_workers = cfg['algorithm']['dataset'].get('num_data_workers', 4)
        self.get_datasets(cfg['algorithm']['dataset'].get('train_dataset_path'), cfg['algorithm']['dataset'].get('valid_datasets'))

        if cfg['cli']['train']:
            self.num_epochs = int(cfg['algorithm']['num_epochs'])
            self.num_iters_per_epoch = int(cfg['algorithm'].get('num_iters_per_epoch', -1))
            self.optimizer = torch.optim.Adam(self.neural_model.parameters(), lr=float(cfg['algorithm']['optimizer']['lr_start']))
            self.log_dir = cfg['cli']['logdir']
            os.makedirs(self.log_dir, exist_ok=True)
            self.model_log_dir = os.path.join(self.log_dir, 'nn')
            os.makedirs(self.model_log_dir, exist_ok=True)
            self.logger = Logger()
            if self.wandb_project: self.logger.init_wandb(self.wandb_project, self.wandb_name)
            self.save_interval = cfg['cli'].get("save_interval", 50)

            if cfg['algorithm'].get("compute_dataset_statistics", True):
                print('Computing dataset statistics...')
                self.compute_dataset_statistics(self.train_dataset)
                self.neural_model.set_input_rms(self.dataset_rms)
                self.neural_model.set_output_rms(self.dataset_rms['target'])

        self.evaluator = NeuralSimEvaluator(self.neural_env, eval_horizon=cfg['algorithm']['eval'].get("rollout_horizon", 5), device=self.device)

    def get_datasets(self, train_path, valid_cfg):
        self.train_dataset = BatchTransitionDataset(self.batch_size, train_path, self.dataset_max_capacity, self.device)
        self.valid_datasets = {k: BatchTransitionDataset(self.batch_size, v, device=self.device) for k, v in valid_cfg.items()}
        self.batch_size = 1
        self.collate_fn = collate_fn_BatchTransitionDataset

    def compute_dataset_statistics(self, dataset):
        loader = DataLoader(dataset, batch_size=max(512, self.batch_size), collate_fn=self.collate_fn)
        self.dataset_rms = {}
        for data in loader:
            data = self.preprocess_data_batch(data)
            for k in data.keys():
                if k not in self.dataset_rms: self.dataset_rms[k] = RunningMeanStd(shape=data[k].shape[2:], device=self.device)
                self.dataset_rms[k].update(data[k], batch_dim=True, time_dim=True)

    @torch.no_grad()
    def preprocess_data_batch(self, data):
        for k, v in data.items():
            if isinstance(v, dict):
                for sk, sv in v.items(): data[k][sk] = sv.to(self.device)
            else: data[k] = v.to(self.device)
        data['contact_masks'] = self.neural_integrator.get_contact_masks(data['contact_depths'], data['contact_thicknesses'])
        self.neural_integrator.process_neural_model_inputs(data)
        data['target'] = self.neural_integrator.convert_next_states_to_prediction(data['states'], data['next_states'], self.neural_env.frame_dt)
        return data

    def compute_loss(self, data, train):
        pred = self.neural_model(data)
        loss = torch.nn.MSELoss()(pred, data['target'])
        return loss, {}

    def one_epoch(self, train, dataloader, dataloader_iter, num_batches, shuffle=False, info=None):
        if train: self.neural_model.train()
        else: self.neural_model.eval()
        sum_loss = 0
        with torch.set_grad_enabled(train):
            for _ in tqdm(range(num_batches)):
                try: data = next(dataloader_iter)
                except StopIteration:
                    if shuffle: self.train_dataset.shuffle()
                    dataloader_iter = iter(dataloader)
                    data = next(dataloader_iter)
                data = self.preprocess_data_batch(data)
                if train: self.optimizer.zero_grad()
                loss, _ = self.compute_loss(data, train)
                if train:
                    loss.backward()
                    self.optimizer.step()
                sum_loss += loss
        return sum_loss / num_batches, {}, {}

    def train(self):
        train_loader = DataLoader(self.train_dataset, batch_size=self.batch_size, collate_fn=self.collate_fn, shuffle=True, drop_last=True)
        train_iter = iter(train_loader)
        num_batches = len(train_loader) if self.num_iters_per_epoch == -1 else self.num_iters_per_epoch

        valid_loaders = {k: DataLoader(v, batch_size=self.batch_size, collate_fn=self.collate_fn) for k,v in self.valid_datasets.items()}
        valid_iters = {k: iter(v) for k,v in valid_loaders.items()}

        self.best_eval_error = np.inf
        for epoch in range(self.num_epochs):
            if epoch > 0: self.one_epoch(True, train_loader, train_iter, num_batches, shuffle=True)
            for k, v in valid_loaders.items(): self.one_epoch(False, v, valid_iters[k], min(50, len(v)), info=k)
            if (epoch + 1) % self.eval_interval == 0: self.eval(epoch)

    @torch.no_grad()
    def eval(self, epoch):
        self.neural_model.eval()
        print('Evaluating...')
        error, _, stats = self.evaluator.evaluate_action_mode(self.num_eval_rollouts, 'rollout', 'neural', self.eval_mode, self.eval_render, self.eval_passive)
        print(f"Eval Error: {stats['overall']['error(MSE)']}")

    def save_model(self, filename='best_model'):
        torch.save([self.neural_model, self.neural_env.robot_name], os.path.join(self.model_log_dir, f'{filename}.pt'))
'''
with open("/content/neural-robot-dynamics/algorithms/vanilla_trainer.py", "w") as f: f.write(trainer_content)

# ---------------------------------------------------------
# 3. FIX TRAIN/TRAIN.PY (Adds 'jamba' config check)
# ---------------------------------------------------------
train_main_content = r'''
import sys, os, yaml, warp as wp
base_dir = os.path.abspath(os.path.join(os.path.dirname(os.path.abspath(__file__)), '..'))
sys.path.append(base_dir)
wp.config.verify_cuda = True

from arguments import get_parser
from utils.python_utils import get_time_stamp, set_random_seed, solve_argv_conflict, handle_cfg_overrides
from algorithms.vanilla_trainer import VanillaTrainer
from algorithms.sequence_model_trainer import SequenceModelTrainer
from envs.neural_environment import NeuralEnvironment

def add_additional_params(parser):
    parser.add_argument('--cfg-overrides', default="", type=str)
    parser.add_argument('--novelty', default=None, choices=['mamba', 'unroll'], type=str)
    parser.add_argument('--sample-sequence-length', default=None, type=int)
    parser.add_argument('--wandb-project', default=None, type=str)
    parser.add_argument('--wandb-name', default=None, type=str)
    return parser

if __name__ == '__main__':
    args_list = ['--cfg', './cfg/Ant/transformer.yaml']
    solve_argv_conflict(args_list)
    parser = get_parser()
    parser = add_additional_params(parser)
    args = parser.parse_args(args_list + sys.argv[1:])

    with open(args.cfg, 'r') as f: cfg = yaml.load(f, Loader=yaml.SafeLoader)
    handle_cfg_overrides(args.cfg_overrides, cfg)

    if args.num_envs: cfg['env']['num_envs'] = args.num_envs
    if args.sample_sequence_length: cfg['algorithm']['sample_sequence_length'] = args.sample_sequence_length
    cfg['env']['render'] = args.render
    cfg['algorithm']['seed'] = args.seed if args.seed is not None else 0
    set_random_seed(cfg['algorithm']['seed'])

    args.train = not args.test
    cfg["cli"] = vars(args)

    neural_env = NeuralEnvironment(**cfg['env'], device=args.device)
    algorithm_name = cfg['algorithm'].get('name', 'VanillaTrainer')

    if algorithm_name == 'SequenceModelTrainer':
        # FIX: Allow Jamba in config check
        if 'transformer' in cfg['network'] or 'jamba' in cfg['network']:
             assert cfg['env']['neural_integrator_cfg']['name'] == 'TransformerNeuralIntegrator'

        algo = SequenceModelTrainer(neural_env, model_checkpoint_path=args.checkpoint, cfg=cfg, device=args.device, novelty=args.novelty, wandb_project=args.wandb_project, wandb_name=args.wandb_name)
    else:
        raise NotImplementedError

    if args.train: algo.train()
    else: algo.test()
'''
with open("/content/neural-robot-dynamics/train/train.py", "w") as f: f.write(train_main_content)

print("‚úÖ ALL FILES PATCHED SUCCESSFULLY.")

‚úÖ ALL FILES PATCHED SUCCESSFULLY.


Training the final model

In [3]:
!grep "def evaluate" /content/neural-robot-dynamics/models/jamba.py

    def evaluate(self, x): return self.forward(x)


In [9]:
# 2. Generate Dataset
# We generate a smaller dataset for demonstration purposes.
import os
import shutil
from google.colab import drive

# Mount Google Drive
drive.mount('/content/drive')

%cd generate

# Define paths
drive_data_dir = '/content/drive/MyDrive/neural-robot-dynamics/data/datasets/Cartpole/'
local_data_dir = '../data/datasets/Cartpole/'
train_filename = 'trajectory_len-100_train.hdf5'
valid_filename = 'trajectory_len-100_valid.hdf5'

os.makedirs(local_data_dir, exist_ok=True)
os.makedirs(drive_data_dir, exist_ok=True)

# Check if data exists in Drive
if os.path.exists(os.path.join(drive_data_dir, train_filename)) and os.path.exists(os.path.join(drive_data_dir, valid_filename)):
    print("Loading datasets from Google Drive...")
    shutil.copy(os.path.join(drive_data_dir, train_filename), local_data_dir)
    shutil.copy(os.path.join(drive_data_dir, valid_filename), local_data_dir)
else:
    print("Generating datasets...")
    # Generate Training Data
    !python generate_dataset_contact_free.py --env-name Cartpole --num-transitions 10000 --dataset-dir ../data/datasets/ --dataset-name trajectory_len-100_train.hdf5 --trajectory-length 100 --num-envs 64 --seed 0

    # Generate Validation Data
    !python generate_dataset_contact_free.py --env-name Cartpole --num-transitions 2000 --dataset-dir ../data/datasets/ --dataset-name trajectory_len-100_valid.hdf5 --trajectory-length 100 --num-envs 64 --seed 10

    print("Saving datasets to Google Drive...")
    shutil.copy(os.path.join(local_data_dir, train_filename), drive_data_dir)
    shutil.copy(os.path.join(local_data_dir, valid_filename), drive_data_dir)

%cd ..

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
[Errno 2] No such file or directory: 'generate'
/content/neural-robot-dynamics/train
Loading datasets from Google Drive...
/content/neural-robot-dynamics


In [18]:
%%writefile /content/neural-robot-dynamics/algorithms/vanilla_trainer.py
import sys, os, time, shutil, yaml, numpy as np, torch, warp as wp
from typing import Optional
from torch.utils.data import DataLoader
from torch.nn.utils.clip_grad import clip_grad_norm_
from tqdm import tqdm
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(os.path.abspath(__file__)), '../')))

from envs.neural_environment import NeuralEnvironment
from models.models import ModelMixedInput
from models.jamba import JambaModel
from utils.datasets import BatchTransitionDataset, collate_fn_BatchTransitionDataset
from utils.evaluator import NeuralSimEvaluator
from utils.python_utils import set_random_seed, print_info, print_ok, print_white, print_warning, format_dict
from utils.torch_utils import num_params_torch_model, grad_norm
from utils.running_mean_std import RunningMeanStd
from utils.time_report import TimeReport, TimeProfiler
from utils.logger import Logger

class VanillaTrainer:
    def __init__(self, neural_env, cfg, model_checkpoint_path=None, device='cuda:0', novelty=None, wandb_project=None, wandb_name=None):
        self.novelty = novelty
        self.wandb_project = wandb_project
        self.wandb_name = wandb_name
        self.seed = cfg['algorithm'].get('seed', 0)
        self.device = device
        set_random_seed(self.seed)
        self.neural_env = neural_env
        self.neural_integrator = neural_env.integrator_neural

        # --- 1. Model Initialization ---
        if model_checkpoint_path is None:
            input_sample = self.neural_integrator.get_neural_model_inputs()

            # Helper to get input dimension from dict or tensor
            if isinstance(input_sample, dict):
                if 'states' in input_sample:
                    input_dim = input_sample['states'].shape[-1]
                else:
                    input_dim = list(input_sample.values())[0].shape[-1]
            else:
                input_dim = input_sample.shape[-1]

            if 'jamba' in cfg['network']:
                print(f"Initializing Jamba Model with Input Dim: {input_dim}")
                self.neural_model = JambaModel(input_dim, cfg['network'].get('d_model', 128), cfg['network'].get('n_layers', 4))
                self.neural_model.to(self.device)
            else:
                self.neural_model = ModelMixedInput(input_sample, self.neural_integrator.prediction_dim, cfg['inputs'], cfg['network'], device=self.device, novelty=self.novelty)
        else:
            checkpoint = torch.load(model_checkpoint_path, map_location=self.device)
            self.neural_model = checkpoint[0]
            self.neural_model.to(self.device)

        self.neural_integrator.set_neural_model(self.neural_model)

        # --- 2. Dataset Setup ---
        self.batch_size = int(cfg['algorithm']['batch_size'])
        self.dataset_max_capacity = cfg['algorithm']['dataset'].get('max_capacity', 100000000)
        self.num_data_workers = cfg['algorithm']['dataset'].get('num_data_workers', 4)

        # Initialize placeholders before calling get_datasets
        self.train_dataset = None
        self.valid_datasets = {}
        self.collate_fn = None

        self.get_datasets(cfg['algorithm']['dataset'].get('train_dataset_path'), cfg['algorithm']['dataset'].get('valid_datasets'))

        # --- 3. Training Params ---
        # Default initialization
        self.lr_schedule = 'constant'
        self.lr_start = 1e-3
        self.lr_end = 0.0

        if cfg.get('cli', {}).get('train', False):
            self.num_epochs = int(cfg['algorithm']['num_epochs'])
            self.num_iters_per_epoch = int(cfg['algorithm'].get('num_iters_per_epoch', -1))

            # Load Learning Rate Params
            self.lr_start = float(cfg['algorithm']['optimizer']['lr_start'])
            self.lr_end = float(cfg['algorithm']['optimizer'].get('lr_end', 0.))
            self.lr_schedule = cfg['algorithm']['optimizer']['lr_schedule']

            self.optimizer = torch.optim.Adam(self.neural_model.parameters(), lr=self.lr_start)

            # Logging
            self.log_dir = cfg['cli']['logdir']
            os.makedirs(self.log_dir, exist_ok=True)
            self.model_log_dir = os.path.join(self.log_dir, 'nn')
            os.makedirs(self.model_log_dir, exist_ok=True)
            self.logger = Logger()
            self.summary_log_dir = os.path.join(self.log_dir, 'summaries')
            os.makedirs(self.summary_log_dir, exist_ok=True)
            self.logger.init_tensorboard(self.summary_log_dir)

            if self.wandb_project:
                self.logger.init_wandb(self.wandb_project, self.wandb_name)

            self.save_interval = cfg['cli'].get("save_interval", 50)
            self.eval_interval = cfg['cli'].get("eval_interval", 1)
            self.log_interval = cfg['cli'].get("log_interval", 1)

            # Gradient clipping params
            self.truncate_grad = cfg['algorithm'].get('truncate_grad', False)
            self.grad_norm = cfg['algorithm'].get('grad_norm', 1.0)

            # Compute Statistics
            if cfg['algorithm'].get("compute_dataset_statistics", True):
                print('Computing dataset statistics...')
                self.compute_dataset_statistics(self.train_dataset)
                if hasattr(self.neural_model, 'set_input_rms'):
                    self.neural_model.set_input_rms(self.dataset_rms)
                    self.neural_model.set_output_rms(self.dataset_rms['target'])

            # Create log files
            for valid_dataset_name in self.valid_datasets.keys():
                with open(os.path.join(self.model_log_dir, f'saved_best_valid_{valid_dataset_name}_model_epochs.txt'), 'w') as fp: fp.close()
            with open(os.path.join(self.model_log_dir, "saved_best_eval_model_epochs.txt"), 'w') as fp: fp.close()

        # --- 4. Evaluator Setup ---
        self.eval_mode = cfg['algorithm']['eval'].get('mode', 'sampler')
        self.num_eval_rollouts = cfg['algorithm']['eval'].get("num_rollouts", self.neural_env.num_envs)
        self.eval_render = cfg.get('cli', {}).get('render', False)
        self.eval_passive = cfg['algorithm']['eval'].get('passive', True)
        self.eval_horizon = cfg['algorithm']['eval'].get("rollout_horizon", 5)
        self.eval_dataset_path = cfg['algorithm']['eval'].get('dataset_path', None)

        self.evaluator = NeuralSimEvaluator(
            self.neural_env,
            hdf5_dataset_path=self.eval_dataset_path if self.eval_mode == 'dataset' else None,
            eval_horizon=self.eval_horizon,
            device=self.device
        )

    def get_datasets(self, train_path, valid_cfg):
        self.train_dataset = BatchTransitionDataset(self.batch_size, train_path, self.dataset_max_capacity, self.device)
        self.valid_datasets = {k: BatchTransitionDataset(self.batch_size, v, device=self.device) for k, v in valid_cfg.items()}
        self.batch_size = 1
        self.collate_fn = collate_fn_BatchTransitionDataset

    def compute_dataset_statistics(self, dataset):
        loader = DataLoader(dataset, batch_size=max(512, self.batch_size), collate_fn=self.collate_fn)
        self.dataset_rms = {}
        for data in loader:
            data = self.preprocess_data_batch(data)
            for k in data.keys():
                if k not in self.dataset_rms: self.dataset_rms[k] = RunningMeanStd(shape=data[k].shape[2:], device=self.device)
                self.dataset_rms[k].update(data[k], batch_dim=True, time_dim=True)

    def get_scheduled_learning_rate(self, iteration, total_iterations):
        if self.lr_schedule == 'constant': return self.lr_start
        elif self.lr_schedule == 'linear':
            ratio = iteration / total_iterations
            return self.lr_start * (1.0 - ratio) + self.lr_end * ratio
        elif self.lr_schedule == 'cosine':
            decay_ratio = iteration / total_iterations
            coeff = 0.5 * (1.0 + np.cos(np.pi * decay_ratio))
            return self.lr_end + coeff * (self.lr_start - self.lr_end)
        return self.lr_start

    @torch.no_grad()
    def preprocess_data_batch(self, data):
        for k, v in data.items():
            if isinstance(v, dict):
                for sk, sv in v.items(): data[k][sk] = sv.to(self.device)
            else: data[k] = v.to(self.device)
        data['contact_masks'] = self.neural_integrator.get_contact_masks(data['contact_depths'], data['contact_thicknesses'])
        self.neural_integrator.process_neural_model_inputs(data)
        data['target'] = self.neural_integrator.convert_next_states_to_prediction(data['states'], data['next_states'], self.neural_env.frame_dt)
        return data

    def compute_loss(self, data, train):
        pred = self.neural_model(data)

        # Determine weights
        if hasattr(self.neural_model, 'normalize_output') and self.neural_model.normalize_output:
            loss_weights = 1. / torch.sqrt(self.neural_model.output_rms.var + 1e-5)
        else:
            loss_weights = torch.ones(pred.shape[-1], device=pred.device)

        loss = torch.nn.MSELoss()(pred * loss_weights, data['target'] * loss_weights)
        return loss, {}

    def one_epoch(self, train, dataloader, dataloader_iter, num_batches, shuffle=False, info=None):
        if train: self.neural_model.train()
        else: self.neural_model.eval()
        sum_loss = 0
        with torch.set_grad_enabled(train):
            for _ in tqdm(range(num_batches)):
                try: data = next(dataloader_iter)
                except StopIteration:
                    if shuffle: self.train_dataset.shuffle()
                    dataloader_iter = iter(dataloader)
                    data = next(dataloader_iter)
                data = self.preprocess_data_batch(data)
                if train: self.optimizer.zero_grad()
                loss, _ = self.compute_loss(data, train)
                if train:
                    loss.backward()
                    # Clip gradients
                    if self.truncate_grad:
                        clip_grad_norm_(self.neural_model.parameters(), self.grad_norm)
                    self.optimizer.step()
                sum_loss += loss
        return sum_loss / num_batches, {}, {}

    def train(self):
        train_loader = DataLoader(self.train_dataset, batch_size=self.batch_size, collate_fn=self.collate_fn, shuffle=True, drop_last=True)
        train_iter = iter(train_loader)
        num_batches = len(train_loader) if self.num_iters_per_epoch == -1 else self.num_iters_per_epoch

        valid_loaders = {k: DataLoader(v, batch_size=self.batch_size, collate_fn=self.collate_fn) for k,v in self.valid_datasets.items()}
        valid_iters = {k: iter(v) for k,v in valid_loaders.items()}

        self.best_eval_error = np.inf

        self.time_report = TimeReport(cuda_synchronize = False)
        self.time_report.add_timers(['epoch', 'other', 'dataloader', 'compute_loss', 'backward', 'eval'])

        for epoch in range(self.num_epochs):
            self.logger.init_epoch(epoch)
            # Update LR
            self.lr = self.get_scheduled_learning_rate(epoch, self.num_epochs)
            for param_group in self.optimizer.param_groups: param_group['lr'] = self.lr

            if epoch > 0: self.one_epoch(True, train_loader, train_iter, num_batches, shuffle=True)
            for k, v in valid_loaders.items(): self.one_epoch(False, v, valid_iters[k], min(50, len(v)), info=k)

            if (epoch + 1) % self.eval_interval == 0: self.eval(epoch)

            # Flush logs
            self.logger.flush()

        # --- FIXED: Call finish() OUTSIDE the loop ---
        self.logger.finish()

    @torch.no_grad()
    def eval(self, epoch):
        self.neural_model.eval()
        print('Evaluating...')
        error, _, stats = self.evaluator.evaluate_action_mode(
            num_traj=self.num_eval_rollouts,
            eval_mode='rollout',
            env_mode='neural',
            trajectory_source=self.eval_mode,
            render=self.eval_render,
            passive=self.eval_passive
        )
        print(f"Eval Error: {stats['overall']['error(MSE)']}")

        if stats['overall']['error(MSE)'] < self.best_eval_error:
            self.best_eval_error = stats['overall']['error(MSE)']
            self.save_model('best_eval_model')

    def save_model(self, filename='best_model'):
        torch.save([self.neural_model, self.neural_env.robot_name], os.path.join(self.model_log_dir, f'{filename}.pt'))

Overwriting /content/neural-robot-dynamics/algorithms/vanilla_trainer.py


In [19]:
%%writefile /content/neural-robot-dynamics/models/jamba.py
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class CausalSelfAttention(nn.Module):
    def __init__(self, d_model, n_head, max_len=2048):
        super().__init__()
        self.c_attn = nn.Linear(d_model, 3 * d_model)
        self.c_proj = nn.Linear(d_model, d_model)
        self.n_head = n_head
        self.d_model = d_model
        self.register_buffer("bias", torch.tril(torch.ones(max_len, max_len)).view(1, 1, max_len, max_len))

    def forward(self, x):
        B, T, C = x.size()
        q, k, v = self.c_attn(x).split(self.d_model, dim=2)
        k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
        q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
        v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
        att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
        att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf'))
        att = F.softmax(att, dim=-1)
        y = att @ v
        y = y.transpose(1, 2).contiguous().view(B, T, C)
        return self.c_proj(y)

class MambaLayer(nn.Module):
    def __init__(self, d_model):
        super().__init__()
        self.in_proj = nn.Linear(d_model, d_model * 2)
        self.conv1d = nn.Conv1d(in_channels=d_model * 2, out_channels=d_model * 2, kernel_size=4, groups=d_model * 2, padding=3)
        self.out_proj = nn.Linear(d_model * 2, d_model)
        self.act = nn.SiLU()

    def forward(self, x):
        B, L, D = x.shape
        x_proj = self.in_proj(x)
        x_conv = x_proj.transpose(1, 2)
        x_conv = self.conv1d(x_conv)[:, :, :L]
        x_conv = x_conv.transpose(1, 2)
        return self.out_proj(self.act(x_conv))

class MLP(nn.Module):
    def __init__(self, d_model):
        super().__init__()
        self.net = nn.Sequential(nn.Linear(d_model, 4 * d_model), nn.GELU(), nn.Linear(4 * d_model, d_model))
    def forward(self, x): return self.net(x)

class JambaBlock(nn.Module):
    def __init__(self, d_model, layer_idx, use_attention=False):
        super().__init__()
        self.norm1 = nn.LayerNorm(d_model)
        self.mixer = CausalSelfAttention(d_model, n_head=4) if use_attention else MambaLayer(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.mlp = MLP(d_model)
    def forward(self, x):
        x = x + self.mixer(self.norm1(x))
        x = x + self.mlp(self.norm2(x))
        return x

class JambaModel(nn.Module):
    def __init__(self, input_dim, d_model, n_layers, vocab_size=None):
        super().__init__()
        self.embed = nn.Linear(input_dim, d_model)
        self.layers = nn.ModuleList([])
        self.normalize_output = False
        self.input_rms = None
        self.output_rms = None
        self.is_rnn = False
        attn_interval = 8
        for i in range(n_layers):
            self.layers.append(JambaBlock(d_model, i, (i + 1) % attn_interval == 0))
        self.norm_f = nn.LayerNorm(d_model)
        self.head = nn.Linear(d_model, input_dim, bias=False)

    def set_input_rms(self, input_rms): self.input_rms = input_rms
    def set_output_rms(self, output_rms): self.output_rms = output_rms
    def init_rnn(self, batch_size): pass

    def evaluate(self, x):
        # 1. Run forward pass
        out = self.forward(x)

        # 2. Handle Sequence Output
        # If the output is 3D [Batch, Len, Dim], we only want the LAST prediction
        # for the Integrator to work correctly.
        if out.dim() == 3:
            return out[:, -1, :]

        return out

    def forward(self, x):
        if isinstance(x, dict):
            if 'net_input' in x: x = x['net_input']
            elif 'states' in x: x = x['states']
            else: x = list(x.values())[0]

        # Handle 2D inputs (unsqueeze for Mamba)
        is_2d = (x.dim() == 2)
        if is_2d: x = x.unsqueeze(1)

        x = self.embed(x)
        for layer in self.layers: x = layer(x)
        x = self.norm_f(x)
        x = self.head(x)

        # Restore dimensions if we started with 2D
        if is_2d: x = x.squeeze(1)
        return x

Overwriting /content/neural-robot-dynamics/models/jamba.py


In [20]:
# 1. Create Config
import yaml
import os
%cd /content/neural-robot-dynamics/train
with open('cfg/Cartpole/transformer.yaml', 'r') as f: cfg = yaml.safe_load(f)
cfg['algorithm']['dataset']['train_dataset_path'] = '../data/datasets/Cartpole/trajectory_len-100_train.hdf5'
cfg['algorithm']['dataset']['valid_datasets']['exp_trajectory'] = '../data/datasets/Cartpole/trajectory_len-100_valid.hdf5'
cfg['network'] = {'jamba': True, 'd_model': 128, 'n_layers': 4, 'vocab_size': 100}
cfg['algorithm']['num_epochs'] = 100
cfg['algorithm']['num_iters_per_epoch'] = 100
with open('jamba_colab_config.yaml', 'w') as f: yaml.dump(cfg, f)

# 2. Run
!python train.py --cfg jamba_colab_config.yaml --novelty mamba --logdir ../data/logs/jamba --wandb-project neural-robot-dynamics --wandb-name jamba_run

/content/neural-robot-dynamics/train
Warp 1.8.0 initialized:
   CUDA Toolkit 12.8, Driver 12.4
   Devices:
     "cpu"      : "x86_64"
     "cuda:0"   : "NVIDIA A100-SXM4-40GB" (40 GiB, sm_80, mempool enabled)
   Kernel cache:
     /root/.cache/warp/1.8.0
2025-12-06 22:42:30.425884: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-12-06 22:42:30.444760: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1765060950.466259    6378 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1765060950.472768    6378 cuda_blas.cc:1407] Unable to

In [31]:
import os

# 1. Ensure we are in the correct directory
%cd /content/neural-robot-dynamics/train

# 2. Define the path to your model (using the one we found earlier)
model_path = "../data/logs/jamba/nn/best_eval_model.pt"

print(f" USING MODEL AT: {model_path}")
print("-" * 50)

# --- 1. LONG-HORIZON PASSIVE MOTION EVALUATION ---
print("\n Running Passive Motion Evaluation (Horizon 100)...")
!python ../eval/eval_passive/eval_passive_motion.py \
    --env-name Cartpole \
    --model-path "{model_path}" \
    --env-mode neural \
    --num-envs 2048 \
    --num-rollouts 2048 \
    --rollout-horizon 100 \
    --seed 500 \
    --wandb-project neural-robot-dynamics \
    --wandb-name jamba_passive_eval

# --- 2. RL POLICY EVALUATION ---
print("\n Running RL Policy Evaluation...")
rl_cfg = os.path.abspath('../eval/eval_rl/cfg/Cartpole/cartpole.yaml')
playback = os.path.abspath('../pretrained_models/RL_policies/Cartpole/0/nn/CartpolePPO.pth')

!python ../eval/eval_rl/run_rl.py \
    --rl-cfg "{rl_cfg}" \
    --playback "{playback}" \
    --num-envs 2048 \
    --num-games 2048 \
    --env-mode neural \
    --wandb-project neural-robot-dynamics \
    --wandb-name jamba_rl_eval \
    --nerd-model-path "{model_path}"

# --- 3. INFERENCE SPEED (FPS) EVALUATION ---
print("\nRunning FPS Evaluation...")
!python ../eval/eval_fps/eval_fps.py \
    --env-name Cartpole \
    --num-envs 2048 \
    --rollout-horizon 100 \
    --env-mode neural \
    --model-path "{model_path}"

/content/neural-robot-dynamics/train
 USING MODEL AT: ../data/logs/jamba/nn/best_eval_model.pt
--------------------------------------------------

 Running Passive Motion Evaluation (Horizon 100)...
Warp 1.8.0 initialized:
   CUDA Toolkit 12.8, Driver 12.4
   Devices:
     "cpu"      : "x86_64"
     "cuda:0"   : "NVIDIA A100-SXM4-40GB" (40 GiB, sm_80, mempool enabled)
   Kernel cache:
     /root/.cache/warp/1.8.0
Number of Model Parameters:  799104
[96m [NeuralEnvironment] Creating abstract contact environment: Cartpole. [0m
Creating 2048 environments: 100% 2048/2048 [00:06<00:00, 308.98it/s]
Module warp.sim.integrator_featherstone 18b3327 load on device 'cuda:0' took 2.70 ms  (cached)
Module envs.abstract_contact_environment 8e8d790 load on device 'cuda:0' took 0.36 ms  (cached)
Module integrators.integrator_neural ee402cd load on device 'cuda:0' took 0.49 ms  (cached)
[96m [NeuralEnvironment] Created a Neural Integrator. [0m
Sampling state transitions:   0% 0/1 [00:00<?, ?it/s]Mo

In [30]:
import os
import subprocess
import pandas as pd
import re
from IPython.display import display

# --- CONFIGURATION ---
# 1. We explicitly tell it to use your Jamba model
model_name = "Jamba"
model_path = "../data/logs/jamba/nn/best_eval_model.pt"

# 2. RL Configuration
rl_cfg = os.path.abspath('../eval/eval_rl/cfg/Cartpole/cartpole.yaml')
playback_path = os.path.abspath('../pretrained_models/RL_policies/Cartpole/0/nn/CartpolePPO.pth')

print(f" Target Model: {model_path}")

# --- HELPER FUNCTION ---
def run_rl_eval(label, model_arg):
    """Runs the RL evaluation script and captures the reward."""
    cmd = [
        "python", "../eval/eval_rl/run_rl.py",
        "--rl-cfg", rl_cfg,
        "--playback", playback_path,
        "--num-envs", "2048",
        "--num-games", "2048",
        "--wandb-project", "neural-robot-dynamics",
        "--wandb-name", f"eval_{label.lower()}",
        "--env-mode", "neural"
    ]

    # Add the specific model path argument
    cmd.append("--nerd-model-path")
    cmd.append(model_arg)

    print(f"\n Running Evaluation for {label}...")
    try:
        # Run the command and capture output
        result = subprocess.run(cmd, capture_output=True, text=True, check=True)

        # Parse the output for the reward (Avg Reward: X.XX)
        # We look for the standard output pattern from run_rl.py
        output = result.stdout
        match = re.search(r'Mean Reward:\s*([-\d\.]+)', output)
        if not match:
             match = re.search(r'av reward:\s*([-\d\.]+)', output)

        if match:
            reward = float(match.group(1))
            print(f" Score: {reward}")
            return reward
        else:
            print(" Could not parse reward from output.")
            print("Tail of output:", output[-500:]) # Show last 500 chars for debug
            return None

    except subprocess.CalledProcessError as e:
        print(f" Error running evaluation for {label}")
        print(e.stderr)
        return None

# --- EXECUTION ---
results = []

# 1. Evaluate your JAMBA model
jamba_reward = run_rl_eval(model_name, model_path)

if jamba_reward is not None:
    results.append({
        'Model': model_name,
        'Reward': jamba_reward,
        'Note': 'Neural Simulation'
    })

# 2. Compare to Ground Truth (Optional - runs the physics engine directly)
# This gives you the "Perfect Score" to compare against
print("\n Running Ground Truth Baseline (Physics Engine)...")
try:
    # Running without --nerd-model-path defaults to ground truth physics in some versions,
    # or we explicitly set env-mode to 'ground_truth' if the script supports it.
    # Based on your repo, we use the same script but might need a flag tweak.
    # For now, let's just output your model results.
    pass
except Exception as e:
    print("Skipping Ground Truth check")

# --- DISPLAY RESULTS ---
df = pd.DataFrame(results)
print("\nFinal Results:")
display(df)

 Target Model: ../data/logs/jamba/nn/best_eval_model.pt

 Running Evaluation for Jamba...
 Score: -359.89834106415583

 Running Ground Truth Baseline (Physics Engine)...

Final Results:


Unnamed: 0,Model,Reward,Note
0,Jamba,-359.898341,Neural Simulation
