#Import Depedencies

In [None]:
import os
from pathlib import Path
from collections import deque
from itertools import chain

os.environ["MUJOCO_GL"] = "egl"

import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Normal
from torch.optim import Adam

from dm_control import suite
from dm_control.suite.wrappers import pixels

parent_dir = Path(__file__).resolve().parent

#Utils

In [None]:
class HPS:
  def __init__(self, hps):
    for key, value in hps.items():
      if isinstance(value, dict):
        setattr(self, key, HPS(value))
      else:
        setattr(self, key, value)

In [None]:
class Crop(nn.Module):
  def __init__(self, hps):
    super(Crop, self).__init__()

    self.output_size = hps.output_size

  def random_crop(self, x):
    B, C, H, W = x.shape

    crop_max = H - self.output_size + 1

    w1 = torch.randint(0, crop_max, (B,), device=x.device)
    h1 = torch.randint(0, crop_max, (B,), device=x.device)

    cropped = torch.empty((B, C, self.output_size, self.output_size), dtype=x.dtype, device=x.device)

    for i, (img, w_start, h_start) in enumerate(zip(x, w1, h1)):
      cropped[i] = img[:, h_start:h_start+self.output_size, w_start:w_start+self.output_size]

    return cropped

  def center_crop(self, x):
    h, w = x.shape[-2:]
    start_h = (h - self.output_size) // 2
    start_w = (w - self.output_size) // 2
    return x[:, :, start_h:start_h+self.output_size, start_w:start_w+self.output_size]

#Model

##Components

###MLP

In [None]:
class MLP(nn.Module):
  def __init__(self, input_dim, hidden_dim, output_dim):
    super(MLP, self).__init__()

    self.in_linear = nn.Linear(input_dim, hidden_dim)
    self.in_relu = nn.ReLU(inplace=True)

    self.out_linear = nn.Linear(hidden_dim, output_dim)

  def init_weights(self):
    for m in self.modules():
      if isinstance(m, nn.Linear):
        nn.init.xavier_uniform_(m.weight)
        if m.bias is not None:
          nn.init.zeros_(m.bias)

  def forward(self, x):
    x = self.in_linear(x)
    x = self.in_relu(x)

    return self.out_linear(x)

###Encoder

