In [96]:
class Environment:
    class Hunter:
        def __init__(self, height, width):
            self.height = height
            self.width = width

        def initialize_state(self):
            state = np.random.rand(2)
            environment = np.array([self.height, self.width])
            state = state * environment
            return state
        
    class Prey:
        def __init__(self, height, width):
            self.height = height
            self.width = width

        def initialize_state(self):
            state = np.random.rand(2)
            environment = np.array([self.height, self.width])
            state = state * environment
            return state

        def move_away_from_hunters(self, prey_state, hunter_state1, hunter_state2, hunter_state3):
            hunters_g = (hunter_state1 + hunter_state2 + hunter_state3) / 3
            g_to_prey_state_vector = prey_state - hunters_g
            g_to_prey_state_vector_norm = g_to_prey_state_vector / np.linalg.norm(g_to_prey_state_vector)
            new_prey_state = prey_state + g_to_prey_state_vector_norm * 0.5
            return new_prey_state
    
    def __init__(self):
        # self.done = False
        self.max_episode_steps = 10
        self.reward1 = 1
        self.reward2 = -0.1
        self.reward3 = -0.5
        self.area_threshold = 0.1
        self.prey = self.Prey(1,1)
        self.hunter1 = self.Hunter(1,1)
        self.hunter2 = self.Hunter(1,1)
        self.hunter3 = self.Hunter(1,1)
        
        self.hunter_state1 = None
        self.hunter_state2 = None
        self.hunter_state3 = None
        self.prey_state = None
        
        self.initialize_state()
        
    def initialize_state(self):
        self.hunter_state1 = self.hunter1.initialize_state()
        self.hunter_state2 = self.hunter2.initialize_state()
        self.hunter_state3 = self.hunter3.initialize_state()
        self.prey_state = self.prey.initialize_state()
        
        return self.hunter_state1, self.hunter_state2, self.hunter_state3, self.prey_state

    def sigmoid(self, x):
        return np.exp(x) / (np.exp(x) + 1)
    
    def step(self, action1, action2, action3):
        reward = 0
        done = False
        
        self.prey_state = self.prey.move_away_from_hunters(self.prey_state, self.hunter_state1, self.hunter_state2, self.hunter_state3)
        self.hunter_state1 = self.hunter_state1 + np.array([self.sigmoid(action1[0]) * np.cos(np.pi * self.sigmoid(action1[1]))])
        self.hunter_state2 = self.hunter_state2 + np.array([self.sigmoid(action2[0]) * np.cos(np.pi * self.sigmoid(action2[1]))])
        self.hunter_state3 = self.hunter_state3 + np.array([self.sigmoid(action3[0]) * np.cos(np.pi * self.sigmoid(action3[1]))])
        
        hunter2_vector_from_hunter1 = self.hunter_state2 - self.hunter_state1
        hunter3_vector_from_hunter1 = self.hunter_state3 - self.hunter_state1
        prey_vector_from_hunter1 = self.prey_state - self.hunter_state1
        
        cross_h1_h2_p = np.cross(hunter2_vector_from_hunter1, prey_vector_from_hunter1)
        cross_h1_h2_h3 = np.cross(hunter2_vector_from_hunter1, hunter3_vector_from_hunter1)

        area_h1_h2_p = np.linalg.norm(cross_h1_h2_p)
        area_h1_h2_h3 = np.linalg.norm(cross_h1_h2_h3)

        cross_h1_h2_p_sign = np.sign(cross_h1_h2_p)
        cross_h1_h2_h3_sign = np.sign(cross_h1_h2_h3)

        cos_h1_h2_p = np.dot(hunter2_vector_from_hunter1, prey_vector_from_hunter1) / (np.linalg.norm(hunter2_vector_from_hunter1) * np.linalg.norm(prey_vector_from_hunter1))
        cos_h1_h2_h3 = np.dot(hunter2_vector_from_hunter1, hunter3_vector_from_hunter1) / (np.linalg.norm(hunter2_vector_from_hunter1) * np.linalg.norm(hunter3_vector_from_hunter1))
        
        if((area_h1_h2_p < area_h1_h2_h3) and (cross_h1_h2_p_sign == cross_h1_h2_h3_sign) and (cos_h1_h2_p > cos_h1_h2_h3)):
            if(area_h1_h2_h3 < self.area_threshold):
              reward = self.reward1
              done = True
            else:
              reward = self.reward2            
        else:
            reward = self.reward3
        
        return self.prey_state, self.hunter_state1, self.hunter_state2, self.hunter_state3, reward, done

