In [1]:
import sys
from pathlib import Path

# add parent folder to the path
module_path = str(Path.cwd().parents[0])
if module_path not in sys.path:
    sys.path.append(module_path)

___

In [2]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader # load batches to the network

from feito.trainer import BasecallerTrainer as Trainer
from feito.models import SimpleNet, Rodan
from feito.loss_functions import ctc_label_smoothing_loss
from feito.dataloader import DatasetONT

___
PARAMS

In [3]:
EPOCHS=10
BATCH_SIZE=16

### Training network
Model, loss function and optimizer

In [4]:
# network to use
model=SimpleNet()
# model=Rodan()
model_output_len = model.output_len # another way to obtain the output of the model https://github.com/biodlab/RODAN/blob/029f7d5eb31b11b53537f13164bfedee0c0786e4/model.py#L317
loss_fn = nn.CTCLoss() #ctc_label_smoothing 
optimizer = torch.optim.SGD(model.parameters(), lr = 1e-3)


Load dataset

In [5]:
# dataset
dataset_train = DatasetONT(recfile="../data/subsample_train.hdf5", output_network_len=model_output_len) #simplenet_output_len)
dataset_val   = DatasetONT(recfile="../data/subsample_val.hdf5", output_network_len=model_output_len) #simplenet_output_len)

dataloader_train = DataLoader(dataset_val, batch_size=BATCH_SIZE, shuffle=True)
dataloader_val = DataLoader(dataset_val, batch_size=BATCH_SIZE, shuffle=False)

Train

In [6]:

trainer=Trainer(
    model=SimpleNet(),
    device=torch.device("cuda" if torch.cuda.is_available() else "cpu"),
    train_loader=dataloader_train,
    validation_loader=dataloader_val,
    criterion=loss_fn,
    optimizer=optimizer,
    callbacks=[]
)

In [7]:
trainer.fit(epochs=EPOCHS)

Epoch: 1 | batch: 12/13: 100%|████████████████████████████████████████████████| [00:04, loss=2.5319]
Epoch: 2 | batch: 12/13: 100%|████████████████████████████████████████████████| [00:03, loss=2.6133]
Epoch: 3 | batch: 12/13: 100%|████████████████████████████████████████████████| [00:04, loss=2.6180]
Epoch: 4 | batch: 12/13: 100%|████████████████████████████████████████████████| [00:04, loss=2.5693]
Epoch: 5 | batch: 12/13: 100%|████████████████████████████████████████████████| [00:04, loss=2.5465]
Epoch: 6 | batch: 12/13: 100%|████████████████████████████████████████████████| [00:04, loss=2.5285]
Epoch: 7 | batch: 12/13: 100%|████████████████████████████████████████████████| [00:04, loss=2.4814]
Epoch: 8 | batch: 12/13: 100%|████████████████████████████████████████████████| [00:04, loss=2.3735]
Epoch: 9 | batch: 12/13: 100%|████████████████████████████████████████████████| [00:04, loss=2.6132]
Epoch: 10 | batch: 12/13: 100%|██████████████████████████████████████████████████| [00:04, 

Done!



