In [1]:
# imports
import asyncio
import json
import neptune.new as neptune
import nest_asyncio
import numpy as np
import os
import time

from collections import defaultdict
from datetime import date
from itertools import product
from poke_env.environment.abstract_battle import AbstractBattle
from poke_env.player.battle_order import ForfeitBattleOrder
from poke_env.player.player import Player
# from poke_env.player.random_player import RandomPlayer
from src.PlayerQLearning import Player as PlayerQLearning


In [2]:
# global configs

debug = True
save_to_json_file = True
use_validation = True
use_neptune = False

nest_asyncio.apply()
np.random.seed(0)

if use_neptune:
    run = neptune.init(project='project', api_token='token')


In [3]:
# our team

OUR_TEAM = """
Pikachu-Original (M) @ Light Ball  
Ability: Static  
EVs: 252 Atk / 4 SpD / 252 Spe  
Jolly Nature  
- Volt Tackle  
- Nuzzle  
- Iron Tail  
- Knock Off  

Charizard @ Life Orb  
Ability: Solar Power  
EVs: 252 SpA / 4 SpD / 252 Spe  
Timid Nature  
IVs: 0 Atk  
- Flamethrower  
- Dragon Pulse  
- Roost  
- Sunny Day  

Blastoise @ White Herb  
Ability: Torrent  
EVs: 4 Atk / 252 SpA / 252 Spe  
Mild Nature  
- Scald  
- Ice Beam  
- Earthquake  
- Shell Smash  

Venusaur @ Black Sludge  
Ability: Chlorophyll  
EVs: 252 SpA / 4 SpD / 252 Spe  
Modest Nature  
IVs: 0 Atk  
- Giga Drain  
- Sludge Bomb  
- Sleep Powder  
- Leech Seed  

Sirfetch’d @ Aguav Berry  
Ability: Steadfast  
EVs: 248 HP / 252 Atk / 8 SpD  
Adamant Nature  
- Close Combat  
- Swords Dance  
- Poison Jab  
- Knock Off  

Tauros (M) @ Assault Vest  
Ability: Intimidate  
EVs: 252 Atk / 4 SpD / 252 Spe  
Jolly Nature  
- Double-Edge  
- Earthquake  
- Megahorn  
- Iron Head  
"""


In [4]:
# opponent's team

OP_TEAM = """
Eevee @ Eviolite  
Ability: Adaptability  
EVs: 252 HP / 252 Atk / 4 SpD  
Adamant Nature  
- Quick Attack  
- Flail  
- Facade  
- Wish  

Vaporeon @ Leftovers  
Ability: Hydration  
EVs: 252 HP / 252 Def / 4 SpA  
Bold Nature  
IVs: 0 Atk  
- Scald  
- Shadow Ball  
- Toxic  
- Wish  

Sylveon @ Aguav Berry  
Ability: Pixilate  
EVs: 252 HP / 252 SpA / 4 SpD  
Modest Nature  
IVs: 0 Atk  
- Hyper Voice  
- Mystical Fire  
- Psyshock  
- Calm Mind  

Jolteon @ Assault Vest  
Ability: Quick Feet  
EVs: 252 SpA / 4 SpD / 252 Spe  
Timid Nature  
IVs: 0 Atk  
- Thunderbolt  
- Hyper Voice  
- Volt Switch  
- Shadow Ball  

Leafeon @ Life Orb  
Ability: Chlorophyll  
EVs: 252 Atk / 4 SpD / 252 Spe  
Adamant Nature  
- Leaf Blade  
- Knock Off  
- X-Scissor  
- Swords Dance  

Umbreon @ Iapapa Berry  
Ability: Inner Focus  
EVs: 252 HP / 4 Atk / 252 SpD  
Careful Nature  
- Foul Play  
- Body Slam  
- Toxic  
- Wish  
"""


In [5]:
N_OUR_MOVE_ACTIONS = 4
N_OUR_SWITCH_ACTIONS = 5
N_OUR_ACTIONS = N_OUR_MOVE_ACTIONS + N_OUR_SWITCH_ACTIONS

ALL_OUR_ACTIONS = np.array(range(0, N_OUR_ACTIONS))

NAME_TO_ID_DICT = {
    "pikachuoriginal": 0,
    "charizard": 1,
    "blastoise": 2,
    "venusaur": 3,
    "sirfetchd": 4,
    "tauros": 5,
    "eevee": 6,
    "vaporeon": 7,
    "sylveon": 8,
    "jolteon": 9,
    "leafeon": 10,
    "umbreon": 11
}

In [6]:
# Max-damage player
class MaxDamagePlayer(Player):
    def choose_move(self, battle):
        if battle.available_moves:
            best_move = max(battle.available_moves, key=lambda move: move.base_power)
            return self.create_order(best_move)
        else:
            return self.choose_random_move(battle)

