## Import packages

In [None]:
!pip install -q numpy
!pip install -q matplotlib
!pip install -q mujoco
!command -v ffmpeg >/dev/null || (apt update && apt install -y ffmpeg)
!pip install -q mediapy

%env MUJOCO_GL=egl

In [None]:
import os
import numpy as np
import mediapy as media
import matplotlib.pyplot as plt
import mujoco
from copy import deepcopy
import torch.nn.functional as F

import scipy.signal
import time

import torch
import torch.nn as nn
from torch.optim import Adam
from torch.distributions.normal import Normal
from torch.distributions import kl_divergence
from torch.optim.lr_scheduler import CosineAnnealingLR

## Define MuJoCo Environment

In [None]:
xml_string="""<!-- Cheetah Model

    The state space is populated with joints in the order that they are
    defined in this file. The actuators also operate on joints.

    State-Space (name/joint/parameter):
        - rootx     slider      position (m)
        - rootz     slider      position (m)
        - rooty     hinge       angle (rad)
        - bthigh    hinge       angle (rad)
        - bshin     hinge       angle (rad)
        - bfoot     hinge       angle (rad)
        - fthigh    hinge       angle (rad)
        - fshin     hinge       angle (rad)
        - ffoot     hinge       angle (rad)
        - rootx     slider      velocity (m/s)
        - rootz     slider      velocity (m/s)
        - rooty     hinge       angular velocity (rad/s)
        - bthigh    hinge       angular velocity (rad/s)
        - bshin     hinge       angular velocity (rad/s)
        - bfoot     hinge       angular velocity (rad/s)
        - fthigh    hinge       angular velocity (rad/s)
        - fshin     hinge       angular velocity (rad/s)
        - ffoot     hinge       angular velocity (rad/s)

    Actuators (name/actuator/parameter):
        - bthigh    hinge       torque (N m)
        - bshin     hinge       torque (N m)
        - bfoot     hinge       torque (N m)
        - fthigh    hinge       torque (N m)
        - fshin     hinge       torque (N m)
        - ffoot     hinge       torque (N m)

-->
<mujoco model="cheetah">
  <compiler angle="radian" coordinate="local" inertiafromgeom="true" settotalmass="14"/>
  <default>
    <joint armature=".1" damping=".01" limited="true" solimplimit="0 .8 .03" solreflimit=".02 1" stiffness="8"/>
    <geom conaffinity="0" condim="3" contype="1" friction=".4 .1 .1" rgba="0.8 0.6 .4 1" solimp="0.0 0.8 0.01" solref="0.02 1"/>
    <motor ctrllimited="true" ctrlrange="-1 1"/>
  </default>
  <size nstack="300000" nuser_geom="1"/>
  <option gravity="0 0 -9.81" timestep="0.01"/>
  <asset>
    <texture builtin="gradient" height="100" rgb1="1 1 1" rgb2="0 0 0" type="skybox" width="100"/>
    <texture builtin="flat" height="1278" mark="cross" markrgb="1 1 1" name="texgeom" random="0.01" rgb1="0.8 0.6 0.4" rgb2="0.8 0.6 0.4" type="cube" width="127"/>
    <texture builtin="checker" height="100" name="texplane" rgb1="0 0 0" rgb2="0.8 0.8 0.8" type="2d" width="100"/>
    <material name="MatPlane" reflectance="0.5" shininess="1" specular="1" texrepeat="60 60" texture="texplane"/>
    <material name="geom" texture="texgeom" texuniform="true"/>
  </asset>
  <worldbody>
    <light cutoff="100" diffuse="1 1 1" dir="-0 0 -1.3" directional="true" exponent="1" pos="0 0 1.3" specular=".1 .1 .1"/>
    <geom conaffinity="1" condim="3" material="MatPlane" name="floor" pos="0 0 0" rgba="0.8 0.9 0.8 1" size="40 40 40" type="plane"/>
    <body name="torso" pos="0 0 .7">
      <camera name="track" mode="trackcom" pos="0 -3 0.3" xyaxes="1 0 0 0 0 1"/>
      <joint armature="0" axis="1 0 0" damping="0" limited="false" name="rootx" pos="0 0 0" stiffness="0" type="slide"/>
      <joint armature="0" axis="0 0 1" damping="0" limited="false" name="rootz" pos="0 0 0" stiffness="0" type="slide"/>
      <joint armature="0" axis="0 1 0" damping="0" limited="false" name="rooty" pos="0 0 0" stiffness="0" type="hinge"/>
      <geom fromto="-.5 0 0 .5 0 0" name="torso" size="0.046" type="capsule"/>
      <geom axisangle="0 1 0 .87" name="head" pos=".6 0 .1" size="0.046 .15" type="capsule"/>
      <!-- <site name='tip'  pos='.15 0 .11'/>-->
      <body name="bthigh" pos="-.5 0 0">
        <joint axis="0 1 0" damping="6" name="bthigh" pos="0 0 0" range="-.52 1.05" stiffness="240" type="hinge"/>
        <geom axisangle="0 1 0 -3.8" name="bthigh" pos=".1 0 -.13" size="0.046 .145" type="capsule"/>
        <body name="bshin" pos=".16 0 -.25">
          <joint axis="0 1 0" damping="4.5" name="bshin" pos="0 0 0" range="-.785 .785" stiffness="180" type="hinge"/>
          <geom axisangle="0 1 0 -2.03" name="bshin" pos="-.14 0 -.07" rgba="0.9 0.6 0.6 1" size="0.046 .15" type="capsule"/>
          <body name="bfoot" pos="-.28 0 -.14">
            <joint axis="0 1 0" damping="3" name="bfoot" pos="0 0 0" range="-.4 .785" stiffness="120" type="hinge"/>
            <geom axisangle="0 1 0 -.27" name="bfoot" pos=".03 0 -.097" rgba="0.9 0.6 0.6 1" size="0.046 .094" type="capsule"/>
          </body>
        </body>
      </body>
      <body name="fthigh" pos=".5 0 0">
        <joint axis="0 1 0" damping="4.5" name="fthigh" pos="0 0 0" range="-1 .7" stiffness="180" type="hinge"/>
        <geom axisangle="0 1 0 .52" name="fthigh" pos="-.07 0 -.12" size="0.046 .133" type="capsule"/>
        <body name="fshin" pos="-.14 0 -.24">
          <joint axis="0 1 0" damping="3" name="fshin" pos="0 0 0" range="-1.2 .87" stiffness="120" type="hinge"/>
          <geom axisangle="0 1 0 -.6" name="fshin" pos=".065 0 -.09" rgba="0.9 0.6 0.6 1" size="0.046 .106" type="capsule"/>
          <body name="ffoot" pos=".13 0 -.18">
            <joint axis="0 1 0" damping="1.5" name="ffoot" pos="0 0 0" range="-.5 .5" stiffness="60" type="hinge"/>
            <geom axisangle="0 1 0 -.6" name="ffoot" pos=".045 0 -.07" rgba="0.9 0.6 0.6 1" size="0.046 .07" type="capsule"/>
          </body>
        </body>
      </body>
    </body>
  </worldbody>
  <actuator>
    <motor gear="120" joint="bthigh" name="bthigh"/>
    <motor gear="90" joint="bshin" name="bshin"/>
    <motor gear="60" joint="bfoot" name="bfoot"/>
    <motor gear="120" joint="fthigh" name="fthigh"/>
    <motor gear="60" joint="fshin" name="fshin"/>
    <motor gear="30" joint="ffoot" name="ffoot"/>
  </actuator>
</mujoco>"""

