In [1]:
import torch
from game.constsants import NUM_CHANNELS
from game.nca_model import NCA

device = 'cuda' if torch.cuda.is_available() else 'cpu' #cuda

model = NCA(NUM_CHANNELS).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

checkpoint = torch.load("nca.pth")
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])


In [2]:
from game.constsants import CHANNELS, H , W
from game.nca_model import update_shade
import torch

class Environment:
    def __init__(self, grid, model, ownership_grid, agents, elevation_static, soil_static, steps_per_turn=10):
        self.grid = grid
        self.model = model
        self.ownership_grid = ownership_grid
        self.agents = agents
        self.steps_per_turn = steps_per_turn
        self.elevation_static = elevation_static
        self.soil_static = soil_static
        self.current_turn = 0
        self.frames = []
        self.grid_frames = []
        self.quadrant_masks = torch.zeros((len(self.agents), H, W), device=self.grid.device)

        for agent in self.agents:
            if agent.start_quadrant == "top_left":
                self.quadrant_masks[agent.agent_id, :H//2, :W//2] = 1.0
            elif agent.start_quadrant == "bottom_right":
                self.quadrant_masks[agent.agent_id, H//2:, W//2:] = 1.0
            elif agent.start_quadrant == "top_right":
                self.quadrant_masks[agent.agent_id, :H//2, W//2:] = 1.0
            elif agent.start_quadrant == "bottom_left":
                self.quadrant_masks[agent.agent_id, H//2:, :W//2] = 1.0


    def get_frames(self):
        return self.frames

    def step(self, actions):
        """Actions is a list of (agent_id, species_name, row, col)"""
        for agent_id, species_name, row, col in actions:
            plant_idx = CHANNELS["plants"][species_name]
            self.grid[0, plant_idx, row, col] = 1
            self.ownership_grid[0, row, col] = agent_id

        # Run growth
        for _ in range(self.steps_per_turn):
            with torch.no_grad():
                self.grid = self.model(self.grid)
                update_shade(self.grid)

                self.grid[:, CHANNELS["elevation"]] = self.elevation_static
                for idx, value in self.soil_static.items():
                    self.grid[:, idx] = value
                
                self.frames.append(self.ownership_grid[0].detach().cpu().clone())
                self.grid_frames.append(self.grid[0].detach().cpu().clone()) 
                self.update_ownership()

        # Restore static channels
        self.current_turn += 1

    def update_ownership(self):
        """Optional: Transfer ownership as plants expand."""
        for species_name, idx in CHANNELS["plants"].items():
            mask = self.grid[0, idx] > 0.1
            owner = [agent.agent_id for agent in self.agents if species_name in agent.available_species]
            if owner:
                self.ownership_grid[0][mask] = owner[0]  # Assume one agent per species for now

    def get_scores(self):
        """Returns a dict of agent_id: controlled area."""
        scores = {agent.agent_id: 0 for agent in self.agents}
        for agent_id in scores.keys():
            scores[agent_id] = (self.ownership_grid[0] == agent_id).float().sum().item()
        return scores


In [None]:
class LearningAgent:
    def __init__(self, agent_id, policy_net, available_species, start_quadrant, steps_per_turn=10):
        self.agent_id = agent_id
        self.policy_net = policy_net
        self.available_species = available_species
        self.start_quadrant = start_quadrant  # <--- NEW
        self.saved_log_probs = []
        self.rewards = []
        self.save_interval = steps_per_turn
        self.step_counter = 0



    def choose_action(self, grid):
        grid = grid.detach()

        species_logits, location_logits = self.policy_net(grid)
        print(species_logits)
        species_dist = torch.distributions.Categorical(logits=species_logits)
        location_dist = torch.distributions.Categorical(logits=location_logits)

        species_idx = species_dist.sample()
        location_idx = location_dist.sample()

        log_prob = species_dist.log_prob(species_idx) + location_dist.log_prob(location_idx)

        self.saved_log_probs.append(log_prob)

        B, HW = location_logits.shape
        pooled_H = pooled_W = int(HW ** 0.5)

        flat_idx = location_idx.item()
        pooled_row = flat_idx // pooled_W
        pooled_col = flat_idx % pooled_W

        scale = 4  # 64/16 = 4
        row = pooled_row * scale + scale // 2  # center of block
        col = pooled_col * scale + scale // 2  # center of block

        species_name = self.available_species[species_idx.item()]
        action = (self.agent_id, species_name, row, col)
        print(action)
        return action
    
    def is_legal_move(self, row, col, ownership_grid, quadrant_mask):
        in_quadrant = quadrant_mask[row, col] > 0.5
        owns_cell = (ownership_grid[0, row, col] == self.agent_id)
        return in_quadrant or owns_cell

In [45]:
def play_game(environment, max_turns=8):
    for _ in range(max_turns):
        actions = []
        state = environment.grid.clone()
        ownership_grid = environment.ownership_grid  # <-- pull this once
        quadrant_masks = environment.quadrant_masks  # <-- pull this once

        for agent in environment.agents:
            # Build agent-specific grid with quadrant mask attached
            agent_quadrant = quadrant_masks[agent.agent_id].unsqueeze(0)  # shape [1, H, W]
            augmented_grid = torch.cat([state, agent_quadrant.unsqueeze(0)], dim=1)  # add to channels
            print(augmented_grid[0, 3, 5, 5].item() )
            print(augmented_grid[0, 3, 61, 5].item() )
            print(augmented_grid[0, 3, 61, 61].item() )
            print(augmented_grid[0, 3, 5, 61].item() )
            print(augmented_grid.shape)
            action = agent.choose_action(augmented_grid)
            agent_id, species_name, row, col = action  # <-- UNPACK action now

            # Add initial reward if needed
            agent.rewards.append(0.0)

            # Check legality
            quadrant_mask = quadrant_masks[agent.agent_id]  # just their own mask
            if not agent.is_legal_move(row, col, ownership_grid, quadrant_mask):
                penalty = 2000.0
                print(f"penalty applied to {agent} ")
                agent.rewards[-1] -= penalty

            actions.append(action)

        environment.step(actions)

    scores = environment.get_scores()
    return scores




NameError: name 'augmented_grid' is not defined

In [8]:
# Setup
from game.constsants import H, W
from game.generate_map import generate_training_world

policy0 = PolicyNet(num_species=len(["grass_0", "shrub_0", "tree_0"]), h=H, w=W).to(device)
policy1 = PolicyNet(num_species=len(["grass_1", "shrub_1", "tree_1"]), h=H, w=W).to(device)

agent0 = LearningAgent(agent_id=0, policy_net=policy0, available_species=["grass_0", "shrub_0", "tree_0"], start_quadrant="top_left")
agent1 = LearningAgent(agent_id=1, policy_net=policy1, available_species=["grass_1", "shrub_1", "tree_1"], start_quadrant="bottom_right")
agents = [agent0, agent1]

optimizer0 = torch.optim.Adam(policy0.parameters(), lr=1e-3)
optimizer1 = torch.optim.Adam(policy1.parameters(), lr=1e-3)
optimizers = [optimizer0, optimizer1]



In [9]:
import random
torch.autograd.set_detect_anomaly(True)

for episode in range(200):
    grid = generate_training_world(H, W, seed_plants=False).to(device)
    ownership_grid = torch.full((1, H, W), fill_value=-1, dtype=torch.long, device=device)
    elevation_static = grid[:, CHANNELS["elevation"]].clone().detach()
    soil_static = {
        idx: grid[:, idx].clone().detach()
        for idx in CHANNELS["soil"].values()
    }
    
    if random.random() < .5:
        agents[0].start_quadrant = "top_left"
        agents[1].start_quadrant = "bottom_right"
    else:
        agents[1].start_quadrant = "top_left"
        agents[0].start_quadrant = "bottom_right"
        
    env = Environment(grid, model, ownership_grid, agents, elevation_static, soil_static, steps_per_turn=5)

    scores = play_game(env,15)

    print(scores)
    for agent in agents: 
        reward = scores[agent.agent_id]
        agent.rewards.append(reward)

    grid.detach()
    for i in range(len(agents)):
        finish_episode(agents[i], optimizers[i], grid)
        

penalty applied to <__main__.LearningAgent object at 0x0000024F0B8744A0> 
penalty applied to <__main__.LearningAgent object at 0x0000024F0D8F0B90> 
penalty applied to <__main__.LearningAgent object at 0x0000024F0B8744A0> 
penalty applied to <__main__.LearningAgent object at 0x0000024F0D8F0B90> 
penalty applied to <__main__.LearningAgent object at 0x0000024F0B8744A0> 
penalty applied to <__main__.LearningAgent object at 0x0000024F0B8744A0> 
penalty applied to <__main__.LearningAgent object at 0x0000024F0D8F0B90> 
penalty applied to <__main__.LearningAgent object at 0x0000024F0B8744A0> 
penalty applied to <__main__.LearningAgent object at 0x0000024F0D8F0B90> 
penalty applied to <__main__.LearningAgent object at 0x0000024F0B8744A0> 
penalty applied to <__main__.LearningAgent object at 0x0000024F0D8F0B90> 
penalty applied to <__main__.LearningAgent object at 0x0000024F0B8744A0> 
penalty applied to <__main__.LearningAgent object at 0x0000024F0D8F0B90> 
penalty applied to <__main__.LearningA

In [5]:
import torch.nn as nn
import torch

from game.constsants import NUM_CHANNELS

class PolicyNet(nn.Module):
    def __init__(self, num_species, h, w):
        super().__init__()
        self.h = h
        self.w = w

        self.conv = nn.Sequential(
            nn.Conv2d(NUM_CHANNELS+1, 32, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 32, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2)
        )

        self.species_head = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),  # average pool spatial dimensions
            nn.Flatten(),             # [B, 32]
            nn.Linear(32, num_species)
        )

        self.location_head = nn.Conv2d(32, 1, 1)  # 1x1 conv to predict location logits per (h/4, w/4)

    def forward(self, grid):
        x = self.conv(grid)

        species_logits = self.species_head(x)  # [B, num_species]

        location_logits = self.location_head(x)  # [B, 1, H/4, W/4]
        location_logits = location_logits.flatten(1)  # [B, (H/4)*(W/4)]
        
        species_logits = species_logits + 0.1 * torch.randn_like(species_logits)
        location_logits = location_logits + 0.1 * torch.randn_like(location_logits)
        return species_logits, location_logits


In [6]:
def finish_episode(agent, optimizer, grid, gamma=1.0):
    returns = []
    R = 0.0
    for r in reversed(agent.rewards):
        R = r + gamma * R
        returns.insert(0, R)

    returns = torch.tensor(returns, device='cuda')

    # Normalize
    if returns.std() > 1e-6:
        returns = (returns - returns.mean()) / (returns.std() + 1e-8)
    else:
        returns = returns - returns.mean()  #
        
    optimizer.zero_grad()
    total_loss = torch.tensor(0.0, device='cuda')

    losses=[]
    for log_prob, R in zip(agent.saved_log_probs, returns):
        losses.append(-log_prob * R)


    total_loss = torch.stack(losses).sum()
    total_loss.backward()
    optimizer.step()
    # Clean up
    agent.rewards.clear()
    agent.saved_log_probs.clear()


In [7]:
import matplotlib.pyplot as plt
from matplotlib import animation
import torch
import numpy as np

def animate_species_ownership_with_static_layers(environment, elevation_static, soil_static, steps=None):
    from game.constsants import CHANNELS, H, W

    if steps is None:
        steps = len(environment.frames)

    fig, axs = plt.subplots(1, 3, figsize=(18, 6))
    plt.tight_layout()

    ims = []

    # Main animated species × ownership
    im_species_owner = axs[0].imshow(torch.zeros((H, W, 3)), animated=True)
    axs[0].axis('off')
    axs[0].set_title("Species + Ownership")

    # Static soil type map
    soil_layers = torch.stack(list(soil_static.values()))  # [4, 1, H, W]
    soil_layers = soil_layers.squeeze(1)  # [4, H, W]

    soil_np = soil_layers.argmax(0).cpu().numpy()  # [H, W]

    axs[1].imshow(soil_np, cmap='Set3')
    axs[1].axis('off')
    axs[1].set_title("Soil Type")

    # Static elevation map
    elevation_np = elevation_static.squeeze(0).cpu().numpy()  # [H, W]
    axs[2].imshow(elevation_np, cmap='terrain', vmin=0, vmax=1)
    axs[2].axis('off')
    axs[2].set_title("Elevation")

    ims.append(im_species_owner)

    # Prepare coloring
    plant_names = list(CHANNELS["plants"].keys())
    plant_channels = list(CHANNELS["plants"].values())

    colors = plt.cm.tab10(np.linspace(0,1,len(plant_channels)))[:, :3]  # RGB only
    owner_tints = np.array([
        [1.0, 1.0, 1.0],    # Unclaimed (-1)
        [1.2, 1.0, 1.0],    # Agent 0
        [1.0, 1.2, 1.0],    # Agent 1
        [1.0, 1.0, 1.2],    # Agent 2
        [1.2, 1.2, 1.0],    # Agent 3
        # Extend if more agents
    ])

    def update(i):
        ownership = environment.frames[i].numpy()    # [H, W]
        plants = environment.grid_frames[i].numpy()  # [C, H, W]

        # Pick dominant species at each location
        plant_values = plants[plant_channels]  # [num_species, H, W]
        dominant_species = plant_values.argmax(0)  # [H, W]
        dominant_strength = plant_values.max(0)    # [H, W], how strong the best species is

        # Build base color purely from ownership
        base_color = np.zeros((H, W, 3), dtype=np.float32)

        # Red for agent 0 (ownership == 0)
        base_color[ownership == 0, 0] = 1.0  # Red channel

        # Blue for agent 1 (ownership == 1)
        base_color[ownership == 1, 2] = 1.0  # Blue channel

        # Optionally, you can add more owner colors if needed.

        # Multiply the color by species presence
        intensity = dominant_strength / (dominant_strength.max() + 1e-8)  # normalize
        intensity = np.clip(intensity, 0, 1)

        combined = base_color * intensity[..., None]  # apply intensity per pixel

        im_species_owner.set_data(combined)
        axs[0].set_title(f"Species + Ownership - Step {i}")

        return ims


    ani = animation.FuncAnimation(fig, update, frames=steps, interval=300, blit=False)
    plt.close()
    return ani


In [21]:
print(f"episode {episode}")
grid = generate_training_world(H, W, seed_plants=False).to(device)
ownership_grid = torch.full((1, H, W), fill_value=-1, dtype=torch.long, device=device)
elevation_static = grid[:, CHANNELS["elevation"]].clone().detach()
soil_static = {
    idx: grid[:, idx].clone().detach()
    for idx in CHANNELS["soil"].values()
}
env = Environment(grid, model, ownership_grid, agents, elevation_static, soil_static, steps_per_turn=5)

scores = play_game(env, max_turns=15)
print(scores)
ani = animate_species_ownership_with_static_layers(env, elevation_static, soil_static)
from IPython.display import HTML
HTML(ani.to_jshtml())


episode 199
penalty applied to <__main__.LearningAgent object at 0x0000024F3E639340> 
penalty applied to <__main__.LearningAgent object at 0x0000024F3E639340> 
penalty applied to <__main__.LearningAgent object at 0x0000024F3E639340> 
penalty applied to <__main__.LearningAgent object at 0x0000024F3E639340> 
penalty applied to <__main__.LearningAgent object at 0x0000024F3E639340> 
penalty applied to <__main__.LearningAgent object at 0x0000024F3E639340> 
penalty applied to <__main__.LearningAgent object at 0x0000024F3E639340> 
penalty applied to <__main__.LearningAgent object at 0x0000024F3E639340> 
penalty applied to <__main__.LearningAgent object at 0x0000024F3E639340> 
penalty applied to <__main__.LearningAgent object at 0x0000024F3E639340> 
penalty applied to <__main__.LearningAgent object at 0x0000024F3E639340> 
penalty applied to <__main__.LearningAgent object at 0x0000024F3E639340> 
{0: 1115.0, 1: 1956.0}


In [42]:
import torch

dummy_input = torch.randn(1, 16, H, W).cuda()

# model = PolicyNet(num_species=3, h=H, w=W).cuda()
# Load your trained weights if needed here
# model.load_state_dict(torch.load(...))
policy0.eval()
torch.onnx.export(
    policy0, 
    dummy_input, 
    "policy_net.onnx", 
    input_names=["input"],
    output_names=["species_logits", "location_logits"],
    dynamic_axes={"input": {0: "batch_size"}, "species_logits": {0: "batch_size"}, "location_logits": {0: "batch_size"}},
    opset_version=17
)


In [17]:
import onnxruntime as ort
import torch
import numpy as np

class PolicyNetONNX:
    def __init__(self, model_path, device='cuda'):
        self.session = ort.InferenceSession(model_path, providers=["CPUExecutionProvider"])
        self.device = device

    def __call__(self, grid):
        if grid.is_cuda:
            input_numpy = grid.detach().cpu().numpy()
        else:
            input_numpy = grid.detach().numpy()

        outputs = self.session.run(None, {"input": input_numpy})

        species_logits = torch.from_numpy(outputs[0]).to(self.device)
        location_logits = torch.from_numpy(outputs[1]).to(self.device)

        # Match your original PyTorch model behavior (optional)
        species_logits = species_logits + 0.1 * torch.randn_like(species_logits)
        location_logits = location_logits + 0.1 * torch.randn_like(location_logits)

        return species_logits, location_logits


In [25]:
from game.constsants import H, W
from game.generate_map import generate_training_world

policy0 = PolicyNet(num_species=len(["grass_0", "shrub_0", "tree_0"]), h=H, w=W).to(device)
policy1 = PolicyNetONNX(model_path="policy_net.onnx")
agent0 = LearningAgent(agent_id=0, policy_net=policy0, available_species=["grass_0", "shrub_0", "tree_0"], start_quadrant="top_left")
agent1 = LearningAgent(agent_id=1, policy_net=policy1, available_species=["grass_1", "shrub_1", "tree_1"], start_quadrant="bottom_right")
agents = [agent0, agent1]


In [46]:
print(f"episode {episode}")
grid = generate_training_world(H, W, seed_plants=False).to(device)
ownership_grid = torch.full((1, H, W), fill_value=-1, dtype=torch.long, device=device)
elevation_static = grid[:, CHANNELS["elevation"]].clone().detach()
soil_static = {
    idx: grid[:, idx].clone().detach()
    for idx in CHANNELS["soil"].values()
}
env = Environment(grid, model, ownership_grid, agents, elevation_static, soil_static, steps_per_turn=5)

scores = play_game(env, max_turns=15)
print(scores)
ani = animate_species_ownership_with_static_layers(env, elevation_static, soil_static)
from IPython.display import HTML
HTML(ani.to_jshtml())


episode 199
0.0
0.0
0.0
0.0
torch.Size([1, 16, 64, 64])
tensor([[ 0.0720, -0.2179,  0.0655]], device='cuda:0', grad_fn=<AddBackward0>)
(0, 'grass_0', 34, 62)
penalty applied to <__main__.LearningAgent object at 0x0000024F1286A0C0> 
0.0
0.0
0.0
0.0
torch.Size([1, 16, 64, 64])
tensor([[-0.0914, -0.6903,  0.2806]], device='cuda:0')
(1, 'grass_1', 54, 34)
0.0
0.0
0.0
0.0
torch.Size([1, 16, 64, 64])
tensor([[ 0.0646, -0.1506, -0.2358]], device='cuda:0', grad_fn=<AddBackward0>)
(0, 'tree_0', 10, 46)
penalty applied to <__main__.LearningAgent object at 0x0000024F1286A0C0> 
0.0
0.0
0.0
0.0
torch.Size([1, 16, 64, 64])
tensor([[-0.1614, -0.6486,  0.0281]], device='cuda:0')
(1, 'tree_1', 50, 34)
0.0
0.0
0.0
0.0
torch.Size([1, 16, 64, 64])
tensor([[ 0.0994, -0.0199, -0.0391]], device='cuda:0', grad_fn=<AddBackward0>)
(0, 'tree_0', 6, 34)
penalty applied to <__main__.LearningAgent object at 0x0000024F1286A0C0> 
0.0
0.0
0.0
0.0
torch.Size([1, 16, 64, 64])
tensor([[-0.4340, -0.5315,  0.3715]], device