In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from nb_004a import *

In [None]:
from ipywidgets import IntProgress, HBox, HTML, VBox
from time import time

In [None]:
def format_time(t):
    t = int(t)
    h,m,s = t//3600, (t//60)%60, t%60
    if h!= 0: return f'{h}:{m:02d}:{s:02d}'
    else:     return f'{m:02d}:{s:02d}'

In [None]:
class ProgressBar():
    update_every = 0.2
    
    def __init__(self, gen, display=True, leave=True, parent=None):
        self._gen,self.total = gen,len(gen)
        self.progress,self.text = IntProgress(min=0, max=self.total), HTML()
        self.box = HBox([self.progress, self.text])
        if parent is None: self.leave,self.display = leave,display
        else:
            self.leave,self.display=False,False
            parent.add_child(self)
        self.comment = ''
        
    def __iter__(self):
        if self.display: 
            display(self.box)
        self.update(0)
        try:
            for i,o in enumerate(self._gen):
                yield o
                self.update(i+1)
        except:
            self.progress.bar_style = 'danger'
        if not self.leave: self.box.close()
    
    def update(self, val):
        if val == 0:
            self.start_t = self.last_t = time()
            self.pred_t = 0
            self.last_v,self.wait_for = 0,1
            self.update_bar(0)
        elif val >= self.last_v + self.wait_for:
            cur_t = time()
            avg_t = (cur_t - self.start_t) / val
            self.wait_for = max(int(self.update_every / avg_t),1)
            self.pred_t = avg_t * self.total
            self.last_v,self.last_t = val,cur_t
            self.update_bar(val)
    
    def update_bar(self, val):
        elapsed_t = self.last_t - self.start_t
        remaining_t = format_time(self.pred_t - elapsed_t)
        elapsed_t = format_time(elapsed_t)
        end = '' if len(self.comment) == 0 else f' {self.comment}'
        self.text.value = f'{100 * val/self.total:.2f}% [{val}/{self.total} {elapsed_t}<{remaining_t}{end}]'
        self.progress.value = val

In [None]:
class MasterBar():
    def __init__(self, gen):
        self.first_bar = ProgressBar(gen, display=False)
        self.text = HTML()
        self.vbox = VBox([self.first_bar.box, self.text])
    
    def __iter__(self):
        display(self.vbox)
        for o in self.first_bar:
            yield o
    
    def add_child(self, child):
        self.child = child
        self.vbox.children = [self.first_bar.box, self.text, child.box]
    
    def write(self, line):
        self.text.value += line + '<p>'

In [None]:
mb = MasterBar(range(10))

In [None]:
from time import sleep
for j in mb:
    for i in ProgressBar(range(0, 100), parent=mb):
        sleep(0.01)
        mb.child.comment = str(i)
    mb.text.value += f'Epoch {j+1}: accuracy: {7.5*j+5}%<p>'

In [None]:
#export
@dataclass
class DeviceDataLoader():
    dl: DataLoader
    device: torch.device
    half: bool = False
        
    def __len__(self): return len(self.dl)
    def __iter__(self):
        self.gen = (to_device(self.device,o) for o in self.dl)
        if self.half: self.gen = (to_half(o) for o in self.gen)
        return iter(self.gen)

    @classmethod
    def create(cls, *args, device=default_device, **kwargs):
        return cls(DataLoader(*args, **kwargs), device=device, half=False)

nb_002b.DeviceDataLoader = DeviceDataLoader

In [None]:
def fit(epochs, model, loss_fn, opt, data, callbacks=None, metrics=None, pbar=None):
    cb_handler = CallbackHandler(callbacks)
    cb_handler.on_train_begin()
    if pbar is None: pbar = MasterBar(range(epochs))
    
    for epoch in pbar:
        model.train()
        cb_handler.on_epoch_begin()
        
        for xb,yb in ProgressBar(data.train_dl, parent=pbar):
            xb, yb = cb_handler.on_batch_begin(xb, yb)
            loss,_ = loss_batch(model, xb, yb, loss_fn, opt, cb_handler)
            if cb_handler.on_batch_end(loss): break
        
        if hasattr(data,'valid_dl') and data.valid_dl is not None:
            model.eval()
            with torch.no_grad():
                *val_metrics,nums = zip(*[loss_batch(model, xb, yb, loss_fn, cb_handler=cb_handler, metrics=metrics)
                                for xb,yb in ProgressBar(data.valid_dl, parent=pbar)])
            val_metrics = [np.sum(np.multiply(val,nums)) / np.sum(nums) for val in val_metrics]
            
        else: val_metrics=None
        if cb_handler.on_epoch_end(val_metrics): break
        
    cb_handler.on_train_end()

