In [1]:
import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchvision.datasets import OxfordIIITPet
# import matplotlib.pyplot as plt
import segmentation_models_pytorch as smp
from PIL import Image

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

Using device: cuda


In [2]:
class OxfordIIITPetsAugmented(OxfordIIITPet):
    """inherit from torchvision.datasets.OxfordIIITPet"""
    def __init__(self,
                 root: str,
                 split: str,
                 target_types="segmentation",
                 download=True,
                 pre_transform=None,
                 post_transform=None,
                 pre_target_transform=None,
                 post_target_transform=None, 
                 common_transform=None):
        
        super().__init__(
            root=root,
            split=split,
            target_types=target_types,
            download=download,
            transform=pre_transform,
            target_transform=pre_target_transform
        )
        
        self.post_transform = post_transform
        self.post_target_transform = post_target_transform
        self.common_transform = common_transform

    def __getitem__(self, idx):
        input, target = super().__getitem__(idx)
        
        if self.common_transform is not None:
            both = torch.cat([input, target], dim=0)
            both = self.common_transform(both)
            input, target = torch.split(both, 3, dim=0)
            
        if self.post_transform is not None:
            input = self.post_transform(input)
            
        if self.post_target_transform is not None:
            target = self.post_target_transform(target)
            
        return input, target

In [3]:
def tensor_trimap(t):
    x = t * 255
    x = x.to(torch.long)
    x = x - 1
    return x

def get_oxford_dataloaders(root='./OxfordPetsData', train_size=64, test_size=32):
    transform_dict = {
        'pre_transform': transforms.ToTensor(),
        'pre_target_transform': transforms.ToTensor(),
        'common_transform': transforms.Compose([
            transforms.Resize((224, 224), interpolation=transforms.InterpolationMode.NEAREST),
            transforms.RandomHorizontalFlip(p=0.5),
        ]),
        'post_transform': transforms.Compose([
            transforms.ColorJitter(contrast=0.3),
        ]),
        'post_target_transform': transforms.Compose([
            transforms.Lambda(tensor_trimap),
        ])
    }


    train_dataset = OxfordIIITPetsAugmented(
        root=root,
        split="trainval",
        target_types="segmentation",
        download=True,  
        **transform_dict
    )

   
    test_dataset = OxfordIIITPetsAugmented(
        root=root,  
        split="test",
        target_types="segmentation",
        download=True,
        **transform_dict
    )

    train_loader = DataLoader(
        train_dataset,
        batch_size=train_size,
        shuffle=True,
    )

    test_loader = DataLoader(
        test_dataset,
        batch_size=test_size,
        shuffle=False,
    )

    return train_dataset, test_dataset, train_loader, test_loader

In [4]:
train_ds, test_ds, train_loader, test_loader = get_oxford_dataloaders(train_size=32, test_size=16)

In [5]:
for images, masks in train_loader:
    print(f"Image batch shape: {images.shape}")  # [batch_size, 3, 224, 224]
    print(f"Mask batch shape: {masks.shape}")  # [batch_size, 1, 224, 224]
    break  

Image batch shape: torch.Size([32, 3, 224, 224])
Mask batch shape: torch.Size([32, 1, 224, 224])


In [5]:
def UNet():
    model = smp.Unet(
        encoder_name="resnet34",
        encoder_weights="imagenet",
        in_channels=3,
        classes=3,
    )
    return model

def dice_coefficient(y_true, y_pred):
    """
    y_true: GT label (B, H, W)
    y_pred: Predicted label (B, H, W)
    """
    
    num_classes = 3  
    dice_scores = []
    smooth = 1e-7  

    for class_id in range(num_classes):
        true_flat = (y_true == class_id).float()
        pred_flat = (y_pred == class_id).float()
        
        intersection = torch.sum(true_flat * pred_flat, dim=[1,2])
        union = torch.sum(true_flat, dim=[1,2]) + torch.sum(pred_flat, dim=[1,2])
        
        dice = (2.0 * intersection + smooth) / (union + smooth)
        dice_scores.append(dice.mean())
    
    return torch.mean(torch.stack(dice_scores))

In [7]:
from focalloss import focal_loss

In [6]:
from torchvision.utils import make_grid
from PIL import Image, ImageDraw, ImageFont
import os

