<a href="https://colab.research.google.com/github/manushree635/CV/blob/main/unetsegmentation_original.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Importing libraries


In [None]:
import torch
import os
import torch.nn as nn
import torchvision
import torchvision.transforms.functional as F
from PIL import Image
from torch.utils.data import Dataset
import numpy as np
from torch.utils.data import DataLoader
import albumentations as A
from albumentations.pytorch import ToTensorV2
from tqdm import tqdm
import torch.optim as optim
from distutils.file_util import copy_file


UNet Implementation
 


In [None]:
class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(DoubleConv, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, 1, 1),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, 3, 1, 1),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.conv(x)
    
    

In [None]:
class UNet(nn.Module):
    def __init__(self,in_channels=3,out_channels=1,channels=[64,128,256,512]):
        super(UNet,self).__init__()
        self.downs=nn.ModuleList()
        self.ups=nn.ModuleList()
        self.pool=nn.MaxPool2d(kernel_size=2,stride=2)


        for channel in channels:
            self.downs.append(DoubleConv(in_channels,channel))
            in_channels=channel

        for channel in reversed(channels):
            self.ups.append(nn.ConvTranspose2d(channel*2,channel,kernel_size=2,stride=2))
            self.ups.append(DoubleConv(channel*2,channel))
    
        self.bottleneck=DoubleConv(channels[-1],channels[-1]*2)
        self.final=nn.Conv2d(channels[0],out_channels,kernel_size=1)
        
     
    
    def forward(self, x):
        skip_connections = []

        for down in self.downs:
            x = down(x)
            skip_connections.append(x)
            x = self.pool(x)

        x = self.bottleneck(x)
        skip_connections = skip_connections[::-1]

        for i in range(0, len(self.ups), 2):
            x = self.ups[i](x)
            skip_connection = skip_connections[i//2]

            if x.shape != skip_connection.shape:
                x = F.resize(x, size=skip_connection.shape[2:])

            concat = torch.cat((skip_connection, x), dim=1)
            x = self.ups[i+1](concat)

        return self.final(x)


Splitting Dataset into Train and Valid Sets


In [None]:

trainPath = '/content/train'
maskPath = '/content/train_masks'

In [None]:
listt = []
listm = []



for root, directories, files in os.walk(trainPath):
    for name in files:
        listt.append(os.path.join(root, name))
        

for root, directories, files in os.walk(maskPath):
    for name in files:
        listm.append(os.path.join(root, name))

listt.sort()
listm.sort()



In [None]:
!mkdir ./Data

# Train data folders
!mkdir ./Data/Train/
!mkdir ./Data/Train/Images/
!mkdir ./Data/Train/Masks/


# Validation data folders
!mkdir ./Data/Val/
!mkdir ./Data/Val/Images/
!mkdir ./Data/Val/Masks/

In [None]:
pathImagesTrain = './Data/Train/Images/'
pathImagesValid = './Data/Val/Images/'


pathMasksTrain = './Data/Train/Masks/'
pathMasksValid = './Data/Val/Masks/'

In [None]:
for i in range(len(listt)):
    if i < 48:
        copy_file(listt[i],pathImagesValid )
        copy_file(listm[i],pathMasksValid )

    else:
        copy_file(listt[i],pathImagesTrain )
        copy_file(listm[i],pathMasksTrain )

Custom Dataset


In [None]:
class CarvanaDataset(Dataset):
  def __init__(self,image_dir,mask_dir,transform=None):
    self.image_dir=image_dir
    self.mask_dir=mask_dir
    self.transform=transform
    self.images=os.listdir(image_dir)
  
  def __len__(self):
    return len(self.images)

  def __getitem__(self,index):
    image_path=os.path.join(self.image_dir,self.images[index])
    mask_path=os.path.join(self.mask_dir,self.images[index].replace('.jpg','_mask.gif'))
    image=np.array(Image.open(image_path).convert("RGB"))
    mask=np.array(Image.open(mask_path).convert('L'),dtype=np.float32)
    mask[mask==255.0]=1.0
    if self.transform is not None:
      aug=self.transform(image=image,mask=mask)
      image=aug['image']
      mask=aug['mask']
    print(__len__(self))
      
    return image,mask




Loading Dataset

In [None]:
# Hyperparameters
LEARNING_RATE = 1e-5
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
batch_size = 16
num_epochs = 7

IMAGE_HEIGHT = 572  # 1280 originally
IMAGE_WIDTH = 572  # 1918 originally

LOAD_MODEL = False
train_img_dir = "/content/Data/Train/Images"
train_mask_dir = "/content/Data/Train/Masks"
val_img_dir = "/content/Data/Val/Images"
val_mask_dir="/content/Data/Val/Masks"



train_transform = A.Compose(
        [
            A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH),
            A.Rotate(limit=35, p=1.0),
            A.HorizontalFlip(p=0.5),
            A.VerticalFlip(p=0.1),
            A.Normalize(
                mean=[0.0, 0.0, 0.0],
                std=[1.0, 1.0, 1.0],
                max_pixel_value=255.0,
            ),
            ToTensorV2(),
        ],
    )