In [None]:
#export
@dataclass
class Learner():
    data: DataBunch
    model: nn.Module
    opt_fn: Callable = optim.SGD
    loss_fn: Callable = F.cross_entropy
    metrics: Collection[Callable] = None
    true_wd: bool = False
    def __post_init__(self): self.model = self.model.to(self.data.device)

    def fit(self, epochs, lr, wd=0., callbacks=None):
        self.opt = OptimWrapper(self.opt_fn(self.model.parameters(), lr), wd=wd, true_wd=self.true_wd)
        pbar = MasterBar(range(epochs))
        self.recorder = Recorder(self.opt, self.data.train_dl, pbar)
        if callbacks is None: callbacks = []
        callbacks = [self.recorder]+callbacks
        fit(epochs, self.model, self.loss_fn, self.opt, self.data, callbacks=callbacks, metrics=self.metrics, pbar=pbar)
        
    def lr_find(self, start_lr=1e-5, end_lr=10, num_it=100):
        cb = LRFinder(self, start_lr, end_lr, num_it)
        a = int(np.ceil(num_it/len(self.data.train_dl)))
        self.fit(a, start_lr, callbacks=[cb])

In [None]:
#export
@dataclass
class Recorder(Callback):
    opt: torch.optim
    train_dl: DeviceDataLoader = None
    pbar: MasterBar = None

    def on_train_begin(self, **kwargs):
        self.losses,self.val_losses,self.lrs,self.moms,self.metrics,self.nb_batches = [],[],[],[],[],[]
    
    def on_batch_begin(self, **kwargs):
        self.lrs.append(self.opt.lr)
        self.moms.append(self.opt.mom)
    
    def on_backward_begin(self, smooth_loss, **kwargs):
        #We record the loss here before any other callback has a chance to modify it.
        self.losses.append(smooth_loss)
        if self.pbar is not None and hasattr(self.pbar,'child'): 
            self.pbar.child.comment = f'{smooth_loss:.4f}'
    
    def on_epoch_end(self, epoch, num_batch, smooth_loss, last_metrics, **kwargs):
        self.nb_batches.append(num_batch)
        if last_metrics is not None:
            self.val_losses.append(last_metrics[0])
            if len(last_metrics) > 1: self.metrics.append(last_metrics[1:])
            self.pbar.write(f'{epoch}, {smooth_loss}, {last_metrics}')
        else:  self.pbar.write(f'{epoch}, {smooth_loss}')
    
    def plot_lr(self, show_moms=False):
        iterations = list(range(len(learn.recorder.lrs)))
        if show_moms:
            _, axs = plt.subplots(1,2, figsize=(12,4))
            axs[0].plot(iterations, self.lrs)
            axs[1].plot(iterations, self.moms)
        else: plt.plot(iterations, self.lrs)
    
    def plot(self, skip_start=10, skip_end=5):
        lrs = self.lrs[skip_start:-skip_end] if skip_end > 0 else self.lrs[skip_start:]
        losses = self.losses[skip_start:-skip_end] if skip_end > 0 else self.losses[skip_start:]
        _, ax = plt.subplots(1,1)
        ax.plot(lrs, losses)
        ax.set_xscale('log')
    
    def plot_losses(self):
        _, ax = plt.subplots(1,1)
        iterations = list(range(len(self.losses)))
        ax.plot(iterations, self.losses)
        val_iter = self.nb_batches
        val_iter = np.array(val_iter).cumsum()
        ax.plot(val_iter, self.val_losses)
    
    def plot_metrics(self):
        assert len(self.metrics) != 0, "There is no metrics to plot."
        _, axes = plt.subplots(len(self.metrics[0]),1,figsize=(6, 4*len(self.metrics[0])))
        val_iter = self.nb_batches
        val_iter = np.array(val_iter).cumsum()
        axes = axes.flatten() if len(self.metrics[0]) != 1 else [axes]
        for i, ax in enumerate(axes):
            values = [met[i] for met in self.metrics]
            ax.plot(val_iter, values)

In [None]:
DATA_PATH = Path('data')
PATH = DATA_PATH/'cifar10'

data_mean,data_std = map(tensor, ([0.491, 0.482, 0.447], [0.247, 0.243, 0.261]))
cifar_norm = normalize_tfm(mean=data_mean,std=data_std)

tfms = [flip_lr_tfm(p=0.5),
        pad_tfm(padding=4),
        crop_tfm(size=32, row_pct=(0,1.), col_pct=(0,1.))]

bs = 64

In [None]:
train_ds = FilesDataset.from_folder(PATH/'train', classes=['airplane','dog'])
valid_ds = FilesDataset.from_folder(PATH/'test', classes=['airplane','dog'])
data = DataBunch.create(train_ds, valid_ds, bs=bs, train_tfm=tfms, valid_tfm=[], num_workers=4)
len(data.train_dl), len(data.valid_dl)

In [None]:
model = Darknet([1, 2, 2, 2, 2], num_classes=2, nf=16)
learn = Learner(data, model)

In [None]:
learn.fit(5,0.01)