You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Hi team, thank you for sharing this fantastic work. I initialize the cluster with jax.distributed.initialize()
and run below command:
#!/bin/bash
#SBATCH --job-name=octo_train # Job name
#SBATCH --nodes=2
#SBATCH --ntasks=16
#SBATCH --ntasks-per-node=8 # Number of nodes
#SBATCH --nodelist=compute-permanent-node-493,compute-permanent-node-580
#SBATCH --gpus-per-node=8 # Request 1 GPU (adjust as needed)
#SBATCH --time=12:00:00 # Time limit hrs:min:sec
srun python scripts/finetune.py --config.pretrained_path=hf://rail-berkeley/octo-small --debug
and running into (I suppose data loader issue) AssertionError: horizon must be <= max_horizon (it's 256 vs 10) which tells me the batch is not split. Have you experienced a similar issue or trained only on tpus before?
The text was updated successfully, but these errors were encountered:
Hi team, thank you for sharing this fantastic work. I initialize the cluster with
jax.distributed.initialize()
and run below command:
and running into (I suppose data loader issue)
AssertionError: horizon must be <= max_horizon
(it's 256 vs 10) which tells me the batch is not split. Have you experienced a similar issue or trained only on tpus before?The text was updated successfully, but these errors were encountered: