In [4]:
import logging
import wandb
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, random_split, Subset
from tqdm import tqdm
import time
from pathlib import Path

from unet.dataset import DeadwoodDataset
from unet.dice_score import dice_loss
from unet.unet_model import UNet
from unet.evaluate import evaluate

In [5]:
# data paths
images_dir = "/net/scratch/jmoehring/tiles/images"
masks_dir = "/net/scratch/jmoehring/tiles/masks"
checkpoint_dir = "/net/scratch/jmoehring/checkpoints"

# data params
no_folds: int = 228
fold: int = 0
val_percent: float = 0.1
random_seed: int = 42
batch_size: int = 16
save_checkpoint: bool = True

# model params
epochs: int = 5
learning_rate: float = 1e-5
amp: bool = False
weight_decay: float = 1e-8
momentum: float = 0.999
gradient_clipping: float = 1.0

In [7]:
dataset = DeadwoodDataset(
    images_dir,
    masks_dir,
    n_folds=no_folds,
    random_seed=random_seed,
)

In [None]:
# get first fold for training run and split into train _val and test set
train_val_set, test_set = dataset.get_fold(fold)

# split train_val_set into train and val set
n_val = int(len(train_val_set) * val_percent)
n_train = len(train_val_set) - n_val
train_set, val_set = random_split(
    train_val_set,
    [n_train, n_val],
    generator=torch.Generator().manual_seed(random_seed),
)

In [None]:
loader_args = {
    "batch_size": batch_size,
    "num_workers": 12,
    "pin_memory": True,
    "shuffle": True,
}

In [None]:
train_loader = DataLoader(train_set, **loader_args)
val_loader = DataLoader(val_set, **loader_args)
test_loader = DataLoader(test_set, **loader_args)

In [None]:
len(train_loader), len(val_loader)

In [None]:
# preferably use GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# model with three input channels (RGB)
model = UNet(n_channels=3, n_classes=1, bilinear=True)
model = model.to(memory_format=torch.channels_last)
model.to(device=device)

# loss function (binary cross entropy)
criterion = nn.BCEWithLogitsLoss()

# optimizer
optimizer = torch.optim.RMSprop(
    model.parameters(), lr=learning_rate, weight_decay=weight_decay, momentum=momentum
)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, "max", patience=5)
grad_scaler = torch.cuda.amp.grad_scaler.GradScaler(enabled=amp)

In [None]:
# initialize a new experiment run
experiment = wandb.init(
    project="standing-deadwood-unet",
    resume="allow",
    name=f"deadwood_fold_{fold}_{time.time()}",
)
experiment.config.update(
    dict(
        data=dict(
            images_dir=images_dir,
            masks_dir=masks_dir,
            no_folds=no_folds,
            fold=fold,
            val_percent=val_percent,
            random_seed=random_seed,
            batch_size=batch_size,
            save_checkpoint=save_checkpoint,
        ),
        model=dict(
            epochs=epochs,
            learning_rate=learning_rate,
            amp=amp,
            weight_decay=weight_decay,
            momentum=momentum,
            gradient_clipping=gradient_clipping,
        ),
    )
)

In [None]:
for epoch in range(1, epochs + 1):
    model.train()
    epoch_loss = 0

    with tqdm(total=n_train, desc=f"Epoch {epoch}/{epochs}", unit="img") as pbar:
        for images, masks_true in train_loader:
            images = images.to(
                device=device, dtype=torch.float32, memory_format=torch.channels_last
            )
            masks_true = masks_true.to(device=device, dtype=torch.long).squeeze(1)

            with torch.amp.autocast(
                device.type if device.type != "mps" else "cpu", enabled=amp
            ):
                masks_pred = model(images).squeeze(1)

                loss = criterion(masks_pred.squeeze(1), masks_true.float())
                loss += dice_loss(
                    F.sigmoid(masks_pred.squeeze(1)),
                    masks_true.float(),
                )
                optimizer.zero_grad(set_to_none=True)
                grad_scaler.scale(loss).backward()
                torch.nn.utils.clip_grad.clip_grad_norm_(
                    model.parameters(), gradient_clipping
                )
                grad_scaler.step(optimizer)
                grad_scaler.update()

                pbar.update(images.shape[0])
                epoch_loss += loss.item()

    val_dice_score, val_dice_loss = evaluate(model, criterion, val_loader, device, amp)
    experiment.log(
        {
            "train loss": epoch_loss / len(train_loader),
            "val loss": val_dice_loss,
            "val dice score": val_dice_score,
            "epoch": epoch,
        }
    )

    if save_checkpoint:
        Path(checkpoint_dir).mkdir(parents=True, exist_ok=True)
        state_dict = model.state_dict()
        state_dict["mask_values"] = dataset.mask_values
        torch.save(
            state_dict, str(checkpoint_dir / "checkpoint_epoch{}.pth".format(epoch))
        )
        logging.info(f"Checkpoint {epoch} saved!")