In [1]:
# todo: understand log_prob of evaluate

In [2]:
VERSION = 'sac,treechop_v2,max_camera_turn=10,sample_hum_ratio=0.5'

from importlib import reload
import models_sac
import utils
reload(models_sac)
reload(utils)

from models_sac_treechop import CriticNetwork, ActorNetwork
from utils import ReplayBuffer, seed_everything, Monitor

import torch
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Normal

import numpy as np
import math
import random
import os
from tqdm import tqdm
import logging
import time

import minerl
import gym

import matplotlib.pyplot as plt
import pickle

from IPython.display import clear_output
from IPython import display
import shutil

logging.basicConfig(filename='logs/'+VERSION+'-'+time.strftime("%Y%m%d-%H%M%S")+'.log', 
                    filemode='w', level=logging.INFO,
                    format='%(asctime)s - %(levelname)s - %(message)s', datefmt='%d-%b-%y %H:%M:%S')

In [3]:
ACTIONS = ['attack', 'jump', 'forward', 'back', 'left', 'right', 'sprint', 'sneak', 'camera_hor', 'camera_ver']

SEED = 584
ENV_NAME = 'MineRLTreechop-v0'
OBS_DIM = 64*64*3 # pov + compassAngle
OBS_REST = False # if there are other observations beyond POV
BUFFER_SIZE = int(1E5)

MAX_NUM_FRAMES = int(1E9)
BATCH_SIZE = 32

POV_SCALING = 255

TANH_FACTOR = 1
CAMERA_FACTOR = 1
NUM_FILTERS = 128
LIN_1_DIM = 512
LIN_2_DIM = 256

ACTOR_LR = 1E-5
CRITIC_LR = 1E-5
ALPHA_LR = 1E-5
MAX_CAMERA_TURN = 10

DYNAMIC_ALPHA = False
Q_PREDICTS_REW = False

TARGET_ENTROPY = -1

GAMMA = 0.9
BUFFER_SIZE = int(1E5)

TAU = 0.001
USE_BN = True

INITIAL_ALPHA = 0.1

CHECKS = True
CHECKS_AGENT = False # calculates Q value of the agent at each step

SAMPLE_HUMAN = True
SAMPLE_HUMAN_RATIO = 0.5

monitor = Monitor('monitor_'+VERSION+'.pkl', 'monitoring')

In [4]:
if SAMPLE_HUMAN:
    human_data = minerl.data.make(ENV_NAME, data_dir='/app/code/minerl-data')
else:
    human_data = None

### SAC

