### Data Load

In [None]:
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import transforms
from PIL import Image  
import os
from torchvision.transforms import functional as TF
import random

# transform 정의
class ImageMaskTransform:

    def __init__(self):
        self.resize = transforms.Resize((384, 384))
        self.rotation = transforms.RandomRotation(degrees=30)
        self.h_flip = transforms.RandomHorizontalFlip(p=0.5)
        self.v_flip = transforms.RandomVerticalFlip(p=0.5)
        self.affine = transforms.RandomAffine(degrees=0, scale=(0.8, 1.2), translate=(0.1, 0.1))
        self.color_jitter = transforms.ColorJitter(brightness=0.4)
        self.crop = transforms.RandomResizedCrop(size=(384, 384))

    def __call__(self, image, mask):
        image = self.resize(image)
        mask = self.resize(mask)

        angle = random.uniform(-5, 5)
        image = TF.rotate(image, angle)
        mask = TF.rotate(mask, angle)

        if random.random() > 0.5:
            image = TF.hflip(image)
            mask = TF.hflip(mask)

        if random.random() > 0.5:
            image = TF.vflip(image)
            mask = TF.vflip(mask)

        angle, translations, scale, shear = self.affine.get_params(self.affine.degrees, self.affine.translate, self.affine.scale, self.affine.shear, image.size)
        image = TF.affine(image, angle, translations, scale, shear)
        mask = TF.affine(mask, angle, translations, scale, shear)

        image = self.color_jitter(image)

        i, j, h, w = transforms.RandomResizedCrop.get_params(image, scale=(0.8, 1.0), ratio=(0.75, 1.33))
        image = TF.resized_crop(image, i, j, h, w, size=(384, 384))
        mask = TF.resized_crop(mask, i, j, h, w, size=(384, 384))

        image = TF.to_tensor(image)
        mask = TF.to_tensor(mask)

        return image, mask

class SegmentationDataset(Dataset):
    def __init__(self, image_paths, mask_paths, transform=None):
        self.image_paths = image_paths
        self.mask_paths = mask_paths
        self.transform = transform

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        mask_path = self.mask_paths[idx]

        image = Image.open(img_path).convert("RGB")  
        mask = Image.open(mask_path).convert("L")  
        
        if self.transform:
            image, mask = self.transform(image, mask)

        return image, mask

# load, split dataset with paper setting
def load_and_split_dataset(image_dir, mask_dir, batch_size=10, name = None):

    valid_extensions = ('.tif', '.tiff', '.jpg', '.jpeg', '.png')
    image_paths = sorted([os.path.join(image_dir, fname) for fname in os.listdir(image_dir) if fname.lower().endswith(valid_extensions)])
    mask_paths = sorted([os.path.join(mask_dir, fname) for fname in os.listdir(mask_dir) if fname.lower().endswith(valid_extensions)])

    dataset = SegmentationDataset(image_paths, mask_paths, transform=ImageMaskTransform())

    train_size = int(0.8 * len(dataset))
    test_size = len(dataset) - train_size
    train_dataset, test_dataset = random_split(dataset, [train_size, test_size])
    print(name)
    print("entire dataset length: ", len(train_dataset) + len(test_dataset))
    
    val_size = int(0.25 * train_size)
    train_size = train_size - val_size
    train_dataset, val_dataset = random_split(train_dataset, [train_size, val_size])
    # dataset count check
    print("entire train dataset length: ", len(train_dataset))
    print("entire val dataset length: ", len(val_dataset))
    print("entire test dataset length: ", len(test_dataset))
    print()

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

    return train_loader, val_loader, test_loader

# Load datasets
cvc_train_loader, cvc_val_loader, cvc_test_loader = load_and_split_dataset(
    image_dir='CVC-ClinicDB/train', 
    mask_dir='CVC-ClinicDB/masks',
    name = "cvc clinic DB"
)

kvasirseg_train_loader, kvasirseg_val_loader, kvasirseg_test_loader = load_and_split_dataset(
    image_dir='/Kvasir-SEG/images', 
    mask_dir='/Kvasir-SEG/masks',
    name = 'Kvasir SEG'
)

kvasirinst_train_loader, kvasirinst_val_loader, kvasirinst_test_loader = load_and_split_dataset(
    image_dir='kvasir-instrument/images/images', 
    mask_dir='kvasir-instrument/mask',
    name = 'Kvasir instrument'
)

