Skip to content

Commit

Permalink
Merge pull request #105 from caio-freitas/104-normalize-data
Browse files Browse the repository at this point in the history
Normalize Data and use 6D Rotation
  • Loading branch information
caio-freitas authored Apr 1, 2024
2 parents 38a3a43 + 2c39ea3 commit 3003b1b
Show file tree
Hide file tree
Showing 17 changed files with 207 additions and 129 deletions.
2 changes: 2 additions & 0 deletions conda_environment.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ dependencies:
- pip
- hydra-core
- torchvision==0.15.1
- cudatoolkit=10.1
- pip:
- pytorchvideo
- gymnasium==0.28.1
Expand All @@ -20,6 +21,7 @@ dependencies:
- tqdm
- pybullet
- torch==2.0.0
- pytorch3d==0.3.0
- torch-geometric
- -e git+https://github.com/ARISE-Initiative/robosuite.git@v1.4.1#egg=robosuite
- mujoco-py
Expand Down
6 changes: 3 additions & 3 deletions imitation/config/eval.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,13 @@ render: True
output_video: True

num_episodes: 20
max_steps: 500
max_steps: 1000
output_dir: ./outputs


pred_horizon: 16
obs_horizon: 4
action_horizon: 8
obs_horizon: 2
action_horizon: 16


env_runner: ${task.env_runner}
Expand Down
3 changes: 2 additions & 1 deletion imitation/config/policy/argd_policy.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ dataset: ${task.dataset}
node_feature_dim: 2
num_edge_types: 2 # robot joints, object-robot
lr: 0.00005
ckpt_path: ./weights/${task.task_name}_${task.dataset_type}_en_graph_diffusion_policy.pt
ckpt_path: ./weights/${task.task_name}_${task.dataset_type}_${task.control_mode}_en_argd_policy.pt
device: cuda
denoising_network:
_target_: imitation.model.graph_diffusion.ConditionalARGDenoising
Expand All @@ -14,3 +14,4 @@ denoising_network:
num_edge_types: ${policy.num_edge_types}
num_layers: 3
hidden_dim: 128
use_normalization: True
9 changes: 5 additions & 4 deletions imitation/config/policy/graph_ddpm_policy.yaml
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
_target_: imitation.policy.graph_ddpm_policy.GraphConditionalDDPMPolicy
obs_dim: 7
obs_dim: 9
action_dim: 9
node_feature_dim: 2
node_feature_dim: 1 # from [joint_val, node_flag]
num_edge_types: 2 # robot joints, object-robot
pred_horizon: ${pred_horizon}
obs_horizon: ${obs_horizon}
action_horizon: ${action_horizon}
num_diffusion_iters: 200
num_diffusion_iters: 100
dataset: ${task.dataset}
denoising_network:
_target_: imitation.model.graph_diffusion.ConditionalGraphNoisePred
Expand All @@ -18,5 +18,6 @@ denoising_network:
num_layers: 3
hidden_dim: 128
ckpt_path: ./weights/diffusion_graph_policy_${task.task_name}_${task.dataset_type}_${task.control_mode}_${policy.num_diffusion_iters}iters.pt
lr: 5e-4
lr: 1e-4
batch_size: 32
use_normalization: True
2 changes: 2 additions & 0 deletions imitation/config/task/lift_graph.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ env_runner:
controller_config:
interpolation: "linear"
ramp_ratio: 0.2
base_link_shift: [-0.56, 0, 0.912]
dataset:
_target_: imitation.dataset.robomimic_graph_dataset.RobomimicGraphDataset
dataset_path: ${task.dataset_path}
Expand All @@ -49,3 +50,4 @@ dataset:
object_state_sizes: *object_state_sizes
object_state_keys: *object_state_keys
control_mode: ${task.control_mode}
base_link_shift: [-0.56, 0, 0.912]
4 changes: 2 additions & 2 deletions imitation/config/task/lift_lowdim.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ dataset_path: &dataset_path ./data/${task.task_name}/${task.dataset_type}/low_di

max_steps: 500

obs_dim: 33
obs_dim: 35 # 33 + 2 (due to quaternion -> 6D rotation)
action_dim: 9


Expand All @@ -23,7 +23,7 @@ env_runner:
output_video: ${output_video}
env:
_target_: imitation.env.robomimic_lowdim_wrapper.RobomimicLowdimWrapper
max_steps: ${task.max_steps}
max_steps: ${eval:'500 if "${task.dataset_type}" == "mh" else 400'}
task: "Lift"
robots: ["Panda"]
output_video: ${output_video}
Expand Down
4 changes: 2 additions & 2 deletions imitation/config/task/square_lowdim.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ dataset_path: &dataset_path ./data/${task.task_name}/${task.dataset_type}/low_di

max_steps: 1000

obs_dim: 37
obs_dim: 39 # 37 + 2 (due to quaternion -> 6D rotation)
action_dim: 9


