In [1]:
from IPython import display
import argparse
import cv2
import copy
from collections import deque
import gym
import hydra.utils
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import os, sys, shutil
import torch
import omegaconf
import time
import torch
from typing import Optional, Sequence, cast
import torch.nn as nn

import mbrl
import mbrl.models as models
import mbrl.planning as planning
import mbrl.util.common as common_util
import mbrl.third_party.pytorch_sac_pranz24 as pytorch_sac_pranz24
import mbrl.util.math as math_util

from mbrl.planning.sac_wrapper import SACAgent, SACAgentSb3
from mbrl.third_party.sb3_sac.sac import SAC
from mbrl.util.plot_and_save_push_data import plot_and_save_training, plot_and_save_push_plots, clear_and_create_dir
from mbrl.util.eval_agent import eval_and_save_vid

import tactile_gym.rl_envs
from tactile_gym.sb3_helpers.params import import_parameters
# from tactile_gym.utils.general_utils import get_orn_diff, quaternion_multiply, get_inverse_quaternion


from stable_baselines3.common.torch_layers import NatureCNN
from tactile_gym.sb3_helpers.custom.custom_torch_layers import CustomCombinedExtractor, ImpalaCNN

from pyvirtualdisplay import Display
_display = Display(visible=False, size=(1400, 900))
_ = _display.start()

device = 'cuda:0' if torch.cuda.is_available() else 'cpu'

In [2]:
# Load the environment 
algo_name = 'ppo'
env_name = 'object_push-v0'
rl_params, algo_params, augmentations = import_parameters(env_name, algo_name)

rl_params["max_ep_len"] = 1000    
rl_params["env_modes"]["observation_mode"] = "tactile_pose_relative_data"
rl_params["env_modes"][ 'control_mode'] = 'TCP_position_control'
rl_params["env_modes"]['movement_mode'] = 'TyRz'
rl_params["env_modes"]['traj_type'] = 'point'
rl_params["env_modes"]['task'] = "goal_pos"
rl_params["env_modes"]['planar_states'] = True
rl_params["env_modes"]['use_contact'] = True
rl_params["env_modes"]['terminate_early']  = True
rl_params["env_modes"]['terminate_terminate_early'] = True

rl_params["env_modes"]['rand_init_orn'] = True
# rl_params["env_modes"]['rand_init_pos_y'] = True
# rl_params["env_modes"]['rand_obj_mass'] = True

rl_params["env_modes"]['additional_reward_settings'] = 'john_guide_off_normal'
rl_params["env_modes"]['terminated_early_penalty'] =  0.0
rl_params["env_modes"]['reached_goal_reward'] = 0.0
rl_params["env_modes"]['max_no_contact_steps'] = 1000
rl_params["env_modes"]['max_tcp_to_obj_orn'] = 180/180 * np.pi
rl_params["env_modes"]['importance_obj_goal_pos'] = 1.0
rl_params["env_modes"]['importance_obj_goal_orn'] = 1.0
rl_params["env_modes"]['importance_tip_obj_orn'] = 1.0
rl_params["env_modes"]["x_speed_ratio"] = 5.0
rl_params["env_modes"]["y_speed_ratio"] = 5.0
rl_params["env_modes"]["Rz_speed_ratio"] = 5.0

rl_params["env_modes"]['mpc_goal_orn_update'] = True
rl_params["env_modes"]['goal_orn_update_freq'] = 'every_step'


# set limits and goals
TCP_lims = np.zeros(shape=(6, 2))
TCP_lims[0, 0], TCP_lims[0, 1] = -0.1, 0.4  # x lims
TCP_lims[1, 0], TCP_lims[1, 1] = -0.3, 0.3  # y lims
TCP_lims[2, 0], TCP_lims[2, 1] = -0.0, 0.0  # z lims
TCP_lims[3, 0], TCP_lims[3, 1] = -0.0, 0.0  # roll lims
TCP_lims[4, 0], TCP_lims[4, 1] = -0.0, 0.0  # pitch lims
TCP_lims[5, 0], TCP_lims[5, 1] = -180 * np.pi / 180, 180 * np.pi / 180  # yaw lims

# goal parameter
goal_edges = [(0, -1), (0, 1), (1, 0)] # Top bottom and stright
# goal_edges = [(1, 0)]
goal_x_max = np.float64(TCP_lims[0, 1] * 0.8).item()
goal_x_min = 0.0 # np.float64(TCP_lims[0, 0] * 0.6).item()
goal_y_max = np.float64(TCP_lims[1, 1] * 0.6).item()
goal_y_min = np.float64(TCP_lims[1, 0] * 0.6).item()
goal_ranges = [goal_x_min, goal_x_max, goal_y_min, goal_y_max]

rl_params["env_modes"]['tcp_lims'] = TCP_lims.tolist()
rl_params["env_modes"]['goal_edges'] = goal_edges
rl_params["env_modes"]['goal_ranges'] = goal_ranges

env_kwargs={
    'show_gui':False,
    'show_tactile':False,
    'states_stacked_len': 1,
    'max_steps':rl_params["max_ep_len"],
    'image_size':rl_params["image_size"],
    'env_modes':rl_params["env_modes"],
}

