In [None]:
from ijcai2022nmmo import CompetitionConfig, scripted, TeamBasedEnv, Team
import nmmo
import numpy as np
import copy
import gym
import matplotlib.pyplot as plt
from gym import spaces 
from nmmo.io import action

class MyTrainEnv(gym.Env):
    def __init__(self,env_config:dict):
        self.config = CompetitionConfig()
        self.team_env = TeamBasedEnv(self.config)
        self.env_config = env_config
        #observation_space = [spaces.MultiDiscrete(2*np.ones((129,129),dtype=np.int32))]*8
        self.observation_space = spaces.Tuple(( spaces.Box(low=-np.infty,high=np.infty,shape=(129,129,40),dtype=np.float32),
                                               spaces.Box(low=-np.infty,high=np.infty,shape=(88,),dtype=np.float32) 
                                             ))
        self.action_space = spaces.Box(low=-np.infty,high=np.infty,shape=(80,),dtype=np.float32)
        
    def reset(self):
        self.obs_by_team = self.team_env.reset()
#         print(self.obs_by_team)
        self.n_tick = 0 
        ##### establish the teams:
        self.teams = []
        self.teams += [MyAwesomeTeam("MyTeam",self.config)]
        team_config = self.env_config["teams"]
        ### 随机生成不同的配比
        
        for key, team_num in team_config.items():
            if team_num>0:
                designed_team = getattr(scripted,key)
                self.teams += [designed_team(key+f"-{i}",self.config) for i in range(team_num)] 
        ############################
        
        for team in self.teams:
            team.reset()
        
        # myteam population id = 0
        obs_to_myteam = self.obs_by_team[0]
        
        
        new_obs = self.teams[0].update_map(obs_to_myteam)
        return new_obs
        
    def step(self,action):
        
        ######## reward
        def cal_rwd(init, seg1, seg2, seg3, prev, curr):
            k = [4 / (seg1 - init), 6 / (seg2 - seg1), 11 / (seg3 - seg2)]
            seg = [seg1, seg2, seg3]

            rwd = 0
            for i in range(3):
                if prev < seg[i]:
                    incre = min(curr - prev, seg[i] - prev)
                    rwd += k[i] * incre
                    prev += incre
                    if prev == curr:
                        return rwd
            
            return rwd
        
        def find_exploration(exploration_map):
            if exploration_map.sum() == 0:
                return 0
            x,y = np.where(exploration_map==1)
            return max( np.max(x)-np.min(x), np.max(y)-np.min(y)  )    
        
        actions = {}
        ####### convert the output of the policy into the true action
        new_my_action = self.teams[0].act(action)
        
        
        actions[0]=new_my_action
        ############
        
        ####### obtain the actions of other teams
        for team_index in range(1,16):
            if self.obs_by_team.get(team_index):
                other_actions = self.teams[team_index].act(self.obs_by_team[team_index])
            else:
                other_actions ={}
            actions[team_index] = other_actions
        self.obs_by_team_all = self.team_env.step(actions)
        self.obs_by_team = self.obs_by_team_all[0]
        if self.obs_by_team.get(0) != None:
            obs_to_myteam = self.obs_by_team[0]

            new_my_obs = self.teams[0].update_map(obs_to_myteam)

            self.n_tick += 1
            if self.n_tick<1024:
                done = False
            else:
                done = True

            info = {}

            ### hunting and fishing 21
            rwd_hunting = cal_rwd(10, 20, 35, 50, self.teams[0].last_global_information["hunting_level"], self.teams[0].global_information["hunting_level"])
            rwd_fishing = cal_rwd(10, 20, 35, 50, self.teams[0].last_global_information["fishing_level"], self.teams[0].global_information["fishing_level"])
            rwd_resources = (rwd_hunting + rwd_fishing) / 2

            ### defeat players 21
            rwd_kill = cal_rwd(0, 1, 3, 6, self.teams[0].last_global_information["kill_opponent_num"], self.teams[0].global_information["kill_opponent_num"])
            print("kill_num_prev",self.teams[0].last_global_information["kill_opponent_num"])
            print("kill_num_curr",self.teams[0].global_information["kill_opponent_num"])
            ### exploration 21
            for idx in range(8):
                exploration_map = self.teams[0].global_map["agents"]["layers"][:,:,idx*2+1]
                self.teams[0].global_information["exploration"] = max(find_exploration(exploration_map), self.teams[0].global_information["exploration"])
            rwd_exploration = cal_rwd(0, 32, 64, 127, self.teams[0].last_global_information["exploration"], self.teams[0].global_information["exploration"])
