In [1]:
import numpy as np
import robosuite as suite
from robosuite.environments.manipulation.empty import Empty
from scipy import interpolate
from robosuite.utils.mjmod import DynamicsModder

In [2]:
horizon = 200

In [3]:
def plan(start_pose, middle_pose, end_pose, horizon):
    cs = interpolate.CubicSpline([0, horizon // 2, horizon], [start_pose, middle_pose, end_pose], 
                                       axis=0, bc_type='clamped')
    return cs(range(horizon))  
def difference(traj):
    traj = np.array(traj)
    new_traj = [traj[0]]
    for i in range(len(traj)):
        if i == 0:
            continue
        new_traj.append(traj[i] - traj[i - 1])
    return new_traj

In [4]:
suite.environments.base.register_env(Empty)

In [5]:
controller_config = suite.load_controller_config(default_controller="JOINT_POSITION")
controller_config["damping_ratio"] = 0.1
controller_config["ramp_ratio"] = 0.1
print(controller_config)

{'type': 'JOINT_POSITION', 'input_max': 1, 'input_min': -1, 'output_max': 0.05, 'output_min': -0.05, 'kp': 50, 'damping_ratio': 0.1, 'impedance_mode': 'fixed', 'kp_limits': [0, 300], 'damping_ratio_limits': [0, 10], 'qpos_limits': None, 'interpolation': None, 'ramp_ratio': 0.1}


In [6]:
env = suite.make(
    "Empty",
    robots="IIWA",             # load a Sawyer robot and a Panda robot
    gripper_types="ClothGripper",                # use default grippers per robot arm
    controller_configs=controller_config, 
    has_renderer=True,                      # on-screen rendering
    render_camera="sideview",              # visualize the "frontview" camera
    has_offscreen_renderer=False,           # no off-screen rendering
    render_collision_mesh=True,
    control_freq=20,                        # 20 hz control for applied actions
    horizon=horizon,                            # each episode terminates after 200 steps
    use_object_obs=False,                   # no observations needed
    use_camera_obs=False,                   # no observations needed
)

Creating window glfw


In [7]:
import tqdm
import random

num_dataset_size = 1000
num_cloth_joints = 11

all_geom_positions = []
all_parameters = []
for _ in tqdm.tqdm(range(num_dataset_size)):
    done = False
    obs = env.reset()
    
    modder = DynamicsModder(sim=env.sim)
    damping = random.random() * 20
    stiffness = random.random() * 20
    all_parameters.append([damping, stiffness])
    
    for i in range(num_cloth_joints):
        modder.mod("gripper0_joint" + str(i), "damping", 3)
        modder.mod("gripper0_joint" + str(i), "stiffness", 0)
    
    geom_positions = []
    traj = difference(plan([0, 0, 0, 0, 0, 0, 0], [0, -10, 0, 10, 0, -10, 0], [0, 0, 0, 0, 0, 0, 0], horizon))
    
    initial_jpos = np.array([env.sim.data.get_joint_qpos("robot0_joint_" + str(joint)) for joint in range(1, 8)])
    final_jpos = np.array([env.sim.data.get_joint_qpos("robot0_joint_" + str(joint)) + (0.4 if joint == 2 else 0) for joint in range(1, 8)])
    
    while not done:
        action = traj[i]
        jpos = np.array([env.sim.data.get_joint_qpos("robot0_joint_" + str(joint)) for joint in range(1, 8)])
        relative_jpos = final_jpos - jpos
        
        for j in range(num_cloth_joints):
            pos = env.sim.data.geom_xpos[env.sim.model.geom_name2id("gripper0_g{}_col".format(j))]
            geom_positions.append(pos[0])
            geom_positions.append(pos[2])
            
        obs, reward, done, info = env.step(relative_jpos) 
        env.render()
    all_geom_positions.append(geom_positions)
    
all_geom_positions = np.array(all_geom_positions)
all_parameters = np.array(all_parameters)

  0%|          | 0/1000 [00:00<?, ?it/s]

Creating window glfw
[0.  0.4 0.  0.  0.  0.  0. ]
[ 4.22156352e-08  3.98741441e-01  1.85883932e-07  9.23438577e-05
  1.80471756e-07 -1.12452197e-04  4.73409564e-09]
[ 7.73250573e-08  3.95263295e-01  3.00613658e-07  2.60917503e-04
  2.58977074e-07 -2.46674772e-04  6.43504086e-09]
[ 1.03568159e-07  3.89869922e-01  3.35824928e-07  4.20636249e-04
  2.11135381e-07 -2.73397422e-04  9.40147993e-09]
[ 1.26776440e-07  3.82834058e-01  3.30615876e-07  5.25756797e-04
  8.98555928e-08 -2.74729354e-04  1.34136406e-08]
[ 1.47979615e-07  3.74402135e-01  3.09992495e-07  5.66094184e-04
 -4.97697673e-08 -2.75458339e-04  1.79273695e-08]
[ 1.60207126e-07  3.64801180e-01  2.78113017e-07  5.70303655e-04
 -1.42034649e-07 -2.75919775e-04  2.20816081e-08]
[ 1.70317603e-07  3.54237383e-01  2.47874548e-07  5.72755589e-04
 -2.11262610e-07 -2.76487712e-04  2.62393506e-08]
[ 1.83854130e-07  3.42890441e-01  2.28702921e-07  5.77273664e-04
 -2.78033647e-07 -2.77369719e-04  3.06798935e-08]
[ 2.01145290e-07  3.30921596e

[ 6.56350591e-07  3.71917816e-03 -3.27339172e-06  7.66056503e-04
 -3.65528935e-06 -2.89517686e-04  3.34894984e-07]
[ 6.58693847e-07  3.71872769e-03 -3.29379508e-06  7.66326884e-04
 -3.66301375e-06 -2.89650252e-04  3.36121474e-07]
[ 6.61188968e-07  3.71827425e-03 -3.31382262e-06  7.66609530e-04
 -3.67086993e-06 -2.89795421e-04  3.37368841e-07]
[ 6.63821658e-07  3.71781816e-03 -3.33351753e-06  7.66902949e-04
 -3.67885942e-06 -2.89951633e-04  3.38636193e-07]
[ 6.66574230e-07  3.71735982e-03 -3.35293185e-06  7.67205346e-04
 -3.68698201e-06 -2.90117015e-04  3.39922339e-07]
[ 6.69427955e-07  3.71689965e-03 -3.37211989e-06  7.67514854e-04
 -3.69523616e-06 -2.90289619e-04  3.41225950e-07]
[ 6.72364672e-07  3.71643804e-03 -3.39113368e-06  7.67829686e-04
 -3.70361939e-06 -2.90467581e-04  3.42545685e-07]
[ 6.75367812e-07  3.71597535e-03 -3.41002020e-06  7.68148230e-04
 -3.71212860e-06 -2.90649224e-04  3.43880268e-07]
[ 6.78422960e-07  3.71551190e-03 -3.42881984e-06  7.68469102e-04
 -3.72076036e-0

[ 9.06050810e-07  3.68455695e-03 -4.70531161e-06  7.88141361e-04
 -4.46652231e-06 -3.01108915e-04  4.51233084e-07]
[ 9.09624912e-07  3.68409734e-03 -4.72450815e-06  7.88413635e-04
 -4.47914875e-06 -3.01240452e-04  4.52948138e-07]
[ 9.13201602e-07  3.68363781e-03 -4.74370530e-06  7.88685445e-04
 -4.49180006e-06 -3.01371558e-04  4.54665196e-07]
[ 9.16780808e-07  3.68317835e-03 -4.76290295e-06  7.88956797e-04
 -4.50447572e-06 -3.01502239e-04  4.56384207e-07]
[ 9.20362459e-07  3.68271897e-03 -4.78210101e-06  7.89227695e-04
 -4.51717520e-06 -3.01632499e-04  4.58105119e-07]
[ 9.23946486e-07  3.68225966e-03 -4.80129938e-06  7.89498142e-04
 -4.52989800e-06 -3.01762344e-04  4.59827879e-07]
[ 9.27532820e-07  3.68180043e-03 -4.82049796e-06  7.89768144e-04
 -4.54264360e-06 -3.01891779e-04  4.61552440e-07]
[ 9.31121395e-07  3.68134128e-03 -4.83969667e-06  7.90037704e-04
 -4.55541154e-06 -3.02020810e-04  4.63278752e-07]
[ 9.34712148e-07  3.68088220e-03 -4.85889542e-06  7.90306827e-04
 -4.56820132e-0

  0%|          | 1/1000 [00:04<1:08:23,  4.11s/it]

[ 1.04681444e-06  3.66668859e-03 -5.45364210e-06  7.98452236e-04
 -4.97332429e-06 -3.05962207e-04  5.19201610e-07]
[ 1.05044835e-06  3.66623196e-03 -5.47280206e-06  7.98709171e-04
 -4.98661183e-06 -3.06080204e-04  5.20964107e-07]
[ 1.05408298e-06  3.66577541e-03 -5.49195978e-06  7.98965772e-04
 -4.99990991e-06 -3.06197923e-04  5.22727184e-07]
[ 1.05771828e-06  3.66531894e-03 -5.51111521e-06  7.99222041e-04
 -5.01321831e-06 -3.06315367e-04  5.24490818e-07]
[ 1.06135424e-06  3.66486254e-03 -5.53026829e-06  7.99477982e-04
 -5.02653679e-06 -3.06432538e-04  5.26254986e-07]
[ 1.06499082e-06  3.66440622e-03 -5.54941898e-06  7.99733595e-04
 -5.03986511e-06 -3.06549440e-04  5.28019665e-07]
[ 1.06862799e-06  3.66394998e-03 -5.56856722e-06  7.99988883e-04
 -5.05320305e-06 -3.06666073e-04  5.29784832e-07]
[ 1.07226573e-06  3.66349382e-03 -5.58771297e-06  8.00243848e-04
 -5.06655039e-06 -3.06782442e-04  5.31550466e-07]
Creating window glfw
[0.  0.4 0.  0.  0.  0.  0. ]
[-5.61307626e-09  3.98740340e

[ 6.00410526e-07  3.80436573e-03 -2.90341608e-06  8.16254284e-04
 -2.92095290e-06 -3.42389395e-04  3.11831208e-08]
[ 6.02653348e-07  3.80391615e-03 -2.92279653e-06  8.16485938e-04
 -2.92423742e-06 -3.42484327e-04  3.00530345e-08]
[ 6.04961413e-07  3.80346350e-03 -2.94213596e-06  8.16730509e-04
 -2.92754746e-06 -3.42592533e-04  2.89440108e-08]
[ 6.07336525e-07  3.80300772e-03 -2.96143408e-06  8.16988171e-04
 -2.93088819e-06 -3.42714191e-04  2.78566531e-08]
[ 6.09776489e-07  3.80254900e-03 -2.98069343e-06  8.17258065e-04
 -2.93427002e-06 -3.42848402e-04  2.67906651e-08]
[ 6.12276604e-07  3.80208765e-03 -2.99991830e-06  8.17538709e-04
 -2.93770592e-06 -3.42993619e-04  2.57451634e-08]
[ 6.14830824e-07  3.80162408e-03 -3.01911377e-06  8.17828317e-04
 -2.94120944e-06 -3.43147980e-04  2.47189221e-08]
[ 6.17432609e-07  3.80115870e-03 -3.03828509e-06  8.18125029e-04
 -2.94479338e-06 -3.43309545e-04  2.37105578e-08]
[ 6.20075511e-07  3.80069188e-03 -3.05743728e-06  8.18427061e-04
 -2.94846903e-0

[ 8.28097423e-07  3.76892219e-03 -4.35146994e-06  8.37597511e-04
 -3.39498042e-06 -3.53237516e-04 -2.69658678e-08]
[ 8.31378410e-07  3.76845742e-03 -4.37035599e-06  8.37857518e-04
 -3.40333829e-06 -3.53358530e-04 -2.75426809e-08]
[ 8.34662928e-07  3.76799273e-03 -4.38923686e-06  8.38117060e-04
 -3.41172629e-06 -3.53479110e-04 -2.81168848e-08]
[ 8.37950898e-07  3.76752811e-03 -4.40811254e-06  8.38376139e-04
 -3.42014384e-06 -3.53599260e-04 -2.86885296e-08]
[ 8.41242244e-07  3.76706357e-03 -4.42698299e-06  8.38634763e-04
 -3.42859034e-06 -3.53718986e-04 -2.92576643e-08]
[ 8.44536889e-07  3.76659911e-03 -4.44584820e-06  8.38892933e-04
 -3.43706525e-06 -3.53838293e-04 -2.98243369e-08]
[ 8.47834761e-07  3.76613472e-03 -4.46470813e-06  8.39150656e-04
 -3.44556800e-06 -3.53957188e-04 -3.03885945e-08]
[ 8.51135788e-07  3.76567040e-03 -4.48356277e-06  8.39407936e-04
 -3.45409804e-06 -3.54075674e-04 -3.09504830e-08]
[ 8.54439899e-07  3.76520617e-03 -4.50241208e-06  8.39664776e-04
 -3.46265487e-0

  0%|          | 2/1000 [00:08<1:09:03,  4.15s/it]

[ 9.51333673e-07  3.75177653e-03 -5.04663185e-06  8.46938622e-04
 -3.72041842e-06 -3.57462897e-04 -4.68978435e-08]
[ 9.54704116e-07  3.75131459e-03 -5.06531217e-06  8.47183928e-04
 -3.72957888e-06 -3.57570860e-04 -4.74045200e-08]
[ 9.58076075e-07  3.75085273e-03 -5.08398662e-06  8.47428895e-04
 -3.73875411e-06 -3.57678540e-04 -4.79098840e-08]
[ 9.61449513e-07  3.75039095e-03 -5.10265517e-06  8.47673526e-04
 -3.74794383e-06 -3.57785938e-04 -4.84139587e-08]
[ 9.64824395e-07  3.74992924e-03 -5.12131781e-06  8.47917824e-04
 -3.75714776e-06 -3.57893058e-04 -4.89167668e-08]
[ 9.68200688e-07  3.74946761e-03 -5.13997454e-06  8.48161791e-04
 -3.76636564e-06 -3.57999902e-04 -4.94183306e-08]
[ 9.71578357e-07  3.74900606e-03 -5.15862533e-06  8.48405429e-04
 -3.77559720e-06 -3.58106474e-04 -4.99186721e-08]
[ 9.74957369e-07  3.74854459e-03 -5.17727019e-06  8.48648740e-04
 -3.78484220e-06 -3.58212775e-04 -5.04178129e-08]
[ 9.78337693e-07  3.74808319e-03 -5.19590909e-06  8.48891728e-04
 -3.79410037e-0

[ 1.62616028e-07  3.77471834e-03 -1.16835486e-06  1.11572180e-03
 -2.78655881e-06 -6.83152395e-04 -2.88292642e-07]
[ 1.62810981e-07  3.77428140e-03 -1.17954342e-06  1.11585951e-03
 -2.79029587e-06 -6.83152313e-04 -2.91371421e-07]
[ 1.63137162e-07  3.77384678e-03 -1.19035168e-06  1.11598607e-03
 -2.79395285e-06 -6.83140417e-04 -2.94446760e-07]
[ 1.63464778e-07  3.77341198e-03 -1.20114333e-06  1.11611252e-03
 -2.79768183e-06 -6.83128213e-04 -2.97515831e-07]
[ 1.63709723e-07  3.77297535e-03 -1.21215412e-06  1.11624604e-03
 -2.80158453e-06 -6.83123166e-04 -3.00576543e-07]
[ 1.63823683e-07  3.77253594e-03 -1.22351971e-06  1.11639075e-03
 -2.80572272e-06 -6.83129578e-04 -3.03627494e-07]
[ 1.63785115e-07  3.77209333e-03 -1.23530092e-06  1.11654853e-03
 -2.81012762e-06 -6.83149383e-04 -3.06667837e-07]
[ 1.63591502e-07  3.77164744e-03 -1.24750536e-06  1.11671961e-03
 -2.81480813e-06 -6.83182836e-04 -3.09697164e-07]
[ 1.63252966e-07  3.77119845e-03 -1.26010529e-06  1.11690317e-03
 -2.81975773e-0

[ 1.33497805e-07  3.74003798e-03 -2.12437283e-06  1.13043574e-03
 -3.35237767e-06 -6.87246851e-04 -5.03073856e-07]
[ 1.33362594e-07  3.73958166e-03 -2.13612505e-06  1.13061346e-03
 -3.36159447e-06 -6.87283618e-04 -5.05733001e-07]
[ 1.33233589e-07  3.73912541e-03 -2.14785708e-06  1.13079072e-03
 -3.37083373e-06 -6.87319941e-04 -5.08389363e-07]
[ 1.33110705e-07  3.73866924e-03 -2.15956917e-06  1.13096752e-03
 -3.38009500e-06 -6.87355827e-04 -5.11042976e-07]
[ 1.32993860e-07  3.73821313e-03 -2.17126155e-06  1.13114385e-03
 -3.38937781e-06 -6.87391283e-04 -5.13693874e-07]
[ 1.32882973e-07  3.73775710e-03 -2.18293445e-06  1.13131974e-03
 -3.39868172e-06 -6.87426313e-04 -5.16342090e-07]
[ 1.32777964e-07  3.73730113e-03 -2.19458809e-06  1.13149518e-03
 -3.40800628e-06 -6.87460926e-04 -5.18987655e-07]
[ 1.32678755e-07  3.73684524e-03 -2.20622270e-06  1.13167017e-03
 -3.41735107e-06 -6.87495125e-04 -5.21630602e-07]
[ 1.32585270e-07  3.73638942e-03 -2.21783848e-06  1.13184473e-03
 -3.42671567e-0

  0%|          | 3/1000 [00:12<1:09:42,  4.20s/it]

[ 1.33089921e-07  3.71685832e-03 -2.70201891e-06  1.13899056e-03
 -3.84338754e-06 -6.88661785e-04 -6.35708769e-07]
Creating window glfw
[0.  0.4 0.  0.  0.  0.  0. ]
[-5.68226161e-08  3.98740158e-01 -1.39383630e-07  9.81928085e-05
  9.35953908e-08 -1.23122028e-04  1.12905861e-08]
[-9.00288013e-08  3.95259477e-01 -2.38943613e-07  2.78034047e-04
  1.32306469e-07 -2.78051692e-04  1.50846572e-08]
[-9.38390813e-08  3.89863421e-01 -2.82675801e-07  4.50386758e-04
  8.91648350e-08 -3.18079989e-04  1.67638313e-08]
[-7.84915182e-08  3.82825231e-01 -2.92237160e-07  5.67951576e-04
 -1.13706627e-08 -3.19604108e-04  1.73930215e-08]
[-5.74923258e-08  3.74391468e-01 -2.92287524e-07  6.17361913e-04
 -1.27973412e-07 -3.20308958e-04  1.77694765e-08]
[-4.59193291e-08  3.64788349e-01 -3.10832729e-07  6.23087815e-04
 -2.10295476e-07 -3.20730954e-04  1.90368622e-08]
[-4.14057492e-08  3.54223243e-01 -3.43719350e-07  6.25385933e-04
 -2.69675150e-07 -3.21209471e-04  2.08967361e-08]
[-3.79086033e-08  3.42876111e

[ 4.90298792e-07  3.76259844e-03 -2.38883090e-06  8.13311041e-04
 -2.99659226e-06 -3.33163925e-04 -4.96180081e-08]
[ 4.92073274e-07  3.76214827e-03 -2.40679167e-06  8.13556867e-04
 -3.00218116e-06 -3.33274762e-04 -5.11590505e-08]
[ 4.93849149e-07  3.76169493e-03 -2.42491070e-06  8.13815882e-04
 -3.00797140e-06 -3.33399150e-04 -5.26848215e-08]
[ 4.95628101e-07  3.76123859e-03 -2.44317623e-06  8.14087227e-04
 -3.01396116e-06 -3.33536191e-04 -5.41954073e-08]
[ 4.97412226e-07  3.76077958e-03 -2.46156841e-06  8.14369414e-04
 -3.02014269e-06 -3.33684334e-04 -5.56912528e-08]
[ 4.99203729e-07  3.76031831e-03 -2.48006343e-06  8.14660648e-04
 -3.02650466e-06 -3.33841707e-04 -5.71730220e-08]
[ 5.01004706e-07  3.75985519e-03 -2.49863657e-06  8.14959055e-04
 -3.03303407e-06 -3.34006360e-04 -5.86414906e-08]
[ 5.02817018e-07  3.75939062e-03 -2.51726425e-06  8.15262840e-04
 -3.03971753e-06 -3.34176423e-04 -6.00974686e-08]
[ 5.04642218e-07  3.75892496e-03 -2.53592524e-06  8.15570380e-04
 -3.04654218e-0

  0%|          | 3/1000 [00:16<1:29:14,  5.37s/it]

[ 6.57966799e-07  3.72685179e-03 -3.77278719e-06  8.35077853e-04
 -3.67730209e-06 -3.44412966e-04 -1.46213051e-07]





SystemExit: 0

  warn("To exit: use 'exit', 'quit', or Ctrl-D.", stacklevel=1)


In [None]:
# Have some prior distribution of magazine parameters
# Model predicts parameters distribution from actions and movement
# Each cycle
# Sample parameters and choose actions to minimize entropy of predictions

In [None]:
# Use DIAYN for trajectories

In [None]:
all_geom_positions.shape, all_parameters.shape

In [None]:
with open("all_geom_positions.npy", 'r') as f:
    all_geom_positions = np.load(f, all_geom_positions)
with open("all_parameters.npy", 'r') as f:
    np.load(f, all_parameters)

In [None]:
import torch
import numpy as np
from torch.utils.data import TensorDataset, DataLoader

train_split = 0.8
train_length = int(train_split * len(all_geom_positions))
val_length = len(all_geom_positions) - train_length
all_dataset = torch.utils.data.TensorDataset(torch.from_numpy(all_geom_positions), torch.from_numpy(all_parameters))
train_dataset, val_dataset = torch.utils.data.random_split(all_dataset, [train_length, val_length])

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=32, pin_memory=True)

In [None]:
model = torch.nn.Sequential(
    torch.nn.Linear(4400, 2048),
    torch.nn.ReLU(),
    torch.nn.Dropout(0.1),
    torch.nn.Linear(2048, 1024),
    torch.nn.ReLU(),
    torch.nn.Dropout(0.1),
    torch.nn.Linear(1024, 1024),
    torch.nn.ReLU(),
    torch.nn.Linear(1024, 2 * 2),
).to("cuda")
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)

In [None]:
loss_fn = torch.nn.GaussianNLLLoss()
for epoch in range(100):
    model.train()
    train_losses = []
    for position, target_params in train_loader:
        position = position.to("cuda").float()
        target_params = target_params.to("cuda").float()
        
        pred_params_mu, pred_params_logvar = torch.split(model(position), 2, dim=1)
        loss = loss_fn(pred_params_mu, target_params, torch.exp(pred_params_logvar))
        
        train_losses.append(loss.item())
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    
    model.eval()
    with torch.no_grad():
        val_losses = []
        for position, target_params in val_loader:
            position = position.to("cuda").float()
            target_params = target_params.to("cuda").float()

            pred_params_mu, pred_params_logvar = torch.split(model(position), 2, dim=1)
            loss = loss_fn(pred_params_mu, target_params, torch.exp(pred_params_logvar))

            val_losses.append(loss.item())
        
    print("Epoch: {}, Train Loss: {}, Val Loss: {}".format(epoch, np.mean(train_losses), np.mean(val_losses)))

In [None]:
predictions = []
ground_truth = []
for i, (position, target_params) in enumerate(val_loader):
    prediction, _ = torch.split(model(position.to("cuda").float()), 2, dim=1)
    predictions.append(prediction.detach().cpu().numpy())
    ground_truth.append(target_params.cpu().numpy())
predictions = np.concatenate(predictions, axis=0)
ground_truth = np.concatenate(ground_truth, axis=0)

In [None]:
from matplotlib import pyplot as plt
plt.scatter(ground_truth[:, 0], predictions[:, 0])
plt.show()

In [None]:
from matplotlib import pyplot as plt
plt.scatter(ground_truth[:, 1], predictions[:, 1])
plt.show()