# training environment
env = gym.make(env_name, **env_kwargs)
seed = 0
env.seed(seed)
rng = np.random.default_rng(seed=0)
generator = torch.Generator(device=device)
generator.manual_seed(seed)
obs_shape = env.observation_space.shape
act_shape = env.action_space.shape
history_len = env.states_stacked_len
step_obs_shape = (obs_shape[-1], )
step_act_shape = act_shape
stacked_obs_shape = (obs_shape[-1] * history_len, )
stacked_act_shape = (history_len * act_shape[-1], )
goal_shape = env.get_goal_obs().shape

# evaluation environment
num_eval_episodes = 7
eval_env_kwargs = copy.deepcopy(env_kwargs)
eval_env_kwargs["env_modes"]['eval_mode'] = True
eval_env_kwargs["env_modes"]['eval_num'] = num_eval_episodes
eval_env = gym.make(env_name, **eval_env_kwargs)

work_dir = os.path.join(os.getcwd(), 'trained_mbpo')
clear_and_create_dir(work_dir)

pybullet build time: Mar  8 2021 17:26:24


argv[0]=
Loaded EGL 1.5 after reload.
GL_VENDOR=NVIDIA Corporation
GL_RENDERER=NVIDIA GeForce RTX 3090/PCIe/SSE2
GL_VERSION=4.6.0 NVIDIA 495.29.05
GL_SHADING_LANGUAGE_VERSION=4.60 NVIDIA
Version = 4.6.0 NVIDIA 495.29.05
Vendor = NVIDIA Corporation
Renderer = NVIDIA GeForce RTX 3090/PCIe/SSE2
ven = NVIDIA Corporation
ven = NVIDIA Corporation
argv[0]=
Loaded EGL 1.5 after reload.
GL_VENDOR=NVIDIA Corporation
GL_RENDERER=NVIDIA GeForce RTX 3090/PCIe/SSE2
GL_VERSION=4.6.0 NVIDIA 495.29.05
GL_SHADING_LANGUAGE_VERSION=4.60 NVIDIA
Version = 4.6.0 NVIDIA 495.29.05
Vendor = NVIDIA Corporation
Renderer = NVIDIA GeForce RTX 3090/PCIe/SSE2
ven = NVIDIA Corporation
ven = NVIDIA Corporation


In [3]:
# def quaternions_product(quaternion0, quaternion1):
#     x0, y0, z0, w0 = torch.tensor_split(quaternion0, 4, dim=-1)
#     x1, y1, z1, w1 = torch.tensor_split(quaternion1, 4, dim=-1)
#     return torch.concat(
#         (x1*w0 + y1*z0 - z1*y0 + w1*x0,
#          -x1*z0 + y1*w0 + z1*x0 + w1*y0,
#          x1*y0 - y1*x0 + z1*w0 + w1*z0,
#          -x1*x0 - y1*y0 - z1*z0 + w1*w0),
#         axis=-1)

# def get_inverse_quaternion_batched(orn_batched):
#     conj_orn = orn_batched.clone()
#     conj_orn[:, :3] = -conj_orn[:, :3] 
#     norm_orn = torch.linalg.norm(orn_batched, axis=-1).unsqueeze(1)

#     return conj_orn/norm_orn

# def get_orn_diff_batched(orn_1, orn_2):
#     """
#     Calculate the difference between two orientaion quaternion in the same frame 
#     of reference
#     """
#     return quaternions_product(orn_1, get_inverse_quaternion_batched(orn_2))


# env.update_goal_orn()
# _, cur_obj_orn_workframe = env.get_obj_pos_workframe()
# goal_orn = env.goal_orn_workframe

# # Batch goals and obj orn
# obj_orn_batch = torch.tensor(cur_obj_orn_workframe).repeat(2, 1).to(torch.float32)
# goal_orn_batch = torch.tensor(goal_orn).repeat(2, 1).to(torch.float32)

# print(env.get_goal_aware_tactile_pose_relative_obs()[-2:])
# print(get_orn_diff(cur_obj_orn_workframe, goal_orn))
# print(get_orn_diff_batched(obj_orn_batch, goal_orn_batch))

# # print(quaternion_multiply(cur_obj_orn_workframe, goal_orn))
# # print(quat_multiply(obj_orn_batch, goal_orn_batch))

# # print(get_inverse_quaternion(np.array(cur_obj_orn_workframe)))
# # print(get_inverse_quaternion_batched(obj_orn_batch))

In [4]:
num_trials = 1000
initial_buffer_size = 10000
trial_length = env._max_steps
buffer_size = num_trials * trial_length

