In [1]:
import torch
from torch import optim
import torch.nn as nn
import torch.nn.functional as F
import gym
import numpy as np
from matplotlib import pyplot as plt
%matplotlib inline
import random
import time

In [2]:
#Network
class Policy_net(nn.Module):
    
    def __init__(self,state_dim, action_dim):
        super(Policy_net,self).__init__()
        self.bn1 = nn.BatchNorm1d(num_features = state_dim)
        self.linear1 = nn.Linear(in_features = state_dim, out_features = 64)
        self.activation1 = nn.ReLU()
        self.bn2 = nn.BatchNorm1d(num_features = 64)
        self.linear2 = nn.Linear(in_features = 64, out_features = 32)
        self.activation2 = nn.ReLU()
        self.bn3 = nn.BatchNorm1d(num_features=32)
        self.linear3 = nn.Linear(in_features=32, out_features=action_dim)
        self.activation3 = nn.Tanh()
        
    def forward(self,x):
        x = self.activation1(self.linear1(self.bn1(x)))
        x = self.activation2(self.linear2(self.bn2(x)))
        x = self.activation3(self.linear3(self.bn3(x))) * 2
        return x
    
    
class Value_net(nn.Module):
    def __init__(self,state_dim, action_dim):
        super(Value_net,self).__init__()        
        self.network = []
        self.bn1 = nn.BatchNorm1d(num_features = state_dim)
        self.linear1 = nn.Linear(in_features = state_dim, out_features = 128)
        self.activation1 = nn.ReLU()
        self.bn2 = nn.BatchNorm1d(num_features = 128)
        self.linear2 = nn.Linear(in_features = 128, out_features = 32)
        self.activation2 = nn.ReLU()
        self.bn3 = nn.BatchNorm1d(num_features=32)
        
        self.linear3 = nn.Linear(in_features=action_dim, out_features=32)
        self.activation3 = nn.ReLU()
        self.linear4 = nn.Linear(in_features = 64, out_features=1)
                
    def forward(self, state, action):
        state = self.activation1(self.linear1(self.bn1(state)))
        state = self.bn3(self.activation2(self.linear2(self.bn2(state))))
        action = self.activation3(self.linear3(action))
        output = torch.cat((state,action),1)
        output = self.linear4(output)
        return output

