In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
# shared across tasks
from plb.optimizer.optim import Adam
from plb.engine.taichi_env import TaichiEnv
from plb.config.default_config import get_cfg_defaults, CN
# from toolbox.control_soft import setup_finger

import os
import cv2
import numpy as np
import taichi as ti
ti.init(arch=ti.gpu)
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm

In [None]:
from yacs.config import CfgNode
from plb.config import load
cfg = load("../envs/tool.yml")
print(cfg)
env = TaichiEnv(cfg, nn=False, loss=False)
env.initialize()
state = env.get_state()

In [None]:
## %matplotlib inline
env.set_state(**state)
taichi_env = env
print(env.renderer.camera_pos)
env.renderer.camera_pos[0] = 0.5#np.array([float(i) for i in (0.5, 2.5, 0.5)]) #(0.5, 2.5, 0.5)  #.from_numpy(np.array([[0.5, 2.5, 0.5]]))
env.renderer.camera_pos[1] = 2.5
env.renderer.camera_pos[2] = 2.2
env.renderer.camera_rot = (0.8, 0.0)

env.primitives.primitives[0].size[None][2] = 0.05
env.primitives.primitives[1].size[None][2] = 0.05
env.render('plt')

action_dim = taichi_env.primitives.action_dim
print(action_dim)

In [None]:
cwd = os.getcwd()
root_dir = cwd + "/../.."
print(root_dir)

In [None]:
def set_parameters(env: TaichiEnv, yield_stress, E, nu):
    env.simulator.yield_stress.fill(yield_stress)
    _mu, _lam = E / (2 * (1 + nu)), E * nu / ((1 + nu) * (1 - 2 * nu))  # Lame parameters
    env.simulator.mu.fill(_mu)
    env.simulator.lam.fill(_lam)
def update_camera(env):
    env.renderer.camera_pos[0] = 0.5
    env.renderer.camera_pos[1] = 2.5
    env.renderer.camera_pos[2] = 2.2
    env.renderer.camera_rot = (0.8, 0.)
    env.render_cfg.defrost()
    env.render_cfg.camera_pos_1 = (0.5, 2.5, 2.2)
    env.render_cfg.camera_rot_1 = (0.8, 0.)
    env.render_cfg.camera_pos_2 = (2.4, 2.5, 0.2)
    env.render_cfg.camera_rot_2 = (0.8, 1.8)
    env.render_cfg.camera_pos_3 = (-1.9, 2.5, 0.2)
    env.render_cfg.camera_rot_3 = (0.8, -1.8)
    env.render_cfg.camera_pos_4 = (0.5, 2.5, -1.8)
    env.render_cfg.camera_rot_4 = (0.8, 3.14)
    
def update_primitive(env, prim1_list, prim2_list):
    env.primitives.primitives[0].set_state(0, prim1_list)
    env.primitives.primitives[1].set_state(0, prim2_list)
    
def save_files(env, rollout_dir, i):
    files = glob.glob(f"{rollout_dir}/{i:03d}/*")
    for f in files:
        os.remove(f)
    os.makedirs(f"{rollout_dir}/{i:03d}", exist_ok=True)
    with open(f"{rollout_dir}/{i:03d}"+"/cam_params.npy", 'wb') as f:
        ext1=env.renderer.get_ext(env.render_cfg.camera_rot_1, np.array(env.render_cfg.camera_pos_1))
        ext2=env.renderer.get_ext(env.render_cfg.camera_rot_2, np.array(env.render_cfg.camera_pos_2))
        ext3=env.renderer.get_ext(env.render_cfg.camera_rot_3, np.array(env.render_cfg.camera_pos_3))
        ext4=env.renderer.get_ext(env.render_cfg.camera_rot_4, np.array(env.render_cfg.camera_pos_4))
        intrinsic = env.renderer.get_int()
        cam_params = {'cam1_ext': ext1, 'cam2_ext': ext2, 'cam3_ext': ext3, 'cam4_ext': ext4, 'intrinsic': intrinsic}
        np.save(f, cam_params)
        
from transforms3d.quaternions import axangle2quat
from transforms3d.quaternions import mat2quat
from transforms3d.axangles import axangle2mat
def random_rotate(mid_point, gripper1_pos, gripper2_pos):
    mid_point = mid_point[:3]
    z_vec = np.array([0, 1, 0])
    z_angle = np.random.uniform(0, np.pi)
    z_mat = axangle2mat(z_vec, z_angle, is_normalized=True)
    x_vec = np.array([1, 0, 0])
    
    axis_vec = z_mat @ x_vec
    axis_vec = axis_vec / np.linalg.norm(axis_vec)
    angle = np.random.uniform(-np.pi/2, np.pi/2)
    mat = axangle2mat(axis_vec, angle, is_normalized=True)
    all_mat = mat @ z_mat
    quat = mat2quat(all_mat)
    
    gripper1_pos = gripper1_pos - ((gripper1_pos - mid_point) - all_mat @ (gripper1_pos - mid_point))
    gripper2_pos = gripper2_pos - ((gripper2_pos - mid_point) - all_mat @ (gripper2_pos - mid_point))
    return gripper1_pos, gripper2_pos, quat

