In [12]:
# training_code_ppo.py

import os
import random
import numpy as np
from collections import deque
from scipy.spatial import cKDTree
import csv

import torch
import torch.nn as nn
import torch.optim as optim
from torch.distributions import Categorical

"""
Overview:

PPO Agent architecture:
input: agent x, agent y, agent sugar, message x, message y,
relu activation
"""

# Set up device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

params = {
            'max_sugar': 5,
            'growth_rate': 1,
            'sugar_peak_frequency': 0.05,
            'sugar_peak_spread': 4,
            'job_center_duration': (20, 50),
            'vision_range': 1,
            'message_expiry': 15,
            'max_messages': 5,
        }

class PPOAgent(nn.Module):
    def __init__(self, state_size, action_size):
        super(PPOAgent, self).__init__()
        # Common network layers
        self.fc1 = nn.Linear(state_size, 128)
        self.relu1 = nn.ReLU()

        # Policy network
        self.fc_policy = nn.Linear(128, 64)
        self.relu_policy = nn.ReLU()
        self.policy_head = nn.Linear(64, action_size)
        self.softmax = nn.Softmax(dim=-1)

        # Value network
        self.fc_value = nn.Linear(128, 64)
        self.relu_value = nn.ReLU()
        self.value_head = nn.Linear(64, 1)

        # Initialize weights
        for layer in [self.fc1, self.fc_policy, self.fc_value]:
            nn.init.xavier_uniform_(layer.weight)
            nn.init.zeros_(layer.bias)
        nn.init.xavier_uniform_(self.policy_head.weight)
        nn.init.zeros_(self.policy_head.bias)
        nn.init.xavier_uniform_(self.value_head.weight)
        nn.init.zeros_(self.value_head.bias)

        self.to(device)

    def forward(self, x):
        x = self.relu1(self.fc1(x))

        # Policy network forward
        policy_x = self.relu_policy(self.fc_policy(x))
        action_logits = self.policy_head(policy_x)

        # Value network forward
        value_x = self.relu_value(self.fc_value(x))
        state_value = self.value_head(value_x)

        return action_logits, state_value