isic_train_loader, isic_val_loader, isic_test_loader = load_and_split_dataset(
    image_dir='ISIC-2017_Training_Data', 
    mask_dir='/ISIC-2017_Training_GroundTruth',
    name = 'ISIC 2017'
)

cvc clinic DB
entire dataset length:  612
entire train dataset length:  367
entire val dataset length:  122
entire test dataset length:  123

Kvasir SEG
entire dataset length:  1000
entire train dataset length:  600
entire val dataset length:  200
entire test dataset length:  200

Kvasir instrument
entire dataset length:  590
entire train dataset length:  354
entire val dataset length:  118
entire test dataset length:  118

ISIC 2017
entire dataset length:  2000
entire train dataset length:  1200
entire val dataset length:  400
entire test dataset length:  400



### Unet

In [25]:
import torch
import torch.nn as nn

# Conv Layer는 여러 번 쓰이므로 하나의 class로 정의
class conv_layer(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=True):
        super(conv_layer, self).__init__()
        self.conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels,
                              kernel_size=kernel_size, stride=stride, padding=padding, bias=bias)
        self.bn = nn.BatchNorm2d(num_features=out_channels)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        x = self.relu(x)
        return x

class U_Net(nn.Module):
    def __init__(self):
        super(U_Net, self).__init__()

        # Contracting path
        self.enc1_1 = conv_layer(in_channels=3, out_channels=64)
        self.enc1_2 = conv_layer(in_channels=64, out_channels=64)
        self.pool1 = nn.MaxPool2d(kernel_size=2)

        self.enc2_1 = conv_layer(in_channels=64, out_channels=128)
        self.enc2_2 = conv_layer(in_channels=128, out_channels=128)
        self.pool2 = nn.MaxPool2d(kernel_size=2)

        self.enc3_1 = conv_layer(in_channels=128, out_channels=256)
        self.enc3_2 = conv_layer(in_channels=256, out_channels=256)
        self.pool3 = nn.MaxPool2d(kernel_size=2)

        self.enc4_1 = conv_layer(in_channels=256, out_channels=512)
        self.enc4_2 = conv_layer(in_channels=512, out_channels=512)
        self.pool4 = nn.MaxPool2d(kernel_size=2)

        self.enc5_1 = conv_layer(in_channels=512, out_channels=1024)

        # Expansive path
        self.dec5_1 = conv_layer(in_channels=1024, out_channels=512)
        self.unpool4 = nn.ConvTranspose2d(in_channels=512, out_channels=512,
                                          kernel_size=2, stride=2, padding=0, bias=True)

        self.dec4_2 = conv_layer(in_channels=2 * 512, out_channels=512)
        self.dec4_1 = conv_layer(in_channels=512, out_channels=256)
        self.unpool3 = nn.ConvTranspose2d(in_channels=256, out_channels=256,
                                          kernel_size=2, stride=2, padding=0, bias=True)

        self.dec3_2 = conv_layer(in_channels=2 * 256, out_channels=256)
        self.dec3_1 = conv_layer(in_channels=256, out_channels=128)
        self.unpool2 = nn.ConvTranspose2d(in_channels=128, out_channels=128,
                                          kernel_size=2, stride=2, padding=0, bias=True)

        self.dec2_2 = conv_layer(in_channels=2 * 128, out_channels=128)
        self.dec2_1 = conv_layer(in_channels=128, out_channels=64)
        self.unpool1 = nn.ConvTranspose2d(in_channels=64, out_channels=64,
                                          kernel_size=2, stride=2, padding=0, bias=True)

        self.dec1_2 = conv_layer(in_channels=2 * 64, out_channels=64)
        self.dec1_1 = conv_layer(in_channels=64, out_channels=64)
        self.fc = nn.Conv2d(in_channels=64, out_channels=1, kernel_size=1, stride=1, padding=0, bias=True)

    def forward(self, x):
        enc1_1 = self.enc1_1(x)
        enc1_2 = self.enc1_2(enc1_1)
        pool1 = self.pool1(enc1_2)

        enc2_1 = self.enc2_1(pool1)
        enc2_2 = self.enc2_2(enc2_1)
        pool2 = self.pool2(enc2_2)

        enc3_1 = self.enc3_1(pool2)
        enc3_2 = self.enc3_2(enc3_1)
        pool3 = self.pool3(enc3_2)

        enc4_1 = self.enc4_1(pool3)
        enc4_2 = self.enc4_2(enc4_1)
        pool4 = self.pool4(enc4_2)

        enc5_1 = self.enc5_1(pool4)

        dec5_1 = self.dec5_1(enc5_1)

        unpool4 = self.unpool4(dec5_1)
        cat4 = torch.cat((unpool4, enc4_2), dim=1) # Skip Connection
        dec4_2 = self.dec4_2(cat4)
        dec4_1 = self.dec4_1(dec4_2)

        unpool3 = self.unpool3(dec4_1)
        cat3 = torch.cat((unpool3, enc3_2), dim=1)
        dec3_2 = self.dec3_2(cat3)
        dec3_1 = self.dec3_1(dec3_2)

        unpool2 = self.unpool2(dec3_1)
        cat2 = torch.cat((unpool2, enc2_2), dim=1)
        dec2_2 = self.dec2_2(cat2)
        dec2_1 = self.dec2_1(dec2_2)

        unpool1 = self.unpool1(dec2_1)
        cat1 = torch.cat((unpool1, enc1_2), dim=1)
        dec1_2 = self.dec1_2(cat1)
        dec1_1 = self.dec1_1(dec1_2)

        x = self.fc(dec1_1)

        return x

