In [None]:
import numpy as np
import pandas as pd
import os
import math
import random
import matplotlib.pyplot as plt
from collections import defaultdict
import glob
import time
from datetime import datetime

import torch
import torch.nn as nn
from torch.distributions import MultivariateNormal
from torch.distributions import Categorical

In [None]:
# set device to cpu or cuda
device = torch.device('cpu')

if(torch.cuda.is_available()): 
    device = torch.device('cuda:0') 
    torch.cuda.empty_cache()
    print("Device set to : " + str(torch.cuda.get_device_name(device)))
else:
    print("Device set to : cpu")
    
print("============================================================================================")

In [None]:
from ipynb.fs.full.choose_jurisdiction import *
from ipynb.fs.full.step_function import *
from ipynb.fs.full.PPO_algo import *

In [None]:
################################### Training ###################################


####### initialize environment hyperparameters ######

action_std_decay_rate = 0.0046
min_action_std = 0.05
action_std_decay_freq = 1000




has_continuous_action_space = True

max_ep_len = 12                   # max timesteps in one episode
max_training_timesteps = 100000   # break training loop if timeteps > max_training_timesteps

print_freq = max_ep_len * 10     # print avg reward in the interval (in num timesteps)
log_freq = max_ep_len * 2       # log avg reward in the interval (in num timesteps)
save_model_freq = 2000      # save model frequency (in num timesteps)
plot_freq = 1200

action_std = 0.4

#####################################################


## Note : print/log frequencies should be > than max_ep_len


################ PPO hyperparameters ################


update_timestep = 120     # update policy every n timesteps
K_epochs = 20               # update policy for K epochs
eps_clip = 0.2              # clip parameter for PPO
gamma_ = 0.99                # discount factor

lr_actor = 0.0003       # learning rate for actor network
lr_critic = 0.0003       # learning rate for critic network

random_seed = 10   # set random seed if required (0 = no random seed)

#####################################################

env_name = 'HIV Jurisdiction'

print("training environment name : " + env_name)

# env = gym.make(env_name)

# state space dimension
obs_dim = 15
state_dim = 120

# action space dimension
if has_continuous_action_space:
    action_dim = 9
else:
    action_dim = 1
    
###################### logging ######################

#### log files for multiple runs are NOT overwritten

log_dir = "PPO_logs"
if not os.path.exists(log_dir):
      os.makedirs(log_dir)

log_dir = log_dir + '/' + env_name + '/'
if not os.path.exists(log_dir):
      os.makedirs(log_dir)


#### get number of log files in log directory
run_num = 0
current_num_files = next(os.walk(log_dir))[2]
run_num = len(current_num_files)


#### create new log file for each run 
log_f_name = log_dir + '/PPO_' + env_name + "_log_" + str(run_num) + ".csv"

print("current logging run number for " + env_name + " : ", run_num)
print("logging at : " + log_f_name)

#####################################################


################### checkpointing ###################

run_num_pretrained = 0      #### change this to prevent overwriting weights in same env_name folder

directory = "PPO_preTrained"
if not os.path.exists(directory):
      os.makedirs(directory)

directory = directory + '/' + env_name + '/' 
if not os.path.exists(directory):
      os.makedirs(directory)


checkpoint_path1 = directory + 'Cluster1' + "PPO_{}_{}_{}.pth".format(env_name, random_seed, run_num_pretrained)
checkpoint_path2 = directory + 'Cluster2' + "PPO_{}_{}_{}.pth".format(env_name, random_seed, run_num_pretrained)
checkpoint_path3 = directory + 'Cluster3' + "PPO_{}_{}_{}.pth".format(env_name, random_seed, run_num_pretrained)
checkpoint_path4 = directory + 'Cluster4' + "PPO_{}_{}_{}.pth".format(env_name, random_seed, run_num_pretrained)
checkpoint_path5 = directory + 'Cluster5' + "PPO_{}_{}_{}.pth".format(env_name, random_seed, run_num_pretrained)
checkpoint_path6 = directory + 'Cluster6' + "PPO_{}_{}_{}.pth".format(env_name, random_seed, run_num_pretrained)
checkpoint_path7 = directory + 'Cluster7' + "PPO_{}_{}_{}.pth".format(env_name, random_seed, run_num_pretrained)
checkpoint_path8 = directory + 'Cluster8' + "PPO_{}_{}_{}.pth".format(env_name, random_seed, run_num_pretrained)