In [97]:
from torch import nn

def reparameterize(means, log_stds):
    stds = log_stds.exp()
    noises = torch.randn_like(means)
    us = means + noises * stds
    actions = torch.tanh(us)
    log_pis = calculate_log_pi(log_stds, noises, actions)
    return actions, log_pis

def calculate_log_pi(log_stds, noises, actions):
    gaussian_log_probs = (-0.5 * noises.pow(2) - log_stds).sum(dim=-1, keepdim=True) - 0.5* math.log(2 * math.pi) * log_stds.size(-1)
    log_pis = gaussian_log_probs - torch.log(1 - actions.pow(2) + 1e-6).sum(dim=-1, keepdim=True)
    return log_pis

class SACActor(nn.Module):
    def __init__(self, state_shape, action_shape):
        super().__init__()
        
        self.net=nn.Sequential(
            nn.Linear(state_shape[0],256),
            nn.ReLU(inplace=True),
            nn.Linear(256,256),
            nn.ReLU(inplace=True),
            nn.Linear(256, 2*action_shape[0])
        )
        
    def forward(self, states):
        return torch.tanh(self.net(states).chunk(2,dim=-1)[0])
    
    def sample(self, states):
        means, log_stds = self.net(states).chunk(2,dim=-1)
        return reparameterize(means, log_stds.clamp_(-20,2))
    
class SACCritic(nn.Module):
    def __init__(self, state_shape, action_shape):
        super().__init__()
        
        self.net1 = nn.Sequential(
            nn.Linear(state_shape[0] + action_shape[0], 256),
            nn.ReLU(inplace=True),
            nn.Linear(256,256),
            nn.ReLU(inplace=True),
            nn.Linear(256,1)
        )
        
        self.net2 = nn.Sequential(
            nn.Linear(state_shape[0] + action_shape[0], 256),
            nn.ReLU(inplace=True),
            nn.Linear(256,256),
            nn.ReLU(inplace=True),
            nn.Linear(256,1)
        )
        
    def forward(self, states, actions):
        x = torch.cat([states, actions], dim=-1)
        return self.net1(x), self.net2(x)

In [98]:
class ReplayBuffer:
    def __init__(self, buffer_size, state_shape, action_shape, device):
        self._p = 0
        self._n = 0
        self.buffer_size = buffer_size
        
        self.states = torch.empty((buffer_size, *state_shape), dtype=torch.float, device=device)
        self.actions = torch.empty((buffer_size, *action_shape), dtype=torch.float, device=device)
        self.rewards = torch.empty((buffer_size, 1), dtype=torch.float, device=device)
        self.dones = torch.empty((buffer_size, 1), dtype=torch.float, device=device)
        self.next_states = torch.empty((buffer_size, *state_shape), dtype=torch.float, device=device)
        
    def append(self, state, action, reward, done, next_state):
        self.states[self._p].copy_(torch.from_numpy(state))
        self.actions[self._p].copy_(torch.from_numpy(action))
        self.rewards[self._p] = float(reward)
        self.dones[self._p] = float(done)
        self.next_states[self._p].copy_(torch.from_numpy(next_state))
        
        self._p = (self._p + 1) % self.buffer_size
        self._n = min(self._n + 1, self.buffer_size)
        
    def sample(self, batch_size):
        idxes = np.random.randint(low=0, high=self._n, size=batch_size)
        return(
            self.states[idxes],
            self.actions[idxes],
            self.rewards[idxes],
            self.dones[idxes],
            self.next_states[idxes]
        )

In [99]:
from abc import ABC, abstractmethod

class Algorithm(ABC):
    def explore(self, state, actor):
        state = torch.tensor(state, dtype=torch.float, device=self.device).unsqueeze_(0)
        with torch.no_grad():
            action, log_pi = actor.sample(state)
        return action.cpu().numpy()[0], log_pi.item()
    
    def exploit(self, state, actor):
        state = torch.tensor(state, dtype=torch.float, device=self.device).unsqueeze_(0)
        with torch.no_grad():
            action = actor(state)
        return action.cpu().numpy()[0]