#             self.teams[0].last_global_information["exploration"] = self.teams[0].global_information["exploration"]

            ### food and water
            rwd_food_and_water = 0
            for idx in range(8):
                food = self.teams[0].global_map['agents']['vectors'][idx*11+2]
                water = self.teams[0].global_map['agents']['vectors'][idx*11+3]
                if food < 0.3 * self.teams[0].global_information["hunting_level"]:
                    rwd_food_and_water += -0.02
                if water < 0.3 * self.teams[0].global_information["fishing_level"]:
                    rwd_food_and_water += -0.02

            ### equipment
            rwd_equip = 0
            if self.teams[0].global_information['kill_npc_level'] > self.teams[0].last_global_information['kill_npc_level']:
                rwd_equip = cal_rwd(0, 1, 10, 20, self.teams[0].last_global_information["kill_npc_level"], self.teams[0].global_information["kill_npc_level"])

            ### level
            alive = False
            rwd_level = 0
            for idx in range(8):
                level = self.teams[0].global_map['agents']['vectors'][idx*11+0]
                if level > 0:
                    alive = True
                self.teams[0].global_information['level'] = max(self.teams[0].global_information['level'], level)
            rwd_level += (self.teams[0].global_information['level'] - self.teams[0].last_global_information['level']) * 21 / 80
            self.teams[0].last_global_information['level'] = self.teams[0].global_information['level']

            ### alive
            rwd_alive = 0
            if alive:
                rwd_alive += 0.02

            reward = rwd_resources + rwd_kill + rwd_exploration + rwd_food_and_water + rwd_equip + rwd_level + rwd_alive
            print("---------------------------------")
            print( "resources: ",rwd_resources)
            print( "kill: ",rwd_kill)
            print( "exploration: ",rwd_exploration)
            print( "food_and_water: ",rwd_food_and_water)
            print( "equip: ",rwd_equip)
            print( "level: ",rwd_level)
            print( "alive: ",rwd_alive)
            print("---------------------------------")
            
        else:
            done = True   
            reward = 0 
            info = {}
            new_my_obs = (np.zeros((129,129,40)),np.zeros(88))
#         print(done)
        return new_my_obs, reward, done, info
        
        

In [None]:

