In [1]:
import logging
import wandb
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, random_split, Subset
from tqdm import tqdm
import numpy as np
import random

from unet.dataset import DeadwoodDataset
from unet.dice_score import dice_loss
from unet.unet_model import UNet
from unet.evaluate import evaluate

In [2]:
dataset = DeadwoodDataset(
    "/net/scratch/jmoehring/tiles/images",
    "/net/scratch/jmoehring/tiles/masks",
)

In [3]:
# model params
epochs: int = 5
batch_size: int = 5
learning_rate: float = 1e-5
val_percent: float = 0.1
save_checkpoint: bool = True
img_scale: float = 0.5
amp: bool = False
weight_decay: float = 1e-8
momentum: float = 0.999
gradient_clipping: float = 1.0

In [12]:
random.seed(1)

In [13]:
sub_dataset = Subset(dataset, subset_sel)

In [14]:
n_val = int(len(sub_dataset) * val_percent)
n_train = len(sub_dataset) - n_val
train_set, val_set = random_split(
    sub_dataset,
    [n_train, n_val],
    generator=torch.Generator().manual_seed(0),
)

In [15]:
loader_args = dict(batch_size=batch_size, num_workers=2, pin_memory=True)
# train_loader = DataLoader(
#     train_set, shuffle=True, batch_size=batch_size, num_workers=2, pin_memory=True
# )
val_loader = DataLoader(
    val_set, shuffle=True, batch_size=batch_size, num_workers=2, pin_memory=True
)

In [8]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s")
logging.info(f"Using device {device}")

# model with three input channels (RGB)
model = UNet(n_channels=3, n_classes=1, bilinear=True)
model = model.to(memory_format=torch.channels_last)
model.to(device=device)

# loss function (binary cross entropy)
criterion = nn.BCEWithLogitsLoss()

# optimizer
optimizer = torch.optim.RMSprop(
    model.parameters(), lr=learning_rate, weight_decay=weight_decay, momentum=momentum
)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, "max", patience=5)
grad_scaler = torch.cuda.amp.grad_scaler.GradScaler(enabled=amp)

In [16]:
experiment = wandb.init(
    project="standing-deadwood-unet", resume="allow", anonymous="must"
)
experiment.config.update(
    dict(
        epochs=epochs,
        batch_size=batch_size,
        learning_rate=learning_rate,
        val_percent=val_percent,
        save_checkpoint=save_checkpoint,
        img_scale=img_scale,
        amp=amp,
    )
)

In [None]:
global_step = 0
for epoch in range(1, epochs + 1):
    model.train()
    epoch_loss = 0

    with tqdm(total=n_train, desc=f"Epoch {epoch}/{epochs}", unit="img") as pbar:
        for images, masks_true in train_loader:
            images = images.to(
                device=device, dtype=torch.float32, memory_format=torch.channels_last
            )
            masks_true = masks_true.to(device=device, dtype=torch.long).squeeze(1)

            with torch.amp.autocast(
                device.type if device.type != "mps" else "cpu", enabled=amp
            ):
                masks_pred = model(images).squeeze(1)

                loss = criterion(masks_pred.squeeze(1), masks_true.float())
                loss += dice_loss(
                    F.sigmoid(masks_pred.squeeze(1)),
                    masks_true.float(),
                    multiclass=False,
                )
                optimizer.zero_grad(set_to_none=True)
                grad_scaler.scale(loss).backward()
                torch.nn.utils.clip_grad.clip_grad_norm_(
                    model.parameters(), gradient_clipping
                )
                grad_scaler.step(optimizer)
                grad_scaler.update()

                pbar.update(images.shape[0])
                global_step += 1
                epoch_loss += loss.item()
                experiment.log(
                    {"train loss": loss.item(), "step": global_step, "epoch": epoch}
                )
                pbar.set_postfix(**{"loss (batch)": loss.item()})

                division_step = n_train // (5 * batch_size)
                if division_step > 0:
                    if global_step % division_step == 0:
                        histograms = {}
                        for tag, value in model.named_parameters():
                            tag = tag.replace("/", ".")
                            if not (torch.isinf(value) | torch.isnan(value)).any():
                                histograms["Weights/" + tag] = wandb.Histogram(
                                    value.data.cpu()
                                )
                            if not (
                                torch.isinf(value.grad) | torch.isnan(value.grad)
                            ).any():
                                histograms["Gradients/" + tag] = wandb.Histogram(
                                    value.grad.data.cpu()
                                )

                        val_score = evaluate(model, val_loader, device, amp)
                        scheduler.step(val_score)

                        logging.info("Validation Dice score: {}".format(val_score))

                        relevant_index = 0
                        for index, mask_true in enumerate(masks_true):
                            if 1 in mask_true:
                                relevant_index = index
                        experiment.log(
                            {
                                "learning rate": optimizer.param_groups[0]["lr"],
                                "validation Dice": val_score,
                                "images": wandb.Image(images[relevant_index].cpu()),
                                "masks": {
                                    "true": wandb.Image(
                                        masks_true[relevant_index].float().cpu()
                                    ),
                                    "pred": wandb.Image(
                                        F.sigmoid(
                                            masks_pred[relevant_index].float().cpu()
                                        )
                                    ),
                                },
                                "step": global_step,
                                "epoch": epoch,
                                **histograms,
                            }
                        )

In [None]:
torch.save(model.state_dict(), "data/models/unet_20122023")

In [10]:
inf_model = UNet(n_channels=3, n_classes=1, bilinear=True)
inf_model = inf_model.to(memory_format=torch.channels_last)
inf_model = inf_model.to(device=device)
inf_model.load_state_dict(torch.load("data/models/unet_20122023"))

<All keys matched successfully>

In [20]:
inf_model.parameters

<bound method Module.parameters of UNet(
  (inc): DoubleConv(
    (double_conv): Sequential(
      (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): ReLU(inplace=True)
    )
  )
  (down1): Down(
    (maxpool_conv): Sequential(
      (0): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (1): DoubleConv(
        (double_conv): Sequential(
          (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU(inplace=True)
          (3): Conv2d(128, 128, kernel_size=(3, 3), stride=(1

In [17]:
for images, masks_true in val_loader:
    images = images.to(
        device=device, dtype=torch.float32, memory_format=torch.channels_last
    )
    masks_true = masks_true.to(device=device, dtype=torch.long).squeeze(1)
    masks_pred = inf_model(images).squeeze(1)

    for index, mask in enumerate(masks_true):
        if 1 in mask:
            experiment.log(
                {
                    "true_image": wandb.Image(images[index].cpu()),
                    "true_mask": wandb.Image(masks_true[index].float().cpu()),
                    "pred_mask": wandb.Image(
                        (F.sigmoid(masks_pred[index]) > 0.5).float().cpu()
                    ),
                }
            )
    torch.cuda.empty_cache()