In [1]:
# todo: understand log_prob of evaluate

In [1]:
VERSION = 'sac,lr=1E-4,gamma=0.9,sq_log_corrected'

from importlib import reload
import models_sac
import utils
reload(models_sac)
reload(utils)

from models_sac import CriticNetwork, ActorNetwork
from utils import ReplayBuffer, seed_everything, Monitor

import torch
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Normal

import numpy as np
import math
import os
from tqdm import tqdm
import logging
import time

import minerl
import gym

import matplotlib.pyplot as plt
import pickle

from IPython.display import clear_output
from IPython import display

logging.basicConfig(filename='logs/'+VERSION+'-'+time.strftime("%Y%m%d-%H%M%S")+'.log', 
                    filemode='w', level=logging.INFO,
                    format='%(asctime)s - %(levelname)s - %(message)s', datefmt='%d-%b-%y %H:%M:%S')

In [2]:
ACTIONS = ['attack', 'jump', 'forward', 'back', 'left', 'right', 'sprint', 'sneak', 'camera_hor', 'camera_ver']

SEED = 584
OBS_DIM = int(64*64*3+1) # pov + compassAngle
BUFFER_SIZE = int(1E5)

MAX_NUM_FRAMES = int(1E9)
BATCH_SIZE = 32

POV_SCALING = 255
COMPASS_SCALING = 180
MAX_CAMERA_TURN = 15
TANH_FACTOR = 1
CAMERA_FACTOR = 1
NUM_FILTERS = 128

ACTOR_LR = 1E-4
CRITIC_LR = 1E-4
ALPHA_LR = 1E-4

GAMMA = 0.9
BUFFER_SIZE = int(1E5)

TAU = 0.001
USE_BN = True

INITIAL_ALPHA = 0.1

monitor = Monitor('monitor_'+VERSION+'.pkl', 'monitoring')

### SAC