### Experiment Setting

In [26]:
import torch.optim as optim
import torch.nn as nn

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Hyperparameter setting
lr = 1e-4
patience = 10
criterion = nn.BCEWithLogitsLoss()
num_epochs = 500

#train - early stopping + get best val loss
def train_model(model, train_loader, val_loader, num_epochs=num_epochs):
    best_loss = float('inf')  
    epochs_no_improve = 0  
    optimizer = optim.Adam(model.parameters(), lr=lr)
    
    for epoch in range(num_epochs):
        model.train()  
        train_loss = 0.0
        
        for images, masks in train_loader:

            images = images.to(device)
            masks = masks.to(device)
 
            optimizer.zero_grad()
            outputs = model(images)

            loss = criterion(outputs, masks)
            
            loss.backward()
            optimizer.step()

            train_loss += loss.item() * images.size(0)
        
        train_loss = train_loss / len(train_loader.dataset)
        val_loss = validate_model(model, val_loader)
        
        print(f"Epoch [{epoch+1}/{num_epochs}] - Training Loss: {train_loss:.4f} - Val Loss: {val_loss:.4f}")

        if val_loss < best_loss:
            best_loss = val_loss
            print('best loss: ',best_loss)
            best_model_state = model.state_dict()
            epochs_no_improve = 0 
        else:
            epochs_no_improve += 1  
            
        if epochs_no_improve > patience:
            print(f"Early stopping triggered at epoch {epoch+1}")
            break  

        if best_model_state is not None:
            model.load_state_dict(best_model_state)

def validate_model(model, val_loader):
    model.eval() 
    val_loss = 0.0
    with torch.no_grad():  
        for images, masks in val_loader:
            images = images.to(device)
            masks = masks.to(device)
            
            outputs = model(images)
            loss = criterion(outputs, masks)
            val_loss += loss.item() * images.size(0)
    
    val_loss = val_loss / len(val_loader.dataset)
    return val_loss


In [27]:
import torch, gc
gc.collect()
torch.cuda.empty_cache()

### Metrics

In [28]:
from sklearn.metrics import roc_auc_score

# batch 당 iou 정의, threshold = 0.5
def calculate_batch_iou(predictions, targets, threshold=0.5, eps = 1e-8):

    batch_size = predictions.size(0)
    predictions = predictions.float()  
    
    predictions = torch.sigmoid(predictions)
    
    pred_class = (predictions >= threshold).float()
    
    pred_class = pred_class.view(batch_size, -1)
    target_class = targets.view(batch_size, -1).float()
    
    intersection = (pred_class * target_class).sum(dim=1)
    union = pred_class.sum(dim=1) + target_class.sum(dim=1) - intersection
    
    iou = (intersection +eps) / (union+eps) 
    
    return iou.mean().item()

#AUROC 정의 - sklearn package 사용
def compute_auroc(pred_mask, true_mask):
    
    pred_flat = pred_mask.view(-1).cpu().numpy()
    true_flat = true_mask.view(-1).cpu().numpy().astype(int)
    auroc = roc_auc_score(true_flat, pred_flat)

    return auroc

### Inference

In [29]:
import torch
import numpy as np