class SugarscapeEnvironmentPPO:
    def __init__(self, width, height, num_agents, seed=None, num_episodes=200, params = None):
        self.width = width
        self.height = height
        self.num_agents = num_agents
        self.seed = seed

        if self.seed is not None:
            random.seed(self.seed)
            np.random.seed(self.seed)
            torch.manual_seed(self.seed)

        self.params = params
        self.num_episodes = num_episodes
        self.state_size = 5 + (2 * self.params['vision_range'] + 1) ** 2 + (3 * self.params['max_messages'])
        self.action_size = 5  # Up, Down, Left, Right, Stay
        self.agent = PPOAgent(self.state_size, self.action_size).to(device)
        self.optimizer = optim.Adam(self.agent.parameters(), lr=1e-4)
        self.gamma = 0.99
        self.eps_clip = 0.2  # PPO clipping parameter
        self.K_epochs = 4  # Number of epochs for updating policy
        self.batch_size = 64

        # Storage for PPO
        self.memory = []

    def create_initial_sugar_peaks(self, num_peaks=2):
        self.job_centers = []
        for _ in range(num_peaks):
            self.create_job_center()
        self.update_sugar_landscape()

    def create_job_center(self):
        x, y = np.random.randint(0, self.width), np.random.randint(0, self.height)
        duration = np.random.randint(*self.params['job_center_duration'])
        self.job_centers.append({
            'x': x, 'y': y,
            'duration': duration,
            'max_sugar': self.params['max_sugar']
        })

    def update_sugar_landscape(self):
        self.sugar = np.zeros((self.height, self.width))
        for center in self.job_centers:
            x_grid, y_grid = np.meshgrid(np.arange(self.width), np.arange(self.height))
            distance = np.sqrt((x_grid - center['x']) ** 2 + (y_grid - center['y']) ** 2)
            sugar_level = center['max_sugar'] * np.exp(-distance ** 2 / (2 * self.params['sugar_peak_spread'] ** 2))
            self.sugar += sugar_level
        self.sugar = np.clip(self.sugar, 0, self.params['max_sugar'])
        self.sugar = np.round(self.sugar).astype(int)

    def initialize_agents(self):
        agents = []
        available_positions = set((x, y) for x in range(self.width) for y in range(self.height))
        for i in range(self.num_agents):
            if not available_positions:
                break
            x, y = available_positions.pop()
            agents.append(self.create_agent(i, x, y))
        return agents

    def create_agent(self, id, x, y):
        return {
            'id': id,
            'x': x,
            'y': y,
            'sugar': np.random.randint(20, 50),
            'metabolism': np.random.randint(1, 3),
            'vision': self.params['vision_range'],
            'messages': deque(maxlen=self.params['max_messages']),
            'destination': None,
            'memory': [],
            'path': [],
            'age': 0
        }

    def reset_environment(self):
        self.timestep = 0
        self.job_centers = []
        self.sugar = np.zeros((self.height, self.width), dtype=int)
        self.create_initial_sugar_peaks()
        self.agents = self.initialize_agents()
        self.agent_positions = set((agent['x'], agent['y']) for agent in self.agents)
        self.dead_agents = []
        self.memory = []

    def run_training(self, total_episodes=None, max_timesteps=1000):
        # Create checkpoints directory if it doesn't exist
        os.makedirs('checkpoints', exist_ok=True)

        # Open CSV file for writing
        with open('training_log.csv', mode='w', newline='') as csv_file:
            fieldnames = ['Episode', 'TotalReward', 'FinalAgentPopulation']
            writer = csv.DictWriter(csv_file, fieldnames=fieldnames)
            writer.writeheader()

            if total_episodes is None:
                total_episodes = self.num_episodes

            for episode in range(total_episodes):
                self.reset_environment()
                total_reward = 0  # Initialize total reward for the episode

                for t in range(max_timesteps):
                    timestep_reward = self.step()  # Modify step() to return total rewards
                    total_reward += timestep_reward

                    # Update policy every batch_size steps
                    if len(self.memory) >= self.batch_size:
                        self.train_ppo()
                        self.memory = []

                    self.timestep += 1  # Increment timestep here

                # Log data for the episode
                final_agent_population = len(self.agents)
                log_entry = {
                    'Episode': episode + 1,
                    'TotalReward': total_reward,
                    'FinalAgentPopulation': final_agent_population
                }
                writer.writerow(log_entry)

                print(f"Episode {episode + 1} completed. Total Reward: {total_reward:.2f}")

                if (episode + 1) % 10 == 0:
                    print(f"Completed Episode: {episode + 1}")
                    # Save model every 100 episodes into checkpoints folder
                    torch.save(self.agent.state_dict(), f'checkpoints/ppo_agent_episode_{episode + 1}.pth')

        print("\nTraining completed and log saved to 'training_log.csv'.")

    def get_state(self, agent):
        x, y = agent['x'], agent['y']
        sugar = agent['sugar'] / 100  # Normalize sugar level
        metabolism = agent['metabolism'] / 5  # Normalize metabolism
        vision = agent['vision'] / 5  # Normalize vision

        # Extract sugar levels within vision range
        vision_range = agent['vision']
        y_min = max(0, y - vision_range)
        y_max = min(self.height, y + vision_range + 1)
        x_min = max(0, x - vision_range)
        x_max = min(self.width, x + vision_range + 1)
        sugar_map = self.sugar[y_min:y_max, x_min:x_max]

        # Pad the sugar map to a fixed size
        expected_size = (2 * vision_range + 1, 2 * vision_range + 1)
        pad_h = expected_size[0] - sugar_map.shape[0]
        pad_w = expected_size[1] - sugar_map.shape[1]

        pad_top = pad_h // 2
        pad_bottom = pad_h - pad_top
        pad_left = pad_w // 2
        pad_right = pad_w - pad_left

        padded_sugar_map = np.pad(sugar_map, ((pad_top, pad_bottom), (pad_left, pad_right)), mode='constant', constant_values=0)
        sugar_map_flat = padded_sugar_map.flatten() / self.params['max_sugar']

        # Encode messages
        N = self.params['max_messages']
        messages = list(agent['messages'])[-N:]
        message_features = []
        for msg in messages:
            # Normalize message coordinates relative to grid size
            msg_x = msg['x'] / self.width
            msg_y = msg['y'] / self.height
            msg_sugar = msg['sugar_amount'] / self.params['max_sugar']
            message_features.extend([msg_x, msg_y, msg_sugar])
        # Pad remaining messages with zeros if fewer than N
        while len(message_features) < 3 * N:
            message_features.extend([0.0, 0.0, 0.0])

        state = np.concatenate(([x / self.width, y / self.height, sugar, metabolism, vision], sugar_map_flat, message_features))
        return state

    def select_action(self, state, valid_actions):
        state_tensor = torch.FloatTensor(state).unsqueeze(0).to(device)
        with torch.no_grad():
            action_logits, _ = self.agent(state_tensor)
        action_probs = torch.softmax(action_logits, dim=-1).squeeze().cpu().numpy()

        # Check for NaNs
        if np.isnan(action_probs).any():
            print("Error: action_probs contains NaN values.")
            action_probs = np.ones(self.action_size) / self.action_size

    # Mask invalid actions
    mask = np.zeros(self.action_size, dtype=bool)
    mask[valid_actions] = True
    masked_probs = action_probs * mask

    total_prob = masked_probs.sum()
    if total_prob == 0:
        print("Warning: Sum of masked_probs is zero. Assigning equal probabilities to valid actions.")
        masked_probs = mask.astype(float)
        masked_probs /= masked_probs.sum()
    else:
        masked_probs /= total_prob

    action = np.random.choice(self.action_size, p=masked_probs)
    return action, action_probs


    def get_valid_actions(self, agent):
        actions = []
        x, y = agent['x'], agent['y']
        possible_moves = {
            0: (x, y - 1),  # Up
            1: (x, y + 1),  # Down
            2: (x - 1, y),  # Left
            3: (x + 1, y),  # Right
            4: (x, y)       # Stay
        }
        for action, (nx, ny) in possible_moves.items():
            if 0 <= nx < self.width and 0 <= ny < self.height:
                if (nx, ny) not in self.agent_positions or (nx, ny) == (x, y):
                    actions.append(action)
        return actions

    def move_agent(self, agent, action):
        x, y = agent['x'], agent['y']
        possible_moves = {
            0: (x, y - 1),  # Up
            1: (x, y + 1),  # Down
            2: (x - 1, y),  # Left
            3: (x + 1, y),  # Right
            4: (x, y)       # Stay
        }
        nx, ny = possible_moves[action]
        if (0 <= nx < self.width and 0 <= ny < self.height and
                ((nx, ny) not in self.agent_positions or (nx, ny) == (x, y))):
            self.agent_positions.remove((x, y))
            agent['x'], agent['y'] = nx, ny
            agent['path'].append((agent['x'], agent['y']))
            self.agent_positions.add((nx, ny))

    def collect_sugar_and_update_agent(self, agent):
        collected_sugar = self.sugar[agent['y'], agent['x']]
        agent['sugar'] += collected_sugar
        self.sugar[agent['y'], agent['x']] = 0
        agent['sugar'] -= agent['metabolism']
        agent['age'] += 1

    def broadcast_messages(self):
        if not self.agents:
            return  # No agents to broadcast

        positions = np.array([[agent['x'], agent['y']] for agent in self.agents])
        tree = cKDTree(positions)

        for i, agent in enumerate(self.agents):
            # Identify visible sugar peaks
            visible_sugar = self.get_visible_sugar(agent)
            sugar_locations = np.argwhere(visible_sugar > 0)
            messages = []
            for loc in sugar_locations:
                msg_x = agent['x'] + loc[1] - agent['vision']
                msg_y = agent['y'] + loc[0] - agent['vision']
                # Ensure message coordinates are within grid
                msg_x = int(np.clip(msg_x, 0, self.width - 1))
                msg_y = int(np.clip(msg_y, 0, self.height - 1))
                msg = {
                    'sender_id': agent['id'],
                    'timestep': self.timestep,
                    'sugar_amount': self.sugar[msg_y, msg_x],
                    'x': msg_x,
                    'y': msg_y
                }
                messages.append(msg)

            # Broadcast to neighbors within broadcast_radius
            radius = 5  # Fixed broadcast radius
            neighbors = tree.query_ball_point([agent['x'], agent['y']], radius)
            for neighbor_idx in neighbors:
                if neighbor_idx != i:
                    for msg in messages:
                        self.agents[neighbor_idx]['messages'].append(msg)

    def get_visible_sugar(self, agent):
        x, y = agent['x'], agent['y']
        vision = agent['vision']
        y_min = max(0, y - vision)
        y_max = min(self.height, y + vision + 1)
        x_min = max(0, x - vision)
        x_max = min(self.width, x + vision + 1)
        visible_area = self.sugar[y_min:y_max, x_min:x_max]
        return visible_area

    def step(self):
        # Update job centers and sugar landscape
        for center in self.job_centers:
            center['duration'] -= 1
        self.job_centers = [center for center in self.job_centers if center['duration'] > 0]
        if np.random.random() < self.params['sugar_peak_frequency']:
            self.create_job_center()
        self.update_sugar_landscape()

        # Broadcast messages
        self.broadcast_messages()

        total_rewards = 0  # Initialize total rewards for this timestep

        # For each agent, select action and collect experience
        for agent in self.agents:
            state = self.get_state(agent)
            valid_actions = self.get_valid_actions(agent)
            if not valid_actions:
                continue  # Skip if no valid actions
            action, action_probs = self.select_action(state, valid_actions)
            prev_sugar = agent['sugar']
            self.move_agent(agent, action)
            self.collect_sugar_and_update_agent(agent)
            next_state = self.get_state(agent)
            reward = (agent['sugar'] - prev_sugar) / 10.0  # Normalize reward
            total_rewards += reward  # Accumulate total rewards
            done = agent['sugar'] <= 0

            # Store experience
            agent['memory'].append({
                'state': state,
                'action': action,
                'action_prob': action_probs[action],
                'reward': reward,
                'next_state': next_state,
                'done': done
            })

            # Add to global memory
            self.memory.append(agent['memory'][-1])

        # Handle agent death
        alive_agents = []
        for agent in self.agents:
            if agent['sugar'] <= 0:
                self.agent_positions.remove((agent['x'], agent['y']))
            else:
                alive_agents.append(agent)
        self.agents = alive_agents

        # Replenish agents
        self.replenish_agents()

        return total_rewards  # Return total rewards collected in this timestep

    def train_ppo(self):
        # Convert memory to tensors
        states = torch.FloatTensor(np.array([m['state'] for m in self.memory])).to(device)
        actions = torch.LongTensor(np.array([m['action'] for m in self.memory])).unsqueeze(1).to(device)
        old_action_probs = torch.FloatTensor(np.array([m['action_prob'] for m in self.memory])).unsqueeze(1).to(device)
        rewards = [m['reward'] for m in self.memory]
        dones = [m['done'] for m in self.memory]

        # Compute discounted rewards
        discounted_rewards = []
        R = 0
        for reward, done in zip(reversed(rewards), reversed(dones)):
            if done:
                R = 0
            R = reward + self.gamma * R
            discounted_rewards.insert(0, R)
        discounted_rewards = torch.FloatTensor(discounted_rewards).unsqueeze(1).to(device)

        # Normalize rewards
        discounted_rewards = (discounted_rewards - discounted_rewards.mean()) / (discounted_rewards.std() + 1e-7)

        # Optimize policy for K epochs
        for _ in range(self.K_epochs):
            # Get action probabilities and state values
            action_probs, state_values = self.agent(states)
            action_probs = action_probs.gather(1, actions)
            state_values = state_values

            # Calculate ratios
            ratios = action_probs / old_action_probs

            # Calculate advantages
            advantages = discounted_rewards - state_values.detach()

            # Compute surrogate losses
            surr1 = ratios * advantages
            surr2 = torch.clamp(ratios, 1 - self.eps_clip, 1 + self.eps_clip) * advantages

            # Policy loss
            policy_loss = -torch.min(surr1, surr2).mean()

            # Value loss
            value_loss = nn.MSELoss()(state_values, discounted_rewards)

            # Total loss
            loss = policy_loss + 0.5 * value_loss

            # Optimize
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()


    def replenish_agents(self):
        while len(self.agents) < self.num_agents:
            x, y = random.randint(0, self.width - 1), random.randint(0, self.height - 1)
            if (x, y) not in self.agent_positions:
                agent_id = max([agent['id'] for agent in self.agents] + [0]) + 1
                new_agent = self.create_agent(agent_id, x, y)
                self.agent_positions.add((x, y))
                self.agents.append(new_agent)

