In [1]:
%load_ext autoreload
%autoreload 2

import os

os.environ['KMP_DUPLICATE_LIB_OK']='True'

import albumentations as A
import torch
from torch.utils.data import DataLoader

from sidewalk_widths_extractor import SatelliteDataset, SegModule, Trainer, seed_all
from sidewalk_widths_extractor.utilities import get_device

# Parameters

In [2]:
TRAIN_BATCH_SIZE = 16
VAL_BATCH_SIZE = 4
NUM_WORKERS = 0
PERSISTENT_WORKERS = False
PIN_MEMORY = False
SPLIT_RATIO = 0.8
RANDOM_SEED = 42

network_parms = {
    "encoder_name": "resnet34",
    "encoder_weights": "imagenet",
    "in_channels": 3,
    "classes": 2,
}
optimizer_params = {"lr": 2e-4, "weight_decay": 1e-4}


device = get_device()
# device = "cpu"
print(device)

cuda:0


# Setup

In [3]:
seed_all(42)

module = SegModule(
    "unet", network_parms, "adam", optimizer_params, "wce", {"weight": [0.1, 1.0]}, device
)

transform = A.Compose(
    [
        A.HorizontalFlip(p=0.2),
        A.VerticalFlip(p=0.2),
        A.RandomBrightnessContrast(p=0.2, brightness_limit=0.2, contrast_limit=0.2),
    ]
)

train_dataset, val_dataset = SatelliteDataset.from_split(
    "data/images/",
    "data/masks/",
    split_ratio=SPLIT_RATIO,
    train_transform=transform,
    random_seed=RANDOM_SEED,
)

train_dataloader = DataLoader(
    train_dataset,
    batch_size=TRAIN_BATCH_SIZE,
    shuffle=True,
    pin_memory=PIN_MEMORY,
    num_workers=NUM_WORKERS,
    persistent_workers=PERSISTENT_WORKERS,
)
val_dataloader = DataLoader(
    val_dataset,
    batch_size=VAL_BATCH_SIZE,
    shuffle=False,
    pin_memory=PIN_MEMORY,
    num_workers=NUM_WORKERS,
    persistent_workers=PERSISTENT_WORKERS,
)

# Train

In [4]:
trainer = Trainer(log_comment="augs")
trainer.fit(
    module=module,
    dataloader=train_dataloader,
    validate_dataloader=val_dataloader,
    max_epochs=50,
    # save_every_n_epoch=5,
    save_settings=True,
    save_scalars=True,
    save_figures=False,
)

[1] augs - Training: 100% 1/1 [00:07<00:00,  7.92s/it]
[1] augs - Validating: 100% 1/1 [00:00<00:00,  2.72it/s]
[2] augs - Training: 100% 1/1 [00:00<00:00,  1.13it/s]
[2] augs - Validating: 100% 1/1 [00:00<00:00,  2.16it/s]
[3] augs - Training: 100% 1/1 [00:00<00:00,  1.24it/s]
[3] augs - Validating: 100% 1/1 [00:00<00:00,  2.06it/s]
[4] augs - Training: 100% 1/1 [00:00<00:00,  1.16it/s]
[4] augs - Validating: 100% 1/1 [00:00<00:00,  2.01it/s]
[5] augs - Training: 100% 1/1 [00:00<00:00,  1.29it/s]
[5] augs - Validating: 100% 1/1 [00:00<00:00,  2.01it/s]
[6] augs - Training: 100% 1/1 [00:00<00:00,  1.33it/s]
[6] augs - Validating: 100% 1/1 [00:00<00:00,  2.04it/s]
[7] augs - Training: 100% 1/1 [00:00<00:00,  1.46it/s]
[7] augs - Validating: 100% 1/1 [00:00<00:00,  1.87it/s]
[8] augs - Training: 100% 1/1 [00:00<00:00,  1.29it/s]
[8] augs - Validating: 100% 1/1 [00:00<00:00,  2.08it/s]
[9] augs - Training: 100% 1/1 [00:00<00:00,  1.35it/s]
[9] augs - Validating: 100% 1/1 [00:00<00:00,  2.