# cfg_dict = {
#     # dynamics model configuration
#     "dynamics_model": {
#         "_target_": "mbrl.models.GaussianMLP",
#         "device": device,
#         "num_layers": 3,
#         "ensemble_size": 5,
#         "hid_size": 200,
#         "in_size": "???",
#         "out_size": "???",
#         "deterministic": False,
#         "propagation_method": "fixed_model",
#         "learn_logvar_bounds": False,
#         # can also configure activation function for GaussianMLP
#         # "activation_fn_cfg": {
#         #     "_target_": "torch.nn.LeakyReLU",
#         #     "negative_slope": 0.01
#         # }
#         "activation_fn_cfg": {
#             "_target_": "torch.nn.SiLU",
#         }
#     },
#     # options for training the dynamics model
#     "algorithm": {
#         "learned_rewards": False,
#         "target_is_delta": True,
#         "normalize": True,
#         "target_normalize": True,
#         "dataset_size": buffer_size,
#         "initial_dataset_size": initial_buffer_size,
#         "using_history_of_obs": True,
#         "sac_samples_action": True,
#         "real_data_ratio": 0.0,
#     },
#     # these are experiment specific options
#     "overrides": {
#         "trial_length": trial_length,
#         "epoch_length": trial_length,
#         "num_steps": num_trials * trial_length,
#         "patience": 25,
#         "num_epochs_train_model": 25,
#         "model_lr": 0.0005,
#         "model_wd": 0.0001,
#         "model_batch_size": 32,
#         "validation_ratio": 0.0,
#         "freq_train_model": 500,
#         "effective_model_rollouts_per_step": 200,
#         "rollout_schedule": [10, 50, 1, 20],
#         "num_sac_updates_per_step": 20,
#         "sac_updates_every_steps": 1,
#         "num_epochs_to_retain_sac_buffer": 1,
#         "sac_batch_size": 512,
#     }
# }

# agent_cfg_dict = {
#     "_target_": "mbrl.third_party.pytorch_sac_pranz24.sac.SAC",
#     "num_inputs": "???",
#     "action_space": {
#         "_target_": "gym.spaces.Box",
#         "low": "???",
#         "high": "???",
#         "shape": "???",
#     },

#     "args": {
#         "gamma": 0.99,
#         "tau": 0.005,
#         "alpha": 0.2,
#         "policy": "Gaussian",
#         "target_update_interval": 1,
#         "automatic_entropy_tuning": False,
#         "target_entropy": 0.1,
#         "hidden_size": 256,
#         "device": device,
#         "lr": 0.0001,
#     }
# }

cfg_dict = {
    # dynamics model configuration
    "dynamics_model": {
        "_target_": "mbrl.models.GaussianMLP",
        "device": device,
        "num_layers": 3,
        "ensemble_size": 5,
        "hid_size": 200,
        "in_size": "???",
        "out_size": "???",
        "deterministic": False,
        "propagation_method": "fixed_model",
        "learn_logvar_bounds": False,
        # can also configure activation function for GaussianMLP
        # "activation_fn_cfg": {
        #     "_target_": "torch.nn.LeakyReLU",
        #     "negative_slope": 0.01
        # }
        "activation_fn_cfg": {
            "_target_": "torch.nn.SiLU",
        }
    },
    # options for training the dynamics model
    "algorithm": {
        "learned_rewards": False,
        "target_is_delta": True,
        "normalize": True,
        "target_normalize": True,
        "dataset_size": buffer_size,
        "initial_dataset_size": initial_buffer_size,
        "using_history_of_obs": True,
        "sac_samples_action": True,
        "real_data_ratio": 0.0,
    },
    # these are experiment specific options
    "overrides": {
        "trial_length": trial_length,
        "epoch_length": trial_length,
        "num_steps": num_trials * trial_length,
        "patience": 5,
        "num_epochs_train_model": None,
        "model_lr": 0.001,
        "model_wd": 0.00005,
        "model_batch_size": 32,
        "validation_ratio": 0.0,
        "freq_train_model": 500,
        "effective_model_rollouts_per_step": 100,
        "rollout_schedule": [20, 100, 1, 20],
        "num_sac_updates_per_step": 10,
        "sac_updates_every_steps": 1,
        "num_epochs_to_retain_sac_buffer": 1,
        "sac_batch_size": 512,
    }
}

agent_cfg_dict = {
    "_target_": "mbrl.third_party.pytorch_sac_pranz24.sac.SAC",
    "num_inputs": "???",
    "action_space": {
        "_target_": "gym.spaces.Box",
        "low": "???",
        "high": "???",
        "shape": "???",
    },

    "args": {
        "gamma": 0.99,
        "tau": 0.005,
        "alpha": 0.05,
        "policy": "Gaussian",
        "target_update_interval": 1,
        "automatic_entropy_tuning": False,
        "target_entropy": 0.1,
        "hidden_size": 256,
        "device": device,
        "lr": 0.0003,
    }
}

In [5]:
# sac_params = {
#     # === net arch ===
#     "policy_kwargs": {
#         "features_extractor_class": CustomCombinedExtractor,
#         "features_extractor_kwargs": {
#             'cnn_base':NatureCNN,
#             # 'cnn_base':ImpalaCNN,
#             'cnn_output_dim':256,
#             'mlp_extractor_net_arch':[64, 64],
#         },
#         "net_arch": dict(pi=[256, 256], qf=[256, 256]),
#         "activation_fn": nn.Tanh,
#     },

