In [2]:
from data.build_loader import build_dataloader

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
data_dir = "../cv-project/MEDICAL/MEDICAL-DATASET-001/Segmentation/"
dataloader = build_dataloader(data_dir)

In [4]:
data_iter = iter(dataloader["train"])
images, targets = next(data_iter)

In [5]:
images.shape, targets.shape #plt로 찍어봐도딤

(torch.Size([4, 3, 224, 224]), torch.Size([4, 224, 224]))

In [6]:
from models.u_net import UNet

In [7]:
unet = UNet(num_classes=4)

In [8]:
unet

UNet(
  (encoder): Encoder(
    (conv_block1): Sequential(
      (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): ReLU(inplace=True)
    )
    (conv_block2): Sequential(
      (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (4): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): ReLU(inplace=True)
    )
    (conv_block3): Sequential(
      (0): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=

In [9]:
import torch

In [10]:
model = UNet(num_classes=4)
x = torch.randn(1, 3, 224, 224)
out = model(x)
print(f"input shape: {x.shape}")
print(f"output shape: {out.shape}")

input shape: torch.Size([1, 3, 224, 224])
output shape: torch.Size([1, 4, 224, 224])


In [11]:
import torch.nn as nn
import torch.nn.functional as F

class UNet_metric(): 
    def __init__(self, num_classes):
        self.num_classes = num_classes
        
    def __call__(self, pred, target):
        onehot_pred = F.one_hot(torch.argmax(pred, dim=1), num_classes=self.num_classes).permute(0, 3, 1, 2)
        onehot_target = F.one_hot(target, num_classes=self.num_classes).permute(0, 3, 1, 2)
        dice_loss = self._get_dice_loss(onehot_pred, onehot_target)
        dice_coefficient = self._get_batch_dice_coefficient(onehot_pred, onehot_target)
        return dice_loss, dice_coefficient
    
    def _get_dice_coeffient(self, pred, target):
        set_inter = torch.dot(pred.reshape(-1).float(), target.reshape(-1).float())
        set_sum = pred.sum() + target.sum()
        if set_sum.item() == 0:
            set_sum = 2 * set_inter
        dice_coeff = (2 * set_inter) / (set_sum + 1e-9)
        return dice_coeff
    
    def _get_multiclass_dice_coefficient(self, pred, target):
        dice = 0
        for class_index in range(1, self.num_classes):
            dice += self._get_dice_coeffient(pred[class_index], target[class_index])
        return dice / (self.num_classes - 1)
    
    def _get_batch_dice_coefficient(self, pred, target):
        num_batch = pred.shape[0]
        dice = 0
        for batch_index in range(num_batch):
            dice += self._get_multiclass_dice_coefficient(pred[batch_index], target[batch_index])
        return dice / num_batch
    
    def _get_dice_loss(self, pred, target):
        return 1 - self._get_batch_dice_coefficient(pred, target)

In [12]:
criterion = UNet_metric(num_classes=4)

In [13]:
pred = model(images)
loss, dice_coef = criterion(pred, targets)

In [14]:
print(pred)

tensor([[[[-5.8602e-01, -7.4348e-01, -6.5285e-01,  ..., -7.0046e-01,
           -3.5558e-01, -5.8511e-01],
          [-4.0531e-01, -8.2851e-01, -5.5545e-01,  ..., -4.2105e-01,
           -4.9174e-01, -8.9817e-01],
          [-2.5451e-02, -2.5596e-01, -1.8742e-01,  ...,  1.2177e-02,
           -1.0068e+00, -5.4857e-01],
          ...,
          [-8.4110e-01, -1.1120e+00, -2.7378e-01,  ...,  3.9139e-01,
           -5.0562e-01, -1.1774e+00],
          [-8.0285e-01, -5.4841e-01, -9.6537e-01,  ...,  3.9413e-01,
           -1.5202e+00, -5.6104e-01],
          [-9.3878e-01, -5.1638e-01, -8.6002e-01,  ..., -2.0323e-01,
           -4.6141e-02, -1.2665e+00]],

         [[ 4.8472e-01,  4.2805e-01,  7.7120e-01,  ...,  1.0628e+00,
            1.2923e+00,  1.1487e+00],
          [ 6.5095e-01, -3.1714e-01,  4.4317e-01,  ...,  1.6094e+00,
            1.0750e+00,  9.1753e-01],
          [ 9.6838e-01, -9.7852e-02,  1.0824e+00,  ...,  1.1854e+00,
            1.0457e+00,  2.0669e+00],
          ...,
     

In [15]:
import torch
x = torch.randn(4,3,224,224)
tgt = torch.randint(0,4,(4,224,224)).long()
out = unet(x)

In [16]:
loss, coef = criterion(out, tgt)

In [17]:
loss

tensor(0.7554)

In [21]:
loss.requires_grad_(True)

tensor(0.7554, requires_grad=True)

In [22]:
loss.backward()
#! 왜 안돼!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!

In [20]:
torch.__version__

'1.12.1'