In [None]:
import isaacgym
from legged_gym.envs import *
from legged_gym.utils import  get_args, export_policy_as_jit, task_registry, Logger
from isaacgym.torch_utils import *
import torch
import torch.nn as nn
import numpy as np
from rsl_rl.runners import OnPolicyRunner, OnPolicyRunner_MLP

In [None]:
class A():
    class Transition():
        def __init__(self, **kwargs):
            self.observations = None
            self.critic_observations = None
            self.actions = None
            self.rewards = None
            self.dones = None
            self.values = None
            self.actions_log_prob = None
            self.action_mean = None
            self.action_sigma = None
            self.hidden_states = None

            self.extra_info={}
            for key in kwargs:
                print(key)
                self.extra_info[key] = None
            # self.observations_h = None
            # self.body_vel = None
            # self.observations_f = None
            print(self.extra_info)

        def clear(self):
            self.observations = None
            self.critic_observations = None
            self.actions = None
            self.rewards = None
            self.dones = None
            self.values = None
            self.actions_log_prob = None
            self.action_mean = None
            self.action_sigma = None
            self.hidden_states = None

            for key in self.extra_info:
                print(key)
                self.extra_info[key] = None
    

    def __init__(self, num_envs, num_transitions_per_env, device='cpu', **kwargs):
        self.device = device
        self.num_envs = num_envs
        self.num_transitions_per_env = num_transitions_per_env
        self.extra_info={}
                
        for key, value in kwargs.items():
            if key.endswith("shape"):
                print(key, "/", key.rsplit("_",1), ":", value)
                self.extra_info[key.rsplit("_",1)[0]] = torch.zeros(num_transitions_per_env, num_envs, *value, device=self.device)
                
    def get_extra(self):
        for k,d in self.extra_info.items():
            print(k, d.shape, d)

In [None]:
a = A(15,24,device="cpu",
      obs_shape=[70], 
      privileged_obs_shape=[70], 
      actions_shape=[12], 
      obs_h_shape=[70*10],
      obs_f_shape=[70], 
      body_vel_shape=[3])
b = A.Transition(obs_h=[], obs_f=[], body_vel=[])
b.clear()

In [None]:
a.get_extra()