#     # ==== rl params ====
#     "learning_rate": 1e-4,
#     "buffer_size": int(1e6),
#     "learning_starts": 1e4,
#     "batch_size": 512,
#     "tau": 0.005,
#     "gamma": 0.99,
#     "train_freq": 1,
#     "gradient_steps": 1,
#     "action_noise": None,
#     "optimize_memory_usage":False,
#     "ent_coef": "auto",
#     "target_update_interval": 1,
#     "target_entropy": "auto",
#     "use_sde": False,
#     "sde_sample_freq": -1,
#     "use_sde_at_warmup": False,
# }

# model = SAC(
#     rl_params["policy"],
#     env,
#     **sac_params,
#     verbose=1,
#     device=device,
# )

# agent = SACAgentSb3(model)

In [6]:
agent_cfg = omegaconf.OmegaConf.create(agent_cfg_dict)
planning.complete_agent_cfg(env, agent_cfg)
agent = SACAgent(
    cast(pytorch_sac_pranz24.SAC, hydra.utils.instantiate(agent_cfg))
)

cfg = omegaconf.OmegaConf.create(cfg_dict)
dynamics_model = common_util.create_one_dim_tr_model(cfg, obs_shape, act_shape)
model_env = models.ModelEnvPushing(env, dynamics_model, termination_fn=None, reward_fn=None, generator=generator)

  if OmegaConf.is_none(config):


In [7]:
print(dynamics_model)

OneDTransitionRewardModel(
  (model): GaussianMLP(
    (hidden_layers): Sequential(
      (0): Sequential(
        (0): EnsembleLinearLayer(num_members=5, in_size=10, out_size=200, bias=True)
        (1): SiLU()
      )
      (1): Sequential(
        (0): EnsembleLinearLayer(num_members=5, in_size=200, out_size=200, bias=True)
        (1): SiLU()
      )
      (2): Sequential(
        (0): EnsembleLinearLayer(num_members=5, in_size=200, out_size=200, bias=True)
        (1): SiLU()
      )
    )
    (mean_and_logvar): EnsembleLinearLayer(num_members=5, in_size=200, out_size=16, bias=True)
  )
)


In [8]:
print(agent.sac_agent.policy)
print(agent.sac_agent.critic)

GaussianPolicy(
  (linear1): Linear(in_features=8, out_features=256, bias=True)
  (linear2): Linear(in_features=256, out_features=256, bias=True)
  (mean_linear): Linear(in_features=256, out_features=2, bias=True)
  (log_std_linear): Linear(in_features=256, out_features=2, bias=True)
)
QNetwork(
  (linear1): Linear(in_features=10, out_features=256, bias=True)
  (linear2): Linear(in_features=256, out_features=256, bias=True)
  (linear3): Linear(in_features=256, out_features=1, bias=True)
  (linear4): Linear(in_features=10, out_features=256, bias=True)
  (linear5): Linear(in_features=256, out_features=256, bias=True)
  (linear6): Linear(in_features=256, out_features=1, bias=True)
)


In [9]:
# env.update_goal_orn()
# obs = env.get_observation()
# goal_aware_obs = env.get_goal_aware_tactile_pose_relative_obs()
# goal_pos_workframe = env.goal_pos_workframe

# # Batch goals and obj orn
# obs_batch = np.repeat(obs, repeats=2, axis=0)
# goal_pos_workframe_batch = np.repeat(goal_pos_workframe[None, :], repeats=2, axis=0)

# model_env.reset_batch_goals(2)
# model_env.update_step_data(obs_batch)
# agent_obs = model_env.get_agent_obs(obs_batch)
# agent_obs_goal = model_env.get_agent_obs(obs_batch, goal_pos_workframe_batch)

# # print(obs_batch)
# # Should be the same 
# print(goal_aware_obs)
# print(agent_obs)
# print(agent_obs_goal)

In [10]:
# -------------- Create initial overrides. dataset --------------
replay_buffer = common_util.create_replay_buffer(
    cfg,
    stacked_obs_shape,
    stacked_act_shape,
    rng=rng,
    next_obs_shape=(obs_shape[-1], ),
    goal_shape=goal_shape,
)

common_util.rollout_agent_trajectories(
    env,
    initial_buffer_size,
    agent,
    {"sample": True, "batched": False},
    replay_buffer=replay_buffer,
    stacking=cfg.algorithm.using_history_of_obs,
    store_goals=True,
)

print("Data collected: ", initial_buffer_size)

Data collected:  10000


In [11]:
# evaluation_result_directory = os.path.join(work_dir, 'random_exploration_agent')
# clear_and_create_dir(evaluation_result_directory)

# avg_reward = eval_and_save_vid(
#     env,
#     agent,
#     n_eval_episodes=10,
#     trial_length=trial_length,
#     save_and_plot_flag=True,
#     save_vid=True,
#     render=True,
#     data_directory=evaluation_result_directory,
#     print_ep_reward=True,
#     agent_kwargs={"sample": True, "batched": False},
# )