In [None]:
class HalfCheetahEnv():
  def __init__(
      self,
      frame_skip=5,
      forward_reward_weight=1.0,
      ctrl_cost_weight=0.1,
      reset_noise_scale=0.1
      ):

    self.frame_skip = frame_skip
    self.forward_reward_weight = forward_reward_weight
    self.ctrl_cost_weight = ctrl_cost_weight
    self.reset_noise_scale = reset_noise_scale

    self.initialize_simulation()
    self.init_qpos = self.data.qpos.ravel().copy()
    self.init_qvel = self.data.qvel.ravel().copy()
    self.dt = self.model.opt.timestep * self.frame_skip

    self.observation_dim = 17
    self.action_dim = 6
    self.action_limit = 1.

  def initialize_simulation(self):
    self.model = mujoco.MjModel.from_xml_string(xml_string)
    self.data = mujoco.MjData(self.model)
    mujoco.mj_resetData(self.model, self.data)
    self.renderer = mujoco.Renderer(self.model)

  def reset_simulation(self):
    mujoco.mj_resetData(self.model, self.data)

  def step_mujoco_simulation(self, ctrl, n_frames):
    self.data.ctrl[:] = ctrl
    mujoco.mj_step(self.model, self.data, nstep=n_frames)
    self.renderer.update_scene(self.data,0)

  def set_state(self, qpos, qvel):
    self.data.qpos[:] = np.copy(qpos)
    self.data.qvel[:] = np.copy(qvel)
    if self.model.na == 0:
      self.data.act[:] = None
    mujoco.mj_forward(self.model, self.data)

  def sample_action(self):
    return (2.*np.random.uniform(size=(self.action_dim,)) - 1)*self.action_limit

  def step(self, action):
    x_position_before = self.data.qpos[0]
    self.step_mujoco_simulation(action, self.frame_skip)
    x_position_after = self.data.qpos[0]
    x_velocity = (x_position_after - x_position_before) / self.dt

    # Rewards
    ctrl_cost = self.ctrl_cost_weight * np.sum(np.square(action))
    forward_reward = self.forward_reward_weight * x_velocity
    observation = self.get_obs()
    reward = forward_reward - ctrl_cost
    terminated = False
    info = {
        "x_position": x_position_after,
        "x_velocity": x_velocity,
        "reward_run": forward_reward,
        "reward_ctrl": -ctrl_cost,
    }
    return observation, reward, terminated, info

  def get_obs(self):
    position = self.data.qpos.flat.copy()
    velocity = self.data.qvel.flat.copy()
    position = position[1:]

    observation = np.concatenate((position, velocity)).ravel()
    return observation

  def render(self):
    return self.renderer.render()

  def reset(self):
    self.reset_simulation()
    noise_low = -self.reset_noise_scale
    noise_high = self.reset_noise_scale
    qpos = self.init_qpos + np.random.uniform(
        low=noise_low, high=noise_high, size=self.model.nq
    )
    qvel = (
        self.init_qvel
        + self.reset_noise_scale * np.random.standard_normal(self.model.nv)
    )
    self.set_state(qpos, qvel)
    observation = self.get_obs()
    return observation