In [100]:
import math, torch
import numpy as np

env = Environment()

class SAC(Algorithm):
    def __init__(self, state_shape, action_shape, device=torch.device('cuda'), seed=0, batch_size=256, gamma=0.99, lr_actor=3e-4, lr_critic=3e-4, replay_size=10**6, start_steps=10**4, tau=5e-3, alpha=0.2, reward_scale=1.0):
        super().__init__()
        
        np.random.seed(seed)
        torch.manual_seed(seed)
        
        self.buffer1 = ReplayBuffer(
            buffer_size=replay_size,
            state_shape=state_shape,
            action_shape=action_shape,
            device=device
        )
        
        self.buffer2 = ReplayBuffer(
            buffer_size=replay_size,
            state_shape=state_shape,
            action_shape=action_shape,
            device=device
        )
        
        self.buffer3 = ReplayBuffer(
            buffer_size=replay_size,
            state_shape=state_shape,
            action_shape=action_shape,
            device=device
        )
        
        self.actor1 = SACActor(
            state_shape=state_shape,
            action_shape=action_shape
        ).to(device)
        
        self.critic1 = SACCritic(
            state_shape=state_shape,
            action_shape=action_shape
        ).to(device)
        
        self.critic_target1 = SACCritic(
            state_shape=state_shape,
            action_shape=action_shape
        ).to(device).eval()
        
        self.actor2 = SACActor(
            state_shape=state_shape,
            action_shape=action_shape
        ).to(device)
        
        self.critic2 = SACCritic(
            state_shape=state_shape,
            action_shape=action_shape
        ).to(device)
        
        self.critic_target2 = SACCritic(
            state_shape=state_shape,
            action_shape=action_shape
        ).to(device).eval()
        
        self.actor3 = SACActor(
            state_shape=state_shape,
            action_shape=action_shape
        ).to(device)
        
        self.critic3 = SACCritic(
            state_shape=state_shape,
            action_shape=action_shape
        ).to(device)
        
        self.critic_target3 = SACCritic(
            state_shape=state_shape,
            action_shape=action_shape
        ).to(device).eval()
        
        self.critic_target1.load_state_dict(self.critic1.state_dict())
        for param in self.critic_target1.parameters():
            param.requires_grad = False
            
        self.critic_target2.load_state_dict(self.critic2.state_dict())
        for param in self.critic_target2.parameters():
            param.requires_grad = False
            
        self.critic_target3.load_state_dict(self.critic3.state_dict())
        for param in self.critic_target3.parameters():
            param.requires_grad = False
        
        self.optim_actor1 = torch.optim.Adam(self.actor1.parameters(), lr=lr_actor)
        self.optim_actor2 = torch.optim.Adam(self.actor1.parameters(), lr=lr_actor)
        self.optim_actor3 = torch.optim.Adam(self.actor1.parameters(), lr=lr_actor)
        
        self.optim_critic1 = torch.optim.Adam(self.critic1.parameters(), lr=lr_critic)
        self.optim_critic2 = torch.optim.Adam(self.critic2.parameters(), lr=lr_critic)
        self.optim_critic3 = torch.optim.Adam(self.critic3.parameters(), lr=lr_critic)
        
        self.learning_steps = 0
        self.batch_size = batch_size
        self.device = device
        self.gamma = gamma
        self.start_steps = start_steps
        self.tau = tau
        self.alpha = alpha
        self.reward_scale = reward_scale
        
    def is_update(self, steps):
        return steps >= max(self.start_steps, self.batch_size)
    
    def step(self, env, hunter_state1, hunter_state2, hunter_state3, t, steps):
        t += 1
        
        if steps <= self.start_steps:
            r = np.random.rand(3)
            theta = 2 * np.pi * np.random.rand(3)
            action = np.array([r * np.cos(theta), r * np.sin(theta)])
            action = action.T
            action_1 = action[0] 
            action_2 = action[1] 
            action_3 = action[2] 
        else:
            action_1,_ = self.explore(hunter_state1, self.actor1)
            action_2,_ = self.explore(hunter_state2, self.actor2)
            action_3,_ = self.explore(hunter_state3, self.actor3)
        
        next_prey_state, next_hunter_state1, next_hunter_state2, next_hunter_state3, reward, done = env.step(action_1, action_2, action_3)
        
        if t == env.max_episode_steps:
            done_masked = False
        else:
            done_masked = done
            
        self.buffer1.append(hunter_state1, action_1, reward, done_masked, next_hunter_state1)
        self.buffer2.append(hunter_state2, action_2, reward, done_masked, next_hunter_state2)
        self.buffer3.append(hunter_state3, action_3, reward, done_masked, next_hunter_state3)
        
        if done:
            t = 0
            next_hunter_state1, next_hunter_state2, next_hunter_state3, next_prey_state = env.initialize_state()
        
        return next_hunter_state1, next_hunter_state2, next_hunter_state3, next_prey_state, t
    
    def update(self):
        self.learning_steps += 1
        
        hunter_states1, actions1, rewards1, dones1, next_hunter_states1 = self.buffer1.sample(self.batch_size)
        hunter_states2, actions2, rewards2, dones2, next_hunter_states2 = self.buffer2.sample(self.batch_size)
        hunter_states3, actions3, rewards3, dones3, next_hunter_states3 = self.buffer3.sample(self.batch_size)
        
        self.update_critic(hunter_states1, actions1, rewards1, dones1, next_hunter_states1, self.critic1, self.actor1, self.critic_target1, self.optim_critic1)
        self.update_critic(hunter_states2, actions2, rewards2, dones2, next_hunter_states2, self.critic2, self.actor2, self.critic_target2, self.optim_critic2)
        self.update_critic(hunter_states3, actions3, rewards3, dones3, next_hunter_states3, self.critic3, self.actor3, self.critic_target3, self.optim_critic3)

        self.update_actor(hunter_states1, self.critic1, self.actor1, self.optim_actor1)
        self.update_actor(hunter_states2, self.critic2, self.actor2, self.optim_actor2)
        self.update_actor(hunter_states3, self.critic3, self.actor3, self.optim_actor3)
        self.update_target()
        
    def update_critic(self, states, actions, rewards, dones, next_states, critic, actor, critic_target, optim_critic):
        curr_qs1, curr_qs2 = critic(states, actions)
        
        with torch.no_grad():
            next_actions, log_pis = actor.sample(next_states)
            next_qs1, next_qs2 = critic_target(next_states, next_actions)
            next_qs = torch.min(next_qs1, next_qs2) - self.alpha * log_pis
        target_qs = rewards * self.reward_scale + (1.0 - dones) * self.gamma * next_qs
        
        loss_critic1 = (curr_qs1 - target_qs).pow_(2).mean()
        loss_critic2 = (curr_qs2 - target_qs).pow_(2).mean()
        
        optim_critic.zero_grad()
        (loss_critic1 + loss_critic2).backward(retain_graph=False)
        optim_critic.step()
    
    def update_actor(self, states, critic, actor, optim_actor):
        actions, log_pis = actor.sample(states)
        qs1, qs2 = critic(states, actions)
        loss_actor = (self.alpha * log_pis - torch.min(qs1, qs2)).mean()
        
        optim_actor.zero_grad()
        loss_actor.backward(retain_graph=False)
        optim_actor.step()
        
    def update_target(self):
        for t, s in zip(self.critic_target1.parameters(), self.critic1.parameters()):
            t.data.mul_(1.0 - self.tau)
            t.data.add_(self.tau * s.data)
        
        for t, s in zip(self.critic_target2.parameters(), self.critic2.parameters()):
            t.data.mul_(1.0 - self.tau)
            t.data.add_(self.tau * s.data)
        
        for t, s in zip(self.critic_target3.parameters(), self.critic3.parameters()):
            t.data.mul_(1.0 - self.tau)
            t.data.add_(self.tau * s.data)

