In [1]:
import torch
import torch.nn as nn
import torch.optim as opts
import torch.utils.data as data

In [None]:
from plasma.modules import *
from plasma.training import trainers, metrics, callbacks
from tensorflow.keras.datasets import mnist

In [4]:
(xtr, ytr), (xt, yt) = mnist.load_data()

xtr.shape, ytr.shape, xt.shape, yt.shape

((60000, 28, 28), (60000,), (10000, 28, 28), (10000,))

In [5]:
xtr = torch.tensor(xtr, dtype=torch.float) / 255
ytr = torch.tensor(ytr)
xt = torch.tensor(xt, dtype=torch.float) / 255
yt = torch.tensor(yt)

In [6]:
train = data.TensorDataset(xtr[:, None], ytr)
test = data.TensorDataset(xt[:, None], yt)

In [7]:
model = nn.Sequential(*[
    nn.Conv2d(1, 16, kernel_size=3, padding=1),
    nn.BatchNorm2d(16),
    nn.ReLU(inplace=True),
    nn.Conv2d(16, 32, kernel_size=3, padding=1),
    nn.BatchNorm2d(32),
    nn.ReLU(inplace=True),
    
    nn.MaxPool2d(2, 2),
    nn.Conv2d(32, 32, kernel_size=3, padding=1),
    nn.BatchNorm2d(32),
    nn.ReLU(inplace=True),
    nn.Conv2d(32, 64, kernel_size=3, padding=1),
    nn.BatchNorm2d(64),
    nn.ReLU(inplace=True),
    
    nn.MaxPool2d(2, 2),
    nn.Conv2d(64, 64, kernel_size=3, padding=1),
    nn.BatchNorm2d(64),
    nn.ReLU(inplace=True),
    nn.Conv2d(64, 128, kernel_size=3, padding=1),
    nn.BatchNorm2d(128),
    nn.ReLU(inplace=True),
    
    GlobalAverage(),
    nn.Linear(128, 10)
])

model.cuda(0)

Sequential(
  (0): Conv2d(1, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (2): ReLU(inplace=True)
  (3): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (4): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (5): ReLU(inplace=True)
  (6): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (7): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (8): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (9): ReLU(inplace=True)
  (10): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (11): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (12): ReLU(inplace=True)
  (13): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (14): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (

In [8]:
opt = opts.SGD(model.parameters(), lr=1e-2, momentum=0.9, nesterov=True)

In [9]:
loss = nn.CrossEntropyLoss()

In [10]:
trainer = trainers.Trainer(model, opt, loss, metrics=[metrics.acc_fn()], x_device="cuda:0", y_device="cuda:0")

In [11]:
cbs = [
    callbacks.
]

In [12]:
trainer.fit(train, test, batch_size=128, callbacks=cbs)

epoch 1


HBox(children=(FloatProgress(value=0.0, description='train', max=468.0, style=ProgressStyle(description_width=…




HBox(children=(FloatProgress(value=0.0, description='evaluate', max=79.0, style=ProgressStyle(description_widt…


epoch 2


HBox(children=(FloatProgress(value=0.0, description='train', max=468.0, style=ProgressStyle(description_width=…




HBox(children=(FloatProgress(value=0.0, description='evaluate', max=79.0, style=ProgressStyle(description_widt…


epoch 3


HBox(children=(FloatProgress(value=0.0, description='train', max=468.0, style=ProgressStyle(description_width=…




HBox(children=(FloatProgress(value=0.0, description='evaluate', max=79.0, style=ProgressStyle(description_widt…


epoch 4


HBox(children=(FloatProgress(value=0.0, description='train', max=468.0, style=ProgressStyle(description_width=…




KeyboardInterrupt: 