In [1]:
import torch
from torch.utils.data import DataLoader
from torch.optim import Adam, Adadelta

In [2]:
from willitsurf.models.lenet import LeNet
from willitsurf.dataset import SurfImageDataset
from willitsurf.dataset import train_val_test_datset_split
from willitsurf.dataset import BalancedClassDataloader
from willitsurf.models.lenet import train, validate, test

In [3]:
dataset = SurfImageDataset('../assets/data/labels/annotations.tsv', '../assets/data/raw', relative_path_to_assets='../')

In [4]:
len(dataset)

758

In [5]:
train_dataset, val_dataset, test_dataset = train_val_test_datset_split(dataset)

In [6]:
train_dataloader = BalancedClassDataloader(train_dataset, batch_size=32)
val_dataloader = DataLoader(val_dataset, batch_size=32, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=32, shuffle=True)

In [10]:
model = LeNet()
model.to('mps')

LeNet(
  (cn1): Conv2d(3, 6, kernel_size=(10, 10), stride=(2, 2))
  (cn2): Conv2d(6, 16, kernel_size=(10, 10), stride=(4, 4))
  (fc1): Linear(in_features=9440, out_features=120, bias=True)
  (fc2): Linear(in_features=120, out_features=84, bias=True)
  (fc3): Linear(in_features=84, out_features=2, bias=True)
)

In [11]:
optimizer = Adam(model.parameters(), lr=0.0001)

In [9]:
import logging
logger = logging.getLogger()
logger.setLevel(logging.DEBUG)

In [None]:
for epoch in range(100):
    train(model, 'mps', train_dataloader, optimizer, epoch)
    if epoch % 10 == 0:
        validate(model, 'mps', val_dataloader)

epoch: 0, 0, 568, 0.0, 0.6990761756896973
epoch: 0, 320, 568, 55.55555555555556, 0.6944805383682251
Accuracy on val dataset: 23.68421052631579
epoch: 1, 0, 568, 0.0, 0.6900649070739746
epoch: 1, 320, 568, 55.55555555555556, 0.7013429403305054
epoch: 2, 0, 568, 0.0, 0.6885457634925842
epoch: 2, 320, 568, 55.55555555555556, 0.6779267191886902
epoch: 3, 0, 568, 0.0, 0.6725304126739502
epoch: 3, 320, 568, 55.55555555555556, 0.6704833507537842
epoch: 4, 0, 568, 0.0, 0.7143497467041016
epoch: 4, 320, 568, 55.55555555555556, 0.6536661386489868
epoch: 5, 0, 568, 0.0, 0.6367033123970032
epoch: 5, 320, 568, 55.55555555555556, 0.6638399362564087
epoch: 6, 0, 568, 0.0, 0.6382762789726257
epoch: 6, 320, 568, 55.55555555555556, 0.6300560235977173
epoch: 7, 0, 568, 0.0, 0.6580138802528381
epoch: 7, 320, 568, 55.55555555555556, 0.6542156338691711
epoch: 8, 0, 568, 0.0, 0.6340775489807129
epoch: 8, 320, 568, 55.55555555555556, 0.615452766418457
epoch: 9, 0, 568, 0.0, 0.5619714260101318
epoch: 9, 320, 5

In [None]:
X, y = next(iter(test_dataloader))
X = X.to('mps')
y = y.to('mps')

In [None]:
model.forward(X)