def visualize_predictions(model, val_loader, epoch, model_name):
    device = next(model.parameters()).device
    model.eval()
    # model_name = str(model)
    
    with torch.no_grad():
        samples = []
        for i, (images, masks) in enumerate(val_loader):
            if i >= 3:
                break
            images = images.to(device)
            masks = masks.squeeze(1).to(device)
            outputs = model(images)
            predictions = torch.argmax(outputs, dim=1)
            
            samples.append((
                images[0:1],  # [1, C, H, W]
                masks[0:1].unsqueeze(1).repeat(1,3,1,1),  # [1, 3, H, W]
                predictions[0:1].unsqueeze(1).repeat(1,3,1,1)  # [1, 3, H, W]
            ))
            
        img_row = torch.cat([s[0] for s in samples], dim=0)  # [3, C, H, W]
        mask_row = torch.cat([s[1] for s in samples], dim=0)  # [3, C, H, W] 
        pred_row = torch.cat([s[2] for s in samples], dim=0)  # [3, C, H, W]
        
        grid = make_grid(
            torch.cat([img_row, mask_row, pred_row], dim=0),
            nrow=3,
            normalize=True,
            pad_value=1,
            padding=2
        )
        
        #grid_np = grid.permute(1, 2, 0).cpu().numpy()
        # plt.imsave(f'./images/epoch_{epoch}.jpg', grid_np)

        grid_np = (grid.permute(1, 2, 0).cpu().numpy() * 255).astype('uint8')
        img = Image.fromarray(grid_np)
        # Image.fromarray(grid_np).resize((224, 224), Image.Resampling.LANCZOS).save(f'./images/epoch_{epoch}.jpg', quality=90)
        # save_dir = os.path.join("./images/", model_name, "/epoch_{epoch}.jpg")
        # Image.fromarray(grid_np).resize((224, 224), Image.Resampling.LANCZOS).save(save_dir, quality=90)

        title_height = 50
        new_img = Image.new('RGB', (img.width, img.height + title_height), 'white')
        draw = ImageDraw.Draw(new_img)
        
        try:
            font = ImageFont.truetype("arial.ttf", 24)
        except:
            font = ImageFont.load_default()
            
        titles = ['Input', 'Ground Truth', 'Prediction']
        width_per_img = img.width // 3
        for i, title in enumerate(titles):
            w = draw.textlength(title, font=font)
            x = i * width_per_img + (width_per_img - w) // 2
            draw.text((x, 10), title, font=font, fill='black')

        save_dir = os.path.join("./images/", model_name)
        os.makedirs(save_dir, exist_ok=True)
        save_path = os.path.join(save_dir, f"epoch_{epoch}.jpg")
            
        new_img.paste(img, (0, title_height))
        new_img.resize((512, 512), Image.Resampling.LANCZOS).save(save_path, quality=90)

        # save_dir = os.path.join("./images/", model_name)
        # os.makedirs(save_dir, exist_ok=True)
        # save_path = os.path.join(save_dir, f"epoch_{epoch}.jpg")
        # Image.fromarray(grid_np).resize((512, 512), Image.Resampling.LANCZOS).save(save_path, quality=90)

In [7]:
from tqdm import tqdm
from torchmetrics import JaccardIndex

jaccard = JaccardIndex(task="multiclass", num_classes=3).to(device)

def train_model(model, train_loader, val_loader, epochs=20, learning_rate=5e-4, model_name = "model"):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = model.to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=0.0001)
    scheduler = torch.optim.lr_scheduler.LinearLR(
        optimizer,
        start_factor=1.0,
        end_factor=0.01,
        total_iters=epochs
        )
    
    os.makedirs('./checkpoints', exist_ok=True)
    best_val_loss = float('inf')
    
    for epoch in range(epochs):
        # try:
        #     visualize_predictions(model, val_loader, epoch)
        # except Exception as e:
        #     print(f"Visualization failed: {str(e)}")
        
        # Training phase
        if epoch % 2 == 0:
            visualize_predictions(model, val_loader, epoch=epoch, model_name=model_name)

        model.train()
        train_loss = 0
        train_dice = 0
        train_acc = 0
        train_miou = 0
        n_batches = 0
        
        pbar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{epochs} [Train]')
        for images, masks in pbar:
            images = images.to(device)
            masks = masks.squeeze(1).to(device)
            
            outputs = model(images)
            loss = criterion(outputs, masks)
            
            preds = torch.argmax(outputs, dim=1)
            dice = dice_coefficient(masks, preds)
            accuracy = (preds == masks).float().mean()
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            train_loss += loss.item()
            train_dice += dice.item()
            train_acc += accuracy.item()
            train_miou += jaccard(preds, masks).item()
            n_batches += 1
            
            avg_loss = train_loss / n_batches
            avg_dice = train_dice / n_batches
            avg_acc = train_acc / n_batches
            avg_miou = train_miou / n_batches
            
            pbar.set_postfix({
                'loss': f"{avg_loss:.4f}",
                'miou': f"{avg_miou:.4f}",
                'dice': f"{avg_dice:.4f}",
                'pixel acc': f"{avg_acc:.4f}"
            })

        scheduler.step()
        
        # Validation phase
        model.eval()
        val_loss = 0
        val_dice = 0
        val_acc = 0
        val_miou = 0
        n_batches = 0
        
        # pbar = tqdm(val_loader, desc=f'Epoch {epoch+1}/{epochs} [Valid]')
        with torch.no_grad():
            for images, masks in val_loader:
                images = images.to(device)
                masks = masks.squeeze(1).to(device)
                
                outputs = model(images)
                loss = criterion(outputs, masks)
                
                preds = torch.argmax(outputs, dim=1)
                dice = dice_coefficient(masks, preds)
                accuracy = (preds == masks).float().mean()
                
                val_loss += loss.item()
                val_dice += dice.item()
                val_acc += accuracy.item()
                val_miou += jaccard(preds, masks).item()
                n_batches += 1
                
        avg_val_loss = val_loss / n_batches
        avg_val_dice = val_dice / n_batches
        avg_val_acc = val_acc / n_batches
        avg_val_miou = val_miou / n_batches
                
                # pbar.set_postfix({
                #     'val_loss': f"{avg_val_loss:.4f}",
                #     'val_dice': f"{avg_val_dice:.4f}",
                #     'val_acc': f"{avg_val_acc:.4f}"
                # })
        # pbar.set_description(f"Val Loss: {avg_val_loss:.4f}, Val dice: {avg_val_dice:.4f}, Val acc: {avg_val_acc:.4f}")
        print(f"[Valid] Loss: {avg_val_loss:.4f}, Dice: {avg_val_dice:.4f}, Pixel acc: {avg_val_acc:.4f}, mIoU: {avg_val_miou:.4f}, LR: {scheduler.get_last_lr()[0]:.6f}")

        # Save best model
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'val_loss': avg_val_loss,
                'val_dice': avg_val_dice,
                'val_acc': avg_val_acc,
            }, './checkpoints/best_unet.pth')
    
    # return model

In [42]:
unet = UNet()
train_model(unet, train_loader, test_loader)

Epoch 1/20 [Train]: 100%|██████████| 115/115 [00:48<00:00,  2.36it/s, loss=0.3547, miou=0.6794, dice=0.7740, acc=0.8698]


