In [None]:
from ijcai2022nmmo import CompetitionConfig, scripted, TeamBasedEnv, Team
import nmmo
import numpy as np
import copy
import gym
import matplotlib.pyplot as plt
from gym import spaces 
from nmmo.io import action


class MyTrainEnv(gym.Env):
    def __init__(self,env_config:dict):
        self.config = CompetitionConfig()
        self.team_env = TeamBasedEnv(self.config)
        self.env_config = env_config
        #observation_space = [spaces.MultiDiscrete(2*np.ones((129,129),dtype=np.int32))]*8
        self.observation_space = spaces.Tuple(( spaces.Box(low=-np.infty,high=np.infty,shape=(129,129,40),dtype=np.float32),
                                               spaces.Box(low=-np.infty,high=np.infty,shape=(88,),dtype=np.float32) 
                                             ))
        self.action_space = spaces.Box(low=-np.infty,high=np.infty,shape=(80,),dtype=np.float32)
        self.t = 0
        
    def reset(self):
        self.obs_by_team = self.team_env.reset()
        self.n_tick = 0 
        ##### establish the teams:
        self.teams = []
        self.teams += [MyAwesomeTeam("MyTeam",self.config)]
        team_config = self.env_config["teams"]
        for key, team_num in team_config.items():
            if team_num>0:
                designed_team = getattr(scripted,key)
                self.teams += [designed_team(key+f"-{i}",self.config) for i in range(team_num)] 
        ############################
        
        for team in self.teams:
            team.reset()
        
        # myteam population id = 0
        obs_to_myteam = obs_by_team[0]
        
        new_obs = self.teams[0].update_map(obs_to_myteam)
        return new_obs
        
    def step(self,action):
        actions = {}
        ####### convert the output of the policy into the true action
        new_my_action = sself.teams[0].act(action)
        
        
        actions[0]=new_my_actions
        ############
        
        ####### obtain the actions of other teams
        for team_index in range(1,16):
            other_actions = self.teams[team_index].act(self.obs_by_team[team_index])
            actions[team_index] = other_actions
        self.obs_by_team = self.team_env.step(actions)
        obs_to_myteam = self.obs_by_team[0]
        
        new_my_obs = self.teams[0].update_map(obs_to_myteam)
        
        self.n_tick += 1 
        
        
        ######## reward
        def cal_rwd(init, seg1, seg2, seg3, prev, curr):
            k = [4 / (seg1 - init), 6 / (seg2 - seg1), 11 / (seg3 - seg2)]
            seg = [seg1, seg2, seg3]

            rwd = 0
            for i in range(3):
                if prev < seg[i]:
                    incre = min(curr - prev, seg[i] - prev)
                    rwd += k[i] * incre
                    prev += incre
                    if prev == curr:
                        return rwd
            
            return rwd
        
        def find_exploration(exploration_map):
            if exploration_map.sum() == 0:
                return 0
            i_l, j_l, i_r, j_r = 0, 0, 0, 0
            for i in range(exploration_map.shape[0]):
                for j in range(exploration_map.shape[1]):
                    if exploration_map[i][j]:
                        i_l = i
                        j_l = j
                        break
            for i in range(exploration_map.shape[0] - 1, -1, -1):
                for j in range(exploration_map.shape[1] - 1, -1, -1):
                    if exploration_map[i][j]:
                        i_r = i
                        j_r = j
                        break
            
            return max(np.abs(i_r - i_l), np.abs(j_r - j_l))
            

        reward = 0
        done=False
        info = {}
        
        ### hunting and fishing 21
        rwd_hunting = cal_rwd(10, 20, 35, 50, self.last_global_information["hunting_level"], self.global_information["hunting_level"])
        rwd_fishing = cal_rwd(10, 20, 35, 50, self.last_global_information["fishing_level"], self.global_information["fishing_level"])
        rwd_resources = (rwd_hunting + rwd_fishing) / 2

        ### defeat players 21
        rwd_kill = cal_rwd(0, 1, 3, 6, self.last_global_information["kill_opponent_num"], self.global_information["kill_opponent_num"])

        ### exploration 21
        for idx in range(8):
            exploration_map = self.global_map["agents"]["layers"][:,:,idx*2+1]
            self.global_information["exploration"] = max(find_exploration(exploration_map), self.global_information["exploration"])
        rwd_exploration = cal_rwd(0, 32, 64, 127, self.last_global_information["exploration"], self.global_information["exploration"])
        self.last_global_information["exploration"] = self.global_information["exploration"]
        
        ### food and water
        rwd_food_and_water = 0
        for idx in range(8):
            food = self.global_map['agents']['vectors'][idx*11+2]
            water = self.global_map['agents']['vectors'][idx*11+3]
            if food < 0.3 * self.global_information["hunting_level"]:
                rwd_food_and_water += -0.02
            if water < 0.3 * self.global_information["fishing_level"]:
                rwd_food_and_water += -0.02
        
        ### equipment
        rwd_equip = 0
        if self.global_information['kill_npc_level'] > self.last_global_information['kill_npc_level']:
            rwd_equip = cal_rwd(0, 1, 10, 20, self.last_global_information["kill_npc_level"], self.global_information["kill_npc_level"])
        
        ### level
        alive = False
        rwd_level = 0
        for idx in range(8):
            level = self.global_map['agents']['vectors'][idx*11+0]
            if level > 0:
                alive = True
            self.global_information['level'] = max(self.global_information['level'], level)
        rwd_level += (self.global_information['level'] - self.last_global_information['level']) * 21 / 80
        self.last_global_information['level'] = self.global_information['level']

        ### alive
        rwd_alive = 0
        if alive:
            rwd_alive += 0.02
        
        reward = rwd_resources + rwd_kill + rwd_exploration + rwd_food_and_water + rwd_equip + rwd_level + rwd_alive
            
        ######## done
        done = False
        self.t += 1
        if self.t >= 1024 or not alive:
            done = True
        
        return new_my_obs, reward, done, info
        