In [3]:
#Agent
class Agent():
    def __init__(self, state_dim , action_range , seed ):
        
        self.action_range = action_range
        self.action_dim = len(action_range)
        
        self.target_policy_net = Policy_net(state_dim,self.action_dim)
        self.target_value_net = Value_net(state_dim,self.action_dim)
        self.learning_policy_net = Policy_net(state_dim,self.action_dim)
        self.learning_value_net = Value_net(state_dim,self.action_dim)
        
        self.target_policy_net.load_state_dict(self.learning_policy_net.state_dict())
        self.target_value_net.load_state_dict(self.learning_value_net.state_dict())
        
        self.gamma = 0.99
        self.actor_learning_rate = 0.0001
        self.critic_learning_rate = 0.0001
        self.batch_size = 64
        self.replay_buffer = []
        self.buffer_size = 50000
        self.buffer_index = 0
        self.seed = seed
        self.target_update = 0.001
        
        
        self.policy_optimizer = optim.Adam(self.learning_policy_net.parameters(), lr=self.actor_learning_rate)
        self.value_optimizer = optim.Adam(self.learning_value_net.parameters(), lr=self.critic_learning_rate)
        self.criterion = nn.MSELoss()
        self.plot_reward = []
        
        self.learning_policy_net.zero_grad()
        self.learning_value_net.zero_grad()
        
        np.random.seed(self.seed)
        torch.manual_seed(self.seed)
        
        
    def get_target_action(self,x):
        return self.target_policy_net(torch.Tensor(x)).detach()
    
    def get_learning_action(self,x):
        return self.learning_policy_net(torch.Tensor(x))
    
    def get_target_action_value(self,state,action):
        
        return self.target_value_net( torch.Tensor(state),torch.Tensor(action) ).detach()
    
    def get_learning_action_value(self,state,action):
        return self.learning_value_net(torch.Tensor(state),torch.Tensor(action))
    
 
        
        
    def cal_target_loss(self,batch):
        state_batch = [e[0] for e in batch]
        action_batch = [e[1] for e in batch]
        reward_batch = [[e[2]] for e in batch]
        next_state_batch = [e[3] for e in batch]
        
        
        #target_value = [self.get_target_value(reward_batch[i],next_state_batch[i]) for i in range(self.batch_size)]
        #target_value = torch.stack(target_value).detach()
        target_value = self.get_target_value(reward_batch,next_state_batch)
        
        learning_value = self.get_learning_action_value(state_batch, action_batch)
        
        loss = self.criterion(target_value, learning_value)
        '''if(loss > 1000):
            print("##########################target_value#########################")
            print(target_value)

            print("##########################learning_value#########################")
            print(learning_value)'''
        
        
        
        return loss
        
    def target_net_update(self):    
        for t_p_params , l_p_params in zip(self.target_policy_net.parameters() , self.learning_policy_net.parameters()):
            t_p_params.data.copy_( t_p_params.data * (1-self.target_update) + self.target_update * l_p_params.data )
        
        for t_v_params , l_v_params in zip(self.target_value_net.parameters() , self.learning_value_net.parameters()):
            t_v_params.data.copy_( t_v_params.data * (1-self.target_update) + self.target_update * l_v_params.data )
                
        
    def get_target_value(self,reward, next_state):
        action = self.get_target_action(next_state)
        target_action_value = self.get_target_action_value(next_state,action)

        target_value = torch.Tensor(reward) + self.gamma * target_action_value.squeeze(0)
        
        

        return target_value        
    def get_batch(self):
        batch = random.sample(self.replay_buffer, self.batch_size)
        return batch

    def set_network_train(self):
        self.learning_policy_net.train()
        self.learning_value_net.train()
        self.target_policy_net.train()
        self.target_value_net.train()        
        
######################################################################3    
    
    def train(self):
        
        self.learning_policy_net.zero_grad()
        self.learning_value_net.zero_grad()       
        self.target_policy_net.zero_grad()
        self.target_value_net.zero_grad()               
        
        self.set_network_train()
        
        batch = self.get_batch()
        
        
        target_loss =self.cal_target_loss(batch)
        target_loss.backward()
        self.value_optimizer.step()
        
        self.learning_value_net.zero_grad()
        
        state_batch = [e[0] for e in batch]
        d_action = self.get_learning_action(state_batch)
        action_value = self.get_learning_action_value(state_batch,d_action)
        
        mean_action_value = -torch.mean(action_value)
        mean_action_value.backward()
        self.policy_optimizer.step()
        
        self.learning_policy_net.zero_grad()
        self.learning_value_net.zero_grad()
        self.target_policy_net.zero_grad()
        self.target_value_net.zero_grad()               
        
        self.target_net_update()
        
        return target_loss , mean_action_value
       

    
        
    def get_noise_action(self,state):
        mean = self.get_learning_action(state)
        rand_value = np.random.rand(1)
        noise_action = []
        if(rand_value < 0.2):
            for i in range(len(mean)):
                action = (np.random.rand(1) * (self.action_range[i][1]-self.action_range[i][0]) - self.action_range[i][1]).item()
                noise_action.append([action])
            noise_action = np.array(noise_action)
            return noise_action
        else:
            return mean.detach().numpy()
                
                
        
    def buffer_update(self, replay):
        if(len(self.replay_buffer) < self.buffer_size):
            self.replay_buffer.append(replay)
            self.buffer_index += 1
        else:
            self.replay_buffer[self.buffer_index] = replay
            self.buffer_index += 1
        if(self.buffer_index == self.buffer_size):
            self.buffer_index=0
        
    def set_reward_plot(self, reward_sum):
        self.plot_reward.append(reward_sum)
        
    def set_network_exploration(self):
        self.learning_policy_net.eval()
        self.learning_value_net.eval()
        self.target_policy_net.eval()
        self.target_value_net.eval()
                
        
    def save(self,directory):
        torch.save({'learning_policy_net_state_dict' :  self.learning_policy_net.state_dict() ,'learning_value_net_state_dict' : self.learning_value_net.state_dict() , 'target_policy_net_state_dict' :  self.target_policy_net.state_dict() ,'target_value_net_state_dict' : self.target_value_net.state_dict() } , directory + '/pendulum_model_seed_{}.pth'.format(self.seed) )
        
        
    def load(self,file_name):
        checkpoint = torch.load(directory)
        
        self.learning_policy_net.load_state_dict(checkpoint['learning_policy_net_state_dict'])
        self.learning_value_net.load_state_dict(checkpoint['learning_value_net_state_dict'])
        self.target_policy_net.load_state_dict(checkpoint['target_policy_net_state_dict'])
        self.target_value_net.load_state_dict(checkpoint['target_value_net_state_dict'])