print("save checkpoint path : " + checkpoint_path1)

#####################################################

############# print all hyperparameters #############

print("--------------------------------------------------------------------------------------------")

print("max training timesteps : ", max_training_timesteps)
print("max timesteps per episode : ", max_ep_len)

print("model saving frequency : " + str(save_model_freq) + " timesteps")
print("log frequency : " + str(log_freq) + " timesteps")
print("printing average reward over episodes in last : " + str(print_freq) + " timesteps")

print("--------------------------------------------------------------------------------------------")

print("observation space dimension : ", obs_dim)
print("state space dimension : ", state_dim)
print("action space dimension : ", action_dim)

print("--------------------------------------------------------------------------------------------")

if has_continuous_action_space:
    print("Initializing a continuous action space policy")
    print("--------------------------------------------------------------------------------------------")
    print("starting std of action distribution : ", action_std)
    print("decay rate of std of action distribution : ", action_std_decay_rate)
    print("minimum std of action distribution : ", min_action_std)
    print("decay frequency of std of action distribution : " + str(action_std_decay_freq) + " timesteps")

else:
    print("Initializing a discrete action space policy")

print("--------------------------------------------------------------------------------------------")

print("PPO update frequency : " + str(update_timestep) + " timesteps") 
print("PPO K epochs : ", K_epochs)
print("PPO epsilon clip : ", eps_clip)
print("discount factor (gamma_) : ", gamma_)

print("--------------------------------------------------------------------------------------------")

print("optimizer learning rate actor : ", lr_actor)
print("optimizer learning rate critic : ", lr_critic)

if random_seed:
    print("--------------------------------------------------------------------------------------------")
    print("setting random seed to ", random_seed)
    torch.manual_seed(random_seed)
#     env.seed(random_seed)
    np.random.seed(random_seed)

#####################################################

print("============================================================================================")

In [None]:
################# training procedure ################

# initialize a PPO agent
ppo_agent1 = PPO(obs_dim, state_dim, action_dim, lr_actor, lr_critic, gamma_, K_epochs, eps_clip, has_continuous_action_space, action_std)
ppo_agent2 = PPO(obs_dim, state_dim, action_dim, lr_actor, lr_critic, gamma_, K_epochs, eps_clip, has_continuous_action_space, action_std)
ppo_agent3 = PPO(obs_dim, state_dim, action_dim, lr_actor, lr_critic, gamma_, K_epochs, eps_clip, has_continuous_action_space, action_std)
ppo_agent4 = PPO(obs_dim, state_dim, action_dim, lr_actor, lr_critic, gamma_, K_epochs, eps_clip, has_continuous_action_space, action_std)
ppo_agent5 = PPO(obs_dim, state_dim, action_dim, lr_actor, lr_critic, gamma_, K_epochs, eps_clip, has_continuous_action_space, action_std)
ppo_agent6 = PPO(obs_dim, state_dim, action_dim, lr_actor, lr_critic, gamma_, K_epochs, eps_clip, has_continuous_action_space, action_std)
ppo_agent7 = PPO(obs_dim, state_dim, action_dim, lr_actor, lr_critic, gamma_, K_epochs, eps_clip, has_continuous_action_space, action_std)
ppo_agent8 = PPO(obs_dim, state_dim, action_dim, lr_actor, lr_critic, gamma_, K_epochs, eps_clip, has_continuous_action_space, action_std)

# track total training time
start_time = datetime.now().replace(microsecond=0)
print("Started training at (GMT) : ", start_time)

print("============================================================================================")


# logging file
log_f = open(log_f_name,"w+")
log_f.write('episode,timestep,reward\n')

# rew_list = []
rew_list1 = []
rew_list2 = []
rew_list3 = []
rew_list4 = []
rew_list5 = []
rew_list6 = []
rew_list7 = []
rew_list8 = []

ep_rew_list1 = []
ep_rew_list2 = []
ep_rew_list3 = []
ep_rew_list4 = []
ep_rew_list5 = []
ep_rew_list6 = []
ep_rew_list7 = []
ep_rew_list8 = []

# printing and logging variables
print_running_reward1 = 0
print_running_reward2 = 0
print_running_reward3 = 0
print_running_reward4 = 0
print_running_reward5 = 0
print_running_reward6 = 0
print_running_reward7 = 0
print_running_reward8 = 0