Expand All @@ -21,7 +21,7 @@ env_runner:
output_video: ${output_video}
env:
_target_: imitation.env.robomimic_lowdim_wrapper.RobomimicLowdimWrapper
max_steps: ${task.max_steps}
max_steps: ${eval:'500 if "${task.dataset_type}" == "mh" else 400'}
task: "NutAssemblySquare"
robots: ["Panda"]
output_video: ${output_video}
Expand Down
4 changes: 2 additions & 2 deletions imitation/config/task/transport_lowdim.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ dataset_path: &dataset_path ./data/${task.task_name}/${task.dataset_type}/low_di

max_steps: 1000

obs_dim: 87
obs_dim: 91 # 87 + 2 (due to quaternion -> 6D rotation) * 2 (robots)
action_dim: 18


Expand Down Expand Up @@ -35,7 +35,7 @@ env_runner:
output_video: ${output_video}
env:
_target_: imitation.env.robomimic_lowdim_wrapper.RobomimicLowdimWrapper
max_steps: ${task.max_steps}
max_steps: ${eval:'1100 if "${task.dataset_type}" == "mh" else 700'}
task: "TwoArmTransport"
robots: *robots
output_video: ${output_video}
Expand Down
9 changes: 5 additions & 4 deletions imitation/config/train.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ render: False
output_video: True

pred_horizon: 16
obs_horizon: 4
obs_horizon: 2
action_horizon: 8
# Training parameters
num_epochs: 2000
Expand All @@ -23,17 +23,18 @@ agent:

# Disable evaluation
# eval_params: "disabled"
max_steps: 500
max_steps: 1500

# Evaluation during training
eval_params:
eval_every: 20 # evaluate every 50 epochs
eval_every: 50 # evaluate every 50 epochs
val_every: 1
task: ${task}
policy: ${policy}
render: False
output_video: True

num_episodes: 2
num_episodes: 50
max_steps: ${max_steps}
output_dir: ./outputs

Expand Down
71 changes: 62 additions & 9 deletions imitation/dataset/robomimic_graph_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,11 @@
import numpy as np
import torch
from tqdm import tqdm

from typing import List, Dict

from functools import lru_cache

from diffusion_policy.model.common.rotation_transformer import RotationTransformer

from imitation.utils.generic import calculate_panda_joints_positions

