In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from nca.constsants import SOIL_TYPES, CHANNELS, NUM_CHANNELS, PLANT_RULES, PLANT_GROUPS
from nca.suitability import compute_suitability
from nca.generate_map import generate_training_world
from nca.nca_model import NCA
import random


H, W = 64, 64
device = 'cuda' if torch.cuda.is_available() else 'cpu' #cuda
print("using device " + device)
# --- Grid Initialization ---

grid = generate_training_world(H, W,seed_plants=False).to(device)


model = NCA(NUM_CHANNELS).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)



using device cuda


In [None]:
#### TRAIN ####

from nca.nca_model import compute_loss, rollout_training_step, randomize_species_order, get_species_features_tensor, build_channel_mapping_from_species_list
import time
for epoch in range(200):
    species_list = random.sample(list(PLANT_RULES.keys()), k=10)
    species_features = get_species_features_tensor(species_list=species_list)
    build_channel_mapping_from_species_list(species_list)  
    print(CHANNELS)
    start = time.time()
    grid = generate_training_world(H, W, seed_plants=True, seed_smart=False).to(device)
    grid.requires_grad = True
    grid, species_features = rollout_training_step(model, grid, species_features, steps=120)
    loss = compute_loss(grid, species_features, epoch)
    print(loss)
    optimizer.zero_grad()
    loss.backward()

    optimizer.step()

    if epoch % 10 == 0:
        print(f"[{epoch}] Loss: {loss.item():.4f}")

tensor(0.0888, device='cuda:0', grad_fn=<SubBackward0>)
[0] Loss: 0.0888
tensor(0.0902, device='cuda:0', grad_fn=<SubBackward0>)
tensor(0.0904, device='cuda:0', grad_fn=<SubBackward0>)
tensor(0.0898, device='cuda:0', grad_fn=<SubBackward0>)
tensor(0.0916, device='cuda:0', grad_fn=<SubBackward0>)
tensor(0.0923, device='cuda:0', grad_fn=<SubBackward0>)
tensor(0.0897, device='cuda:0', grad_fn=<SubBackward0>)
tensor(0.0885, device='cuda:0', grad_fn=<SubBackward0>)
tensor(0.0895, device='cuda:0', grad_fn=<SubBackward0>)
tensor(0.0911, device='cuda:0', grad_fn=<SubBackward0>)
tensor(0.0887, device='cuda:0', grad_fn=<SubBackward0>)
[10] Loss: 0.0887
tensor(0.0917, device='cuda:0', grad_fn=<SubBackward0>)
tensor(0.0909, device='cuda:0', grad_fn=<SubBackward0>)
tensor(0.0877, device='cuda:0', grad_fn=<SubBackward0>)
tensor(0.0903, device='cuda:0', grad_fn=<SubBackward0>)
tensor(0.0891, device='cuda:0', grad_fn=<SubBackward0>)
tensor(0.0898, device='cuda:0', grad_fn=<SubBackward0>)
tensor(0.0904

In [8]:
import numpy as np

def log_channel_sums(grid: torch.Tensor):
    assert grid.ndim == 4, "Expected shape [1, channels, height, width]"
    channels = grid.shape[1]
    for c in range(channels):
        channel_sum = grid[0, c].sum().item()
        print(f"Channel {c} sum: {channel_sum}")


log_channel_sums(test_grid)

Channel 0 sum: 2104.0
Channel 1 sum: 757.0
Channel 2 sum: 846.0
Channel 3 sum: 389.0
Channel 4 sum: 379.2241516113281
Channel 5 sum: 0.0
Channel 6 sum: 3.0
Channel 7 sum: 3.0
Channel 8 sum: 3.0
Channel 9 sum: 3.0
Channel 10 sum: 3.0
Channel 11 sum: 3.0
Channel 12 sum: 3.0
Channel 13 sum: 3.0
Channel 14 sum: 3.0
Channel 15 sum: 3.0


In [3]:
#### VISUALIZE ####

from nca.animate import animate_full_ecosystem
model.eval()
randomize_species_order()
test_grid = generate_training_world(H, W, seed_plants=True).to(device)
elevation_static = test_grid[:, CHANNELS["elevation"]].clone().detach()
shade_static = test_grid[:, CHANNELS["shade"]].clone().detach()
soil_static = {
    idx: test_grid[:, idx].clone().detach()
    for idx in CHANNELS["soil"].values()
}

ani = animate_full_ecosystem(model, test_grid.clone(), elevation_static, soil_static, shade_static, steps=120)
from IPython.display import HTML
HTML(ani.to_jshtml())

In [4]:
torch.save({
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
}, "nca.pth")

In [6]:
checkpoint = torch.load("nca.pth", map_location="cpu")  # Force loading on CPU
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()

NCA(
  (model): Sequential(
    (0): Conv2d(16, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU()
    (2): Conv2d(128, 16, kernel_size=(1, 1), stride=(1, 1))
  )
)

In [10]:
import torch
from nca.nca_model import NCA
from nca.constsants import NUM_CHANNELS, H, W

# --- Load model ---
model = NCA(NUM_CHANNELS)
checkpoint = torch.load("nca.pth", map_location="cpu")  # Force loading on CPU
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()
model = model.cpu()  # Ensure model is fully on CPU

# --- Create dummy input ---
dummy_grid = torch.randn(1, NUM_CHANNELS, H, W)  # CPU tensor by default
dummy_species = torch.tensor([
    [0.3, 0.9, 0.4, 0.2] + [1, 0, 0, 1] + [1, 1, 0] + [0, 1, 0]  # Example feature vector
    for _ in PLANT_RULES
], dtype=torch.float32)  # shape: [num_species, 14]

# --- Export to ONNX ---
torch.onnx.export(
    model,
    (dummy_grid, dummy_species),
    "nca_model.onnx",
    input_names=["grid", "species_features"],
    output_names=["output"],
    opset_version=17,
    dynamic_axes={
        "input": {0: "batch_size", 2: "height", 3: "width"},
        "output": {0: "batch_size", 2: "height", 3: "width"}
    }
)

print("Exported successfully to nca_model.onnx")




Exported successfully to nca_model.onnx