In [5]:
class Agent:

    def __init__(self, 
                 num_acts=len(ACTIONS),
                 batch_size=BATCH_SIZE, 
                 gamma=GAMMA,
                 actor_learning_rate=ACTOR_LR, 
                 critic_learning_rate=CRITIC_LR,
                 alpha_lr = ALPHA_LR,
                 tau=TAU,
                 initial_alpha=INITIAL_ALPHA,
                 target_entropy=TARGET_ENTROPY,
                 dynamic_alpha = DYNAMIC_ALPHA,
                 sample_human = SAMPLE_HUMAN,
                 human_data = human_data
                 ):

        self.num_acts = num_acts
        self.batch_size = batch_size
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        logging.info("Device is: "+str(self.device))
        print("Device is: ", self.device)
        
        self.dynamic_alpha = dynamic_alpha
        self.gamma = gamma
        self.actor_lr = actor_learning_rate
        self.critic_lr = critic_learning_rate
        self.tau = tau
        self.buffer = ReplayBuffer( 
            obs_dim=OBS_DIM,
            size=BUFFER_SIZE,
            act_dim=len(ACTIONS),
            batch_size=BATCH_SIZE
        )
     
        params_nn = {'acts_dim':len(ACTIONS), 'num_filters':NUM_FILTERS, 
                     'use_bn':USE_BN, 'lin_1_dim': LIN_1_DIM, 'lin_2_dim': LIN_2_DIM}
        self.policy = ActorNetwork(**params_nn).to(self.device)
        self.q1 = CriticNetwork(**params_nn).to(self.device)
        self.q2 = CriticNetwork(**params_nn).to(self.device)
        
        if not Q_PREDICTS_REW:
            self.target_q1 = CriticNetwork(**params_nn).to(self.device)
            self.target_q2 = CriticNetwork(**params_nn).to(self.device)
            self.target_q1.load_state_dict(self.q1.state_dict())
            self.target_q2.load_state_dict(self.q2.state_dict())
            
        self.alpha = initial_alpha
        if self.dynamic_alpha: self.log_alpha = nn.Parameter( torch.Tensor( [np.log(self.alpha)] ).to(self.device) )
        
        self.q1_optim = optim.Adam(self.q1.parameters(), lr=CRITIC_LR)
        self.q2_optim = optim.Adam(self.q2.parameters(), lr=CRITIC_LR)
        self.policy_optim = optim.Adam(self.policy.parameters(), lr=ACTOR_LR)
        if self.dynamic_alpha: self.alpha_optim = optim.Adam([self.log_alpha], lr=ALPHA_LR)
        self.target_entropy = target_entropy
        self.sample_human = sample_human
        if self.sample_human:
            self.human_data_iter = human_data.sarsd_iter( max_sequence_len=batch_size, seed = SEED )

    def get_act(self, obs):
        self.policy.eval()
        with torch.no_grad():
            act = self.policy.get_action(obs)
        self.policy.train()    
        return act#.cpu().numpy()
        
    def unflatten_obs(self, flat_obs):
        if OBS_REST:
            return (flat_obs[:,:-1].reshape(-1,64,64,3), flat_obs[:,-1].reshape(-1,1))
        else:
            return flat_obs.reshape(-1,64,64,3)
    
        
    def flatten_obs(self, obs):
        if OBS_REST:
            return np.append(obs['pov'].reshape(-1), obs['compassAngle'])
        else:
            return obs.reshape(-1)

    # Store the transition into the replay buffer
    def store_transition(self, obs, next_obs, act, rew, done):
        obs = self.flatten_obs(obs)
        next_obs = self.flatten_obs(next_obs)
        self.buffer.store(obs=obs, act=act, rew=rew, 
                          next_obs=next_obs, done=done)

    def float_tensor(self, numpy_array):
        return torch.FloatTensor(numpy_array).to(self.device)
        
    def sample_batch(self):
        
        if self.sample_human & (random.random() < SAMPLE_HUMAN_RATIO):
            obss_h, acts_h, rews_h, next_obss_h, dones_h = next(self.human_data_iter)
            obss = self.float_tensor( obss_h['pov'] )
            acts_ex = np.stack([acts_h[action] for action in ACTIONS[0:8] ], axis = 1)*2-1 # all acts except the camera ones. Transformation is due to the tanh.
            act_cam = np.clip(acts_h['camera']/MAX_CAMERA_TURN, -1, +1)
            acts =self.float_tensor( np.append( acts_ex, act_cam, axis=1) )
            rews = self.float_tensor( rews_h )
            next_obss = self.float_tensor( next_obss_h['pov'] )
            dones = self.float_tensor(dones_h[:-1])
        else:
            transitions = self.buffer.sample_batch()      
            obss = self.unflatten_obs( self.float_tensor(transitions['obs']) )
            next_obss = self.unflatten_obs( self.float_tensor(transitions['next_obs']) )
            acts = self.float_tensor(transitions['acts'])
            rews = self.float_tensor(transitions['rews'])
            dones = self.float_tensor(transitions['dones'])
        return obss, next_obss, acts, rews, dones
    
    def fit_batch(self):
        # Sample
        obss, next_obss, acts, rews, dones = self.sample_batch()

        # Q function loss
        with torch.no_grad():
            next_log_prob, next_action = self.policy.get_log_probs(next_obss)
            if not Q_PREDICTS_REW:
                target_q1_next = self.target_q1(next_obss, next_action).view(-1)
                target_q2_next = self.target_q2(next_obss, next_action).view(-1)
                min_q_target_hat = torch.min(target_q1_next, target_q2_next) - self.alpha * next_log_prob.view(-1) 
                y = rews + (1 - dones) * self.gamma * min_q_target_hat
            else:
                y = rews
 
        q1_hat = self.q1( obss , acts ).view(-1)  
        q2_hat = self.q2( obss , acts ).view(-1)  
        q1_loss = F.mse_loss(q1_hat, y) 
        q2_loss = F.mse_loss(q2_hat, y)
        #assert (float(q1_loss)<100) or (float(q2_loss)<100)
        
        # Policy loss
        log_pi, pi = self.policy.get_log_probs(obss)
        
        q1_hat_policy = self.q1( obss, pi)
        q2_hat_policy = self.q2( obss, pi)
        min_q_pi = torch.min(q1_hat_policy, q2_hat_policy)
        policy_loss = (self.alpha * log_pi - min_q_pi).mean()
        
        if CHECKS:
            #policy_prev_params = copy.deepcopy( list(self.policy.parameters()))
            y_mean = float( y.mean() )
            monitor.add(['next_log_prob_mean', 'min_q_target_hat_mean',
                         'alpha','y_mean', 
                         'log_pi', 'policy_mean_mean','policy_log_std_mean', 'num_rewards_training'],
           [float(next_log_prob.mean()),  float(min_q_target_hat.mean()), 
            float(self.alpha), y_mean, float(log_pi.mean()),
            float(self.policy.mean.mean()), float(self.policy.log_std.mean()), float(sum(rews))])
            assert y_mean == y_mean, "Y is NaN!"
        
        # Gradient descent
        self.q1_optim.zero_grad()
        q1_loss.backward()
        torch.nn.utils.clip_grad_norm_(self.q1.parameters(), 1.0, norm_type=1)
        self.q1_optim.step()
        
        self.q2_optim.zero_grad()
        q2_loss.backward()
        torch.nn.utils.clip_grad_norm_(self.q2.parameters(), 1.0, norm_type=1)
        self.q2_optim.step()
        
        self.policy_optim.zero_grad()
        policy_loss.backward()
        torch.nn.utils.clip_grad_norm_(self.policy.parameters(), 1.0, norm_type=1)
        self.policy_optim.step()
        
        if CHECKS:
            check_nan_policy = np.array([float(i.mean()) for i in self.policy.parameters()]).mean()
            assert check_nan_policy == check_nan_policy
            check_nan_q1 = np.array([float(i.mean()) for i in self.q1.parameters()]).mean()
            assert check_nan_q1 == check_nan_q1
            check_nan_q2 = np.array([float(i.mean()) for i in self.q2.parameters()]).mean()
            assert check_nan_q2 == check_nan_q2
        
        # Alpha parameter tuning
        if self.dynamic_alpha:
            alpha_loss = -(self.log_alpha * (log_pi + self.target_entropy).detach()).mean()

            self.alpha_optim.zero_grad()
            alpha_loss.backward()
            #torch.nn.utils.clip_grad_norm_(self.log_alpha.parameters(), 1.0, norm_type=1)
            self.alpha_optim.step()

            self.alpha = self.log_alpha.exp()

        return q1_loss, q2_loss, policy_loss, q1_hat, q2_hat
    
    def update_target_networks(self):
        self.polyak_averaging(self.target_q1, self.q1)
        self.polyak_averaging(self.target_q2, self.q2)
        
    def polyak_averaging(self, target, original):
        for target_param, param in zip(target.parameters(), original.parameters()):
            target_param.data.copy_(self.tau * param.data + target_param.data * (1.0 - self.tau))
            
    def get_env_act(self, model_act):
        '''
        Gets environment act from model act
        '''