In [None]:
class RolloutStorage2():
    class Transition():
        def __init__(self, **kwargs):
            self.observations = None
            self.critic_observations = None
            self.actions = None
            self.rewards = None
            self.dones = None
            self.values = None
            self.actions_log_prob = None
            self.action_mean = None
            self.action_sigma = None
            self.hidden_states = None

            self.extra_info={}
            for key in kwargs:
                print(key)
                self.extra_info[key] = None

            self.obs_h = None
            self.body_vel = None
            self.obs_f = None

        def clear(self):
            self.observations = None
            self.critic_observations = None
            self.actions = None
            self.rewards = None
            self.dones = None
            self.values = None
            self.actions_log_prob = None
            self.action_mean = None
            self.action_sigma = None
            self.hidden_states = None

            for key in self.extra_info:
                print(key)
                self.extra_info[key] = None
                
            self.obs_h = None
            self.body_vel = None
            self.obs_f = None

    def __init__(self, num_envs, num_transitions_per_env, obs_shape, privileged_obs_shape, actions_shape, 
                 obs_h_sp, obs_f_sp, body_vel_sp, device='cpu', **kwargs):
        self.device = device
        self.obs_shape = obs_shape
        self.privileged_obs_shape = privileged_obs_shape
        self.actions_shape = actions_shape

        # Core
        self.observations = torch.zeros(num_transitions_per_env, num_envs, *obs_shape, device=self.device)
        if privileged_obs_shape[0] is not None:
            self.privileged_observations = torch.zeros(num_transitions_per_env, num_envs, *privileged_obs_shape, 
                                                       device=self.device)
        else:
            self.privileged_observations = None
        self.rewards = torch.zeros(num_transitions_per_env, num_envs, 1, device=self.device)
        self.actions = torch.zeros(num_transitions_per_env, num_envs, *actions_shape, device=self.device)
        self.dones = torch.zeros(num_transitions_per_env, num_envs, 1, device=self.device).byte()

        # For PPO
        self.actions_log_prob = torch.zeros(num_transitions_per_env, num_envs, 1, device=self.device)
        self.values = torch.zeros(num_transitions_per_env, num_envs, 1, device=self.device)
        self.returns = torch.zeros(num_transitions_per_env, num_envs, 1, device=self.device)
        self.advantages = torch.zeros(num_transitions_per_env, num_envs, 1, device=self.device)
        self.mu = torch.zeros(num_transitions_per_env, num_envs, *actions_shape, device=self.device)
        self.sigma = torch.zeros(num_transitions_per_env, num_envs, *actions_shape, device=self.device)

        self.num_transitions_per_env = num_transitions_per_env
        self.num_envs = num_envs

        # rnn
        self.saved_hidden_states_a = None
        self.saved_hidden_states_c = None

        self.step = 0

        # For other info (dream waq)
        self.extra_info={}
        for key, value in kwargs.items():
            if key.endswith("shape"):
                print(key, "/", key.rsplit("_",1), ":", value)
                self.extra_info[key.rsplit("_",1)[0]] = torch.zeros(num_transitions_per_env, num_envs, 
                                                                    *value, device=self.device)

        self.obs_h = torch.zeros(num_transitions_per_env, num_envs, *obs_h_sp, device=self.device)
        self.obs_f = torch.zeros(num_transitions_per_env, num_envs, *obs_f_sp, device=self.device)
        self.body_vel = torch.zeros(num_transitions_per_env, num_envs, *body_vel_sp, device=self.device)

    def add_transitions(self, transition: Transition):
        if self.step >= self.num_transitions_per_env:
            raise AssertionError("Rollout buffer overflow")
        self.observations[self.step].copy_(transition.observations)
        if self.privileged_observations is not None: self.privileged_observations[self.step].copy_(transition.critic_observations)
        self.actions[self.step].copy_(transition.actions)
        self.rewards[self.step].copy_(transition.rewards.view(-1, 1))
        self.dones[self.step].copy_(transition.dones.view(-1, 1))
        self.values[self.step].copy_(transition.values)
        self.actions_log_prob[self.step].copy_(transition.actions_log_prob.view(-1, 1))
        self.mu[self.step].copy_(transition.action_mean)
        self.sigma[self.step].copy_(transition.action_sigma)
        self._save_hidden_states(transition.hidden_states)

        for key in self.extra_info:
            self.extra_info[key][self.step].copy_(transition.extra_info[key])
        self.obs_h[self.step].copy_(transition.observations_h)
        self.body_vel[self.step].copy_(transition.body_vel)
        self.obs_f[self.step].copy_(transition.observations_f)
        
        print("self.step:", self.step)
        for key in self.extra_info:
            print(key)
            print(self.extra_info[key][self.step])
        
        print("self.obs_h:\n", self.obs_h[self.step])
        print("self.obs_f:\n", self.obs_f[self.step])
        print("self.body_vel:\n", self.body_vel[self.step])
            
        self.step += 1

    def _save_hidden_states(self, hidden_states):
        if hidden_states is None or hidden_states==(None, None):
            return
        # make a tuple out of GRU hidden state sto match the LSTM format
        hid_a = hidden_states[0] if isinstance(hidden_states[0], tuple) else (hidden_states[0],)
        hid_c = hidden_states[1] if isinstance(hidden_states[1], tuple) else (hidden_states[1],)

        # initialize if needed
        if self.saved_hidden_states_a is None:
            self.saved_hidden_states_a = [torch.zeros(self.observations.shape[0], *hid_a[i].shape, device=self.device) for i in range(len(hid_a))]
            self.saved_hidden_states_c = [torch.zeros(self.observations.shape[0], *hid_c[i].shape, device=self.device) for i in range(len(hid_c))]
        # copy the states
        for i in range(len(hid_a)):
            self.saved_hidden_states_a[i][self.step].copy_(hid_a[i])
            self.saved_hidden_states_c[i][self.step].copy_(hid_c[i])


    def clear(self):
        self.step = 0

    def compute_returns(self, last_values, gamma, lam):
        advantage = 0
        for step in reversed(range(self.num_transitions_per_env)):
            if step == self.num_transitions_per_env - 1:
                next_values = last_values
            else:
                next_values = self.values[step + 1]
            next_is_not_terminal = 1.0 - self.dones[step].float()
            delta = self.rewards[step] + next_is_not_terminal * gamma * next_values - self.values[step]
            advantage = delta + next_is_not_terminal * gamma * lam * advantage
            self.returns[step] = advantage + self.values[step]

        # Compute and normalize the advantages
        self.advantages = self.returns - self.values
        self.advantages = (self.advantages - self.advantages.mean()) / (self.advantages.std() + 1e-8)

    def get_statistics(self):
        done = self.dones
        done[-1] = 1
        flat_dones = done.permute(1, 0, 2).reshape(-1, 1)
        done_indices = torch.cat((flat_dones.new_tensor([-1], dtype=torch.int64), flat_dones.nonzero(as_tuple=False)[:, 0]))
        trajectory_lengths = (done_indices[1:] - done_indices[:-1])
        return trajectory_lengths.float().mean(), self.rewards.mean()

    def mini_batch_generator(self, num_mini_batches, num_epochs=8):
        batch_size = self.num_envs * self.num_transitions_per_env
        mini_batch_size = batch_size // num_mini_batches
        indices = torch.randperm(num_mini_batches*mini_batch_size, requires_grad=False, device=self.device)

        observations = self.observations.flatten(0, 1)
        if self.privileged_observations is not None:
            critic_observations = self.privileged_observations.flatten(0, 1)
        else:
            critic_observations = observations

        print("flatten")
        extra_flat = {}
        for key in self.extra_info:
            extra_flat[key] = self.extra_info[key].flatten(0, 1)
            print(key)
            print(extra_flat[key])

        observations_h=self.obs_h.flatten(0, 1)
        observations_f = self.obs_f.flatten(0, 1)
        body_vel = self.body_vel.flatten(0, 1)

        actions = self.actions.flatten(0, 1)
        values = self.values.flatten(0, 1)
        returns = self.returns.flatten(0, 1)
        old_actions_log_prob = self.actions_log_prob.flatten(0, 1)
        advantages = self.advantages.flatten(0, 1)
        old_mu = self.mu.flatten(0, 1)
        old_sigma = self.sigma.flatten(0, 1)

        for epoch in range(num_epochs):
            for i in range(num_mini_batches):
                start = i*mini_batch_size
                end = (i+1)*mini_batch_size
                batch_idx = indices[start:end]

                obs_batch = observations[batch_idx]
                critic_observations_batch = critic_observations[batch_idx]
                actions_batch = actions[batch_idx]
                target_values_batch = values[batch_idx]
                returns_batch = returns[batch_idx]
                old_actions_log_prob_batch = old_actions_log_prob[batch_idx]
                advantages_batch = advantages[batch_idx]
                old_mu_batch = old_mu[batch_idx]
                old_sigma_batch = old_sigma[batch_idx]

                obs_h_batch = observations_h[batch_idx]
                obs_f_batch = observations_f[batch_idx]
                body_vel_batch = body_vel[batch_idx]
                
                extra_batch = {}
                for key in extra_flat:
                    extra_batch[key] = extra_flat[key][batch_idx]
                print("extra_batch:\n", extra_batch)
                extra_batch_tuple = tuple(extra_batch.values())
                print("extra_batch_tuple:\n", extra_batch_tuple)

                yield obs_batch, critic_observations_batch, actions_batch, target_values_batch, advantages_batch, returns_batch, \
                       old_actions_log_prob_batch, old_mu_batch, old_sigma_batch, \
                (None, None), None, obs_h_batch, obs_f_batch, body_vel_batch, extra_batch_tuple


