In [1]:
from showdown.engine.helpers import normalize_name as norm_name, calculate_stats
from showdown.engine.find_state_instructions import get_all_state_instructions, get_state_instructions_from_move, get_effective_speed
from showdown.engine.damage_calculator import calculate_damage
from showdown.engine.objects import Pokemon, State, StateMutator, Side
from showdown.engine.special_effects.abilities.on_switch_in import ability_on_switch_in
import json, config
from collections import defaultdict
from copy import deepcopy, copy
import numpy as np
import random, pickle
from multiprocessing import Pool, cpu_count, Queue, Process, Manager, SimpleQueue
import sys, importlib
#from thefuzz import process

import time
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.utils.data as data

# %matplotlib inline
from IPython.display import set_matplotlib_formats
from matplotlib.colors import to_rgba
%load_ext line_profiler
np.set_printoptions(precision=3, suppress=True)
config.damage_calc_type = 'average'

In [21]:
with open('pokedex.json', 'r') as fp:
    pokedex = json.load(fp)

with open('moves.json', 'r') as fp:
    all_moves = json.load(fp)
    
with open('data/gen8ou_teams.json', 'r') as fp:
    teams = json.load(fp)
    
curated_teams = []
for team in teams['teams']:
    if 'ou' in team['folder'].lower() or 'overused' in team['folder'].lower():
        curated_teams.append(team)
teams['teams'] = curated_teams

with open('items.json', 'r') as fp:
    items = json.load(fp)
abilities = set()
ability_dict = {}
for x in pokedex:
    for a in list(pokedex[x]['abilities'].values()):
        abilities.add(norm_name(a))
abilities.remove('')
for i, a in enumerate(sorted(list(abilities))):
    ability_dict[a] = i
with open('items.json', 'r') as fp:
    items = json.load(fp)
item_dict = {}
for item in items:
    if 'english' in item['name']:
        name = item['name']['english']
    else:
        name = item['name']
    item_dict[norm_name(name)] = item['id']
item_dict[None] = 1039
status_dict = {'slp': 0, 'brn': 1, 'frz': 2, 'par':3,'psn':4,'tox':5}
field_dict = {'electricterrain':0, 'grassyterrain':1, 'mistyterrain':2, 'psychicterrain':3}
weather_dict = {'raindance': 0, 'sunnyday': 1, 'sandstorm': 2, 'hail':3}
volatile_status_dict = {}
volatile_statuses = set()
for x in all_moves:
    if 'volatileStatus' in all_moves[x]:
        volatile_statuses.add(all_moves[x]['volatileStatus'])
    if 'self' in all_moves[x] and 'volatileStatus' in all_moves[x]['self']:
        volatile_statuses.add(all_moves[x]['self']['volatileStatus'])
    if all_moves[x]['secondary'] != False and 'volatileStatus' in all_moves[x]['secondary']:
        volatile_statuses.add(all_moves[x]['secondary']['volatileStatus'])
    if 'onTryMove' in all_moves[x]:
        volatile_statuses.add(x)
    if 'onTryHit' in all_moves[x]:
        volatile_statuses.add(x)
    if all_moves[x]['secondary'] != False and 'onHit' in all_moves[x]['secondary']:
        volatile_statuses.add(x)
for i, x in enumerate(list(volatile_statuses)):
    volatile_status_dict[x] = i

type_dict = {}
type_set = set()
for x in pokedex:
    for i in pokedex[x]['types']:
        type_set.add(norm_name(i))
type_set.remove('bird')
for i, a in enumerate(list(type_set)):
    type_dict[a] = i
all_move_list = set(x for x in all_moves)

id_2_pokemon = {}
for x in pokedex:
    if x not in ['slowkinggalar', 'zapdosgalar', 'corviknightgmax', 'laprasgmax']:
        id_2_pokemon[pokedex[x]['num']] = x
    
id_2_ability = {}
for x in ability_dict:
    id_2_ability[ability_dict[x]] = x
    
id_2_item = {}
for x in item_dict:
    id_2_item[item_dict[x]] = x
    
id_2_move = {}
for x in all_moves:
    id_2_move[all_moves[x]['num']] = x
    
id_2_status = {}
for x in status_dict:
    id_2_status[status_dict[x]] = x

id_2_weather = {}
for x in weather_dict:
    id_2_weather[weather_dict[x]] = x
    
id_2_field = {}
for x in field_dict:
    id_2_field[field_dict[x]] = x

In [3]:
def rename_baseStats(baseStats):
    return {'hp': baseStats['hp'], 'attack': baseStats['atk'], 'defense': baseStats['def'], 
           'special-attack': baseStats['spa'], 'special-defense': baseStats['spd'], 'speed': baseStats['spe']}
def max_pp(pp):
    return np.floor(pp * 1.6)

def create_mutator(team):

    poke_moves = []
    with open(team) as f:
        for line in f:
            line = line.strip().split()
            if line[0] == '===':
                pokemon = []
                marker = 0
            if len(line) == 0:
                stats = calculate_stats(rename_baseStats(pokedex[name]['baseStats']),
                                100, ivs=ivs, evs=evs, nature=nature)   


                mon = Pokemon(
                                identifier = name,
                                level=100,
                                types=types,
                                hp = stats['hp'],
                                maxhp = stats['hp'],
                                ability = ability,
                                item = item,
                                attack=stats['attack'],
                                defense=stats['defense'],
                                special_attack=stats['special-attack'],
                                special_defense=stats['special-defense'],
                                speed=stats['speed'],
                                attack_boost=0,
                                defense_boost=0,
                                special_attack_boost=0,
                                special_defense_boost=0,
                                speed_boost=0,
                                accuracy_boost=0,
                                evasion_boost=0,
                                status=None,
                                volatile_status = set(),
                                moves= poke_moves)
                pokemon.append(mon)
                marker = 0
                item = None
                name = None
                nature = 'hardy'
                ability = 'illuminate'
                evs = [0,0,0,0,0,0]
                ivs = [31,31,31,31,31,31]
                poke_moves = []
                stat_map = {'HP': 0, 'Atk': 1, 'Def': 2, 'SpA': 3, 'SpD': 4, 'Spe': 5}
                continue
            if marker == 0:
                if '@' in line:
                    at = line.index('@')
                    name = norm_name(' '.join(line[:at]))
                    item = norm_name(' '.join(line[at + 1:]))
                else:
                    name = norm_name(' '.join(line))
                types = [norm_name(x) for x in pokedex[name]['types']]
                marker = 1
            if line[0] == 'Ability:':
                ability = norm_name(' '.join(line[1:]))
            if line[0] == 'EVs:':
                line.append('/')
                while '/' in line:
                    sep = line.index('/')
                    stat = line[sep - 1]
                    value = int(line[sep - 2])
                    evs[stat_map[stat]] = value
                    line = line[sep + 1:]
                continue
            if line[0] == 'IVs:':
                line.append('/')
                while '/' in line:
                    sep = line.index('/')
                    stat = line[sep - 1]
                    value = int(line[sep - 2])
                    ivs[stat_map[stat]] = value
                    line = line[sep + 1:]
                continue
            if 'Nature' == line[-1]:
                nature = norm_name(line[0])
            if line[0] == '-':
                move_name = norm_name(' '.join(line[1:]))
                poke_moves.append({'id': move_name, 'disabled': False, 'current_pp': max_pp(all_moves[move_name]['pp'])})
        stats = calculate_stats(rename_baseStats(pokedex[name]['baseStats']),
                                100, ivs=ivs, evs=evs, nature=nature)   


        mon = Pokemon(
                        identifier = name,
                        level=100,
                        types=types,
                        hp = stats['hp'],
                        maxhp = stats['hp'],
                        ability = ability,
                        item = item,
                        attack=stats['attack'],
                        defense=stats['defense'],
                        special_attack=stats['special-attack'],
                        special_defense=stats['special-defense'],
                        speed=stats['speed'],
                        attack_boost=0,
                        defense_boost=0,
                        special_attack_boost=0,
                        special_defense_boost=0,
                        speed_boost=0,
                        accuracy_boost=0,
                        evasion_boost=0,
                        status=None,
                        volatile_status = set(),
                        moves= poke_moves)
        pokemon.append(mon)
        reserve = {x.id: x for x in pokemon[1:]}
        side = Side(
            active = pokemon[0],
            reserve = reserve,
            wish = (0, 0),
            side_conditions = defaultdict(int),
            ls_count=0,
            ref_count=0,
            av_count=0
            )
        sides.append(side)
    state = State(
            user=sides[0],
            opponent=sides[1],
            weather=None,
            field=None,
            trick_room=False,
            weather_count=0,
            terrain_count=0
        )
    return StateMutator(state)