#         env_act = {'attack':0, 'jump':1, 'forward':1, 'back':0, 'left':0, 'right':0, 'sprint':0, 'sneak':0, 'camera':[0, model_act[0]*180]}
        env_act = {act: int(value>0) for act, value in zip(ACTIONS[0:8], model_act[0:8])}
        env_act['camera'] = [model_act[8]*MAX_CAMERA_TURN, model_act[9]*MAX_CAMERA_TURN]
        return env_act

### Instantiate

In [6]:
if 'env' not in locals():
    %time env = gym.make(ENV_NAME)

seed_everything(seed=SEED, env=env)

CPU times: user 255 ms, sys: 328 ms, total: 584 ms
Wall time: 1min 12s


In [7]:
agent = Agent()

# number of trainable parameters of model
message_1 = f'Number of trainable parameters actor: {str(sum(p.numel() for p in agent.policy.parameters()  if p.requires_grad))}'
message_2 = f'Number of trainable parameters critic: {str(sum(p.numel() for p in agent.q1.parameters()  if p.requires_grad)*2)}'
message_3 = 'Number of trainable parameters alpha:'+' Unknown'
print(message_1)
print(message_2)
print(message_3)

logging.info(message_1)
logging.info(message_2)

Device is:  cuda
Number of trainable parameters actor: 3642004
Number of trainable parameters critic: 7284482
Number of trainable parameters alpha: Unknown


