In [1]:
#export
import os
from collections import ChainMap

from visdom import Visdom

from loop.callbacks import Callback, Order

In [2]:
#export
class VisdomDashboard(Callback):
    
    order = Order.Logging()
    
    def __init__(self, show_opt_params: bool=False, batch_freq: int=1, **visdom_conf):
        self.show_opt_params = show_opt_params
        self.batch_freq = batch_freq
        self.visdom_conf = visdom_conf
        self.vis = None
        
    def training_started(self, **kwargs):
        username = self.visdom_conf.get('username', os.environ.get('VISDOM_USERNAME'))
        password = self.visdom_conf.get('password', os.environ.get('VISDOM_PASSWORD'))
        server = self.visdom_conf.get('server', '0.0.0.0')
        port = int(self.visdom_conf.get('port', 9090))
        self.vis = Visdom(server=server, port=port, username=username, password=password)
        
    def batch_ended(self, phase, **kwargs):
        x = phase.batch_index
        
        if phase.grad:
            if self.show_opt_params:
                opt = self.group.opt
                for i, params in enumerate(opt.param_groups):
                    title = f'LR [group: {i}]'
                    self.vis.line(
                        X=[x], Y=[params['lr']], win=title, name='lr',
                        opts=dict(title=title), update='append')

        if x % self.batch_freq == 0:
            self.vis.line(
                X=[x], Y=[phase.batch_loss], win=phase.name, name='batch_loss', 
                opts=dict(title=f'Running Batch Loss [{phase.name}]'), update='append')
    
    def epoch_ended(self, phases, epoch):
        metrics = dict(ChainMap(*[phase.last_metrics for phase in phases]))
        for name, value in metrics.items():
            phase, metric_name = name.split('_')
            self.vis.line(
                X=[epoch], Y=[value], win=metric_name, name=phase,
                opts=dict(title=metric_name), update='append')

In [3]:
import numpy as np
from torch.nn.functional import cross_entropy
from torch.utils.data import DataLoader
from torch.optim import SGD
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts

from loop.callbacks import Average
from loop.modules import fc_network
from loop.metrics import accuracy
from loop.schedule import ScheduleCallback
from loop.training import Loop, Phase
from loop.testing import get_mnist
from loop.utils import from_torch

trn_ds, val_ds = get_mnist(flat=True)

phases = Phase.make_train_valid(trn_ds, val_ds, bs=1024)

net = fc_network(784, [100, 10])

opt = SGD(params=[
  {'params': net[0].parameters(), 'lr': 3e-2, 'momentum': 0.98},
  {'params': net[1].parameters(), 'lr': 1e-1, 'momentum': 0.95}
], weight_decay=0.01)

cbs = [
    Average(accuracy, alias='acc'),
    VisdomDashboard(batch_freq=1, show_opt_params=True),
    ScheduleCallback(CosineAnnealingWarmRestarts(opt, T_0=len(phases['train'])))
]

loop = Loop(net, opt=opt, cbs=cbs, loss_fn=cross_entropy)
loop.train(phases, epochs=3)

Setting up a new session...


Epoch:    1 | train_loss=0.5493, train_acc=0.8336, valid_loss=0.4107, valid_acc=0.9042
Epoch:    2 | train_loss=0.3587, train_acc=0.9106, valid_loss=0.3242, valid_acc=0.9258
Epoch:    3 | train_loss=0.2733, train_acc=0.9284, valid_loss=0.2774, valid_acc=0.9389
