In [1]:
from policy_net.environment import Environment
from policy_net.learning_agent import LearningAgent
from policy_net.animate import animate_species_ownership_with_static_layers
from policy_net.game import play_game, finish_episode
from policy_net.policy_net import PolicyNet
from nca.constsants import CHANNELS, H, W, NUM_CHANNELS
from nca.nca_model import NCA, get_species_features_tensor,build_channel_mapping_from_species_list
from nca.generate_map import generate_training_world
import torch
import random



16NUMCHANNELS


In [2]:
# Setup

device = 'cuda' if torch.cuda.is_available() else 'cpu' #cuda



policy0 = PolicyNet(num_species=5, h=H, w=W).to(device)
policy1 = PolicyNet(num_species=5, h=H, w=W).to(device)


agent_0_mask = torch.tensor([1, 1, 1, 1, 1, 0, 0, 0, 0, 0], dtype=torch.float32)
agent_1_mask = 1.0 - agent_0_mask

agent0 = LearningAgent(agent_id=0, policy_net=policy0,  start_quadrant="top_left", agent_mask=agent_0_mask)
agent1 = LearningAgent(agent_id=1, policy_net=policy1,  start_quadrant="bottom_right", agent_mask=agent_1_mask)
agents = [agent0, agent1]

optimizer0 = torch.optim.Adam(policy0.parameters(), lr=1e-3)
optimizer1 = torch.optim.Adam(policy1.parameters(), lr=1e-3)
optimizers = [optimizer0, optimizer1]


plant_list = agents[0].available_species + agents[1].available_species
# CHANNELS["plants"] = {
#     plant_name: 6 + i
#     for i, plant_name in enumerate(plant_list)
# }
CHANNELS["plants"] = [
    (plant_name, 6 + i)
    for i, plant_name in enumerate(plant_list)
]
sample_species_features = get_species_features_tensor()

model = NCA(NUM_CHANNELS, sample_species_features.shape[1]).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

checkpoint = torch.load("nca.pth",map_location=torch.device('cuda'))
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

model.eval()