log = logging.getLogger(__name__)
Expand All @@ -24,7 +25,8 @@ def __init__(self,
pred_horizon=1,
obs_horizon=1,
node_feature_dim = 2, # joint value and node type flag
control_mode="JOINT_VELOCITY"):
control_mode="JOINT_VELOCITY",
base_link_shift=[0.0, 0.0, 0.0]):
self.control_mode : str = control_mode
self.node_feature_dim : int = node_feature_dim
self.action_keys : List = action_keys
Expand All @@ -35,6 +37,7 @@ def __init__(self,
self.num_objects : int = len(object_state_keys)
self._processed_dir : str = dataset_path.replace(".hdf5", f"_{self.control_mode}_processed_{self.obs_horizon}_{self.pred_horizon}")

self.BASE_LINK_SHIFT : List = base_link_shift
self.ROBOT_NODE_TYPE : int = 1
self.OBJECT_NODE_TYPE : int = -1

Expand All @@ -47,12 +50,22 @@ def __init__(self,
self.dataset_keys.remove("mask")
except:
pass
self.rotation_transformer = RotationTransformer(
from_rep="quaternion",
to_rep="rotation_6d"
)

super().__init__(root=self._processed_dir, transform=None, pre_transform=None, pre_filter=None, log=True)
self.stats = {}
self.stats["y"] = self.get_data_stats("y")
self.stats["x"] = self.get_data_stats("x")


self.constant_stats = {
"y": torch.tensor([False, False, False, True, True, True, True, True, True]), # mask rotations for robot and object nodes
"x": torch.tensor([False, True]) # node type flag is constant
}


@property
def processed_file_names(self):
'''
Expand All @@ -70,20 +83,26 @@ def _get_object_feats(self, num_objects, node_feature_dim, OBJECT_NODE_TYPE, T):
return obj_state_tensor

def _get_object_pos(self, data, t):
obj_state_tensor = torch.zeros((self.num_objects, 7)) # 3 for position, 4 for quaternion
obj_state_tensor = torch.zeros((self.num_objects, 9)) # 3 for position, 6 for 6D rotation

for object, object_state_items in enumerate(self.object_state_keys.values()):
i = 0
for object_state in object_state_items:
obj_state_tensor[object,i:i + self.object_state_sizes[object_state]] = torch.from_numpy(data["object"][t][i:i + self.object_state_sizes[object_state]])
if "quat" in object_state:
assert self.object_state_sizes[object_state] == 4
rot = self.rotation_transformer.forward(torch.tensor(data["object"][t][i:i + self.object_state_sizes[object_state]]))
obj_state_tensor[object,i:i + 6] = rot
else:
obj_state_tensor[object,i:i + self.object_state_sizes[object_state]] = torch.from_numpy(data["object"][t][i:i + self.object_state_sizes[object_state]])
i += self.object_state_sizes[object_state]

return obj_state_tensor

def _get_node_pos(self, data, t):
node_pos = []
node_pos.append(calculate_panda_joints_positions([*data["robot0_joint_pos"][t], *data["robot0_gripper_qpos"][t]]))
node_pos = torch.cat(node_pos)
node_pos = calculate_panda_joints_positions([*data["robot0_joint_pos"][t], *data["robot0_gripper_qpos"][t]])
node_pos[:,:3] += torch.tensor(self.BASE_LINK_SHIFT)
# use rotation transformer to convert quaternion to 6d rotation
node_pos = torch.cat([node_pos[:,:3], self.rotation_transformer.forward(node_pos[:,3:])], dim=1)
obj_pos_tensor = self._get_object_pos(data, t)
node_pos = torch.cat((node_pos, obj_pos_tensor), dim=0)
return node_pos
Expand Down Expand Up @@ -225,6 +244,40 @@ def get_data_stats(self, key):
"min": torch.min(data, dim=1).values,
"max": torch.max(data, dim=1).values
}

def normalize_data(self, data, stats_key, batch_size=1):
# avoid division by zero by skipping normalization
with torch.no_grad():
# duplicate stats for each batch
data = data.clone().to(dtype=torch.float64)
stats = self.stats[stats_key].copy()
stats["min"] = stats["min"].repeat(batch_size, 1)
stats["max"] = stats["max"].repeat(batch_size, 1)
to_normalize = ~self.constant_stats[stats_key]
constant_stats = stats["max"] == stats["min"]
stats["min"][constant_stats] = -1
stats["max"][constant_stats] = 1
for t in range(data.shape[1]):
data[:,t,to_normalize] = (data[:,t,to_normalize] - stats['min'][:,to_normalize]) / (stats['max'][:,to_normalize] - stats['min'][:,to_normalize])
data[:,t,to_normalize] = data[:,t,to_normalize] * 2 - 1
return data

def unnormalize_data(self, data, stats_key, batch_size=1):
# avoid division by zero by skipping normalization
with torch.no_grad():
stats = self.stats[stats_key].copy()
# duplicate stats for each batch
stats["min"] = stats["min"].repeat(batch_size, 1)
stats["max"] = stats["max"].repeat(batch_size, 1)
data = data.clone().to(dtype=torch.float64)
to_normalize = ~self.constant_stats[stats_key]
constant_stats = stats["max"] == stats["min"]
stats["min"][constant_stats] = -1
stats["max"][constant_stats] = 1
for t in range(data.shape[1]):
data[:,t,to_normalize] = (data[:,t,to_normalize] + 1) / 2
data[:,t,to_normalize] = data[:,t,to_normalize] * (stats['max'][:,to_normalize] - stats['min'][:,to_normalize]) + stats['min'][:,to_normalize]
return data

class MultiRobotGraphDataset(RobomimicGraphDataset):
'''
Expand Down
17 changes: 15 additions & 2 deletions imitation/dataset/robomimic_lowdim_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@
from diffusion_policy.model.common.normalizer import LinearNormalizer


from diffusion_policy.model.common.rotation_transformer import RotationTransformer


log = logging.getLogger(__name__)

class RobomimicLowdimDataset(torch.utils.data.Dataset):
Expand All @@ -29,6 +32,10 @@ def __init__(self,
self.dataset_path = dataset_path
self.dataset_root = h5py.File(dataset_path, 'r')
self.dataset_keys = list(self.dataset_root["data"].keys())
self.rotation_transformer = RotationTransformer(
from_rep="quaternion",
to_rep="rotation_6d"
)
try:
self.dataset_keys.remove("mask")
except:
Expand Down Expand Up @@ -86,10 +93,16 @@ def create_sample_indices(self):
self.indices.append(idx_global + idx)
data_obs_keys = []
for obs_key in self.obs_keys:
data_obs_keys.append(self.dataset_root[f"data/{key}/obs/{obs_key}"][idx - self.obs_horizon:idx, :])
obs = self.dataset_root[f"data/{key}/obs/{obs_key}"][idx - self.obs_horizon:idx, :]
if "quat" in obs_key:
obs = self.rotation_transformer.forward(obs)
data_obs_keys.append(obs)
data_action_keys = []
for action_key in self.action_keys:
data_action_keys.append(self.dataset_root[f"data/{key}/obs/{action_key}"][idx:idx+self.pred_horizon, :])
action = self.dataset_root[f"data/{key}/obs/{action_key}"][idx:idx+self.pred_horizon, :]
if "quat" in action_key:
action = self.rotation_transformer.forward(action)
data_action_keys.append(action)
data_obs_keys = np.concatenate(data_obs_keys, axis=-1)
data_action_keys = np.concatenate(data_action_keys, axis=-1)
self.data_at_indices.append({
Expand Down
Loading

0 comments on commit 3003b1b

Please sign in to comment.