def create_mutator_from_files(file1, file2):
    teams = [file1, file2]
    sides = []
    for team in teams:
        item = None
        name = None
        nature = 'hardy'
        ability = 'illuminate'
        evs = [0,0,0,0,0,0]
        ivs = [31,31,31,31,31,31]
        poke_moves = []
        stat_map = {'HP': 0, 'Atk': 1, 'Def': 2, 'SpA': 3, 'SpD': 4, 'Spe': 5}
        with open(team) as f:
            pokemon = []
            marker = 0
            for line in f:
                line = line.strip().split()
                if len(line) == 0:
                    stats = calculate_stats(rename_baseStats(pokedex[name]['baseStats']),
                                    100, ivs=ivs, evs=evs, nature=nature)   


                    mon = Pokemon(
                                    identifier = name,
                                    level=100,
                                    types=types,
                                    hp = stats['hp'],
                                    maxhp = stats['hp'],
                                    ability = ability,
                                    item = item,
                                    attack=stats['attack'],
                                    defense=stats['defense'],
                                    special_attack=stats['special-attack'],
                                    special_defense=stats['special-defense'],
                                    speed=stats['speed'],
                                    attack_boost=0,
                                    defense_boost=0,
                                    special_attack_boost=0,
                                    special_defense_boost=0,
                                    speed_boost=0,
                                    accuracy_boost=0,
                                    evasion_boost=0,
                                    status=None,
                                    volatile_status = set(),
                                    moves= poke_moves)
                    pokemon.append(mon)
                    marker = 0
                    item = None
                    name = None
                    nature = 'hardy'
                    ability = 'illuminate'
                    evs = [0,0,0,0,0,0]
                    ivs = [31,31,31,31,31,31]
                    poke_moves = []
                    stat_map = {'HP': 0, 'Atk': 1, 'Def': 2, 'SpA': 3, 'SpD': 4, 'Spe': 5}
                    continue
                if marker == 0:
                    if '@' in line:
                        at = line.index('@')
                        name = norm_name(' '.join(line[:at]))
                        item = norm_name(' '.join(line[at + 1:]))
                    else:
                        name = norm_name(' '.join(line))
                    types = [norm_name(x) for x in pokedex[name]['types']]
                    marker = 1
                if line[0] == 'Ability:':
                    ability = norm_name(' '.join(line[1:]))
                if line[0] == 'EVs:':
                    line.append('/')
                    while '/' in line:
                        sep = line.index('/')
                        stat = line[sep - 1]
                        value = int(line[sep - 2])
                        evs[stat_map[stat]] = value
                        line = line[sep + 1:]
                    continue
                if line[0] == 'IVs:':
                    line.append('/')
                    while '/' in line:
                        sep = line.index('/')
                        stat = line[sep - 1]
                        value = int(line[sep - 2])
                        ivs[stat_map[stat]] = value
                        line = line[sep + 1:]
                    continue
                if 'Nature' == line[-1]:
                    nature = norm_name(line[0])
                if line[0] == '-':
                    move_name = norm_name(' '.join(line[1:]))
                    poke_moves.append({'id': move_name, 'disabled': False, 'current_pp': max_pp(all_moves[move_name]['pp'])})
            stats = calculate_stats(rename_baseStats(pokedex[name]['baseStats']),
                                    100, ivs=ivs, evs=evs, nature=nature)   


            mon = Pokemon(
                            identifier = name,
                            level=100,
                            types=types,
                            hp = stats['hp'],
                            maxhp = stats['hp'],
                            ability = ability,
                            item = item,
                            attack=stats['attack'],
                            defense=stats['defense'],
                            special_attack=stats['special-attack'],
                            special_defense=stats['special-defense'],
                            speed=stats['speed'],
                            attack_boost=0,
                            defense_boost=0,
                            special_attack_boost=0,
                            special_defense_boost=0,
                            speed_boost=0,
                            accuracy_boost=0,
                            evasion_boost=0,
                            status=None,
                            volatile_status = set(),
                            moves= poke_moves)
            pokemon.append(mon)
            reserve = {x.id: x for x in pokemon[1:]}
            side = Side(
                active = pokemon[0],
                reserve = reserve,
                wish = (0, 0),
                side_conditions = defaultdict(int),
                ls_count=0,
                ref_count=0,
                av_count=0
                )
            sides.append(side)
    state = State(
            user=sides[0],
            opponent=sides[1],
            weather=None,
            field=None,
            trick_room=False,
            weather_count=0,
            terrain_count=0
        )
    return StateMutator(state)

def display_battle(state_log, move_log, instruction_log, battle_num):
    battle_len = len(state_log[battle_num])
    for i in range(battle_len - 1):
        print('TURN ' + str(i+1))
        print('')
        display_turn(state_log, move_log, instruction_log, battle_num, i)
        print('')
        print('-----------------------------------------------------------------------------------------------------')

def display_turn(state_log, move_log, instruction_log, battle_num, turn_num):
    instructions = instruction_log[battle_num][turn_num]
    moves = move_log[battle_num][turn_num]
    display_side(state_log, move_log, battle_num, turn_num)
    print('')
    print('-----------------------------------------------------------------------------------------------------')
    print('My move: ' + moves[0])
    print('My expected value ' + str(moves[2]))
    print('Opp move: ' + moves[1])
    print('Opponent expected value ' + str(moves[3]))
    print('Instructions:' )
    print(instructions)
    print('-----------------------------------------------------------------------------------------------------')
    print('')
    display_side(state_log, move_log, battle_num, turn_num + 1)
    
def display_side(state_log, move_log,battle_num, turn_num):
    turn = state_log[battle_num][turn_num]
    side_conds = 'My side conditions- '
    for i in turn.self.side_conditions:
        if turn.self.side_conditions[i] != 0:
            side_conds += i +': ' + str(turn.self.side_conditions[i])
    print(side_conds)
    print('')
    display_mon(turn.self.active)
    print('')
    print('Weather: ', turn.weather)
    print('')
    display_mon(turn.opponent.active)
    print('')
    opp_side_conds = 'Opponent side conditions- '
    for i in turn.opponent.side_conditions:
        if turn.opponent.side_conditions[i] != 0:
            opp_side_conds += i +': ' + str(turn.opponent.side_conditions[i])
    print(opp_side_conds)

def display_mon(mon):
    info = '     ' + mon.id + ', ' + str(mon.item)
    if mon.attack_boost != 0:
        info += ', att_b: ' + str(mon.attack_boost)
    if mon.defense_boost != 0:
        info += ', def_b: ' + str(mon.defense_boost)
    if mon.special_attack_boost != 0:
        info += ', spa_b: ' + str(mon.special_attack_boost)
    if mon.special_defense_boost != 0:
        info += ', spdef_b: ' + str(mon.special_defense_boost)
    if mon.speed_boost != 0:
        info += ', spd_b: ' + str(mon.speed_boost)
    if mon.accuracy_boost != 0:
        info += ', acc_b: ' + str(mon.accuracy_boost)
    if mon.evasion_boost != 0:
        info += ', eva_b: ' + str(mon.evasion_boost)
    print(info)
    do_health(mon.hp, mon.maxhp, mon.status, 35)

def do_health(health, maxHealth, status,healthDashes):
    dashConvert = int(maxHealth/healthDashes)            # Get the number to divide by to convert health to dashes (being 10)
    currentDashes = int(health/dashConvert)              # Convert health to dash count: 80/10 => 8 dashes
    remainingHealth = healthDashes - currentDashes       # Get the health remaining to fill as space => 12 spaces

    healthDisplay = '-' * currentDashes                  # Convert 8 to 8 dashes as a string:   "--------"
    remainingDisplay = ' ' * remainingHealth             # Convert 12 to 12 spaces as a string: "            "
    percent = str(int(np.ceil(health/maxHealth*100) ))+ "%"     # Get the percent as a whole number:   40%

    print("|" + healthDisplay + remainingDisplay + "|")  # Print out textbased healthbar
    if status == None:
        print("        " + percent + ', ' + str(int(health)) + '/' + str(maxHealth))   
    else:
        print("        " + percent + ', ' + str(int(health)) + '/' + str(maxHealth) + ', '+ status)  

def battle_turn(new_battle, my_move, opp_move, instruction_log):
        transpose_instructions = get_all_state_instructions(new_battle, my_move,opp_move)
        percentage_list = [x.percentage for x in transpose_instructions]
        instruction_list = [x.instructions for x in transpose_instructions]
#         print(transpose_instructions)
        instruction = random.choices(instruction_list, weights=percentage_list, k=1)[0]
#         print('')
#         print(instruction)
        new_battle.apply(instruction)
        instruction_log[-1].append(instruction)
        
        if 'switch ' !=  my_move[:7]:
            for x in new_battle.state.self.active.moves:
                if x['id'] == my_move:
                    if new_battle.state.opponent.active.ability != 'Pressure':
                        x['current_pp'] -= 1
                    else:
                        x['current_pp'] -= 2
                    if x['current_pp'] <= 0:
                        x['disabled'] = True
        else:
            for x in new_battle.state.self.reserve.values():
                x.attack_boost = 0
                x.defense_boost = 0
                x.special_attack_boost = 0
                x.special_defense_boost = 0
                x.speed_boost = 0
                x.evasion_boost = 0
                x.accuracy_boost = 0
                if hasattr(x, 'volatileStatus'):
                    x.volatileStatus = set()

        if 'switch ' !=  opp_move[:7]:
            for x in new_battle.state.opponent.active.moves:
                if x['id'] == opp_move:
                    if new_battle.state.self.active.ability != 'Pressure':
                        x['current_pp'] -= 1
                    else:
                        x['current_pp'] -= 2
                    if x['current_pp'] <= 0:
                        x['disabled'] = True

        else:
            for x in new_battle.state.self.reserve.values():
                x.attack_boost = 0
                x.defense_boost = 0
                x.special_attack_boost = 0
                x.special_defense_boost = 0
                x.speed_boost = 0                
                x.evasion_boost = 0
                x.accuracy_boost = 0  
                if hasattr(x, 'volatileStatus'):
                    x.volatileStatus = set()
                    
        return new_battle, instruction_log