In [8]:
# Load agent
# agent.policy.load_state_dict(torch.load('trained_models/sac_policy'+VERSION+'.pkl'))
#agent.critic.load_state_dict(torch.load('trained_models/ddpg_critic_'+VERSION+'.pkl'))

### Training loop

In [9]:
%time obs = env.reset()
assert (agent.unflatten_obs( np.array([agent.flatten_obs(obs['pov'])]) )[0].astype(int) == obs['pov']).sum() == 3*64*64
trajectory_count = 1
losses_agent = []
losses_critic = []
scores = []
current_score = []
current_scores = []
max_parameters = []
plotting_interval = 100
score = 0
logging.info('Environment reset')

q1_values_mean=[]
q2_values_mean=[]
losses_agent=[]
losses_q1=[]
losses_q2=[]

CPU times: user 13.5 ms, sys: 4.61 ms, total: 18.1 ms
Wall time: 11.4 s


In [10]:
def check_relus_are_alive():
    for network in [agent.q1, agent.q2, agent.policy]:
        if float(network.non_lin_1.sum()) == 0:
            text = str(network) + ' has non_lin_1 died'
            logging.info(text)
            print(text)
        if float(network.non_lin_2.sum()) == 0:
            text = str(network) + ' has non_lin_2 died'
            logging.info(text)
            print(text)
        if float(network.non_lin_3.sum()) == 0:
            text = str(network) + ' has non_lin_3 died'
            logging.info(text)
            print(text)
                       
# test
# agent.actor.x_relu_1[agent.actor.x_relu_1!=0] = 0