In [None]:
device = "cuda:0" if torch.cuda.is_available() else "cpu"

## Define the Buffer Class

In [None]:
class ReplayBuffer:
    def __init__(self, obs_dim, action_dim, buffer_size, device="cpu"):
        self._buffer_size = buffer_size
        self._pointer = 0
        self._size = 0

        self._obses = torch.zeros((buffer_size, obs_dim), dtype=torch.float32, device=device)
        self._actions = torch.zeros((buffer_size, action_dim), dtype=torch.float32, device=device)
        self._rewards = torch.zeros((buffer_size, 1), dtype=torch.float32, device=device)
        self._next_obses = torch.zeros((buffer_size, obs_dim), dtype=torch.float32, device=device)
        self._dones = torch.zeros((buffer_size, 1), dtype=torch.float32, device=device)
        self._device = device

    def _to_tensor(self, data: np.ndarray) -> torch.Tensor:
        return torch.tensor(data, dtype=torch.float32, device=self._device)

    def load_dataset(self, dataset):
        n_transitions = dataset["observations"].shape[0]
        self._obses[:n_transitions] = self._to_tensor(dataset["observations"])
        self._actions[:n_transitions] = self._to_tensor(dataset["actions"])
        self._rewards[:n_transitions] = self._to_tensor(dataset["rewards"][..., None])
        self._next_obses[:n_transitions] = self._to_tensor(dataset["next_observations"])
        self._dones[:n_transitions] = self._to_tensor(dataset["terminals"][..., None])
        self._size += n_transitions
        self._pointer = min(self._size, n_transitions)
        print(f"Dataset size: {n_transitions}")

    def add_batch(self, observations, next_observations, actions, rewards, terminals):
        batch_size = len(terminals)
        if self._pointer + batch_size > self._buffer_size:
            begin = self._pointer
            end = self._buffer_size
            first_add_size = end - begin
            self._obses[begin:end] = self._to_tensor(observations[:first_add_size].copy())
            self._next_obses[begin:end] = self._to_tensor(next_observations[:first_add_size].copy())
            self._actions[begin:end] = self._to_tensor(actions[:first_add_size].copy())
            self._rewards[begin:end] = self._to_tensor(rewards[:first_add_size].copy())
            self._dones[begin:end] = self._to_tensor(terminals[:first_add_size].copy())

            begin = 0
            end = batch_size - first_add_size
            self._obses[begin:end] = self._to_tensor(observations[first_add_size:].copy())
            self._next_obses[begin:end] = self._to_tensor(next_observations[first_add_size:].copy())
            self._actions[begin:end] = self._to_tensor(actions[first_add_size:].copy())
            self._rewards[begin:end] = self._to_tensor(rewards[first_add_size:].copy())
            self._dones[begin:end] = self._to_tensor(terminals[first_add_size:].copy())

            self._pointer = end
            self._size = min(self._size + batch_size, self._buffer_size)

        else:
            begin = self._pointer
            end = self._pointer + batch_size
            self._obses[begin:end] = self._to_tensor(observations.copy())
            self._next_obses[begin:end] = self._to_tensor(next_observations.copy())
            self._actions[begin:end] = self._to_tensor(actions.copy())
            self._rewards[begin:end] = self._to_tensor(rewards.copy())
            self._dones[begin:end] = self._to_tensor(terminals.copy())

            self._pointer = end
            self._size = min(self._size + batch_size, self._buffer_size)

    def sample(self, batch_size):
        indices = np.random.randint(0, min(self._size, self._pointer), size=batch_size)
        states = self._obses[indices]
        actions = self._actions[indices]
        rewards = self._rewards[indices]
        next_states = self._next_obses[indices]
        dones = self._dones[indices]
        return [states, actions, rewards, next_states, dones]

    def sample_all(self, batch_size):
        num_batches = int((self._pointer+1)/batch_size)
        indices = np.arange(self._pointer)
        np.random.shuffle(indices)
        for batch_id in range(num_batches):
            batch_start = batch_id * batch_size
            batch_end = min(self._pointer, (batch_id + 1) * batch_size)

            states = self._obses[batch_start:batch_end]
            actions = self._actions[batch_start:batch_end]
            rewards = self._rewards[batch_start:batch_end]
            next_states = self._next_obses[batch_start:batch_end]
            dones = self._dones[batch_start:batch_end]
            yield [states, actions, rewards, next_states, dones]

    def normalize_states(self, eps = 1e-3, mean=None, std=None):
        mean = self._obses.mean(0,keepdims=True)
        std = self._obses.std(0,keepdims=True) + eps
        self._obses = (self._obses - mean)/std
        self._next_obses = (self._next_obses - mean)/std
        return mean.cpu().data.numpy().flatten(), std.cpu().data.numpy().flatten()