def flip_state(state):
    state_flipped = deepcopy(state)
    tmp = state_flipped.self
    state_flipped.self = state_flipped.opponent
    state_flipped.opponent = tmp
    return state_flipped

def battle(p1, p2, mutator, state_log, move_log, instruction_log):
    new_battle = deepcopy(mutator)
    while new_battle.state.battle_is_finished() == False:
        flipped_state = flip_state(new_battle.state)
        state_log[-1].append(deepcopy(new_battle.state))
        my_move = p1.move(new_battle.state)
        opp_move = p2.move(flipped_state)
        move_log[-1].append((my_move, opp_move))
        new_battle, instruction_log = battle_turn(new_battle, my_move, opp_move, instruction_log)
    state_log[-1].append(deepcopy(new_battle.state))
    reward = new_battle.state.battle_is_finished()
    if reward == -1:
        reward = 0
    return state_log, move_log, instruction_log, reward

def import_teams(team_nums):
    sides = []
    for team_id in team_nums:
        team = teams['teams'][team_id]
        pokemon = []
        random.shuffle(team['pokemon'])
        stat_map = {'hp': 0, 'atk': 1, 'def': 2, 'spa': 3, 'spd': 4, 'spe': 5}
        for mon in team['pokemon']:
            name = norm_name(mon['name'])  
            if name == 'gastrodoneast':
                name = 'gastrodon'
            if name == 'ditto':
                return 1+'e'
            poke_moves = []
            evs = [0,0,0,0,0,0]
            ivs = [31,31,31,31,31,31]
            nature = None
            if 'nature' in mon:
                nature = norm_name(mon['nature'])
            if 'ivs' in mon:
                for stat in mon['ivs']:
                    ivs[stat_map[stat]] = mon['ivs'][stat]
            if 'evs' in mon:
                for stat in mon['evs']:
                    evs[stat_map[stat]] = mon['evs'][stat]
            stats = calculate_stats(rename_baseStats(pokedex[name]['baseStats']),
                                        100, ivs=ivs, evs=evs, nature=nature)  
            ability = norm_name(mon['ability'])
#             ability = process.extractOne(norm_name(mon['ability']), abilities)
#             if ability[1] >= 0.8:
#                 ability = ability[0]
#             else:
#                 ability = 1 + 'e'
            if ability not in abilities:
                return 1 + 'e'
            for move in mon['moves']:
                move = norm_name(move)
                if move == 'uturnthunderpunch' or move == 'uturnearthquake':
                    move = 'uturn'
#                 move = process.extractOne(move, all_move_list)
#                 if move[1] >= 0.8:
#                     move = move[0]
#                 else:
#                     move = 1 + 'e'
                if move not in all_move_list:
                    return 1 + 'e'
                poke_moves.append({'id': move, 'disabled': False, 'current_pp': max_pp(all_moves[move]['pp'])})
            item = None
            if 'item' in mon:
                item = norm_name(mon['item'])
                if '/' in item:
                    item = item[:item.find('/')]
            pokemon.append(Pokemon(
                identifier = name,
                level=100,
                types=[norm_name(x) for x in pokedex[name]['types']],
                hp = stats['hp'],
                maxhp = stats['hp'],
                ability = norm_name(mon['ability']),
                item = item,
                attack=stats['attack'],
                defense=stats['defense'],
                special_attack=stats['special-attack'],
                special_defense=stats['special-defense'],
                speed=stats['speed'],
                attack_boost=0,
                defense_boost=0,
                special_attack_boost=0,
                special_defense_boost=0,
                speed_boost=0,
                accuracy_boost=0,
                evasion_boost=0,
                status=None,
                volatile_status = set(),
                moves= poke_moves))
        if len(pokemon) != 6:
            return 1+'e'
        reserve = {x.id: x for x in pokemon[1:]}
        side = Side(
            active = pokemon[0],
            reserve = reserve,
            wish = (0, 0),
            side_conditions = defaultdict(int),
            ls_count=0,
            ref_count=0,
            av_count=0
            )
        sides.append(side)
    state = State(
                user=sides[0],
                opponent=sides[1],
                weather=None,
                field=None,
                trick_room=False,
                weather_count=0,
                terrain_count=0
            )
    return StateMutator(state)

def score_diff(mutator):
    my_score, opp_score = 0, 0
    if mutator.state.self.active.hp != 0:
        my_score += 1
    for mon in mutator.state.self.reserve:
        if mutator.state.self.reserve[mon].hp != 0:
            my_score += 1
    if mutator.state.opponent.active.hp != 0:
        opp_score += 1
    for mon in mutator.state.opponent.reserve:
        if mutator.state.opponent.reserve[mon].hp != 0:
            opp_score += 1
    if my_score != 0 and opp_score != 0:
        if my_score < opp_score:
            my_score = 0
        else:
            opp_score = 0
    return np.abs(my_score - opp_score)

def convert_mon(mon, active):
    name = [pokedex[mon.id]['num']]
    ability = [ability_dict[mon.ability]]
    status = [0] * 6
    if mon.status:
        status[status_dict[mon.status]] = 1
    item = [item_dict[mon.item]]
    types = [0] * len(type_dict)
    for a in pokedex[mon.id]['types']:
        types[type_dict[norm_name(a)]] = 1
    moves = [all_moves[x['id']]['num'] for x in mon.moves]
    avg_stats = calculate_stats(rename_baseStats(pokedex[mon.id]['baseStats']),
                                    100, ivs=(31,31,31,31,31,31), evs=(0,0,0,0,0,0), nature='hardy')
    stats = [mon.hp / mon.maxhp - 0.5, mon.hp/avg_stats['hp'] - 1, mon.attack/avg_stats['attack'] - 1, mon.defense/avg_stats['defense'] - 1,
             mon.special_attack/avg_stats['special-attack'] - 1, mon.special_defense/avg_stats['special-defense'] - 1, mon.speed/avg_stats['speed'] - 1]
    stat_boosts = [mon.attack_boost/6, mon.defense_boost/6, mon.special_attack_boost/6, 
                                mon.special_defense_boost/6, mon.speed_boost/6, mon.accuracy_boost/6, mon.evasion_boost/6]
    move_pp = [x['current_pp'] / (1.6*all_moves[x['id']]['pp']) for x in mon.moves]
    move_disabled = [x['disabled']*1 for x in mon.moves]
    status = [0] * 6
    if mon.status:
        status[status_dict[mon.status]] = 1
    if active:
        trapped_counter = [mon.part_trapped_counter]
        sub_hp = [mon.substitute_hp]
        volatile_status = [0]*len(volatile_status_dict)
        if mon.volatile_status:
            for vs in mon.volatile_status:
                volatile_status[volatile_status_dict[vs]] = 1
        return torch.Tensor(name+ability+item+types+moves+status+stats+stat_boosts+move_pp+move_disabled+sub_hp+trapped_counter+volatile_status)
    return torch.Tensor(name+ability+item+types+moves+status+stats+stat_boosts+move_pp+move_disabled)

def convert_side(side):
    result = []
    team = [side.active]
    sorted_reserve = [side.reserve[x] for x in sorted(side.reserve)]
    team.extend(sorted_reserve)
    for mon in team:
        mon.moves.sort(key=lambda move: move['id'])
    result.append(convert_mon(team[0], True))
    for mon in team[1:]:
        result.append(convert_mon(mon, False))
    result.append(torch.Tensor([side.wish[0], side.wish[1], side.side_conditions['stealthrock'],
                  side.side_conditions['spikes'],side.side_conditions['toxicspikes']
                   , side.ls_count, side.ref_count, side.av_count]))
    return torch.cat(result)
    
def convert_state(mutator):
    side1 = mutator.state.self
    side2 = mutator.state.opponent
    sides = [torch.cat([convert_side(side1), convert_side(side2)])]
    weather = [0] * 4
    if mutator.state.weather:
        weather[weather_dict[mutator.state.weather]] = 1
    field = [0] * 4
    if mutator.state.field:
        field[field_dict[mutator.state.field]] = 1
    tr = [0]
    if mutator.state.trick_room:
        tr = [1]
    sides.append(torch.Tensor(weather + field + tr + [mutator.state.weather_count] + [mutator.state.terrain_count]))
    return torch.unsqueeze(torch.cat(sides), 0)
    
def standardize(mutator):
    mut_copy = deepcopy(mutator)
    mut_copy.state.self.active.moves.sort(key=lambda move: move['id'])
    mut_copy.state.opponent.active.moves.sort(key=lambda move: move['id'])
    for mon in mut_copy.state.self.reserve:
        mut_copy.state.self.reserve[mon].moves.sort(key=lambda move: move['id'])
    for mon in mut_copy.state.opponent.reserve:
        mut_copy.state.opponent.reserve[mon].moves.sort(key=lambda move: move['id'])
    mut_copy.state.self.reserve = {x: mut_copy.state.self.reserve[x] for x in sorted(mut_copy.state.self.reserve)}
    mut_copy.state.opponent.reserve = {x: mut_copy.state.opponent.reserve[x] for x in sorted(mut_copy.state.opponent.reserve)}
    return mut_copy

