In [1]:
import torch
from torch.utils.data import DataLoader
from torch.optim import Adam, Adadelta

In [2]:
from willitsurf.models.lenet import LeNet
from willitsurf.dataset import SurfImageDataset
from willitsurf.dataset import train_val_test_datset_split
from willitsurf.dataset import BalancedClassDataloader
from willitsurf.models.lenet import train, validate, test

In [3]:
dataset = SurfImageDataset('../assets/data/labels/annotations.tsv', '../assets/data/raw', relative_path_to_assets='../')

In [4]:
len(dataset)

758

In [5]:
train_dataset, val_dataset, test_dataset = train_val_test_datset_split(dataset)

In [6]:
train_dataloader = BalancedClassDataloader(train_dataset, batch_size=32)
val_dataloader = DataLoader(val_dataset, batch_size=32, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=32, shuffle=True)

In [10]:
model = LeNet()
model.to('mps')

LeNet(
  (cn1): Conv2d(3, 6, kernel_size=(10, 10), stride=(2, 2))
  (cn2): Conv2d(6, 16, kernel_size=(10, 10), stride=(4, 4))
  (fc1): Linear(in_features=9440, out_features=120, bias=True)
  (fc2): Linear(in_features=120, out_features=84, bias=True)
  (fc3): Linear(in_features=84, out_features=2, bias=True)
)

In [15]:
optimizer = Adam(model.parameters(), lr=0.001)

In [9]:
import logging
logger = logging.getLogger()
logger.setLevel(logging.DEBUG)

In [19]:
for epoch in range(10):
    train(model, 'mps', train_dataloader, optimizer, epoch)
    if epoch % 10 == 0:
        validate(model, 'mps', val_dataloader)

epoch: 0, 0, 568, 0.0, 1.8551587572801509e-06
epoch: 0, 320, 568, 55.55555555555556, 1.121305899687286e-06
Accuracy on val dataset: 84.21052631578948
epoch: 1, 0, 568, 0.0, 8.083840157269151e-07
epoch: 1, 320, 568, 55.55555555555556, 1.4752037031939835e-06
epoch: 2, 0, 568, 0.0, 4.768347992012423e-07
epoch: 2, 320, 568, 55.55555555555556, 1.4901142719736526e-07
epoch: 3, 0, 568, 0.0, 5.997691232551006e-07
epoch: 3, 320, 568, 55.55555555555556, 1.143658664659597e-06
epoch: 4, 0, 568, 0.0, 2.250003035442205e-06
epoch: 4, 320, 568, 55.55555555555556, 1.605590682629554e-06
epoch: 5, 0, 568, 0.0, 5.40163625828427e-07
epoch: 5, 320, 568, 55.55555555555556, 1.799289293558104e-06
epoch: 6, 0, 568, 0.0, 3.112441845587455e-05
epoch: 6, 320, 568, 55.55555555555556, 2.2351237021212e-06
epoch: 7, 0, 568, 0.0, 3.725289587919178e-08
epoch: 7, 320, 568, 55.55555555555556, 3.1676718208473176e-05
epoch: 8, 0, 568, 0.0, 7.413287903546006e-07
epoch: 8, 320, 568, 55.55555555555556, 6.045776899554767e-06
ep

In [20]:
X, y = next(iter(test_dataloader))
X = X.to('mps')
y = y.to('mps')

In [21]:
model.forward(X)

tensor([[ 32.0939, -27.7632],
        [  3.2352,  -3.4434],
        [ 34.7594, -29.2567],
        [ 14.4928, -12.2157],
        [ 35.9465, -29.4924],
        [ 42.2531, -35.2894],
        [ 51.4587, -44.0957],
        [ 28.1328, -24.6779],
        [ 15.5242, -13.1013],
        [ 75.8901, -64.0907],
        [ 26.5152, -22.7055],
        [ 21.0911, -17.6091],
        [ 23.9962, -23.5931],
        [ 19.1428, -16.1193],
        [ -3.1498,   1.2285],
        [ 32.6228, -28.1564],
        [ 11.5873,  -9.6030],
        [ 13.0622, -11.3630],
        [ 25.4478, -21.5551],
        [ 13.3042, -11.6053],
        [  0.7720,  -1.5833],
        [  8.0265,  -6.7527],
        [ 27.2571, -23.0067],
        [ 10.8682,  -9.5984],
        [ 13.8624, -12.7195],
        [ 15.9116, -13.1407],
        [  3.6291,  -4.0108],
        [ -5.6680,   3.2765],
        [ 36.9584, -31.4195],
        [ 29.2270, -25.7876],
        [ 44.0331, -36.5216],
        [  8.7432,  -7.9012]], device='mps:0', grad_fn=<LinearBackward