[Valid] Loss: 0.3002, Dice: 0.8010, Acc: 0.8910, mIoU: 0.7083


Epoch 2/20 [Train]: 100%|██████████| 115/115 [00:54<00:00,  2.10it/s, loss=0.2355, miou=0.7488, dice=0.8337, acc=0.9104]


[Valid] Loss: 0.2315, Dice: 0.8422, Acc: 0.9136, mIoU: 0.7582


Epoch 3/20 [Train]: 100%|██████████| 115/115 [00:53<00:00,  2.13it/s, loss=0.2144, miou=0.7635, dice=0.8448, acc=0.9174]


[Valid] Loss: 0.2349, Dice: 0.8376, Acc: 0.9122, mIoU: 0.7528


Epoch 4/20 [Train]: 100%|██████████| 115/115 [00:54<00:00,  2.12it/s, loss=0.2053, miou=0.7706, dice=0.8500, acc=0.9206]


[Valid] Loss: 0.2212, Dice: 0.8516, Acc: 0.9183, mIoU: 0.7704


Epoch 5/20 [Train]: 100%|██████████| 115/115 [00:54<00:00,  2.11it/s, loss=0.1918, miou=0.7793, dice=0.8567, acc=0.9247]


[Valid] Loss: 0.2263, Dice: 0.8499, Acc: 0.9142, mIoU: 0.7672


Epoch 6/20 [Train]: 100%|██████████| 115/115 [00:54<00:00,  2.12it/s, loss=0.1742, miou=0.7925, dice=0.8659, acc=0.9310]


[Valid] Loss: 0.2122, Dice: 0.8513, Acc: 0.9211, mIoU: 0.7709


Epoch 7/20 [Train]: 100%|██████████| 115/115 [00:54<00:00,  2.13it/s, loss=0.1720, miou=0.7941, dice=0.8668, acc=0.9316]


[Valid] Loss: 0.2687, Dice: 0.8261, Acc: 0.9025, mIoU: 0.7393


Epoch 8/20 [Train]: 100%|██████████| 115/115 [00:54<00:00,  2.11it/s, loss=0.1803, miou=0.7891, dice=0.8637, acc=0.9290]


[Valid] Loss: 0.2315, Dice: 0.8444, Acc: 0.9120, mIoU: 0.7611


Epoch 9/20 [Train]: 100%|██████████| 115/115 [00:54<00:00,  2.10it/s, loss=0.1620, miou=0.8012, dice=0.8718, acc=0.9349]


[Valid] Loss: 0.2101, Dice: 0.8611, Acc: 0.9211, mIoU: 0.7821


Epoch 10/20 [Train]: 100%|██████████| 115/115 [00:54<00:00,  2.12it/s, loss=0.1539, miou=0.8080, dice=0.8767, acc=0.9379]


[Valid] Loss: 0.2108, Dice: 0.8617, Acc: 0.9235, mIoU: 0.7841


Epoch 11/20 [Train]: 100%|██████████| 115/115 [00:54<00:00,  2.12it/s, loss=0.1518, miou=0.8102, dice=0.8782, acc=0.9386]


[Valid] Loss: 0.2154, Dice: 0.8619, Acc: 0.9245, mIoU: 0.7847


Epoch 12/20 [Train]: 100%|██████████| 115/115 [00:54<00:00,  2.11it/s, loss=0.1541, miou=0.8090, dice=0.8773, acc=0.9379]


[Valid] Loss: 0.2069, Dice: 0.8647, Acc: 0.9256, mIoU: 0.7889


Epoch 13/20 [Train]: 100%|██████████| 115/115 [00:54<00:00,  2.12it/s, loss=0.1581, miou=0.8051, dice=0.8745, acc=0.9361]


[Valid] Loss: 0.2246, Dice: 0.8544, Acc: 0.9190, mIoU: 0.7743


Epoch 14/20 [Train]: 100%|██████████| 115/115 [00:53<00:00,  2.14it/s, loss=0.1458, miou=0.8144, dice=0.8812, acc=0.9405]


[Valid] Loss: 0.2069, Dice: 0.8641, Acc: 0.9251, mIoU: 0.7884


Epoch 15/20 [Train]: 100%|██████████| 115/115 [00:53<00:00,  2.13it/s, loss=0.1313, miou=0.8269, dice=0.8900, acc=0.9457]


[Valid] Loss: 0.1941, Dice: 0.8732, Acc: 0.9309, mIoU: 0.8016


Epoch 16/20 [Train]: 100%|██████████| 115/115 [00:54<00:00,  2.12it/s, loss=0.1397, miou=0.8219, dice=0.8866, acc=0.9430]


[Valid] Loss: 0.2330, Dice: 0.8572, Acc: 0.9190, mIoU: 0.7770


Epoch 17/20 [Train]: 100%|██████████| 115/115 [00:54<00:00,  2.12it/s, loss=0.1355, miou=0.8239, dice=0.8880, acc=0.9442]


[Valid] Loss: 0.2129, Dice: 0.8633, Acc: 0.9259, mIoU: 0.7865


Epoch 18/20 [Train]: 100%|██████████| 115/115 [00:54<00:00,  2.12it/s, loss=0.1371, miou=0.8240, dice=0.8881, acc=0.9439]


[Valid] Loss: 0.2684, Dice: 0.8377, Acc: 0.9103, mIoU: 0.7547


Epoch 19/20 [Train]: 100%|██████████| 115/115 [00:54<00:00,  2.12it/s, loss=0.1663, miou=0.8006, dice=0.8713, acc=0.9338]


[Valid] Loss: 0.2362, Dice: 0.8523, Acc: 0.9158, mIoU: 0.7712