In [7]:
# SARSA player
class SARSAPlayer(PlayerQLearning):
    def __init__(self, battle_format, team, n0, gamma):
        super().__init__(battle_format=battle_format, team=team)
        self.N = defaultdict(lambda: np.zeros(N_OUR_ACTIONS))
        self.Q = defaultdict(lambda: np.zeros(N_OUR_ACTIONS))
        self.n0 = n0
        self.gamma = gamma
        self.state = None
        self.action = None

    def choose_move(self, battle):
        
        if self.state is not None:
            # observe R, state2 and Take action 2
            reward = self.compute_reward(battle)
            state2 = self.embed_battle(battle)
            action2 = self.choose_action(state2)
            
            #alpha
            self.N[self.state][self.action] += 1
            alpha = 1.0 / self.N[self.state][self.action]
            #alpha = 0.85
            
            #Calculate SARSA
            predict = self.Q[self.state][self.action]
            target = reward + self.gamma * self.Q[state2][action2]
            self.Q[self.state][self.action] = self.Q[self.state][self.action] + alpha * (target - predict)
            
            # S <- S'
            self.state = state2
        else:
            # S first initialization
            self.state = self.embed_battle(battle)
            
        # Choose action
        self.action = self.choose_action(self.state)

        # if the selected action is not possible, perform a random move instead
        if self.action == -1:
            return ForfeitBattleOrder()
        elif self.action < 4 and self.action < len(battle.available_moves) and not battle.force_switch:
            return self.create_order(battle.available_moves[self.action])
        elif 0 <= self.action - 4 < len(battle.available_switches):
            return self.create_order(battle.available_switches[self.action - 4])
        else:
            return self.choose_random_move(battle)

    def _battle_finished_callback(self, battle):
        pass
    
    
    
    ''' Helper functions '''
    #Function to choose the next action
    def choose_action(self, state):
        # epsilon
        # epsilon = 0.9
        epsilon = self.n0 / (self.n0 + np.sum(self.N[state]))
        action=0
        if np.random.uniform(0, 1) < epsilon:
            # Choose a random action
            action = np.random.choice(self.Q[state])
        else:
            # Choose the action of a greedy policy
            action = np.random.choice(np.where(self.Q[state] == self.Q[state].max())[0])
        #Return the action
        return int(action)

    # epsilon-greedy policy
    def pi(self, state):
        epsilon = self.n0 / (self.n0 + np.sum(self.N[state]))
        # let's get the greedy action. Ties must be broken arbitrarily
        greedy_action = np.random.choice(np.where(self.Q[state] == self.Q[state].max())[0])
        action_pick_probability = np.full(N_OUR_ACTIONS, epsilon / N_OUR_ACTIONS)
        action_pick_probability[greedy_action] += 1 - epsilon
        return np.random.choice(ALL_OUR_ACTIONS, p=action_pick_probability)

    # the embed battle is our state
    # 12 factors: our active mon, opponent's active mon, 4 moves base power, 4 moves multipliers, num fainted mons
    @staticmethod
    def embed_battle(battle):
        # -1 indicates that the move does not have a base power
        # or is not available
        moves_base_power = -np.ones(4)
        moves_dmg_multiplier = np.ones(4)
        for i, move in enumerate(battle.available_moves):
            moves_base_power[i] = (
                    move.base_power / 100
            )  # Simple rescaling to facilitate learning
            if move.type:
                moves_dmg_multiplier[i] = move.type.damage_multiplier(
                    battle.opponent_active_pokemon.type_1,
                    battle.opponent_active_pokemon.type_2,
                )

        # We count how many pokemons have not fainted in each team
        n_fainted_mon_team = (
            len([mon for mon in battle.team.values() if mon.fainted])
        )
        n_fainted_mon_opponent = (
            len([mon for mon in battle.opponent_team.values() if mon.fainted])
        )

        state = list()
        state.append(NAME_TO_ID_DICT[str(battle.active_pokemon).split(' ')[0]])
        state.append(NAME_TO_ID_DICT[str(battle.opponent_active_pokemon).split(' ')[0]])
        for move_base_power in moves_base_power:
            state.append('{0:.2f}'.format(move_base_power))
        for move_dmg_multiplier in moves_dmg_multiplier:
            state.append('{0:.2f}'.format(move_dmg_multiplier))
        state.append(n_fainted_mon_team)
        state.append(n_fainted_mon_opponent)

        return str(state)

    # Computing rewards
    def reward_computing_helper(
            self,
            battle: AbstractBattle,
            *,
            fainted_value: float = 0.0,
            hp_value: float = 0.0,
            number_of_pokemons: int = 6,
            starting_value: float = 0.0,
            status_value: float = 0.0,
            victory_value: float = 1.0
    ) -> float:
        # 1st compute
        if battle not in self._reward_buffer:
            self._reward_buffer[battle] = starting_value
        current_value = 0

        # Verify if pokemon have fainted or have status
        for mon in battle.team.values():
            current_value += mon.current_hp_fraction * hp_value
            if mon.fainted:
                current_value -= fainted_value
            elif mon.status is not None:
                current_value -= status_value

        current_value += (number_of_pokemons - len(battle.team)) * hp_value

        # Verify if opponent pokemon have fainted or have status
        for mon in battle.opponent_team.values():
            current_value -= mon.current_hp_fraction * hp_value
            if mon.fainted:
                current_value += fainted_value
            elif mon.status is not None:
                current_value += status_value

        current_value -= (number_of_pokemons - len(battle.opponent_team)) * hp_value

        # Verify if we won or lost
        if battle.won:
            current_value += victory_value
        elif battle.lost:
            current_value -= victory_value

        # Value to return
        to_return = current_value - self._reward_buffer[battle]
        self._reward_buffer[battle] = current_value
        if use_neptune:
            run[f'N0: {self.n0}, gamma: {self.gamma} reward_buffer'].log(current_value)
            run[f'N0: {self.n0}, gamma: {self.gamma} reward returned'].log(to_return)
        return to_return

    # Calling reward_computing_helper
    def compute_reward(self, battle) -> float:
        return self.reward_computing_helper(battle, fainted_value=2, hp_value=1, victory_value=30)