In [None]:

########################
## Myteam
class MyAwesomeTeam(Team):
    def reset(self):
        # reset some states
        self.global_map = {}
        self.n_tick = 0 
        self.LOCAL_BIAS = 16
        self.global_information={ "hunting_level":10, 
                                  "fishing_level":10,
                                  "kill_npc_level":0,
                                  "kill_npc_num":0,
                                  "kill_opponent_num":0,
                                  "exploration":0,
                                  "level": 1}
        self.last_global_information={ "hunting_level":10, 
                                       "fishing_level":10,
                                       "kill_npc_level":0,
                                       "kill_npc_num":0,
                                       "kill_opponent_num":0,
                                       "exploration":0,
                                       "level": 1}
        self.last_action_informations= [ [0,0] for i in range(8) ]  # id-level
    
    def obs_normalization(self,agent_layers,agent_vectors,opponents,npcs,terrain):
        MAX_LEVEL = 80
        MAX_HEALTH = MAX_LEVEL
        MAX_FOOD = 80 
        MAX_WATER = 80
        MAX_DAMAGE = 1000
        MAX_HUNTING_LEVEL = 75 
        MAX_FISHING_LEVEL = 75 
        MAX_KILL_NUM = 20
        
        # normalization for the agents
        for agent_index in range(8):
            # level
            agent_vectors[agent_index*11+0] /= MAX_LEVEL
            # health
            agent_vectors[agent_index*11+1] /= MAX_HEALTH
            # food
            agent_vectors[agent_index*11+2] /= MAX_FOOD
            # water 
            agent_vectors[agent_index*11+3] /= MAX_WATER
            # DAMAGE 
            agent_vectors[agent_index*11+4] /= MAX_DAMAGE
            # HUNTING_LEVEL
            agent_vectors[agent_index*11+6] /= MAX_HUNTING_LEVEL
            # FISHING LEVEL 
            agent_vectors[agent_index*11+7] /= MAX_FISHING_LEVEL
            # KILL_NPC_LEVEL
            agent_vectors[agent_index*11+8] /= MAX_LEVEL
            # kill_npc_num
            agent_vectors[agent_index*11+9] /= MAX_KILL_NUM
            # KILL_OPPONENT_NUM
            agent_vectors[agent_index*11+10] /= MAX_KILL_NUM
        
        MAX_DIS_NUM = 5
        # normalization for the opponents AND NPCS 
        # distribution 
        opponents[:,:,0] /= MAX_DIS_NUM 
        npcs[:,:,0] /= MAX_DIS_NUM
        # level 
        opponents[:,:,1] /= MAX_LEVEL
        npcs[:,:,1] /= MAX_LEVEL
        # damage
        opponents[:,:,2] /= MAX_DAMAGE
        npcs[:,:,2] /= MAX_DAMAGE
        # food 
        opponents[:,:,3] /= MAX_FOOD
        npcs[:,:,3] /= MAX_FOOD
        # water 
        opponents[:,:,4] /= MAX_WATER
        npcs[:,:,4] /= MAX_WATER
        # health
        opponents[:,:,5] /= MAX_HEALTH
        npcs[:,:,5] /= MAX_HEALTH
        
        ##### clip between [0,1]
        agent_layers = np.clip(agent_layers,0,1)
        opponents = np.clip(opponents,0,1)
        npcs = np.clip(npcs,0,1)
        terrain = np.clip(terrain,0,1)
        agent_vectors = np.clip(agent_vectors,0,1)
        
        # concatenate
        obs_layers = np.concatenate( (agent_layers,opponents,npcs,terrain),axis=-1 )
        
        return ( obs_layers, agent_vectors )
            
    
    def update_map(self,observations:dict[int,dict]):
        ############# agent layer information
        agent_visible = np.zeros([129,129])
        agent_info = np.zeros([129,129,16])
        agents_dis = np.zeros([129,129])
        ############# agent vector information
        agent_vectors = np.zeros([88])
        ############# opponent layer information
        opponents=np.zeros([129,129,7])
        ############# npc layer infromation
        npcs=np.zeros([129,129,7])
        ############# record the attack objects 
        self.attack_objs = []
        for k in self.global_information:
            self.last_global_information[k] = self.global_information[k]
        for agent_index in range(8):
            attack_obj = []
            if observations.get(agent_index)!=None:
                ######## obttain the obs from the env
                obs_en = observations[agent_index]['Entity']['Continuous']
                ######## team pop 
                team_id = int( obs_en[0][4] )
                ########  agent row | column index 
                agent_row_index = int( obs_en[0][5]- self.LOCAL_BIAS )
                agent_column_index = int( obs_en[0][6] - self.LOCAL_BIAS ) 
                ######## opponents, npcs location
                opponents_loc = []
                npcs_loc = []
                
                ######## information for npcs killing and opponents killing
                last_object_list = []
                
                ######## 
                for index in range(100):
                    info = int(obs_en[index][0])
                    if info==0:
                        break
                    entity_id = int(obs_en[index][1]) # entity_id
                    level = int( obs_en[index][3] )   # agent level 
                    pop = int( obs_en[index][4])      # agent pop 
                    row_index = int( obs_en[index][5]- self.LOCAL_BIAS )  # row index 
                    column_index = int( obs_en[index][6] - self.LOCAL_BIAS ) # column index 
                    damage = int( obs_en[index][7] )  # damage received 
                    timealive = int( obs_en[index][8] ) # timealive
                    food = int( obs_en[index][9] )  # food 
                    water = int( obs_en[index][10] ) # water
                    health = int( obs_en[index][11] ) # health 
                    frozen = int( obs_en[index][12] ) # frozen 

                    if index==0:
                        # 表示是该team的agent
                        ### update layers 
                        ######### 更新 agent_visible
                        agent_visible[int(max(row_index-7,0)):int(min(row_index+8,129)),
                                      int(max(column_index-7,0)):int(min(column_index+8,129))] = 1
                        ######### update agent info
                        ################################   alive
                        agent_info[row_index][column_index][agent_index*2+0]=1
                        ################################   exploration
                        if self.global_map=={}: # first step
                            agent_info[:,:,agent_index*2+1]=agent_visible
                        else:
                            agent_info[:,:,agent_index*2+1] = self.global_map["agents"]["layers"][:,:,agent_index*2+1]
                            agent_info[int(max(row_index-7,0)):int(min(row_index+8,129)),
                                      int(max(column_index-7,0)):int(min(column_index+8,129)),
                                      agent_index*2+1] = 1
                            
                        ######### agent 更新 agent dis
                        agents_dis[row_index][column_index]=1 # locations 
                        
                        ### update agent vectors 
                        ######### agent 更新 agent level
                        agent_vectors[agent_index*11+0]=level
                        ######### agent 更新 agent health
                        agent_vectors[agent_index*11+1]=health
                        ######### agent 更新 agent food
                        agent_vectors[agent_index*11+2]=food                        
                        ######### agent 更新 agent water
                        agent_vectors[agent_index*11+3]=water
                        ######### agent 更新 agent damage
                        agent_vectors[agent_index*11+4]=damage                        
                        ######### agent 更新 agent frozen
                        agent_vectors[agent_index*11+5]=frozen  
                        
                        ######### agent 更新 hunting level
                        self.global_information["hunting_level"] = max(food,self.global_information["hunting_level"])
                        agent_vectors[agent_index*11+6] = self.global_information["hunting_level"]
                        
                        ######### agent 更新 fishing level
                        self.global_information["fishing_level"] = max(water,self.global_information["fishing_level"])
                        agent_vectors[agent_index*11+7] = self.global_information["fishing_level"]

                        ######### agent 更新 timealive
                        self.n_tick = timealive

                        #############################################################
                    else:
                        if pop>0 and pop!=team_id: # opponent
                            ######## opponent distribution
                            opponents[row_index][column_index][0]+=1
                            ######## level
                            opponents[row_index][column_index][1]+=level
                            ######## damage
                            opponents[row_index][column_index][2]+=damage
                            ######## food
                            opponents[row_index][column_index][3]+=food
                            ######## water
                            opponents[row_index][column_index][4]+=water                                
                             ######## health
                            opponents[row_index][column_index][5]+=health
                            ######## frozen
                            opponents[row_index][column_index][6]+=frozen

                            ######## killing opponents
                            if self.last_action_informations[agent_index][0]>0:
                                last_object_list.append(entity_id)
                            
                            if [row_index,column_index] not in opponents_loc:
                                opponents_loc.append( [row_index,column_index] )
                            attack_obj.append([entity_id,
                                               abs(row_index-agent_row_index),
                                               abs(column_index-agent_column_index),
                                              level,
                                              damage,
                                              food,
                                              water,
                                              health,
                                              frozen])
                        elif pop<0: # npcs 
                            ######## npc distribution
                            npcs[row_index][column_index][0]+=1
                            ######## level
                            npcs[row_index][column_index][1]+=level
                            ######## damage
                            npcs[row_index][column_index][2]+=damage
                            ######## food
                            npcs[row_index][column_index][3]+=food
                            ######## water
                            npcs[row_index][column_index][4]+=water                                
                            ######## health
                            npcs[row_index][column_index][5]+=health
                            ######## frozen
                            npcs[row_index][column_index][6]+=frozen      
                            
                            ############ killing npcs 
                            if self.last_action_informations[agent_index][0]<0:
                                last_object_list.append(entity_id)
                                
                            if [row_index,column_index] not in npcs_loc:
                                npcs_loc.append([row_index,column_index])
                                
                            attack_obj.append([entity_id,
                                               abs(row_index-agent_row_index),
                                               abs(column_index-agent_column_index),
                                              level,
                                              damage,
                                              food,
                                              water,
                                              health,
                                              frozen])
                
                for loc in opponents_loc:
                    for i in range(1,7):
                        opponents[loc[0]][loc[1]][i] = int( opponents[loc[0]][loc[1]][i]/opponents[loc[0]][loc[1]][0])
                for loc in npcs_loc:
                    for i in range(1,7):
                        npcs[loc[0]][loc[1]][i] = int(npcs[loc[0]][loc[1]][i]/npcs[loc[0]][loc[1]][0])
            ### update the killing information:
            if self.last_action_informations[agent_index][0]>0 and \
               self.last_action_informations[agent_index][0] not in last_object_list :
                self.global_information["kill_opponent_num"]+=1
            elif self.last_action_informations[agent_index][0]<0 and \
               self.last_action_informations[agent_index][0] not in last_object_list :
                self.global_information["kill_npc_num"]+=1
                self.global_information["kill_npc_level"] = max(self.global_information["kill_npc_level"],
                                                               self.last_action_informations[agent_index][1])
            
            ####### update the killing information:
            ### kill_npc level 
            agent_vectors[agent_index*11+8] = self.global_information["kill_npc_level"]
            ### kill_npc_num
            agent_vectors[agent_index*11+9] = self.global_information["kill_npc_num"]
            ### kill_opponent_num 
            agent_vectors[agent_index*11+10] = self.global_information["kill_opponent_num"]
            
            self.attack_objs.append(np.array(attack_obj,dtype=np.float32))
        ##################### terrain update
        if self.global_map=={}: # first step
            terrain = np.zeros([129,129,8])  
            for agent_index in range(8):
                if observations.get(agent_index)!=None:
                    obs_til = observations[agent_index]['Tile']['Continuous']
                    for index in range(225):
                        til_type = int( obs_til[index][1] )
                        til_row = int( obs_til[index][2] -self.LOCAL_BIAS )
                        til_column = int( obs_til[index][3] -self.LOCAL_BIAS ) 
                        if til_row<0 or til_row > 128 or til_column<0 or til_column>128:
                            continue 

                        terrain[til_row][til_column][0]= 1
                        terrain[til_row][til_column][til_type]=1
        else:
            terrain = self.global_map['terrain']
            for i in range(129):
                for j in range(129):
                    if terrain[i][j][3]==1:
                        terrain[i][j][6]=1
                    if terrain[i][j][4]==1:
                        terrain[i][j][7]=1       
            for agent_index in range(8):     
                if observations.get(agent_index)!=None:
                    obs_til = observations[agent_index]['Tile']['Continuous']
                    for index in range(225):
                        til_type = int( obs_til[index][1] )
                        til_row = int( obs_til[index][2] -self.LOCAL_BIAS )
                        til_column = int( obs_til[index][3] -self.LOCAL_BIAS ) 
                        if til_row<0 or til_row > 128 or til_column<0 or til_column>128:
                            continue 

                        terrain[til_row][til_column][0]= 1
                        terrain[til_row][til_column][til_type]=1

                        if til_type==3 or til_type==4:
                            terrain[til_row][tile_column][6]=0
                            terrain[til_row][tile_column][6]=0
                                
        ######## concatenate                        
        self.global_map['terrain']=terrain
        self.global_map['oppos']=opponents
        self.global_map['npcs']=npcs
        
        agent_visible_expanded = np.expand_dims(agent_visible,axis=-1)
        agents_dis_expanded = np.expand_dims(agents_dis,axis=-1)
        agent_layers = np.concatenate( (agent_info,agent_visible_expanded,agents_dis_expanded),axis=-1 )
        self.global_map['agents']={"layers":agent_layers, "vectors":agent_vectors}

        obs = self.obs_normalization(copy.deepcopy(agent_layers),
                                     copy.deepcopy(agent_vectors),
                                     copy.deepcopy(opponents),
                                     copy.deepcopy(npcs),
                                     copy.deepcopy(terrain))
        
        return obs
    
    
    
    
    def show_map(self,characters,agent_attris,attris,agent_index=0,terrain_option=False,ter_attris=[0,1,2,3,4,5,6,7]):
        """
        characters :  ["agents","oppos","npcs"]
        attris = [0,1,2,3,4,5,6]
        ter_attris = [0,1,2,3,4,5,6,7]
        """
        
        CHARAC = ['distribution','level','damage','food','water','health','is_frozen']
        AGENTS = ["location","exploration",'visible','agents distribution']
        TERRAIN = ["exploration","water","grass","scrub(latest)","forest(latest)",
                   "stone","scrub(previous)","forest(previous)"]
        i = 0
        for character in characters:
            if character == "agents":
                for attri in agent_attris:
                    plt.figure(i)
                    i+=1
                    ######## show agents 
                    plt.imshow(self.global_map[character]["layers"][:,:,agent_index*2+attri],cmap=plt.cm.cool)
                    plt.colorbar()
                    plt.title(character+" "+AGENTS[attri])
                    plt.show()
                    
                plt.figure(i)
                i+=1
                plt.imshow(self.global_map[character]["layers"][:,:,-2],cmap=plt.cm.cool)
                plt.colorbar()
                plt.title(AGENTS[-2])
                
                plt.figure(i)
                i+=1
                plt.imshow(self.global_map[character]["layers"][:,:,-1],cmap=plt.cm.cool)
                plt.colorbar()
                plt.title(AGENTS[-1])                
            else:
                for attri in attris:
                    plt.figure(i)
                    i+=1
                    plt.imshow(self.global_map[character][:,:,agent_index*8+attri],cmap=plt.cm.cool)
                    plt.colorbar()
                    plt.title(character+" "+CHARAC[attri])
                    plt.show()    
        if terrain_option:
            terrains = self.global_map['terrain']
            for ter_attri in ter_attris:
                plt.figure(i)
                i+=1     
                
                plt.imshow(terrains[:,:,ter_attri],cmap=plt.cm.cool)
                plt.colorbar()
                plt.title("terrain "+TERRAIN[ter_attri])
                    
                    
    def act(self, actions):
        # actions: 8 * 8 [x, y, x_w, y_w, damage_w, food_w, health_w, frozen_w]
        # self.attack_objs: n_bojs * 7 [id, x, y, damage, food, health, frozen]
        # id, x, y, level, damage, food, water, health, frozen
        
        actions = actions.reshape((8, 8))
        moves = np.zeros(actions.shape[0])
        scores = np.zeros((actions.shape[0], self.attack_objs.shape[0], 2))
        dists = {}
        idx = [0, 1, 2, 4, 5, 6, 7, 8]
        self.attack_objs = self.attack_objs[:,idx]
        id_level = {}
        for obj in self.attack_objs:
            id_level[obj[0]] = obj[3]
        for i, a in enumerate(actions):
            if np.abs(a[0]) > np.abs(a[1]) and a[0] > 0:
                moves[i] = 3
            elif np.abs(a[0]) > np.abs(a[1]) and a[0] < 0:
                moves[i] = 4
            elif np.abs(a[0]) < np.abs(a[1]) and a[1] > 0:
                moves[i] = 1
            elif np.abs(a[0]) < np.abs(a[1]) and a[1] < 0:
                moves[i] = 2
            
            scores[i][:,1] = np.sum(a[2:] * self.attack_objs[:,1:], 1)
            scores[i][:,0] = self.attack_objs[:,0]
            scores = scores[np.argsort(scores[:,1])[::-1]] # sort
            dists[i] = {}
            dist = np.sqrt(np.sum(np.square(self.attack_objs[:,1:3] - a[:2]), 1))
            for obj in self.attack_objs:
                dists[i][obj[0]] = dist
        
        central_dist = np.linalg.norm(actions[:,:2], axis=1)
        moves = np.where(central_dist < 0.1, 0, moves)
        
        actions = {}
        for i in range(actions.shape[0]):
            a = {action.Attack:{
                    action.Style: None,
                    action.Target: None
                },
                 action.Move: {
                    action.Direction: None
                }
                }
            if moves[i]:
                a[action.Move][action.Direction] = moves[i]

            for score in scores[i]:
                dist = dists[i][score[0]]
                if dist > 4:
                    continue
                elif dist <= 1:
                    a[action.Attack][action.Style] = 0
                elif dist <= 3:
                    a[action.Attack][action.Style] = 1
                elif dist <= 4:
                    a[action.Attack][action.Style] = 2
                a[action.Attack][action.Target] = score[0]
                self.last_action_informations[i][0] = score[0]
                self.last_action_informations[i][1] = id_level[score[0]]
                break
            
            actions[i] = a
        
        return actions
        


