In [None]:
import torch
import torch.nn.functional as F

import wandb

from ignite.contrib.handlers import ProgressBar
from ignite.engine import Events, create_supervised_evaluator, create_supervised_trainer
from ignite.handlers import Checkpoint, DiskSaver, global_step_from_engine
from ignite.metrics import Accuracy

from gechebnet.data.dataloader import get_datalist_mnist, get_datalist_rotated_mnist, get_dataloader
from gechebnet.data.dataset import download_mnist, download_rotated_mnist
from gechebnet.graph.graph import GraphData
from gechebnet.model.model import get_model
from gechebnet.utils import prepare_batch, track_loss, track_metrics

In [None]:
wandb.login()

In [None]:
wandb.init(project="gechebnet")

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
# download mnist and rotated mnist
data_path = "data"
mnist_processed_path = download_mnist(data_path)
rotated_mnist_processed_path = download_rotated_mnist(data_path)

In [None]:
# create the graph embedding data
eps = .25
xi = .01
nx, ny, nz = (28, 28, 6)
graph_data = GraphData(grid_size=(nx, ny), 
                       num_layers=nz,
                       static_compression=("edge", 0.5),
                       self_loop=True, 
                       weight_threshold=0.3, 
                       sigma=1., 
                       lambdas=((xi/eps), xi, 1.))

In [None]:
# get training dataloader from mnist
train_mnist_loader = get_dataloader(
    get_datalist_mnist(graph_data, mnist_processed_path, train=True), 
    batch_size=16, 
    shuffle=True
)

In [None]:
# get test dataloader from mnist
test_mnist_loader = get_dataloader(
    get_datalist_mnist(graph_data, mnist_processed_path, train=False), 
    batch_size=16, 
    shuffle=True
)

In [None]:
# get test dataloader from rotated mnist
test_rotated_mnist_loader = get_dataloader(
    get_datalist_rotated_mnist(graph_data, rotated_mnist_processed_path, train=False), 
    batch_size=16, 
    shuffle=True
)

In [None]:
model_name = "chebnet"
model_params = {
    "K": 10, 
    "num_layers":2, 
    "input_dim":1, 
    "output_dim":10, 
    "hidden_dim":10,
}


model = get_model(model_name, model_params, device)
optimizer = torch.optim.Adam(model.parameters())

In [None]:
loss_fn = F.nll_loss
mnist_metrics = {"mnist_acc": Accuracy()}
rot_mnist_metrics = {"rot_mnist_acc": Accuracy()}

In [None]:
# create ignite's engines
trainer = create_supervised_trainer(model, optimizer, loss_fn, device, prepare_batch=prepare_batch)
ProgressBar(persist=False, desc="Training").attach(trainer)

mnist_evaluator = create_supervised_evaluator(model, mnist_metrics, device, prepare_batch=prepare_batch)
ProgressBar(persist=False, desc="Evaluation").attach(mnist_evaluator)

rot_mnist_evaluator = create_supervised_evaluator(model, rot_mnist_metrics, device, prepare_batch=prepare_batch)
ProgressBar(persist=False, desc="Evaluation").attach(rot_mnist_evaluator)

In [None]:
# track training with wandb
_ = trainer.add_event_handler(Events.ITERATION_COMPLETED, track_loss)
_ = trainer.add_event_handler(Events.EPOCH_COMPLETED, track_metrics, mnist_evaluator, test_mnist_loader, "test mnist")
_ = trainer.add_event_handler(Events.EPOCH_COMPLETED, track_metrics, rot_mnist_evaluator, test_rotated_mnist_loader, "test_rotated_mnist")

In [None]:
# save best model handler
models_path = "models"
eval_to_save = {"model": model}
best_handler = Checkpoint(
    eval_to_save,
    DiskSaver(models_path, create_dir=True, require_empty=False),
    n_saved=1,
    filename_prefix=f"best-{model_name}",
    score_function=lambda engine: engine.state.metrics["mnist_acc"],
    score_name="mnist_acc",
    global_step_transform=global_step_from_engine(trainer),
)
_ = mnist_evaluator.add_event_handler(Events.COMPLETED, best_handler)

In [None]:
# save best model
max_epochs = 20
trainer.run(train_mnist_loader, max_epochs=max_epochs)