In [8]:
# Q-learning validation player
class SARSAValidationPlayer(PlayerQLearning):
    def __init__(self, battle_format, team, Q):
        super().__init__(battle_format=battle_format, team=team)
        self.Q = Q

    def choose_move(self, battle):
        state = self.embed_battle(battle)
        # let's get the greedy action. Ties must be broken arbitrarily
        if state in self.Q.keys():
            action = np.random.choice(np.where(self.Q[state] == self.Q[state].max())[0])
        else:
            return self.choose_random_move(battle)

        # if the selected action is not possible, perform a random move instead
        if action == -1:
            return ForfeitBattleOrder()
        elif action < 4 and action < len(battle.available_moves) and not battle.force_switch:
            return self.create_order(battle.available_moves[action])
        elif 0 <= action - 4 < len(battle.available_switches):
            return self.create_order(battle.available_switches[action - 4])
        else:
            return self.choose_random_move(battle)

    def _battle_finished_callback(self, battle):
        pass

    # the embed battle is our state
    # 12 factors: our active mon, opponent's active mon, 4 moves base power, 4 moves multipliers, remaining mons
    @staticmethod
    def embed_battle(battle):
        # -1 indicates that the move does not have a base power
        # or is not available
        moves_base_power = -np.ones(4)
        moves_dmg_multiplier = np.ones(4)
        for i, move in enumerate(battle.available_moves):
            moves_base_power[i] = (
                    move.base_power / 100
            )  # Simple rescaling to facilitate learning
            if move.type:
                moves_dmg_multiplier[i] = move.type.damage_multiplier(
                    battle.opponent_active_pokemon.type_1,
                    battle.opponent_active_pokemon.type_2,
                )

        # We count how many pokemons have not fainted in each team
        remaining_mon_team = (
            len([mon for mon in battle.team.values() if mon.fainted])
        )
        remaining_mon_opponent = (
            len([mon for mon in battle.opponent_team.values() if mon.fainted])
        )

        state = list()
        state.append(NAME_TO_ID_DICT[str(battle.active_pokemon).split(' ')[0]])
        state.append(NAME_TO_ID_DICT[str(battle.opponent_active_pokemon).split(' ')[0]])
        for move_base_power in moves_base_power:
            state.append('{0:.2f}'.format(move_base_power))
        for move_dmg_multiplier in moves_dmg_multiplier:
            state.append('{0:.2f}'.format(move_dmg_multiplier))
        state.append(remaining_mon_team)
        state.append(remaining_mon_opponent)

        return str(state)


In [9]:
# global parameters

# possible values for num_battles (number of episodes)
n_battles_array = [10000]
# exploration schedule from MC, i. e., epsilon(t) = N0 / (N0 + N(S(t)))
n0_array = [0.1, 0.2, 0.3]
# possible values for gamma (discount factor)
gamma_array = [0.1, 0.2, 0.3]

list_of_params = [
    {
        'n_battles': n_battles,
        'n0': n0,
        'gamma': gamma
    } for n_battles, n0, gamma in product(n_battles_array, n0_array, gamma_array)
]


In [10]:
# main (let's battle!)