In [None]:
class RolloutStorage():
    class Transition():
        def __init__(self, **kwargs):
            self.observations = None
            self.critic_observations = None
            self.actions = None
            self.rewards = None
            self.dones = None
            self.values = None
            self.actions_log_prob = None
            self.action_mean = None
            self.action_sigma = None
            self.hidden_states = None

            self.extra_info={}
            for key in kwargs:
                print(key)
                self.extra_info[key] = None

        def clear(self):
            self.observations = None
            self.critic_observations = None
            self.actions = None
            self.rewards = None
            self.dones = None
            self.values = None
            self.actions_log_prob = None
            self.action_mean = None
            self.action_sigma = None
            self.hidden_states = None

            for key in self.extra_info:
                self.extra_info[key] = None

    def __init__(self, num_envs, num_transitions_per_env, obs_shape, privileged_obs_shape, actions_shape, 
                 device='cpu', **kwargs):
        self.device = device
        self.obs_shape = obs_shape
        self.privileged_obs_shape = privileged_obs_shape
        self.actions_shape = actions_shape

        # Core
        self.observations = torch.zeros(num_transitions_per_env, num_envs, *obs_shape, device=self.device)
        if privileged_obs_shape[0] is not None:
            self.privileged_observations = torch.zeros(num_transitions_per_env, num_envs, *privileged_obs_shape, 
                                                       device=self.device)
        else:
            self.privileged_observations = None
        self.rewards = torch.zeros(num_transitions_per_env, num_envs, 1, device=self.device)
        self.actions = torch.zeros(num_transitions_per_env, num_envs, *actions_shape, device=self.device)
        self.dones = torch.zeros(num_transitions_per_env, num_envs, 1, device=self.device).byte()

        # For PPO
        self.actions_log_prob = torch.zeros(num_transitions_per_env, num_envs, 1, device=self.device)
        self.values = torch.zeros(num_transitions_per_env, num_envs, 1, device=self.device)
        self.returns = torch.zeros(num_transitions_per_env, num_envs, 1, device=self.device)
        self.advantages = torch.zeros(num_transitions_per_env, num_envs, 1, device=self.device)
        self.mu = torch.zeros(num_transitions_per_env, num_envs, *actions_shape, device=self.device)
        self.sigma = torch.zeros(num_transitions_per_env, num_envs, *actions_shape, device=self.device)

        self.num_transitions_per_env = num_transitions_per_env
        self.num_envs = num_envs

        # rnn
        self.saved_hidden_states_a = None
        self.saved_hidden_states_c = None

        self.step = 0

        # For other info (dream waq)
        self.extra_info={}
        for key, value in kwargs.items():
            if key.endswith("shape"):
                print(key, "/", key.rsplit("_",1), ":", value)
                self.extra_info[key.rsplit("_",1)[0]] = torch.zeros(num_transitions_per_env, num_envs, 
                                                                    *value, device=self.device)

    def add_transitions(self, transition: Transition):
        if self.step >= self.num_transitions_per_env:
            raise AssertionError("Rollout buffer overflow")
        self.observations[self.step].copy_(transition.observations)
        if self.privileged_observations is not None: self.privileged_observations[self.step].copy_(transition.critic_observations)
        self.actions[self.step].copy_(transition.actions)
        self.rewards[self.step].copy_(transition.rewards.view(-1, 1))
        self.dones[self.step].copy_(transition.dones.view(-1, 1))
        self.values[self.step].copy_(transition.values)
        self.actions_log_prob[self.step].copy_(transition.actions_log_prob.view(-1, 1))
        self.mu[self.step].copy_(transition.action_mean)
        self.sigma[self.step].copy_(transition.action_sigma)
        self._save_hidden_states(transition.hidden_states)

        for key in self.extra_info:
            self.extra_info[key][self.step].copy_(transition.extra_info[key])
        
