In [1]:
import logging
import wandb
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, random_split, Subset
from tqdm import tqdm
import numpy as np

from unet.dataset import DeadwoodDataset
from unet.dice_score import dice_loss
from unet.unet_model import UNet
from unet.evaluate import evaluate

In [2]:
dataset = DeadwoodDataset(
    "/net/scratch/jmoehring/tiles/images",
    "/net/scratch/jmoehring/tiles/masks",
)

In [3]:
# model params
epochs: int = 5
batch_size: int = 10
learning_rate: float = 1e-5
val_percent: float = 0.1
save_checkpoint: bool = True
img_scale: float = 0.5
amp: bool = False
weight_decay: float = 1e-8
momentum: float = 0.999
gradient_clipping: float = 1.0

In [4]:
subset_sel = [i for i in range(1000)]

In [5]:
sub_dataset = Subset(dataset, subset_sel)

In [6]:
n_val = int(len(sub_dataset) * val_percent)
n_train = len(sub_dataset) - n_val
train_set, val_set = random_split(
    sub_dataset,
    [n_train, n_val],
    generator=torch.Generator().manual_seed(0),
)

In [7]:
loader_args = dict(batch_size=batch_size, num_workers=2, pin_memory=True)
train_loader = DataLoader(
    train_set, shuffle=True, batch_size=batch_size, num_workers=2, pin_memory=True
)
val_loader = DataLoader(
    val_set, shuffle=True, batch_size=batch_size, num_workers=2, pin_memory=True
)

In [8]:
logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logging.info(f"Using device {device}")

# model with three input channels (RGB)
model = UNet(n_channels=3, n_classes=1, bilinear=True)
model = model.to(memory_format=torch.channels_last)
model.to(device=device)

# loss function (binary cross entropy)
criterion = nn.BCEWithLogitsLoss()

# optimizer
optimizer = torch.optim.RMSprop(
    model.parameters(), lr=learning_rate, weight_decay=weight_decay, momentum=momentum
)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, "max", patience=5)
grad_scaler = torch.cuda.amp.grad_scaler.GradScaler(enabled=amp)

INFO: Using device cuda


In [9]:
experiment = wandb.init(
    project="standing-deadwood-unet", resume="allow", anonymous="must"
)
experiment.config.update(
    dict(
        epochs=epochs,
        batch_size=batch_size,
        learning_rate=learning_rate,
        val_percent=val_percent,
        save_checkpoint=save_checkpoint,
        img_scale=img_scale,
        amp=amp,
    )
)

ERROR: Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mjmoehring[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [10]:
global_step = 0
for epoch in range(1, epochs + 1):
    model.train()
    epoch_loss = 0

    with tqdm(total=n_train, desc=f"Epoch {epoch}/{epochs}", unit="img") as pbar:
        for images, masks_true in train_loader:
            images = images.to(
                device=device, dtype=torch.float32, memory_format=torch.channels_last
            )
            masks_true = masks_true.to(device=device, dtype=torch.long).squeeze(1)

            with torch.amp.autocast(
                device.type if device.type != "mps" else "cpu", enabled=amp
            ):
                masks_pred = model(images).squeeze(1)

                loss = criterion(masks_pred.squeeze(1), masks_true.float())
                loss += dice_loss(
                    F.sigmoid(masks_pred.squeeze(1)),
                    masks_true.float(),
                    multiclass=False,
                )
                optimizer.zero_grad(set_to_none=True)
                grad_scaler.scale(loss).backward()
                torch.nn.utils.clip_grad.clip_grad_norm_(
                    model.parameters(), gradient_clipping
                )
                grad_scaler.step(optimizer)
                grad_scaler.update()

                pbar.update(images.shape[0])
                global_step += 1
                epoch_loss += loss.item()
                experiment.log(
                    {"train loss": loss.item(), "step": global_step, "epoch": epoch}
                )
                pbar.set_postfix(**{"loss (batch)": loss.item()})

                division_step = n_train // (5 * batch_size)
                if division_step > 0:
                    if global_step % division_step == 0:
                        histograms = {}
                        for tag, value in model.named_parameters():
                            tag = tag.replace("/", ".")
                            if not (torch.isinf(value) | torch.isnan(value)).any():
                                histograms["Weights/" + tag] = wandb.Histogram(
                                    value.data.cpu()
                                )
                            if not (
                                torch.isinf(value.grad) | torch.isnan(value.grad)
                            ).any():
                                histograms["Gradients/" + tag] = wandb.Histogram(
                                    value.grad.data.cpu()
                                )

                        val_score = evaluate(model, val_loader, device, amp)
                        scheduler.step(val_score)

                        logging.info("Validation Dice score: {}".format(val_score))

                        relevant_index = 0
                        for index, mask_true in enumerate(masks_true):
                            if 1 in mask_true:
                                relevant_index = index
                        experiment.log(
                            {
                                "learning rate": optimizer.param_groups[0]["lr"],
                                "validation Dice": val_score,
                                "images": wandb.Image(images[relevant_index].cpu()),
                                "masks": {
                                    "true": wandb.Image(
                                        masks_true[relevant_index].float().cpu()
                                    ),
                                    "pred": wandb.Image(
                                        F.sigmoid(
                                            masks_pred[relevant_index].float().cpu()
                                        )
                                    ),
                                },
                                "step": global_step,
                                "epoch": epoch,
                                **histograms,
                            }
                        )

Epoch 1/5:  20%|██        | 180/900 [00:07<00:26, 26.97img/s, loss (batch)=1.21]INFO: Validation Dice score: 0.00022963497031014413
Epoch 1/5:  40%|████      | 360/900 [00:16<00:20, 26.74img/s, loss (batch)=1.16]INFO: Validation Dice score: 0.949999988079071
Epoch 1/5:  60%|██████    | 540/900 [00:24<00:13, 26.67img/s, loss (batch)=1.09]INFO: Validation Dice score: 0.949999988079071
Epoch 1/5:  80%|████████  | 720/900 [00:33<00:06, 26.78img/s, loss (batch)=1.02] INFO: Validation Dice score: 0.7800460457801819
Epoch 1/5: 100%|██████████| 900/900 [00:42<00:00, 26.66img/s, loss (batch)=0.713]INFO: Validation Dice score: 0.4724428355693817
Epoch 1/5: 100%|██████████| 900/900 [00:44<00:00, 20.40img/s, loss (batch)=0.713]
Epoch 2/5:  20%|██        | 180/900 [00:06<00:26, 26.71img/s, loss (batch)=0.811]INFO: Validation Dice score: 0.8201643228530884
Epoch 2/5:  40%|████      | 360/900 [00:15<00:20, 26.57img/s, loss (batch)=1.03] INFO: Validation Dice score: 0.949999988079071
Epoch 2/5:  60%|█