In [1]:
import torch
from torch.utils.data import DataLoader
from torch.optim import Adam, Adadelta
from torchvision.transforms.v2 import Resize

In [2]:
from willitsurf.models.lenet import LeNet
from willitsurf.dataset import SurfImageDataset
from willitsurf.dataset import train_val_test_datset_split
from willitsurf.dataset import BalancedClassDataloader
from willitsurf.models.lenet import train, validate, test

In [3]:
dataset = SurfImageDataset(
    '../assets/data/labels/annotations.tsv',
    '../assets/data/raw', 
    relative_path_to_assets='../',
    transform=Resize(200),
)

In [4]:
len(dataset)

758

In [5]:
dataset[0][0].shape

torch.Size([3, 200, 1066])

In [6]:
train_dataset, val_dataset, test_dataset = train_val_test_datset_split(dataset)

In [7]:
train_dataloader = BalancedClassDataloader(train_dataset, batch_size=32)
val_dataloader = DataLoader(val_dataset, batch_size=32, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=32, shuffle=True)

In [8]:
model = LeNet()
model.to('mps')

LeNet(
  (cn1): Conv2d(3, 6, kernel_size=(10, 10), stride=(4, 4))
  (cn2): Conv2d(6, 12, kernel_size=(5, 5), stride=(2, 2))
  (fc1): Linear(in_features=1920, out_features=256, bias=True)
  (fc2): Linear(in_features=256, out_features=128, bias=True)
  (fc3): Linear(in_features=128, out_features=2, bias=True)
)

In [28]:
optimizer = Adam(model.parameters(), lr=0.00005)

In [10]:
import logging
logger = logging.getLogger()
logger.setLevel(logging.DEBUG)

In [29]:
for epoch in range(100):
    train(model, 'mps', train_dataloader, optimizer, epoch)
    if epoch % 10 == 0:
        validate(model, 'mps', val_dataloader)

epoch: 0, 0, 568, 0.0, 0.0007094840984791517
epoch: 0, 320, 568, 55.55555555555556, 0.0006337517406791449
Accuracy on val dataset: 86.84210526315789
epoch: 1, 0, 568, 0.0, 0.0011226281058043242
epoch: 1, 320, 568, 55.55555555555556, 0.0025632623583078384
epoch: 2, 0, 568, 0.0, 0.0009355052607133985
epoch: 2, 320, 568, 55.55555555555556, 0.002515029162168503
epoch: 3, 0, 568, 0.0, 0.004057032056152821
epoch: 3, 320, 568, 55.55555555555556, 0.0009851958602666855
epoch: 4, 0, 568, 0.0, 0.0019280905835330486
epoch: 4, 320, 568, 55.55555555555556, 0.0003696655039675534
epoch: 5, 0, 568, 0.0, 0.0006957760779187083
epoch: 5, 320, 568, 55.55555555555556, 0.0001564409030834213
epoch: 6, 0, 568, 0.0, 0.0007148830336518586
epoch: 6, 320, 568, 55.55555555555556, 0.0008061200496740639
epoch: 7, 0, 568, 0.0, 0.0008608687785454094
epoch: 7, 320, 568, 55.55555555555556, 0.0009443983435630798
epoch: 8, 0, 568, 0.0, 0.0007508236449211836
epoch: 8, 320, 568, 55.55555555555556, 0.0003867443883791566
epoch

In [30]:
test(model, 'mps', test_dataloader)

Accuracy on test dataset: 83.33333333333333


In [31]:
X, y = next(iter(test_dataloader))
X = X.to('mps')
y = y.to('mps')

In [32]:
model.forward(X)

tensor([[ -8.1466,   8.9173],
        [  1.9645,  -2.2048],
        [  4.1598,  -4.5028],
        [ -0.3498,   0.4950],
        [ 12.6378, -13.6227],
        [ 19.4657, -21.1892],
        [ -0.4987,   0.6159],
        [ -4.3594,   4.6279],
        [ 18.1183, -19.8352],
        [  4.5584,  -4.8938],
        [ 18.5750, -20.1342],
        [ 11.5522, -12.5082],
        [ 10.4821, -11.2992],
        [ 11.1571, -11.8921],
        [  8.0924,  -8.6562],
        [  4.8536,  -5.2436],
        [  7.9530,  -8.5912],
        [  2.4535,  -2.5424],
        [ 11.5965, -12.4407],
        [ 13.2311, -14.3281],
        [  0.2089,  -0.2241],
        [  0.7670,  -0.7743],
        [  9.4879, -10.2881],
        [  5.7095,  -6.1207],
        [ -1.7584,   2.0618],
        [ -3.0185,   3.4551],
        [  8.0404,  -8.7289],
        [  6.7572,  -7.3140],
        [ 14.1230, -15.2156],
        [-10.1204,  10.9344],
        [  4.8826,  -5.3193],
        [ 10.0126, -10.6957]], device='mps:0', grad_fn=<LinearBackward

In [33]:
y

tensor([1, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0,
        0, 1, 0, 1, 0, 0, 0, 0], device='mps:0')