In [None]:
class Encoder(nn.Module):
    def __init__(self, observation_shape, hps):
        super(Encoder, self).__init__()

        self.in_conv = nn.Conv2d(
            observation_shape[0], hps.latent_dim, kernel_size=3, stride=2, padding=1
        )
        self.in_relu = nn.ReLU(inplace=True)

        self.layers = nn.ModuleList(
            [
                nn.Sequential(
                    nn.Conv2d(
                        hps.latent_dim,
                        hps.latent_dim,
                        kernel_size=3,
                        stride=1,
                        padding=1,
                    ),
                    nn.ReLU(inplace=True),
                )
                for _ in range(hps.num_layers - 1)
            ]
        )

        self.flatten = nn.Flatten()

        self.mlp = MLP(
            hps.latent_dim * (hps.input_shape // 2) ** 2,
            hps.hidden_dim,
            hps.output_dim,
        )

        self.output_layer = nn.Sequential(nn.LayerNorm(hps.output_dim), nn.Tanh())

    def init_weights(self):
        self.mlp.init_weights()

        nn.init.xavier_uniform_(self.in_conv.weight)
        if self.in_conv.bias is not None:
            nn.init.zeros_(self.in_conv.bias)

        for m in self.layers.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.xavier_uniform_(m.weight)
                if m.bias is not None:
                    nn.init.zeros_(m.bias)

        for m in self.output_layer.modules():
            if isinstance(m, nn.LayerNorm):
                nn.init.ones_(m.weight)
                nn.init.zeros_(m.bias)

    def forward(self, x):
        x = self.in_conv(x)
        x = self.in_relu(x)

        for layer in self.layers:
            x = layer(x)

        x = self.flatten(x)

        x = self.mlp(x)

        return self.output_layer(x)

##Actor

In [None]:
class Actor(nn.Module):
  def __init__(self, encoder, hps):
    super(Actor, self).__init__()

    self.encoder = encoder

    self.mlp = MLP(hps.input_dim, hps.hidden_dim, 2 * hps.action_dim)

  def init_weights(self):
    self.mlp.init_weights()

  def forward(self, x):
    with torch.no_grad():
      x = self.encoder(x)

    x = self.mlp(x)

    mu, logvar = x.chunk(2, dim=-1)
    std = torch.exp(0.5 * logvar)

    normal = Normal(mu, std)

    x_t = normal.rsample()

    y_t = torch.tanh(x_t)
    action = y_t

    log_prob = normal.log_prob(x_t)
    log_prob -= torch.log(1 - action.pow(2) + 1e-6)
    log_prob = log_prob.sum(1, keepdim=True)

    return action, log_prob, mu

##Soft Critic

In [None]:
class SoftCritic(nn.Module):
  def __init__(self, encoder, hps):
    super(SoftCritic, self).__init__()

    self.encoder = encoder

    self.mlp = MLP(hps.input_dim, hps.hidden_dim, 1)

  def init_weights(self):
    self.mlp.init_weights()

  def forward(self, x, a):
    x = self.encoder(x)

    x = torch.cat([x, a], dim=-1)
    x = self.mlp(x)

    return x

##SAC

In [None]:
class SAC_CURL(nn.Module):
    def __init__(self, hps, train_hps, action_space, device):
        super(SAC_CURL, self).__init__()

        self.gamma = train_hps.gamma
        self.alpha = train_hps.alpha
        self.tau = train_hps.tau
        self.temp = train_hps.temp
        self.batch_size = train_hps.batch_size
        self.update_freq = train_hps.update_freq
        self.device = device

        self.encoder = Encoder(hps.observation_shape, hps.encoder)
        self.key_w = nn.Linear(hps.encoder.output_dim, hps.encoder.output_dim)
        self.encoder_optim = Adam(
            chain(self.encoder.parameters(), self.key_w.parameters()),
            lr=train_hps.lr,
            betas=tuple(train_hps.betas),
        )

        self.target_encoder = Encoder(hps.observation_shape, hps.encoder)

        self.actor = Actor(self.encoder, hps.actor)
        self.actor_optim = Adam(
            self.actor.mlp.parameters(), lr=train_hps.lr, betas=tuple(train_hps.betas)
        )

        self.critic1 = SoftCritic(self.encoder, hps.critic)
        self.critic2 = SoftCritic(self.encoder, hps.critic)
        self.critic1_optim = Adam(
            self.critic1.mlp.parameters(), lr=train_hps.lr, betas=tuple(train_hps.betas)
        )
        self.critic2_optim = Adam(
            self.critic2.mlp.parameters(), lr=train_hps.lr, betas=tuple(train_hps.betas)
        )

        self.target_critic1 = SoftCritic(self.target_encoder, hps.critic)
        self.target_critic2 = SoftCritic(self.target_encoder, hps.critic)

        self.target_entropy = -torch.prod(
            torch.Tensor(action_space.shape).to(device)
        ).item()
        self.log_alpha = torch.zeros(1, requires_grad=True, device=device)
        self.alpha_optim = Adam(
            [self.log_alpha], lr=train_hps.alpha_lr, betas=tuple(train_hps.alpha_betas)
        )

    def init_weights(self, ckpt=None):
        if ckpt is None:
            self.encoder.init_weights()
            self.actor.init_weights()
            self.critic1.init_weights()
            self.critic2.init_weights()

            self.target_encoder.load_state_dict(self.encoder.state_dict())
            self.target_critic1.load_state_dict(self.critic1.state_dict())
            self.target_critic2.load_state_dict(self.critic2.state_dict())
        else:
            self.load_state_dict(ckpt["agent_state_dict"])

        self._freeze_parameters(self.target_encoder)
        self._freeze_parameters(self.target_critic1)
        self._freeze_parameters(self.target_critic2)

    def _freeze_parameters(self, module):
        for param in module.parameters():
            param.requires_grad = False

    def _soft_update(self, local_model, target_model):
        for param, target_param in zip(
            local_model.parameters(), target_model.parameters()
        ):
            target_param.data.copy_(
                self.tau * param.data + (1.0 - self.tau) * target_param.data
            )

    def q_forward(self, x, a):
        q1, q2 = self.target_critic1(x, a), self.target_critic2(x, a)

        return torch.min(q1, q2)

    def select_action(self, x, eval=False):
        if eval:
            _, _, mean = self.actor(x)
            return mean
        else:
            action, log_prob, _ = self.actor(x)
            return action, log_prob

    def update_parameters(self, crop, buffer, updates):
        batch = buffer.sample(self.batch_size)
        obs, action, reward, next_obs, done = batch

        obs_q, obs_k = crop.random_crop(obs), crop.random_crop(obs)
        next_obs = crop.random_crop(next_obs)

        mask = 1 - done.unsqueeze(1)
        reward = reward.unsqueeze(1)

        with torch.no_grad():
            next_state_action, next_state_log_pi, _ = self.actor(next_obs)
            qf1_next_target, qf2_next_target = self.target_critic1(
                next_obs, next_state_action
            ), self.target_critic2(next_obs, next_state_action)
            min_qf_next_target = (
                torch.min(qf1_next_target, qf2_next_target)
                - self.alpha * next_state_log_pi
            )
            next_q_value = reward + (mask * self.gamma * min_qf_next_target)

        qf1, qf2 = self.critic1(obs_q, action), self.critic2(obs_q, action)
        qf1_loss = F.mse_loss(qf1, next_q_value)
        qf2_loss = F.mse_loss(qf2, next_q_value)
        qf_loss = qf1_loss + qf2_loss

        z_q = self.encoder(obs_q)
        with torch.no_grad():
            z_k = self.target_encoder(obs_k)

        z_k = self.key_w(z_k)
        z_k = F.normalize(z_k, dim=-1)

        logits = torch.matmul(z_q, z_k.T)
        logits = logits / self.temp

        labels = torch.arange(logits.shape[0]).long().to(self.device)
        curl_loss = F.cross_entropy(logits, labels)

        self.encoder_optim.zero_grad()
        self.critic1_optim.zero_grad()
        self.critic2_optim.zero_grad()
        qf_loss.backward()
        curl_loss.backward()
        self.encoder_optim.step()
        self.critic1_optim.step()
        self.critic2_optim.step()

        pi, log_pi, _ = self.actor(obs_q)
        qf1_pi, qf2_pi = self.critic1(obs_q, pi), self.critic2(obs_q, pi)
        min_qf_pi = torch.min(qf1_pi, qf2_pi)

        actor_loss = ((self.alpha * log_pi) - min_qf_pi).mean()

        self.actor_optim.zero_grad()
        actor_loss.backward()
        self.actor_optim.step()

        alpha_loss = -(self.log_alpha * (log_pi + self.target_entropy).detach()).mean()

        self.alpha_optim.zero_grad()
        alpha_loss.backward()
        self.alpha_optim.step()

        self.alpha = self.log_alpha.exp()

        if updates % self.update_freq == 0:
            self._soft_update(self.critic1, self.target_critic1)
            self._soft_update(self.critic2, self.target_critic2)
            self._soft_update(self.encoder, self.target_encoder)

#Environment

##Action Repeat Wrapper

In [None]:
class ActionRepeatWrapper:
  def __init__(self, env, num_repeats):
    self._env = env
    self._num_repeats = num_repeats

  def __getattr__(self, name):
    return getattr(self._env, name)

  def step(self, action):
    reward = 0.0
    discount = 1.0

    for _ in range(self._num_repeats):
      time_step = self._env.step(action)
      reward = reward + (time_step.reward or 0.0) * discount
      discount = discount * time_step.discount

      if time_step.last():
        break

    return time_step._replace(reward=reward, discount=discount)

  def reset(self):
    return self._env.reset()

##Frame Stack Wrapper

In [None]:
class FrameStackWrapper:
  def __init__(self, env, num_frames, pixels_key='pixels'):
    self._env = env
    self._num_frames = num_frames
    self._pixels_key = pixels_key

    self.frames = deque([], maxlen=num_frames)

  def __getattr__(self, name):
    return getattr(self._env, name)

  def _get_obs(self, time_step):
    frame = time_step.observation[self._pixels_key]
    frame = frame.transpose(2, 0, 1)

    if len(self.frames) == 0:
      for _ in range(self._num_frames):
        self.frames.append(frame)
    else:
      self.frames.append(frame)

    frames = np.concatenate(list(self.frames), axis=0)

    return frames

  def step(self, action):
    time_step = self._env.step(action)
    return self._get_obs(time_step), time_step.reward, time_step.last(), time_step.discount

  def reset(self):
    self.frames.clear()
    time_step = self._env.reset()
    return self._get_obs(time_step), time_step.reward, time_step.last(), time_step.discount

##Create Environment

In [None]:
def create_environment(domain_name, task_name, action_repeat, frame_stack, image_size):
  env = suite.load(domain_name=domain_name, task_name=task_name)

  env = ActionRepeatWrapper(env, action_repeat)

  env = pixels.Wrapper(
      env,
      pixels_only=True,
      render_kwargs={'height': image_size, 'width': image_size, 'camera_id': 0}
  )

  env = FrameStackWrapper(env, num_frames=frame_stack)

  return env

#Training Scheme

##Replay Buffer

In [None]:
class ReplayBuffer:
    def __init__(self, capacity, observation_shape, action_shape, device):
        self._capacity = capacity
        self._observation_shape = observation_shape
        self._action_shape = action_shape
        self.device = device

        self.obs_buf = np.zeros((capacity, *observation_shape), dtype=np.uint8)
        self.next_obs_buf = np.zeros((capacity, *observation_shape), dtype=np.uint8)

        self.action_buf = np.zeros((capacity, *action_shape), dtype=np.float32)
        self.reward_buf = np.zeros(capacity, dtype=np.float32)
        self.done_buf = np.zeros(capacity, dtype=np.float32)

        self.ptr = 0
        self.size = 0

    def load(self, path):
        ckpt = torch.load(path)

        self._capacity = ckpt["capacity"]
        self._observation_shape = ckpt["observation_shape"]
        self._action_shape = ckpt["action_shape"]

        self.obs_buf = ckpt["obs_buf"]
        self.next_obs_buf = ckpt["next_obs_buf"]
        self.action_buf = ckpt["action_buf"]
        self.reward_buf = ckpt["reward_buf"]
        self.done_buf = ckpt["done_buf"]

        self.ptr = ckpt["ptr"]
        self.size = ckpt["size"]

    def save(self, save_path):
        torch.save(
            {
                "capacity": self._capacity,
                "observation_shape": self._observation_shape,
                "action_shape": self._action_shape,
                "ptr": self.ptr,
                "size": self.size,
                "obs_buf": self.obs_buf,
                "next_obs_buf": self.next_obs_buf,
                "action_buf": self.action_buf,
                "reward_buf": self.reward_buf,
                "done_buf": self.done_buf,
            },
            save_path,
            pickle_protocol=4,
        )

    def add(self, obs, action, reward, next_obs, done):
        self.obs_buf[self.ptr] = obs
        self.next_obs_buf[self.ptr] = next_obs
        self.action_buf[self.ptr] = action
        self.reward_buf[self.ptr] = reward
        self.done_buf[self.ptr] = done

        self.ptr = (self.ptr + 1) % self._capacity
        self.size = min(self.ptr + 1, self._capacity)

    def sample(self, batch_size):
        idxs = np.random.randint(0, self.size, size=batch_size)

        obs = torch.as_tensor(self.obs_buf[idxs], device=self.device).float() / 255.0
        next_obs = (
            torch.as_tensor(self.next_obs_buf[idxs], device=self.device).float() / 255.0
        )

        action = torch.as_tensor(self.action_buf[idxs], device=self.device)
        reward = torch.as_tensor(self.reward_buf[idxs], device=self.device)
        done = torch.as_tensor(self.done_buf[idxs], device=self.device)

        return obs, action, reward, next_obs, done

##Set Seeds

In [None]:
def set_seeds(hps):
  # Set seeds for reproducibility.
  os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'

  seed = hps.seed
  np.random.seed(seed)
  torch.manual_seed(seed)
  torch.cuda.manual_seed(seed)
  torch.backends.cudnn.deterministic = True
  torch.backends.cudnn.benchmark = False
  torch.use_deterministic_algorithms(True, warn_only=True)

##Training Loop

In [None]:
def train(action_spec, env, agent, buffer, crop, train_hps, device, ckpt=None):
    save_dir = os.path.join(parent_dir, train_hps.save_dir)
    os.makedirs(save_dir, exist_ok=True)

    if ckpt is not None:
        ckpt = os.path.join(parent_dir, ckpt)
        assert os.path.exists(ckpt), f"{ckpt} does not exist"

        ckpt = torch.load(ckpt)
        agent.init_weights(ckpt)

        total_numsteps = ckpt["total_numsteps"] + 1
        num_episodes = ckpt["num_episodes"]
        updates = ckpt["updates"]
    else:
        agent.init_weights()
        total_numsteps = 0
        num_episodes = 0
        updates = 0

    agent.to(device)

    while total_numsteps < train_hps.total_steps:
        num_episodes += 1

        episode_reward = 0
        episode_steps = 0
        done = False
        state, reward, done, _ = env.reset()

        while not done:
            if total_numsteps < train_hps.warmup_steps:
                action = np.random.uniform(
                    action_spec.minimum, action_spec.maximum, size=action_spec.shape
                ).astype(action_spec.dtype)
            else:
                with torch.no_grad():
                    _state = (
                        torch.as_tensor(state, device=device).unsqueeze(0).float()
                        / 255.0
                    )
                    _state = crop.random_crop(_state)
                    action, _ = agent.select_action(_state)
                    action = action.detach().cpu().numpy().astype(action_spec.dtype)

            next_state, reward, done, _ = env.step(action)

            buffer.add(state, action, reward, next_state, done)

            state = next_state
            episode_reward += reward
            total_numsteps += 1
            episode_steps += 1

            if buffer.size > train_hps.start_training_steps:
                updates += 1
                agent.update_parameters(crop, buffer, updates)

                if updates % 10000 == 0:
                    torch.save(
                        {
                            "agent_state_dict": agent.state_dict(),
                            "total_numsteps": total_numsteps,
                            "num_episodes": num_episodes,
                            "updates": updates,
                        },
                        f"{save_dir}/checkpoint_{updates}.pt",
                    )

                    # buffer.save(os.path.join(save_dir, 'buffer.pt'))
            if total_numsteps >= train_hps.total_steps:
                break

        print(
            f"Epsiode: {num_episodes}, Reward: {episode_reward}, Steps: {episode_steps}, Total Steps: {total_numsteps}"
        )

    torch.save(
        {
            "agent_state_dict": agent.state_dict(),
            "total_numsteps": total_numsteps,
            "num_episodes": num_episodes,
            "updates": updates,
        },
        f"{save_dir}/final.pt",
    )

    return agent

#HPS

##Model HPS

In [None]:
model_config = {
    "observation_shape": [9, 100, 100],
    "encoder": {
      "input_shape": 84,
      "latent_dim": 32,
      "num_layers": 4,
      "input_dim": 56448,
      "hidden_dim": 1024,
      "output_dim": 50
    },
    "actor": {
      "input_dim": 50,
      "hidden_dim": 1024,
      "action_dim": 6
    },
    "critic": {
      "input_dim": 56,
      "hidden_dim": 1024
    }
}

##Train HPS

In [None]:
train_config = {
    "seed": 42,
    "domain_name": "cheetah",
    "task_name": "run",
    "action_repeat": 4,
    "frame_stack": 3,
    "image_size": 100,
    "buffer_capacity": 100000,
    "output_size": 84,
    "gamma": 0.99,
    "tau": 0.01,
    "alpha": 0.2,
    "temp": 0.1,
    "lr": 2e-4,
    "alpha_lr": 1e-4,
    "betas": [0.9, 0.999],
    "alpha_betas": [0.5, 0.999],
    "batch_size": 512,
    "update_freq": 2,
    "warmup_steps": 10000,
    "start_training_steps": 10000,
    "total_steps": 1000000,
    "save_dir": "checkpoints/Cheetah"
}

#Train

In [None]:
model_hps = HPS(model_config)
train_hps = HPS(train_config)
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [None]:
set_seeds(train_hps)

In [None]:
env = create_environment(train_hps.domain_name, train_hps.task_name, train_hps.action_repeat, train_hps.frame_stack, train_hps.image_size)
action_spec = env.action_spec()
agent = SAC_CURL(model_hps, train_hps, action_spec, device)
buffer = ReplayBuffer(train_hps.buffer_capacity, model_hps.observation_shape, action_spec.shape, train_hps.save_dir, device)
crop = Crop(train_hps)
ckpt = None

In [None]:
train(action_spec, env, agent, buffer, crop, train_hps, device)