In [2]:
from data.build_loader import build_dataloader

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
data_dir = "../cv-project/MEDICAL/MEDICAL-DATASET-001/Segmentation/"
dataloader = build_dataloader(data_dir)

In [4]:
data_iter = iter(dataloader["train"])
images, targets = next(data_iter)

In [6]:
images.shape, targets.shape #plt로 찍어봐도딤

(torch.Size([4, 3, 224, 224]), torch.Size([4, 224, 224]))

In [5]:
from models.u_net import UNet

In [8]:
unet = UNet(num_classes=4)

In [9]:
unet

UNet(
  (encoder): Encoder(
    (conv_block1): Sequential(
      (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): ReLU(inplace=True)
    )
    (conv_block2): Sequential(
      (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (4): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): ReLU(inplace=True)
    )
    (conv_block3): Sequential(
      (0): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=

In [11]:
import torch

In [12]:
model = UNet(num_classes=4)
x = torch.randn(1, 3, 224, 224)
out = model(x)
print(f"input shape: {x.shape}")
print(f"output shape: {out.shape}")

input shape: torch.Size([1, 3, 224, 224])
output shape: torch.Size([1, 4, 224, 224])


In [13]:
import torch.nn as nn
import torch.nn.functional as F

class UNet_metric(): 
    def __init__(self, num_classes):
        self.num_classes = num_classes
        
    def __call__(self, pred, target):
        onehot_pred = F.one_hot(torch.argmax(pred, dim=1), num_classes=self.num_classes).permute(0, 3, 1, 2)
        onehot_target = F.one_hot(target, num_classes=self.num_classes).permute(0, 3, 1, 2)
        dice_loss = self._get_dice_loss(onehot_pred, onehot_target)
        dice_coefficient = self._get_batch_dice_coefficient(onehot_pred, onehot_target)
        return dice_loss, dice_coefficient
    
    def _get_dice_coeffient(self, pred, target):
        set_inter = torch.dot(pred.reshape(-1).float(), target.reshape(-1).float())
        set_sum = pred.sum() + target.sum()
        if set_sum.item() == 0:
            set_sum = 2 * set_inter
        dice_coeff = (2 * set_inter) / (set_sum + 1e-9)
        return dice_coeff
    
    def _get_multiclass_dice_coefficient(self, pred, target):
        dice = 0
        for class_index in range(1, self.num_classes):
            dice += self._get_dice_coeffient(pred[class_index], target[class_index])
        return dice / (self.num_classes - 1)
    
    def _get_batch_dice_coefficient(self, pred, target):
        num_batch = pred.shape[0]
        dice = 0
        for batch_index in range(num_batch):
            dice += self._get_multiclass_dice_coefficient(pred[batch_index], target[batch_index])
        return dice / num_batch
    
    def _get_dice_loss(self, pred, target):
        return 1 - self._get_batch_dice_coefficient(pred, target)

In [14]:
criterion = UNet_metric(num_classes=4)

In [17]:
pred = model(images)
loss, dice_coef = criterion(pred, targets)

In [18]:
print(pred)

tensor([[[[ 4.5513e-02, -5.2327e-02, -1.4683e-02,  ..., -2.5909e-01,
           -1.3941e-01, -2.6723e-01],
          [ 3.4005e-01,  6.4137e-01,  1.8682e-01,  ...,  5.8093e-01,
            2.2707e-02, -1.8473e-02],
          [ 5.9321e-01,  2.7806e-01,  2.5347e-01,  ...,  1.4982e-01,
            4.6449e-01,  1.1967e-01],
          ...,
          [ 3.6349e-02,  1.7516e-01,  4.3432e-01,  ...,  6.6158e-01,
            2.3300e-01,  1.0300e-01],
          [-1.2670e-01,  5.9507e-01,  2.9414e-01,  ..., -1.4381e-02,
            3.4664e-01,  2.1866e-01],
          [ 4.3710e-01,  5.2479e-01,  2.0057e-01,  ...,  9.1162e-01,
            2.6887e-01,  9.6241e-03]],

         [[-4.1189e-01, -8.4309e-01,  1.1556e-01,  ..., -6.1461e-01,
           -8.1351e-01, -5.3761e-01],
          [-1.2688e-01, -1.5124e-01, -7.1890e-01,  ..., -4.3473e-01,
           -2.1733e-01, -5.7132e-01],
          [-2.5008e-01, -5.0638e-02, -5.5854e-02,  ..., -5.7177e-01,
           -8.9847e-01, -4.1027e-01],
          ...,
     

In [19]:
loss

tensor(0.9790)