In [12]:
# data = replay_buffer.get_all()
# print(data.obs.shape)
# print(data.act.shape)
# print(data.next_obs.shape)

In [13]:
# print(env.goal_orn_workframe)
# init_obs = env.get_observation()
# obs = init_obs[-1]
# batch_size = 3
# model_env.reset_batch_goals(batch_size)
# model_env.update_goal_orn(torch.from_numpy(obs).repeat(batch_size, 1).to(device))
# print(model_env.goal_orn_workframe_batch)

In [14]:
def rollout_model_and_populate_sac_buffer(
    model_env: models.ModelEnvPushing,
    replay_buffer: mbrl.util.ReplayBuffer,
    agent: SACAgent,
    sac_buffer: mbrl.util.ReplayBuffer,
    sac_samples_action: bool,
    rollout_horizon: int,
    batch_size: int,
):

    batch = replay_buffer.sample(batch_size)
    initial_obs, initial_act, _, _, _, initial_goals  = cast(mbrl.types.TransitionBatch, batch).astuple()
    # goal_pos_workframe = np.empty((batch_size, 3))
    # goal_pos_workframe[:, 0:2] = initial_goals[:, 4:6]
    model_state = model_env.reset(
        initial_obs_batch=cast(np.ndarray, initial_obs),
        return_as_np=True,
    )
    accum_dones = np.zeros(initial_obs.shape[0], dtype=bool)
    obs = initial_obs
    stacked_act = initial_act
    model_env.reset_batch_goals(batch_size, sample_goals=True)
    # model_env.reset_batch_goals(batch_size, sample_goals=False, goal_batch=goal_pos_workframe)
    model_env.update_step_data(obs)
    for i in range(rollout_horizon):
        agent_obs = model_env.get_agent_obs(obs)
        action = agent.act(agent_obs, action=None, sample=sac_samples_action, batched=True)
        stacked_act = np.concatenate([stacked_act, action], axis=action.ndim - 1)
        stacked_act = stacked_act[:, -initial_act.shape[-1]:]
        pred_next_obs, pred_rewards, pred_dones, model_state = model_env.step(
            stacked_act, model_state, sample=True
        )
        pred_next_obs = np.concatenate([obs, pred_next_obs], axis=pred_next_obs.ndim - 1)
        pred_next_obs = pred_next_obs[:, -initial_obs.shape[-1]:]
        agent_next_obs = model_env.get_agent_obs(pred_next_obs)
        sac_buffer.add_batch(
            agent_obs[~accum_dones],
            action[~accum_dones],
            agent_next_obs[~accum_dones],
            pred_rewards[~accum_dones, 0],
            pred_dones[~accum_dones, 0],
        )
        obs = pred_next_obs
        accum_dones |= pred_dones.squeeze()


def maybe_replace_sac_buffer(
    sac_buffer: Optional[mbrl.util.ReplayBuffer],
    obs_shape: Sequence[int],
    act_shape: Sequence[int],
    new_capacity: int,
    seed: int,
) -> mbrl.util.ReplayBuffer:
    if sac_buffer is None or new_capacity != sac_buffer.capacity:
        if sac_buffer is None:
            rng = np.random.default_rng(seed=seed)
        else:
            rng = sac_buffer.rng
        new_buffer = mbrl.util.ReplayBuffer(new_capacity, obs_shape, act_shape, rng=rng)
        if sac_buffer is None:
            return new_buffer
        obs, action, next_obs, reward, done, _ = sac_buffer.get_all().astuple()
        new_buffer.add_batch(obs, action, next_obs, reward, done)
        return new_buffer
    return sac_buffer


# rollout_batch_size = (
#     cfg.overrides.effective_model_rollouts_per_step * cfg.overrides.freq_train_model
# )
# trains_per_epoch = int(
#     np.ceil(cfg.overrides.epoch_length / cfg.overrides.freq_train_model)
# )
# rollout_length = int(
#     math_util.truncated_linear(
#         *(cfg.overrides.rollout_schedule + [1])
#     )
# )
# sac_buffer = None
# sac_buffer_capacity = rollout_length * rollout_batch_size * trains_per_epoch
# sac_buffer_capacity *= cfg.overrides.num_epochs_to_retain_sac_buffer
# sac_buffer = maybe_replace_sac_buffer(
#     sac_buffer, stacked_obs_shape, act_shape, sac_buffer_capacity, rng
# )

# # Batch all rollouts for the next freq_train_model steps together
# rollout_model_and_populate_sac_buffer(
#     model_env,
#     replay_buffer,
#     agent,
#     sac_buffer,
#     cfg.algorithm.sac_samples_action,
#     rollout_length,
#     rollout_batch_size,
# )

In [15]:
# # obs = torch.randn(1, 8).to(device)
# obs = np.random.randn(32, 8)
# print(agent.sac_agent.sample_action(obs, batched=True, deterministic=True))