In [15]:
class Agent:

    def __init__(self, 
                 num_acts=len(ACTIONS),
                 batch_size=BATCH_SIZE, 
                 gamma=GAMMA,
                 actor_learning_rate=ACTOR_LR, 
                 critic_learning_rate=CRITIC_LR,
                 alpha_lr = ALPHA_LR,
                 tau=TAU,
                 buffer_capacity=BUFFER_SIZE,
                 initial_alpha=INITIAL_ALPHA
                 ):

        self.num_acts = num_acts
        self.batch_size = batch_size
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        logging.info("Device is: "+str(self.device))
        print("Device is: ", self.device)
        
        self.gamma = gamma
        self.actor_lr = actor_learning_rate
        self.critic_lr = critic_learning_rate
        self.tau = tau
        self.buffer = ReplayBuffer( 
            obs_dim=OBS_DIM,
            size=BUFFER_SIZE,
            act_dim=len(ACTIONS),
            batch_size=BATCH_SIZE
        )

        
        params_nn = {'acts_dim':len(ACTIONS), 'num_filters':NUM_FILTERS, 'use_bn':USE_BN,  
            'pov_scaling':POV_SCALING, 'compass_scaling':COMPASS_SCALING}
        self.policy = ActorNetwork(**params_nn).to(self.device)
        self.q1 = CriticNetwork(**params_nn).to(self.device)
        self.q2 = CriticNetwork(**params_nn).to(self.device)
        self.target_q1 = CriticNetwork(**params_nn).to(self.device)
        self.target_q2 = CriticNetwork(**params_nn).to(self.device)
        self.target_q1.load_state_dict(self.q1.state_dict())
        self.target_q2.load_state_dict(self.q2.state_dict())
        self.target_entropy = -torch.prod(torch.Tensor(num_acts).to(self.device)).item()
        self.alpha = initial_alpha
        self.log_alpha = nn.Parameter( torch.Tensor( [np.log(self.alpha)] ).to(self.device) )
        
        self.q1_optim = optim.Adam(self.q1.parameters(), lr=CRITIC_LR)
        self.q2_optim = optim.Adam(self.q2.parameters(), lr=CRITIC_LR)
        self.policy_optim = optim.Adam(self.policy.parameters(), lr=ACTOR_LR)
        self.alpha_optim = optim.Adam([self.log_alpha], lr=alpha_lr)

    def get_act(self, obs_pov, obs_rest):
        
        self.policy.eval()
        with torch.no_grad():
            act = self.policy.get_action(obs_pov, obs_rest)
        self.policy.train()
        
        return act.cpu().numpy()
        
    def unflatten_obs(self, flat_obs):
        return (flat_obs[:,:-1].reshape(-1,64,64,3), flat_obs[:,-1].reshape(-1,1))
        
    def flatten_obs(self, obs):
        return np.append(obs['pov'].reshape(-1), obs['compassAngle'])

    # Store the transition into the replay buffer
    def store_transition(self, obs, next_obs, act, rew, done):
        obs = self.flatten_obs(obs)
        next_obs = self.flatten_obs(next_obs)
        self.buffer.store(obs=obs, act=act, rew=rew, 
                          next_obs=next_obs, done=done)

    def float_tensor(self, numpy_array):
        return torch.FloatTensor(numpy_array).to(self.device)
        
    def fit_batch(self):
        # Sample frorm buffer
        transitions = self.buffer.sample_batch()
        
        obss_pov, obss_rest = self.unflatten_obs(self.float_tensor(transitions['obs']))
        next_obss_pov, next_obss_rest = self.unflatten_obs(self.float_tensor(transitions['next_obs']))
        acts = self.float_tensor(transitions['acts'])
        rews = self.float_tensor(transitions['rews'])
        dones = self.float_tensor(transitions['dones'])
        
        # Q function loss
        with torch.no_grad():
            next_log_prob, next_action = self.policy.get_log_probs(next_obss_pov, next_obss_rest)
            target_q1_next = self.target_q1(next_obss_pov, next_obss_rest, next_action).view(-1)
            target_q2_next = self.target_q2(next_obss_pov, next_obss_rest, next_action).view(-1)
            min_q_target_hat = torch.min(target_q1_next, target_q2_next) - self.alpha * next_log_prob # UNDERSTAND
            y = rews + (1 - dones) * self.gamma * min_q_target_hat
            monitor.add(['next_log_prob_mean','min_q_target_hat_mean','alpha','y_mean'],
                       [float(next_log_prob.mean()), float(min_q_target_hat.mean()), 
                        float(self.alpha), float(y.mean())])
        
        q1_hat = self.q1( obss_pov, obss_rest , acts ).view(-1)  
        q2_hat = self.q2( obss_pov, obss_rest , acts ).view(-1)  
        q1_loss = F.mse_loss(q1_hat, y) 
        q2_loss = F.mse_loss(q2_hat, y)
        #assert (float(q1_loss)<100) or (float(q2_loss)<100)
        
        # Policy loss
        log_pi, pi = self.policy.get_log_probs(obss_pov, obss_rest)
        
        q1_hat_policy = self.q1(obss_pov, obss_rest, pi)
        q2_hat_policy = self.q2(obss_pov, obss_rest, pi)
        min_q_pi = torch.min(q1_hat_policy, q2_hat_policy)

        policy_loss = ((self.alpha * log_pi) - min_q_pi).mean() # Jπ = 𝔼st∼D,εt∼N[α * logπ(f(εt;st)|st) − Q(st,f(εt;st))]
        
        # Gradient descent
        self.q1_optim.zero_grad()
        q1_loss.backward()
        torch.nn.utils.clip_grad_norm_(self.q1.parameters(), 1.0, norm_type=1)
        self.q1_optim.step()
        
        self.q2_optim.zero_grad()
        q2_loss.backward()
        torch.nn.utils.clip_grad_norm_(self.q2.parameters(), 1.0, norm_type=1)
        self.q2_optim.step()
        
        self.policy_optim.zero_grad()
        policy_loss.backward()
        torch.nn.utils.clip_grad_norm_(self.policy.parameters(), 1.0, norm_type=1)
        self.policy_optim.step()
        
        # Alpha parameter tuning      
        alpha_loss = -(self.log_alpha * (log_pi + self.target_entropy).detach()).mean()

        self.alpha_optim.zero_grad()
        alpha_loss.backward()
        #torch.nn.utils.clip_grad_norm_(self.log_alpha.parameters(), 1.0, norm_type=1)
        self.alpha_optim.step()

        self.alpha = self.log_alpha.exp()

        return q1_loss, q2_loss, policy_loss, q1_hat, q2_hat
    
    def update_target_networks(self):
        self.polyak_averaging(self.target_q1, self.q1)
        self.polyak_averaging(self.target_q2, self.q2)
        
    def polyak_averaging(self, target, original):
        for target_param, param in zip(target.parameters(), original.parameters()):
            target_param.data.copy_(self.tau * param.data + target_param.data * (1.0 - self.tau))
            
    def get_env_act(self, model_act):
        '''
        Gets environment act from model act
        '''
        env_act = {act: int(value>0) for act, value in zip(ACTIONS[0:8], model_act[0:8])}
        env_act['camera'] = [model_act[8]*MAX_CAMERA_TURN, model_act[9]*MAX_CAMERA_TURN]
        return env_act