def inference(model, test_loader):
    model.eval()  
    auroc_score = 0.0
    total_samples = 0
    class_iou = np.zeros(1)
    threshold = 0.5
    total_batch = 0

    with torch.no_grad():  
        for images, targets in test_loader:

            images = images.to(device)
            targets_iou = targets.to(device)

            targets = targets_iou.squeeze(1).long()
            y_pred = model(images)

            batch_iou = calculate_batch_iou(y_pred, targets_iou, threshold=threshold)
            class_iou += batch_iou 
            total_batch +=1
            
            batch_auroc = compute_auroc(y_pred, targets)   
            auroc_score += batch_auroc * images.size(0) 
            total_samples += images.size(0)

    avg_auroc = auroc_score / total_samples
    avg_iou = class_iou / total_batch

    print(f"Average IoU: {avg_iou[0]:.4f}")
    print(f"Average AUROC: {avg_auroc:.4f}")
    


In [30]:
model = U_Net()
model = model.to(device)
train_model(model, cvc_train_loader, cvc_val_loader)

Epoch [1/500] - Training Loss: 0.5879 - Val Loss: 0.4753
best loss:  0.47529339301781576
Epoch [2/500] - Training Loss: 0.4430 - Val Loss: 0.4143
best loss:  0.41433037695337516
Epoch [3/500] - Training Loss: 0.4136 - Val Loss: 0.4104
best loss:  0.41040347247827247
Epoch [4/500] - Training Loss: 0.3904 - Val Loss: 0.3951
best loss:  0.3951438640961882
Epoch [5/500] - Training Loss: 0.3762 - Val Loss: 0.3754
best loss:  0.3754367100410774
Epoch [6/500] - Training Loss: 0.3647 - Val Loss: 0.3571
best loss:  0.3571133618472052
Epoch [7/500] - Training Loss: 0.3504 - Val Loss: 0.3583
Epoch [8/500] - Training Loss: 0.3384 - Val Loss: 0.3388
best loss:  0.3387653744611584
Epoch [9/500] - Training Loss: 0.3308 - Val Loss: 0.3349
best loss:  0.3348921869621902
Epoch [10/500] - Training Loss: 0.3164 - Val Loss: 0.3146
best loss:  0.31462047334577214
Epoch [11/500] - Training Loss: 0.3084 - Val Loss: 0.3406
Epoch [12/500] - Training Loss: 0.2958 - Val Loss: 0.3179
Epoch [13/500] - Training Loss

In [31]:
inference(model, cvc_test_loader)

Average IoU: 0.7232
Average AUROC: 0.9811


### experiment U-NET backbone: vgg16

In [32]:
import segmentation_models_pytorch as smp

vgg16_unet_model = smp.Unet(
    encoder_name="vgg16",     
    encoder_weights="imagenet",    
    in_channels=3,                 
    classes=1                    
)
vgg16_unet_model.to(device)
train_model(vgg16_unet_model, cvc_train_loader, cvc_val_loader)

Epoch [1/500] - Training Loss: 0.5575 - Val Loss: 0.4140
best loss:  0.4140317166437868
Epoch [2/500] - Training Loss: 0.3789 - Val Loss: 0.3313
best loss:  0.3313110568484322
Epoch [3/500] - Training Loss: 0.2884 - Val Loss: 0.2521
best loss:  0.2521477896170538
Epoch [4/500] - Training Loss: 0.2266 - Val Loss: 0.2572
Epoch [5/500] - Training Loss: 0.2063 - Val Loss: 0.2106
best loss:  0.2106130267264413
Epoch [6/500] - Training Loss: 0.1844 - Val Loss: 0.1795
best loss:  0.17952794593865753
Epoch [7/500] - Training Loss: 0.1673 - Val Loss: 0.1721
best loss:  0.17212115020536986
Epoch [8/500] - Training Loss: 0.1515 - Val Loss: 0.1637
best loss:  0.16368669955456844
Epoch [9/500] - Training Loss: 0.1414 - Val Loss: 0.1530
best loss:  0.15296817399927828
Epoch [10/500] - Training Loss: 0.1274 - Val Loss: 0.1447
best loss:  0.14471675650995286
Epoch [11/500] - Training Loss: 0.1202 - Val Loss: 0.1327
best loss:  0.1326880401275197
Epoch [12/500] - Training Loss: 0.1118 - Val Loss: 0.150

In [33]:
inference(vgg16_unet_model, cvc_test_loader)

Average IoU: 0.8361
Average AUROC: 0.9913


In [34]:
model = U_Net()
model = model.to(device)
train_model(model, kvasirinst_train_loader, kvasirinst_val_loader)

In [35]:
inference(model, kvasirinst_test_loader)

In [36]:
import segmentation_models_pytorch as smp

