In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from game.constsants import SOIL_TYPES, PLANT_TYPES, CHANNELS, NUM_CHANNELS, PLANT_RULES, PLANT_GROUPS
from game.suitability import compute_suitability
from game.generate_map import generate_training_world
from game.nca_model import NCA


H, W = 64, 64
device = 'cuda' if torch.cuda.is_available() else 'cpu' #cuda
print("using device " + device)
# --- Grid Initialization ---
grid = torch.zeros(1, NUM_CHANNELS, H, W).to(device)
grid = torch.zeros(1, NUM_CHANNELS, H, W).to(device)

grid = generate_training_world(H, W).to(device)


model = NCA(NUM_CHANNELS).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)



using device cuda


In [2]:
#### TRAIN ####

from game.nca_model import compute_loss, rollout_training_step


for epoch in range(200):
    grid = generate_training_world(H, W).to(device)
    grid.requires_grad = True

    out = rollout_training_step(model, grid, steps=120)
    loss = compute_loss(out, epoch)
    print(loss)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if epoch % 10 == 0:
        print(f"[{epoch}] Loss: {loss.item():.4f}")

tensor(-0.0345, device='cuda:0', grad_fn=<SubBackward0>)
[0] Loss: -0.0345
tensor(-0.1034, device='cuda:0', grad_fn=<SubBackward0>)
tensor(-0.1296, device='cuda:0', grad_fn=<SubBackward0>)
tensor(-0.2156, device='cuda:0', grad_fn=<SubBackward0>)
tensor(-0.1999, device='cuda:0', grad_fn=<SubBackward0>)
tensor(-0.1989, device='cuda:0', grad_fn=<SubBackward0>)
tensor(-0.2861, device='cuda:0', grad_fn=<SubBackward0>)
tensor(-0.2870, device='cuda:0', grad_fn=<SubBackward0>)
tensor(-0.3190, device='cuda:0', grad_fn=<SubBackward0>)
tensor(-0.3849, device='cuda:0', grad_fn=<SubBackward0>)
tensor(-0.4130, device='cuda:0', grad_fn=<SubBackward0>)
[10] Loss: -0.4130
tensor(-0.3812, device='cuda:0', grad_fn=<SubBackward0>)
tensor(-0.4846, device='cuda:0', grad_fn=<SubBackward0>)
tensor(-0.6506, device='cuda:0', grad_fn=<SubBackward0>)
tensor(-0.4710, device='cuda:0', grad_fn=<SubBackward0>)
tensor(-0.7477, device='cuda:0', grad_fn=<SubBackward0>)
tensor(-0.9094, device='cuda:0', grad_fn=<SubBackwa

In [5]:
#### VISUALIZE ####

from game.animate import animate_full_ecosystem

test_grid = generate_training_world(H, W).to(device)
elevation_static = test_grid[:, CHANNELS["elevation"]].clone().detach()
soil_static = {
    idx: test_grid[:, idx].clone().detach()
    for idx in CHANNELS["soil"].values()
}

ani = animate_full_ecosystem(model, test_grid.clone(), elevation_static, soil_static, steps=120)
from IPython.display import HTML
HTML(ani.to_jshtml())

In [6]:
torch.save({
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
}, "nca.pth")

In [5]:
import torch
from game.nca_model import NCA
from game.constsants import NUM_CHANNELS, H, W

# --- Load model ---
model = NCA(NUM_CHANNELS)
checkpoint = torch.load("nca.pth", map_location="cpu")  # Force loading on CPU
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()
model = model.cpu()  # Ensure model is fully on CPU

# --- Create dummy input ---
dummy_input = torch.randn(1, NUM_CHANNELS, H, W)  # CPU tensor by default

# --- Export to ONNX ---
torch.onnx.export(
    model,
    dummy_input,
    "nca_model.onnx",
    input_names=["input"],
    output_names=["output"],
    opset_version=17,
    dynamic_axes={
        "input": {0: "batch_size", 2: "height", 3: "width"},
        "output": {0: "batch_size", 2: "height", 3: "width"}
    }
)

print("Exported successfully to nca_model.onnx")


Exported successfully to nca_model.onnx
