In [1]:
%matplotlib inline
import os
import subprocess

import numpy as np
import matplotlib.pyplot as plt
import torch

from trainers import SegmentationTask, SegmentationDataModule
import pytorch_lightning as pl

from pytorch_lightning import loggers as pl_loggers
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint

In [2]:
image_fns = [
    "data/imagery/16_pre_imagery_cropped.tif",
]

mask_fns = [
    "data/masks/16_pre_imagery_cropped_mask_buffered.tif",
]

In [3]:
dm = SegmentationDataModule(
    image_fns={"train": image_fns, "valid": image_fns, "test": image_fns},
    mask_fns={"train": mask_fns, "valid": mask_fns, "test": mask_fns},
    batch_size=24,
    patch_size=512,
    num_workers=6,
    batches_per_epoch=256,
)

In [4]:
task = SegmentationTask(
    segmentation_model="unet",
    encoder_name="resnet18",
    encoder_weights="imagenet", # use None for random weight init
    loss="ce",
    learning_rate=0.001,
    learning_rate_schedule_patience=6,
    optimizer="adamw",
    weight_decay=0.01,
)

In [8]:
[torch.cuda.device(i) for i in range(torch.cuda.device_count())]

[]

In [15]:
print(torch.__version__)

1.10.2


In [14]:
print(torch.cuda.is_available())

False


In [5]:
log_dir = "output/logs/"
output_dir = "output/runs/"
experiment_name = "unet-resnet18-imagenet-lr_0.001"
experiment_dir = os.path.join(output_dir, experiment_name)

tb_logger = pl_loggers.TensorBoardLogger(log_dir, name=experiment_name)

checkpoint_callback = ModelCheckpoint( 
    monitor="val_loss",
    dirpath=experiment_dir,
    save_top_k=12,
    save_last=True,
)
early_stopping_callback = EarlyStopping(
    monitor="val_loss",
    min_delta=0.00,
    patience=18,
)

trainer_args = {
    "callbacks": [checkpoint_callback, early_stopping_callback],
    "logger": tb_logger,
    "default_root_dir": experiment_dir,
    "max_epochs": 15,
}

trainer = pl.Trainer(**trainer_args)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs


In [None]:
trainer.fit(model=task, datamodule=dm)


  | Name                | Type                   | Params
---------------------------------------------------------------
0 | model               | Unet                   | 14.3 M
1 | loss                | CrossEntropyLoss       | 0     
2 | train_augmentations | AugmentationSequential | 0     
3 | train_metrics       | MetricCollection       | 0     
4 | val_metrics         | MetricCollection       | 0     
5 | test_metrics        | MetricCollection       | 0     
6 | loss1               | CrossEntropyLoss       | 0     
7 | loss2               | TverskyLoss            | 0     
---------------------------------------------------------------
14.3 M    Trainable params
0         Non-trainable params
14.3 M    Total params
57.315    Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]

  rank_zero_warn(


Training: 0it [00:00, ?it/s]