In [None]:
from saida_gym.starcraft_multi.marineVsZealot import MarineVsZealot
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.distributions import Categorical
import matplotlib.pyplot as plt
import numpy as np
import time
import random
import math
import sys

In [None]:
EPS = 0.5 ## for exploration, annealed every episode
LAMBDA = 0.8 # for TD lambda
BATCH_SIZE = 10
GAMMA = 0.99

In [None]:
class Actor(nn.Module):
    def __init__(self):
        super(Actor, self).__init__()
        self.fc1 = nn.Linear(11, 256)
        self.fc2 = nn.Linear(256,256)
        self.fc_pi = nn.Linear(256, 15)
        self.optimizer = optim.Adam(self.parameters(), lr=0.0005)
    
    def pi(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        policy = (1-EPS)*F.softmax(self.fc_pi(x), dim=1) + (EPS/15) ## exploration
        return policy
    
    def train(self, loss):
        self.optimizer.zero_grad()
        loss.backward(retain_graph=True)
        self.optimizer.step()
    
class Critic(nn.Module):
    def __init__(self):
        super(Critic, self).__init__()
        self.fc1 = nn.Linear(51, 256)
        self.fc2 = nn.Linear(256, 256)
        self.q = nn.Linear(256, 15)
        self.optimizer = optim.Adam(self.parameters(), lr=0.0005)
        
    def Q(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        q = self.q(x)
        return q
    
    def train(self, loss):
        self.optimizer.zero_grad()
        loss.backward(retain_graph=True)
        self.optimizer.step()
    
class Make_data():
    def __init__(self):
        self.enemy_hp = 160.0
        self.my_hp = 120.0
        self.agents_state_list, self.global_state_list, self.action_list,\
        self.reward_list, self.done_list = [], [], [], [], []
    
    def feature_scale(self, hp, pos_x, pos_y, velo_x, velo_y, is_en):
        if not is_en:
            return [hp/10.0, pos_x/1000.0, pos_y/1000.0, velo_x, velo_y]
        else:
            return [hp/30.0, pos_x/1000.0, pos_y/1000.0]
        
    def make_agent_state(self, obs):
        agent_list = []
        en = obs.en_unit[0]
        enemy_state = self.feature_scale(en.hp+en.shield, en.pos_x, en.pos_y, en.velocity_x, en.velocity_y, is_en=True)
        for i, agent in enumerate(obs.my_unit):
            hp, pos_x, pos_y, velo_x, velo_y = agent.hp, agent.pos_x, agent.pos_y, agent.velocity_x, agent.velocity_y
            agent_state = self.feature_scale(hp, pos_x, pos_y, velo_x, velo_y, is_en=False) + enemy_state
            one_hot = [0,0,0]
            one_hot[i] += 1
            agent_list.append(agent_state + one_hot)
        return torch.tensor(agent_list, dtype=torch.float)
    
    def make_global_state(self, obs):
        en = obs.en_unit[0]
        global_state = self.feature_scale(en.hp+en.shield, en.pos_x, en.pos_y, en.velocity_x, en.velocity_y, is_en=True)
        for i, agent in enumerate(obs.my_unit):
            hp, pos_x, pos_y, velo_x, velo_y = agent.hp, agent.pos_x, agent.pos_y, agent.velocity_x, agent.velocity_y
            global_state += self.feature_scale(hp, pos_x, pos_y, velo_x, velo_y, is_en=False)
        return torch.tensor(global_state, dtype=torch.float)
    
    def make_reward(self, obs, done):
        reward = None
        en = obs.en_unit[0]
        en_hp = en.hp+en.shield
        agents_hp = 0.0
        for my in obs.my_unit:
            agents_hp += my.hp
        if self.enemy_hp > en_hp:
            self.enemy_hp = en_hp
            if self.my_hp > agents_hp:
                self.my_hp = agents_hp
                reward = -1.0 # 때렸지만 맞았을때
            else:
                reward = 0.65 # 때리기만 했을때
        else:
            if self.my_hp > agents_hp:
                self.my_hp = agents_hp
                reward = -1.0 # 맞기만 했을때
            else:
                reward = 0.0 # 때리지도 않고 맞지도 않았을때
        if done:
            self.enemy_hp = 160.0
            self.my_hp = 120.0
        return reward
    
    def store_data(self, agents_state, global_state, actions, reward, done):
        self.agents_state_list.append(agents_state)
        self.global_state_list.append(global_state)
        self.action_list.append(actions)
        self.reward_list.append(reward)
        self.done_list.append(done)
        
    def store_terminal_state(self, agents_state, global_state):
        self.agents_state_list.append(agents_state)
        self.global_state_list.append(global_state)
        
    def return_training_data(self):
        episode = [self.agents_state_list, self.global_state_list, self.action_list, self.reward_list, self.done_list]
        self.agents_state_list, self.global_state_list, self.action_list,\
        self.reward_list, self.done_list = [], [], [], [], []
        return episode

In [None]:
env = MarineVsZealot(frames_per_step=5, action_type=0, move_angle=30, move_dist=3, verbose=0,\
                          local_speed=0, no_gui=False, auto_kill=True)


In [None]:
actor = Actor()
critic = Critic()
Data = Make_data()

In [None]:
def TD_lambda(Q, a, r):
    """Evaluate TD lambda in single for loop! :) (with dynamic programming)"""
    next_Q = torch.cat([Q[1:].gather(1,a[1:,[0]]), torch.zeros(1,1).float()], dim=0) ## concat for terminal state value == 0 
    td_lambda_temp = r + GAMMA*next_Q
    td_lambda = td_lambda_temp
    for t in range(1, r.shape[0]):
        shift_td = torch.cat([td_lambda_temp[t:], torch.zeros(t,1).float()], dim=0)
        td_lambda_temp = r + GAMMA*shift_td
        td_lambda_temp[-t:] *= 0.0
        td_lambda += (LAMBDA**t)*td_lambda_temp
    td_lambda *= 1 - LAMBDA
    return td_lambda

def training(episode):
    agent_state_list, global_state_list, action_list, reward_list, done_list = episode
    
    s = torch.cat(global_state_list, dim=0).reshape(-1,18)
    a = torch.cat(action_list, dim=0).reshape(-1,3)
    r = torch.tensor(reward_list, dtype=torch.float).reshape(-1,1)
    agent_one_hot = F.one_hot(torch.arange(s.shape[0])%1, num_classes=3).float()
    action_one_hot = F.one_hot(a[:,[1,2]], num_classes=15).reshape(-1,30).float()
    s_q_input = torch.cat([action_one_hot, s, agent_one_hot], dim=1)
    Q = critic.Q(s_q_input)
    
    td_lambda = TD_lambda(Q, a, r)
    for t in reversed(range(Q.shape[0])):
        temp_Q = critic.Q(s_q_input[t])
        cur_Q = temp_Q[a[t,0]]
        critic_loss = (td_lambda[t,0].detach() - cur_Q).pow(2)
        critic.train(critic_loss)
    
    for agent in range(3):
        agent_state = torch.cat(agent_state_list, dim=1)[agent,:].reshape(-1,11)
        agent_one_hot = F.one_hot(torch.tensor(agent).repeat(s.shape[0]), num_classes=3).float()
        action_one_hot = F.one_hot(a[:,[i for i in range(a.shape[1]) if not agent==i]], num_classes=15).reshape(-1,30).float()
        s_q_input = torch.cat([action_one_hot, s, agent_one_hot], dim=1)
        Q = critic.Q(s_q_input)
        real_Q = Q.gather(1,a[:,[agent]])
        pi = actor.pi(agent_state)
        coma = real_Q - torch.sum(pi*Q, dim=1).reshape(-1,1) ##### COMA calculation !!!!!
        actor_loss = -torch.log(pi.gather(1,a[:,[agent]]))*coma.detach()
        for i in range(actor_loss.shape[0]):
            actor.train(actor_loss[i,0])

In [None]:
reward_sum = 0.0
reward_list = []

for ep in range(100000):
    if EPS > 0.1: EPS -= 0.0001 # annealing EPS
    observation = env.reset()
    while True:
        agents_state = Data.make_agent_state(observation)
        global_state = Data.make_global_state(observation)
        pi = actor.pi(agents_state)
        actions = Categorical(pi).sample()
        observation, _, done, _ = env.step(actions.numpy())
        reward = Data.make_reward(observation, done)
        reward_sum += reward
        Data.store_data(agents_state, global_state, actions, reward, done)
        if done:
            break
            
    training(Data.return_training_data())
    
    if ep % 20 == 19:
        print('Episode %d'%ep,', Reward mean : %f'%(reward_sum/20.0))
        reward_list.append(reward_sum/20.0)
        #plt.plot(reward_list)
        #plt.show()
        reward_sum = 0.0
env.close()