In [None]:
import sys
from pathlib import Path
import shutil

import torch

sys.path.append('./src')

from model import build_resnet
from data import build_dataloader

import smdebug.pytorch as smd
from smdebug.core.reduction_config import ReductionConfig
from smdebug.core.save_config import SaveConfig
from smdebug.core.collection import CollectionKeys
from smdebug.core.config_constants import DEFAULT_CONFIG_FILE_PATH

In [None]:
train_data_src = "s3://jbsnyder-sagemaker-us-east/data/imagenet/train"
val_data_src = "s3://jbsnyder-sagemaker-us-east/data/imagenet/val"

In [None]:
model = build_resnet(resnet_version=50, num_classes=1000)
model.to("cuda")
train_dataloader = build_dataloader(train_data_src, batch_size=64, num_workers=0, train=True)
val_dataloader = build_dataloader(val_data_src, batch_size=64, num_workers=0, train=False)
opt = torch.optim.AdamW(model.parameters(), lr=0.004)
loss_func = torch.nn.CrossEntropyLoss()

In [None]:
out_dir='./smdebugger'
if out_dir:
    shutil.rmtree(out_dir, ignore_errors=True)
    assert not Path(out_dir).exists()
if Path(DEFAULT_CONFIG_FILE_PATH).exists():
    hook = smd.Hook.create_from_json_file()
else:
    hook = smd.Hook(out_dir=out_dir,
                    export_tensorboard=True,
                    reduction_config=ReductionConfig(reductions=['mean'], norms=['l2']),
                    save_config=SaveConfig(save_interval=25),
                    include_regex=None,
                    include_collections=[CollectionKeys.LOSSES, 
                                        CollectionKeys.GRADIENTS, 
                                        CollectionKeys.WEIGHTS],
                    save_all=False,
                    include_workers="one")
hook.register_module(model)
hook.register_loss(loss_func)

In [None]:
def train_step(batch, model, opt):
    opt.zero_grad()
    x, y, idx = batch
    x = x.to("cuda")
    y = y.to("cuda")
    with torch.autocast(device_type="cuda", dtype=torch.float16):
        pred = model(x)
        loss = loss_func(pred, y)
    loss.backward()
    opt.step()
    return loss

def train_epoch(dataloader, model):
    for batch_idx, batch in enumerate(dataloader): 
        loss = train_step(batch, model, opt)
        if batch_idx%10==0:
            print(f"Step: {batch_idx} \t Loss {float(loss)}")

def val_step(batch, model):
    x, y, idx = batch
    x = x.to("cuda")
    y = y.to("cuda")
    with torch.autocast(device_type="cuda", dtype=torch.float16):
        preds = model(x)
        loss = loss_func(preds, y)
    acc = (y == torch.argmax(preds, 1)).type(torch.FloatTensor)
    return acc

def val_epoch(dataloader, model):
    acc = []
    for batch_idx, batch in enumerate(dataloader): 
        acc.append(val_step(batch, model))
    return torch.stack(acc).reshape(-1)

In [None]:
train_epoch(train_dataloader, model)
val_epoch(val_dataloader, model)