In [1]:
import os
import re
import glob

import numpy as np
from PIL import Image
from tqdm import tqdm
import matplotlib.pyplot as plt


import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Subset

from torchvision import transforms
from torchvision.transforms import v2
from torch.utils.data import Dataset, DataLoader

from torchinfo import summary

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

device

device(type='cuda')

### Dataset

In [3]:
class CAGDataset(Dataset) :
    def __init__(self, img_path, mask_path, transform = None) :
        self.img_path = img_path
        self.mask_path = mask_path
        self.transform = transform
        
    def __len__(self) :
        return len(self.img_path)
    
    def __getitem__(self, index) :
        img = Image.open(self.img_path[index]).convert('RGB')
        mask = Image.open(self.mask_path[index]).convert('L')
        
        if self.transform :
            img = self.transform(img)
        
        mask_transform = v2.Compose([
            v2.ToTensor(),
            v2.Resize((224, 224)),
        ])
        mask = mask_transform(mask)
        
        return img, mask

### U-Net++

In [4]:
class ConvBlock(nn.Module) :
    def __init__(self, in_channels, out_channels) :
        super().__init__()
        self.block = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size = 3, padding = 1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace = True),
            nn.Conv2d(out_channels, out_channels, kernel_size = 3, padding = 1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace = True),
        )

    def forward(self, x) :
        return self.block(x)

class UNetPP(nn.Module) :
    def __init__(self, in_channels, out_channels, base_channels = 64, deep_supervision = False):
        super().__init__()
        self.deep_supervision = deep_supervision

        # encoder
        self.conv00 = ConvBlock(in_channels, base_channels)
        self.pool0 = nn.MaxPool2d(2)
        self.conv10 = ConvBlock(base_channels, base_channels * 2)
        self.pool1 = nn.MaxPool2d(2)
        self.conv20 = ConvBlock(base_channels * 2, base_channels * 4)
        self.pool2 = nn.MaxPool2d(2)
        self.conv30 = ConvBlock(base_channels * 4, base_channels * 8)
        self.pool3 = nn.MaxPool2d(2)
        self.conv40 = ConvBlock(base_channels * 8, base_channels * 16)

        # decoder - nested blocks
        self.up01 = ConvBlock(base_channels + base_channels * 2, base_channels)
        self.up11 = ConvBlock(base_channels * 2 + base_channels * 4, base_channels * 2)
        self.up21 = ConvBlock(base_channels * 4 + base_channels * 8, base_channels * 4)
        self.up31 = ConvBlock(base_channels * 8 + base_channels * 16, base_channels * 8)

        self.up02 = ConvBlock(base_channels * 2 + base_channels + base_channels, base_channels)
        self.up12 = ConvBlock(base_channels * 4 + base_channels * 2 + base_channels * 2, base_channels * 2)
        self.up22 = ConvBlock(base_channels * 8 + base_channels * 4 + base_channels * 4, base_channels * 4)

        self.up03 = ConvBlock(base_channels * 2 + base_channels + base_channels + base_channels, base_channels)
        self.up13 = ConvBlock(base_channels * 4 + base_channels * 2 + base_channels * 2 + base_channels * 2, base_channels * 2)

        self.up04 = ConvBlock(base_channels * 2 + base_channels + base_channels + base_channels + base_channels, base_channels)

        # output layer(s)
        if self.deep_supervision :
            self.final1 = nn.Conv2d(base_channels, out_channels, kernel_size = 1)
            self.final2 = nn.Conv2d(base_channels, out_channels, kernel_size = 1)
            self.final3 = nn.Conv2d(base_channels, out_channels, kernel_size = 1)
            self.final4 = nn.Conv2d(base_channels, out_channels, kernel_size = 1)
        else:
            self.final = nn.Conv2d(base_channels, out_channels, kernel_size = 1)

    def forward(self, x) :
        x00 = self.conv00(x)
        x10 = self.conv10(self.pool0(x00))
        x20 = self.conv20(self.pool1(x10))
        x30 = self.conv30(self.pool2(x20))
        x40 = self.conv40(self.pool3(x30))

        x01 = self.up01(torch.cat([x00, F.interpolate(x10, scale_factor = 2, mode = 'bilinear', align_corners = True)], dim = 1))
        x11 = self.up11(torch.cat([x10, F.interpolate(x20, scale_factor = 2, mode = 'bilinear', align_corners = True)], dim = 1))
        x21 = self.up21(torch.cat([x20, F.interpolate(x30, scale_factor = 2, mode = 'bilinear', align_corners = True)], dim = 1))
        x31 = self.up31(torch.cat([x30, F.interpolate(x40, scale_factor = 2, mode = 'bilinear', align_corners = True)], dim = 1))

        x02 = self.up02(torch.cat([x00, x01, F.interpolate(x11, scale_factor = 2, mode = 'bilinear', align_corners = True)], dim = 1))
        x12 = self.up12(torch.cat([x10, x11, F.interpolate(x21, scale_factor = 2, mode = 'bilinear', align_corners = True)], dim = 1))
        x22 = self.up22(torch.cat([x20, x21, F.interpolate(x31, scale_factor = 2, mode = 'bilinear', align_corners = True)], dim = 1))

        x03 = self.up03(torch.cat([x00, x01, x02, F.interpolate(x12, scale_factor = 2, mode = 'bilinear', align_corners = True)], dim = 1))
        x13 = self.up13(torch.cat([x10, x11, x12, F.interpolate(x22, scale_factor = 2, mode = 'bilinear', align_corners = True)], dim = 1))

        x04 = self.up04(torch.cat([x00, x01, x02, x03, F.interpolate(x13, scale_factor = 2, mode = 'bilinear', align_corners = True)], dim = 1))

        if self.deep_supervision:
            return [
                self.final1(x01),
                self.final2(x02),
                self.final3(x03),
                self.final4(x04),
            ]
        else:
            return self.final(x04)