In [16]:
def replay_to_sac_buffer(replay_buffer, sac_obs_shape, sac_act_shape, rng):
    new_sac_buffer = mbrl.util.ReplayBuffer(replay_buffer.num_stored, sac_obs_shape, sac_act_shape, rng=rng)
    obs, action, next_obs, reward, done, _ = replay_buffer.get_all().astuple()

    # Get next_obs and action states
    next_obs = np.concatenate([obs[:, next_obs.shape[-1]:], next_obs], axis=1)
    action = action[:, -sac_act_shape[-1]:]

    # Fill new buffer
    new_sac_buffer.add_batch(obs, action, next_obs, reward, done)

    return new_sac_buffer

# obs, action, next_obs, *_ = replay_buffer.get_all().astuple()
# print(obs.shape)
# print(next_obs.shape)
# print(action.shape)
# new_buffer = replay_to_sac_buffer(replay_buffer, stacked_obs_shape, act_shape, rng)
# new_obs, new_action, new_next_obs, *_ = new_buffer.get_all().astuple()
# print(new_obs.shape)
# print(new_next_obs.shape)
# print(new_action.shape)

# print(obs[0])
# print(new_obs[0])
# print(action[0])
# print(new_action[0])
# print(next_obs[0])
# print(new_next_obs[0])

In [17]:
# def train_model_and_save_model_and_data(
#     model,
#     model_trainer,
#     cfg,
#     replay_buffer,
#     work_dir = None,
#     callback = None,    
# ):
#     dynamics_model.update_normalizer(replay_buffer.get_all())  # update normalizer stats            
#     dataset_train, dataset_val = common_util.get_basic_buffer_iterators(
#         replay_buffer,
#         batch_size=cfg.overrides.model_batch_size,
#         val_ratio=cfg.overrides.validation_ratio,
#         ensemble_size=len(dynamics_model),
#         shuffle_each_epoch=True,
#         bootstrap_permutes=False,  # build bootstrap dataset using sampling with replacement
#     )

#     model_trainer.train(
#         dataset_train, 
#         dataset_val=dataset_val, 
#         num_epochs=cfg.overrides.num_epochs_train_model, 
#         patience=cfg.overrides.patience, 
#         callback=callback,
#         silent=True)

In [18]:
def plot_and_save(y_data, x_data=None, title=None, xlabel=None, ylabel=None):

    fig, ax = plt.subplots(3, 2, figsize=(14, 10))

    for i in range(3):
        for j in range(2):
            
            if (2 * i + j) == 5:
                break
            if not x_data:
                ax[i, j].plot(y_data[2 * i + j])
            else:   
                ax[i, j].plot(x_data[2 * i + j], y_data[2 * i + j])
            
            if title:
                ax[i, j].set_title(title[2 * i + j])
            if xlabel:
                ax[i, j].set_xlabel(xlabel[2 * i + j])
            if ylabel:
                ax[i, j].set_ylabel(ylabel[2 * i + j])

    fig.savefig(os.path.join(work_dir, "losses.png"))
    plt.close(fig)

training_result_directory = os.path.join(work_dir, "training_result")
data_columns = ['trial','trial_steps', 'time_steps', 'tcp_x','tcp_y','tcp_z','contact_x', 'contact_y', 'contact_z', 'tcp_Rz', 'contact_Rz', 'goal_x', 'goal_y', 'goal_Rz', 'rewards', 'contact', 'dones']

# Save losses 
train_losses = [0.0]
val_scores = [0.0]
policy_losses = [0.0]
qf1_losses = [0.0]
qf2_losses = [0.0]

def train_callback(_model, _total_calls, _epoch, tr_loss, val_score, _best_val):
    train_losses.append(tr_loss)
    val_scores.append(val_score.mean().item())   # this returns val score per ensemble model

# ---------------------------------------------------------
# --------------------- Training Loop ---------------------
rollout_batch_size = (
    cfg.overrides.effective_model_rollouts_per_step * cfg.overrides.freq_train_model
)
trains_per_epoch = int(
    np.ceil(cfg.overrides.epoch_length / cfg.overrides.freq_train_model)
)
updates_made = 0
env_steps = 0
model_env = models.ModelEnvPushing(
    env, dynamics_model, termination_fn=None, reward_fn=None, generator=generator
)
model_trainer = models.ModelTrainer(
    dynamics_model,
    optim_lr=cfg.overrides.model_lr,
    weight_decay=cfg.overrides.model_wd,
)

