# Bayesian Actor-Critic algortihms (BAC)
---
In this notebook, we train a Bayesian Actor-Critic with DeepMind Control Suite's `Cartpole` domain in `balance` task.

### 1. Import the Necessary Packages

In [None]:
import os

from dm_control import suite

from collections import deque
from tqdm import trange
from IPython.display import clear_output

from datetime import datetime
from packaging import version

import numpy as np
import tensorflow as tf
from tensorflow import keras

print("TensorFlow version: ", tf.__version__)
assert version.parse(tf.__version__).release[0] >= 2, \
    "This notebook requires TensorFlow 2.0 or above."

### 2. Instantiate the Environment and Agent

In [None]:
from bayesian_ddpg import Agent
from cpprb import ReplayBuffer, PrioritizedReplayBuffer

In [None]:
BUFFER_SIZE = int(1e5)
STATE_DIM = (5,)
ACTION_DIM = 1
BATCH_SIZE = 256

env = suite.load(domain_name='cartpole', 
                 task_name='balance')

agent = Agent(state_dim=STATE_DIM, 
              action_dim=ACTION_DIM, 
              dropout_on_v=0)

print('Running on ', agent.device)

rb = ReplayBuffer(BUFFER_SIZE, {"obs": {"shape": (STATE_DIM,)},
                                "act": {"shape": ACTION_DIM},
                                "rew": {},
                                "next_obs": {"shape": (STATE_DIM,)},
                                "done": {}})

### 3. Train the Agent with DDPG

In [None]:
# Clear any logs from previous runs
!rm -rf logs 

log_dir="logs/"
summary_writer = tf.summary.create_file_writer(
  log_dir + "scalar/" + datetime.now().strftime("%Y%m%d-%H%M%S"))

# Checkpoint-saver
checkpoint_dir = './training_checkpoints'
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
checkpoint = tf.train.Checkpoint(actor=agent.pi,
                                 critic=agent.critic,
                                 actor_optim=agent.pi_optim,
                                 critic_optim=agent.critic_optim)

%load_ext tensorboard
%tensorboard --logdir {log_dir}

In [None]:
n_episodes=1000; max_t=1e3; save_every=2
scores_deque = deque(maxlen=save_every)


prevScore = 0
for i_episode in trange(1, int(n_episodes)+1):
    
    time_step = env.reset()
    state = np.concatenate( [ time_step.observation[key] 
                             for key in list( time_step.observation.keys() ) ] )
    score = 0
    
    for t in range(int(max_t)):      
        action = agent.get_action(state)
        time_step = env.step(action)
        reward, done = time_step.reward, time_step.last()
        next_state = np.concatenate( [ time_step.observation[key] 
                                      for key in list( time_step.observation.keys() ) ] )
        
        # Learn, if enough samples are available in memory
        if rb.get_stored_size() > BATCH_SIZE:
            data = rb.sample(BATCH_SIZE)                
            states = data['obs']; actions = data['act']; rewards = data['rew']
            next_states = data['next_obs']; dones = data['done']
            
            actor_loss, critic_loss, _ = agent.train(states, 
                                                     actions, 
                                                     next_states, 
                                                     rewards, 
                                                     dones)
            with summary_writer.as_default():
                tf.summary.scalar(name="actor_loss",
                                  data=actor_loss,
                                  step=t)
                tf.summary.scalar(name="critic_loss",
                                  data=critic_loss,
                                  step=t)
        
        # Save experience / reward
        else:       
            rb.add(obs=state, 
                   act=action, 
                   next_obs=next_state, 
                   rew=reward,
                   done=done)
            
        state = next_state
        score += reward

        if done:
            break
    
    with summary_writer.as_default():
        tf.summary.scalar(name="EpRet",
                          data=score,
                          step=i_episode)
    
    if i_episode % save_every == 0:
        checkpoint.save(file_prefix = checkpoint_prefix)
        
checkpoint.save(file_prefix = checkpoint_prefix)

### 4. Watch a Smart Agent!

In [None]:
import cv2
import glob
from PIL import Image
import subprocess
from packaging import version

agent.actor_local.load_state_dict(torch.load('checkpoint_actor.pth'))
agent.critic_local.load_state_dict(torch.load('checkpoint_critic.pth'))

# reset frames folder
subprocess.call([ 'rm', '-rf', 'frames'])
subprocess.call([ 'mkdir', '-p', 'frames'])

time_step = env.reset()
state = np.concatenate([time_step.observation[key] for key in list(time_step.observation.keys())])

agent.actor_local.eval()
agent.critic_local.eval()

with torch.no_grad():
    for t in trange(0, 700):
        action = agent.act(state)
        time_step = env.step(action)

        image_data = env.physics.render(height=480, width=480, camera_id=0)
        img = Image.fromarray(image_data, 'RGB')
        img.save("frames/frame-%.10d.png" % t)

        state = np.concatenate([time_step.observation[key] for key in list(time_step.observation.keys())])
        clear_output(True)
        if time_step.last():
            break


In [None]:
# Convert frames to video
img_array = []
for filename in sorted(glob.glob('frames/*.png')):
    img = cv2.imread(filename)
    height, width, layers = img.shape
    size = (width,height)
    img_array.append(img)

out = cv2.VideoWriter('project.mp4',cv2.VideoWriter_fourcc(*'DIVX'), 15, size)

for i in range(len(img_array)):
    out.write(img_array[i])
out.release()