def random_pose(angle=False, random_rotate_flag=False):
    r = 0.25
    mid_point = np.array([0.5, 0.14, 0.5, 1, 0, 0, 0])
    p_noise = np.clip(np.array([np.random.randn()*0.03, 
                               np.random.randn()*0.03, np.random.randn()*0.03]), a_max=0.1, a_min=-0.1) # np.array([np.random.randn()*0.05, 0, np.random.randn()*0.05])
    new_mid_point = mid_point[:3] + p_noise
    if angle:
        rot_noise = np.random.uniform(0, np.pi)
    else:
        rot_noise = 0
    x1 = new_mid_point[0] - r * np.cos(rot_noise)
    y1 = new_mid_point[2] + r * np.sin(rot_noise)
    x2 = new_mid_point[0] + r * np.cos(rot_noise)
    y2 = new_mid_point[2] - r * np.sin(rot_noise)
    z = new_mid_point[1]
    if random_rotate_flag:
        gripper1_pos, gripper2_pos, quat = random_rotate(new_mid_point, np.array([x1, z, y1]), np.array([x2, z, y2]))
    else:
        gripper1_pos = np.array([x1, z, y1])
        gripper2_pos = np.array([x2, z, y2])
        quat = np.array([1, 0, 0, 0])
    return np.concatenate([gripper1_pos, quat]), np.concatenate([gripper2_pos, quat]), rot_noise

def get_obs(env, n_particles, t=0):
    x = env.simulator.get_x(t)
    v = env.simulator.get_v(t)
    step_size = len(x) // n_particles
    return x[::step_size], v[::step_size]

def choose_tool(env, width):
    env.primitives.primitives[0].size[None][2] = width
    env.primitives.primitives[1].size[None][2] = width

In [None]:
import glob
from datetime import datetime

i = 0; task_name = 'gripper_tool'; n_grips=3; rate=0.01; suffix='mix'
len_per_grip = 20
len_per_grip_back = 10
zero_pad = np.array([0,0,0])

time_now = datetime.now().strftime("%d-%b-%Y-%H:%M:%S.%f")
rollout_dir = f"{root_dir}/dataset/{task_name}_{suffix}_{time_now}"

while i < 50: 
    print(f"+++++++++++++++++++{i}+++++++++++++++++++++")
    env.set_state(**state)
    taichi_env = env  
    update_camera(env)
    set_parameters(env, yield_stress=200, E=5e3, nu=0.2)
    update_primitive(env, [0.3, 0.4, 0.5, 1, 0, 0, 0], [0.7, 0.4, 0.5, 1, 0, 0, 0])
    env.primitives.primitives[0].friction[None] = 100.
    env.primitives.primitives[0].friction[None] = 100.
    save_files(env, rollout_dir, i)
    
    action_dim = env.primitives.action_dim
        
    from tqdm.notebook import tqdm
    imgs = [] 
    true_idx = 0
    for k in range(n_grips):
        print(k)
        
        prim1, prim2, cur_angle = random_pose(angle=False, random_rotate_flag=True)
        update_primitive(env, prim1, prim2)
        
        delta_g = np.random.uniform(0.27, 0.35)
        v_close = rate
        counter = 0 
        actions = []
        mid_point = (prim1[:3] + prim2[:3])/2
        x_direction = mid_point - prim1[:3]
        x_direction = x_direction / np.linalg.norm(x_direction)
        while delta_g > 0 and counter < len_per_grip:
            x = v_close * x_direction
            y = - x
            delta_g -= 2*rate
            actions.append(np.concatenate([x/0.02, zero_pad, y/0.02, zero_pad]))
            counter += 1
        actions = actions[:len_per_grip]
        for _ in range(len_per_grip - len(actions)):
            actions.append(np.concatenate([zero_pad, zero_pad, zero_pad, zero_pad]))
        
        counter = 0
        while counter < len_per_grip_back:
            x = - v_close * x_direction
            y = - x
            actions.append(np.concatenate([x/0.02, zero_pad, y/0.02, zero_pad]))
            counter += 1

        actions = np.stack(actions)
        tool_size = np.random.choice([0.05, 0.1])
        choose_tool(env, width=tool_size)
        for idx, act in enumerate(tqdm(actions, total=actions.shape[0])):
            env.step(act)
            obs = get_obs(env, 300)
            x = obs[0][:300]
            
            if 'gripper' in task_name:
                primitive_state = [env.primitives.primitives[0].get_state(0), env.primitives.primitives[1].get_state(0)]

            img = env.render_multi(mode='rgb_array', spp=3)
            rgb, depth = img[0], img[1]

            os.system('mkdir -p ' + f"{rollout_dir}/{i:03d}")
            
            for num_cam in range(4):
                cv2.imwrite(f"{rollout_dir}/{i:03d}/{true_idx:03d}_rgb_{num_cam}.png", rgb[num_cam][..., ::-1])
            with open(f"{rollout_dir}/{i:03d}/{true_idx:03d}_depth_prim.npy", 'wb') as f:
                np.save(f, depth + primitive_state+[tool_size])
            with open(f"{rollout_dir}/{i:03d}/{true_idx:03d}_gtp.npy", 'wb') as f:
                np.save(f, x)
            true_idx += 1

        print(true_idx)
    
    os.system(f'ffmpeg -y -i {rollout_dir}/{i:03d}/%03d_rgb_0.png -c:v libx264 -vf fps=25 -pix_fmt yuv420p {rollout_dir}/{i:03d}/vid{i:03d}.mp4')         
    i += 1