In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, random_split
from tqdm import tqdm

from unet.dataset import DeadwoodDataset
from unet.dice_score import dice_loss
from unet.unet_model import UNet

In [None]:
dataset = DeadwoodDataset(
    "data/train_batch/images",
    "data/train_batch/masks",
)

In [None]:
# model params
epochs: int = 5
batch_size: int = 10
learning_rate: float = 1e-5
val_percent: float = 0.1
save_checkpoint: bool = True
img_scale: float = 0.5
amp: bool = False
weight_decay: float = 1e-8
momentum: float = 0.999
gradient_clipping: float = 1.0

In [None]:
n_val = int(len(dataset) * val_percent)
n_train = len(dataset) - n_val
train_set, val_set = random_split(
    dataset, [n_train, n_val], generator=torch.Generator().manual_seed(0)
)

In [None]:
loader_args = dict(batch_size=batch_size, num_workers=2, pin_memory=True)
train_loader = DataLoader(
    dataset, shuffle=True, batch_size=batch_size, num_workers=2, pin_memory=True
)
val_loader = DataLoader(
    val_set, shuffle=True, batch_size=batch_size, num_workers=2, pin_memory=True
)

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# model with three input channels (RGB)
model = UNet(n_channels=3, n_classes=1, bilinear=True)

# loss function (binary cross entropy)
criterion = nn.BCEWithLogitsLoss()

# optimizer
optimizer = torch.optim.RMSprop(
    model.parameters(), lr=learning_rate, weight_decay=weight_decay, momentum=momentum
)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, "max", patience=5)
grad_scaler = torch.cuda.amp.grad_scaler.GradScaler(enabled=amp)

In [None]:
global_step = 0
for epoch in range(1, epochs + 1):
    model.train()
    epoch_loss = 0

    with tqdm(total=n_train, desc=f"Epoch {epoch}/{epochs}", unit="img") as pbar:
        for images, masks_true in train_loader:
            images = images.to(
                device=device, dtype=torch.float32, memory_format=torch.channels_last
            )
            masks_true = masks_true.to(device=device, dtype=torch.long).squeeze(1)

            with torch.amp.autocast(
                device.type if device.type != "mps" else "cpu", enabled=amp
            ):
                masks_pred = model(images).squeeze(1)

                loss = criterion(masks_pred.squeeze(1), masks_true.float())
                loss += dice_loss(
                    F.sigmoid(masks_pred.squeeze(1)),
                    masks_true.float(),
                    multiclass=False,
                )
                optimizer.zero_grad(set_to_none=True)
                grad_scaler.scale(loss).backward()
                torch.nn.utils.clip_grad.clip_grad_norm_(
                    model.parameters(), gradient_clipping
                )
                grad_scaler.step(optimizer)
                grad_scaler.update()

                pbar.update(images.shape[0])
                global_step += 1
                epoch_loss += loss.item()
                pbar.set_postfix(**{"loss (batch)": loss.item()})