vgg16_unet_model = smp.Unet(
    encoder_name="vgg16",     
    encoder_weights="imagenet",    
    in_channels=3,                 
    classes=1                    
)
vgg16_unet_model.to(device)
train_model(vgg16_unet_model, kvasirinst_train_loader, kvasirinst_val_loader)

In [37]:
inference(vgg16_unet_model, kvasirinst_test_loader)

In [38]:
model = U_Net()
model = model.to(device)
train_model(model, kvasirseg_train_loader, kvasirseg_val_loader)

Epoch [1/500] - Training Loss: 0.6055 - Val Loss: 1.0405
best loss:  1.040546452999115
Epoch [2/500] - Training Loss: 0.5024 - Val Loss: 0.5007
best loss:  0.5006719529628754
Epoch [3/500] - Training Loss: 0.4781 - Val Loss: 0.4776
best loss:  0.4776441127061844
Epoch [4/500] - Training Loss: 0.4666 - Val Loss: 0.4640
best loss:  0.46398368030786513
Epoch [5/500] - Training Loss: 0.4483 - Val Loss: 0.4483
best loss:  0.4482656195759773
Epoch [6/500] - Training Loss: 0.4332 - Val Loss: 0.4350
best loss:  0.43503053337335584
Epoch [7/500] - Training Loss: 0.4256 - Val Loss: 0.4304
best loss:  0.4304482087492943
Epoch [8/500] - Training Loss: 0.4247 - Val Loss: 0.4221
best loss:  0.42210519015789033
Epoch [9/500] - Training Loss: 0.4096 - Val Loss: 0.4389
Epoch [10/500] - Training Loss: 0.4056 - Val Loss: 0.5081
Epoch [11/500] - Training Loss: 0.4000 - Val Loss: 0.3993
best loss:  0.3992871791124344
Epoch [12/500] - Training Loss: 0.3849 - Val Loss: 0.4041
Epoch [13/500] - Training Loss: 

In [39]:
inference(model, kvasirseg_test_loader)

Average IoU: 0.6995
Average AUROC: 0.9697


In [40]:
import segmentation_models_pytorch as smp

vgg16_unet_model = smp.Unet(
    encoder_name="vgg16",     
    encoder_weights="imagenet",    
    in_channels=3,                 
    classes=1                    
)
vgg16_unet_model.to(device)
train_model(vgg16_unet_model, kvasirseg_train_loader, kvasirseg_val_loader)

Epoch [1/500] - Training Loss: 0.3928 - Val Loss: 0.2873
best loss:  0.2872691288590431
Epoch [2/500] - Training Loss: 0.2629 - Val Loss: 0.2458
best loss:  0.24583308771252632
Epoch [3/500] - Training Loss: 0.2158 - Val Loss: 0.2030
best loss:  0.20301557704806328
Epoch [4/500] - Training Loss: 0.1893 - Val Loss: 0.1997
best loss:  0.19974569380283355
Epoch [5/500] - Training Loss: 0.1749 - Val Loss: 0.1481
best loss:  0.14809590466320516
Epoch [6/500] - Training Loss: 0.1495 - Val Loss: 0.1681
Epoch [7/500] - Training Loss: 0.1579 - Val Loss: 0.1568
Epoch [8/500] - Training Loss: 0.1361 - Val Loss: 0.1558
Epoch [9/500] - Training Loss: 0.1312 - Val Loss: 0.1394
best loss:  0.13944493941962718
Epoch [10/500] - Training Loss: 0.1218 - Val Loss: 0.1553
Epoch [11/500] - Training Loss: 0.1243 - Val Loss: 0.1336
best loss:  0.13364214599132537
Epoch [12/500] - Training Loss: 0.1129 - Val Loss: 0.1304
best loss:  0.13044996820390226
Epoch [13/500] - Training Loss: 0.1099 - Val Loss: 0.1520


In [41]:
inference(vgg16_unet_model, kvasirseg_test_loader)

Average IoU: 0.8014
Average AUROC: 0.9904


In [42]:
model = U_Net()
model = model.to(device)
train_model(model, isic_train_loader, isic_val_loader)

KeyboardInterrupt: 

In [125]:
inference(model, isic_test_loader)

In [126]:
import segmentation_models_pytorch as smp

vgg16_unet_isic_model = smp.Unet(
    encoder_name="vgg16",     
    encoder_weights="imagenet",    
    in_channels=3,                 
    classes=1                    
)

In [127]:
vgg16_unet_isic_model.to(device)
train_model(vgg16_unet_isic_model, isic_train_loader, isic_val_loader)

In [128]:
inference(vgg16_unet_isic_model, isic_test_loader)