In [2]:
from models import FixedTargetEGNCA
from argparse import Namespace
import datetime
import torch
import time

# * format the date and time in a readable format
now = datetime.datetime.now()
date_time = now.strftime('%Y-%m-%d@%H-%M-%S')

torch.manual_seed(42)
torch.set_default_dtype(torch.float32)
pool_size = 64

# * create args
args = Namespace()
args.save_to = 'logs'
args.graph = 'line'
args.size = 8
args.model_name = f'test-{args.graph}{args.size}-{date_time}'
args.scale = 1.0
args.coords_dim = 3
args.hidden_dim = 8
args.message_dim = 16
args.has_attention = True
args.device = 'cuda'
args.pool_size = pool_size
args.batch_size = 2
args.epochs = 25_000
args.min_steps = 16
args.max_steps = 32
args.info_rate = 250
args.lr = 1e-3
args.b1 = 0.9
args.b2 = 0.999
args.wd = 1e-5
args.patience_sch = 500
args.factor_sch = 0.5
args.density_rand_edge = 1.0
args.reset_at = 1_000

# * create and train model
fixed_target_egnca = FixedTargetEGNCA(args)
tik = time.time()
fixed_target_egnca.train_model(
    train_verbose=False,
    view_random_graphs=True,
)
tok = time.time()
print ('train time: %d (s)' % (tok - tik))

[models.py] target_coords.shape: torch.Size([8, 3]), target_edges.shape: torch.Size([2, 14])
[models.py] training new model -- creating starting seed
rand_edges.shape: torch.Size([2, 56]), rand_edges:
tensor([[ 7,  6,  3,  6,  5,  4,  6,  7,  5,  1,  5,  3,  7,  2,  7,  6,  7,  2,
          7,  5,  5,  4,  4,  7,  4,  6,  6,  3, 15, 14, 11, 14, 13, 12, 14, 15,
         13,  9, 13, 11, 15, 10, 15, 14, 15, 10, 15, 13, 13, 12, 12, 15, 12, 14,
         14, 11],
        [ 0,  3,  1,  4,  2,  1,  2,  4,  1,  0,  4,  2,  1,  0,  3,  5,  2,  1,
          5,  3,  0,  2,  0,  6,  3,  1,  0,  0,  8, 11,  9, 12, 10,  9, 10, 12,
          9,  8, 12, 10,  9,  8, 11, 13, 10,  9, 13, 11,  8, 10,  8, 14, 11,  9,
          8,  8]], device='cuda:0')
edges.shape: torch.Size([2, 28]), edges:
tensor([[ 0,  1,  1,  2,  2,  3,  3,  4,  4,  5,  5,  6,  6,  7,  8,  9,  9, 10,
         10, 11, 11, 12, 12, 13, 13, 14, 14, 15],
        [ 1,  0,  2,  1,  3,  2,  4,  3,  5,  4,  6,  5,  7,  6,  9,  8, 10,  9,
      

KeyboardInterrupt: 

In [None]:
from utils.visualize import create_evolve_figure
from models import FixedTargetEGNCA
from argparse import Namespace
import torch
import json

torch.manual_seed(42)
path = 'logs/test-line16-2024-09-25@20-48-13'
best = 1
with open(f'{path}/args.txt', 'r') as f:
    args = json.load(f)
args = Namespace(**args)
model = FixedTargetEGNCA(args, new_model=False)
model.load_state_dict(torch.load(f'{path}/best-22192.pt', weights_only=True, map_location=args.device))
fig = create_evolve_figure(
    model, 
    num_steps=500,
    frame_duration=50,
    show_edges=True,
)

import plotly
plotly.offline.init_notebook_mode()
fig.show()