def state_to_mon(mon_info, active):

    pokemon = id_2_pokemon[mon_info[0].item()]
    types = [norm_name(x) for x in pokedex[pokemon]['types']]
    ability = id_2_ability[mon_info[1].item()]
    item = id_2_item[mon_info[2].item()]
    moves = []
    for i in range(4):
        moves.append({'id' : id_2_move[mon_info[3+i].item()], 'disabled' : bool(mon_info[31+i].item()),
     'current_pp': round(mon_info[27+i].item() * max_pp(all_moves[id_2_move[mon_info[3+i].item()]]['pp']))})
    status = None
    for i in range(6):
        if mon_info[7 + i].item() == 1:
            status = id_2_status(i)
    stats = {}
    avg_stats = calculate_stats(rename_baseStats(pokedex[pokemon]['baseStats']),
                                    100, ivs=(31,31,31,31,31,31), evs=(0,0,0,0,0,0), nature='hardy')
    stats['maxhp'] = round((mon_info[14].item() + 1) * avg_stats['hp'])
    stats['hp'] = round((mon_info[13].item() + 0.5) * stats['maxhp'])
    stats['attack'] = round((mon_info[15].item() + 1) * avg_stats['attack'])
    stats['defense'] = round((mon_info[16].item() + 1) * avg_stats['defense'])
    stats['special-attack'] = round((mon_info[17].item() + 1) * avg_stats['special-attack'])
    stats['special-defense'] = round((mon_info[18].item() + 1) * avg_stats['special-defense'])
    stats['speed'] = round((mon_info[19].item() + 1) * avg_stats['speed'])
    boosts = {}
    for i, stat in enumerate(['attack', 'defense', 'special-attack', 'special-defense', 'speed', 'accuracy', 'evasion']):
        boosts[stat] = int(mon_info[20+i].item() * 6)
    sub_hp = 0
    part_trapped_counter = 0
    volatile_status = set()
    if active:
        sub_hp = int(mon_info[35].item())
        part_trapped_counter = int(mon_info[36].item())
        for i in range(6):
            if mon_info[37 + i].item() == 1:
                volatile_status.add(id_2_status(i))

    mon = Pokemon(
            identifier = pokemon,
            level=100,
            types=types,
            hp = stats['hp'],
            maxhp = stats['maxhp'],
            ability = ability,
            item = item,
            attack=stats['attack'],
            defense=stats['defense'],
            special_attack=stats['special-attack'],
            special_defense=stats['special-defense'],
            speed=stats['speed'],
            attack_boost=boosts['attack'],
            defense_boost=boosts['defense'],
            special_attack_boost=boosts['special-attack'],
            special_defense_boost=boosts['special-defense'],
            speed_boost=boosts['speed'],
            accuracy_boost=boosts['accuracy'],
            evasion_boost=boosts['evasion'],
            status=status,
            volatile_status = volatile_status,
            substitute_hp = sub_hp,
            trapped_counter = part_trapped_counter,
            moves= moves)
    return mon

def detailed_state(state_log, battle_num, turn_num):
    print('TURN ' + str(turn_num))
    turn = state_log[battle_num][turn_num - 1]
    print('')
    side_conds = 'My side conditions- '
    for i in turn.self.side_conditions:
        if turn.self.side_conditions[i] != 0:
            side_conds += i +': ' + str(turn.self.side_conditions[i])
    print(side_conds)
    print('')
    display_mon(turn.self.active)
    print('')
    for mon in turn.self.reserve.values():
        display_mon(mon)
        print('')
    print('-----------------------------------------------------------------------------------------------------') 
    print('Weather: ', turn.weather)
    print('-----------------------------------------------------------------------------------------------------') 
    print('')
    display_mon(turn.opponent.active)
    print('')
    for mon in turn.opponent.reserve.values():
        display_mon(mon)
        print('')
    opp_side_conds = 'Opponent side conditions- '
    for i in turn.opponent.side_conditions:
        if turn.opponent.side_conditions[i] != 0:
            opp_side_conds += i +': ' + str(turn.opponent.side_conditions[i])
    print(opp_side_conds)
    print('') 
    
def state_to_side(state):
    active = state_to_mon(state[:43], True)
    reserve = []
    for i in range(5):
        reserve.append(state_to_mon(state[35*i + 43:35*(i+1) + 43], False))
    reserve = {x.id: x for x in reserve}
    wish = (int(state[218].item()), int(state[219].item()))
    side_conds = [int(state[220 +i ].item()) for i in range(3)]
    side_conditions = defaultdict(int)
    if side_conds[0] != 0:
        side_conditions['stealthrock'] = side_conds[0]
    if side_conds[1] != 0:
        side_conditions['spikes'] = side_conds[1]
    if side_conds[2] != 0:
        side_conditions['toxicspikes'] = side_conds[2]
    screens = [int(state[223 +i ].item()) for i in range(3)]
    side = Side(
        active = active,
        reserve = reserve,
        wish = wish,
        side_conditions = side_conditions,
        ls_count=screens[0],
        ref_count=screens[1],
        av_count=screens[2]
        )    
    return side
def state_to_mut(state):
    state = state[0]
    my_side = state_to_side(state[:226])
    opp_side = state_to_side(state[226:452])
    weather = None
    for i in range(4):
        if state[452 + i].item() == 1:
            weather = id_2_weather(i)
    field = None
    for i in range(4):
        if state[456 + i].item() == 1:
            field = id_2_field(i)
    trick_room = bool(state[460].item())
    mut_state = State(
            user=my_side,
            opponent=opp_side,
            weather=weather,
            field=field,
            trick_room=trick_room,
            weather_count=int(state[461].item()),
            terrain_count=int(state[462].item())
            )
    return StateMutator(mut_state)

In [4]:
# importlib.reload(sys.modules['showdown.engine.damage_calculator'])
# importlib.reload(sys.modules['showdown.engine.instruction_generator'])
# importlib.reload(sys.modules['showdown.engine.objects'])
# importlib.reload(sys.modules['showdown.engine.special_effects.abilities.on_switch_in'])
# importlib.reload(sys.modules['showdown.engine.find_state_instructions'])
# importlib.reload(sys.modules['showdown.engine.special_effects.abilities.modify_attack_against'])

In [5]:
class MCTSPlayer():
    def __init__(self, model=None):
        self.visited = set()
        self.p = defaultdict(lambda: defaultdict(int))
        self.q = defaultdict(lambda: defaultdict(int))
        self.n = defaultdict(lambda: defaultdict(int))
        self.visited = set()
        self.move_log = []
        self.instruction_log = []
        self.state_log = []
        self.n_log = []
        self.bug = []
        self.bug2 = []

    def flip_state(self, state):
        state_flipped = state
        tmp = state_flipped.self
        state_flipped.self = state_flipped.opponent
        state_flipped.opponent = tmp
        return state_flipped
        
    def battle_turn(self, new_battle, my_move, opp_move):
        transpose_instructions = get_all_state_instructions(new_battle, my_move,opp_move)
        percentage_list = [x.percentage for x in transpose_instructions]
        instruction_list = [x.instructions for x in transpose_instructions]
#         print(transpose_instructions)
        instruction = random.choices(instruction_list, weights=percentage_list, k=1)[0]
        new_battle.apply(instruction)
                    
        return new_battle, instruction
    
    def select_random_move(self, new_battle):
        return np.random.choice(new_battle.state.get_self_options())
    
    def heuristic(self, new_battle, move, u):
        my_speed = get_effective_speed(new_battle.state, new_battle.state.self)
        opp_speed = get_effective_speed(new_battle.state, new_battle.state.opponent)
        if (move.split()[0] != 'switch' and all_moves[move].get('heal') and 
            new_battle.state.self.active.hp > 0.9* new_battle.state.self.active.maxhp and my_speed > opp_speed):
            return -float('inf')
        if (move.split()[0] != 'switch' and 'sideCondition' in all_moves[move] and 
           all_moves[move]['sideCondition'] in ['stealthrock', 'stickyweb'] and move in new_battle.state.opponent.side_conditions
           and my_speed > opp_speed):
            return -float('inf')
        
        if (move.split()[0] != 'switch' and move in ['defog', 'rapidspin'] and 
            (len(new_battle.state.self.side_conditions) == 0 and 'lightscreen' not in new_battle.state.opponent.side_conditions
             and 'reflect' not in new_battle.state.opponent.side_conditions and 'auroraveil' not in new_battle.state.opponent.side_conditions)
             and my_speed > opp_speed):
            return -float('inf')
#         active_moves = [x['id'] for x in new_battle.state.self.active.moves]
#         if (('voltswitch' in active_moves or 'uturn' in active_moves) and 
#             my_speed > opp_speed and move.split()[0] == 'switch') :
#             return -float('inf')
        return u
    
    def select_move(self, new_battle, s, tuple_s):
        if np.random.rand() < 0.05:
            return np.random.choice(new_battle.state.get_self_options())
        zero_arms = []
        for a in new_battle.state.get_self_options():
            if a not in self.n[tuple_s]:
                zero_arms.append(a)
        if len(zero_arms) > 0:
            return random.choice(zero_arms)
        max_u, best_a = -float("inf"), -1
        for a in new_battle.state.get_self_options():