########################
## Myteam
class MyAwesomeTeam(Team):
    def reset(self):
        # reset some states
        self.global_map = {}
        self.n_tick = 0 
        self.LOCAL_BIAS = 16
        self.global_information={ 
                                  "level": 1,
                                  "hunting_level":10, 
                                  "fishing_level":10,
                                  "kill_npc_num":0,
                                  "kill_npc_level":0,
                                  "kill_opponent_num":0,
                                  "exploration":0
                                  }
        
        self.last_global_information={ "hunting_level":10, 
                               "fishing_level":10,
                               "kill_npc_level":0,
                               "kill_npc_num":0,
                               "kill_opponent_num":0,
                               "exploration":0,
                               "level": 1}
        
        self.agents_eval = np.zeros((8,6),dtype=np.float32)
        self.agents_eval[:,0] = 1 # level
        self.agents_eval[:,1] = 10 # hunting level
        self.agents_eval[:,2] = 10 # fishing level
        self.agents_eval[:,3] = 0 # kill_npc_num
        self.agents_eval[:,4] = 0 # kill_npc_level
        self.agents_eval[:,5] = 0 # kill_opponents_num 
        
        self.last_action_informations= np.zeros((8,2),dtype=np.int32)  # id-level
        self.agents_pos = np.zeros((8, 2))  # agent position
    
    def obs_normalization(self,agent_layers,agent_vectors,opponents,npcs,terrain):
        MAX_LEVEL = 80
        MAX_HEALTH = MAX_LEVEL
        MAX_FOOD = 80 
        MAX_WATER = 80
        MAX_DAMAGE = 1000
        MAX_HUNTING_LEVEL = 75 
        MAX_FISHING_LEVEL = 75 
        MAX_KILL_NUM = 20
        
        # normalization for the agents
        for agent_index in range(8):
            # level
            agent_vectors[agent_index*11+0] /= MAX_LEVEL
            # health
            agent_vectors[agent_index*11+1] /= MAX_HEALTH
            # food
            agent_vectors[agent_index*11+2] /= MAX_FOOD
            # water 
            agent_vectors[agent_index*11+3] /= MAX_WATER
            # DAMAGE 
            agent_vectors[agent_index*11+4] /= MAX_DAMAGE
            # HUNTING_LEVEL
            agent_vectors[agent_index*11+6] /= MAX_HUNTING_LEVEL
            # FISHING LEVEL 
            agent_vectors[agent_index*11+7] /= MAX_FISHING_LEVEL
            # KILL_NPC_LEVEL
            agent_vectors[agent_index*11+8] /= MAX_LEVEL
            # kill_npc_num
            agent_vectors[agent_index*11+9] /= MAX_KILL_NUM
            # KILL_OPPONENT_NUM
            agent_vectors[agent_index*11+10] /= MAX_KILL_NUM
        
        MAX_DIS_NUM = 5
        # normalization for the opponents AND NPCS 
        # distribution 
        opponents[:,:,0] /= MAX_DIS_NUM 
        npcs[:,:,0] /= MAX_DIS_NUM
        # level 
        opponents[:,:,1] /= MAX_LEVEL
        npcs[:,:,1] /= MAX_LEVEL
        # damage
        opponents[:,:,2] /= MAX_DAMAGE
        npcs[:,:,2] /= MAX_DAMAGE
        # food 
        opponents[:,:,3] /= MAX_FOOD
        npcs[:,:,3] /= MAX_FOOD
        # water 
        opponents[:,:,4] /= MAX_WATER
        npcs[:,:,4] /= MAX_WATER
        # health
        opponents[:,:,5] /= MAX_HEALTH
        npcs[:,:,5] /= MAX_HEALTH
        
        ##### clip between [0,1]
        agent_layers = np.clip(agent_layers,0,1)
        opponents = np.clip(opponents,0,1)
        npcs = np.clip(npcs,0,1)
        terrain = np.clip(terrain,0,1)
        agent_vectors = np.clip(agent_vectors,0,1)
        
        # concatenate
        obs_layers = np.concatenate( (agent_layers,opponents,npcs,terrain),axis=-1 )
        
        return ( obs_layers, agent_vectors )
            
    
    def update_map(self,observations:dict[int,dict]):
        ############# agent layer information
        agent_visible = np.zeros([129,129])
        agent_info = np.zeros([129,129,16])
        agents_dis = np.zeros([129,129])
        ############# agent vector information
        agent_vectors = np.zeros([88])
        ############# opponent layer information
        opponents=np.zeros([129,129,7])
        ############# npc layer infromation
        npcs=np.zeros([129,129,7])
        ############# record the attack objects 
        self.attack_objs = []
        for k in self.global_information:
            self.last_global_information[k] = self.global_information[k]
        for agent_index in range(8):
            attack_obj = []
            last_object_list = []
            if observations.get(agent_index)!=None:
                ######## obttain the obs from the env
