<a href="https://colab.research.google.com/github/kkotsche1/SMP-Binary-Image-Segmentation-Training/blob/main/Unet%2B%2B_Segmentation_Training.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install segmentation-models-pytorch
import segmentation_models_pytorch as smp
from segmentation_models_pytorch import utils

import torchvision
import torchvision.transforms as transforms
import torch
import os
import matplotlib.pyplot as plt
import numpy as np
import shutil

import cv2
import zipfile
import pandas as pd
from PIL import Image
from tqdm.auto import tqdm
import matplotlib.pyplot as plt

import torch
import torchvision
import torch.nn as nn
import albumentations as A
import torch.optim as optim
from torchvision import models
import torch.nn.functional as F
from torch.optim import lr_scheduler
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torchvision.datasets import ImageFolder
from albumentations.pytorch import ToTensorV2 
from torch.utils.data import DataLoader, Dataset
from glob import glob

In [None]:
## This is how you should be formatting the file structure ##

#Code is included to augment the images by adding this such as rotation etc. which are then saved in the Train_Augment folder

#Images should be in .jpg format

# /content/Train
# /content/Train/Images
# /content/Train/Masks

# /content/Train_Augment
# /content/Train_Augment/Images
# /content/Train_Augment/Masks

# /content/Test
# /content/Test/Images
# /content/Test/Masks

In [None]:
#Resizing Images

from tqdm import tqdm

for dir in os.listdir("/content/Train"):
  for file in tqdm(os.listdir("/content/Train/" + dir)):
    try:
      image = Image.open("/content/Train/" + dir + "/" + file)
      image = image.resize((512,512),Image.ANTIALIAS)
      image.save(fp="/content/Train/" + dir + "/" + file)
    except:
      print(file)

for dir in tqdm(os.listdir("/content/Test")):
  for file in os.listdir("/content/Test/" + dir):
    try:
      image = Image.open("/content/Test/" + dir + "/" + file)
      image = image.resize((512,512),Image.ANTIALIAS)
      image.save(fp="/content/Test/" + dir + "/" + file)
    except:
      print(file)



In [None]:
from tqdm import tqdm
import cv2
import imageio
from albumentations import HorizontalFlip, VerticalFlip, Rotate

def load_data():
  train_x = sorted(glob(os.path.join("/content/Train/Images/", "*.jpg")))
  train_y = sorted(glob(os.path.join("/content/Train/Masks/", "*.png")))

  return train_x, train_y

def augment_data (images,masks,save_path, augment=True):

  for idx, (x,y) in tqdm(enumerate(zip(images, masks)), total=len(images)):
    #Extracting the Name of the file

    name_x = x.split("/")[-1].split(".")[0]
    name_y = y.split("/")[-1].split(".")[0]
    
    #Reading Image and Mask 

    x = cv2.imread(x, cv2.COLOR_BGR2RGB)
    y = imageio.imread(y)

    if augment:
      
      aug = HorizontalFlip(p=1.0)
      augmented = aug(image=x, mask=y)
      x1 = augmented["image"]
      y1 = augmented["mask"]

      aug = VerticalFlip(p=1.0)
      augmented = aug(image=x, mask=y)
      x2 = augmented["image"]
      y2 = augmented["mask"]

      aug = Rotate(limit=360, p=1.0)
      augmented = aug(image=x, mask=y)
      x3 = augmented["image"]
      y3 = augmented["mask"]

      X = [x, x1, x2, x3]
      Y = [y, y1, y2, y3]

      #X = [x, x1, x2]
      #Y = [y, y1, y2]


      # X = [x, x3]
      # Y = [y, y3]

    else:
      
      X = [x]
      Y = [y]
    
    index = 0

    for i,m in zip(X,Y):

      tmp_image_name = f"{name_x}_{index}.png"
      tmp_mask_name = f"{name_y}_{index}.png"

      index = index +1

      image_path = os.path.join(save_path, "Images", tmp_image_name)
      mask_path = os.path.join(save_path, "Masks", tmp_mask_name)

      cv2.imwrite(image_path, i)
      cv2.imwrite(mask_path, m)

train_x,train_y = load_data()
augment_data(train_x, train_y, "/content/Train_Augment/", augment = True)

In [None]:
class ROOTDIR:
    train = "/content/Train_Augment/Images"
    train_mask = "/content/Train_Augment/Masks"
    test = "/content/Test/Images"
    test_mask = "/content/Test/Masks"

In [None]:
train_img_lst = os.listdir(ROOTDIR.train) # "./train"
train_mask_lst = os.listdir(ROOTDIR.train_mask) # "./train_masks"

sorted_test_img_lst = sorted(os.listdir(ROOTDIR.test))
sorted_train_img_lst = sorted(train_img_lst)

permuted_test_img_lst = np.random.permutation(np.array(sorted_test_img_lst))
permuted_test_mask_lst = [x.replace(".jpg", ".png") for x in permuted_test_img_lst]

permuted_train_img_lst = np.random.permutation(np.array(sorted_train_img_lst))
permuted_train_mask_lst = [x.replace(".jpg", ".png") for x in permuted_train_img_lst]

In [None]:
seed = 123
np.random.seed(seed)
torch.manual_seed(seed)

class CFG:
    device = "cuda"
    split_pct = 0.1
    learning_rate = 5e-4
    batch_size = 32
    epochs = 4

