In [1]:
import torch
import snntorch
import morphSNN as morph
import Reinforcement as re
import numpy as np
import mujoco
import gymnasium as gym



In [8]:
from mujoco import viewer

model = mujoco.MjModel.from_xml_path("quadruped_var1.xml")

data = mujoco.MjData(model)

viewer.launch(model, data)

In [8]:

env = gym.make("Ant-v5", xml_file="../quadruped.xml")
model = env.unwrapped.model

print("Number of joints:", model.njnt)

for j in range(model.njnt):
    name = mujoco.mj_id2name(model, mujoco.mjtObj.mjOBJ_JOINT, j)
    qpos_adr = model.jnt_qposadr[j]  # index into data.qpos
    dof_adr = model.jnt_dofadr[j]    # index into data.qvel
    print(f"joint {j:2d} | name={name!r} | qpos index={qpos_adr} | dof index={dof_adr}")
    

Number of joints: 9
joint  0 | name='root' | qpos index=0 | dof index=0
joint  1 | name='hip_1' | qpos index=7 | dof index=6
joint  2 | name='ankle_1' | qpos index=8 | dof index=7
joint  3 | name='hip_2' | qpos index=9 | dof index=8
joint  4 | name='ankle_2' | qpos index=10 | dof index=9
joint  5 | name='hip_3' | qpos index=11 | dof index=10
joint  6 | name='ankle_3' | qpos index=12 | dof index=11
joint  7 | name='hip_4' | qpos index=13 | dof index=12
joint  8 | name='ankle_4' | qpos index=14 | dof index=13


In [None]:
model = re.train_ppo_with_pose_template(
    run_name="quad_alif",
    pose_generator=None,
    morph_vec=var1_morph,
    xml_path="./quadruped_var1.xml"
    timesteps=2_000_000,
    parallel_envs=4,
    initial_learning_rate=3e-4
)

In [2]:

var1_morph = [
    0.141421, 0.282843,  # FR leg, FR ankle
    0.141421, 0.282843,  # FL leg, FL ankle
    0.141421, 0.282843,  # BL leg, BL ankle
    0.141421, 0.282843,  # BR leg, BR ankle
]

var2_morph = [
    0.141421, 0.16,  # FR leg, FR ankle
    0.141421, 0.16,  # FL leg, FL ankle
    0.141421, 0.282843,  # BL leg, BL ankle
    0.141421, 0.282843,  # BR leg, BR ankle
]

var3_morph = [
    0.141421, 0.17,  # FR leg, FR ankle
    0.141421, 0.17,  # FL leg, FL ankle
    0.16, 0.15,  # BL leg, BL ankle
    0.16, 0.15,  # BR leg, BR ankle
]

var4_morph = [
    0.16, 0.3,  # FR leg, FR ankle
    0.141421, 0.28,  # FL leg, FL ankle
    0.141421, 0.28,  # BL leg, BL ankle
    0.141421, 0.28,  # BR leg, BR ankle
]

morph_specs = [
   # ("./quadruped_var1.xml", var1_morph),
   # ("./quadruped_var2.xml", var2_morph),
   #("./quadruped_var3.xml", var3_morph),
    ("./quadruped_var4.xml", var4_morph),
]

for i, (xml_path, morph_vec) in enumerate(morph_specs):
    run_name = f"quadrl_m{i+3}"  # e.g. "quad_var_m0", "quad_var_m1", ...

    print(f"\n=== Training policy for morph {i} ===")
    print("  xml_path:", xml_path)
    print("  morph_vec:", morph_vec)

    model = re.train_ppo_with_pose_template(
        run_name=run_name,
        pose_generator=None,  
        morph_vec=morph_vec,
        xml_path=xml_path,
        timesteps=2_000_000,
        parallel_envs=4,
        initial_learning_rate=3e-4 )





=== Training policy for morph 0 ===
  xml_path: ./quadruped_var4.xml
  morph_vec: [0.16, 0.3, 0.141421, 0.28, 0.141421, 0.28, 0.141421, 0.28]