In [101]:
import matplotlib.pyplot as plt
from time import time
from datetime import timedelta

class Trainer:
    def __init__(self, env, env_test, algo, seed=0, num_steps=10**6, eval_interval=10**4, num_eval_episodes=3):
        self.env = env
        self.env_test = env_test
        self.algo = algo
        
        # self.env.seed(seed)
        # self.env_test.seed(2**32 - seed)
        
        self.returns = {'step': [], 'return': []}
        
        self.num_steps = num_steps
        self.eval_interval = eval_interval
        self.num_eval_episodes = num_eval_episodes
        self.prey_state = env.prey_state
        
    def train(self):
        self.start_time = time()
        
        t = 0
        
        hunter_state1, hunter_state2, hunter_state3, prey_state = self.env.initialize_state()
        
        for steps in range(1, self.num_steps + 1):
          hunter_state1, hunter_state2, hunter_state3, prey_state, t = self.algo.step(self.env, hunter_state1, hunter_state2, hunter_state3, t, steps)
            
          if(self.algo.is_update(steps)):
              self.algo.update()
                
          if(steps % self.eval_interval == 0):
              self.evaluate(steps)

          if(steps % env.max_episode_steps == 0):
              hunter_state1, hunter_state2, hunter_state3, prey_state = self.env.initialize_state()

          if(steps % 1000 == 0):
              print(steps)
              print(prey_state)
                
    def evaluate(self, steps):
        returns = []
        
        for _ in range(self.num_eval_episodes):
            hunter_state1, hunter_state2, hunter_state3, self.prey_state = self.env_test.initialize_state()
            done = False
            episode_return = 0.0
            s = 0
            print('New loop ', done)
            
            while(not done):
                if(s == env.max_episode_steps):
                  break
                s += 1
                action_hunter1 = self.algo.exploit(hunter_state1, self.algo.actor1)
                action_hunter2 = self.algo.exploit(hunter_state2, self.algo.actor2)
                action_hunter3 = self.algo.exploit(hunter_state3, self.algo.actor3)
                hunter_state1, hunter_state2, hunter_state3, prey_state, reward, done = self.env_test.step(action_hunter1, action_hunter2, action_hunter3)
                episode_return += reward
                print(reward, done, hunter_state1, hunter_state2, hunter_state3, prey_state)
            
            returns.append(episode_return)
        
        mean_return = np.mean(returns)
        self.returns['step'].append(steps)
        self.returns['return'].append(mean_return)
        
        print(f'Num steps: {steps:<6} '
             f'Return: {mean_return:<5.1f} '
             f'Time: {self.time}')
        
    def visualize(self):
        return
    
    def plot(self):
        fig = plt.figure(figsize=(8,6))
        plt.plot(self.returns['step'], self.returns['return'])
        plt.xlabel('Steps', fontsize=24)
        plt.ylabel('Return', fontsize24)
        plt.tick_params(labelsize=18)
        plt.tight_layout()

    def time(self):
      return str(timedelta(seconds=int(time() - self.start_time)))