best_eval_reward = -np.inf
epoch = 0
sac_buffer = None
all_train_rewards = [0]
all_eval_rewards = [0]
total_steps_train = [0]
total_steps_eval = [0]
goal_reached = [0]
trial_push_result = []
trial = 0
while env_steps < cfg.overrides.num_steps:
    rollout_length = int(
        math_util.truncated_linear(
            *(cfg.overrides.rollout_schedule + [epoch + 1])
        )
    )
    sac_buffer_capacity = rollout_length * rollout_batch_size * trains_per_epoch
    sac_buffer_capacity *= cfg.overrides.num_epochs_to_retain_sac_buffer
    sac_buffer = maybe_replace_sac_buffer(
        sac_buffer, stacked_obs_shape, act_shape, sac_buffer_capacity, rng
    )
    obs, done = None, False
    for steps_epoch in range(cfg.overrides.epoch_length):
        if steps_epoch == 0 or done:
            if done:
                # save goal reached data during training
                if env.single_goal_reached:
                    goal_reached.append(trial_reward)
                else:
                    goal_reached.append(0)

                # Save data to csv and plot
                all_train_rewards.append(trial_reward)
                total_steps_train.append(steps_trial + total_steps_train[-1])
                trial_time = time.time() - start_trial_time

                # Save data to csv and plot
                # trial_push_result = np.array(trial_push_result)
                # plot_and_save_training(env, trial_push_result, trial, data_columns, training_result_directory)
                trial_push_result = []

                # Save and plot training curve 
                training_result = np.stack((total_steps_train[1:], all_train_rewards[1:]), axis=-1)
                pd.DataFrame(training_result).to_csv(os.path.join(work_dir, "{}_result.csv".format("train_curve")))
                fig, ax = plt.subplots(figsize=(12, 6))
                ax.plot(total_steps_train[1:], all_train_rewards[1:], 'bs-', total_steps_train[1:], goal_reached[1:], 'rs')
                ax.set_xlabel("Samples")
                ax.set_ylabel("Trial reward")
                fig.savefig(os.path.join(work_dir, "output_train.png"))        
                plt.close(fig)

                trial += 1

            obs, done = env.reset(), False
            stacked_act = deque(np.zeros((env.states_stacked_len, *env.action_space.shape)), maxlen=env.states_stacked_len)
            trial_reward = 0.0
            trial_pb_steps = 0.0
            steps_trial = 0
            start_trial_time = time.time()

            (tcp_pos_workframe, 
            tcp_rpy_workframe,
            cur_obj_pos_workframe, 
            cur_obj_rpy_workframe) = env.get_obs_workframe()
            trial_push_result.append(np.hstack([trial, 
                                                steps_trial, 
                                                trial_pb_steps,
                                                tcp_pos_workframe, 
                                                cur_obj_pos_workframe, 
                                                tcp_rpy_workframe[2],
                                                cur_obj_rpy_workframe[2],
                                                env.goal_pos_workframe[0:2], 
                                                env.goal_rpy_workframe[2],
                                                trial_reward, 
                                                False,
                                                done]))
                        

        # --- Doing env step and adding to model dataset ---
        next_obs, reward, done, info = common_util.step_env_and_add_to_buffer_stacked(
            env, obs, agent, {}, replay_buffer, stacked_action=stacked_act, store_goals=True,
        )

        # --------------- Model Training -----------------
        if (env_steps + 1) % cfg.overrides.freq_train_model == 0:
            dynamics_model.update_normalizer(replay_buffer.get_all())  # update normalizer stats            
            dataset_train, dataset_val = common_util.get_basic_buffer_iterators(
                replay_buffer,
                batch_size=cfg.overrides.model_batch_size,
                val_ratio=cfg.overrides.validation_ratio,
                ensemble_size=len(dynamics_model),
                shuffle_each_epoch=True,
                bootstrap_permutes=False,  # build bootstrap dataset using sampling with replacement
            )
            # print("Training model...")
            model_trainer.train(
                dataset_train, 
                dataset_val=dataset_val, 
                num_epochs=cfg.overrides.num_epochs_train_model, 
                patience=cfg.overrides.patience, 
                callback=train_callback,
                silent=True)
            # print("Model training done.")
            
            
            # --------- Rollout new model and store imagined trajectories --------
            # Batch all rollouts for the next freq_train_model steps together
            rollout_model_and_populate_sac_buffer(
                model_env,
                replay_buffer,
                agent,
                sac_buffer,
                cfg.algorithm.sac_samples_action,
                rollout_length,
                rollout_batch_size,
            )

        # --------------- Agent Training -----------------
        # print("Training agent...")
        for _ in range(cfg.overrides.num_sac_updates_per_step):
            use_real_data = rng.random() < cfg.algorithm.real_data_ratio
            replay_sac_buffer = replay_to_sac_buffer(replay_buffer, stacked_obs_shape, act_shape, rng)
            which_buffer = replay_sac_buffer if use_real_data else sac_buffer
            if (env_steps + 1) % cfg.overrides.sac_updates_every_steps != 0 or len(
                which_buffer
            ) < cfg.overrides.sac_batch_size:
                
                qf1_loss = 0
                qf2_loss = 0
                policy_loss = 0
                # print("Agent training skipped. Buffer not big enough")
                break  # only update every once in a while
            (
                qf1_loss,
                qf2_loss,
                policy_loss,
                _,
                _,
            ) = agent.sac_agent.update_parameters(
                which_buffer,
                cfg.overrides.sac_batch_size,
                updates_made,
                logger=None,
                reverse_mask=True,
            )
            # qf2_loss = 0
            # (
            #     qf1_loss,
            #     policy_loss,
            #     _,
            #     _,
            # ) = agent.sac_agent.train(
            #     1,
            #     which_buffer,
            #     cfg.overrides.sac_batch_size,
            # )
            updates_made += 1
        # print("Agent training done.")

        qf1_losses.append(qf1_loss)
        qf2_losses.append(qf2_loss)
        policy_losses.append(policy_loss)

        # ------ Epoch ended (evaluate and save model) ------
        if (env_steps + 1) % cfg.overrides.epoch_length == 0:
            avg_reward = eval_and_save_vid(eval_env, agent, n_eval_episodes=num_eval_episodes)
            all_eval_rewards.append(avg_reward)
            total_steps_eval.append(total_steps_train[-1])

            eval_result = np.stack((total_steps_eval[1:], all_eval_rewards[1:]), axis=-1)
            pd.DataFrame(eval_result).to_csv(os.path.join(work_dir, "{}_result.csv".format("eval_curve")))
            fig, ax = plt.subplots(figsize=(12, 6))
            ax.plot(total_steps_eval[1:], all_eval_rewards[1:], 'bs-')
            ax.set_xlabel("Samples")
            ax.set_ylabel("Eval reward")
            fig.savefig(os.path.join(work_dir, "output_eval.png"))        
            plt.close(fig)

            print(
                f"Epoch: {epoch}. "
                f"SAC buffer size: {len(sac_buffer)}. "
                f"Real buffer size: {len(replay_buffer)}. "
                f"Rollout length: {rollout_length}. "
                f"Steps: {env_steps}"
                f"Average reward: {avg_reward} "
            )
            
            if avg_reward > best_eval_reward:
                best_eval_reward = avg_reward
                agent.sac_agent.save_checkpoint(
                    ckpt_path=os.path.join(work_dir, "sac.pth")
                )
                print("Saved best model")
            epoch += 1

        env_steps += 1
        obs = next_obs
        trial_reward += reward
        trial_pb_steps += info["num_of_pb_steps"]
        steps_trial += 1

        # Save data for plotting training performances
        (tcp_pos_workframe, 
        tcp_rpy_workframe,
        cur_obj_pos_workframe, 
        cur_obj_rpy_workframe) = env.get_obs_workframe()
        trial_push_result.append(np.hstack([trial,
                                        steps_trial,
                                        trial_pb_steps * env._sim_time_step,
                                        tcp_pos_workframe, 
                                        cur_obj_pos_workframe, 
                                        tcp_rpy_workframe[2],
                                        cur_obj_rpy_workframe[2],
                                        env.goal_pos_workframe[0:2], 
                                        env.goal_rpy_workframe[2],
                                        trial_reward, 
                                        info["tip_in_contact"],
                                        done]))
        
    # Plot losses at the end of each epoch
    data = [train_losses, val_scores, qf1_losses, qf2_losses, policy_losses]
    ylabels = ["Train loss", "Val score", "QF1 loss", "QF2 loss", "Policy loss"]
    plot_and_save(data, ylabel=ylabels)