#             print('q: ', self.q[tuple_s][a])
#             print('p: ', self.p[tuple_s][a])
#             print('n1: ', sum(self.n[tuple_s].values()))
#             print('n2: ', self.n[tuple_s][a])
            u = self.q[tuple_s][a] + 2*np.sqrt(2 * np.log(sum(self.n[tuple_s].values())))/(1+self.n[tuple_s][a])
            #u = self.q[tuple_s][a] + 1*self.p[tuple_s][a]*np.sqrt(sum(self.n[tuple_s].values()))/(1+self.n[tuple_s][a])
#             print('u: ', u)
            u = self.heuristic(new_battle, a, u)
            if u>max_u:
                max_u = u
                best_a = a
        if best_a == -1:
            return self.select_random_move(new_battle)
        return best_a
    
    def playout(self, new_battle, tuple_s, ins):
        all_ins = ins
        while True:
            if new_battle.state.battle_is_finished():
                result = new_battle.state.battle_is_finished()
                new_battle.reverse(ins)
                return result
            else:
                my_move = self.select_random_move(new_battle)
                flipped_battle = StateMutator(self.flip_state(new_battle.state))
                opp_move = self.select_random_move(flipped_battle)
            
                new_battle = StateMutator(self.flip_state(flipped_battle.state))
                new_battle, new_ins = self.battle_turn(new_battle, my_move, opp_move)
                ins += new_ins
        
    def search(self, mutator, s, tuple_s, opp_player, ins, move_list):
        new_battle = mutator
#         s = convert_state(new_battle)
#         tuple_s = tuple(s.tolist()[0])
       # #self.bug.append(new_battle.state.battle_is_finished())
        if new_battle.state.battle_is_finished(): 
            new_battle.reverse(ins)
            return new_battle.state.battle_is_finished()
        if tuple_s not in self.visited:
            self.visited.add(tuple_s)
            return self.playout(new_battle, tuple_s, ins)
        
        #self.bug.append(new_battle.state)
        my_move = self.select_move(new_battle, s, tuple_s)       
        
        flipped_battle = StateMutator(self.flip_state(new_battle.state))
        flipped_mat = convert_state(flipped_battle)
        tuple_flipped = tuple(flipped_mat.tolist()[0])
        opp_move = opp_player.select_move(flipped_battle, flipped_mat, tuple_flipped)
        
        new_battle = StateMutator(self.flip_state(flipped_battle.state))
        #print(my_move, opp_move)
        sp, new_ins = self.battle_turn(new_battle, my_move, opp_move)
        sp_s = convert_state(sp)
        sp_curr = tuple(sp_s.tolist()[0])
        v = self.search(sp, sp_s, sp_curr, opp_player, ins + new_ins, move_list + [(my_move, opp_move)])
        self.q[tuple_s][my_move] = (self.n[tuple_s][my_move]*self.q[tuple_s][my_move] + v)/(self.n[tuple_s][my_move]+1)
        self.n[tuple_s][my_move] += 1
        
        opp_player.q[tuple_flipped][opp_move] = (opp_player.n[tuple_flipped][opp_move]*opp_player.q[tuple_flipped][opp_move] - v)/(opp_player.n[tuple_flipped][opp_move]+1)
        opp_player.n[tuple_flipped][opp_move] += 1  
        return v
    
    def pi(self, tuple_s, options):
        result = []
        total = sum(self.n[tuple_s].values())
        if total == 0:
            return [1/len(options)] * len(options)
        for a in options:
            #print(a)
            result.append(self.n[tuple_s][a]/total)
        return result
    
    def executeEpisode(self, mutator, opp_player, team_nums, evaluate=True, num_sims=1, k=0):
        examples = []
        current_mut = standardize(mutator)
        turn = 1
        self.state_log.append([])
        self.move_log.append([])
        self.instruction_log.append([])
        self.n_log.append([])
        opp_player.n_log.append([])
        self.state_log[-1].append(deepcopy(current_mut.state))
        while True:
            #print(turn)
            if turn % 100 == 0:
                print('reached: ', turn)
            self.p = defaultdict(lambda: defaultdict(int))
            self.q = defaultdict(lambda: defaultdict(int))
            self.n = defaultdict(lambda: defaultdict(int))
            self.visited = set()
            all_options = current_mut.state.get_all_self_options()
            curr_mat = convert_state(current_mut)
            tuple_curr = tuple(curr_mat.tolist()[0])
            curr_options = current_mut.state.get_self_options()
            for i in range(num_sims):
#                print(i)
                self.search(current_mut, curr_mat,tuple_curr, opp_player, [], [])
            if len(curr_options) == 1 and curr_options[0] == 'splash':
                my_a = 'splash'
                self.n_log[-1].append(['splash', 1])
                my_exp_val = 'N/A'
                #print('splash')
            else:
                pi = self.pi(tuple_curr, all_options)
                sort_idx = np.argsort(np.array(pi))[::-1]
                all_sorted_options = [all_options[i] for i in sort_idx]
                my_a = all_options[np.argmax(pi)]
                self.n_log[-1].append([all_sorted_options, sorted(pi, reverse=True)])
                my_exp_val = self.q[tuple_curr][my_a]

            flipped_battle = StateMutator(self.flip_state(current_mut.state))
            #print(flipped_battle.state.self.active)
            flipped_mat = convert_state(flipped_battle)
            tuple_flipped = tuple(flipped_mat.tolist()[0])
            all_opp_options = flipped_battle.state.get_all_self_options()
            opp_curr_options = flipped_battle.state.get_self_options()
            if len(opp_curr_options) == 1 and opp_curr_options[0] == 'splash':
                opp_a = 'splash'
                opp_player.n_log[-1].append(['splash', 1])
                opp_exp_val = 'N/A'
                
            else:
                opp_pi = opp_player.pi(tuple_flipped, all_opp_options)
                opp_sort_idx = np.argsort(np.array(opp_pi))[::-1]
                all_sorted_opp_options = [all_opp_options[i] for i in opp_sort_idx]
                opp_a = all_opp_options[np.argmax(opp_pi)]
                opp_player.n_log[-1].append([all_sorted_opp_options, sorted(opp_pi, reverse=True)])
                opp_exp_val = opp_player.q[tuple_flipped][opp_a]
            
            current_mut = StateMutator(self.flip_state(flipped_battle.state))
            current_mut, ins = self.battle_turn(current_mut, my_a, opp_a)
            #print('real moves: ', (my_a, opp_a))
            self.move_log[-1].append((my_a, opp_a, my_exp_val, opp_exp_val))
            self.instruction_log[-1].append(ins)
            self.state_log[-1].append(deepcopy(current_mut.state))
            turn += 1
#           
#             print('')
            if my_a != 'splash' and opp_a != 'splash':
                examples.append([curr_mat, pi, 0])
                examples.append([flipped_mat, opp_pi, 1])
                
            reward = current_mut.state.battle_is_finished()
            if evaluate and reward:
                print(reward)
                #queue.put(reward)
                return
            if reward:
                for x in examples:
                    if x[2] == 0:
                        x[2] = reward
                    else:
                        x[2] = -reward
                #return examples
                print(turn)
                #queue.put(examples)
                team1, team2 = team_nums
                with open("battle_data/data_" + str(team1) + '_' + str(team2) + '_' + str(k)+"_6v6.pkl", "wb") as fp:   #Pickling
                    pickle.dump(examples, fp)
                return
            
#importlib.reload(sys.modules['showdown.engine.helpers'])