### Instantiate

In [9]:
if 'environment_name' not in locals():
    environment_name = 'MineRLTreechop-v0'
    %time env = gym.make(environment_name)

seed_everything(seed=SEED, env=env)

In [16]:
agent = Agent()

# number of trainable parameters of model
message_1 = 'Number of trainable parameters actor:' + str(sum(p.numel() for p in agent.policy.parameters()  if p.requires_grad))
message_2 = 'Number of trainable parameters critic:'+str(sum(p.numel() for p in agent.q1.parameters()  if p.requires_grad)*2)
message_3 = 'Number of trainable parameters alpha:'+' Unknown'
print(message_1)
print(message_2)
print(message_3)

logging.info(message_1)
logging.info(message_2)

Device is:  cuda
Number of trainable parameters actor:1106388
Number of trainable parameters critic:2212866
Number of trainable parameters alpha: Unknown


In [None]:
# Load agent
# agent.actor.load_state_dict(torch.load('trained_models/ddpg_actor_'+VERSION+'.pkl'))
# agent.critic.load_state_dict(torch.load('trained_models/ddpg_critic_'+VERSION+'.pkl'))

### Training loop

In [17]:
%time obs = env.reset()
assert (agent.unflatten_obs( np.array([agent.flatten_obs(obs)]) )[0][0].astype(int) == obs['pov']).sum() == 3*64*64
trajectory_count = 1
losses_agent = []
losses_critic = []
scores = []
current_score = []
current_scores = []
max_parameters = []
plotting_interval = 100
score = 0
logging.info('Environment reset')

q1_values_mean=[]
q2_values_mean=[]
losses_agent=[]
losses_q1=[]
losses_q2=[]

CPU times: user 3.02 ms, sys: 5.42 ms, total: 8.44 ms
Wall time: 7.45 s


KeyError: 'compassAngle'

In [None]:
def check_relus_are_alive():
    for network in [agent.q1, agent.q2, agent.policy]:
        if float(network.non_lin_1.sum()) == 0:
            text = str(network) + ' has non_lin_1 died'
            logging.info(text)
            print(text)
        if float(network.non_lin_2.sum()) == 0:
            text = str(network) + ' has non_lin_2 died'
            logging.info(text)
            print(text)
        if float(network.non_lin_3.sum()) == 0:
            text = str(network) + ' has non_lin_3 died'
            logging.info(text)
            print(text)
            
    