In [None]:
import pickle

def get_dataset(env, num_trajs=100, max_ep_len = 1000, file_name='expert_dataset.pickle'):

  if file_name is not None:
    with open(file_name, 'rb') as f:
      dataset = pickle.load(f)
      ravg = dataset['ravg']
      print(f"Average return of Offline Dataset: {np.mean(ravg)}")
      return dataset

  obs_dim = env.observation_dim
  act_dim = env.action_dim

  actions = []
  observations = []
  next_observations = []
  rewards = []
  terminals = []
  ravg = []

  o, ep_ret, ep_len = env.reset(), 0, 0
  for t in range(num_trajs * max_ep_len):
    a = env.sample_action() 

    o2, r, d, _ = env.step(a) 
    ep_ret += r
    ep_len += 1
    d = False if ep_len==max_ep_len else d

    observations.append(o)
    actions.append(a)
    next_observations.append(o2)
    rewards.append(r)
    terminals.append(d)

    o = o2
    if d or (ep_len == max_ep_len):
        ravg.append(ep_ret)
        o, ep_ret, ep_len = env.reset(), 0, 0
  print(f"Average Return of Offline Dataset: {np.mean(ravg)}")

  observations = np.array(observations).astype(np.float32)
  actions = np.array(actions).astype(np.float32)
  next_observations = np.array(next_observations).astype(np.float32)
  rewards = np.array(rewards).astype(np.float32)
  terminals = np.array(terminals).astype(np.bool_)
  return {"observations":observations,"actions":actions,"next_observations":next_observations,"rewards":rewards,"terminals":terminals}

env = HalfCheetahEnv()
dataset = get_dataset(env)

obs_dim = env.observation_dim
act_dim = env.action_dim
act_lim = env.action_limit
replay_buffer = ReplayBuffer(obs_dim, act_dim, 2000000, device)
replay_buffer.load_dataset(dataset)

## Define the Network

In [None]:
def mlp(sizes, activation, output_activation=nn.Identity):
    layers = []
    for j in range(len(sizes)-1):
        act = activation if j < len(sizes)-2 else output_activation
        layers += [nn.Linear(sizes[j], sizes[j+1]), act()]
    return nn.Sequential(*layers)

class Swish(nn.Module):
    def __init__(self):
        super(Swish, self).__init__()

    def forward(self, x):
        x = x * torch.sigmoid(x)
        return x

class EnsembleModel(nn.Module):
    def __init__(self, obs_dim, act_dim, hidden_sizes, activation=Swish, output_activation=nn.Identity, reward_dim=1, ensemble_size=7, num_elite=5, decay_weights=None):
        super(EnsembleModel, self).__init__()

        self.out_dim = obs_dim + reward_dim

        self.ensemble_models = [mlp([obs_dim + act_dim] + list(hidden_sizes) + [self.out_dim * 2], activation, output_activation) for _ in range(ensemble_size)]
        for i in range(ensemble_size):
            self.add_module("model_{}".format(i), self.ensemble_models[i])

        self.obs_dim = obs_dim
        self.action_dim = act_dim
        self.num_elite = num_elite
        self.ensemble_size = ensemble_size
        self.decay_weights = decay_weights
        self.elite_model_idxes = torch.tensor([i for i in range(num_elite)])
        self.max_logvar = nn.Parameter((torch.ones((1, self.out_dim)).float() / 2).to(device), requires_grad=True)
        self.min_logvar = nn.Parameter((-torch.ones((1, self.out_dim)).float() * 10).to(device), requires_grad=True)
        self.register_parameter("max_logvar", self.max_logvar)
        self.register_parameter("min_logvar", self.min_logvar)

    def predict(self, input):
        # convert input to tensors
        if type(input) != torch.Tensor:
            if len(input.shape) == 1:
                input = torch.FloatTensor([input]).to(device)
            else:
                input = torch.FloatTensor(input).to(device)

        # predict
        if len(input.shape) == 3:
            model_outputs = [net(ip) for ip, net in zip(torch.unbind(input), self.ensemble_models)]
        elif len(input.shape) == 2:
            model_outputs = [net(input) for net in self.ensemble_models]
        predictions = torch.stack(model_outputs)

        mean = predictions[:, :, :self.out_dim]
        logvar = predictions[:, :, self.out_dim:]
        logvar = self.max_logvar - F.softplus(self.max_logvar - logvar)
        logvar = self.min_logvar + F.softplus(logvar - self.min_logvar)

        return mean, logvar

    def get_decay_loss(self):
        decay_losses = []
        for model_net in self.ensemble_models:
            curr_net_decay_losses = [decay_weight * torch.sum(torch.square(weight)) for decay_weight, weight in
                                     zip(self.decay_weights, model_net.weights)]
            decay_losses.append(torch.sum(torch.stack(curr_net_decay_losses)))
        return torch.sum(torch.stack(decay_losses))

## Ensemble Dynamic Model