val_transform = A.Compose(
        [
            A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH),
            A.Normalize(
                mean=[0.0, 0.0, 0.0],
                std=[1.0, 1.0, 1.0],
                max_pixel_value=255.0,
            ),
            ToTensorV2(),
        ],
    )

model = UNet(in_channels=3, out_channels=1).to(DEVICE)
loss_fn = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)


train_ds = CarvanaDataset(
        image_dir=train_img_dir,
        mask_dir=train_mask_dir,
        transform=train_transform,
    )

train_loader = DataLoader(
        train_ds,
        batch_size=batch_size,
        shuffle=True,
    )

val_ds = CarvanaDataset(
        image_dir=val_img_dir,
        mask_dir=val_mask_dir,
        transform=val_transform,
    )

val_loader = DataLoader(
        val_ds,
        batch_size=batch_size,
        shuffle=False,
    )


Training and Validating the Model

In [None]:
def save_checkpoint(state, filename="my_checkpoint.pth.tar"):
    print("=> Saving checkpoint")
    torch.save(state, filename)

def load_checkpoint(checkpoint, model):
    print("=> Loading checkpoint")
    model.load_state_dict(checkpoint["state_dict"])


def check_accuracy(loader, model, device="cuda"):
    num_correct = 0
    num_pixels = 0
    dice_score = 0
    model.eval()

    with torch.no_grad():
        for x, y in loader:
            x = x.to(device)
            y = y.to(device).unsqueeze(1)
            preds = torch.sigmoid(model(x))
            preds = (preds > 0.5).float()
            num_correct += (preds == y).sum()
            num_pixels += torch.numel(preds)
            dice_score += (2 * (preds * y).sum()) / (
                (preds + y).sum() + 1e-8
            )

    print(
        f"Got {num_correct}/{num_pixels} with acc {num_correct/num_pixels*100:.2f}"
    )
    print(f"Dice score: {dice_score/len(loader)}")
    model.train()

def save_predictions_as_imgs(
    loader, model, folder="/content/Data/saved_images", device="cuda"
):
    model.eval()
    for idx, (x, y) in enumerate(loader):
        x = x.to(device=device)
        with torch.no_grad():
            preds = torch.sigmoid(model(x))
            preds = (preds > 0.5).float()
        torchvision.utils.save_image(
            preds, f"{folder}/pred_{idx}.png"
        )
        torchvision.utils.save_image(y.unsqueeze(1), f"{folder}{idx}.png")

    model.train()

In [None]:
check_accuracy(val_loader, model, device=DEVICE)
scaler = torch.cuda.amp.GradScaler()