from typing import List    
def plot_stats(
        frame_idx: int, 
        scores: List[float],
        current_score: List[float],
        losses: List[float],
    ):
    """Plot the training progresses."""
    clear_output(True)
    plt.figure(figsize=(20, 5))
    plt.subplot(131)
    plt.title('q_values')
    plt.plot(scores[-10000:])
    plt.subplot(132)
    plt.title('loss')
    plt.plot(losses)
    plt.subplot(133)
    plt.title('current_score')
    plt.plot(current_score)
    plt.show()
            
# test
# agent.actor.x_relu_1[agent.actor.x_relu_1!=0] = 0

In [None]:
for frame_idx in tqdm(range(MAX_NUM_FRAMES), desc='frame'):
    obs_pov = agent.float_tensor([obs['pov'].astype(float)])
    obs_rest = agent.float_tensor([[obs['compassAngle']]])
    model_act = agent.get_act(obs_pov, obs_rest)
    env_act = agent.get_env_act(model_act[0])
    next_obs, rew, done, info = env.step(env_act)

    # Store the transition in the replay buffer of the agent
    agent.store_transition(obs=obs, next_obs=next_obs,
                               act=model_act, done=done, rew=rew)
    
    # Prepare for next step and store scores
    obs = next_obs
    score += rew
    
    monitor.add(f'score_{trajectory_count}', score)
    #current_score.append(score)

    # if episode ends
    if done:
        obs = env.reset()
        current_scores.append(current_score)
        scores.append(score)
        pickle.dump(current_scores, open('scores/current_score_'+VERSION+'.pkl', 'wb'))
        pickle.dump(scores, open('scores/scores_'+VERSION+'.pkl', 'wb'))
        current_score = []
        score = 0
        torch.save(agent.policy.state_dict(), 'trained_models/sac_policy'+VERSION+'.pkl')
        torch.save(agent.q1.state_dict(), 'trained_models/sac_q1'+VERSION+'.pkl')
        torch.save(agent.q2.state_dict(), 'trained_models/sac_q2'+VERSION+'.pkl')
        logging.info(f'Trajectory {len(current_scores)} done, with final score {current_scores[-1][-1]}')
        trajectory_count += 1

    # TRAIN
    if len(agent.buffer) >= agent.batch_size:
        q1_loss, q2_loss, policy_loss, q1_hat, q2_hat = agent.fit_batch()
        agent.update_target_networks()
        q1_hat_mean = float(q1_hat.mean())
        q2_hat_mean = float(q2_hat.mean())
        
        assert (q1_hat_mean==q1_hat_mean) or (q2_hat_mean==q2_hat_mean), "At least one q function returns NaN!"
        
        q1_values_mean.append(q1_hat_mean)
        q2_values_mean.append(q2_hat_mean)
        
        losses_agent.append(float(policy_loss))
        losses_q1.append(float(q1_loss))
        losses_q2.append(float(q2_loss))

    if (frame_idx+1) % plotting_interval == 0:
        
        pickle.dump([q1_values_mean, q2_values_mean], open('logs/q_values_'+VERSION+'.pkl','wb'))
        check_relus_are_alive()
        
        plot_stats(frame_idx, 
                   np.array([q1_values_mean, q2_values_mean]).T, 
                   current_score,
                   np.array([losses_agent, losses_q1, losses_q2]).T)
        monitor.plot_all()
        monitor.save()
        
        

### Observe agent

In [None]:
obs = env.reset()
net_reward = 0
actions = []

In [None]:
for i in range(200):
    #import pdb; pdb.set_trace()
    obs_pov = agent.float_tensor([obs['pov'].astype(float)])
    obs_rest = agent.float_tensor([[obs['compassAngle']]])
    model_act = agent.get_act(obs_pov,obs_rest)
    env_act = get_env_act(model_act[0])
    next_obs, rew, done, info = env.step(env_act)

    # Prepare for next step and store scores
    obs = next_obs
    score += rew
    current_score.append(score)
    
#     if i%10==0:
    plt.imshow(env.render(mode='rgb_array')) 
    display.display(plt.gcf())
    clear_output(wait=True)
    net_reward += rew
    actions.append((env_act, net_reward))