In [1]:
from env_CatchPigs import EnvCatchPigs

import math
import numpy as np

import torch
import torch.nn as nn
import torchvision.models as models

In [2]:
class ReplayBuffer():
    def __init__(self, size, obs_dims):
        """
        size :: int
        obs_dims :: (int, ... , int)
        """
        self.size = size
        self.ctr = 0
        self.obs_buffer = np.zeros((size, *obs_dims))
        self._obs_buffer = np.zeros((size, *obs_dims))
        self.act_buffer = np.zeros(size)
        self.rew_buffer = np.zeros(size)
        self.done_buffer = np.zeros(size)


    def push(self, obs, act, rew, _obs, done):
        idx = self.ctr % self.size
        self.obs_buffer[idx] = obs
        self.act_buffer[idx] = act
        self.rew_buffer[idx] = rew
        self._obs_buffer[idx] = _obs
        self.done_buffer[idx] = done
        self.ctr += 1


    def sample(self, batch_size):
        """
        TO DO: Maybe redesign to return non repeating indxs!!
        """
        hi = min(self.size - 1, self.ctr - 1)
        idxs = np.random.randint(0, hi, batch_size)
        return (self.obs_buffer[idxs], self.act_buffer[idxs], self.rew_buffer[idxs], 
            self._obs_buffer[idxs], self.done_buffer[idxs])

In [3]:
class DQNAgent():
    def __init__(self, obs_dims, act_dim, lr=1e-3, gamma=0.99, replay_buffer_size=10000,
     batch_size=64, epsilon_min=0.01, epsilon_dec=5e-4, target_update_frequency=64):
        
        self.obs_dims = obs_dims
        self.act_dim = act_dim
        self.batch_size = batch_size
        self.gamma = gamma
        self.target_update_frequency = target_update_frequency
        self.learn_ctr = 0
        
        self.buffer = ReplayBuffer(replay_buffer_size, obs_dims)
        self.q_eval = models.mobilenet_v2(num_classes=act_dim)
        self.q_target = models.mobilenet_v2(num_classes=act_dim)
        
        
        self.epsilon = 1
        self.epsilon_min = epsilon_min
        self.epsilon_dec = epsilon_dec
        
        self.optimizer = torch.optim.Adam(self.q_eval.parameters(), lr=lr)
        self.loss_fn = torch.nn.MSELoss()


    def update_target(self):
        if self.learn_ctr % self.target_update_frequency == 0:
            self.q_target.load_state_dict(self.q_eval.state_dict())


    def decrement_epsilon(self):
        if self.epsilon > self.epsilon_min:
            self.epsilon = self.epsilon - self.epsilon_dec


    def choose_action(self, obs):
        if np.random.sample() < self.epsilon:
            return np.random.randint(self.act_dim)
        else:
            self.q_eval.eval()
            obs = torch.tensor(np.expand_dims(obs, axis=0), dtype=torch.float)
            return torch.argmax(self.q_eval(obs)).item()


    def store_transition(self, obs, act, rew, _obs, done):
        self.buffer.push(obs, act, rew, _obs, done)


    def sample_replay_buffer(self):
        return self.buffer.sample(self.batch_size)


    def learn(self):
        if self.buffer.ctr < self.batch_size:
            return
        self.q_eval.train()
        self.optimizer.zero_grad()
        obs, act, rew, _obs, done = self.sample_replay_buffer()
        obs = torch.tensor(obs, dtype=torch.float)
        act = torch.tensor(act, dtype=torch.long)
        rew = torch.tensor(rew, dtype=torch.long)
        _obs = torch.tensor(_obs, dtype=torch.float)
        done = torch.tensor(done, dtype=torch.long)
        idxs = torch.tensor(np.arange(self.batch_size), dtype=torch.long)
        q_pred = self.q_eval(obs)[idxs, act]
        q_next = self.q_target(_obs).max(dim=1)[0]
        q_target = rew + (1 - done) * self.gamma * q_next
        loss = self.loss_fn(q_target, q_pred)
        loss.backward()
        self.optimizer.step()
        self.update_target()
        self.decrement_epsilon()

In [4]:
def swap_axes(arr):
    arr = np.swapaxes(arr, 0, 2)
    arr = np.swapaxes(arr, 1, 2)
    return arr

In [5]:
def train(max_iter, env, agent1, agent2):
    
    obs_list = env.get_obs()
    obs1 = swap_axes(obs_list[0])
    obs2 = swap_axes(obs_list[1])
    for i in range(max_iter):
        act1 = agent1.choose_action(obs1)
        act2 = agent2.choose_action(obs2)
        act_list = [act1, act2]
        # print("iter= ", i, env.agt1_pos, env.agt2_pos, env.pig_pos, env.agt1_ori, env.agt2_ori, 'action', act1, act2)
        env.render()
        rew_list, done = env.step(act_list)
        rew1 = rew_list[0]
        rew2 = rew_list[1]
        # print(rew1)
        _obs_list = env.get_obs()
        _obs1 = swap_axes(_obs_list[0])
        _obs2 = swap_axes(_obs_list[1])
        agent1.store_transition(obs1, act1, rew1, _obs1, done)
        agent2.store_transition(obs2, act2, rew1, _obs2, done)
        obs1 = _obs1
        obs2 = _obs2
        agent1.learn()
        agent2.learn()
        #env.plot_scene()
        if rew1 + rew2 > 0:
            print("iter= ", i)
            print("Goal found!")

        if i % 100 == 0:
            print(f"Iter: {i}, Epsilon:{agent1.epsilon}")

In [6]:
env = EnvCatchPigs(7, False)
max_iter = 1000000

dqn1 = DQNAgent((3, 21, 21), 4)
dqn2 = DQNAgent((3, 21, 21), 4)

train(max_iter, env, dqn1, dqn2)

size of map should be an odd integer no smaller than 7
Iter: 0, Epsilon:1
Iter: 100, Epsilon:0.9810000000000021
Iter: 200, Epsilon:0.9310000000000076
Iter: 300, Epsilon:0.8810000000000131
Iter: 400, Epsilon:0.8310000000000186
Iter: 500, Epsilon:0.7810000000000241
Iter: 600, Epsilon:0.7310000000000296
Iter: 700, Epsilon:0.6810000000000351
Iter: 800, Epsilon:0.6310000000000406
Iter: 900, Epsilon:0.5810000000000461
Iter: 1000, Epsilon:0.5310000000000517
Iter: 1100, Epsilon:0.48100000000005505
Iter: 1200, Epsilon:0.431000000000055
Iter: 1300, Epsilon:0.38100000000005496
Iter: 1400, Epsilon:0.3310000000000549
Iter: 1500, Epsilon:0.2810000000000549
Iter: 1600, Epsilon:0.23100000000005483
Iter: 1700, Epsilon:0.18100000000005478
Iter: 1800, Epsilon:0.13100000000005474
Iter: 1900, Epsilon:0.0810000000000547
Iter: 2000, Epsilon:0.03100000000005465
Iter: 2100, Epsilon:0.009500000000054631
Iter: 2200, Epsilon:0.009500000000054631
Iter: 2300, Epsilon:0.009500000000054631
Iter: 2400, Epsilon:0.00950

KeyboardInterrupt: 