In [1]:
import torch
from torch.utils.data import DataLoader
from torch.optim import Adam, Adadelta
from torchvision.transforms.v2 import Resize

In [2]:
from willitsurf.models.lenet import LeNet
from willitsurf.dataset import SurfImageDataset
from willitsurf.dataset import train_val_test_datset_split
from willitsurf.dataset import BalancedClassDataloader
from willitsurf.models.lenet import train, validate, test

In [3]:
dataset = SurfImageDataset(
    '../assets/data/labels/annotations.tsv',
    '../assets/data/raw', 
    relative_path_to_assets='../',
    transform=Resize(200),
)

In [4]:
len(dataset)

758

In [5]:
dataset[0][0].shape

torch.Size([3, 200, 1066])

In [6]:
train_dataset, val_dataset, test_dataset = train_val_test_datset_split(dataset)

In [7]:
train_dataloader = BalancedClassDataloader(train_dataset, batch_size=32)
val_dataloader = DataLoader(val_dataset, batch_size=32, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=32, shuffle=True)

In [8]:
model = LeNet()
model.to('mps')

LeNet(
  (cn1): Conv2d(3, 6, kernel_size=(10, 10), stride=(4, 4))
  (cn2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
  (fc1): Linear(in_features=10240, out_features=120, bias=True)
  (fc2): Linear(in_features=120, out_features=84, bias=True)
  (fc3): Linear(in_features=84, out_features=2, bias=True)
)

In [9]:
optimizer = Adam(model.parameters(), lr=0.001)

In [11]:
import logging
logger = logging.getLogger()
logger.setLevel(logging.DEBUG)

In [10]:
for epoch in range(100):
    train(model, 'mps', train_dataloader, optimizer, epoch)
    if epoch % 10 == 0:
        validate(model, 'mps', val_dataloader)

epoch: 0, 0, 568, 0.0, 0.6904773712158203
epoch: 0, 320, 568, 55.55555555555556, 0.6834925413131714
Accuracy on val dataset: 15.789473684210526
epoch: 1, 0, 568, 0.0, 0.6849852800369263
epoch: 1, 320, 568, 55.55555555555556, 0.6724197864532471
epoch: 2, 0, 568, 0.0, 0.6814824342727661
epoch: 2, 320, 568, 55.55555555555556, 0.6237773895263672
epoch: 3, 0, 568, 0.0, 0.587654709815979
epoch: 3, 320, 568, 55.55555555555556, 0.5944782495498657
epoch: 4, 0, 568, 0.0, 0.6149712204933167
epoch: 4, 320, 568, 55.55555555555556, 0.6467022895812988
epoch: 5, 0, 568, 0.0, 0.45602717995643616
epoch: 5, 320, 568, 55.55555555555556, 0.5524084568023682
epoch: 6, 0, 568, 0.0, 0.5888446569442749
epoch: 6, 320, 568, 55.55555555555556, 0.526513934135437
epoch: 7, 0, 568, 0.0, 0.5483881235122681
epoch: 7, 320, 568, 55.55555555555556, 0.4931451678276062
epoch: 8, 0, 568, 0.0, 0.3945896029472351
epoch: 8, 320, 568, 55.55555555555556, 0.4005483388900757
epoch: 9, 0, 568, 0.0, 0.43283337354660034
epoch: 9, 320,

In [11]:
X, y = next(iter(test_dataloader))
X = X.to('mps')
y = y.to('mps')

In [12]:
model.forward(X)

tensor([[ -1.7248,   3.0128],
        [-12.8518,  13.7227],
        [  6.8511,  -3.9940],
        [  7.4099,  -6.0442],
        [ 22.4058, -18.0322],
        [ -1.6234,   3.2425],
        [ 23.2176, -19.3941],
        [ 11.3070,  -8.9602],
        [  3.8526,  -2.7298],
        [  3.1140,  -2.0899],
        [ 10.3920,  -8.3694],
        [ 28.9557, -24.2062],
        [ 13.1967, -10.8638],
        [ 16.8589, -13.3209],
        [ 15.3258, -13.0239],
        [  4.3153,  -3.1848],
        [  1.8365,  -0.7154],
        [  7.8264,  -5.8749],
        [  7.9789,  -6.4574],
        [  1.9457,  -0.5193],
        [  1.9691,  -0.7194],
        [ 15.2175, -12.9828],
        [ 14.1074, -11.3453],
        [ 10.6913,  -8.4517],
        [ 15.7758, -12.5869],
        [ 11.5489,  -9.0912],
        [ 11.1761,  -8.5423],
        [ 15.6956, -12.9392],
        [  5.5490,  -4.4158],
        [  1.0301,   1.5098],
        [  0.9002,   0.3427],
        [ -2.7730,   4.7912]], device='mps:0', grad_fn=<LinearBackward