#                 print(observations[agent_index])
                obs_en = observations[agent_index]['Entity']['Continuous']
                ######## team pop 
                team_id = int( obs_en[0][4] )
                ########  agent row | column index 
                agent_row_index = int( obs_en[0][5]- self.LOCAL_BIAS )
                agent_column_index = int( obs_en[0][6] - self.LOCAL_BIAS )
                self.agents_pos[agent_index] = np.array([agent_row_index, agent_column_index])
                ######## opponents, npcs location
                opponents_loc = []
                npcs_loc = []
                
                ######## information for npcs killing and opponents killing
                
                ######## 
                for index in range(100):
                    info = int(obs_en[index][0])
                    if info==0:
                        break
                    entity_id = int(obs_en[index][1]) # entity_id
                    level = int( obs_en[index][3] )   # agent level 
                    pop = int( obs_en[index][4])      # agent pop 
                    row_index = int( obs_en[index][5]- self.LOCAL_BIAS )  # row index 
                    column_index = int( obs_en[index][6] - self.LOCAL_BIAS ) # column index 
                    damage = int( obs_en[index][7] )  # damage received 
                    timealive = int( obs_en[index][8] ) # timealive
                    food = int( obs_en[index][9] )  # food 
                    water = int( obs_en[index][10] ) # water
                    health = int( obs_en[index][11] ) # health 
                    frozen = int( obs_en[index][12] ) # frozen 

                    if index==0:
                        # 表示是该team的agent
                        ### update layers 
                        ######### 更新 agent_visible
                        agent_visible[int(max(row_index-7,0)):int(min(row_index+8,129)),
                                      int(max(column_index-7,0)):int(min(column_index+8,129))] = 1
                        ######### update agent info
                        ################################   alive
                        agent_info[row_index][column_index][agent_index*2+0]=1
                        ################################   exploration
                        if self.global_map=={}: # first step
                            agent_info[:,:,agent_index*2+1]=agent_visible
                        else:
                            agent_info[:,:,agent_index*2+1] = self.global_map["agents"]["layers"][:,:,agent_index*2+1]
                            agent_info[int(max(row_index-7,0)):int(min(row_index+8,129)),
                                      int(max(column_index-7,0)):int(min(column_index+8,129)),
                                      agent_index*2+1] = 1
                            
                        ######### agent 更新 agent dis
                        agents_dis[row_index][column_index]=1 # locations 
                        
                        ### update agent vectors 
                        ######### agent 更新 agent level
                        agent_vectors[agent_index*11+0]=level
                        self.agents_eval[agent_index,0] = max(self.agents_eval[agent_index,0],level)
                        ######### agent 更新 agent health
                        agent_vectors[agent_index*11+1]=health
                        ######### agent 更新 agent food
                        agent_vectors[agent_index*11+2]=food                        
                        ######### agent 更新 agent water
                        agent_vectors[agent_index*11+3]=water
                        ######### agent 更新 agent damage
                        agent_vectors[agent_index*11+4]=damage                        
                        ######### agent 更新 agent frozen
                        agent_vectors[agent_index*11+5]=frozen  
                        
                        ######### agent 更新 hunting level
                        self.agents_eval[agent_index,1] = max(food,self.agents_eval[agent_index,1])
                        agent_vectors[agent_index*11+6] = self.agents_eval[agent_index,1]
                        
                        ######### agent 更新 fishing level
                        self.agents_eval[agent_index,2] = max(water,self.agents_eval[agent_index,2])
                        agent_vectors[agent_index*11+7] = self.agents_eval[agent_index,2]

                        ######### agent 更新 timealive
                        self.n_tick = timealive

                        #############################################################
                    else:
                        if pop>0 and pop!=team_id: # opponent
                            ######## opponent distribution
                            opponents[row_index][column_index][0]+=1
                            ######## level
                            opponents[row_index][column_index][1]+=level
                            ######## damage
                            opponents[row_index][column_index][2]+=damage
                            ######## food
                            opponents[row_index][column_index][3]+=food
                            ######## water
                            opponents[row_index][column_index][4]+=water                                
                             ######## health
                            opponents[row_index][column_index][5]+=health
                            ######## frozen
                            opponents[row_index][column_index][6]+=frozen

                            ######## killing opponents
                            if self.last_action_informations[agent_index][0]>0:
                                last_object_list.append(entity_id)
                            
                            if [row_index,column_index] not in opponents_loc:
                                opponents_loc.append( [row_index,column_index] )
                            if max(abs(row_index-agent_row_index),abs(column_index-agent_column_index))<=4:
                                attack_obj.append([entity_id,
                                                   abs(row_index-agent_row_index),
                                                   abs(column_index-agent_column_index),
                                                  level,
                                                  damage,
                                                  food,
                                                  water,
                                                  health,
                                                  frozen,
                                                   index])
                        elif pop<0: # npcs 
                            ######## npc distribution
                            npcs[row_index][column_index][0]+=1
                            ######## level
                            npcs[row_index][column_index][1]+=level
                            ######## damage
                            npcs[row_index][column_index][2]+=damage
                            ######## food
                            npcs[row_index][column_index][3]+=food
                            ######## water
                            npcs[row_index][column_index][4]+=water                                
                            ######## health
                            npcs[row_index][column_index][5]+=health
                            ######## frozen
                            npcs[row_index][column_index][6]+=frozen      
                            
                            ############ killing npcs 
                            if self.last_action_informations[agent_index][0]<0:
                                last_object_list.append(entity_id)
                                
                            if [row_index,column_index] not in npcs_loc:
                                npcs_loc.append([row_index,column_index])
                                
                            if max(abs(row_index-agent_row_index), abs(column_index-agent_column_index))<=4:
                                attack_obj.append([entity_id,
                                                   abs(row_index-agent_row_index),
                                                   abs(column_index-agent_column_index),
                                                  level,
                                                  damage,
                                                  food,
                                                  water,
                                                  health,
                                                  frozen,
                                                   index])
                
                for loc in opponents_loc:
                    for i in range(1,7):
                        opponents[loc[0]][loc[1]][i] = int( opponents[loc[0]][loc[1]][i]/opponents[loc[0]][loc[1]][0])
                for loc in npcs_loc:
                    for i in range(1,7):
                        npcs[loc[0]][loc[1]][i] = int(npcs[loc[0]][loc[1]][i]/npcs[loc[0]][loc[1]][0])
            ### update the killing information:
            if last_object_list != []:
                if self.last_action_informations[agent_index][0]>0 and \
                   self.last_action_informations[agent_index][0] not in last_object_list :
                    self.agents_eval[agent_index,5]+=1 
                elif self.last_action_informations[agent_index][0]<0 and \
                   self.last_action_informations[agent_index][0] not in last_object_list :
                    self.agents_eval[agent_index,3]+=1
                    self.agents_eval[agent_index,4] = max(self.agents_eval[agent_index,4],
                                                                   self.last_action_informations[agent_index][1])
            
            ####### update the killing information:
            ### kill_npc level 
            agent_vectors[agent_index*11+8] = self.agents_eval[agent_index,4]
            ### kill_npc_num
            agent_vectors[agent_index*11+9] = self.agents_eval[agent_index,3]
            ### kill_opponent_num 
            agent_vectors[agent_index*11+10] = self.agents_eval[agent_index,5]
            
            self.attack_objs.append(np.array(attack_obj,dtype=np.float32))
        ##################### terrain update
        if self.global_map=={}: # first step
            terrain = np.zeros([129,129,8])  
            for agent_index in range(8):
                if observations.get(agent_index)!=None:
                    obs_til = observations[agent_index]['Tile']['Continuous']
                    for index in range(225):
                        til_type = int( obs_til[index][1] )
                        til_row = int( obs_til[index][2] -self.LOCAL_BIAS )
                        til_column = int( obs_til[index][3] -self.LOCAL_BIAS ) 
                        if til_row<0 or til_row > 128 or til_column<0 or til_column>128:
                            continue 

                        terrain[til_row][til_column][0]= 1
                        terrain[til_row][til_column][til_type]=1
        else:
            terrain = self.global_map['terrain']
            for i in range(129):
                for j in range(129):
                    if terrain[i][j][3]==1:
                        terrain[i][j][6]=1
                    if terrain[i][j][4]==1:
                        terrain[i][j][7]=1       
            for agent_index in range(8):     
                if observations.get(agent_index)!=None:
                    obs_til = observations[agent_index]['Tile']['Continuous']
                    for index in range(225):
                        til_type = int( obs_til[index][1] )
                        til_row = int( obs_til[index][2] -self.LOCAL_BIAS )
                        til_column = int( obs_til[index][3] -self.LOCAL_BIAS ) 
                        if til_row<0 or til_row > 128 or til_column<0 or til_column>128:
                            continue 

                        terrain[til_row][til_column][0]= 1
                        terrain[til_row][til_column][til_type]=1

                        if til_type==3 or til_type==4:
                            terrain[til_row][til_column][6]=0
                            terrain[til_row][til_column][6]=0
        ######## update the global information
        self.global_information={ 
                                  "level": 1,
                                  "hunting_level":10, 
                                  "fishing_level":10,
                                  "kill_npc_num":0,
                                  "kill_npc_level":0,
                                  "kill_opponent_num":0,
                                  "exploration":0
                                  }        
        
        self.global_information["level"]=np.max(self.agents_eval[:,0])
        self.global_information["hunting_level"]=np.max(self.agents_eval[:,1])
        self.global_information["fishing_level"]=np.max(self.agents_eval[:,2])
        self.global_information["kill_npc_num"]=np.max(self.agents_eval[:,3])
        self.global_information["kill_npc_level"]=np.max(self.agents_eval[:,4])
        self.global_information["kill_opponent_num"]=np.max(self.agents_eval[:,5])
        
        ######## concatenate                        
        self.global_map['terrain']=terrain
        self.global_map['oppos']=opponents
        self.global_map['npcs']=npcs
        
        agent_visible_expanded = np.expand_dims(agent_visible,axis=-1)
        agents_dis_expanded = np.expand_dims(agents_dis,axis=-1)
        agent_layers = np.concatenate( (agent_info,agent_visible_expanded,agents_dis_expanded),axis=-1 )
        self.global_map['agents']={"layers":agent_layers, "vectors":agent_vectors}

        obs = self.obs_normalization(copy.deepcopy(agent_layers),
                                     copy.deepcopy(agent_vectors),
                                     copy.deepcopy(opponents),
                                     copy.deepcopy(npcs),
                                     copy.deepcopy(terrain))
        
        return obs
    
    def show_map(self,characters,agent_attris,attris,agent_index=0,terrain_option=False,ter_attris=[0,1,2,3,4,5,6,7]):
        """
        characters :  ["agents","oppos","npcs"]
        attris = [0,1,2,3,4,5,6]
        ter_attris = [0,1,2,3,4,5,6,7]
        """
        
        CHARAC = ['distribution','level','damage','food','water','health','is_frozen']
        AGENTS = ["location","exploration",'visible','agents distribution']
        TERRAIN = ["exploration","water","grass","scrub(latest)","forest(latest)",
                   "stone","scrub(previous)","forest(previous)"]
        i = 0
        for character in characters:
            if character == "agents":
                for attri in agent_attris:
                    plt.figure(i)
                    i+=1
                    ######## show agents 
                    plt.imshow(self.global_map[character]["layers"][:,:,agent_index*2+attri],cmap=plt.cm.cool)
                    plt.colorbar()
                    plt.title(character+" "+AGENTS[attri])
                    plt.show()
                    
                plt.figure(i)
                i+=1
                plt.imshow(self.global_map[character]["layers"][:,:,-2],cmap=plt.cm.cool)
                plt.colorbar()
                plt.title(AGENTS[-2])
                
                plt.figure(i)
                i+=1
                plt.imshow(self.global_map[character]["layers"][:,:,-1],cmap=plt.cm.cool)
                plt.colorbar()
                plt.title(AGENTS[-1])                
            else:
                for attri in attris:
                    plt.figure(i)
                    i+=1
                    plt.imshow(self.global_map[character][:,:,agent_index*8+attri],cmap=plt.cm.cool)
                    plt.colorbar()
                    plt.title(character+" "+CHARAC[attri])
                    plt.show()    
        if terrain_option:
            terrains = self.global_map['terrain']
            for ter_attri in ter_attris:
                plt.figure(i)
                i+=1     
                
                plt.imshow(terrains[:,:,ter_attri],cmap=plt.cm.cool)
                plt.colorbar()
                plt.title("terrain "+TERRAIN[ter_attri])
                    
                    
    def act(self, actions: dict[int, dict]) -> dict[int, dict]:
        # modified: actions: 8 * 10 [x, y, x_w, y_w, LEVEL_w, damage_w, food_w, WATER_W, health_w, frozen_w]
        # self.attack_objs[i]: n_bojs * 9 [id, x, y, level, damage, food, water, health, frozen, idx] 
        
        actions = actions.reshape((8, 10))
        actions_output = {}
        for i, a in enumerate(actions):
            # 初始化输出的action
            action_output = {action.Attack:{
                                 action.Style: None,
                                 action.Target: None
                                 },
                             action.Move: {
                                 action.Direction: None
                                 }
                             }
            
            # 每个agent的四个移动方向 [North, South, East, West]
            move = 0
            if np.abs(a[0]) > np.abs(a[1]) and a[0] > 0:
                move = 2 # east
            elif np.abs(a[0]) > np.abs(a[1]) and a[0] < 0:
                move = 3 # west
            elif np.abs(a[0]) < np.abs(a[1]) and a[1] > 0:
                move = 0 # north
            elif np.abs(a[0]) < np.abs(a[1]) and a[1] < 0:
                move = 1 # south
            
            action_output[action.Move][action.Direction] = move
            central_dist = np.linalg.norm(a[:2])
            if central_dist <= 0.1 or (self.agents_pos[i][0] == 0 and move == 0) \
                                   or (self.agents_pos[i][0] == 128 and move == 1) \
                                   or (self.agents_pos[i][1] == 0 and move == 3) \
                                   or (self.agents_pos[i][1] == 128 and move == 2):
                action_output.pop(action.Move)        
            
            # 每个agent的攻击对象
            if self.attack_objs[i].size != 0:
                score = np.sum(a[2:] * self.attack_objs[i][:,1:9], 1)  # score
                dist = np.max(self.attack_objs[i][:,1:3], 1)
                # sort
                max_idx = np.argsort(score)[::-1][0]
                dist = dist[max_idx]
                if dist <= 1:
                    action_output[action.Attack][action.Style] = 0
                elif dist <= 3:
                    action_output[action.Attack][action.Style] = 1
                elif dist <= 4:
                    action_output[action.Attack][action.Style] = 2
                action_output[action.Attack][action.Target] = int(self.attack_objs[i][max_idx][9])
                self.last_action_informations[i,0] = int(self.attack_objs[i][max_idx][0])
                self.last_action_informations[i,1] = int(self.attack_objs[i][max_idx][3])
            else:
                action_output.pop(action.Attack)
                self.last_action_informations[i,0] = 0
                self.last_action_informations[i,1] = 0
                
                
            actions_output[i] = action_output
        
        return actions_output