In [None]:
ensemble_dynamic_model = EnsembleModel(obs_dim=obs_dim, act_dim=act_dim, hidden_sizes=[200, 200, 200, 200]).to(device)
dynamics_lr = 1e-3
ensemble_dynamic_model_optimizer = Adam(ensemble_dynamic_model.parameters(), dynamics_lr)

data_mean, data_std = replay_buffer.normalize_states()
batch = replay_buffer.sample(256)
batch = [b.to(device) for b in batch]
observations, actions, rewards, next_observations, dones = batch
delta_observations = next_observations - observations
groundtruths = torch.cat((delta_observations, rewards), dim=-1)

model_input = torch.cat([observations, actions], dim=-1).to(device)
predictions = ensemble_dynamic_model.predict(model_input)
pred_means, pred_logvars = predictions
train_mse_losses = torch.mean(torch.pow(pred_means - groundtruths, 2), dim=(1, 2))
train_mse_loss = torch.sum(train_mse_losses)
train_transition_loss = train_mse_loss
train_transition_loss += 0.01 * torch.sum(ensemble_dynamic_model.max_logvar) - 0.01 * torch.sum(ensemble_dynamic_model.min_logvar)

print(pred_means.shape, pred_logvars.shape, train_mse_losses.shape)

Dynamic Model Update

In [None]:
ensemble_dynamic_model_optimizer.zero_grad()
train_transition_loss.backward()
ensemble_dynamic_model_optimizer.step()

Evaluate of Validation data

In [None]:
eval_mse_total_losses=np.zeros((ensemble_dynamic_model.ensemble_size,))
for eval_batch in replay_buffer.sample_all(256):
  eval_batch = [b.to(device) for b in eval_batch]
  eval_observations, eval_actions, eval_rewards, eval_next_observations, eval_dones = eval_batch
  eval_delta_observations = eval_next_observations - eval_observations
  eval_groundtruths = torch.cat((eval_delta_observations, eval_rewards), dim=-1)
  eval_model_input = torch.cat([eval_observations, eval_actions], dim=-1).to(device)
  eval_predictions = ensemble_dynamic_model.predict(eval_model_input)
  eval_pred_means, eval_pred_logvars = eval_predictions
  eval_mse_losses = torch.mean(torch.pow(eval_pred_means - eval_groundtruths, 2), dim=(1, 2)).to('cpu').detach().numpy()
  eval_mse_total_losses += eval_mse_losses
print(eval_mse_total_losses)

Save the model of the best evaluation loss

In [None]:
best_snapshot_losses=np.full((ensemble_dynamic_model.ensemble_size,), 1e10)
model_best_snapshots=[deepcopy(ensemble_dynamic_model.ensemble_models[idx].state_dict()) for idx in range(ensemble_dynamic_model.ensemble_size)]

updated = False
for i in range(len(eval_mse_total_losses)):
  current_loss = eval_mse_total_losses[i]
  best_loss = best_snapshot_losses[i]
  improvement = (best_loss - current_loss) / best_loss
  if improvement > 0.01:
    best_snapshot_losses[i] = current_loss
    model_best_snapshots[i] = deepcopy(ensemble_dynamic_model.ensemble_models[i].state_dict())
    updated = True

Load the best model parameter

In [None]:
for i in range(ensemble_dynamic_model.ensemble_size):
  ensemble_dynamic_model.ensemble_models[i].load_state_dict(model_best_snapshots[i])

## Define the Network

In [None]:
LOG_STD_MAX = 2
LOG_STD_MIN = -20
MEAN_MIN = -9.0
MEAN_MAX = 9.0


class SquashedGaussianMLPActor(nn.Module):
    def __init__(self, obs_dim, act_dim, hidden_sizes, activation, act_limit):
        super().__init__()
        self.net = mlp([obs_dim] + list(hidden_sizes), activation, activation)
        self.mu_layer = nn.Linear(hidden_sizes[-1], act_dim)
        self.log_std_layer = nn.Linear(hidden_sizes[-1], act_dim)
        self.act_limit = act_limit

    def log_prob(self, obs, actions):
        net_out = self.net(obs)
        mu = self.mu_layer(net_out)
        log_std = self.log_std_layer(net_out)
        log_std = torch.clamp(log_std, LOG_STD_MIN, LOG_STD_MAX)
        std = torch.exp(log_std)

        pi_distribution = Normal(mu, std)
        log_prob = pi_distribution.log_prob(actions).sum(axis=-1)
        log_prob -= (2*(np.log(2) - actions - F.softplus(-2*actions))).sum(axis=1)
        return log_prob.sum(-1)

    def forward(self, obs, deterministic=False, with_logprob=True):
        net_out = self.net(obs)
        mu = self.mu_layer(net_out)
        log_std = self.log_std_layer(net_out)
        log_std = torch.clamp(log_std, LOG_STD_MIN, LOG_STD_MAX)
        std = torch.exp(log_std)

        pi_distribution = Normal(mu, std)
        if deterministic:
            pi_action = mu
        else:
            pi_action = pi_distribution.rsample()

        if with_logprob:
            logp_pi = pi_distribution.log_prob(pi_action).sum(axis=-1)
            logp_pi -= (2*(np.log(2) - pi_action - F.softplus(-2*pi_action))).sum(axis=1)
        else:
            logp_pi = None

        pi_action = torch.tanh(pi_action)
        pi_action = self.act_limit * pi_action

        return pi_action, logp_pi

