In [None]:
%load_ext autoreload
%autoreload 2

import os
import torch
from torch.utils.data import DataLoader
from sidewalk_widths_extractor import Trainer, seed_all
from sidewalk_widths_extractor.dataset import SateliteDataset
from sidewalk_widths_extractor.modules.test import TestModule
from sidewalk_widths_extractor.utilities import get_device

seed_all(42)

# Setup

In [3]:
LOG_DIR = "logs//demo"
TRAIN_BATCH_SIZE = 16
VAL_BATCH_SIZE = 4
NUM_WORKERS = 2
PERSISTENT_WORKERS = True
SPLIT_RATIO = 0.8

device = get_device()
opt_params = {"lr": 2e-4, "weight_decay": 1e-4}
module = TestModule(opt_params, device)

dataset = SateliteDataset("data/images/", "data/masks/")
train_size = int(SPLIT_RATIO * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])
train_dataloader = DataLoader(
    train_dataset,
    batch_size=TRAIN_BATCH_SIZE,
    shuffle=True,
    pin_memory=True,
    num_workers=NUM_WORKERS,
    persistent_workers=PERSISTENT_WORKERS,
)
val_dataloader = DataLoader(
    val_dataset,
    batch_size=VAL_BATCH_SIZE,
    shuffle=False,
    pin_memory=True,
    num_workers=NUM_WORKERS,
    persistent_workers=PERSISTENT_WORKERS,
)

# Training

In [4]:
trainer = Trainer(override_log_dir=LOG_DIR)
trainer.fit(
    module=module,
    dataloader=train_dataloader,
    validate_dataloader=val_dataloader,
    max_epochs=20,
    save_every_n_epoch=5,
    save_settings=True,
    save_scalars=True,
    save_figures=True,
)

[1] Training: 100% 1/1 [00:16<00:00, 16.01s/it]
[1] Validating: 100% 1/1 [00:05<00:00,  5.80s/it]
[2] Training: 100% 1/1 [00:00<00:00,  1.08it/s]
[2] Validating: 100% 1/1 [00:00<00:00,  2.42it/s]
[3] Training: 100% 1/1 [00:00<00:00,  1.10it/s]
[3] Validating: 100% 1/1 [00:00<00:00,  2.29it/s]
[4] Training: 100% 1/1 [00:00<00:00,  1.31it/s]
[4] Validating: 100% 1/1 [00:00<00:00,  3.05it/s]
[5] Training: 100% 1/1 [00:00<00:00,  1.55it/s]
[5] Validating: 100% 1/1 [00:00<00:00,  2.33it/s]
[6] Training: 100% 1/1 [00:00<00:00,  1.48it/s]
[6] Validating: 100% 1/1 [00:00<00:00,  1.75it/s]
[7] Training: 100% 1/1 [00:00<00:00,  1.33it/s]
[7] Validating: 100% 1/1 [00:00<00:00,  2.00it/s]
[8] Training: 100% 1/1 [00:00<00:00,  1.09it/s]
[8] Validating: 100% 1/1 [00:00<00:00,  1.92it/s]
[9] Training: 100% 1/1 [00:00<00:00,  1.54it/s]
[9] Validating: 100% 1/1 [00:00<00:00,  1.81it/s]
[10] Training: 100% 1/1 [00:00<00:00,  1.26it/s]
[10] Validating: 100% 1/1 [00:00<00:00,  1.66it/s]
[11] Training: 100

# Resuming

In [6]:
path = {
    "network": os.path.join(LOG_DIR, "checkpoints//20//network.pth.tar"),
    "optimizer": os.path.join(LOG_DIR, "checkpoints//20//optimizer.pth.tar"),
}

trainer = Trainer(override_log_dir=LOG_DIR)
trainer.fit(
    module,
    train_dataloader,
    val_dataloader,
    max_epochs=5,
    checkpoint_path=path,
    save_scalars=True,
    save_figures=True,
)

[21] Training: 100% 1/1 [00:00<00:00,  1.00it/s]
[21] Validating: 100% 1/1 [00:00<00:00,  2.00it/s]
[22] Training: 100% 1/1 [00:00<00:00,  1.36it/s]
[22] Validating: 100% 1/1 [00:00<00:00,  2.16it/s]
[23] Training: 100% 1/1 [00:00<00:00,  1.16it/s]
[23] Validating: 100% 1/1 [00:00<00:00,  2.46it/s]
[24] Training: 100% 1/1 [00:00<00:00,  1.60it/s]
[24] Validating: 100% 1/1 [00:00<00:00,  2.26it/s]
[25] Training: 100% 1/1 [00:00<00:00,  1.08it/s]
[25] Validating: 100% 1/1 [00:00<00:00,  2.38it/s]


# Validating

In [7]:
# if there is no prior training happened
# trainer = Trainer() 

path = {
    "network": os.path.join(LOG_DIR, "checkpoints//25//network.pth.tar"),
    "optimizer": os.path.join(LOG_DIR, "checkpoints//25//optimizer.pth.tar"),
}

results = trainer.validate(
    val_dataloader,
    module,
    checkpoint_path=path,
)
print(results)

Validating: 100% 1/1 [00:00<00:00, 21.27it/s]

{'loss': [tensor(0.6908, device='cuda:0')]}





# Testing

In [8]:
# if there is no prior training happened
# trainer = Trainer() 

path = {
    "network": os.path.join(LOG_DIR, "checkpoints//25//network.pth.tar"),
    "optimizer": os.path.join(LOG_DIR, "checkpoints//25//optimizer.pth.tar"),
}

results = trainer.test(
    val_dataloader,
    module,
    checkpoint_path=path,
)
print(results)

Testing: 100% 1/1 [00:00<00:00, 20.00it/s]

{'loss': [tensor(0.6908, device='cuda:0')]}