In [5]:
ex = UNetPP(in_channels = 3, out_channels = 1)
# summary(ex, input_size = (1, 3, 224, 224), device='cpu')

ex(torch.randn(1, 3, 224, 224)).shape

torch.Size([1, 1, 224, 224])

### DataLoader

In [6]:
img_path0 = "/project/image/ARCADE"
mask_path0 = "/project/mask/0521"

In [7]:
# 자연 정렬 함수
def natural_key(text):
    return [int(t) if t.isdigit() else t.lower() for t in re.split('(\d+)', text)]

img_path = glob.glob(os.path.join(img_path0, "*.png"))
img_path = sorted(img_path, key = natural_key)
mask_path = glob.glob(os.path.join(mask_path0, "*.png"))
mask_path = sorted(mask_path, key = natural_key)

In [8]:
CAG_transform = v2.Compose([
    v2.ToTensor(),
    v2.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    v2.Resize((224, 224)),
])



In [9]:
CAG_dataset = CAGDataset(img_path, mask_path, transform = CAG_transform)
train_dataset = Subset(CAG_dataset, range(0, 70))
valid_dataset = Subset(CAG_dataset, range(70, 90))
test_dataset = Subset(CAG_dataset, range(90, 100))

batch_size = 8

train_loader = DataLoader(train_dataset, batch_size, shuffle = True)
valid_loader = DataLoader(valid_dataset, batch_size, shuffle = False)
test_loader = DataLoader(test_dataset, batch_size, shuffle = False)

In [10]:
# for image, mask in train_loader :
#     print(f"image shape : {image.shape}\nmask shape : {mask.shape}")
    
#     img = image[0].squeeze(0)
#     msk = mask[0]
    
#     img_np = img.permute(1, 2, 0).cpu().numpy()
#     msk_np = msk.permute(1, 2, 0).cpu().numpy()
    
#     plt.figure(dpi = 128)
#     plt.subplot(121)
#     plt.imshow(img_np, cmap = "gray")
#     plt.subplot(122)
#     plt.imshow(msk_np, cmap = "gray")
#     plt.show()
#     break

### Train

In [11]:
model = UNetPP(in_channels = 3, out_channels = 1).to(device)

n_epochs = 3
criterion = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters(), lr = 1e-6)

In [12]:
train_list = []
valid_list = []