print_running_episodes = 0

log_running_reward = 0
log_running_episodes = 0

time_step = 0
i_episode = 0


# ESS_SOC_max = 1500
# T = 24
# training loop
while time_step <= max_training_timesteps:
    
    state = initial_state(data_array_cluster,prep_values)

    current_ep_reward1 = 0
    current_ep_reward2 = 0
    current_ep_reward3 = 0
    current_ep_reward4 = 0
    current_ep_reward5 = 0
    current_ep_reward6 = 0
    current_ep_reward7 = 0
    current_ep_reward8 = 0
       
    for t in range(0, max_ep_len+1):
        
        # select action with policy
        
        full_state = np.vstack((state[1],state[2],state[3],state[4],state[5],state[6],state[7],state[8])).flatten()

        action1 = ppo_agent1.select_action(state[1].flatten())
        action2 = ppo_agent2.select_action(state[2].flatten())
        action3 = ppo_agent3.select_action(state[3].flatten())
        action4 = ppo_agent4.select_action(state[4].flatten())
        action5 = ppo_agent5.select_action(state[5].flatten())
        action6 = ppo_agent6.select_action(state[6].flatten())
        action7 = ppo_agent7.select_action(state[7].flatten())
        action8 = ppo_agent8.select_action(state[8].flatten())   


        state,reward1,reward2,reward3,reward4,reward5,reward6,reward7,reward8,done  = step(state, action1, action2, action3, action4, action5, action6, action7, action8)

        # saving reward and is_terminals
        ppo_agent1.buffer.full_states.append(torch.FloatTensor(full_state).to(device))
        ppo_agent1.buffer.actions_others.append(torch.FloatTensor(np.hstack((action2,action3,action4,action5,action6,action7,action8))).to(device))
        ppo_agent1.buffer.rewards.append(reward1)
        ppo_agent1.buffer.is_terminals.append(done)
        
        ppo_agent2.buffer.full_states.append(torch.FloatTensor(full_state).to(device))
        ppo_agent2.buffer.actions_others.append(torch.FloatTensor(np.hstack((action1,action3,action4,action5,action6,action7,action8))).to(device))
        ppo_agent2.buffer.rewards.append(reward2)
        ppo_agent2.buffer.is_terminals.append(done)
        
        ppo_agent3.buffer.full_states.append(torch.FloatTensor(full_state).to(device))
        ppo_agent3.buffer.actions_others.append(torch.FloatTensor(np.hstack((action1,action2,action4,action5,action6,action7,action8))).to(device))
        ppo_agent3.buffer.rewards.append(reward3)
        ppo_agent3.buffer.is_terminals.append(done)
        
        ppo_agent4.buffer.full_states.append(torch.FloatTensor(full_state).to(device))
        ppo_agent4.buffer.actions_others.append(torch.FloatTensor(np.hstack((action1,action2,action3,action5,action6,action7,action8))).to(device))
        ppo_agent4.buffer.rewards.append(reward4)
        ppo_agent4.buffer.is_terminals.append(done)
        
        ppo_agent5.buffer.full_states.append(torch.FloatTensor(full_state).to(device))
        ppo_agent5.buffer.actions_others.append(torch.FloatTensor(np.hstack((action1,action2,action3,action4,action6,action7,action8))).to(device))
        ppo_agent5.buffer.rewards.append(reward5)
        ppo_agent5.buffer.is_terminals.append(done)
        
        ppo_agent6.buffer.full_states.append(torch.FloatTensor(full_state).to(device))
        ppo_agent6.buffer.actions_others.append(torch.FloatTensor(np.hstack((action1,action2,action3,action4,action5,action7,action8))).to(device))
        ppo_agent6.buffer.rewards.append(reward6)
        ppo_agent6.buffer.is_terminals.append(done)
        
        ppo_agent7.buffer.full_states.append(torch.FloatTensor(full_state).to(device))
        ppo_agent7.buffer.actions_others.append(torch.FloatTensor(np.hstack((action1,action2,action3,action4,action5,action6,action8))).to(device))
        ppo_agent7.buffer.rewards.append(reward7)
        ppo_agent7.buffer.is_terminals.append(done)
        
        ppo_agent8.buffer.full_states.append(torch.FloatTensor(full_state).to(device))
        ppo_agent8.buffer.actions_others.append(torch.FloatTensor(np.hstack((action1,action2,action3,action4,action5,action6,action7))).to(device))
        ppo_agent8.buffer.rewards.append(reward8)
        ppo_agent8.buffer.is_terminals.append(done)

        
        time_step +=1