Epoch 20/20 [Train]: 100%|██████████| 115/115 [00:54<00:00,  2.12it/s, loss=0.1382, miou=0.8215, dice=0.8859, acc=0.9432]


[Valid] Loss: 0.2225, Dice: 0.8639, Acc: 0.9230, mIoU: 0.7865


In [33]:
checkpoint = torch.load('./checkpoints/best_unet.pth')
unet = UNet()
unet.load_state_dict(checkpoint['model_state_dict'])

visualize_predictions(unet, test_loader, epoch=0)

In [7]:
def focal_loss(alpha=0.25, gamma=1.2):
    def loss(y_true, y_pred):
        epsilon = 1e-7
        y_pred = tf.clip_by_value(y_pred, epsilon, 1. - epsilon)
        ce = -y_true * tf.math.log(y_pred)
        weights = tf.pow(1 - y_pred, gamma) * y_true
        weights = weights * alpha
        fl = weights * ce
        return tf.reduce_mean(tf.reduce_sum(fl, axis=-1))  
    return loss

class F1Score(tf.keras.metrics.Metric):
    def __init__(self, name="f1_score", **kwargs):
        super(F1Score, self).__init__(name=name, **kwargs)
        self.precision = tf.keras.metrics.Precision()
        self.recall = tf.keras.metrics.Recall()

    def update_state(self, y_true, y_pred, sample_weight=None):
        y_true = tf.argmax(y_true, axis=-1)
        y_pred = tf.argmax(y_pred, axis=-1)
        self.precision.update_state(y_true, y_pred, sample_weight)
        self.recall.update_state(y_true, y_pred, sample_weight)

    def result(self):
        precision = self.precision.result()
        recall = self.recall.result()
        return 2 * (precision * recall) / (precision + recall + 1e-7)  

    def reset_states(self):
        self.precision.reset_states()
        self.recall.reset_states()

In [75]:
# Transformer Encoder UNet
# def TEUNet():
#     model = smp.UnetPlusPlus(  
#         encoder_name="mit_b2",
#         encoder_weights="imagenet",
#         in_channels=3,
#         classes=3
#     )
#     return model

def UNet50():
    model = smp.Unet(
        encoder_name="resnet50",
        encoder_weights="imagenet",
        in_channels=3,
        classes=3,
    )
    return model

def DenseUNet():
   model = smp.UnetPlusPlus(
       encoder_name="densenet121",  #densenet169, densenet201, densenet161
       encoder_weights="imagenet",
       in_channels=3,
       classes=3
   )
   return model

def FPN():
    model = smp.FPN(
        encoder_name="resnet34",
        encoder_weights="imagenet",
        in_channels=3,
        classes=3
    )
    return model

def DeepLabV3():
    model = smp.DeepLabV3(
        encoder_name="resnet34", 
        encoder_weights="imagenet",
        in_channels=3,
        classes=3
    )
    return model

In [74]:
fpn = FPN()
train_model(fpn, train_loader, test_loader, epochs=20, learning_rate=1e-4, model_name="fpn")

Epoch 1/20 [Train]: 100%|██████████| 115/115 [00:40<00:00,  2.87it/s, loss=0.5033, miou=0.5982, dice=0.6739, acc=0.8318]


[Valid] Loss: 0.2572, Dice: 0.8185, Acc: 0.9028, mIoU: 0.7330


Epoch 2/20 [Train]: 100%|██████████| 115/115 [00:46<00:00,  2.47it/s, loss=0.2609, miou=0.7179, dice=0.7983, acc=0.8992]


[Valid] Loss: 0.2308, Dice: 0.8356, Acc: 0.9152, mIoU: 0.7546


Epoch 3/20 [Train]: 100%|██████████| 115/115 [00:44<00:00,  2.56it/s, loss=0.2164, miou=0.7540, dice=0.8316, acc=0.9155]


[Valid] Loss: 0.2217, Dice: 0.8470, Acc: 0.9191, mIoU: 0.7680


Epoch 4/20 [Train]: 100%|██████████| 115/115 [00:44<00:00,  2.57it/s, loss=0.1902, miou=0.7743, dice=0.8484, acc=0.9245]


[Valid] Loss: 0.2161, Dice: 0.8576, Acc: 0.9235, mIoU: 0.7805


Epoch 5/20 [Train]: 100%|██████████| 115/115 [00:56<00:00,  2.03it/s, loss=0.1781, miou=0.7856, dice=0.8579, acc=0.9291]


[Valid] Loss: 0.2073, Dice: 0.8598, Acc: 0.9248, mIoU: 0.7835


Epoch 6/20 [Train]: 100%|██████████| 115/115 [00:59<00:00,  1.94it/s, loss=0.1637, miou=0.7971, dice=0.8662, acc=0.9341]


[Valid] Loss: 0.2091, Dice: 0.8622, Acc: 0.9254, mIoU: 0.7861


Epoch 7/20 [Train]: 100%|██████████| 115/115 [00:49<00:00,  2.31it/s, loss=0.1584, miou=0.8022, dice=0.8700, acc=0.9360]


[Valid] Loss: 0.2064, Dice: 0.8654, Acc: 0.9266, mIoU: 0.7900


Epoch 8/20 [Train]: 100%|██████████| 115/115 [00:48<00:00,  2.37it/s, loss=0.1499, miou=0.8100, dice=0.8758, acc=0.9391]


[Valid] Loss: 0.2109, Dice: 0.8648, Acc: 0.9267, mIoU: 0.7896


Epoch 9/20 [Train]: 100%|██████████| 115/115 [00:45<00:00,  2.52it/s, loss=0.1431, miou=0.8157, dice=0.8801, acc=0.9415]


[Valid] Loss: 0.2076, Dice: 0.8665, Acc: 0.9285, mIoU: 0.7920