# Initialize and run training
if __name__ == "__main__":
    seed = 42  # Set a seed for reproducibility
    TOTAL_EPISODES = 600
    MAX_TIMESTEPS = 500

    env = SugarscapeEnvironmentPPO(width=15, height=15, num_agents=100, seed=seed, num_episodes=TOTAL_EPISODES, params=params)
    env.run_training(total_episodes=TOTAL_EPISODES, max_timesteps=MAX_TIMESTEPS)

    # Save the final model
    torch.save(env.agent.state_dict(), 'ppo_agent_final.pth')

    print("\nTraining completed and model saved.")


AttributeError: 'SugarscapeEnvironmentPPO' object has no attribute 'step'

In [7]:
# simulation_with_pygame_ppo.py

import os
import sys
import random
import numpy as np
from collections import deque
from scipy.spatial import cKDTree
import re

import pygame
from pygame.locals import QUIT

import torch
import torch.nn as nn
import torch.optim as optim

# Set up device (ensure this matches your training code)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Define the PPOAgent class
class PPOAgent(nn.Module):
    def __init__(self, state_size, action_size):
        super(PPOAgent, self).__init__()
        # Common network layers
        self.fc1 = nn.Linear(state_size, 128)
        self.relu1 = nn.ReLU()

        # Policy network
        self.fc_policy = nn.Linear(128, 64)
        self.relu_policy = nn.ReLU()
        self.policy_head = nn.Linear(64, action_size)
        self.softmax = nn.Softmax(dim=-1)

        # Value network
        self.fc_value = nn.Linear(128, 64)
        self.relu_value = nn.ReLU()
        self.value_head = nn.Linear(64, 1)

        # Initialize weights
        for layer in [self.fc1, self.fc_policy, self.fc_value]:
            nn.init.xavier_uniform_(layer.weight)
            nn.init.zeros_(layer.bias)
        nn.init.xavier_uniform_(self.policy_head.weight)
        nn.init.zeros_(self.policy_head.bias)
        nn.init.xavier_uniform_(self.value_head.weight)
        nn.init.zeros_(self.value_head.bias)

        self.to(device)

    def forward(self, x):
        x = self.relu1(self.fc1(x))

        # Policy network forward
        policy_x = self.relu_policy(self.fc_policy(x))
        action_logits = self.policy_head(policy_x)
        action_probs = self.softmax(action_logits)

        # Value network forward
        value_x = self.relu_value(self.fc_value(x))
        state_value = self.value_head(value_x)

        return action_probs, state_value