In [None]:
from ijcai2022nmmo import CompetitionConfig, scripted, TeamBasedEnv
import nmmo
import numpy as np
import matplotlib.pyplot as plt



Config = CompetitionConfig()
env = TeamBasedEnv(Config)
myteam = MyAwesomeTeam('myteam',Config)
myteam.reset()
obs = env.reset()
obs_team = obs[0]
new_obs = myteam.update_map(obs_team)
print(new_obs)
# print(new_obs[0].shape)
# print(new_obs[1].shape)

myteam.show_map(["agents","oppos","npcs"],[0,1],[0,1,2,3,4,5,6],agent_index=0,terrain_option=True,ter_attris=[0,1,2,3,4,5,6,7])

print( len(myteam.attack_objs) )  # attack_objs -> list[ np.ndarray(float32) ] 8个agent

In [None]:
myteam.global_map["terrain"][2]

In [None]:
import gym
from gym import spaces
import numpy as np

observe = [spaces.Box(low=-1.0,high=1.0,shape=(20,20),dtype=np.float32)]*3
print(tuple(observe))


In [None]:
import ray 
import ray.rllib.agents.ppo as ppo
from ray.tune.logger import pretty_print
import tqdm
import json5
import ray.tune as tune 
import gym
import gym.spaces as spaces
import numpy as np
from ray.tune.registry import register_env

######### load config 
with open("config.json") as json_file:
    all_config = json5.load(json_file)
    
# all_config["lr"] = tune.grid_search([0.001,0.005])
# all_config["env"] = "MyTrainEnv"
    
print(pretty_print(all_config["algo_config"]))


In [None]:
##### register env    
def trainenv_creator(env_config):
    return MyTrainEnv(env_config)

register_env("MyTrainEnv",trainenv_creator)
# trainer = ppo.PPOTrainer(env="MyTrainEnv",config=all_config["algo_config"] )# config to pass to env class

all_config["env"] = "MyTrainEnv"
all_config["lr"] = tune.grid_search([0.001,0.0005])

tune.run('PPO',
         config=all_config["algo_config"],
         stop={"timesteps_total":40000}
        )

In [None]:
import matplotlib.pyplot as plt
exp = [0]
for i in range(2,99+1):
    increment = np.floor(i-1 + 300*(2**((i-1)/7.0)))/4.0
    exp += [exp[-1] + increment]
exp = np.floor(np.array(exp))
plt.plot(exp)
plt.show()