In [None]:
import ray 
import ray.rllib.agents.ppo as ppo
from ray.tune.logger import pretty_print
import tqdm
import json5
import ray.tune as tune 
import gym
import gym.spaces as spaces
import numpy as np
from ray.tune.registry import register_env

##### register env    
def trainenv_creator(env_config):
    return MyTrainEnv(env_config)

# ray.shutdown()
# ray.init(num_cpus=3,ignore_reinit_error=True,object_store_memory=3*1024*1024*1024)

algo_config = ppo.DEFAULT_CONFIG.copy()

######### load config 
with open("config.json") as json_file:
    all_config = json5.load(json_file)

for key,value in all_config["algo_config"].items():
    if isinstance(value,dict):
        for sub_key,sub_value in value.items():
            algo_config[key][sub_key] = sub_value
    else:
        algo_config[key] = value

algo_config["env_config"] = all_config["env_config"]
algo_config["model"]["fcnet_activation"] = None
# env = MyTrainEnv(algo_config["env_config"])

In [None]:
env.reset()

In [None]:
env.teams[0].show_map(["agents","oppos","npcs"],[0,1],[0,1,2,3,4,5,6],agent_index=0,terrain_option=True,ter_attris=[0,1,2,3,4,5,6,7])

In [None]:
env.step( env.action_space.sample() )