In [None]:
for frame_idx in tqdm(range(MAX_NUM_FRAMES), desc='frame'):
    obs_tensor = agent.float_tensor([obs['pov'].astype(float)])
    model_act = agent.get_act(obs_tensor)
    model_act_np = model_act.cpu().numpy()[0] # act returned by the model in numpy
    env_act = agent.get_env_act(model_act_np)
    next_obs, rew, done, info = env.step(env_act)
    
    # Monitor behaviour of agent for checks
    if CHECKS_AGENT:
        with torch.no_grad():
            new_q = agent.q1(obs_tensor, model_act)
            monitor.add('q_agent', float(new_q))
    monitor.add([ 'camera_hor_agent','rew_agent'], 
                [env_act['camera'][1], rew])
    
    # Store the transition in the replay buffer of the agent
    agent.store_transition(obs=obs['pov'], next_obs=next_obs['pov'],
                               act=model_act_np, done=done, rew=rew)
    
    # Prepare for next step and store scores
    obs = next_obs
    score += rew
    monitor.add(f'score_{trajectory_count}', score)
    monitor.add(ACTIONS, model_act[0])
    monitor.add(['agent_mean', 'agent_log_std'], [float(agent.policy.mean.mean()), float(agent.policy.log_std.mean())])
    
    if done:
        obs = env.reset()
        score = 0
        last_score = monitor.data[f'score_{trajectory_count}'][-1]
        trajectory_count += 1
        
        # Save model and log
        torch.save(agent.policy.state_dict(), 'trained_models/sac_policy'+VERSION+'.pkl')
        torch.save(agent.q1.state_dict(), 'trained_models/sac_q1'+VERSION+'.pkl')
        torch.save(agent.q2.state_dict(), 'trained_models/sac_q2'+VERSION+'.pkl')
        logging.info(f'Trajectory {len(current_scores)} done, with final score {last_score}')
        
    # TRAIN
    if len(agent.buffer) >= agent.batch_size:
        q1_loss, q2_loss, policy_loss, q1_hat, q2_hat = agent.fit_batch()
        # agent.update_target_networks()
        q1_hat_mean = float(q1_hat.mean())
        q2_hat_mean = float(q2_hat.mean())
        monitor.add(['q1_hat.mean', 'q2_hat.mean', 'loss_agent', 'loss_q1', 'loss_q2'],
                    [q1_hat_mean, q2_hat_mean, float(policy_loss),
                    float(q1_loss), float(q2_loss) ])
        
        #assert (q1_hat_mean==q1_hat_mean) or (q2_hat_mean==q2_hat_mean), "At least one q function returns NaN!"

    if (frame_idx+1) % plotting_interval == 0:
        #check_relus_are_alive()
        #clear_output(True)
        #monitor.plot_all()
        monitor.save()
        shutil.copyfile(f'monitoring/monitor_{VERSION}.pkl', f'monitoring/monitor_{VERSION}_copy.pkl')
        

frame:   0%|          | 1187/1000000000 [07:00<69007:47:28,  4.03it/s] 

In [None]:
%debug

In [None]:
monitor

In [None]:
self = agent
obss_h, acts_h, rews_h, next_obss_h, dones_h = next(self.human_data_iter)
obss = self.float_tensor( obss_h['pov'] )
acts_ex = np.stack([acts_h[action] for action in ACTIONS[0:8] ], axis = 1)*2 -1 # all acts except the camera ones
acts =self.float_tensor( np.append( acts_ex, acts_h['camera']/POV_SCALING, axis=1) )
rews = self.float_tensor( rews_h )
next_obss = self.float_tensor( next_obss_h['pov'] )
dones = self.float_tensor(dones_h)
sum(rews)

In [None]:
%debug

### Observe agent

In [None]:
obs = env.reset()
net_reward = 0
actions = []
score = 0
current_score = []

In [None]:
for i in range(200):
    #import pdb; pdb.set_trace()
    obs_pov = agent.float_tensor([obs['pov'].astype(float)])
    #obs_rest = agent.float_tensor([[obs['compassAngle']]])
    model_act = agent.get_act(obs_pov).cpu().numpy()
    env_act = agent.get_env_act(model_act[0])
    next_obs, rew, done, info = env.step(env_act)

    # Prepare for next step and store scores
    obs = next_obs
    score += rew
    current_score.append(score)
    
#     if i%10==0:
    plt.imshow(env.render(mode='rgb_array')) 
    display.display(plt.gcf())
    clear_output(wait=True)
    net_reward += rew
    actions.append((env_act, net_reward))

In [None]:
env_act

In [None]:
actions

In [None]:
for i in range(1):
    #import pdb; pdb.set_trace()
    env_act= {'attack':0, 
               'jump':1, 
               'forward':1, 
               'back':0, 
               'left':0, 
               'right':0, 
               'sprint':0, 
               'sneak':0, 
               'camera':[0,  0.03*obs["compassAngle"]]}
    next_obs, rew, done, info = env.step(env_act)
    
    # Prepare for next step and store scores
    plt.imshow(env.render(mode='rgb_array'))     
    display.display(plt.gcf())
    clear_output(wait=True)
    print('compassAngle:',obs['compassAngle'])
    print('next compassAngle:', next_obs['compassAngle'])
    obs = next_obs
    score += rew
    current_score.append(score)
    
    net_reward += rew
    actions.append((env_act, net_reward))

In [None]:
%debug