for epoch in range(num_epochs):
    loop = tqdm(train_loader)

    for batch_idx, (data, targets) in enumerate(loop):
        data = data.to(device=DEVICE)
        targets = targets.float().unsqueeze(1).to(device=DEVICE)

        # forward
        with torch.cuda.amp.autocast():
            predictions = model(data)
            loss = loss_fn(predictions, targets)

        # backward
        optimizer.zero_grad()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        # update tqdm loop
        loop.set_postfix(loss=loss.item())



        # save model
    checkpoint = {
            "state_dict": model.state_dict(),
            "optimizer":optimizer.state_dict(),
        }
    save_checkpoint(checkpoint)

        # check accuracy
    check_accuracy(val_loader, model, device=DEVICE)

        # print some examples to a folder
    save_predictions_as_imgs(
            val_loader, model, folder="/content/Data/saved_images", device=DEVICE
        )


  0%|          | 0/315 [00:00<?, ?it/s][A

Got 15266345/15704832 with acc 97.21
Dice score: 0.939386785030365



  0%|          | 0/315 [00:03<?, ?it/s, loss=0.0817][A
  0%|          | 1/315 [00:03<17:10,  3.28s/it, loss=0.0817][A
  0%|          | 1/315 [00:06<17:10,  3.28s/it, loss=0.0865][A
  1%|          | 2/315 [00:06<17:06,  3.28s/it, loss=0.0865][A
  1%|          | 2/315 [00:09<17:06,  3.28s/it, loss=0.0808][A
  1%|          | 3/315 [00:09<17:01,  3.27s/it, loss=0.0808][A
  1%|          | 3/315 [00:13<17:01,  3.27s/it, loss=0.0761][A
  1%|▏         | 4/315 [00:13<16:59,  3.28s/it, loss=0.0761][A
  1%|▏         | 4/315 [00:16<16:59,  3.28s/it, loss=0.103] [A
  2%|▏         | 5/315 [00:16<16:57,  3.28s/it, loss=0.103][A
  2%|▏         | 5/315 [00:19<16:57,  3.28s/it, loss=0.0959][A
  2%|▏         | 6/315 [00:19<16:55,  3.29s/it, loss=0.0959][A
  2%|▏         | 6/315 [00:22<16:55,  3.29s/it, loss=0.0771][A
  2%|▏         | 7/315 [00:22<16:51,  3.28s/it, loss=0.0771][A
  2%|▏         | 7/315 [00:26<16:51,  3.28s/it, loss=0.0848][A
  3%|▎         | 8/315 [00:26<16:45,  3.28s/it, 

=> Saving checkpoint
Got 15361572/15704832 with acc 97.81
Dice score: 0.9503728747367859



  0%|          | 0/315 [00:00<?, ?it/s][A
  0%|          | 0/315 [00:03<?, ?it/s, loss=0.0798][A
  0%|          | 1/315 [00:03<17:17,  3.31s/it, loss=0.0798][A
  0%|          | 1/315 [00:06<17:17,  3.31s/it, loss=0.0751][A
  1%|          | 2/315 [00:06<17:15,  3.31s/it, loss=0.0751][A
  1%|          | 2/315 [00:09<17:15,  3.31s/it, loss=0.0787][A
  1%|          | 3/315 [00:09<17:13,  3.31s/it, loss=0.0787][A
  1%|          | 3/315 [00:13<17:13,  3.31s/it, loss=0.0771][A
  1%|▏         | 4/315 [00:13<17:13,  3.32s/it, loss=0.0771][A
  1%|▏         | 4/315 [00:16<17:13,  3.32s/it, loss=0.0692][A
  2%|▏         | 5/315 [00:16<17:08,  3.32s/it, loss=0.0692][A
  2%|▏         | 5/315 [00:19<17:08,  3.32s/it, loss=0.0625][A
  2%|▏         | 6/315 [00:19<17:05,  3.32s/it, loss=0.0625][A
  2%|▏         | 6/315 [00:23<17:05,  3.32s/it, loss=0.0696][A
  2%|▏         | 7/315 [00:23<17:00,  3.31s/it, loss=0.0696][A
  2%|▏         | 7/315 [00:26<17:00,  3.31s/it, loss=0.0723][A
  3%

=> Saving checkpoint
Got 15384308/15704832 with acc 97.96
Dice score: 0.9556411504745483



  0%|          | 0/315 [00:00<?, ?it/s][A
  0%|          | 0/315 [00:03<?, ?it/s, loss=0.0624][A
  0%|          | 1/315 [00:03<17:06,  3.27s/it, loss=0.0624][A
  0%|          | 1/315 [00:06<17:06,  3.27s/it, loss=0.0674][A
  1%|          | 2/315 [00:06<17:09,  3.29s/it, loss=0.0674][A
  1%|          | 2/315 [00:09<17:09,  3.29s/it, loss=0.0551][A
  1%|          | 3/315 [00:09<17:08,  3.30s/it, loss=0.0551][A
  1%|          | 3/315 [00:13<17:08,  3.30s/it, loss=0.0682][A
  1%|▏         | 4/315 [00:13<17:05,  3.30s/it, loss=0.0682][A
  1%|▏         | 4/315 [00:16<17:05,  3.30s/it, loss=0.0601][A
  2%|▏         | 5/315 [00:16<17:04,  3.30s/it, loss=0.0601][A
  2%|▏         | 5/315 [00:19<17:04,  3.30s/it, loss=0.0546][A
  2%|▏         | 6/315 [00:19<17:02,  3.31s/it, loss=0.0546][A
  2%|▏         | 6/315 [00:23<17:02,  3.31s/it, loss=0.0615][A
  2%|▏         | 7/315 [00:23<16:58,  3.31s/it, loss=0.0615][A
  2%|▏         | 7/315 [00:26<16:58,  3.31s/it, loss=0.078] [A
  3%

=> Saving checkpoint
Got 15409392/15704832 with acc 98.12
Dice score: 0.9592427015304565



  0%|          | 0/315 [00:00<?, ?it/s][A
  0%|          | 0/315 [00:03<?, ?it/s, loss=0.0467][A
  0%|          | 1/315 [00:03<17:17,  3.30s/it, loss=0.0467][A
  0%|          | 1/315 [00:06<17:17,  3.30s/it, loss=0.0595][A
  1%|          | 2/315 [00:06<17:14,  3.30s/it, loss=0.0595][A
  1%|          | 2/315 [00:09<17:14,  3.30s/it, loss=0.0446][A
  1%|          | 3/315 [00:09<17:12,  3.31s/it, loss=0.0446][A
  1%|          | 3/315 [00:13<17:12,  3.31s/it, loss=0.0533][A
  1%|▏         | 4/315 [00:13<17:10,  3.31s/it, loss=0.0533][A
  1%|▏         | 4/315 [00:16<17:10,  3.31s/it, loss=0.0437][A
  2%|▏         | 5/315 [00:16<17:10,  3.32s/it, loss=0.0437][A
  2%|▏         | 5/315 [00:19<17:10,  3.32s/it, loss=0.0584][A
  2%|▏         | 6/315 [00:19<17:06,  3.32s/it, loss=0.0584][A
  2%|▏         | 6/315 [00:23<17:06,  3.32s/it, loss=0.056] [A
  2%|▏         | 7/315 [00:23<17:04,  3.33s/it, loss=0.056][A
  2%|▏         | 7/315 [00:26<17:04,  3.33s/it, loss=0.0502][A
  3%|

=> Saving checkpoint
Got 15463085/15704832 with acc 98.46
Dice score: 0.9655752182006836



  0%|          | 0/315 [00:00<?, ?it/s][A
  0%|          | 0/315 [00:03<?, ?it/s, loss=0.0578][A
  0%|          | 1/315 [00:03<17:03,  3.26s/it, loss=0.0578][A
  0%|          | 1/315 [00:06<17:03,  3.26s/it, loss=0.0415][A
  1%|          | 2/315 [00:06<17:03,  3.27s/it, loss=0.0415][A
  1%|          | 2/315 [00:09<17:03,  3.27s/it, loss=0.0497][A
  1%|          | 3/315 [00:09<17:04,  3.28s/it, loss=0.0497][A
  1%|          | 3/315 [00:13<17:04,  3.28s/it, loss=0.0476][A
  1%|▏         | 4/315 [00:13<17:03,  3.29s/it, loss=0.0476][A
  1%|▏         | 4/315 [00:16<17:03,  3.29s/it, loss=0.0394][A
  2%|▏         | 5/315 [00:16<17:00,  3.29s/it, loss=0.0394][A
  2%|▏         | 5/315 [00:19<17:00,  3.29s/it, loss=0.0411][A
  2%|▏         | 6/315 [00:19<16:57,  3.29s/it, loss=0.0411][A
  2%|▏         | 6/315 [00:23<16:57,  3.29s/it, loss=0.0442][A
  2%|▏         | 7/315 [00:23<16:54,  3.29s/it, loss=0.0442][A
  2%|▏         | 7/315 [00:26<16:54,  3.29s/it, loss=0.0423][A
  3%

=> Saving checkpoint
Got 15474143/15704832 with acc 98.53
Dice score: 0.967836320400238



  0%|          | 0/315 [00:00<?, ?it/s][A
  0%|          | 0/315 [00:03<?, ?it/s, loss=0.0347][A
  0%|          | 1/315 [00:03<17:28,  3.34s/it, loss=0.0347][A
  0%|          | 1/315 [00:06<17:28,  3.34s/it, loss=0.0423][A
  1%|          | 2/315 [00:06<17:21,  3.33s/it, loss=0.0423][A
  1%|          | 2/315 [00:09<17:21,  3.33s/it, loss=0.0417][A
  1%|          | 3/315 [00:09<17:16,  3.32s/it, loss=0.0417][A
  1%|          | 3/315 [00:13<17:16,  3.32s/it, loss=0.0444][A
  1%|▏         | 4/315 [00:13<17:09,  3.31s/it, loss=0.0444][A
  1%|▏         | 4/315 [00:16<17:09,  3.31s/it, loss=0.04]  [A
  2%|▏         | 5/315 [00:16<17:05,  3.31s/it, loss=0.04][A
  2%|▏         | 5/315 [00:19<17:05,  3.31s/it, loss=0.0413][A
  2%|▏         | 6/315 [00:19<17:01,  3.31s/it, loss=0.0413][A
  2%|▏         | 6/315 [00:23<17:01,  3.31s/it, loss=0.0383][A
  2%|▏         | 7/315 [00:23<16:57,  3.30s/it, loss=0.0383][A
  2%|▏         | 7/315 [00:26<16:57,  3.30s/it, loss=0.038] [A
  3%|▎

=> Saving checkpoint
Got 15491492/15704832 with acc 98.64
Dice score: 0.9700021743774414



  0%|          | 0/315 [00:00<?, ?it/s][A
  0%|          | 0/315 [00:03<?, ?it/s, loss=0.035][A
  0%|          | 1/315 [00:03<17:45,  3.39s/it, loss=0.035][A
  0%|          | 1/315 [00:06<17:45,  3.39s/it, loss=0.0351][A
  1%|          | 2/315 [00:06<17:40,  3.39s/it, loss=0.0351][A
  1%|          | 2/315 [00:10<17:40,  3.39s/it, loss=0.0478][A
  1%|          | 3/315 [00:10<17:36,  3.39s/it, loss=0.0478][A
  1%|          | 3/315 [00:13<17:36,  3.39s/it, loss=0.037] [A
  1%|▏         | 4/315 [00:13<17:33,  3.39s/it, loss=0.037][A
  1%|▏         | 4/315 [00:16<17:33,  3.39s/it, loss=0.0441][A
  2%|▏         | 5/315 [00:16<17:30,  3.39s/it, loss=0.0441][A
  2%|▏         | 5/315 [00:20<17:30,  3.39s/it, loss=0.0363][A
  2%|▏         | 6/315 [00:20<17:26,  3.39s/it, loss=0.0363][A
  2%|▏         | 6/315 [00:23<17:26,  3.39s/it, loss=0.0416][A
  2%|▏         | 7/315 [00:23<17:24,  3.39s/it, loss=0.0416][A
  2%|▏         | 7/315 [00:27<17:24,  3.39s/it, loss=0.0382][A
  3%|▎ 

=> Saving checkpoint
Got 15504726/15704832 with acc 98.73
Dice score: 0.9719277620315552