# helper function: save to json file
def save_to_json(path, params, name, value):
    today_s = str(date.today())
    n_battle_s = str(params['n_battles'])
    n0_s = str(round(params['n0'], 2))
    gamma_s = str(round(params['gamma'], 2))
    winning_percentage_s = str(round((params['player'].n_won_battles / params['n_battles']) * 100, 2))
    if not os.path.exists(path):
        os.makedirs(path)
    filename = path + "/" + name + "_QLearning_" + today_s + "_n_battles_" + n_battle_s + "_N0_" + n0_s + "_gamma_" + gamma_s + "_wining_" + winning_percentage_s + ".json "
    file = open(filename, "w")
    value_dict = dict()
    for key in value:
        value_dict[key] = value[key].tolist()
    json.dump(value_dict, file)
    file.close()


# let's battle!
async def lets_battle():
    for params in list_of_params:
        # training
        start = time.time()
        if use_neptune:
            run['params'] = params
        params['player'] = SARSAPlayer(battle_format="gen8ou", team=OUR_TEAM, n0=params['n0'], gamma=params['gamma'])
        params['opponent'] = MaxDamagePlayer(battle_format="gen8ou", team=OP_TEAM)
        await params['player'].battle_against(opponent=params['opponent'], n_battles=params['n_battles'])
        if debug:
            print("training: num battles (episodes)=%d, N0=%f, gamma=%f, wins=%d, winning percentage=%f, total time=%s seconds" %
                  (
                      params['n_battles'],
                      round(params['n0'], 2),
                      round(params['gamma'], 2),
                      params['player'].n_won_battles,
                      round((params['player'].n_won_battles / params['n_battles']) * 100, 2),
                      round(time.time() - start, 2)
                  ))
        if save_to_json_file:
            # save Q to json file
            save_to_json("./sarsa_dump10000", params, "Q", params['player'].Q)

        # validation (play 1/3 of the battles using Q-learned table)
        start = time.time()
        if True:
            params['validation_player'] = SARSAValidationPlayer(battle_format="gen8ou", team=OUR_TEAM, Q=params['player'].Q)
            n_battles = int(params['n_battles'] / 3)
            await params['validation_player'].battle_against(opponent=params['opponent'], n_battles=n_battles)
            if debug:
                print("validation: num battles (episodes)=%d, N0=%f, gamma=%f, wins=%d, winning percentage=%f, total time=%s seconds" %
                      (
                          n_battles,
                          round(params['n0'], 2),
                          round(params['gamma'], 2),
                          params['validation_player'].n_won_battles,
                          round((params['validation_player'].n_won_battles / n_battles) * 100, 2),
                          round(time.time() - start, 2)
                      ))


In [11]:
loop = asyncio.get_event_loop()
loop.run_until_complete(loop.create_task(lets_battle()))

training: num battles (episodes)=10000, N0=0.100000, gamma=0.100000, wins=2891, winning percentage=28.910000, total time=467.91 seconds
validation: num battles (episodes)=3333, N0=0.100000, gamma=0.100000, wins=1590, winning percentage=47.700000, total time=296.86 seconds
training: num battles (episodes)=10000, N0=0.100000, gamma=0.200000, wins=2497, winning percentage=24.970000, total time=442.9 seconds
validation: num battles (episodes)=3333, N0=0.100000, gamma=0.200000, wins=1403, winning percentage=42.090000, total time=290.93 seconds
training: num battles (episodes)=10000, N0=0.100000, gamma=0.300000, wins=2232, winning percentage=22.320000, total time=402.59 seconds
validation: num battles (episodes)=3333, N0=0.100000, gamma=0.300000, wins=1615, winning percentage=48.450000, total time=292.13 seconds
training: num battles (episodes)=10000, N0=0.200000, gamma=0.100000, wins=2833, winning percentage=28.330000, total time=488.79 seconds
validation: num battles (episodes)=3333, N0=0.



training: num battles (episodes)=10000, N0=0.200000, gamma=0.300000, wins=2225, winning percentage=22.250000, total time=416.46 seconds
validation: num battles (episodes)=3333, N0=0.200000, gamma=0.300000, wins=1654, winning percentage=49.620000, total time=300.32 seconds
training: num battles (episodes)=10000, N0=0.300000, gamma=0.100000, wins=2877, winning percentage=28.770000, total time=523.66 seconds
validation: num battles (episodes)=3333, N0=0.300000, gamma=0.100000, wins=1579, winning percentage=47.370000, total time=323.59 seconds




training: num battles (episodes)=10000, N0=0.300000, gamma=0.200000, wins=2666, winning percentage=26.660000, total time=593.79 seconds
validation: num battles (episodes)=3333, N0=0.300000, gamma=0.200000, wins=1521, winning percentage=45.630000, total time=340.28 seconds
training: num battles (episodes)=10000, N0=0.300000, gamma=0.300000, wins=2125, winning percentage=21.250000, total time=451.45 seconds
validation: num battles (episodes)=3333, N0=0.300000, gamma=0.300000, wins=1551, winning percentage=46.530000, total time=380.58 seconds