Epoch 10/20 [Train]: 100%|██████████| 115/115 [00:49<00:00,  2.31it/s, loss=0.1360, miou=0.8228, dice=0.8852, acc=0.9442]


[Valid] Loss: 0.2077, Dice: 0.8679, Acc: 0.9294, mIoU: 0.7942


Epoch 11/20 [Train]: 100%|██████████| 115/115 [00:47<00:00,  2.40it/s, loss=0.1313, miou=0.8272, dice=0.8887, acc=0.9459]


[Valid] Loss: 0.2120, Dice: 0.8682, Acc: 0.9294, mIoU: 0.7945


Epoch 12/20 [Train]: 100%|██████████| 115/115 [00:50<00:00,  2.28it/s, loss=0.1268, miou=0.8315, dice=0.8914, acc=0.9475]


[Valid] Loss: 0.2180, Dice: 0.8662, Acc: 0.9287, mIoU: 0.7915


Epoch 13/20 [Train]: 100%|██████████| 115/115 [00:48<00:00,  2.38it/s, loss=0.1234, miou=0.8352, dice=0.8946, acc=0.9488]


[Valid] Loss: 0.2130, Dice: 0.8696, Acc: 0.9305, mIoU: 0.7964


Epoch 14/20 [Train]: 100%|██████████| 115/115 [00:49<00:00,  2.32it/s, loss=0.1186, miou=0.8398, dice=0.8977, acc=0.9506]


[Valid] Loss: 0.2158, Dice: 0.8709, Acc: 0.9302, mIoU: 0.7976


Epoch 15/20 [Train]: 100%|██████████| 115/115 [00:46<00:00,  2.48it/s, loss=0.1156, miou=0.8432, dice=0.9004, acc=0.9518]


[Valid] Loss: 0.2205, Dice: 0.8702, Acc: 0.9304, mIoU: 0.7972


Epoch 16/20 [Train]: 100%|██████████| 115/115 [00:46<00:00,  2.45it/s, loss=0.1147, miou=0.8447, dice=0.9012, acc=0.9522]


[Valid] Loss: 0.2340, Dice: 0.8622, Acc: 0.9267, mIoU: 0.7864


Epoch 17/20 [Train]: 100%|██████████| 115/115 [00:47<00:00,  2.45it/s, loss=0.1333, miou=0.8321, dice=0.8925, acc=0.9467]


[Valid] Loss: 0.2623, Dice: 0.8375, Acc: 0.9132, mIoU: 0.7580


Epoch 18/20 [Train]: 100%|██████████| 115/115 [00:46<00:00,  2.45it/s, loss=0.1382, miou=0.8222, dice=0.8850, acc=0.9435]


[Valid] Loss: 0.2233, Dice: 0.8604, Acc: 0.9249, mIoU: 0.7843


Epoch 19/20 [Train]: 100%|██████████| 115/115 [00:47<00:00,  2.43it/s, loss=0.1201, miou=0.8388, dice=0.8968, acc=0.9500]


[Valid] Loss: 0.2146, Dice: 0.8708, Acc: 0.9304, mIoU: 0.7978


Epoch 20/20 [Train]: 100%|██████████| 115/115 [00:47<00:00,  2.44it/s, loss=0.1110, miou=0.8480, dice=0.9039, acc=0.9535]


[Valid] Loss: 0.2212, Dice: 0.8678, Acc: 0.9300, mIoU: 0.7941


In [80]:
unet50 = UNet50()
train_model(unet50, train_loader, test_loader, epochs=20, learning_rate=1e-4, model_name="unet50")

Epoch 1/20 [Train]: 100%|██████████| 115/115 [01:11<00:00,  1.62it/s, loss=0.6610, miou=0.6247, dice=0.7241, pixel acc=0.8077]


[Valid] Loss: 0.4102, Dice: 0.8459, Pixel acc: 0.9115, mIoU: 0.7627, LR: 0.000095


Epoch 2/20 [Train]: 100%|██████████| 115/115 [01:18<00:00,  1.47it/s, loss=0.3392, miou=0.7699, dice=0.8485, pixel acc=0.9184]


[Valid] Loss: 0.2783, Dice: 0.8601, Pixel acc: 0.9245, mIoU: 0.7825, LR: 0.000090


Epoch 3/20 [Train]: 100%|██████████| 115/115 [01:17<00:00,  1.48it/s, loss=0.2478, miou=0.7926, dice=0.8657, pixel acc=0.9297]


[Valid] Loss: 0.2416, Dice: 0.8660, Pixel acc: 0.9265, mIoU: 0.7895, LR: 0.000085


Epoch 4/20 [Train]: 100%|██████████| 115/115 [01:17<00:00,  1.48it/s, loss=0.2068, miou=0.8062, dice=0.8751, pixel acc=0.9361]


[Valid] Loss: 0.2243, Dice: 0.8727, Pixel acc: 0.9299, mIoU: 0.7992, LR: 0.000080


Epoch 5/20 [Train]: 100%|██████████| 115/115 [01:22<00:00,  1.40it/s, loss=0.1814, miou=0.8157, dice=0.8818, pixel acc=0.9407]


[Valid] Loss: 0.2083, Dice: 0.8714, Pixel acc: 0.9311, mIoU: 0.7977, LR: 0.000075


Epoch 6/20 [Train]: 100%|██████████| 115/115 [01:17<00:00,  1.47it/s, loss=0.1634, miou=0.8239, dice=0.8875, pixel acc=0.9441]


[Valid] Loss: 0.1982, Dice: 0.8747, Pixel acc: 0.9327, mIoU: 0.8021, LR: 0.000070


Epoch 7/20 [Train]: 100%|██████████| 115/115 [01:17<00:00,  1.49it/s, loss=0.1501, miou=0.8311, dice=0.8928, pixel acc=0.9469]