NCA(
  (model): Sequential(
    (0): Conv2d(33, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU()
    (2): Dropout(p=0.1, inplace=False)
    (3): Conv2d(128, 1, kernel_size=(1, 1), stride=(1, 1))
  )
)

In [11]:

import math
from collections import Counter

def compute_entropy(actions):
    if not actions:
        return 0.0
    counter = Counter(actions)
    total = sum(counter.values())
    return -sum((count / total) * math.log2(count / total) for count in counter.values())


torch.autograd.set_detect_anomaly(True)

for episode in range(200):
    print(episode)
    grid = generate_training_world(H, W, seed_plants=False).to(device)
    ownership_grid = torch.full((1, H, W), fill_value=-1, dtype=torch.long, device=device)
    elevation_static = grid[:, CHANNELS["elevation"]].clone().detach()
    shade_static = grid[:, CHANNELS["shade"]].clone().detach()
    soil_static = {
        idx: grid[:, idx].clone().detach()
        for idx in CHANNELS["soil"].values()
    }
    for agent in agents:
        agent.randomize_species(5)
        agent.species_used = []

    if random.random() < .5:
        print("Agent 0 is player 1, Agent 1 is player 2")
        agents[0].start_quadrant = "top_left"
        agents[0].player_number = 1
        agents[1].start_quadrant = "bottom_right"
        agents[1].player_number = 2
        species_list = agents[0].available_species + agents[1].available_species
    else:
        print("Agent 1 is player 1, Agent 0 is player 2")
        agents[1].start_quadrant = "top_left"
        agents[1].player_number = 1
        agents[0].start_quadrant = "bottom_right"
        agents[0].player_number = 2
        species_list = agents[1].available_species + agents[0].available_species

    species_features = get_species_features_tensor(species_list=species_list)
    build_channel_mapping_from_species_list(species_list)   
    env = Environment(grid, model, ownership_grid, agents, elevation_static, soil_static, shade_static, species_features, species_list, steps_per_turn=5)
    print(CHANNELS)
    scores, actions = play_game(env, species_features, 10 ,False)

    print(scores)
    for agent in agents: 

            # Filter agent's past actions
        flattened_actions = [action for turn in actions for action in turn]

        # Extract actions for the current agent
        agent_actions = [
            (species_id, row // 4, col // 4)
            for agent_id, species_id, row, col in flattened_actions
            if agent_id == agent.agent_id
        ]

        # # Compute diversity reward based on action entropy
        # entropy = compute_entropy(agent_actions)
        # diversity_bonus = entropy * 5  # Scale as needed
        diversity_bonus = len(agent.species_used) * 400
        #agent.diversity_reward = diversity_bonus
        reward = scores[agent.agent_id]
        if (agent.training_stage >= 4):   
            agent.rewards[-1] += (reward) * 10
  
        agent.check_curriculum()
        agent.log_and_reset_loss()
        

    grid.detach()
    for i in range(len(agents)):
        finish_episode(agents[i], optimizers[i], grid)
        

0
Agent 0 is player 1, Agent 1 is player 2
{'soil': {'loam': 0, 'clay': 1, 'sand': 2, 'peat': 3}, 'elevation': 4, 'shade': 5, 'plants': [('grass_2', 6), ('tree_1', 7), ('tree_0', 8), ('tree_4', 9), ('grass_1', 10), ('tree_4', 11), ('grass_0', 12), ('tree_0', 13), ('grass_2', 14), ('grass_1', 15)]}
[(0, 1, 18, 26), (1, 6, 54, 46)]
[(0, 0, 18, 22), (1, 6, 10, 18)]
[(0, 0, 22, 26), (1, 6, 6, 22)]
[(0, 1, 18, 26), (1, 6, 50, 58)]
[(0, 3, 54, 14), (1, 6, 42, 62)]
[(0, 1, 22, 26), (1, 6, 54, 58)]
[(0, 3, 22, 46), (1, 6, 46, 58)]
[(0, 1, 18, 22), (1, 6, 38, 62)]
[(0, 0, 46, 34), (1, 5, 42, 54)]
[(0, 0, 22, 2), (1, 6, 58, 54)]
{0: 12.0, 1: 986.0}
Agent 0 Total: 2000 Quad Pen: -2000 Species Pen: 0  Suit Rew: 1000 Diversity  Rew 3000
Agent 1 Total: 4400 Quad Pen: 0 Species Pen: 0  Suit Rew: 2400 Diversity  Rew 2000
1
Agent 1 is player 1, Agent 0 is player 2
{'soil': {'loam': 0, 'clay': 1, 'sand': 2, 'peat': 3}, 'elevation': 4, 'shade': 5, 'plants': [('grass_2', 6), ('shrub_1', 7), ('shrub_2', 8)

In [7]:
grid = generate_training_world(H, W, seed_plants=False).to(device)
ownership_grid = torch.full((1, H, W), fill_value=-1, dtype=torch.long, device=device)
elevation_static = grid[:, CHANNELS["elevation"]].clone().detach()
shade_static = grid[:, CHANNELS["shade"]].clone().detach()

soil_static = {
    idx: grid[:, idx].clone().detach()
    for idx in CHANNELS["soil"].values()
}

agent0Species = ["grass_0", "shrub_0", "tree_0", "grass_1","shrub_2"]
agent1Species = ["grass_1", "shrub_1", "tree_1", "grass_2","tree_2"]

agent0 = LearningAgent(agent_id=0, policy_net=policy0, available_species=agent0Species, start_quadrant="top_left", agent_mask=agent_0_mask)
agent1 = LearningAgent(agent_id=1, policy_net=policy1, available_species=agent1Species, start_quadrant="bottom_right", agent_mask=agent_1_mask)
agents = [agent0, agent1]

agents[0].player_number = 1
agents[1].player_number = 2


species_list = agents[0].available_species + agents[1].available_species
species_features = get_species_features_tensor(species_list=species_list)   
build_channel_mapping_from_species_list(species_list)   
print(species_features.shape)

env = Environment(grid, model, ownership_grid, agents, elevation_static, soil_static, shade_static, species_features, species_list, steps_per_turn=5)


scores = play_game(env, species_features, max_turns=15)
print(scores)
ani = animate_species_ownership_with_static_layers(env, elevation_static, soil_static)
from IPython.display import HTML
HTML(ani.to_jshtml())


torch.Size([10, 17])
[(0, 3, 26, 30), (1, 5, 46, 58)]
[(0, 0, 26, 30), (1, 5, 62, 42)]
[(0, 3, 26, 30), (1, 8, 50, 50)]
[(0, 3, 22, 22), (1, 5, 50, 50)]
[(0, 0, 22, 26), (1, 7, 50, 54)]
[(0, 0, 18, 22), (1, 6, 54, 54)]
[(0, 0, 22, 26), (1, 8, 50, 54)]
[(0, 0, 22, 30), (1, 8, 46, 62)]
[(0, 0, 26, 26), (1, 7, 54, 54)]
[(0, 4, 26, 26), (1, 9, 58, 62)]
[(0, 0, 22, 30), (1, 8, 34, 54)]
[(0, 0, 26, 30), (1, 6, 42, 62)]
[(0, 3, 18, 22), (1, 8, 38, 46)]
[(0, 2, 22, 22), (1, 5, 62, 54)]
[(0, 3, 26, 26), (1, 5, 50, 54)]
({0: 673.0, 1: 56.0}, [[(0, 3, 26, 30), (1, 5, 46, 58)], [(0, 0, 26, 30), (1, 5, 62, 42)], [(0, 3, 26, 30), (1, 8, 50, 50)], [(0, 3, 22, 22), (1, 5, 50, 50)], [(0, 0, 22, 26), (1, 7, 50, 54)], [(0, 0, 18, 22), (1, 6, 54, 54)], [(0, 0, 22, 26), (1, 8, 50, 54)], [(0, 0, 22, 30), (1, 8, 46, 62)], [(0, 0, 26, 26), (1, 7, 54, 54)], [(0, 4, 26, 26), (1, 9, 58, 62)], [(0, 0, 22, 30), (1, 8, 34, 54)], [(0, 0, 26, 30), (1, 6, 42, 62)], [(0, 3, 18, 22), (1, 8, 38, 46)], [(0, 2, 22, 22), (1

In [5]:

# After training loop, or at a checkpoint
torch.save({
    'model_state_dict': policy1.state_dict(),
    'optimizer_state_dict': optimizer1.state_dict()
}, "policy_net.pth")

In [4]:
import torch

# Set model to evaluation mode
model = PolicyNet(5, H, W)
checkpoint = torch.load("policy_net.pth", map_location="cpu")  # Force loading on CPU
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()  # ← disables dropout, noise, etc.


PolicyNet(
  (conv1): Conv2d(17, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (pool1): AvgPool2d(kernel_size=2, stride=2, padding=0)
  (conv2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (pool2): AvgPool2d(kernel_size=2, stride=2, padding=0)
  (location_head): Conv2d(32, 1, kernel_size=(1, 1), stride=(1, 1))
  (avgpool): AvgPool2d(kernel_size=(16, 16), stride=(16, 16), padding=0)
  (env_proj): Linear(in_features=32, out_features=64, bias=True)
  (species_proj): Linear(in_features=17, out_features=64, bias=True)
  (classifier): Linear(in_features=64, out_features=1, bias=True)
)

In [7]:
import torch

# Set model to evaluation mode
model = PolicyNet(3, H, W)
checkpoint = torch.load("policy_net.pth", map_location="cpu")  # Force loading on CPU
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()  # ← disables dropout, noise, etc.

# Sample input matching model expectations
dummy_input = torch.zeros(1, NUM_CHANNELS + 1, H, W)

torch.onnx.export(
    model,
    dummy_input,
    "policy_net.onnx",
    input_names=["input"],
    output_names=["x1", "x2", "x3", "x4", "species_logits", "location_logits"],
    opset_version=13,
    do_constant_folding=False
)

