# Bayesian Actor-Critic algortihms (BAC)
---
In this notebook, we train a Bayesian Actor-Critic with DeepMind Control Suite's `Cartpole` domain in `balance` task.

### 1. Import the Necessary Packages

In [None]:
from dm_control import suite

import cv2
import glob
from PIL import Image
import subprocess
from packaging import version

from datetime import datetime
from collections import deque
from tqdm import trange
from IPython.display import clear_output
import matplotlib.pyplot as plt

from bayesian_ddpg import Agent
from cpprb import ReplayBuffer, PrioritizedReplayBuffer
from utils.logx import EpochLogger

import numpy as np
import tensorflow as tf
from tensorflow import keras

print("TensorFlow version: ", tf.__version__)
assert version.parse(tf.__version__).release[0] >= 2, \
    "This notebook requires TensorFlow 2.0 or above."

### 2. Instantiate the Environment and Agent

In [None]:
BUFFER_SIZE = int(1e5)
STATE_DIM = (5,)
ACTION_DIM = 1
BATCH_SIZE = 256

env = suite.load(domain_name='cartpole', 
                 task_name='balance')

agent = Agent(state_dim=STATE_DIM, 
              action_dim=ACTION_DIM, 
              dropout_on_v=0)

logger_kwargs=dict()
logger = EpochLogger(**logger_kwargs)

print('Running on ', agent.device)

rb = ReplayBuffer(BUFFER_SIZE, {"obs": {"shape": (STATE_DIM,)},
                                "act": {"shape": ACTION_DIM},
                                "rew": {},
                                "next_obs": {"shape": (STATE_DIM,)},
                                "done": {}})

### 3. Train the Agent with DDPG

In [None]:
n_episodes=1000; max_t=1e3; print_every=5
scores_deque = deque(maxlen=print_every)

prevScore = 0
for i_episode in trange(1, int(n_episodes)+1):
    
    time_step = env.reset()
    state = np.concatenate( [ time_step.observation[key] 
                             for key in list( time_step.observation.keys() ) ] )
    score = 0
    
    for t in range(int(max_t)):      
        action = agent.get_action(state)
        time_step = env.step(action)
        reward, done = time_step.reward, time_step.last()
        next_state = np.concatenate( [ time_step.observation[key] 
                                      for key in list( time_step.observation.keys() ) ] )
        
        # Learn, if enough samples are available in memory
        if rb.get_stored_size() > BATCH_SIZE:
            data = rb.sample(BATCH_SIZE)                
            states = data['obs']; actions = data['act']; rewards = data['rew']
            next_states = data['next_obs']; dones = data['done']
            
            agent.train(states, actions, next_states, rewards, dones)
        
        # Save experience / reward
        else:       
            rb.add(obs=state, 
                   act=action, 
                   next_obs=next_state, 
                   rew=reward,
                   done=done)
            
        state = next_state
        score += reward

        if done:
            break
    
    scores_deque.append(score)
    
    if i_episode % print_every == 0:
        
        # Log info about epoch
        logger.log_tabular('Episode', i_episode)
        logger.log_tabular('EpScore', score)
        logger.log_tabular('PrevScore', prevScore)
        logger.log_tabular('EpLen (current)', t)
        
        # Save models
        paths = logger.tf_simple_save(agent)
                
        prevScore = score
        clear_output(True)
        logger.dump_tabular()

### 4. Watch a Smart Agent!

In [None]:
agent.actor_local.load_state_dict(torch.load('checkpoint_actor.pth'))
agent.critic_local.load_state_dict(torch.load('checkpoint_critic.pth'))

# reset frames folder
subprocess.call([ 'rm', '-rf', 'frames'])
subprocess.call([ 'mkdir', '-p', 'frames'])

time_step = env.reset()
state = np.concatenate([time_step.observation[key] for key in list(time_step.observation.keys())])

agent.actor_local.eval()
agent.critic_local.eval()

with torch.no_grad():
    for t in trange(0, 700):
        action = agent.act(state)
        time_step = env.step(action)

        image_data = env.physics.render(height=480, width=480, camera_id=0)
        img = Image.fromarray(image_data, 'RGB')
        img.save("frames/frame-%.10d.png" % t)

        state = np.concatenate([time_step.observation[key] for key in list(time_step.observation.keys())])
        clear_output(True)
        if time_step.last():
            break


In [None]:
# Convert frames to video
img_array = []
for filename in sorted(glob.glob('frames/*.png')):
    img = cv2.imread(filename)
    height, width, layers = img.shape
    size = (width,height)
    img_array.append(img)

out = cv2.VideoWriter('project.mp4',cv2.VideoWriter_fourcc(*'DIVX'), 15, size)

for i in range(len(img_array)):
    out.write(img_array[i])
out.release()