[Valid] Loss: 0.2014, Dice: 0.8738, Pixel acc: 0.9306, mIoU: 0.7998, LR: 0.000065


Epoch 8/20 [Train]: 100%|██████████| 115/115 [01:17<00:00,  1.48it/s, loss=0.1431, miou=0.8339, dice=0.8944, pixel acc=0.9480]


[Valid] Loss: 0.1961, Dice: 0.8737, Pixel acc: 0.9317, mIoU: 0.8005, LR: 0.000060


Epoch 9/20 [Train]: 100%|██████████| 115/115 [01:17<00:00,  1.49it/s, loss=0.1360, miou=0.8380, dice=0.8976, pixel acc=0.9495]


[Valid] Loss: 0.1927, Dice: 0.8752, Pixel acc: 0.9329, mIoU: 0.8026, LR: 0.000055


Epoch 10/20 [Train]: 100%|██████████| 115/115 [01:17<00:00,  1.48it/s, loss=0.1251, miou=0.8470, dice=0.9037, pixel acc=0.9529]


[Valid] Loss: 0.1932, Dice: 0.8768, Pixel acc: 0.9327, mIoU: 0.8046, LR: 0.000051


Epoch 11/20 [Train]: 100%|██████████| 115/115 [01:17<00:00,  1.49it/s, loss=0.1172, miou=0.8537, dice=0.9085, pixel acc=0.9554]


[Valid] Loss: 0.1904, Dice: 0.8774, Pixel acc: 0.9333, mIoU: 0.8057, LR: 0.000046


Epoch 12/20 [Train]: 100%|██████████| 115/115 [01:17<00:00,  1.48it/s, loss=0.1113, miou=0.8591, dice=0.9118, pixel acc=0.9572]


[Valid] Loss: 0.1919, Dice: 0.8756, Pixel acc: 0.9340, mIoU: 0.8037, LR: 0.000041


Epoch 13/20 [Train]: 100%|██████████| 115/115 [01:18<00:00,  1.47it/s, loss=0.1058, miou=0.8640, dice=0.9156, pixel acc=0.9590]


[Valid] Loss: 0.1939, Dice: 0.8744, Pixel acc: 0.9330, mIoU: 0.8017, LR: 0.000036


Epoch 14/20 [Train]: 100%|██████████| 115/115 [01:20<00:00,  1.44it/s, loss=0.1016, miou=0.8681, dice=0.9185, pixel acc=0.9603]


[Valid] Loss: 0.1947, Dice: 0.8769, Pixel acc: 0.9332, mIoU: 0.8047, LR: 0.000031


Epoch 15/20 [Train]: 100%|██████████| 115/115 [01:17<00:00,  1.48it/s, loss=0.0990, miou=0.8710, dice=0.9203, pixel acc=0.9613]


[Valid] Loss: 0.2001, Dice: 0.8745, Pixel acc: 0.9328, mIoU: 0.8021, LR: 0.000026


Epoch 16/20 [Train]: 100%|██████████| 115/115 [01:15<00:00,  1.53it/s, loss=0.0951, miou=0.8751, dice=0.9230, pixel acc=0.9627]


[Valid] Loss: 0.1965, Dice: 0.8779, Pixel acc: 0.9339, mIoU: 0.8064, LR: 0.000021


Epoch 17/20 [Train]: 100%|██████████| 115/115 [01:17<00:00,  1.49it/s, loss=0.0912, miou=0.8792, dice=0.9258, pixel acc=0.9641]


[Valid] Loss: 0.2024, Dice: 0.8763, Pixel acc: 0.9334, mIoU: 0.8043, LR: 0.000016


Epoch 18/20 [Train]: 100%|██████████| 115/115 [01:13<00:00,  1.55it/s, loss=0.0887, miou=0.8820, dice=0.9277, pixel acc=0.9651]


[Valid] Loss: 0.2022, Dice: 0.8771, Pixel acc: 0.9338, mIoU: 0.8057, LR: 0.000011


Epoch 19/20 [Train]: 100%|██████████| 115/115 [01:16<00:00,  1.49it/s, loss=0.0861, miou=0.8853, dice=0.9294, pixel acc=0.9661]


[Valid] Loss: 0.2049, Dice: 0.8758, Pixel acc: 0.9334, mIoU: 0.8038, LR: 0.000006


Epoch 20/20 [Train]: 100%|██████████| 115/115 [01:14<00:00,  1.54it/s, loss=0.0843, miou=0.8873, dice=0.9311, pixel acc=0.9668]


[Valid] Loss: 0.2038, Dice: 0.8763, Pixel acc: 0.9338, mIoU: 0.8045, LR: 0.000001


In [81]:
deeplab = DeepLabV3()
train_model(deeplab, train_loader, test_loader, epochs=20, learning_rate=5e-4, model_name="deeplab")

Epoch 1/20 [Train]: 100%|██████████| 115/115 [01:13<00:00,  1.56it/s, loss=0.3381, miou=0.6539, dice=0.7398, pixel acc=0.8675]


[Valid] Loss: 0.3385, Dice: 0.7720, Pixel acc: 0.8723, mIoU: 0.6779, LR: 0.000475


Epoch 2/20 [Train]: 100%|██████████| 115/115 [01:17<00:00,  1.48it/s, loss=0.2285, miou=0.7479, dice=0.8298, pixel acc=0.9113]


[Valid] Loss: 0.2307, Dice: 0.8355, Pixel acc: 0.9112, mIoU: 0.7519, LR: 0.000450


Epoch 3/20 [Train]: 100%|██████████| 115/115 [01:18<00:00,  1.47it/s, loss=0.2070, miou=0.7632, dice=0.8424, pixel acc=0.9182]