class MLPQFunction(nn.Module):
    def __init__(self, obs_dim, act_dim, hidden_sizes, activation):
        super().__init__()
        self.q = mlp([obs_dim + act_dim] + list(hidden_sizes) + [1], activation)

    def forward(self, obs, act):
        q = self.q(torch.cat([obs, act], dim=-1))
        return torch.squeeze(q, -1) # Critical to ensure q has right shape.

Define Network and Optimizer

In [None]:
hidden_sizes = [256, 256, 256] 
activation = nn.ReLU 

soft_update_tau = 5e-3 
policy_lr=1e-4 
qf_lr=3e-4 
target_entropy = - act_dim 

qf1 = MLPQFunction(obs_dim, act_dim, hidden_sizes, activation).to(device)
qf2 = MLPQFunction(obs_dim, act_dim, hidden_sizes, activation).to(device)
target_qf1 = deepcopy(qf1).to(device)
target_qf2 = deepcopy(qf2).to(device)
policy = SquashedGaussianMLPActor(obs_dim, act_dim, hidden_sizes, activation, act_lim).to(device)
log_alpha = torch.zeros(1, requires_grad=True, device=device)

alpha_optimizer = Adam([log_alpha], lr=policy_lr)
policy_optimizer = Adam(policy.parameters(), lr=policy_lr)
qf1_optimizer = Adam(qf1.parameters(), lr=qf_lr)
qf2_optimizer = Adam(qf2.parameters(), lr=qf_lr)

rollout_freq=2
rollout_batch_size=1000
rollout_length=5

model_buffer = ReplayBuffer(
    obs_dim,
    act_dim,
    1000000,
    device)

init_transitions = replay_buffer.sample(rollout_batch_size)
# rollout
observations = init_transitions[0]
for _ in range(rollout_length):
    actions, _ = policy(observations)
    model_input = torch.cat([observations, actions], dim=-1).to(device)
    pred_diff_means, pred_diff_logvars = ensemble_dynamic_model.predict(model_input)
    observations = observations.detach().cpu().numpy()
    actions = actions.detach().cpu().numpy()
    ensemble_model_stds = pred_diff_logvars.exp().sqrt().detach().cpu().numpy()
    pred_diff_means = pred_diff_means.detach().cpu().numpy()
    pred_diff_means = pred_diff_means + np.random.normal(size=pred_diff_means.shape) * ensemble_model_stds

    num_models, batch_size, _ = pred_diff_means.shape
    model_idxes = np.random.choice(ensemble_dynamic_model.elite_model_idxes, size=batch_size)
    batch_idxes = np.arange(0, batch_size)
    pred_diff_samples = pred_diff_means[model_idxes, batch_idxes]

    next_observations, rewards = pred_diff_samples[:, :-1] + observations, pred_diff_samples[:, [-1]]
    penalty = np.amax(np.linalg.norm(ensemble_model_stds, axis=2), axis=0)
    penalty = np.expand_dims(penalty, 1)
    rewards = rewards - 5e-1 * penalty

    terminals = np.full((batch_size,1),False)
    model_buffer.add_batch(observations, next_observations, actions, rewards, terminals)
    observations = torch.tensor(next_observations, dtype=torch.float32, device=device)
    
mixing_ratio = 0.1
batch_size = 256

replay_batch_size = int(batch_size*(1-mixing_ratio))
model_batch_size = batch_size - replay_batch_size

Batch sampling

In [None]:
replay_batch = replay_buffer.sample(replay_batch_size)
model_batch = model_buffer.sample(model_batch_size)
observations, actions, rewards, next_observations, dones = [torch.concat([r_b, m_b]) for r_b, m_b in zip(replay_batch, model_batch)]

discount = 0.99

new_actions, log_pi = policy(observations)
alpha_loss = -(log_alpha * (log_pi + target_entropy).detach()).mean()
alpha = log_alpha.exp()

q1_predicted = qf1(observations, actions)
q2_predicted = qf2(observations, actions)

new_next_actions, next_log_pi = policy(next_observations)
target_q_values = torch.min(target_qf1(next_observations, new_next_actions),target_qf2(next_observations, new_next_actions))
target_q_values = target_q_values - alpha * next_log_pi
target_q_values = target_q_values.unsqueeze(-1)

td_target = rewards + (1.0 - dones) * discount * target_q_values
td_target = td_target.squeeze(-1)

qf1_loss = F.mse_loss(q1_predicted, td_target.detach())
qf2_loss = F.mse_loss(q2_predicted, td_target.detach())
qf_loss = qf1_loss + qf2_loss

q_new_actions = torch.min(qf1(observations, new_actions), qf2(observations, new_actions))
policy_loss = (alpha * log_pi - q_new_actions).mean()

alpha_optimizer.zero_grad()
alpha_loss.backward()
alpha_optimizer.step()

policy_optimizer.zero_grad()
policy_loss.backward()
policy_optimizer.step()