#         print("self.step:", self.step)
#         for key in self.extra_info:
#             print(key)
#             print(self.extra_info[key][self.step])
            
        self.step += 1

    def _save_hidden_states(self, hidden_states):
        if hidden_states is None or hidden_states==(None, None):
            return
        # make a tuple out of GRU hidden state sto match the LSTM format
        hid_a = hidden_states[0] if isinstance(hidden_states[0], tuple) else (hidden_states[0],)
        hid_c = hidden_states[1] if isinstance(hidden_states[1], tuple) else (hidden_states[1],)

        # initialize if needed
        if self.saved_hidden_states_a is None:
            self.saved_hidden_states_a = [torch.zeros(self.observations.shape[0], *hid_a[i].shape, device=self.device) for i in range(len(hid_a))]
            self.saved_hidden_states_c = [torch.zeros(self.observations.shape[0], *hid_c[i].shape, device=self.device) for i in range(len(hid_c))]
        # copy the states
        for i in range(len(hid_a)):
            self.saved_hidden_states_a[i][self.step].copy_(hid_a[i])
            self.saved_hidden_states_c[i][self.step].copy_(hid_c[i])


    def clear(self):
        self.step = 0

    def compute_returns(self, last_values, gamma, lam):
        advantage = 0
        for step in reversed(range(self.num_transitions_per_env)):
            if step == self.num_transitions_per_env - 1:
                next_values = last_values
            else:
                next_values = self.values[step + 1]
            next_is_not_terminal = 1.0 - self.dones[step].float()
            delta = self.rewards[step] + next_is_not_terminal * gamma * next_values - self.values[step]
            advantage = delta + next_is_not_terminal * gamma * lam * advantage
            self.returns[step] = advantage + self.values[step]

        # Compute and normalize the advantages
        self.advantages = self.returns - self.values
        self.advantages = (self.advantages - self.advantages.mean()) / (self.advantages.std() + 1e-8)

    def get_statistics(self):
        done = self.dones
        done[-1] = 1
        flat_dones = done.permute(1, 0, 2).reshape(-1, 1)
        done_indices = torch.cat((flat_dones.new_tensor([-1], dtype=torch.int64), flat_dones.nonzero(as_tuple=False)[:, 0]))
        trajectory_lengths = (done_indices[1:] - done_indices[:-1])
        return trajectory_lengths.float().mean(), self.rewards.mean()

    def mini_batch_generator(self, num_mini_batches, num_epochs=8):
        batch_size = self.num_envs * self.num_transitions_per_env
        mini_batch_size = batch_size // num_mini_batches
        indices = torch.randperm(num_mini_batches*mini_batch_size, requires_grad=False, device=self.device)

        observations = self.observations.flatten(0, 1)
        if self.privileged_observations is not None:
            critic_observations = self.privileged_observations.flatten(0, 1)
        else:
            critic_observations = observations

        print("flatten")
        extra_flat = {}
        for key in self.extra_info:
            extra_flat[key] = self.extra_info[key].flatten(0, 1)
            print(key)
            print(extra_flat[key])

        actions = self.actions.flatten(0, 1)
        values = self.values.flatten(0, 1)
        returns = self.returns.flatten(0, 1)
        old_actions_log_prob = self.actions_log_prob.flatten(0, 1)
        advantages = self.advantages.flatten(0, 1)
        old_mu = self.mu.flatten(0, 1)
        old_sigma = self.sigma.flatten(0, 1)

        for epoch in range(num_epochs):
            for i in range(num_mini_batches):
                start = i*mini_batch_size
                end = (i+1)*mini_batch_size
                batch_idx = indices[start:end]

                obs_batch = observations[batch_idx]
                critic_observations_batch = critic_observations[batch_idx]
                actions_batch = actions[batch_idx]
                target_values_batch = values[batch_idx]
                returns_batch = returns[batch_idx]
                old_actions_log_prob_batch = old_actions_log_prob[batch_idx]
                advantages_batch = advantages[batch_idx]
                old_mu_batch = old_mu[batch_idx]
                old_sigma_batch = old_sigma[batch_idx]

                extra_batch = {}
                for key in extra_flat:
                    extra_batch[key] = extra_flat[key][batch_idx]
                print("extra_batch:\n", extra_batch)
                extra_batch_tuple = tuple(extra_batch.values())
                print("extra_batch_tuple:\n", extra_batch_tuple)

                yield obs_batch, critic_observations_batch, actions_batch, target_values_batch, advantages_batch, returns_batch, \
                       old_actions_log_prob_batch, old_mu_batch, old_sigma_batch, \
                (None, None), None, extra_batch_tuple