[Valid] Loss: 0.2326, Dice: 0.8369, Pixel acc: 0.9105, mIoU: 0.7529, LR: 0.000426


Epoch 4/20 [Train]: 100%|██████████| 115/115 [01:21<00:00,  1.41it/s, loss=0.1897, miou=0.7779, dice=0.8539, pixel acc=0.9247]


[Valid] Loss: 0.2314, Dice: 0.8433, Pixel acc: 0.9128, mIoU: 0.7605, LR: 0.000401


Epoch 5/20 [Train]: 100%|██████████| 115/115 [01:18<00:00,  1.46it/s, loss=0.1760, miou=0.7889, dice=0.8618, pixel acc=0.9297]


[Valid] Loss: 0.2415, Dice: 0.8390, Pixel acc: 0.9105, mIoU: 0.7559, LR: 0.000376


Epoch 6/20 [Train]: 100%|██████████| 115/115 [01:18<00:00,  1.46it/s, loss=0.1638, miou=0.7982, dice=0.8681, pixel acc=0.9338]


[Valid] Loss: 0.2106, Dice: 0.8531, Pixel acc: 0.9207, mIoU: 0.7740, LR: 0.000352


Epoch 7/20 [Train]: 100%|██████████| 115/115 [01:18<00:00,  1.46it/s, loss=0.1508, miou=0.8097, dice=0.8768, pixel acc=0.9387]


[Valid] Loss: 0.2005, Dice: 0.8633, Pixel acc: 0.9270, mIoU: 0.7887, LR: 0.000327


Epoch 8/20 [Train]: 100%|██████████| 115/115 [01:18<00:00,  1.46it/s, loss=0.1396, miou=0.8185, dice=0.8826, pixel acc=0.9425]


[Valid] Loss: 0.1971, Dice: 0.8638, Pixel acc: 0.9274, mIoU: 0.7890, LR: 0.000302


Epoch 9/20 [Train]: 100%|██████████| 115/115 [01:19<00:00,  1.45it/s, loss=0.1337, miou=0.8249, dice=0.8871, pixel acc=0.9448]


[Valid] Loss: 0.2057, Dice: 0.8646, Pixel acc: 0.9267, mIoU: 0.7896, LR: 0.000277


Epoch 10/20 [Train]: 100%|██████████| 115/115 [01:18<00:00,  1.46it/s, loss=0.1283, miou=0.8300, dice=0.8913, pixel acc=0.9468]


[Valid] Loss: 0.1984, Dice: 0.8661, Pixel acc: 0.9284, mIoU: 0.7916, LR: 0.000253


Epoch 11/20 [Train]: 100%|██████████| 115/115 [01:18<00:00,  1.46it/s, loss=0.1226, miou=0.8353, dice=0.8951, pixel acc=0.9488]


[Valid] Loss: 0.1994, Dice: 0.8667, Pixel acc: 0.9293, mIoU: 0.7929, LR: 0.000228


Epoch 12/20 [Train]: 100%|██████████| 115/115 [01:18<00:00,  1.46it/s, loss=0.1154, miou=0.8425, dice=0.9001, pixel acc=0.9515]


[Valid] Loss: 0.2024, Dice: 0.8690, Pixel acc: 0.9305, mIoU: 0.7958, LR: 0.000203


Epoch 13/20 [Train]: 100%|██████████| 115/115 [01:18<00:00,  1.46it/s, loss=0.1117, miou=0.8474, dice=0.9036, pixel acc=0.9532]


[Valid] Loss: 0.2070, Dice: 0.8693, Pixel acc: 0.9293, mIoU: 0.7966, LR: 0.000178


Epoch 14/20 [Train]: 100%|██████████| 115/115 [01:18<00:00,  1.46it/s, loss=0.1063, miou=0.8525, dice=0.9073, pixel acc=0.9551]


[Valid] Loss: 0.2055, Dice: 0.8700, Pixel acc: 0.9300, mIoU: 0.7968, LR: 0.000154


Epoch 15/20 [Train]: 100%|██████████| 115/115 [01:18<00:00,  1.46it/s, loss=0.1016, miou=0.8580, dice=0.9109, pixel acc=0.9570]


[Valid] Loss: 0.2018, Dice: 0.8732, Pixel acc: 0.9314, mIoU: 0.8017, LR: 0.000129


Epoch 16/20 [Train]: 100%|██████████| 115/115 [01:18<00:00,  1.46it/s, loss=0.0976, miou=0.8629, dice=0.9141, pixel acc=0.9587]


[Valid] Loss: 0.2063, Dice: 0.8712, Pixel acc: 0.9310, mIoU: 0.7996, LR: 0.000104


Epoch 17/20 [Train]: 100%|██████████| 115/115 [01:18<00:00,  1.46it/s, loss=0.0947, miou=0.8666, dice=0.9166, pixel acc=0.9599]


[Valid] Loss: 0.2096, Dice: 0.8731, Pixel acc: 0.9316, mIoU: 0.8014, LR: 0.000079


Epoch 18/20 [Train]: 100%|██████████| 115/115 [01:18<00:00,  1.46it/s, loss=0.0910, miou=0.8709, dice=0.9196, pixel acc=0.9614]


[Valid] Loss: 0.2122, Dice: 0.8717, Pixel acc: 0.9315, mIoU: 0.8000, LR: 0.000055


Epoch 19/20 [Train]: 100%|██████████| 115/115 [01:18<00:00,  1.46it/s, loss=0.0882, miou=0.8744, dice=0.9220, pixel acc=0.9626]


[Valid] Loss: 0.2146, Dice: 0.8714, Pixel acc: 0.9316, mIoU: 0.7995, LR: 0.000030


Epoch 20/20 [Train]: 100%|██████████| 115/115 [01:18<00:00,  1.46it/s, loss=0.0860, miou=0.8771, dice=0.9237, pixel acc=0.9635]