In [None]:
env = gym.make('Pendulum-v0')
action_range = []
for min_action, max_action in zip(env.action_space.low,env.action_space.high):
    action_range.append([min_action,max_action])
state_dim = len(env.observation_space.sample())
seed = 0
step = 0

agent = Agent(state_dim, action_range, seed)
num_iteration = 5000

for iteration in range(num_iteration):
    done = False
    reward_sum = 0
    observation = env.reset()
    state = observation
    value_loss_sum = 0
    policy_loss_sum = 0
    for t in range(1000):
        agent.set_network_exploration()

        action = agent.get_noise_action([state])
        action = action.squeeze(0)
        
        obs , reward, done, info = env.step(action)
        #reward /= 16.2736044
        episode = (state, action, reward, obs)
        agent.buffer_update(episode)

        step += 1
        if(step > 1000):
            value, policy = agent.train()
            value_loss_sum += value
            policy_loss_sum += policy
        state = obs[:]
        if(done):
            break
    observation = env.reset()
    state = observation
    for t in range(1000):
        agent.set_network_exploration()
        action = agent.get_learning_action([state])
        action = action.squeeze(0)
        obs , reward, done, info = env.step(action.detach().numpy())
        
        reward_sum += reward
        state = obs[:]
        if(done):
            break
    agent.set_reward_plot(reward_sum) 
    if(iteration%100 == 0):
        agent.save('./model_pendulum')
    print("value loss : {} ".format(value_loss_sum))
    print("policy loss : {} ".format(policy_loss_sum))
    print("reward at itr {} : {}".format(iteration, reward_sum))
    print("step : {}".format(step))





value loss : 0 
policy loss : 0 
reward at itr 0 : -1114.3846111880937
step : 200
value loss : 0 
policy loss : 0 
reward at itr 1 : -1345.7940663199881
step : 400
value loss : 0 
policy loss : 0 
reward at itr 2 : -1139.288008504123
step : 600
value loss : 0 
policy loss : 0 
reward at itr 3 : -1073.0910288906045
step : 800
value loss : 0 
policy loss : 0 
reward at itr 4 : -1202.4749522330512
step : 1000
value loss : 7921.4716796875 
policy loss : 8.66147518157959 
reward at itr 5 : -1653.974907981836
step : 1200
value loss : 6778.5400390625 
policy loss : 62.3460807800293 
reward at itr 6 : -1251.3148477961672
step : 1400
value loss : 6278.92333984375 
policy loss : 123.60250091552734 
reward at itr 7 : -1305.4472224922977
step : 1600
value loss : 6072.46533203125 
policy loss : 191.70538330078125 
reward at itr 8 : -861.6808054384652
step : 1800
value loss : 5571.83447265625 
policy loss : 269.5920715332031 
reward at itr 9 : -1838.9716799303635
step : 2000
value loss : 5006.476562

In [None]:
b