Epoch: 0. SAC buffer size: 100000. Real buffer size: 11000. Rollout length: 1. Steps: 999Average reward: -533.196698152695 
Saving models to /home/qt21590/Documents/Projects/tactile_gym_mbrl/mbrl-lib/notebooks/trained_mbpo/sac.pth
Saved best model
Epoch: 1. SAC buffer size: 100000. Real buffer size: 12000. Rollout length: 1. Steps: 1999Average reward: -382.02271662797244 
Saving models to /home/qt21590/Documents/Projects/tactile_gym_mbrl/mbrl-lib/notebooks/trained_mbpo/sac.pth
Saved best model
Epoch: 2. SAC buffer size: 100000. Real buffer size: 13000. Rollout length: 1. Steps: 2999Average reward: -602.4292650830384 
Epoch: 3. SAC buffer size: 100000. Real buffer size: 14000. Rollout length: 1. Steps: 3999Average reward: -775.4466397417972 
Epoch: 4. SAC buffer size: 100000. Real buffer size: 15000. Rollout length: 1. Steps: 4999Average reward: -390.95800131377524 
Epoch: 5. SAC buffer size: 100000. Real buffer size: 16000. Rollout length: 1. Steps: 5999Average reward: -709.09773946305

ValueError: Expected parameter loc (Tensor of shape (512, 2)) of distribution Normal(loc: torch.Size([512, 2]), scale: torch.Size([512, 2])) to satisfy the constraint Real(), but found invalid values:
tensor([[nan, nan],
        [nan, nan],
        [nan, nan],
        ...,
        [nan, nan],
        [nan, nan],
        [nan, nan]], device='cuda:0')

In [19]:
(
    qf1_loss,
    qf2_loss,
    policy_loss,
    _,
    _,
) = agent.sac_agent.update_parameters(
    which_buffer,
    cfg.overrides.sac_batch_size,
    updates_made,
    logger=None,
    reverse_mask=True,
)

ValueError: Expected parameter loc (Tensor of shape (512, 2)) of distribution Normal(loc: torch.Size([512, 2]), scale: torch.Size([512, 2])) to satisfy the constraint Real(), but found invalid values:
tensor([[nan, nan],
        [nan, nan],
        [nan, nan],
        ...,
        [nan, nan],
        [nan, nan],
        [nan, nan]], device='cuda:0')