In [None]:
def init_storage(num_envs, num_transitions_per_env, actor_obs_shape, critic_obs_shape, action_shape,
                 **kwargs):
    extra = {}
    for key, value in kwargs.items():
        if key.endswith("shape"):
            print(key, ":", value)
            extra[key.rsplit("_",1)[0]] = None
    print("extra:", extra)
    storage = RolloutStorage(num_envs, num_transitions_per_env, actor_obs_shape, critic_obs_shape, action_shape,
                            **kwargs)
    transition = RolloutStorage.Transition(**extra)
    return storage, transition

In [None]:
ob_dim = 50
ac_dim = 12
num_transitions_per_env = 40
num_envs = 24
obs_h_shape = [50*10]
obs_f_shape = [50]
body_vel_shape = [3]
storage, transition = init_storage(num_envs, num_transitions_per_env, [ob_dim], [ob_dim], [ac_dim],
                                  obs_h_shape=[50*10], obs_f_shape=[50], body_vel_shape=[3])

In [None]:
storage = RolloutStorage(num_envs, num_transitions_per_env, [ob_dim], [ob_dim], [12], device="cpu",
                         obs_h_shape=obs_h_shape,obs_f_shape=obs_f_shape,  body_vel_shape=body_vel_shape)

obd={"obs_h":None, "obs_f":None, "body_vel":None}
print(obd.keys())

