In [1]:
import torch
from seta import (
    MLPThinker, GNNThinker,
    System,
    Dynamics,
    Simulator,
    CustomFunctionDataset, OfflineRBFInterpolationDataset,
    Environment,
    Trainer)

import os
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"


In [2]:
# params

T_max= 100
    
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
# DATASET
import pickle
import numpy as np

# Load the RBF interpolation from the file
with open('rbf_interpolation_internodes.pkl', 'rb') as f:
    rbf_loaded = pickle.load(f)


def curve_data(time_tensor: torch.Tensor, temperature: float) -> torch.Tensor:
    pred = []
    for el in time_tensor:
        pred.append(torch.tensor(rbf_loaded(np.array([el.item(),temperature]).reshape(1,-1))).float())
    return torch.cat(pred, dim=0)



num_examples = 80
temp_min = 20.0
temp_max = 40.0

dataset = CustomFunctionDataset(
    T=T_max,
    num_examples=num_examples,
    temp_min=temp_min,
    temp_max=temp_max,
    curve_fn=curve_data
)


dataset = OfflineRBFInterpolationDataset("dataset_offline.npz")



Building Dataset


In [4]:
# SYSTEM

system = System(device=device)


decision_net = GNNThinker(device = device)

dyn = Dynamics()

def STEM_rule(agent, system):
        """
        For a WorkerAgent:
          - increase 'workload' by 0.1 each step
          - if system.temperature > 20, increase 'age' by an extra 0.5
        """
        # agent.state is a WorkerState dataclass with fields (age, workload)
        agent.state.size += 0.1
        

def LEAF_rule(agent, system):
        """
        For a WorkerAgent:
          - increase 'workload' by 0.1 each step
          - if system.temperature > 20, increase 'age' by an extra 0.5
        """
        # agent.state is a WorkerState dataclass with fields (age, workload)
        agent.state.size += 0.1
        


dyn.register_rule("S", STEM_rule)
dyn.register_rule("L", LEAF_rule)


def spawn_node_SAM(system, prediction):
        current_W = system.types.count("S")
        delta = prediction - current_W
        if delta > 0.0:
            n_to_spawn = int(torch.ceil(torch.tensor(delta)).item())
            for _ in range(n_to_spawn):
                system.add_node_SAM()

sim = Simulator(
     T_max=T_max,
     system=system,
     system_dynamic= dyn,
     decision_net=decision_net,
     act_rule=spawn_node_SAM,
     device= device
    )


In [5]:

# ─── Training Configuration ──────────────────────────────────────────
    
epochs = 100
batch_size = 8
learning_rate = 1e-2

output_model_path = "decision_net_GNN.pth"

validation_split = 0.2
patience = 10
curve_interval = 25
num_example_curves = 4

    # 4) Train
trainer = Trainer(
    decision_net=decision_net,
    simulator=sim,
    dataset=dataset,
    device=device,
    num_epochs=epochs,
    batch_size=batch_size,
    lr=learning_rate,
    validation_split=validation_split,
    patience=patience,
    visualize=True,
    curve_interval=curve_interval,
    num_example_curves=num_example_curves,
    output_model_path=output_model_path
)
trainer.train()


Starting training (patience=10 epochs) …


RuntimeError: mat1 and mat2 shapes cannot be multiplied (4x2 and 1x32)

In [None]:
# test 

model_path = "decision_net.pth"
decision_net.load_state_dict(torch.load(model_path, map_location=device))

env_test = Environment(25)
sim.T_max = 60
sim.run(env_test,"train",3, output = "out/")



=== Starting simulation - Mode: train ===
[t=0] Phase 1 (Sense)
[t=0] Phase 2 (Evolve): applying dynamics to all agents.
    Node 0 (S): {'age': 0.0, 'size': 1.1}
    Node 1 (L): {'age': 0.0, 'size': 0.2}
    Node 2 (L): {'age': 0.0, 'size': 0.2}
    Node 3 (SAM): {'age': 0.0}
[t=0] Phase 3 (Think)
[t=0] Phase 4 (Act)


  n_to_spawn = int(torch.ceil(torch.tensor(delta)).item())


[t=1] Phase 1 (Sense)
[t=1] Phase 2 (Evolve): applying dynamics to all agents.
    Node 0 (S): {'age': 0.0, 'size': 1.2000000000000002}
    Node 1 (L): {'age': 0.0, 'size': 0.30000000000000004}
    Node 2 (L): {'age': 0.0, 'size': 0.30000000000000004}
    Node 3 (SAM): {'age': 0.0}
    Node 4 (S): {'age': 0.0, 'size': 1.1}
    Node 5 (L): {'age': 0.0, 'size': 0.2}
    Node 6 (L): {'age': 0.0, 'size': 0.2}
    Node 7 (S): {'age': 0.0, 'size': 1.1}
    Node 8 (L): {'age': 0.0, 'size': 0.2}
    Node 9 (L): {'age': 0.0, 'size': 0.2}
[t=1] Phase 3 (Think)
[t=1] Phase 4 (Act)
[t=2] Phase 1 (Sense)
[t=2] Phase 2 (Evolve): applying dynamics to all agents.
    Node 0 (S): {'age': 0.0, 'size': 1.3000000000000003}
    Node 1 (L): {'age': 0.0, 'size': 0.4}
    Node 2 (L): {'age': 0.0, 'size': 0.4}
    Node 3 (SAM): {'age': 0.0}
    Node 4 (S): {'age': 0.0, 'size': 1.2000000000000002}
    Node 5 (L): {'age': 0.0, 'size': 0.30000000000000004}
    Node 6 (L): {'age': 0.0, 'size': 0.30000000000000004}

tensor([ 2.6193,  2.7493,  2.8807,  3.0132,  3.1464,  3.2806,  3.4155,  3.5509,
         3.6870,  3.8235,  3.9604,  4.0977,  4.2352,  4.3731,  4.5113,  4.6496,
         4.8639,  5.1417,  5.4217,  5.7004,  5.9793,  6.4097,  6.8845,  7.3016,
         7.7631,  8.1718,  8.6334,  9.0683,  9.5893, 10.0571, 10.5855, 11.0599,
        11.5883, 12.0627, 12.5911, 13.0655, 13.5940, 14.0683, 14.5968, 15.0711,
        15.5996, 16.0732, 16.5992, 17.0714, 17.5853, 18.0060, 18.5179, 18.9386,
        19.3594, 19.8713, 20.1355, 20.6127, 20.8636, 21.1145, 21.5866, 21.7428,
        21.8990, 22.0552, 22.4761, 22.6323], grad_fn=<CatBackward0>)