In [34]:
import torch
from unet import UNet
from utils.data_loading import BasicDataset, CarvanaDataset
from pathlib import Path
import os
from torch.utils.data import DataLoader
import torch.nn.functional as F
# from utils.dice_score import multiclass_dice_coeff, dice_coeff
import numpy as np
from torch import Tensor

In [2]:
print(torch.cuda.is_available())
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

True


In [3]:
model = UNet(n_channels=3, n_classes=9, bilinear=False)
model = model.to(memory_format=torch.channels_last)

In [4]:
state_dict = torch.load('../checkpoints/checkpoint_epoch5.pth', map_location=device)
del state_dict['mask_values']
model.load_state_dict(state_dict)

<All keys matched successfully>

In [5]:
model.to(device=device)

UNet(
  (inc): DoubleConv(
    (double_conv): Sequential(
      (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): ReLU(inplace=True)
    )
  )
  (down1): Down(
    (maxpool_conv): Sequential(
      (0): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (1): DoubleConv(
        (double_conv): Sequential(
          (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU(inplace=True)
          (3): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
 

In [14]:
dir_img = Path('./imgs/')
dir_mask = Path('./masks/')
img_scale = 0.5
try:
    dataset = CarvanaDataset(dir_img, dir_mask, img_scale)
except (AssertionError, RuntimeError, IndexError):
    dataset = BasicDataset(dir_img, dir_mask, img_scale)

In [15]:
batch_size = 8
loader_args = dict(batch_size=batch_size, num_workers=os.cpu_count(), pin_memory=True)
dataset_loader = DataLoader(dataset, shuffle=True, **loader_args)

In [55]:
def dice_coeff(input: Tensor, target: Tensor, reduce_batch_first: bool = False, epsilon: float = 1e-6):
    print(input.shape, target.shape, '===222')
    # Average of Dice coefficient for all batches, or for a single mask
    assert input.size() == target.size()
    assert input.dim() == 3 or not reduce_batch_first

    sum_dim = (-1, -2) if input.dim() == 2 or not reduce_batch_first else (-1, -2, -3)

    inter = 2 * (input * target).sum(dim=sum_dim)
    sets_sum = input.sum(dim=sum_dim) + target.sum(dim=sum_dim)
    sets_sum = torch.where(sets_sum == 0, inter, sets_sum)

    dice = (inter + epsilon) / (sets_sum + epsilon)
    print(min(inter, epsilon), '=====================')
    return dice.mean()


def multiclass_dice_coeff(input: Tensor, target: Tensor, reduce_batch_first: bool = False, epsilon: float = 1e-6):
    print(input.shape, target.shape, '===111')
    # Average of Dice coefficient for all classes
    return dice_coeff(input.flatten(0, 1), target.flatten(0, 1), reduce_batch_first, epsilon)


In [56]:
amp = True
def evaluate(net, dataloader, device, amp):
    net.eval()
    num_val_batches = len(dataloader)
    dice_score = {}
    dice_score['a'] = 0


    # iterate over the validation set
    with torch.autocast(device.type if device.type != 'mps' else 'cpu', enabled=amp):
        for batch in dataloader:
            image, mask_true = batch['image'], batch['mask']

            # move images and labels to correct device and type
            image = image.to(device=device, dtype=torch.float32, memory_format=torch.channels_last)
            mask_true = mask_true.to(device=device, dtype=torch.long)

            # predict the mask
            mask_pred = net(image)

            if net.n_classes == 1:
                assert mask_true.min() >= 0 and mask_true.max() <= 1, 'True mask indices should be in [0, 1]'
                mask_pred = (F.sigmoid(mask_pred) > 0.5).float()
                # compute the Dice score
                dice_score += dice_coeff(mask_pred, mask_true, reduce_batch_first=False)
            else:
                assert mask_true.min() >= 0 and mask_true.max() < net.n_classes, 'True mask indices should be in [0, n_classes['
                # convert to one-hot format
                mask_true = F.one_hot(mask_true, net.n_classes).permute(0, 3, 1, 2).float()
                mask_pred = F.one_hot(mask_pred.argmax(dim=1), net.n_classes).permute(0, 3, 1, 2).float()
                for label in range(net.n_classes):
                    if (label in dice_score):
                        dice_score[label] +=  multiclass_dice_coeff(mask_pred[:, label], mask_true[:, label], reduce_batch_first=False)
                    else:
                        dice_score[label] = multiclass_dice_coeff(mask_pred[:, label], mask_true[:, label], reduce_batch_first=False)
                # compute the Dice score, ignoring background
                dice_score['a'] += multiclass_dice_coeff(mask_pred[:, 1:], mask_true[:, 1:], reduce_batch_first=False)
            break
    net.train()
    for label in dice_score:
        dice_score[label] /= max(num_val_batches, 1)
    return dice_score

val_score = evaluate(model, dataset_loader, device, amp)

AssertionError: Caught AssertionError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "C:\Users\Liuyonghui\AppData\Local\Packages\PythonSoftwareFoundation.Python.3.10_qbz5n2kfra8p0\LocalCache\local-packages\Python310\site-packages\torch\utils\data\_utils\worker.py", line 302, in _worker_loop
    data = fetcher.fetch(index)
  File "C:\Users\Liuyonghui\AppData\Local\Packages\PythonSoftwareFoundation.Python.3.10_qbz5n2kfra8p0\LocalCache\local-packages\Python310\site-packages\torch\utils\data\_utils\fetch.py", line 58, in fetch
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "C:\Users\Liuyonghui\AppData\Local\Packages\PythonSoftwareFoundation.Python.3.10_qbz5n2kfra8p0\LocalCache\local-packages\Python310\site-packages\torch\utils\data\_utils\fetch.py", line 58, in <listcomp>
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "D:\project\model\data\utils\data_loading.py", line 99, in __getitem__
    assert len(img_file) == 1, f'Either no image or multiple images found for the ID {name}: {img_file}'
AssertionError: Either no image or multiple images found for the ID irreg_basemap_56478: []


In [9]:
print(val_score)

{'a': tensor(0.0001, device='cuda:0'), 0: tensor(0.0001, device='cuda:0'), 1: tensor(1.8779e-05, device='cuda:0'), 2: tensor(0.0001, device='cuda:0'), 3: tensor(3.7300e-12, device='cuda:0'), 4: tensor(0.0001, device='cuda:0'), 5: tensor(0.0001, device='cuda:0'), 6: tensor(0.0001, device='cuda:0'), 7: tensor(0.0001, device='cuda:0'), 8: tensor(0.0001, device='cuda:0')}