In [6]:
class PokeZero2(nn.Module):
    def __init__(self, num_hidden1=512, num_hidden2 = 512, num_hidden3=512, num_hidden4=512, 
                 num_hidden5=512, num_hidden6=512, num_hidden7=512, num_hidden8=512, embed_dim = 10, act_fn = 'relu'):
        super().__init__()
        # Initialize the modules we need to build the network
        self.embed_dim = embed_dim
        #self.embed_mon = nn.Embedding(num_embeddings = 898, embedding_dim = embed_dim)
        self.embed_ability = nn.Embedding(num_embeddings = 269, embedding_dim = embed_dim)
        self.embed_item = nn.Embedding(num_embeddings = 1278, embedding_dim = embed_dim)
        self.embed_moves = nn.Embedding(num_embeddings = 1000, embedding_dim = embed_dim)
        if act_fn == 'relu':
            self.act_fn = nn.ReLU()
        else:
            self.act_fn = nn.Tanh()
        self.linear1 = nn.Linear(813 + embed_dim * 72, num_hidden1)
        self.linear2 = nn.Linear(num_hidden1, num_hidden2)
        self.linear3 = nn.Linear(num_hidden2, num_hidden3)
        self.linear4 = nn.Linear(num_hidden3, num_hidden4)
        self.linear5 = nn.Linear(num_hidden4, 9)
        self.linear6 = nn.Linear(813 + embed_dim * 72, num_hidden5)
        self.linear7 = nn.Linear(num_hidden5, num_hidden6)
        self.linear8 = nn.Linear(num_hidden6, num_hidden7)
        self.linear9 = nn.Linear(num_hidden7, num_hidden8)
        self.linear10 = nn.Linear(num_hidden8, 1)
        self.linear11 = nn.Linear(num_hidden8, 1)
        self.linear12 = nn.Linear(num_hidden8, 1)

    
    def encode_mon(self, state, active):
        batch_size = state.shape[0]
        #emb_mon = self.embed_mon(state[:, 0].long())
        emb_ability = self.embed_ability(state[:, 1].long())
        emb_item = self.embed_item(state[:, 2].long())
        types = state[:, 3:21]
        embed_moves = self.embed_moves(state[:, 21:25].long()).reshape(batch_size, 4*self.embed_dim)
        status = state[:, 25:31]
        stats = state[:, 31:38]
        stat_boosts = state[:, 38:45]
        move_info = state[:, 45:53]
        if active:
            volatile = state[:, 53:170]
            return torch.cat((types, emb_ability, emb_item, embed_moves, status, stats, stat_boosts, move_info, volatile), 1)
        return torch.cat((types, emb_ability, emb_item, embed_moves, status, stats, stat_boosts, move_info), 1)
    
    def encode(self, state):
        all_mons = [self.encode_mon(state[:, :170], True)]
        for i in range(5):
            all_mons.append(self.encode_mon(state[:, 53*i + 170:53*(i+1) + 170], False))
        all_mons.append(state[:, 435:443])
        
        all_mons.append(self.encode_mon(state[:, 443:613], True))
        for i in range(5):
            all_mons.append(self.encode_mon(state[:, 53*i + 613:53*(i+1) + 613], False))    
        all_mons.append(state[:, 878:])
        return torch.cat(all_mons, 1)
    
    def forward(self, state):
        # Perform the calculation of the model to determine the prediction
        x = self.encode(state)
        
        p = self.linear1(x)
        p = self.act_fn(p)
        p = self.linear2(p)
        p = self.act_fn(p)
        p = self.linear3(p)
        p = self.act_fn(p)
        p = self.linear4(p)
        p = self.act_fn(p)
        p = self.linear5(p)
        
        v = self.linear6(x)
        v = self.act_fn(v)
        v = self.linear7(v)
        v = self.act_fn(v)
        v = self.linear8(v)
        v = self.act_fn(v)
        v = self.linear9(v)
        v = self.act_fn(v)
        v = self.linear10(v)
        return p, v

In [19]:
class MCTSNNPlayer():
    def __init__(self, model=None, model_dir = None):
        if model:
            model = PokeZero2()
            model.load_state_dict(torch.load(model_dir,  map_location=torch.device('cpu')))
            self.pz = model
        else:
            self.pz = PokeZero2()
        self.trainset = []
        self.visited = set()
        self.p = defaultdict(lambda: defaultdict(int))
        self.q = defaultdict(lambda: defaultdict(int))
        self.n = defaultdict(lambda: defaultdict(int))
        self.softmax = nn.Softmax(dim=0)
        self.move_log = []
        self.instruction_log = []
        self.n_log =[]
        self.state_log = []
        self.bug = []

    def flip_state(self, state):
        state_flipped = state
        tmp = state_flipped.self
        state_flipped.self = state_flipped.opponent
        state_flipped.opponent = tmp
        return state_flipped
        
    def battle_turn(self, new_battle, my_move, opp_move):
        transpose_instructions = get_all_state_instructions(new_battle, my_move,opp_move)
        percentage_list = [x.percentage for x in transpose_instructions]
        instruction_list = [x.instructions for x in transpose_instructions]
        instruction = random.choices(instruction_list, weights=percentage_list, k=1)[0]
        new_battle.apply(instruction)
                    
        return new_battle, instruction
    
    def select_random_move(self, new_battle):
        return np.random.choice(new_battle.state.get_self_options())

    def heuristic(self, new_battle, move, u):
        my_speed = get_effective_speed(new_battle.state, new_battle.state.self)
        opp_speed = get_effective_speed(new_battle.state, new_battle.state.opponent)
        if (move.split()[0] != 'switch' and all_moves[move].get('heal') and 
            new_battle.state.self.active.hp > 0.9* new_battle.state.self.active.maxhp and my_speed > opp_speed):
            return -float('inf')
        if (move.split()[0] != 'switch' and 'sideCondition' in all_moves[move] and 
           all_moves[move]['sideCondition'] in ['stealthrock', 'stickyweb'] and move in new_battle.state.opponent.side_conditions
           and my_speed > opp_speed):
            return -float('inf')

        if (move.split()[0] != 'switch' and move in ['defog', 'rapidspin'] and 
            (len(new_battle.state.self.side_conditions) == 0 and 'lightscreen' not in new_battle.state.opponent.side_conditions
             and 'reflect' not in new_battle.state.opponent.side_conditions and 'auroraveil' not in new_battle.state.opponent.side_conditions)
             and my_speed > opp_speed):
            return -float('inf')
    #         active_moves = [x['id'] for x in new_battle.state.self.active.moves]
    #         if (('voltswitch' in active_moves or 'uturn' in active_moves) and 
    #             my_speed > opp_speed and move.split()[0] == 'switch') :
    #             return -float('inf')
        return u
    
    def PUCT(self, qs, ps, ns, total, new_battle, a, d):
        u = qs[a] + d*ps[a]*np.sqrt(total)/(1+ns[a])
        u = self.heuristic(new_battle, a, u)
        return u

    def select_move(self, new_battle, s, tuple_s, c, d):
        if np.random.rand() < 0.05:
            return self.select_random_move(new_battle)
        self.n_forced = {}
        ns = self.n[tuple_s]
        ps = self.p[tuple_s]
        qs = self.q[tuple_s]
        total = sum(ns.values())
        for a in new_battle.state.get_self_options():
            self.n_forced[a] = (c*ps[a] *total)**(0.5)    
        zero_arms = []
        for a in new_battle.state.get_self_options():
            if a not in ns or ns[a] < self.n_forced[a]:
                zero_arms.append(a)
        if len(zero_arms) > 0:
            return random.choice(zero_arms)
        max_u, best_a = -float("inf"), -1
        for a in new_battle.state.get_self_options():
            u = self.PUCT(qs, ps, ns,total, new_battle, a, d)
            if u>max_u:
                max_u = u
                best_a = a
        if best_a == -1:
            return self.select_random_move(new_battle)
        return best_a
        
    def search(self, mutator, s, tuple_s, opp_player, ins, move_list, c, d):
        new_battle = mutator
        if new_battle.state.battle_is_finished(): 
            new_battle.reverse(ins)
            return new_battle.state.battle_is_finished()
        if tuple_s not in self.visited:
            self.visited.add(tuple_s)
            with torch.no_grad():
                p, v = self.pz.forward(s)
            p = self.softmax(p[0])
            if len(self.p) == 0 and not self.playout_cap:
                p = 0.75*p + 0.25* np.random.dirichlet((1, 1, 1, 1, 1, 1, 1, 1, 1), 1)[0]
            for i, a in enumerate(new_battle.state.get_all_self_options()):
                self.p[tuple_s][a] = p[i].item()
                
            flipped_battle = StateMutator(self.flip_state(new_battle.state))
            flipped_mat = convert_state(flipped_battle)
            tuple_flipped = tuple(flipped_mat.tolist()[0])
            with torch.no_grad():
                opp_p, _ = opp_player.pz.forward(flipped_mat)
            opp_p = opp_player.softmax(opp_p[0])
            if len(opp_player.p) == 0:
                opp_p = 0.75*opp_p + 0.25* np.random.dirichlet((1, 1, 1, 1, 1, 1, 1, 1, 1), 1)[0]
            for i, a in enumerate(flipped_battle.state.get_all_self_options()):
                opp_player.p[tuple_flipped][a] = opp_p[i].item()
            new_battle = StateMutator(self.flip_state(flipped_battle.state))
            new_battle.reverse(ins)
            return v[0].item()

        my_move = self.select_move(new_battle, s, tuple_s, c, d)       
        
        flipped_battle = StateMutator(self.flip_state(new_battle.state))
        flipped_mat = convert_state(flipped_battle)
        tuple_flipped = tuple(flipped_mat.tolist()[0])
        opp_move = opp_player.select_move(flipped_battle, flipped_mat, tuple_flipped, c, d)
        
        new_battle = StateMutator(self.flip_state(flipped_battle.state))
        #print(my_move, opp_move)
        sp, new_ins = self.battle_turn(new_battle, my_move, opp_move)
        sp_s = convert_state(sp)
        sp_curr = tuple(sp_s.tolist()[0])
        v = self.search(sp, sp_s, sp_curr, opp_player, ins + new_ins, move_list + [(my_move, opp_move)], c, d)
        self.q[tuple_s][my_move] = (self.n[tuple_s][my_move]*self.q[tuple_s][my_move] + v)/(self.n[tuple_s][my_move]+1)
        self.n[tuple_s][my_move] += 1
        
        opp_player.q[tuple_flipped][opp_move] = (opp_player.n[tuple_flipped][opp_move]*opp_player.q[tuple_flipped][opp_move] - v)/(opp_player.n[tuple_flipped][opp_move]+1)
        opp_player.n[tuple_flipped][opp_move] += 1  
        return v
    
    def max_puct(self,qs, ps, ns, total, new_battle, d):
        pucts = []
        for a in new_battle.state.get_self_options():
            pucts.append([a, self.PUCT(qs, ps, ns, total, new_battle, a, d)])
        pucts.sort(key=lambda x: x[1], reverse=True)
        maxpuct_a = pucts[0][0]
        return maxpuct_a
    
    def pi(self, new_battle, tuple_s, options,curr_options, c, d):
        result = []
        qs = self.q[tuple_s]
        ns = self.n[tuple_s]
        ps = self.p[tuple_s]
        total = sum(ns.values())
        if total == 0:
            return [1/len(options)] * len(options)
        best_a = self.max_puct(qs, ps, ns, total, new_battle, d)
        for a in curr_options:
            forced = (c*ps[a] *total)**(0.5)
            while forced > 0 and a != best_a and ns[a] > 0:
                ns[a] -= 1
                total -= 1
                if self.max_puct(qs, ps, ns, total, new_battle, d) != best_a:
                    ns[a] += 1
                    break
                forced -= 1