transition = RolloutStorage.Transition(**obd)


In [None]:
for i in range(40):
    observations = torch.rand(num_envs, ob_dim)
    critic_observations = observations
    actions = torch.rand(num_envs, ac_dim)
    rewards = torch.rand(num_envs)
    dones = torch.zeros(num_envs, dtype=torch.bool)
    values = torch.rand(num_envs, 1)
    actions_log_prob = torch.rand(num_envs)
    action_mean = torch.rand(num_envs, ac_dim)
    action_sigma = torch.rand(num_envs, ac_dim)
    
    observations_h = torch.rand(num_envs, ob_dim*10)
    observations_f = torch.rand(num_envs, ob_dim)
    body_vel = torch.rand(num_envs, 3)
    
    transition.observations = observations
    transition.critic_observations = critic_observations
    transition.actions = actions
    transition.rewards = rewards
    transition.dones = dones
    transition.values = values
    transition.actions_log_prob = actions_log_prob
    transition.action_mean = action_mean
    transition.action_sigma = action_sigma
    
    for key in transition.extra_info:
        if key == "obs_h":
            transition.extra_info[key] = observations_h
        if key == "obs_f":
            transition.extra_info[key] = observations_f
        if key == "body_vel":
            transition.extra_info[key] = body_vel
    
    transition.what = body_vel
#     print("what?", transition.what)
#     print("why?", transition.why)
    
    storage.add_transitions(transition)
    transition.clear()
    
storage.clear()

In [None]:
num_mini_batches = 4
num_learning_epochs = 4
cnt = 0
generator = storage.mini_batch_generator(num_mini_batches, num_learning_epochs)
for obs_batch, critic_obs_batch, actions_batch, target_values_batch, advantages_batch, \
    returns_batch, old_actions_log_prob_batch, old_mu_batch, old_sigma_batch, \
    hid_states_batch, masks_batch, (obs_h_batch, obs_f_batch, body_vel_batch) in generator:
    print("cnt:", cnt)
    print("actions_batch:", actions_batch.shape, actions_batch)
    print("obs_h_batch:", obs_h_batch.shape, obs_h_batch)
    print("obs_f_batch:", obs_f_batch.shape, obs_f_batch)
    print("body_vel_batch:", body_vel_batch.shape, body_vel_batch)
    cnt += 1
    

In [None]:
# GRU test
rnn = nn.GRU(input_size=8, hidden_size=20, num_layers=1)   
# obs 维度： sequence length, batch size, input size
obs = torch.randn(5, 2, 8)
# hidden state 维度： num_layers, batch size, hidden_size
h0 = torch.randn(1, 2, 20)
output, hn = rnn(obs, h0)
print("obs:\n", obs)
print("output:", output.shape, "\n", output)
print("hn:", hn.shape, "\n", hn)

In [None]:
t = torch.tensor([[[1, 2],
                   [3, 4]],
                  [[5, 6],
                   [7, 8]]])
print(t.shape)
c = t.flatten(0,1)
print(c.shape)
print(t)
print(c)