In [102]:
SEED = 0
REWARD_SCALE = 5.0
NUM_STEPS = 10 ** 7
EVAL_INTERVAL = 10 ** 4

algo = SAC(
    state_shape=(2,),
    action_shape=(2,),
    # seed=SEED,
    reward_scale=REWARD_SCALE
)

trainer = Trainer(
    env=Environment(),
    env_test=Environment(),
    algo=algo,
    # seed=SEED,
    num_steps=NUM_STEPS,
    eval_interval=EVAL_INTERVAL
)

In [None]:
trainer.train()

1000
[0.61459379 0.50514041]
2000
[0.03573138 0.03654005]
3000
[0.55539599 0.58235771]
4000
[0.38126914 0.1855671 ]
5000
[0.16071907 0.29694245]
6000
[0.00328134 0.44954318]
7000
[0.26959508 0.34348858]
8000
[0.44377657 0.89358   ]
9000
[0.353671   0.76579535]
New loop  False
-0.2 False [-0.05872571  1.21627746] [0.25634118 0.39815086] [0.16930184 0.30427325] [0.08894684 0.84768977]
-0.2 False [-0.21504499  1.69121355] [0.28916709 0.43097677] [0.16565641 0.30062781] [0.10514716 0.86389009]
-0.2 False [-0.37873757  2.16365903] [0.33280391 0.47461359] [0.16310609 0.2980775 ] [0.12119707 0.87994   ]
-0.2 False [-0.54908737  2.63374514] [0.38650847 0.52831815] [0.16217587 0.29714728] [0.13714021 0.89588314]
-0.2 False [-0.72568582  3.10151964] [0.44899662 0.5908063 ] [0.16370346 0.29867487] [0.15304354 0.91178647]
-0.2 False [-0.90824754  3.5669992 ] [0.52129903 0.66310871] [0.16761888 0.30259029] [0.16901227 0.9277552 ]
-0.2 False [-1.09664405  4.03014794] [0.60334796 0.74515763] [0.17361