In [None]:
train_images_list = permuted_train_img_lst[int(CFG.split_pct*len(permuted_train_img_lst)) :]
train_masks_list = permuted_train_mask_lst[int(CFG.split_pct*len(permuted_train_mask_lst)) :]

val_images_list = permuted_train_img_lst[: int(CFG.split_pct*len(permuted_train_img_lst))]
val_masks_list = permuted_train_mask_lst[: int(CFG.split_pct*len(permuted_train_mask_lst))]

In [None]:
class Dataset(Dataset):
    def __init__(self,img_list,mask_list,transform=None):
        self.img_list = img_list
        self.mask_list = mask_list
        self.transform = transform
        
    def __len__(self):
        return len(self.img_list)
    
    def __getitem__(self,index):
        img_path = os.path.join(ROOTDIR.train,self.img_list[index])
        mask_path = os.path.join(ROOTDIR.train_mask,self.mask_list[index])
        img = Image.open(img_path)
        mask = Image.open(mask_path)
        mask = mask.convert("L")
        img = np.array(img)
        mask = np.array(mask)
        mask = mask / 255.0
        #img_mask_dict = {"image": img, "mask": mask}
        
        if self.transform:
            augmentation = self.transform(image=img, mask=mask)
            img = augmentation["image"]
            mask = augmentation["mask"]
            mask = torch.unsqueeze(mask,0)
            #transformations = self.transform(image=img, mask=mask)
            #img = transformations["image"]
            #mask = transformations["mask"]
            
        return img,mask

In [None]:
train_transform = A.Compose([A.Resize(224,224),
                            A.Rotate(limit=75, p=0.9, border_mode = cv2.BORDER_REFLECT),
                            A.RandomBrightnessContrast (brightness_limit=0.1, contrast_limit=0.1, brightness_by_max=True, always_apply=False, p=0.5), 
                            A.Normalize(mean=(0.485, 0.456, 0.406),std=(0.229, 0.224, 0.225)),
                           ToTensorV2()])

test_transform = A.Compose([A.Resize(224,224),
    A.Normalize(mean=(0.485, 0.456, 0.406),std=(0.229, 0.224, 0.225)),                             
                           ToTensorV2()])

train_dataset = Dataset(train_images_list, train_masks_list, transform = train_transform)
val_dataset = Dataset(val_images_list, val_masks_list, transform = test_transform)

train_dataloader = DataLoader(train_dataset,batch_size=CFG.batch_size,shuffle=True, num_workers=2)
val_dataloader = DataLoader(val_dataset,batch_size=CFG.batch_size,shuffle=False, num_workers=2)

In [None]:
def train_model(model,dataloader,criterion,optimizer):
    model.train()
    train_running_loss = 0.0
    for j,img_mask in enumerate(tqdm(dataloader)):
        img = img_mask[0].float().to(CFG.device)
        #print(" ----- IMAGE -----")
        #print(img)
        mask = img_mask[1].float().to(CFG.device)
        #print(" ----- MASK -----")
        #print(mask)
        
        y_pred = model(img)
        #print(" ----- Y PRED -----")
        #print(y_pred)
        #print(" ----- Y PRED SHAPE -----")#
        #print(y_pred.shape)
        optimizer.zero_grad()
        
        loss = criterion(y_pred,mask)
        
        train_running_loss += loss.item() * CFG.batch_size
        
        loss.backward()
        optimizer.step()
        
    train_loss = train_running_loss / (j+1)
    return train_loss



def val_model(model,dataloader,criterion,optimizer):
    model.eval()
    val_running_loss = 0
    with torch.no_grad():
        for j,img_mask in enumerate(tqdm(dataloader)):
            img = img_mask[0].float().to(CFG.device)
            mask = img_mask[1].float().to(CFG.device)
            y_pred = model(img)
            loss = criterion(y_pred,mask)
            
            val_running_loss += loss.item() * CFG.batch_size
            
        val_loss = val_running_loss / (j+1)
    return val_loss

In [None]:
model = smp.UnetPlusPlus(
    encoder_name="timm-mobilenetv3_large_100",        # choose encoder, e.g. mobilenet_v2 or efficientnet-b7
    encoder_weights="imagenet",     # use `imagenet` pre-trained weights for encoder initialization
    in_channels=3,                  # model input channels (1 for gray-scale images, 3 for RGB, etc.)
    classes=1,
    decoder_attention_type="scse",
)

loss = smp.utils.losses.DiceLoss()

criterion = nn.BCEWithLogitsLoss()
metrics = [smp.utils.metrics.IoU(threshold=0.5),]

optimizer = torch.optim.Adam([dict(params=model.parameters(), lr=0.0001),])

optimizer = optim.Adam(model.parameters(), lr = 1e-4)

In [None]:
train_loss_lst = [999]
val_loss_lst = [999]


for i in range(100):
    train_loss = train_model(model=model,dataloader=train_dataloader,criterion=criterion,optimizer=optimizer)
    val_loss = val_model(model=model,dataloader=val_dataloader,criterion=criterion,optimizer=optimizer)
    print(val_loss)
    print ("Train Loss: ", train_loss_lst)
    print("Val Loss: ", val_loss_lst)

    lower = False

    if val_loss < val_loss_lst[-1]: 
      lower = True
      train_loss_lst.append(train_loss)
      val_loss_lst.append(val_loss)
    
    if lower:
      print("MODEL IMPROVED! ;)")
      torch.save(model.state_dict(), f"/content/drive/MyDrive/seg_unet++_mobinetv3_size256x256_adam_{val_loss:.4f}.pth")