[Valid] Loss: 0.2181, Dice: 0.8716, Pixel acc: 0.9315, mIoU: 0.7998, LR: 0.000005


In [8]:
def DeepLabV3_50():
    model = smp.DeepLabV3(
        encoder_name="resnet50", 
        encoder_weights="imagenet",
        in_channels=3,
        classes=3
    )
    return model

deeplab50 = DeepLabV3_50()
train_model(deeplab50, train_loader, test_loader, epochs=15, learning_rate=1e-4, model_name="deeplab50")

Epoch 1/15 [Train]: 100%|██████████| 115/115 [18:25<00:00,  9.61s/it, loss=0.3270, miou=0.6624, dice=0.7447, pixel acc=0.8739]


[Valid] Loss: 0.2307, Dice: 0.8308, Pixel acc: 0.9123, mIoU: 0.7497, LR: 0.000093


Epoch 2/15 [Train]: 100%|██████████| 115/115 [18:11<00:00,  9.49s/it, loss=0.2039, miou=0.7676, dice=0.8441, pixel acc=0.9206]


[Valid] Loss: 0.2033, Dice: 0.8536, Pixel acc: 0.9226, mIoU: 0.7765, LR: 0.000087


Epoch 3/15 [Train]: 100%|██████████| 115/115 [17:51<00:00,  9.32s/it, loss=0.1736, miou=0.7912, dice=0.8624, pixel acc=0.9310]


[Valid] Loss: 0.1917, Dice: 0.8595, Pixel acc: 0.9265, mIoU: 0.7848, LR: 0.000080


Epoch 4/15 [Train]: 100%|██████████| 115/115 [17:53<00:00,  9.34s/it, loss=0.1573, miou=0.8051, dice=0.8724, pixel acc=0.9367]


[Valid] Loss: 0.1858, Dice: 0.8672, Pixel acc: 0.9292, mIoU: 0.7937, LR: 0.000074


Epoch 5/15 [Train]: 100%|██████████| 115/115 [17:29<00:00,  9.13s/it, loss=0.1437, miou=0.8169, dice=0.8808, pixel acc=0.9416]


[Valid] Loss: 0.1854, Dice: 0.8681, Pixel acc: 0.9305, mIoU: 0.7960, LR: 0.000067


Epoch 6/15 [Train]: 100%|██████████| 115/115 [18:00<00:00,  9.39s/it, loss=0.1335, miou=0.8265, dice=0.8877, pixel acc=0.9454]


[Valid] Loss: 0.1847, Dice: 0.8703, Pixel acc: 0.9317, mIoU: 0.7981, LR: 0.000060


Epoch 7/15 [Train]: 100%|██████████| 115/115 [17:51<00:00,  9.32s/it, loss=0.1250, miou=0.8350, dice=0.8940, pixel acc=0.9486]


[Valid] Loss: 0.1861, Dice: 0.8709, Pixel acc: 0.9315, mIoU: 0.7987, LR: 0.000054


Epoch 8/15 [Train]: 100%|██████████| 115/115 [17:31<00:00,  9.14s/it, loss=0.1183, miou=0.8417, dice=0.8986, pixel acc=0.9511]


[Valid] Loss: 0.1861, Dice: 0.8725, Pixel acc: 0.9316, mIoU: 0.8007, LR: 0.000047


Epoch 9/15 [Train]: 100%|██████████| 115/115 [17:36<00:00,  9.19s/it, loss=0.1119, miou=0.8486, dice=0.9039, pixel acc=0.9536]


[Valid] Loss: 0.1858, Dice: 0.8736, Pixel acc: 0.9328, mIoU: 0.8022, LR: 0.000041


Epoch 10/15 [Train]: 100%|██████████| 115/115 [17:43<00:00,  9.25s/it, loss=0.1073, miou=0.8540, dice=0.9071, pixel acc=0.9555]


[Valid] Loss: 0.1898, Dice: 0.8726, Pixel acc: 0.9324, mIoU: 0.8012, LR: 0.000034


Epoch 11/15 [Train]: 100%|██████████| 115/115 [17:51<00:00,  9.32s/it, loss=0.1033, miou=0.8585, dice=0.9108, pixel acc=0.9571]


[Valid] Loss: 0.1896, Dice: 0.8735, Pixel acc: 0.9331, mIoU: 0.8026, LR: 0.000027


Epoch 12/15 [Train]: 100%|██████████| 115/115 [18:02<00:00,  9.41s/it, loss=0.0994, miou=0.8630, dice=0.9136, pixel acc=0.9586]


[Valid] Loss: 0.1908, Dice: 0.8745, Pixel acc: 0.9331, mIoU: 0.8035, LR: 0.000021


Epoch 13/15 [Train]: 100%|██████████| 115/115 [17:37<00:00,  9.20s/it, loss=0.0955, miou=0.8676, dice=0.9169, pixel acc=0.9602]


[Valid] Loss: 0.1920, Dice: 0.8745, Pixel acc: 0.9330, mIoU: 0.8035, LR: 0.000014


Epoch 14/15 [Train]: 100%|██████████| 115/115 [18:06<00:00,  9.45s/it, loss=0.0925, miou=0.8715, dice=0.9193, pixel acc=0.9615]


[Valid] Loss: 0.1950, Dice: 0.8733, Pixel acc: 0.9327, mIoU: 0.8018, LR: 0.000008


Epoch 15/15 [Train]: 100%|██████████| 115/115 [17:54<00:00,  9.34s/it, loss=0.0906, miou=0.8739, dice=0.9211, pixel acc=0.9623]


[Valid] Loss: 0.1939, Dice: 0.8747, Pixel acc: 0.9333, mIoU: 0.8038, LR: 0.000001
