In [1]:
import torch
import numpy as np

In [2]:
from torch import nn, optim as opts
from torch_modules import blocks
from torch_modules.training import Trainer, data, metrics, callbacks
from tensorflow.keras import datasets as dts

In [3]:
(x_train, y_train), (x_test, y_test) = dts.mnist.load_data()

In [4]:
model = nn.Sequential(*[
    blocks.ConvBlock(1, 64, kernel_size=7, padding=3),
    blocks.commons.GlobalAverage(),
    nn.Linear(64, 10),
    nn.Softmax()
])

model.cuda(0)

Sequential(
  (0): ConvBlock(
    (transformer): Conv2d(1, 64, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), bias=False)
    (normalizer): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (activator): ReLU()
  )
  (1): GlobalAverage(axes=[2, 3], keepdims=False)
  (2): Linear(in_features=64, out_features=10, bias=True)
  (3): Softmax(dim=None)
)

In [5]:
loss = nn.CrossEntropyLoss()

In [6]:
class Data(data.Sequence):
    
    def get_len(self):
        return 10
    
    def get_item(self, idx):
        return np.random.randn(64, 1, 28, 28), np.random.randint(0, 10, size=(64,))

In [7]:
d = Data(x_gpu="cuda:0", y_gpu="cuda:0")

In [8]:
trainer = Trainer(model, opts.Adam(model.parameters()), loss, metrics=[metrics.accuracy])

In [9]:
cbs = [
    callbacks.Tensorboard("logs/test", steps=2)
]

In [10]:
from tqdm import tqdm_notebook as tqdm

trainer.fit(d, d, epochs=10, callbacks=cbs, pbar=tqdm)

epochs: 1/10


HBox(children=(IntProgress(value=0, description='train', max=10, style=ProgressStyle(description_width='initia…

  input = module(input)


loss    2.3021488189697266
accuracy    0.09375
loss    2.302395820617676
accuracy    0.07291666666666667
loss    2.3019590854644774
accuracy    0.0875
loss    2.302225112915039
accuracy    0.08928571428571429
loss    2.3021306726667614
accuracy    0.10243055555555555



HBox(children=(IntProgress(value=0, description='evaluate', max=10, style=ProgressStyle(description_width='ini…


epochs: 2/10


HBox(children=(IntProgress(value=0, description='train', max=10, style=ProgressStyle(description_width='initia…

loss    2.3007750511169434
accuracy    0.140625
loss    2.301407734553019
accuracy    0.14583333333333334
loss    2.3023250579833983
accuracy    0.140625
loss    2.302020754132952
accuracy    0.140625
loss    2.3022089269426136
accuracy    0.13368055555555555



HBox(children=(IntProgress(value=0, description='evaluate', max=10, style=ProgressStyle(description_width='ini…


epochs: 3/10


HBox(children=(IntProgress(value=0, description='train', max=10, style=ProgressStyle(description_width='initia…

loss    2.2990591526031494
accuracy    0.15625
loss    2.302384376525879
accuracy    0.09895833333333333
loss    2.3013820171356203
accuracy    0.096875
loss    2.3010283538273404
accuracy    0.10491071428571429
loss    2.30114836162991
accuracy    0.109375



HBox(children=(IntProgress(value=0, description='evaluate', max=10, style=ProgressStyle(description_width='ini…


epochs: 4/10


HBox(children=(IntProgress(value=0, description='train', max=10, style=ProgressStyle(description_width='initia…

loss    2.3079569339752197
accuracy    0.046875
loss    2.305551846822103
accuracy    0.06770833333333333
loss    2.3050097465515136
accuracy    0.075
loss    2.306152139391218
accuracy    0.0625
loss    2.3057370450761585
accuracy    0.05555555555555555



HBox(children=(IntProgress(value=0, description='evaluate', max=10, style=ProgressStyle(description_width='ini…


epochs: 5/10


HBox(children=(IntProgress(value=0, description='train', max=10, style=ProgressStyle(description_width='initia…

loss    2.3058249950408936
accuracy    0.09375
loss    2.3039015928904214
accuracy    0.10416666666666667
loss    2.3028614997863768
accuracy    0.128125
loss    2.302563258579799
accuracy    0.12723214285714285
loss    2.3030540148417153
accuracy    0.11805555555555555



HBox(children=(IntProgress(value=0, description='evaluate', max=10, style=ProgressStyle(description_width='ini…


epochs: 6/10


HBox(children=(IntProgress(value=0, description='train', max=10, style=ProgressStyle(description_width='initia…

loss    2.3067800998687744
accuracy    0.046875
loss    2.30475385983785
accuracy    0.06770833333333333
loss    2.3035507678985594
accuracy    0.09375
loss    2.3026740550994873
accuracy    0.10044642857142858
loss    2.303250869115194
accuracy    0.09375



HBox(children=(IntProgress(value=0, description='evaluate', max=10, style=ProgressStyle(description_width='ini…


epochs: 7/10


HBox(children=(IntProgress(value=0, description='train', max=10, style=ProgressStyle(description_width='initia…

loss    2.304008960723877
accuracy    0.0625
loss    2.3019065062204995
accuracy    0.11458333333333333
loss    2.3030813217163084
accuracy    0.1
loss    2.3035426821027483
accuracy    0.08705357142857142
loss    2.303857167561849
accuracy    0.08333333333333333



HBox(children=(IntProgress(value=0, description='evaluate', max=10, style=ProgressStyle(description_width='ini…


epochs: 8/10


HBox(children=(IntProgress(value=0, description='train', max=10, style=ProgressStyle(description_width='initia…

loss    2.301762580871582
accuracy    0.09375
loss    2.3027278582255044
accuracy    0.09895833333333333
loss    2.302763509750366
accuracy    0.096875
loss    2.3027816159384593
accuracy    0.08482142857142858
loss    2.3030331399705677
accuracy    0.0920138888888889



HBox(children=(IntProgress(value=0, description='evaluate', max=10, style=ProgressStyle(description_width='ini…


epochs: 9/10


HBox(children=(IntProgress(value=0, description='train', max=10, style=ProgressStyle(description_width='initia…

loss    2.3026790618896484
accuracy    0.046875
loss    2.3020344575246177
accuracy    0.07291666666666667
loss    2.3022915363311767
accuracy    0.078125
loss    2.302025011607579
accuracy    0.08482142857142858
loss    2.302196158303155
accuracy    0.08680555555555555



HBox(children=(IntProgress(value=0, description='evaluate', max=10, style=ProgressStyle(description_width='ini…


epochs: 10/10


HBox(children=(IntProgress(value=0, description='train', max=10, style=ProgressStyle(description_width='initia…

loss    2.3002796173095703
accuracy    0.15625
loss    2.301819642384847
accuracy    0.13541666666666666
loss    2.3030513763427733
accuracy    0.10625
loss    2.303119250706264
accuracy    0.10267857142857142
loss    2.3027315934499106
accuracy    0.10590277777777778



HBox(children=(IntProgress(value=0, description='evaluate', max=10, style=ProgressStyle(description_width='ini…