#         ltc_risk += ltc_increment
        current_ep_reward1 += reward1
        current_ep_reward2 += reward2
        current_ep_reward3 += reward3
        current_ep_reward4 += reward4
        current_ep_reward5 += reward5
        current_ep_reward6 += reward6
        current_ep_reward7 += reward7
        current_ep_reward8 += reward8

        
        # update PPO agent
        if time_step % update_timestep == 0:
            ppo_agent1.update()
            ppo_agent2.update()
            ppo_agent3.update()
            ppo_agent4.update()
            ppo_agent5.update()
            ppo_agent6.update()
            ppo_agent7.update()
            ppo_agent8.update()


        # if continuous action space; then decay action std of ouput action distribution
        if has_continuous_action_space and time_step % action_std_decay_freq == 0:
            ppo_agent1.decay_action_std(action_std_decay_rate, min_action_std)
            ppo_agent2.decay_action_std(action_std_decay_rate, min_action_std)
            ppo_agent3.decay_action_std(action_std_decay_rate, min_action_std)
            ppo_agent4.decay_action_std(action_std_decay_rate, min_action_std)
            ppo_agent5.decay_action_std(action_std_decay_rate, min_action_std)
            ppo_agent6.decay_action_std(action_std_decay_rate, min_action_std)
            ppo_agent7.decay_action_std(action_std_decay_rate, min_action_std)
            ppo_agent8.decay_action_std(action_std_decay_rate, min_action_std)


        # log in logging file
        if time_step % log_freq == 0:

            # log average reward till last episode
            log_avg_reward = log_running_reward / log_running_episodes
            log_avg_reward = round(log_avg_reward, 4)

            log_f.write('{},{},{}\n'.format(i_episode, time_step, log_avg_reward))
            log_f.flush()

            log_running_reward = 0
            log_running_episodes = 0

        # printing average reward
        if time_step % print_freq == 0:

            # print average reward till last episode
            print_avg_reward1 = print_running_reward1 / print_running_episodes
            print_avg_reward1 = round(print_avg_reward1, 2)
            
            print_avg_reward2 = print_running_reward2 / print_running_episodes
            print_avg_reward2 = round(print_avg_reward2, 2)
            
            print_avg_reward3 = print_running_reward3 / print_running_episodes
            print_avg_reward3 = round(print_avg_reward3, 2)

            print_avg_reward4 = print_running_reward4 / print_running_episodes
            print_avg_reward4 = round(print_avg_reward4, 2)
            
            print_avg_reward5 = print_running_reward5 / print_running_episodes
            print_avg_reward5 = round(print_avg_reward5, 2)
            
            print_avg_reward6 = print_running_reward6 / print_running_episodes
            print_avg_reward6 = round(print_avg_reward6, 2)

            print_avg_reward7 = print_running_reward7 / print_running_episodes
            print_avg_reward7 = round(print_avg_reward7, 2)
            
            print_avg_reward8 = print_running_reward8 / print_running_episodes
            print_avg_reward8 = round(print_avg_reward8, 2)
            

            print("Agent1 => Episode : {} \t\t Timestep : {} \t\t Average Reward : {}".format(i_episode, time_step, print_avg_reward1))
            print("Agent2 => Episode : {} \t\t Timestep : {} \t\t Average Reward : {}".format(i_episode, time_step, print_avg_reward2))
            print("Agent3 => Episode : {} \t\t Timestep : {} \t\t Average Reward : {}".format(i_episode, time_step, print_avg_reward3))
            print("Agent4 => Episode : {} \t\t Timestep : {} \t\t Average Reward : {}".format(i_episode, time_step, print_avg_reward4))
            print("Agent5 => Episode : {} \t\t Timestep : {} \t\t Average Reward : {}".format(i_episode, time_step, print_avg_reward5))
            print("Agent6 => Episode : {} \t\t Timestep : {} \t\t Average Reward : {}".format(i_episode, time_step, print_avg_reward6))
            print("Agent7 => Episode : {} \t\t Timestep : {} \t\t Average Reward : {}".format(i_episode, time_step, print_avg_reward7))
            print("Agent8 => Episode : {} \t\t Timestep : {} \t\t Average Reward : {}".format(i_episode, time_step, print_avg_reward8))
           
            rew_list1.append(print_avg_reward1)
            rew_list2.append(print_avg_reward2)
            rew_list3.append(print_avg_reward3)
            rew_list4.append(print_avg_reward4)
            rew_list5.append(print_avg_reward5)
            rew_list6.append(print_avg_reward6)
            rew_list7.append(print_avg_reward7)
            rew_list8.append(print_avg_reward8)
           
            print_running_reward1 = 0
            print_running_reward2 = 0
            print_running_reward3 = 0
            print_running_reward4 = 0
            print_running_reward5 = 0
            print_running_reward6 = 0
            print_running_reward7 = 0
            print_running_reward8 = 0
  

            print_running_episodes = 0
            
        if time_step % plot_freq == 0:
            fig,ax = plt.subplots(2,4,sharex=False, sharey=False, figsize=(20,7))
            fig.tight_layout(h_pad=3, w_pad=1)
            
            ax[0][0].plot(range(len(rew_list1)), rew_list1)
            ax[0][0].set_title('Rewards for Cluster 1',pad=12)
            ax[0][1].plot(range(len(rew_list2)), rew_list2)
            ax[0][1].set_title('Rewards for Cluster 2',pad=12)
            ax[0][2].plot(range(len(rew_list3)), rew_list3)
            ax[0][2].set_title('Rewards for Cluster 3',pad=12)
            ax[0][3].plot(range(len(rew_list4)), rew_list4)
            ax[0][3].set_title('Rewards for Cluster 4',pad=12)
            ax[1][0].plot(range(len(rew_list5)), rew_list5)
            ax[1][0].set_title('Rewards for Cluster 5',pad=12)
            ax[1][1].plot(range(len(rew_list6)), rew_list6)
            ax[1][1].set_title('Rewards for Cluster 6',pad=12)
            ax[1][2].plot(range(len(rew_list7)), rew_list7)
            ax[1][2].set_title('Rewards for Cluster 7',pad=12)
            ax[1][3].plot(range(len(rew_list8)), rew_list8)
            ax[1][3].set_title('Rewards for Cluster 8',pad=12)

            
