In [None]:
from google.colab import drive
drive.mount("/content/drive/")

Mounted at /content/drive/


In [None]:
!pip install import_ipynb

Collecting import_ipynb
  Downloading import_ipynb-0.1.4-py3-none-any.whl (4.1 kB)
Collecting jedi>=0.16 (from IPython->import_ipynb)
  Downloading jedi-0.19.1-py2.py3-none-any.whl (1.6 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.6/1.6 MB[0m [31m30.3 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: jedi, import_ipynb
Successfully installed import_ipynb-0.1.4 jedi-0.19.1


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.transforms import v2
import import_ipynb

import os
from os import listdir
from os.path import join

import numpy as np
import matplotlib.pyplot as plt
from PIL import Image

In [None]:
!cp U-Net/Unet_torch.ipynb .
from Unet_torch import UNet, conv_block

importing Jupyter notebook from Unet_torch.ipynb


In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device

device(type='cuda', index=0)

In [None]:
data_dir =  "U-Net/data"
ckpt_dir =  "U-Net/checkpoints"
val_dir = join(data_dir, "test_monochromatic")

In [None]:
def pad_int(run_id, zfill=4):
  # pad id number to 4 digits
  return str(run_id).zfill(zfill)

def get_fname(fnumber, img_bool):
  fnum = pad_int(fnumber)
  return f"{fnum}_img.png" if img_bool else f"{fnum}_msk.png"

In [None]:
def dice_loss(
    input: torch.Tensor,
    target: torch.Tensor,
    multiclass: bool=True,
    epsilon: float=1e-6,
):
  if multiclass:
    input = input.flatten(0,1)
    target = target.flatten(0,1)

  sum_dim = (-1,-2,-3) if len(input.shape) == 3 else (-1,-2)
  inter = 2 * (input * target).sum(dim=sum_dim)
  sets_sum = input.sum(dim=sum_dim) + target.sum(dim=sum_dim)
  sets_sum = torch.where(sets_sum == 0, inter, sets_sum)

  dice = (inter + epsilon) / (sets_sum + epsilon)

  return 1-dice.mean()

In [None]:
def test_model(
    model,
    device,
):
  model.eval()
  num_val_images = len(listdir(val_dir)) // 2
  criterion = nn.CrossEntropyLoss() if model.num_classes > 1 else nn.BCEWithLogitsLoss()
  val_loss = 0

  # iterate over validation images
  for i in range(num_val_images):
    img_name = get_fname(i, img_bool=True)
    mask_truth_name = get_fname(i, img_bool=False)

    with Image.open(join(data_dir, "test_monochromatic", img_name)) as image:
      image.load()
    with Image.open(join(data_dir, "test_monochromatic", mask_truth_name)) as mask_truth:
      mask_truth.load()

    torch_tf = v2.Compose([
      v2.ToImage(),
      v2.ToDtype(torch.float32, scale=True)
    ])

    image = torch_tf(image.convert("RGB")).to(device=device)
    mask_truth = torch_tf(mask_truth.convert("P"))

    image = image.unsqueeze(0)
    mask_truth = mask_truth.to(device=device, dtype=torch.long)

    with torch.autocast(device.type if device.type != "mps" else "cpu"):
        mask_pred = model(image)
        mask_pred = mask_pred.to(device=device)
        if (model.num_classes > 1):
          loss = criterion(mask_pred, mask_truth)
          loss += dice_loss(
              F.softmax(mask_pred, dim=1).float(),
              F.one_hot(mask_truth, model.num_classes).permute(0, 3, 1, 2).float(),
              multiclass=True
          )
        else:
          loss = criterion(mask_pred.squeeze(1), mask_truth.float())
          loss += dice_loss(F.sigmoid(mask_pred.squeeze(1)), mask_truth.float(), multiclass=False)

    val_loss += loss.item()

  val_loss /= num_val_images
  print(f"Validation loss: {val_loss}")

def main():
  # Initialize model for training
  model = UNet(in_channels=3, num_classes=2).to(device)
  # load in state dictionary values
  state_dict = torch.load(join(ckpt_dir, "Unet_ckpt_epoch_20_grad_scaling_scheduler_step_monochromatic_1.pth"), map_location=device)
  model.load_state_dict(state_dict)

  test_model(model, device=device)


if __name__ == "__main__":
  main()

Validation loss: 0.18756907238137155
