<a href="https://colab.research.google.com/github/deveshdatwani/unet/blob/main/Unet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(f'DEVICE IS {DEVICE}')

DEVICE IS cuda:0


# Lib Import

In [2]:
from google.colab import drive
from torchvision import transforms
from torch.utils.data import DataLoader
drive.mount('/content/drive')
import torch
from torch.utils.data import Dataset
from torchvision import datasets
from torchvision.transforms import ToTensor, Resize, RandomRotation
import torchvision.transforms.functional as TF
import matplotlib.pyplot as plt
import pandas as pd
import os
from matplotlib import pyplot as plt
import cv2
from torchvision.io import read_image, ImageReadMode as iomode
from tqdm import tqdm
from torch import nn
from torch.optim import SGD, Adam, RMSprop
from torch.autograd import Variable
from torch.nn import MSELoss
import numpy as np
from PIL import Image


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


# Model

In [3]:
import torch
import torch.nn as nn
import torchvision.transforms.functional as TF
from torchvision.transforms import ToTensor, Resize, InterpolationMode, Normalize

class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(DoubleConv, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, 1, 1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, 3, 1, 1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        return self.conv(x)

class UNet(nn.Module):
    def __init__(
            self, in_channels=3, out_channels=1, features=[64, 128, 256, 512],
    ):
        super(UNet, self).__init__()
        self.ups = nn.ModuleList()
        self.downs = nn.ModuleList()
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)

        # Down part of UNET
        for feature in features:
            self.downs.append(DoubleConv(in_channels, feature))
            in_channels = feature

        # Up part of UNET
        for feature in reversed(features):
            self.ups.append(
                nn.ConvTranspose2d(
                    feature*2, feature, kernel_size=2, stride=2,
                )
            )
            self.ups.append(DoubleConv(feature*2, feature))

        self.bottleneck = DoubleConv(features[-1], features[-1]*2)
        self.final_conv = (nn.Conv2d(features[0], out_channels, kernel_size=1))
        self.sigmoid = nn.Sigmoid()
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        skip_connections = []

        for down in self.downs:
            x = down(x)
            skip_connections.append(x)
            x = self.pool(x)

        x = self.bottleneck(x)
        skip_connections = skip_connections[::-1]

        for idx in range(0, len(self.ups), 2):
            x = self.ups[idx](x)
            skip_connection = skip_connections[idx//2]

            if x.shape != skip_connection.shape:
                x = TF.resize(x, size=skip_connection.shape[2:], interpolation=InterpolationMode.BILINEAR, antialias=True)

            concat_skip = torch.cat((skip_connection, x), dim=1)
            x = self.ups[idx+1](concat_skip)

        return self.sigmoid(self.final_conv(x))

# Dataset

In [4]:
class ResizeSample(object):
    def __init__(self):
        self.image_size = [572, 572]
        self.mask_size = [572, 572]

    def __call__(self, sample):
        image = sample['image']
        mask = sample['mask']
        image = Resize(self.image_size, interpolation=InterpolationMode.BILINEAR, antialias=True )(image)
        mask = Resize(self.mask_size, interpolation=InterpolationMode.BILINEAR, antialias=True )(mask) 
        sample = {'image': image, 'mask': mask}
        
        return sample 


class ToTensor(object):
    def __init__(self):
        pass

    def __call__(self, sample):
        image, mask = sample['image'], sample['mask']
        image = image.transpose((2,0,1))    

        return {'image': torch.from_numpy(image), 'mask':torch.from_numpy(mask).unsqueeze(0)}


class NormalizeImage():
    def __init__(self, mean, std):
        self.mean = mean
        self.std = std
    
    def __call__(self, sample):
        image = sample['image'] 
        mask = sample['mask']
        image_normalized = Normalize(self.mean, self.std)(image)
        mask_normalized = mask 
        # Normalize(self.mean, self.std)(mask)
        sample = {'image': image_normalized, 'mask': mask_normalized}

        return sample

class Rotate(object):
    def __init__(self, degrees=(0,180)):
        self.degrees = degrees
        self.rotator = RandomRotation(degrees)

    def __call__(self, sample):
        instance = sample['image']
        target = sample['mask']
        
        angle = np.random.choice([-30.0, -15.0, 0.0, 15.0, 30.0])
        
        instance = TF.rotate(instance, angle)
        target = TF.rotate(target, angle)
        sample = {'image':instance, 'mask':target}

        return sample



class Caravan(Dataset):
    def __init__(self, root_dir, json=None):
       self.root_dir = root_dir
       self.train_images_dir = os.path.join(root_dir, "train")
       self.train_images = os.listdir(self.train_images_dir)
       self.sample_transform = transforms.Compose([ToTensor(), ResizeSample()])


    def __len__(self):
        return len(os.listdir(self.train_images_dir))


    def transform(self, sample):
        sample = self.sample_transform(sample)

        return sample 
    

    def get_mask_address(self, image_address):
        image_name = image_address.split('/')[-1]
        mask_address = self.root_dir + '/train_masks/' + image_name[:-4] + '_mask.gif'

        return mask_address


    def __getitem__(self, idx):
        image_address = os.path.join(self.train_images_dir, self.train_images[idx])
        mask_address = self.get_mask_address(image_address)   
        image = np.array(Image.open(image_address), dtype=np.float32) / 255.0
        mask = np.array(Image.open(mask_address), dtype=np.float32)
        
        sample = {'image': image, 'mask': mask}
        sample = self.sample_transform(sample)
        
        return sample

# Loss Function

In [5]:
class DiceLoss(nn.Module):
    def __init__(self, smooth=0.01):
        super(DiceLoss, self).__init__()
        self.smooth = smooth
        self.loss = nn.CrossEntropyLoss()

    
    def forward(self, pred, target, epsilon=1e-8):
        numerator = 2 * torch.sum(pred * target)
        denominator = torch.sum(pred + target) + epsilon
        dice_coeff = numerator / denominator
        loss = 1 - dice_coeff
    
        return loss


    # def forward(self, prediction, target):
    #     # prediction[prediction > 0.5] = 1
    #     # target[target > 0.5] = 1
    #     e = 1e-4

    #     prediction_flat = prediction.view(-1)
    #     target_flat = target.view(-1)

    #     intersection = (prediction_flat).sum()
    #     prediction_sum  = (prediction_flat).sum()
    #     target_sum = (target_flat).sum()
    #     loss = 1 - (((2 * intersection) + e) / ((prediction_sum + target_sum)) + e) 

    #     return loss

# Dataloader

In [6]:
class Loader(object):
    def __init__(self):
        pass

    def __call__(self, dataset, batch_size):
        return DataLoader(dataset, batch_size, shuffle=True, num_workers=2)

# Training Loop

In [None]:
# TRAINING LOOP
class Trainer(object):
    def __init__(self, model=None, epochs=16, batch_size=2, path=None):
      
        self.model = model
        self.epochs = epochs
        self.batch_size = batch_size
        self.dataset = Caravan(PATH)
        self.dataloader = Loader()
        self.data_loader = self.dataloader(dataset=self.dataset, batch_size=self.batch_size)
        self.optim = SGD(params=model.parameters(), lr=0.001, momentum=0.99)
        self.criterian = DiceLoss()
        self.checkpoint_loss = float('inf')
        # self.criterian = nn.CrossEntropyLoss()


    def __call__(self):
        DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
        print(f'DEVICE IS {DEVICE}')
        
        for epoch_number in tqdm(range(self.epochs)):
            running_loss = 0
            
            for i, data in enumerate(self.data_loader):
                instance, mask = data['image'], data['mask']
                instance = torch.nn.functional.normalize(instance)
                instance = instance.to(device=DEVICE)
                mask = mask.to(device=DEVICE)
                self.optim.zero_grad()
                prediction = model(instance)
                loss = self.criterian(prediction, mask)               
                loss.backward()
                self.optim.step()
                running_loss += loss.item()
        
                
                if i % 5 == 0:
                    print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                         epoch_number, i * len(data), len(self.dataset),
                         100. * i / len(self.data_loader), running_loss / 5))
                    running_loss = 0


                if running_loss < self.checkpoint_loss:
                    torch.save({
                                'epoch': epoch_number,
                                'model_state_dict': model.state_dict(),
                                'optimizer_state_dict': self.optim.state_dict(),
                                'loss': running_loss,
                                }, CHECKPOINT_PATH)
                    self.checkpoint_loss = running_loss


if __name__ == "__main__":
    EPOCHS = 10
    BATCH_SIZE = 2
    PATH = '/content/drive/MyDrive/Caravan'
    CHECKPOINT_PATH = 'model.pt'

    model = UNet()
    model.to(device=DEVICE)
    trainer = Trainer(model=model, batch_size=BATCH_SIZE, epochs=EPOCHS, path=PATH)
    trainer()
    

DEVICE IS cuda:0


  0%|          | 0/10 [00:00<?, ?it/s]

# Inference

In [None]:
CHECKPOIN_PATH = 'model.pt'
dataset = Caravan(PATH)
dataloader = Loader()
data_loader = dataloader(dataset=dataset, batch_size=1)
sample = next(iter(data_loader)) 

In [None]:
model = UNet()
checkpoint = torch.load(CHECKPOIN_PATH)
model.load_state_dict(checkpoint['model_state_dict'])

In [None]:
instance = sample['image']
mask = sample['mask']

In [None]:
pred = model(instance)

In [None]:
plt.imshow(pred[0].detach().cpu().numpy().transpose(1,2,0), cmap='gray')

In [None]:
plt.imshow(mask[0].detach().cpu().numpy().transpose(1,2,0), cmap='gray')

In [None]:
pred.shape