In [None]:
print(env.obs_by_team_all[3])

In [None]:
env.teams[0].show_map(["agents","oppos","npcs"],[0,1],[0,1,2,3,4,5,6],agent_index=0,terrain_option=True,ter_attris=[0,1,2,3,4,5,6,7])

In [None]:
from ray.tune.logger import pretty_print

register_env("MyTrainEnv",trainenv_creator)
# trainer = ppo.PPOTrainer(env="MyTrainEnv",config=all_config["algo_config"] )# config to pass to env class

# algo_config["env"] = "MyTrainEnv"
# algo_config["lr"] = tune.grid_search([0.001,0.0005])
algo_config["lr"] = 0.001
algo_config["ignore_worker_failures"]=True
algo_config["recreate_failed_workers"]=True
algo_config["preprocessor_pref"]=None
algo_config["train_batch_size"] = 128

trainer = ppo.PPOTrainer(env="MyTrainEnv",config={"framework":"torch",
                                                  "train_batch_size":128,
                                                  "lr":0.001,
                                           "env_config":all_config["env_config"],
                                                   "model":{
                                                            "fcnet_hiddens":[],
                                                       "fcnet_activation":None,
                                                       
                                                       "conv_filters":[[16,[129,129],1] ],
                                                       "conv_activation":"relu",
                                                       
                                                       "post_fcnet_hiddens":[ 100,100 ],
                                                       "post_fcnet_activation":"relu"
                                                    }
                                                  } 
                        )# config to pass to env class




# ppo.PPOTrainer(env="MyTrainEnv",config=algo_config)

# tune.run('PPO',
#          config=algo_config,
#          stop={"timesteps_total":4000}
#         )
for i in range(200):
    result = trainer.train()
    print(pretty_print(result))
    if i%20 == 0: 
        checkpoint = trainer.save()
        print("checkpoint saved at", checkpoint)

In [None]:
model = trainer.get_policy().model
print(model)

In [None]:
import numpy as np
a = np.array( [[2,2,3,7,5],[6,3,6,7,2],[3,2,7,6,9]] )
x,y = np.where(a==7)

In [None]:
print(x)
print(y)