for epoch in range(n_epochs) :
        model.train()
        train_loss = 0

        for image, mask in tqdm(train_loader, desc = f"Train - Epoch {epoch + 1}") :
                image = image.to(device)
                mask = mask.to(device)
                
                output = model(image)
                loss = criterion(output, mask)
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                
                train_loss += loss.item() * image.size(0)
        
        train_loss /= len(train_loader.dataset)
        train_list.append(train_loss)

        model.eval()
        valid_loss = 0
        with torch.no_grad() :
                for image, mask in tqdm(valid_loader, desc = f"Valid - Epoch {epoch + 1}") :
                        image = image.to(device)
                        mask = mask.to(device)
                        
                        output = model(image)
                        loss = criterion(output, mask)
                        
                        valid_loss += loss.item() * image.size(0)
        
        valid_loss /= len(valid_loader.dataset)
        valid_list.append(valid_loss)
        
        print(f'Epoch {epoch + 1} : train_loss = {train_loss:.4f}, val_loss = {valid_loss:.4f}')  

Train - Epoch 1:   0%|          | 0/9 [00:00<?, ?it/s]


IndexError: list index out of range

In [None]:
plt.figure(figsize = (12, 6))

width = range(1, n_epochs + 1)
plt.plot(width, train_list, label = "Train Loss")
plt.plot(width, valid_list, label = "Valid Loss")
plt.xlabel("Epochs")
plt.ylabel("Loss")
plt.title("Loss")
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.show()

### Test

In [19]:
def compute_metrics(preds, masks, smooth = 1e-6) :
    preds_flat = preds.view(preds.size(0), -1)
    masks_flat = masks.view(masks.size(0), -1)
    intersection = (preds_flat * masks_flat).sum(1)
    union = preds_flat.sum(1) + masks_flat.sum(1) - intersection
    iou = (intersection + smooth) / (union + smooth)
    jaccard_distance = 1 - iou
    dice = (2 * intersection + smooth) / (preds_flat.sum(1) + masks_flat.sum(1) + smooth)
    return iou.mean().item(), jaccard_distance.mean().item(), dice.mean().item()

In [None]:
model.eval()
Dice_scores = []
IoU_scores = []
Jaccard_distances = []

last_image = None
last_pred = None
last_mask = None

with torch.no_grad() :
    for image, mask in tqdm(test_loader, desc = "Test") :
        image = image.to(device)
        mask = mask.to(device)
        
        output = torch.sigmoid(model(image))
        pred = (output > 0.5).float()
        
        iou, jdist, dice = compute_metrics(pred, mask)
        IoU_scores.append(iou)
        Jaccard_distances.append(jdist)
        Dice_scores.append(dice)
        
        last_image = image[-1].cpu()
        last_pred = pred[-1].cpu()
        last_mask = mask[-1].cpu()
        

### Visualize

In [None]:
print(f"last_image : {last_image.shape}")
print(f"last_mask : {last_mask.shape}")
print(f"last_pred : {last_pred.shape}")

In [None]:
plt.figure(figsize=(12, 6))

plt.subplot(131)
plt.imshow(last_image.permute(1, 2, 0).numpy(), cmap = 'gray')
plt.title("Input Image")

plt.subplot(132)
plt.imshow(last_mask.squeeze().numpy(), cmap = 'gray')
plt.title("Ground Truth Mask")

plt.subplot(133)
plt.imshow(last_pred.squeeze().numpy(), cmap = 'gray')
plt.title("Predicted Mask")

plt.tight_layout()
plt.show()

In [None]:
plt.figure(figsize = (12, 6))
plt.boxplot([Dice_scores, IoU_scores, Jaccard_distances],
            labels=['Dice', 'IoU', 'Jaccard Distance'])

plt.title('Metric')
plt.ylabel('Score')
plt.grid(True)
plt.tight_layout()
plt.show()

In [None]:
plt.figure(figsize=(12, 6))

plt.subplot(131)
plt.hist(Dice_scores, bins = 20, color = 'skyblue')
plt.title('Dice Score Distribution')
plt.xlabel('Dice Score')
plt.ylabel('Frequency')

plt.subplot(132)
plt.hist(IoU_scores, bins = 20, color = 'lightgreen')
plt.title('IoU Score Distribution')
plt.xlabel('IoU Score')

plt.subplot(133)
plt.hist(Jaccard_distances, bins = 20, color = 'salmon')
plt.title('Jaccard Distance Distribution')
plt.xlabel('Jaccard Distance')

plt.tight_layout()
plt.show()
