In [2]:
import gym
import torch
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import collections
import copy
import time
import random
import wandb
import torchvision.transforms as T
# wandb.init(project="RL-Lec-Project", entity="nninept")
print(torch.cuda.is_available())
device = 'cuda'

True


In [3]:
class ReplayBuffer():
    def __init__(self):
        self.buffer = collections.deque(maxlen = 5000)
        self.minibatch_size = 64

    def append(self, state, action, reward, terminal, next_state):
        self.buffer.append([state, action, reward, terminal, next_state])

    def sample(self):
        mini_batch = random.sample(self.buffer, self.minibatch_size)
        #mini_batch.append(self.buffer[-1])
        s_lst, a_lst, r_lst, terminal, s_prime_lst = map(torch.FloatTensor, zip(*mini_batch))
        return s_lst.to(device), np.array(a_lst), r_lst.to(device), terminal.to(device), s_prime_lst.to(device)
    
    def size(self):
        return len(self.buffer)

In [4]:
def init_weights(l):
    if type(l) == nn.Linear:
        nn.init.xavier_normal_(l.weight, gain=0.25)
        
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(4, 128).to(device)
        self.fc2 = nn.Linear(128, 128).to(device)
        self.fc3 = nn.Linear(128, 2).to(device)

    def forward(self, x):
        x = F.relu(self.fc1(x.to(device)))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x


In [5]:
class Agent():    
    def __init__(self, learning_rate=0.001):
        self.replay_buffer = ReplayBuffer()
        self.networkA = Net()
        self.networkB = Net()
        self.targetNetA = Net()
        self.targetNetB = Net()
        self.targetNetA.load_state_dict(self.networkA.state_dict())
        self.targetNetB.load_state_dict(self.networkB.state_dict())
        self.optimizerA = optim.RMSprop(self.networkA.parameters(), lr = learning_rate)
        self.optimizerB = optim.RMSprop(self.networkB.parameters(), lr = learning_rate)
        self.env = gym.make('CartPole-v1')
        self.num_replay = 10
        self.last_state = None
        self.discount = 0.98
        self.updateTerm = 100
        self.epsilon = 0.1
        self.tau = 0.1

        self.maxStep = 0
        self.updateA_count = 0
        self.updateB_count = 0
        
    def sample_action(self, state, mode="train"):
        pi = (self.targetNetA(torch.from_numpy(state).float().to(device)) + self.targetNetB(torch.from_numpy(state).float().to(device)))/2
        if mode == "train":
          if np.random.random() < self.epsilon:
              action = np.random.randint(2)
              return action
          else : 
              action = torch.argmax(pi).item()
              return action
        else:
            action = torch.argmax(pi).item()
            return action
        
    def train(self, epi):
        totalStep = 0
        self.last_state = self.env.reset()
        while True:
            # self.env.render()
            totalStep += 1

            action = self.sample_action(self.last_state)
            state, reward, done, info = self.env.step(action)
            # wandb.log({"reward": reward})
            self.replay_buffer.append(self.last_state, action, reward, done, state) 
            if self.replay_buffer.size() >= self.replay_buffer.minibatch_size:
                for _ in range(self.num_replay):
                    if(np.random.random() < 0.5):
                        self.optimize_network(self.networkA, self.targetNetB, "A")
                        self.updateA_count += 1
                    else:
                        self.optimize_network(self.networkB, self.targetNetA, "B")
                        self.updateB_count += 1


            if done:
                break
            self.last_state = state

        if(epi%self.updateTerm == 0):
            # self.soft_update(self.networkA, self.targetNetA, self.tau)
            # self.soft_update(self.networkB, self.targetNetB, self.tau)
            self.targetNetA.load_state_dict(self.networkA.state_dict())
            self.targetNetB.load_state_dict(self.networkB.state_dict())
            print(f'#episode : {epi} \t avg_step : {totalStep/self.updateTerm} \t Update_A_ratio : {self.updateA_count / (self.updateA_count + self.updateB_count)} \t Update_B_ratio : {self.updateB_count / (self.updateA_count + self.updateB_count)}')
            totalStep = 0
            self.updateA_count = 0  
            self.updateB_count = 0
            
        if(self.maxStep < totalStep):
            torch.save({
                "targetA" : self.targetNetA.state_dict(),
                "targetB" : self.targetNetA.state_dict()}, './checkpoint/DDQN_best.pt')
            self.maxStep = totalStep
            print("Save Best")
            
    def optimize_network(self, network, targetNet, networkType):
        states, actions, rewards, terminals, next_states = self.replay_buffer.sample()
        
        v_next_vec = torch.max(targetNet(next_states), dim = -1)[0]*(1-terminals)
        target_vec = rewards + self.discount*v_next_vec
        
        q_mat = network(states)
        batch_indices = np.arange(self.replay_buffer.minibatch_size)
        q_vec = q_mat[batch_indices,actions]
        
        loss = F.mse_loss(q_vec,target_vec.detach())

        if(networkType=="A"):
            self.optimizerA.zero_grad()
            loss.backward()
            self.optimizerA.step()
        else:
            self.optimizerB.zero_grad()
            loss.backward()
            self.optimizerB.step()

    def soft_update(self, local_model, target_model, tau):
        for target_param, local_param in zip(target_model.parameters(), local_model.parameters()):
            target_param.data.copy_(tau*local_param.data + (1.0-tau)*target_param.data)



In [6]:
# wandb.config = {
#   "learning_rate": 0.001,
#   "epochs": 5000,
#   "batch_size": 64
# }
model = Agent(learning_rate=0.001)

In [25]:
pre = time.time()
for epi in range(5000):
    model.train(epi+1)
post = time.time()
print(post-pre)

Save Best
Save Best
#episode : 100 	 avg_step : 0.12 	 Update_A_ratio : 0.4990639625585023 	 Update_B_ratio : 0.5009360374414976
#episode : 200 	 avg_step : 0.08 	 Update_A_ratio : 0.4975296442687747 	 Update_B_ratio : 0.5024703557312253
#episode : 300 	 avg_step : 0.12 	 Update_A_ratio : 0.4908839779005525 	 Update_B_ratio : 0.5091160220994475
Save Best
Save Best
Save Best
#episode : 400 	 avg_step : 0.09 	 Update_A_ratio : 0.49575551782682514 	 Update_B_ratio : 0.5042444821731749
Save Best
#episode : 500 	 avg_step : 0.11 	 Update_A_ratio : 0.5052316890881914 	 Update_B_ratio : 0.4947683109118087
#episode : 600 	 avg_step : 0.12 	 Update_A_ratio : 0.5019966722129784 	 Update_B_ratio : 0.4980033277870216
#episode : 700 	 avg_step : 0.11 	 Update_A_ratio : 0.49951026119402986 	 Update_B_ratio : 0.5004897388059701
Save Best
Save Best
Save Best
Save Best
#episode : 800 	 avg_step : 0.94 	 Update_A_ratio : 0.5008080808080808 	 Update_B_ratio : 0.4991919191919192
Save Best
Save Best
Save B

KeyboardInterrupt: 

In [9]:
testModel = Agent()

env = gym.make('CartPole-v1') 
state = env.reset()
i = 0

checkpoint = torch.load("./checkpoint/DDQN_best.pt")
testModel.targetNetA.load_state_dict(checkpoint['targetA'])
testModel.targetNetB.load_state_dict(checkpoint['targetB'])

while True:
    env.render()
    action = testModel.sample_action(state,"test")
    state, reward, done, _= env.step(action)
    i+=1
    if(done):
        env.close()
        break