In [None]:
import torch
import numpy as np

In [2]:
from torch import nn, optim as opts
from torch_modules import blocks
from torch_modules.training import Trainer, data, metrics, callbacks
from tensorflow.keras import datasets as dts

In [3]:
(x_train, y_train), (x_test, y_test) = dts.mnist.load_data()

In [4]:
model = nn.Sequential(*[
    blocks.ConvBlock(1, 64, kernel_size=7, padding=3),
    blocks.commons.GlobalAverage(),
    nn.Linear(64, 10),
    nn.Softmax(dim=-1)
])

model.cuda(0)

Sequential(
  (0): ConvBlock(
    (transformer): Conv2d(1, 64, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), bias=False)
    (normalizer): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (activator): ReLU()
  )
  (1): GlobalAverage(axes=[2, 3], keepdims=False)
  (2): Linear(in_features=64, out_features=10, bias=True)
  (3): Softmax(dim=-1)
)

In [5]:
loss = nn.CrossEntropyLoss()

In [6]:
class Data(data.StandardDataset):
    
    def get_len(self):
        return 100
    
    def get_item(self, idx):
        return np.random.randn(1, 28, 28), np.random.randint(0, 10)

In [7]:
d = Data()

In [8]:
trainer = Trainer(model, opts.Adam(model.parameters()), loss, metrics=[metrics.accuracy])

In [9]:
cbs = [
    callbacks.Tensorboard("logs/test", steps=2)
]

In [10]:
from tqdm import tqdm_notebook as tqdm

trainer.fit(d, d, epochs=10, callbacks=cbs, pin_memory=True)

epochs: 1/10


HBox(children=(IntProgress(value=0, description='train', max=3, style=ProgressStyle(description_width='initial…

loss    2.300403118133545
accuracy    0.15625
loss    2.301056702931722
accuracy    0.13541666666666666



HBox(children=(IntProgress(value=0, description='evaluate', style=ProgressStyle(description_width='initial')),…


epochs: 2/10


HBox(children=(IntProgress(value=0, description='train', max=3, style=ProgressStyle(description_width='initial…

loss    2.300450086593628
accuracy    0.109375



HBox(children=(IntProgress(value=0, description='evaluate', style=ProgressStyle(description_width='initial')),…


epochs: 3/10


HBox(children=(IntProgress(value=0, description='train', max=3, style=ProgressStyle(description_width='initial…

loss    2.303591251373291
accuracy    0.0625
loss    2.3034211794535318
accuracy    0.08333333333333333



HBox(children=(IntProgress(value=0, description='evaluate', style=ProgressStyle(description_width='initial')),…


epochs: 4/10


HBox(children=(IntProgress(value=0, description='train', max=3, style=ProgressStyle(description_width='initial…

loss    2.30343759059906
accuracy    0.046875



HBox(children=(IntProgress(value=0, description='evaluate', style=ProgressStyle(description_width='initial')),…


epochs: 5/10


HBox(children=(IntProgress(value=0, description='train', max=3, style=ProgressStyle(description_width='initial…

loss    2.299429416656494
accuracy    0.09375
loss    2.301255146662394
accuracy    0.11458333333333333



HBox(children=(IntProgress(value=0, description='evaluate', style=ProgressStyle(description_width='initial')),…


epochs: 6/10


HBox(children=(IntProgress(value=0, description='train', max=3, style=ProgressStyle(description_width='initial…

loss    2.3040478229522705
accuracy    0.140625



HBox(children=(IntProgress(value=0, description='evaluate', style=ProgressStyle(description_width='initial')),…


epochs: 7/10


HBox(children=(IntProgress(value=0, description='train', max=3, style=ProgressStyle(description_width='initial…

loss    2.304626703262329
accuracy    0.125
loss    2.304081122080485
accuracy    0.08333333333333333



HBox(children=(IntProgress(value=0, description='evaluate', style=ProgressStyle(description_width='initial')),…


epochs: 8/10


HBox(children=(IntProgress(value=0, description='train', max=3, style=ProgressStyle(description_width='initial…

loss    2.302835702896118
accuracy    0.109375



HBox(children=(IntProgress(value=0, description='evaluate', style=ProgressStyle(description_width='initial')),…


epochs: 9/10


HBox(children=(IntProgress(value=0, description='train', max=3, style=ProgressStyle(description_width='initial…

loss    2.302379608154297
accuracy    0.15625
loss    2.3042122522989907
accuracy    0.08333333333333333



HBox(children=(IntProgress(value=0, description='evaluate', style=ProgressStyle(description_width='initial')),…


epochs: 10/10


HBox(children=(IntProgress(value=0, description='train', max=3, style=ProgressStyle(description_width='initial…

loss    2.303207039833069
accuracy    0.109375



HBox(children=(IntProgress(value=0, description='evaluate', style=ProgressStyle(description_width='initial')),…


