In [1]:
import numpy as np

In [2]:
class Agent:
    def __init__(self, agent_name, agent_index):
        self.index = agent_index
        self.agent = agent_name
        self.health = None
        self.isHit = False
        self.move = True
        self.movement_speed = 1.00
        self.previous_position = np.array([0, 0], dtype=np.float32)
        self.current_position = None
        self.same_position = False
        self.current_step = 0
        self.action = None
        pass

    def agent_action(self, action):

        pass

    def agent_reset(self, width, height):
        padding = 30
        self.current_position = np.array(
            [np.random.uniform(30, width - padding), np.random.uniform(30, width - padding)], dtype=np.float32)

    def step_update(self, action, range_x, range_y):

        if action == 0:
            self.current_position[0] -= self.movement_speed
        elif action == 1:
            self.current_position[0] += self.movement_speed
        elif action == 2:
            self.current_position[1] -= self.movement_speed
        elif action == 3:
            self.current_position[1] += self.movement_speed
        
        self.current_position[0] = np.clip(self.current_position[0], 0, range_x)
        self.current_position[1] = np.clip(self.current_position[1], 0, range_y)

In [3]:

from gymnasium.spaces import Discrete, Box, MultiDiscrete
from gymnasium import Env
import numpy as np
import pygame

In [4]:
class GameEnv(Env):
    def __init__(self, screen_width=400, screen_height=400, render_mode='human'):
        super(GameEnv, self).__init__()

        # defining the screen dimension for render purpose
        self.screen_width = screen_width
        self.screen_height = screen_height
        self.render_mode = render_mode

        # defining the observation and action spaces for all the agents
        
        self.observation_space = Box(low=np.array([0, 0, 0, 0], dtype=np.float32),
                                    high=np.array([self.screen_width, self.screen_height, self.screen_width, self.screen_height], dtype=np.float32),
                                    dtype=np.float32)

        # the pygame window should be initialized in the render function

        # setting the total number of agent
        
        self.number_of_prey = 1
        self.number_of_predator = 1
        self.prey_agents = []
        self.predator_agents = []
        self.number_of_agents = self.number_of_prey + self.number_of_prey

        # defining the action space based on total number of predator and prey
        self.action_space = Discrete(10 * self.number_of_agents)

        # if self.number_of_prey > 0 and self.number_of_predator > 0:
        #     self.agent_init()
        # else:
        #     self.prey_agents.append(Agent('prey', 0))
        #     self.predator_agents.append(Agent('predator', 0))

        # setting the total number of obstacles
        self.total_obstacles = None

        # keeping a counter to save the total steps
        self.total_steps = 0

        # initializing the pygame
        pygame.init()

        # setting the screen size
        self.screen = pygame.display.set_mode((self.screen_width, self.screen_height))
        pygame.display.set_caption('Multi Agent Environment(simple)')
        self.clock = pygame.time.Clock()

        # initializing the font
        pygame.font.init()
        self.font = pygame.font.Font(None, 36)

    # this function rerturns the value of the action into 2 digits 
    # if the action_space.sample() gives 1 digit number
    # * if  the number is 3 it will return 03 
    # * if  the number is 14 then it will return 14
    def expand_action_digit(self, action):
        if action < 10:
            return f'0{action}'
        else:
            str(action)

    # this method will initialize the number of agents
    # ! this must be called from outside
    def agent_init(self):

        # initializing all the agents
        prey_agents = []
        predator_agents = []

        # iterating and adding them in the predator and prey lists
        for i in range(0, self.number_of_prey):
            agent = Agent('prey', i)
            prey_agents.append(agent)

        for i in range(0, self.number_of_predator):
            agent = Agent('predator', i)
            predator_agents.append(agent)

        self.prey_agents = prey_agents
        self.predator_agents = predator_agents

    # this function is used to explicitly set the number of agents
    # ! this needs to be called from outside
    def set_agent_number(self, prey_number, predator_number):
        self.number_of_predator = predator_number
        self.number_of_prey = prey_number

    # the usual reset function
    def reset(self, seed=0):
        self.total_steps = 0
        observation = []

        for prey in self.prey_agents:
            prey.agent_reset(width=self.screen_width, height=self.screen_height)
            observation.append([prey.index, prey.agent, prey.current_position])

        for predator in self.predator_agents:
            predator.agent_reset(width=self.screen_width, height=self.screen_height)
            observation.append([predator.index, predator.agent, predator.current_position])

        return observation, seed

    # the step function
    # this function is called for every timesteps
    # this function updates the actions or states of agents in the env
    def step(self, action):
        done = False
        reward = 0.00
        truncated = False
        info = {}

        for event in pygame.event.get():
            if event.type == pygame.QUIT:
                done = True
                pygame.quit()
                
        observation = []

        prey_actions, predator_actions = action

        for prey, action in zip(self.prey_agents, prey_actions):

            # print(f'prey_{prey.index} = action:{action} current_position: {prey.current_position}')
            prey.step_update(action=action, range_x=self.screen_width - 10, range_y=self.screen_height - 10)
            # print(f'prey_{prey.index}: new_position: {prey.current_position}')

            observation.append({'index': prey.index, 'name': prey.agent, 'position': prey.current_position})
                
        for predator, action in zip(self.predator_agents, predator_actions):

            # print(f'predator_{predator.index} = action:{action} current_position: {predator.current_position}')
            predator.step_update(action=action, range_x=self.screen_width - 10, range_y=self.screen_height - 10)
            # print(f'predator_{predator.index}: new_position: {predator.current_position}')
            
            observation.append({'index': predator.index, 'name': predator.agent, 'position': predator.current_position})
        
        self.total_steps += 1

        
        

        # print(self.total_steps)
        self.render()

        return observation, reward, done, truncated, info
        

    def render(self):
        if self.render_mode == 'human':
            screen = self.screen

            # clear screen
            screen.fill((255, 255, 255))

            for prey in self.prey_agents:
                pos_x, pos_y = prey.current_position
                prey_radius = 10
                pygame.draw.circle(screen, (0, 0, 255), (int(pos_x), int(pos_y)), prey_radius)

            for predator in self.predator_agents:
                pos_x, pos_y = predator.current_position
                predator_radius = 10

                pygame.draw.circle(screen, (255, 0, 0), (int(pos_x), int(pos_y)), predator_radius)

            pygame.display.update()

    def close(self):
        pygame.quit()


In [5]:
env = GameEnv()


In [None]:
env.close()

In [None]:

done = False
number_of_prey = 2
number_of_predator = 3

env.set_agent_number(prey_number=number_of_prey, predator_number=number_of_predator)
env.agent_init()
env.reset()

while not done:
    prey_action = []
    predator_action = []
    for i in range(0, number_of_prey):
        prey_action.append(env.action_space.sample())
    
    for i in range(0, number_of_predator):
        predator_action.append(env.action_space.sample())

    action = [prey_action, predator_action]

    obs, reward, done, _, _ = env.step(action)
    print(obs)
    # env.render()
