# Data

In [None]:
import collections
import torch
import torchvision
import torchvision.transforms as transforms


bs = 32
n_workers = 4

data_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])


loaders = collections.OrderedDict()

trainset = torchvision.datasets.CIFAR10(
    root='./data', train=True,
    download=True, transform=data_transform)
trainloader = torch.utils.data.DataLoader(
    trainset, batch_size=bs,
    shuffle=True, num_workers=n_workers)

testset = torchvision.datasets.CIFAR10(
    root='./data', train=False,
    download=True, transform=data_transform)
testloader = torch.utils.data.DataLoader(
    testset, batch_size=bs,
    shuffle=False, num_workers=n_workers)

loaders["train"] = trainloader
loaders["valid"] = testloader

# Model

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

# Model, criterion, optimizer

In [None]:
model = Net()
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters())
# scheduler = None  # for OneCycle usage
# scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[3, 8], gamma=0.3)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.5, patience=2, verbose=True)

# Callbacks

In [None]:
import collections
from catalyst.dl.callbacks import (
    LossCallback, 
    Logger, TensorboardLogger,
    OptimizerCallback, SchedulerCallback, CheckpointCallback, 
    PrecisionCallback, OneCycleLR)

# the only tricky part
n_epochs = 10
logdir = "./logs/cifar_simple_notebook"

callbacks = collections.OrderedDict()

callbacks["loss"] = LossCallback()
callbacks["optimizer"] = OptimizerCallback()
callbacks["precision"] = PrecisionCallback(
    precision_args=[1, 3, 5])

# OneCylce custom scheduler callback
# callbacks["scheduler"] = OneCycleLR(
#     cycle_len=n_epochs,
#     div=3, cut_div=4, momentum_range=(0.95, 0.85))

# Pytorch scheduler callback
callbacks["scheduler"] = SchedulerCallback(
    reduce_metric="precision01")

callbacks["saver"] = CheckpointCallback()
callbacks["logger"] = Logger()
callbacks["tflogger"] = TensorboardLogger()

# Train

In [None]:
from catalyst.dl.runner import SupervisedModelRunner

runner = SupervisedModelRunner(
    model=model, 
    criterion=criterion, 
    optimizer=optimizer, 
    scheduler=scheduler)
runner.train(
    loaders=loaders, 
    callbacks=callbacks, 
    logdir=logdir,
    epochs=n_epochs, verbose=True)

In [None]:
# for graphs use `tensorboard --logdir=./logs/cifar_simple_notebook`