#         for a in curr_options:
#             if self.n[tuple_s][a] == 1:
#                 self.n[tuple_s][a] = 0
        total = sum(ns.values())
        for a in options:
            result.append(ns[a]/total)
        return result
    
    def executeEpisode(self, mutator, opp_player, team_nums, evaluate=True, num_full_sims=400, 
                       capped_sims = 100, generation=0, k=0, c=2, d=1, quit=False):
        np.random.seed()
        examples = []
        current_mut = standardize(mutator)
        turn = 1
        self.state_log.append([])
        self.move_log.append([])
        self.n_log.append([])
        self.instruction_log.append([])
        opp_player.n_log.append([])
        self.state_log[-1].append(deepcopy(current_mut.state))
        play_out = False
        if np.random.rand() < 0.1 or quit == False:
            play_out = True
        while True:
            print(turn)
            if turn % 100 == 0:
                print('reached: ', turn)
            self.p = defaultdict(lambda: defaultdict(int))
            self.q = defaultdict(lambda: defaultdict(int))
            self.n = defaultdict(lambda: defaultdict(int))
            self.playout_cap = True
            if np.random.rand() < 0.25:
                self.playout_cap = False            
            all_options = current_mut.state.get_all_self_options()
            curr_mat = convert_state(current_mut)
            curr_size = curr_mat.size()[1]
            if curr_size != 897:
                print(current_mut.state)
                print(curr_size)
            tuple_curr = tuple(curr_mat.tolist()[0])
            curr_options = current_mut.state.get_self_options()
            if self.playout_cap and not evaluate:
                for i in range(capped_sims):
                    self.search(current_mut, curr_mat,tuple_curr, opp_player, [], [], c, d)
            else:
                for i in range(num_full_sims):
                    self.search(current_mut, curr_mat,tuple_curr, opp_player, [], [], c, d)
            if len(curr_options) == 1 and curr_options[0] == 'splash':
                my_a = 'splash'
                self.n_log[-1].append(['splash', 1])
                my_exp_val = 'N/A'
            else:
                pi = self.pi(current_mut, tuple_curr, all_options, curr_options, c, d)
                sort_idx = np.argsort(np.array(pi))[::-1]
                all_sorted_options = [all_options[i] for i in sort_idx]
                if evaluate:
                    my_a = all_options[np.argmax(pi)]
                else:
                    my_a = all_options[np.random.choice(range(len(pi)), p=pi)]
                self.n_log[-1].append([all_sorted_options, sorted(pi, reverse=True)])
                my_exp_val = self.q[tuple_curr][my_a]

            
            flipped_battle = StateMutator(self.flip_state(current_mut.state))
            flipped_mat = convert_state(flipped_battle)
            tuple_flipped = tuple(flipped_mat.tolist()[0])
            all_opp_options = flipped_battle.state.get_all_self_options()
            opp_curr_options = flipped_battle.state.get_self_options()
            if len(opp_curr_options) == 1 and opp_curr_options[0] == 'splash':
                opp_a = 'splash'
                opp_player.n_log[-1].append(['splash', 1])
                opp_exp_val = 'N/A'
            else:
                opp_pi = opp_player.pi(flipped_battle, tuple_flipped, all_opp_options, opp_curr_options, c, d)
                opp_sort_idx = np.argsort(np.array(opp_pi))[::-1]
                all_sorted_opp_options = [all_opp_options[i] for i in opp_sort_idx]
                if evaluate:
                    opp_a = all_opp_options[np.argmax(opp_pi)]
                else:
                    opp_a = all_opp_options[np.random.choice(range(len(opp_pi)), p=opp_pi)]
                opp_player.n_log[-1].append([all_sorted_opp_options, sorted(opp_pi, reverse=True)])
                opp_exp_val = opp_player.q[tuple_flipped][opp_a]

            if my_a != 'splash' and opp_a != 'splash' and not self.playout_cap:
                examples.append([curr_mat, pi, [0, self.q[tuple_curr][my_a]], 0, turn])
                examples.append([flipped_mat, opp_pi, [1, opp_player.q[tuple_flipped][opp_a]], 0, turn])
            
            current_mut = StateMutator(self.flip_state(flipped_battle.state))
            current_mut, ins = self.battle_turn(current_mut, my_a, opp_a)
            self.move_log[-1].append((my_a, opp_a, my_exp_val, opp_exp_val))
            self.instruction_log[-1].append(ins)
            self.state_log[-1].append(deepcopy(current_mut.state))
            turn += 1

            if not play_out and my_exp_val != 'N/A' and my_exp_val < -0.9:
                reward = -1
            elif not play_out and opp_exp_val != 'N/A' and opp_exp_val < -0.9:
                reward = 1
            else:
                reward = current_mut.state.battle_is_finished()
            if evaluate and reward:
                print(reward)
                return
            if reward:
                score = score_diff(current_mut)
                for x in examples:
                    x[3] = score
                    x[4] = turn - x[4]
                    if x[2][0] == 0:
                        x[2] = (1- generation/20)* reward + generation/20 * x[2][1]
                    else:
                        x[2] = -1* (1- generation/20)* reward + generation/20 * x[2][1]
                #return examples
                print(turn)
                team1, team2 = team_nums
                with open("battle_data/data_" + str(team1) + '_' + str(team2) + '_' 
                          + str(k)+"_gen_" + str(generation) + ".pkl", "wb") as fp:   #Pickling
                    pickle.dump(examples, fp)
                return

In [8]:
works = False
while not works:
    try:
        team_nums = np.random.randint(len(teams['teams']), size=2)
        mut = import_teams(team_nums)
        works = True
    except:
        pass
print(mut.state.self.active.id)
print(mut.state.self.reserve.keys())
print(mut.state.opponent.active.id)
print(mut.state.opponent.reserve.keys())

slowking
dict_keys(['terrakion', 'clefable', 'hippowdon', 'dragapult', 'skarmory'])
blissey
dict_keys(['slowking', 'skarmory', 'tornadustherian', 'zamazentacrowned', 'hippowdon'])


In [9]:
p1 = MCTSNNPlayer(model=True, model_dir='model_heur.dict')
p2 = MCTSNNPlayer(model=True, model_dir='model_heur.dict')

In [17]:
%%time
p1.executeEpisode(mut, p2, team_nums, evaluate=True, num_full_sims=300, capped_sims = 80, generation=0, k=0, quit=False)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17


KeyboardInterrupt: 

In [24]:
%%time
p1 = MCTSNNPlayer(model=True, model_dir='model_heur.dict')
p2 = MCTSNNPlayer(model=True, model_dir='model_heur.dict')
processes = []
for k in range(15):
    works = False
    while not works:
        try:
            team_nums = np.random.randint(len(teams['teams']), size=2)
            mut = import_teams(team_nums)
            works = True
        except:
            pass
    p = Process(target=p1.executeEpisode, args=(mut, p2, team_nums, False, 500, 
                       100, 0, k, 2, 1, False,))  # Passing the list
    p.start()
    processes.append(p)
for p in processes:
    p.join()

1
1
1
11

1
11

1
1
1
1
1
1
1
2
2
2
2
2
2
2
2
2
22

2
2
3
3
3
3
3
3
3
3
3
3
3
4
4
4
4
4
4
4
4
4
5
5
5
5
5
5
6
6
6
6
2
2
7
7
3
3
3
8
8
4
4
4
4
9
5
9
5
5
5
5
5
10
5
6
10
6
6
6
6
11
6
7
7
11
7
7
3
8
7
7
8
8
7
4
4
9
8
9
9
8
5
10
9
10
10
6
11
10
6
11
7
6
7
12
12
11
7
8
12
7
8
13
12
13
12
9
13
9
14
13
13
5
14
15
8
6
6
15
16
9
7
7
16
17
8
10
8
8
11
18
8
9
11
9
10
9
8
9
14
12
10
10
10
10
11
9
12
10
15
11
13
11
12
14
14
11
13
11
9
12
12
12
12
14
15
15
10
13
13
13
17
19
11
16
14
14
14
20
15
15
17
15
10
11
16
16
16
18
11
16
12
17
17
13
17
1419