# Define the SugarscapeEnvironmentPPO class
class SugarscapeEnvironmentPPO:
    def __init__(self, width, height, num_agents, seed=None, params = None):
        self.width = width
        self.height = height
        self.num_agents = num_agents
        self.seed = seed

        if self.seed is not None:
            random.seed(self.seed)
            np.random.seed(self.seed)
            torch.manual_seed(self.seed)

        self.params = params

        self.state_size = 5 + (2 * self.params['vision_range'] + 1) ** 2 + (3 * self.params['max_messages'])
        self.action_size = 5  # Up, Down, Left, Right, Stay
        self.agent_model = PPOAgent(self.state_size, self.action_size).to(device)

        # Load the trained model
        model_path = self.get_latest_model_path()
        if model_path:
            print(f"Loading model from {model_path}")
            self.agent_model.load_state_dict(torch.load(model_path, map_location=device))
            self.agent_model.eval()
            self.agent_model.to(device)
        else:
            raise FileNotFoundError("No trained model found. Please train the model first.")

    def create_initial_sugar_peaks(self, num_peaks=2):
        self.job_centers = []
        for _ in range(num_peaks):
            self.create_job_center()
        self.update_sugar_landscape()

    def create_job_center(self):
        x, y = np.random.randint(0, self.width), np.random.randint(0, self.height)
        duration = np.random.randint(*self.params['job_center_duration'])
        self.job_centers.append({
            'x': x, 'y': y,
            'duration': duration,
            'max_sugar': self.params['max_sugar']
        })

    def update_sugar_landscape(self):
        self.sugar = np.zeros((self.height, self.width))
        for center in self.job_centers:
            x_grid, y_grid = np.meshgrid(np.arange(self.width), np.arange(self.height))
            distance = np.sqrt((x_grid - center['x']) ** 2 + (y_grid - center['y']) ** 2)
            sugar_level = center['max_sugar'] * np.exp(-distance ** 2 / (2 * self.params['sugar_peak_spread'] ** 2))
            self.sugar += sugar_level
        self.sugar = np.clip(self.sugar, 0, self.params['max_sugar'])
        self.sugar = np.round(self.sugar).astype(int)

    def initialize_agents(self):
        agents = []
        available_positions = set((x, y) for x in range(self.width) for y in range(self.height))
        for i in range(self.num_agents):
            if not available_positions:
                break
            x, y = available_positions.pop()
            agents.append(self.create_agent(i, x, y))
        return agents

    def create_agent(self, id, x, y):
        return {
            'id': id,
            'x': x,
            'y': y,
            'sugar': np.random.randint(20, 50),
            'metabolism': np.random.randint(1, 3),
            'vision': self.params['vision_range'],
            'messages': deque(maxlen=self.params['max_messages']),
            'destination': None,
            'memory': [],
            'path': [],
            'age': 0
        }

    def reset_environment(self):
        self.timestep = 0
        self.job_centers = []
        self.sugar = np.zeros((self.height, self.width), dtype=int)
        self.create_initial_sugar_peaks()
        self.agents = self.initialize_agents()
        self.agent_positions = set((agent['x'], agent['y']) for agent in self.agents)
        self.dead_agents = []

    def get_latest_model_path(self):
        # Check for final model
        final_model = 'ppo_agent_final.pth'
        if os.path.exists(final_model):
            return final_model
        # If final model doesn't exist, find the latest checkpoint
        checkpoint_pattern = r'ppo_agent_episode_(\d+)\.pth'
        checkpoints = [f for f in os.listdir('checkpoints') if re.match(checkpoint_pattern, f)]
        if not checkpoints:
            return None
        # Extract episode numbers and find the latest
        episodes = [int(re.findall(checkpoint_pattern, f)[0]) for f in checkpoints]
        latest_episode = max(episodes)
        latest_checkpoint = f'checkpoints/ppo_agent_episode_{latest_episode}.pth'
        return latest_checkpoint

    def run_simulation(self, max_timesteps=100):
        self.reset_environment()
        # Pygame initialization
        pygame.init()
        self.cell_size = 20  # Size of each grid cell in pixels
        self.screen_width = self.width * self.cell_size
        self.screen_height = self.height * self.cell_size
        self.screen = pygame.display.set_mode((self.screen_width, self.screen_height))
        pygame.display.set_caption("Sugarscape Simulation with PPO")
        self.clock = pygame.time.Clock()

        for t in range(max_timesteps):
            self.step_simulation()
            self.draw_environment()
            pygame.display.flip()
            self.clock.tick(10)  # Adjust the speed as needed

            # Handle Pygame events
            for event in pygame.event.get():
                if event.type == QUIT:
                    pygame.quit()
                    sys.exit()

        pygame.quit()

    def step_simulation(self):
        # Update job centers and sugar landscape
        for center in self.job_centers:
            center['duration'] -= 1
        self.job_centers = [center for center in self.job_centers if center['duration'] > 0]
        if np.random.random() < self.params['sugar_peak_frequency']:
            self.create_job_center()
        self.update_sugar_landscape()

        # Broadcast messages
        self.broadcast_messages()

        # For each agent, select action
        for agent in self.agents:
            state = self.get_state(agent)
            valid_actions = self.get_valid_actions(agent)
            if not valid_actions:
                continue  # Skip if no valid actions
            action = self.select_action(state, valid_actions)
            self.move_agent(agent, action)
            self.collect_sugar_and_update_agent(agent)

        # Handle agent death
        alive_agents = []
        for agent in self.agents:
            if agent['sugar'] <= 0:
                self.agent_positions.remove((agent['x'], agent['y']))
            else:
                alive_agents.append(agent)
        self.agents = alive_agents

        # Replenish agents
        self.replenish_agents_simulation()

    def replenish_agents_simulation(self):
        # Do not replenish agents during simulation to observe agent behavior
        pass

    def draw_environment(self):
        # Clear the screen
        self.screen.fill((255, 255, 255))  # White background
        for y in range(self.height):
            for x in range(self.width):
                sugar_amount = self.sugar[y, x]
                if sugar_amount > 0:
                    color_intensity = int(255 * sugar_amount / self.params['max_sugar'])
                    color = (255, 255 - color_intensity, 255 - color_intensity)
                    rect = pygame.Rect(x * self.cell_size, y * self.cell_size, self.cell_size, self.cell_size)
                    pygame.draw.rect(self.screen, color, rect)

        # Draw agents
        for agent in self.agents:
            x, y = agent['x'], agent['y']
            rect = pygame.Rect(x * self.cell_size, y * self.cell_size, self.cell_size, self.cell_size)
            pygame.draw.rect(self.screen, (0, 0, 255), rect)  # Blue color for agents

    # Rest of the methods (get_state, select_action, move_agent, etc.) are similar to training code

    def get_state(self, agent):
        # Same as in training code
        x, y = agent['x'], agent['y']
        sugar = agent['sugar'] / 100  # Normalize sugar level
        metabolism = agent['metabolism'] / 5  # Normalize metabolism
        vision = agent['vision'] / 5  # Normalize vision

        # Extract sugar levels within vision range
        vision_range = agent['vision']
        y_min = max(0, y - vision_range)
        y_max = min(self.height, y + vision_range + 1)
        x_min = max(0, x - vision_range)
        x_max = min(self.width, x + vision_range + 1)
        sugar_map = self.sugar[y_min:y_max, x_min:x_max]

        # Pad the sugar map to a fixed size
        expected_size = (2 * vision_range + 1, 2 * vision_range + 1)
        pad_h = expected_size[0] - sugar_map.shape[0]
        pad_w = expected_size[1] - sugar_map.shape[1]

        pad_top = pad_h // 2
        pad_bottom = pad_h - pad_top
        pad_left = pad_w // 2
        pad_right = pad_w - pad_left

        padded_sugar_map = np.pad(sugar_map, ((pad_top, pad_bottom), (pad_left, pad_right)), mode='constant', constant_values=0)
        sugar_map_flat = padded_sugar_map.flatten() / self.params['max_sugar']

        # Encode messages
        N = self.params['max_messages']
        messages = list(agent['messages'])[-N:]
        message_features = []
        for msg in messages:
            # Normalize message coordinates relative to grid size
            msg_x = msg['x'] / self.width
            msg_y = msg['y'] / self.height
            msg_sugar = msg['sugar_amount'] / self.params['max_sugar']
            message_features.extend([msg_x, msg_y, msg_sugar])
        # Pad remaining messages with zeros if fewer than N
        while len(message_features) < 3 * N:
            message_features.extend([0.0, 0.0, 0.0])

        state = np.concatenate(([x / self.width, y / self.height, sugar, metabolism, vision], sugar_map_flat, message_features))
        return state

    def select_action(self, state, valid_actions):
        state_tensor = torch.FloatTensor(state).unsqueeze(0).to(device)
        with torch.no_grad():
            action_probs, _ = self.agent_model(state_tensor)
        action_probs = action_probs.squeeze().cpu().numpy()

        # Mask invalid actions
        mask = np.zeros(self.action_size, dtype=bool)
        mask[valid_actions] = True
        masked_probs = action_probs * mask
        masked_probs /= masked_probs.sum()

        action = np.random.choice(self.action_size, p=masked_probs)
        return action

    def get_valid_actions(self, agent):
        # Same as in training code
        actions = []
        x, y = agent['x'], agent['y']
        possible_moves = {
            0: (x, y - 1),  # Up
            1: (x, y + 1),  # Down
            2: (x - 1, y),  # Left
            3: (x + 1, y),  # Right
            4: (x, y)       # Stay
        }
        for action, (nx, ny) in possible_moves.items():
            if 0 <= nx < self.width and 0 <= ny < self.height:
                if (nx, ny) not in self.agent_positions or (nx, ny) == (x, y):
                    actions.append(action)
        return actions

    def move_agent(self, agent, action):
        # Same as in training code
        x, y = agent['x'], agent['y']
        possible_moves = {
            0: (x, y - 1),  # Up
            1: (x, y + 1),  # Down
            2: (x - 1, y),  # Left
            3: (x + 1, y),  # Right
            4: (x, y)       # Stay
        }
        nx, ny = possible_moves[action]
        if (0 <= nx < self.width and 0 <= ny < self.height and
                ((nx, ny) not in self.agent_positions or (nx, ny) == (x, y))):
            self.agent_positions.remove((x, y))
            agent['x'], agent['y'] = nx, ny
            agent['path'].append((agent['x'], agent['y']))
            self.agent_positions.add((nx, ny))

    def collect_sugar_and_update_agent(self, agent):
        # Same as in training code
        collected_sugar = self.sugar[agent['y'], agent['x']]
        agent['sugar'] += collected_sugar
        self.sugar[agent['y'], agent['x']] = 0
        agent['sugar'] -= agent['metabolism']
        agent['age'] += 1

    def broadcast_messages(self):
        # Same as in training code
        if not self.agents:
            return  # No agents to broadcast

        positions = np.array([[agent['x'], agent['y']] for agent in self.agents])
        tree = cKDTree(positions)

        for i, agent in enumerate(self.agents):
            # Identify visible sugar peaks
            visible_sugar = self.get_visible_sugar(agent)
            sugar_locations = np.argwhere(visible_sugar > 0)
            messages = []
            for loc in sugar_locations:
                msg_x = agent['x'] + loc[1] - agent['vision']
                msg_y = agent['y'] + loc[0] - agent['vision']
                # Ensure message coordinates are within grid
                msg_x = int(np.clip(msg_x, 0, self.width - 1))
                msg_y = int(np.clip(msg_y, 0, self.height - 1))
                msg = {
                    'sender_id': agent['id'],
                    'timestep': self.timestep,
                    'sugar_amount': self.sugar[msg_y, msg_x],
                    'x': msg_x,
                    'y': msg_y
                }
                messages.append(msg)

            # Broadcast to neighbors within broadcast_radius
            radius = 5  # Fixed broadcast radius
            neighbors = tree.query_ball_point([agent['x'], agent['y']], radius)
            for neighbor_idx in neighbors:
                if neighbor_idx != i:
                    for msg in messages:
                        self.agents[neighbor_idx]['messages'].append(msg)

    def get_visible_sugar(self, agent):
        # Same as in training code
        x, y = agent['x'], agent['y']
        vision = agent['vision']
        y_min = max(0, y - vision)
        y_max = min(self.height, y + vision + 1)
        x_min = max(0, x - vision)
        x_max = min(self.width, x + vision + 1)
        visible_area = self.sugar[y_min:y_max, x_min:x_max]
        return visible_area

# Run the simulation
if __name__ == "__main__":
    seed = 42  # Set a seed for reproducibility
    try:
        env = SugarscapeEnvironmentPPO(width=30, height=30, num_agents=400, seed=seed, params=params)
        MAX_TIMESTEPS = 1000  # Set the number of timesteps for the simulation
        env.run_simulation(max_timesteps=MAX_TIMESTEPS)
    except FileNotFoundError as e:
        print(e)


Loading model from checkpoints/ppo_agent_episode_10.pth


SystemExit: 