In [None]:
import sys
sys.path.append("../")

In [None]:
import torch
import numpy as np
import random
import wandb
import torch.profiler as tpf

In [None]:
from kdmc.parser import parse_args
from kdmc.data.core import create_dataloaders, get_datasets
from kdmc.train.core import create_model, create_scheduler, get_trainer

# Train

In [None]:
# Detect when NaN appears
torch.autograd.set_detect_anomaly(True)

In [None]:
args = parse_args([
    '--id', 'test', '--loss', 'std', '--dataset', 'sbasic', '--arch', 'resnet', 
    '--batch_size', '512', '--seed', '0', '--grad_clip', '1e3', '--profile'
    '--root_path', '../../'])

wandb.init(project=f"kdmc_{args.dataset}", name=args.id, dir=args.root_path.joinpath("wandb"))
wandb.config.update(args)

# Seed
torch.manual_seed(args.seed)
np.random.seed(args.seed)
random.seed(args.seed)

In [None]:
# Data
print('==> Preparing data..')
trainset, testset = get_datasets(args)
trainloader, testloader = create_dataloaders(args, trainset, testset)
args.time_samples = trainset.dataset.time_samples

In [None]:
# Model
print('==> Building model..')
net = create_model(args)

optimizer = torch.optim.SGD(net.parameters(), lr=args.lr,
                    momentum=0.9, weight_decay=5e-4)
scheduler, schd_updt = create_scheduler(optimizer, args, args.n_epochs, len(trainloader))

In [None]:
start_epoch = 1

trainer = get_trainer(args, net, trainloader, testloader, optimizer, scheduler, schd_updt, args.save_freq)
for epoch in range(start_epoch, args.n_epochs + 1):
    if epoch == 1:
        with tpf.profile(
                activities=[tpf.ProfilerActivity.CPU, tpf.ProfilerActivity.CUDA],
                on_trace_ready=tpf.tensorboard_trace_handler(dir_name=args.root_path.joinpath('profiler')),
                record_shapes=True,  # record shapes of operator inputs
                profile_memory=True,  # record tensor memory allocation
                with_stack=True  # record stack traces of where ops are created
            ) as prof:
            trainer.train(epoch, profiler=prof)

In [None]:
%load_ext tensorboard

In [None]:
%tensorboard --logdir ../../profiler

In [None]:
# Model
print('==> Building model..')
net2 = create_model(args)
ckpt_path = args.root_path.joinpath(f"checkpoint/{args.dataset}/{args.arch}/{args.id}/{args.seed}/ckpt_last.pth")
ckpt = torch.load(ckpt_path)
net2.load_state_dict(ckpt["net"])
net2.to(args.device)
net2.eval()

In [None]:
trainer.get_geometric_metrics = lambda: {}
trainer2 = get_trainer(args, net2, trainloader, testloader, None, None, None, args.save_freq)
trainer2.get_geometric_metrics = lambda: {}

In [None]:
res = trainer.test()

In [None]:
res2 = trainer2.test()

In [None]:
res

In [None]:
(res2['clean'] == res['clean']).all()