17
12
13
13
18
20
18
18
13
15
14
12
21
19
19
18
14
15
15
13
21
20
20
19
15
16
21
22
21
20
23
22
22
18
21
14
23
19
22
14
19
24
22
15
16
16
20
23
16
17
21
14
17
24
16
17
22
18
15
24
16
25
23
18
16
17
23
17
20
24
15
17
17
18
21
24
25
25
18
18
19
23
22
26
25
26
20
19
23
24
26
19
21
24
25
20


Process Process-122:
Traceback (most recent call last):
  File "/home/mingu600/anaconda3/envs/pokezero/lib/python3.6/multiprocessing/process.py", line 258, in _bootstrap
    self.run()
  File "/home/mingu600/anaconda3/envs/pokezero/lib/python3.6/multiprocessing/process.py", line 93, in run
    self._target(*self._args, **self._kwargs)
  File "<ipython-input-19-709b0e852bc2>", line 212, in executeEpisode
    self.search(current_mut, curr_mat,tuple_curr, opp_player, [], [], c, d)
  File "<ipython-input-19-709b0e852bc2>", line 132, in search
    v = self.search(sp, sp_s, sp_curr, opp_player, ins + new_ins, move_list + [(my_move, opp_move)], c, d)
  File "<ipython-input-19-709b0e852bc2>", line 132, in search
    v = self.search(sp, sp_s, sp_curr, opp_player, ins + new_ins, move_list + [(my_move, opp_move)], c, d)
  File "<ipython-input-19-709b0e852bc2>", line 129, in search
    sp, new_ins = self.battle_turn(new_battle, my_move, opp_move)
  File "<ipython-input-19-709b0e852bc2>", line 29, 

26
18
20
22
25
25
26
21
19
26
23
2621

18
27
16
22
27
19
22
19
27
28
23
17
27
28
20
20
24
29
28
21
21
25
29
22
27
26
22
27
30
23
27
27
23
28
24
3124

28
28
28
28
29
25
18
25
32
29
29
30
20
33
26
19
30
29
23
31
31
32
24
30
32
25
33
24
26
29
34
25
35
30
26
34
27
27
30
20
26
36
27
35
31
21
31
27
28
33
29
36
21
22
32
34
37
30
31
29
23
22
33
35
38
32
31
24
30
30
39
34
40
28
37
28
32
41
29
38
29
28
42
30
39
30
29
43
31
40
31
33
36
25
23
31
34
26
37
32
35
35
36
27
38
33
36
28
31
37
34
39
37
29
32
40
44
38
32
30
41
30
33
41
32
45
31
42
31
24
42
32
43
32
33
43
33
33
44
35
44
34
34
34
45
35
45
33
3935

33
46
34
36
40
46
35
25
32
37
47
36
38
33
26
48
34
36
49
27
35
35
50
37
46
28
34
34
36
47
36
36
47
35
29
35
37
37
37
48
38
36
36
30
41
38
38
49
37
39
31
42
38
40
43
32
38
41
51
39
48
44
33
37
40
45
37
34
41
46
38
38
39
50
42
47
39
51
40
39
39
43
48
40
40
44
40
52
49
42
41
41
45
50
43
42
42
39
44
43
40
45
49
44
52
4150

41
46
35
51
45
52
53
42
46
53
36
51
43
46
47
54
37
44
52
41
47
43
55
48
45
38
4

Process Process-124:
Traceback (most recent call last):
  File "/home/mingu600/anaconda3/envs/pokezero/lib/python3.6/multiprocessing/process.py", line 258, in _bootstrap
    self.run()
  File "/home/mingu600/anaconda3/envs/pokezero/lib/python3.6/multiprocessing/process.py", line 93, in run
    self._target(*self._args, **self._kwargs)
  File "<ipython-input-19-709b0e852bc2>", line 212, in executeEpisode
    self.search(current_mut, curr_mat,tuple_curr, opp_player, [], [], c, d)
  File "<ipython-input-19-709b0e852bc2>", line 132, in search
    v = self.search(sp, sp_s, sp_curr, opp_player, ins + new_ins, move_list + [(my_move, opp_move)], c, d)
  File "<ipython-input-19-709b0e852bc2>", line 132, in search
    v = self.search(sp, sp_s, sp_curr, opp_player, ins + new_ins, move_list + [(my_move, opp_move)], c, d)
  File "<ipython-input-19-709b0e852bc2>", line 129, in search
    sp, new_ins = self.battle_turn(new_battle, my_move, opp_move)
  File "<ipython-input-19-709b0e852bc2>", line 29, 

76
77
78
90
79
80
82
83
84
81
91
92
93
85
86
94
87
95
88
96
89
97
90
98
91
99
100
reached:  100
92
101
93
102
94
103
95
104
96
105
97
106
107
108
109
110
111
98
112
113
114
99
115
100
reached:  100
116
101
117
102
118
119
120
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
CPU times: user 2.86 s, sys: 971 ms, total: 3.83 s
Wall time: 35min 50s


In [66]:
display_battle(p1.state_log, p1.move_log, p1.instruction_log, 0)

TURN 1

My side conditions- 

     excadrill, leftovers
|------------------------------------|
        100%, 361/361

Weather:  None

     clefable, leftovers
|-----------------------------------|
        100%, 393/393

Opponent side conditions- 

-----------------------------------------------------------------------------------------------------
My move: earthquake
My expected value 0.6367051053260054
Opp move: switch swampert
Opponent expected value -0.9235751032829285
Instructions:
[('switch', 'opponent', 'clefable', 'swampert'), ('decrement_pp', 'self', 'earthquake', 1), ('damage', 'opponent', 329.0), ('heal', 'opponent', 25)]
-----------------------------------------------------------------------------------------------------

My side conditions- 

     excadrill, leftovers
|------------------------------------|
        100%, 361/361

Weather:  None

     swampert, leftovers
|---------                          |
        25%, 100/404

Opponent side conditions- 

------------------

In [70]:
game = 0
turn = 2
print(p1.n_log[game][turn-1][0])
print(p1.n_log[game][turn-1][1])
print('')
print(p2.n_log[game][turn-1][0])
print(p2.n_log[game][turn-1][1])

['earthquake', 'ironhead', 'switch volcarona', 'switch clefable', 'switch tyranitar', 'rockblast', 'switch dragapult', 'switch dracozolt', 'swordsdance']
[0.4166666666666667, 0.2222222222222222, 0.1111111111111111, 0.08333333333333333, 0.08333333333333333, 0.08333333333333333, 0.0, 0.0, 0.0]

['switch clefable', 'switch slowbro', 'switch dragapult', 'toxic', 'stealthrock', 'earthquake', 'switch nidoking', 'switch mandibuzz', 'flipturn']
[0.48484848484848486, 0.15151515151515152, 0.09090909090909091, 0.09090909090909091, 0.09090909090909091, 0.09090909090909091, 0.0, 0.0, 0.0]


In [26]:
#detailed_state(p1.state_log, 0, 53)

In [46]:
%lprun -f p1.search p1.executeEpisode(mut, p2, team_nums, evaluate=True, num_full_sims=30, capped_sims = 30, generation=0, k=0, quit=False)

1


In [25]:
#display_battle(p1.state_log, p1.move_log, p1.instruction_log, 0)

In [55]:
import torch.utils.data as data
class BattleDataset(data.Dataset):
    def __init__(self, size):
        """
        Inputs:
            size - Number of data points we want to generate
            std - Standard deviation of the noise (see generate_continuous_xor function)
        """
        super().__init__()
        self.size = size
        self.load_data()

    def load_data(self):
        data = []
        for i in range(1, self.size + 1):
            with open("battle_data/data_" + str(i) + "_6v6.pkl", "rb") as fp:   # Unpickling
                b = pickle.load(fp)
            data.extend(b)
        state = [x[0] for x in data]
        label = [(torch.Tensor(x[1]), x[2]) for x in data]
        self.data = state
        self.label = label

    def __len__(self):
        # Number of data point we have. Alternatively self.data.shape[0], or self.label.shape[0]
        return len(self.data)

    def __getitem__(self, idx):
        # Return the idx-th data point of the dataset
        # If we have multiple things to return (data point and label), we can return them as tuple
        data_point = self.data[idx]
        data_label = self.label[idx]
        return data_point, data_label

In [56]:
dataset = BattleDataset(size=3)
print("Size of dataset:", len(dataset))
print("Data point 0:", dataset[0])

Size of dataset: 2088
Data point 0: (tensor([[ 4.8500e+02,  6.0000e+01,  2.3400e+02,  4.1400e+02,  4.3000e+02,
          4.6300e+02,  2.6900e+02,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00,  5.0000e-01,  4.3344e-02,
         -2.2222e-01,  0.0000e+00,  3.3108e-01,  0.0000e+00,  2.6316e-01,
          0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  1.0000e+00,  1.0000e+00,  1.0000e+00,
          1.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00,  3.6000e+01,  1.1800e+02,
          2.3400e+02,  2.8200e+02,  5.8500e+02,  1.3500e+02,  4.4600e+02,
          0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  5.0000e-01,  1.9033e-01, -9.0909e-02,  2.2527e-01,
          0.0000e+00,  2.0833e-01,  6.4103e-03,  0.0000e+00,  0.0000e+00,
 