#             fig.suptitle('Rewards for all the agents')
            plt.show()
            
            
        # save model weights
        if time_step % save_model_freq == 0:
            print("--------------------------------------------------------------------------------------------")
            print("saving model at : " + checkpoint_path1)
#             print("saving model at : " + checkpoint_path2)
#             print("saving model at : " + checkpoint_path3)
            ppo_agent1.save(checkpoint_path1)
            ppo_agent2.save(checkpoint_path2)
            ppo_agent3.save(checkpoint_path3)
            ppo_agent4.save(checkpoint_path4)
            ppo_agent5.save(checkpoint_path5)
            ppo_agent6.save(checkpoint_path6)
            ppo_agent7.save(checkpoint_path7)
            ppo_agent8.save(checkpoint_path8)

            print("model saved")
            print("Elapsed Time  : ", datetime.now().replace(microsecond=0) - start_time)
            print("--------------------------------------------------------------------------------------------")
            
        # break; if the episode is over
        if done:
            break
            

    print_running_reward1 += current_ep_reward1
    print_running_reward2 += current_ep_reward2
    print_running_reward3 += current_ep_reward3
    print_running_reward4 += current_ep_reward4
    print_running_reward5 += current_ep_reward5
    print_running_reward6 += current_ep_reward6
    print_running_reward7 += current_ep_reward7
    print_running_reward8 += current_ep_reward8
    
    ep_rew_list1.append(current_ep_reward1)
    ep_rew_list2.append(current_ep_reward2)
    ep_rew_list3.append(current_ep_reward3)
    ep_rew_list4.append(current_ep_reward4)
    ep_rew_list5.append(current_ep_reward5)
    ep_rew_list6.append(current_ep_reward6)
    ep_rew_list7.append(current_ep_reward7)
    ep_rew_list8.append(current_ep_reward8)

    print_running_episodes += 1

    log_running_reward += current_ep_reward1
    log_running_episodes += 1

    i_episode += 1


log_f.close()
# env.close()