Logging to ./logs_quadrl_m3
Using cuda device





ðŸš€ Training PPO with pose templates for quadrl_m3 ...
[env 2] ep_len=  20 | R_mean=-0.680 | fwd=-0.486 | vel=-1.512 | imitat= 0.000 | alive= 0.100 | energy_p= 0.028 | fail= 0.750 | var=unknown
[env 0] ep_len=  30 | R_mean=-0.204 | fwd=-0.626 | vel=-2.037 | imitat= 0.000 | alive= 0.100 | energy_p= 0.032 | fail= 0.500 | var=unknown
[env 1] ep_len=  39 | R_mean=-0.016 | fwd=-0.085 | vel=-0.394 | imitat= 0.000 | alive= 0.100 | energy_p= 0.030 | fail= 0.385 | var=unknown
[env 2] ep_len=  20 | R_mean=-0.163 | fwd=-0.373 | vel=-1.217 | imitat= 0.000 | alive= 0.100 | energy_p= 0.030 | fail= 0.750 | var=unknown
[env 0] ep_len=  20 | R_mean=-0.172 | fwd=-0.443 | vel=-1.366 | imitat= 0.000 | alive= 0.100 | energy_p= 0.032 | fail= 0.750 | var=unknown
[env 1] ep_len=  20 | R_mean=-0.231 | fwd=-0.639 | vel=-2.058 | imitat= 0.000 | alive= 0.100 | energy_p= 0.030 | fail= 0.750 | var=unknown
[env 3] ep_len=  65 | R_mean=-0.062 | fwd=-0.177 | vel=-0.622 | imitat= 0.000 | alive= 0.100 | energy_p= 0.02

  logger.warn(


Video recording complete.
[env 3] ep_len=  26 | R_mean=-0.124 | fwd=-0.336 | vel=-1.036 | imitat= 0.000 | alive= 0.100 | energy_p= 0.028 | fail= 0.577 | var=unknown
[env 2] ep_len=  33 | R_mean=-0.106 | fwd=-0.271 | vel=-0.954 | imitat= 0.000 | alive= 0.100 | energy_p= 0.029 | fail= 0.455 | var=unknown
[env 0] ep_len=  57 | R_mean=-0.040 | fwd=-0.073 | vel=-0.291 | imitat= 0.000 | alive= 0.100 | energy_p= 0.028 | fail= 0.263 | var=unknown
[env 3] ep_len=  41 | R_mean=-0.047 | fwd=-0.076 | vel=-0.303 | imitat= 0.000 | alive= 0.100 | energy_p= 0.029 | fail= 0.366 | var=unknown
[env 0] ep_len=  24 | R_mean=-0.168 | fwd=-0.486 | vel=-1.543 | imitat= 0.000 | alive= 0.100 | energy_p= 0.029 | fail= 0.625 | var=unknown
[env 1] ep_len= 127 | R_mean=-0.063 | fwd=-0.192 | vel=-0.683 | imitat= 0.000 | alive= 0.100 | energy_p= 0.029 | fail= 0.118 | var=unknown
[env 2] ep_len=  52 | R_mean=-0.050 | fwd=-0.099 | vel=-0.407 | imitat= 0.000 | alive= 0.100 | energy_p= 0.030 | fail= 0.288 | var=unknown
[

In [2]:
##### Train Diffusion ####

import numpy as np
from Reinforcement import collect_quadruped_morph_trajectories
from Diffusion import train_diffusion
from Diffusion import QuadrupedMorphDiffusionDataset


var1_morph = [
    0.2, 0.25,  # FR leg, FR ankle
    0.2, 0.25,  # FL leg, FL ankle
    0.141421, 0.3,  # BL leg, BL ankle
    0.141421, 0.3,  # BR leg, BR ankle
]

var2_morph = [
    0.141421, 0.16,  # FR leg, FR ankle
    0.141421, 0.16,  # FL leg, FL ankle
    0.141421, 0.282843,  # BL leg, BL ankle
    0.141421, 0.282843,  # BR leg, BR ankle
]

var3_morph = [
    0.141421, 0.17,  # FR leg, FR ankle
    0.141421, 0.17,  # FL leg, FL ankle
    0.16, 0.15,  # BL leg, BL ankle
    0.16, 0.15,  # BR leg, BR ankle
]

var4_morph = [
    0.16, 0.3,  # FR leg, FR ankle
    0.141421, 0.28,  # FL leg, FL ankle
    0.141421, 0.28,  # BL leg, BL ankle
    0.141421, 0.28,  # BR leg, BR ankle
]

morph_specs = [
    ("./quadruped_var1.xml", var1_morph),
    ("./quadruped_var2.xml", var2_morph),
    ("./quadruped_var3.xml", var3_morph),
    ("./quadruped_var4.xml", var4_morph),
]


ppo_path1 = "quadrl_m0_ppo.zip"
ppo_path2 = "quadrl_m1_ppo.zip"
ppo_path3 = "quadrl_m2_ppo.zip"


#collect_quadruped_morph_trajectories(
#    ppo_path=ppo_path1,
#    morph_specs=[morph_specs[0]],
#    episodes_per_morph=50,
#    max_steps_per_ep=1000,
#    out_path="quadruped_var1_trajectories.npz")

#collect_quadruped_morph_trajectories(
#    ppo_path=ppo_path2,
#    morph_specs=[morph_specs[1]],
#    episodes_per_morph=50,
#    max_steps_per_ep=1000,
#    out_path="quadruped_var2_trajectories.npz")

#collect_quadruped_morph_trajectories(
#    ppo_path=ppo_path3,
#    morph_specs=[morph_specs[2]],
#    episodes_per_morph=50,
#    max_steps_per_ep=1000,
#    out_path="quadruped_var3_trajectories.npz")


re.merge_traj_files(["quadruped_var1_trajectories.npz", "quadruped_var2_trajectories.npz","quadruped_var3_trajectories.npz"])

traj_file  = "quadruped_morph_trajectories.npz"  # produced by collect_quadruped_morph_trajectories

#standard obs indices for Ant joints
imitation_obs_indices = list(range(5, 13))

# call diffusion training
diff_model, loss_history = train_diffusion(
    traj_file=traj_file,
    past_len=10,
    future_len=200,
    num_epochs=12,
    batch_size=128,
    lr=1e-4,
    imitation_obs_indices=imitation_obs_indices,
    save_path="quadruped_morph_diffusion_weights.pt",
)

Merged shapes:
  obs  : (71297, 97)
  act  : (71297, 8)
  rew  : (71297,)
  morph: (71297, 8)
Saved merged dataset to quadruped_morph_trajectories.npz
Using device: cuda
[QuadrupedDataset] obs_dim=97, morph_dim=8
[QuadrupedDataset] target_dim=8
[QuadrupedDataset] total samples possible=71087

=== Diffusion Model Summary ===
obs_dim      : 8
cond_dim     : 88
future_len   : 200
n_steps      : 100
time_dim     : 256
beta range   : (9.999999747378752e-05, 0.019999999552965164)
alpha range  : (0.9800000190734863, 0.9998999834060669)
alpha_bar[0] : 0.9998999834060669
alpha_bar[-1]: 0.3635632395744324
Parameters   : total=2,179,336  trainable=2,179,336

--- UNet Channels ---
down1: Conv1d(352, 256, kernel_size=(3,), stride=(1,), padding=(1,))
down2: Conv1d(256, 512, kernel_size=(4,), stride=(2,), padding=(1,))
mid  : Conv1d(512, 512, kernel_size=(3,), stride=(1,), padding=(1,))
up1  : ConvTranspose1d(512, 256, kernel_size=(4,), stride=(2,), padding=(1,))
up2  : Conv1d(256, 8, kernel_size=(3,