qf1_optimizer.zero_grad()
qf2_optimizer.zero_grad()
qf_loss.backward(retain_graph=True)
qf1_optimizer.step()
qf2_optimizer.step()

def soft_update(target, source, tau):
    for target_param, source_param in zip(target.parameters(), source.parameters()):
        target_param.data.copy_((1 - tau) * target_param.data + tau * source_param.data)

soft_update(target_qf1, qf1, soft_update_tau)
soft_update(target_qf2, qf2, soft_update_tau)

## Whole code of Model-based offline Policy Optimization

In [None]:
def mopo(env_fn, max_iterations4dynamic_model=10000, max_total_steps=50000, buffer_size=1000000,
         dynamics_lr = 1e-3, batch_size=256, hidden_sizes = [256, 256, 256], activation = nn.ReLU,
         soft_update_tau = 5e-3, policy_lr=1e-4, qf_lr=3e-4, rollout_freq=2,
         rollout_batch_size=1000, rollout_length=5, mixing_ratio = 0.1, discount = 0.99):
  env = env_fn()
  dataset = get_dataset(env, file_name="expert_dataset.pickle")

  obs_dim = env.observation_dim
  act_dim = env.action_dim
  act_lim = env.action_limit
  replay_buffer = ReplayBuffer(obs_dim, act_dim, buffer_size, device)
  replay_buffer.load_dataset(dataset)
  data_mean, data_std = replay_buffer.normalize_states()

  ensemble_dynamic_model = EnsembleModel(obs_dim=obs_dim, act_dim=act_dim, hidden_sizes=[200, 200, 200, 200]).to(device)

  ensemble_dynamic_model_optimizer = Adam(ensemble_dynamic_model.parameters(), dynamics_lr)
  best_snapshot_losses=np.full((ensemble_dynamic_model.ensemble_size,), 1e10)
  model_best_snapshots=[deepcopy(ensemble_dynamic_model.ensemble_models[idx].state_dict()) for idx in range(ensemble_dynamic_model.ensemble_size)]

  for t in range(max_iterations4dynamic_model):
    batch = replay_buffer.sample(batch_size)
    batch = [b.to(device) for b in batch]
    observations, actions, rewards, next_observations, dones = batch
    delta_observations = next_observations - observations
    groundtruths = torch.cat((delta_observations, rewards), dim=-1)

    model_input = torch.cat([observations, actions], dim=-1).to(device)
    predictions = ensemble_dynamic_model.predict(model_input)
    pred_means, pred_logvars = predictions
    train_mse_losses = torch.mean(torch.pow(pred_means - groundtruths, 2), dim=(1, 2))
    train_mse_loss = torch.sum(train_mse_losses)
    train_transition_loss = train_mse_loss
    train_transition_loss += 0.01 * torch.sum(ensemble_dynamic_model.max_logvar) - 0.01 * torch.sum(ensemble_dynamic_model.min_logvar)

    ensemble_dynamic_model_optimizer.zero_grad()
    train_transition_loss.backward()
    ensemble_dynamic_model_optimizer.step()

    if (t%5000)==0:
      eval_mse_total_losses=np.zeros((ensemble_dynamic_model.ensemble_size,))
      for eval_batch in replay_buffer.sample_all(batch_size):
        eval_batch = [b.to(device) for b in eval_batch]
        eval_observations, eval_actions, eval_rewards, eval_next_observations, eval_dones = eval_batch
        eval_delta_observations = eval_next_observations - eval_observations
        eval_groundtruths = torch.cat((eval_delta_observations, eval_rewards), dim=-1)
        eval_model_input = torch.cat([eval_observations, eval_actions], dim=-1).to(device)
        eval_predictions = ensemble_dynamic_model.predict(eval_model_input)
        eval_pred_means, eval_pred_logvars = eval_predictions
        eval_mse_losses = torch.mean(torch.pow(eval_pred_means - eval_groundtruths, 2), dim=(1, 2)).to('cpu').detach().numpy()
        eval_mse_total_losses += eval_mse_losses

      updated = False
      for i in range(len(eval_mse_total_losses)):
        current_loss = eval_mse_total_losses[i]
        best_loss = best_snapshot_losses[i]
        improvement = (best_loss - current_loss) / best_loss
        if improvement > 0.01:
          best_snapshot_losses[i] = current_loss
          model_best_snapshots[i] = deepcopy(ensemble_dynamic_model.ensemble_models[i].state_dict())
          updated = True
          print(f'{i}th model is updated!')
      if updated:
        print(f'[{t}]Dynamic model evaluation: {eval_mse_total_losses}')

  for i in range(ensemble_dynamic_model.ensemble_size):
    ensemble_dynamic_model.ensemble_models[i].load_state_dict(model_best_snapshots[i])

  target_entropy = - act_dim 

  qf1 = MLPQFunction(obs_dim, act_dim, hidden_sizes, activation).to(device)
  qf2 = MLPQFunction(obs_dim, act_dim, hidden_sizes, activation).to(device)
  target_qf1 = deepcopy(qf1).to(device)
  target_qf2 = deepcopy(qf2).to(device)
  policy = SquashedGaussianMLPActor(obs_dim, act_dim, hidden_sizes, activation, act_lim).to(device)
  log_alpha = torch.zeros(1, requires_grad=True, device=device)

  alpha_optimizer = Adam([log_alpha], lr=policy_lr)
  policy_optimizer = Adam(policy.parameters(), lr=policy_lr)
  qf1_optimizer = Adam(qf1.parameters(), lr=qf_lr)
  qf2_optimizer = Adam(qf2.parameters(), lr=qf_lr)

  model_buffer = ReplayBuffer(obs_dim, act_dim, buffer_size, device)

  print('Offline RL start')
  replay_batch_size = int(batch_size*(1-mixing_ratio))
  model_batch_size = batch_size - replay_batch_size

  for t in range(max_total_steps):
    if (t%rollout_freq)==0:
      init_transitions = replay_buffer.sample(rollout_batch_size)
      # rollout
      observations = init_transitions[0]
      for _ in range(rollout_length):
          actions, _ = policy(observations)
          model_input = torch.cat([observations, actions], dim=-1).to(device)
          pred_diff_means, pred_diff_logvars = ensemble_dynamic_model.predict(model_input)
          observations = observations.detach().cpu().numpy()
          actions = actions.detach().cpu().numpy()
          ensemble_model_stds = pred_diff_logvars.exp().sqrt().detach().cpu().numpy()
          pred_diff_means = pred_diff_means.detach().cpu().numpy()
          pred_diff_means = pred_diff_means + np.random.normal(size=pred_diff_means.shape) * ensemble_model_stds

          num_models, batch_size, _ = pred_diff_means.shape
          model_idxes = np.random.choice(ensemble_dynamic_model.elite_model_idxes, size=batch_size)
          batch_idxes = np.arange(0, batch_size)
          pred_diff_samples = pred_diff_means[model_idxes, batch_idxes]

          next_observations, rewards = pred_diff_samples[:, :-1] + observations, pred_diff_samples[:, [-1]]
          penalty = np.amax(np.linalg.norm(ensemble_model_stds, axis=2), axis=0)
          penalty = np.expand_dims(penalty, 1)
          rewards = rewards - 5e-1 * penalty

          terminals = np.full((batch_size,1),False)
          model_buffer.add_batch(observations, next_observations, actions, rewards, terminals)
          observations = torch.tensor(next_observations, dtype=torch.float32, device=device)

    replay_batch = replay_buffer.sample(replay_batch_size)
    model_batch = model_buffer.sample(model_batch_size)

    observations, actions, rewards, next_observations, dones = [torch.concat([r_b, m_b]) for r_b, m_b in zip(replay_batch, model_batch)]

    new_actions, log_pi = policy(observations)
    alpha_loss = -(log_alpha * (log_pi + target_entropy).detach()).mean()
    alpha = log_alpha.exp()

    q1_predicted = qf1(observations, actions)
    q2_predicted = qf2(observations, actions)

    new_next_actions, next_log_pi = policy(next_observations)
    target_q_values = torch.min(target_qf1(next_observations, new_next_actions),target_qf2(next_observations, new_next_actions))
    target_q_values = target_q_values - alpha * next_log_pi
    target_q_values = target_q_values.unsqueeze(-1)

    td_target = rewards + (1.0 - dones) * discount * target_q_values
    td_target = td_target.squeeze(-1)

    qf1_loss = F.mse_loss(q1_predicted, td_target.detach())
    qf2_loss = F.mse_loss(q2_predicted, td_target.detach())
    qf_loss = qf1_loss + qf2_loss

    q_new_actions = torch.min(
        qf1(observations, new_actions),
        qf2(observations, new_actions),
    )
    policy_loss = (alpha * log_pi - q_new_actions).mean()

    alpha_optimizer.zero_grad()
    alpha_loss.backward()
    alpha_optimizer.step()

    policy_optimizer.zero_grad()
    policy_loss.backward()
    policy_optimizer.step()

    qf1_optimizer.zero_grad()
    qf2_optimizer.zero_grad()
    qf_loss.backward(retain_graph=True)
    qf1_optimizer.step()
    qf2_optimizer.step()

    soft_update(target_qf1, qf1, soft_update_tau)
    soft_update(target_qf2, qf2, soft_update_tau)

    if (t%4000)==0:
      log_dict = dict(qf1_loss=qf1_loss.item(),
                      qf2_loss=qf2_loss.item(),
                      alpha_loss=alpha_loss.item(),
                      policy_loss=policy_loss.item())

      for keys, values in log_dict.items():
          print(f'{keys}:{values:8.2f}',end=", ")

      avg_ret = []
      for _ in range(10):
        obs = env.reset()
        ret = 0
        for _t in range(1000):
          obs = (obs - data_mean)/data_std
          with torch.no_grad():
            obs = torch.as_tensor(obs, dtype=torch.float32,device=device)
            action, _ = policy(obs, deterministic=True, with_logprob=False)
            action = action.to('cpu').numpy()
          obs, reward, terminated, info = env.step(action)
          ret += reward
        avg_ret.append(ret)
      print(f'Test Return:{np.mean(avg_ret):8.2f}')

  return policy

In [None]:
policy = mopo(HalfCheetahEnv)