In [None]:
from gechebnet.model.chebnet import ChebNet


In [None]:
K = 7
nx3 =  6
input_dim = 1
output_dim = 10
hidden_dim = 16
edge_red = "max"
net = ChebNet(K, nx3, input_dim, output_dim, hidden_dim, edge_red)

In [None]:
num_param = 10000
hidden_dim = 1
model = ChebNet(K, nx3, input_dim, output_dim, hidden_dim, edge_red)
while model.capacity < num_param:
    hidden_dim += 1
    model = ChebNet(K, nx3, input_dim, output_dim, hidden_dim+1, edge_red)

In [None]:
def weight_field(graph_data, node_idx)
    """
    [summary]
    """

    neighbors, weights = get_neighbors(graph_data, node_idx)

    point_cloud = torch.zeros(len(neighbors), 4)
    point_cloud[:, 0] = graph_data.node_pos[neighbors, 0]

    im = ax.scatter(
        graph_data.node_pos[neighbors, 0], graph_data.node_pos[neighbors, 1], graph_data.node_pos[neighbors, 2], c=weights, s=50, alpha=0.5
    )

In [None]:
net.capacity

In [None]:
10*hidden_dim + 4 * (hidden_dim*hidden_dim*K + hidden_dim) + input_dim*hidden_dim*K + hidden_dim + hidden_dim*output_dim*K + output_dim

In [None]:
wandb.login()

In [None]:
wandb.init(project="gechebnet")

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
# download mnist and rotated mnist
data_path = "data"
mnist_processed_path = download_mnist(data_path)
rotated_mnist_processed_path = download_rotated_mnist(data_path)

In [None]:
# create the graph embedding data
eps = .25
xi = .01
nx1, nx2, nx3 = (28, 28, 6)
graph_data = GraphData(grid_size=(nx1, nx2), 
                       num_layers=nx3,
                       static_compression=("edge", 0.5),
                       self_loop=True, 
                       weight_threshold=0.3, 
                       sigma=1., 
                       lambdas=((xi/eps), xi, 1.))

In [None]:
# get training dataloader from mnist
train_mnist_loader = get_dataloader(
    get_datalist_mnist(graph_data, mnist_processed_path, train=True), 
    batch_size=16, 
    shuffle=True
)

In [None]:
# get test dataloader from mnist
test_mnist_loader = get_dataloader(
    get_datalist_mnist(graph_data, mnist_processed_path, train=False), 
    batch_size=16, 
    shuffle=True
)

In [None]:
# get test dataloader from rotated mnist
test_rotated_mnist_loader = get_dataloader(
    get_datalist_rotated_mnist(graph_data, rotated_mnist_processed_path, train=False), 
    batch_size=16, 
    shuffle=True
)

In [None]:
model_name = "chebnet"
model_params = {
    "K": 10, 
    "num_layers":2, 
    "input_dim":1, 
    "output_dim":10, 
    "hidden_dim":10,
}


model = get_model(model_name, model_params, device)
optimizer = torch.optim.Adam(model.parameters())

In [None]:
loss_fn = F.nll_loss
mnist_metrics = {"mnist_acc": Accuracy()}
rot_mnist_metrics = {"rot_mnist_acc": Accuracy()}

In [None]:
# create ignite's engines
trainer = create_supervised_trainer(model, optimizer, loss_fn, device, prepare_batch=prepare_batch)
ProgressBar(persist=False, desc="Training").attach(trainer)

mnist_evaluator = create_supervised_evaluator(model, mnist_metrics, device, prepare_batch=prepare_batch)
ProgressBar(persist=False, desc="Evaluation").attach(mnist_evaluator)

rot_mnist_evaluator = create_supervised_evaluator(model, rot_mnist_metrics, device, prepare_batch=prepare_batch)
ProgressBar(persist=False, desc="Evaluation").attach(rot_mnist_evaluator)

In [None]:
# track training with wandb
_ = trainer.add_event_handler(Events.ITERATION_COMPLETED, track_loss)
_ = trainer.add_event_handler(Events.EPOCH_COMPLETED, track_metrics, mnist_evaluator, test_mnist_loader, "test mnist")
_ = trainer.add_event_handler(Events.EPOCH_COMPLETED, track_metrics, rot_mnist_evaluator, test_rotated_mnist_loader, "test_rotated_mnist")

In [None]:
# save best model handler
models_path = "models"
eval_to_save = {"model": model}
best_handler = Checkpoint(
    eval_to_save,
    DiskSaver(models_path, create_dir=True, require_empty=False),
    n_saved=1,
    filename_prefix=f"best-{model_name}",
    score_function=lambda engine: engine.state.metrics["mnist_acc"],
    score_name="mnist_acc",
    global_step_transform=global_step_from_engine(trainer),
)
_ = mnist_evaluator.add_event_handler(Events.COMPLETED, best_handler)

In [None]:
# save best model
max_epochs = 20
trainer.run(train_mnist_loader, max_epochs=max_epochs)