In [None]:
# UNet Architecture in Py Torch from scratch

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
!nvidia-smi

/bin/bash: line 1: nvidia-smi: command not found


In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from PIL import Image
from torchvision import transforms
from torch.utils.data import DataLoader,random_split
from torch.utils.data.dataset import Dataset
from torchvision.transforms import transforms
import os
from tqdm import tqdm
import time
from matplotlib import pyplot as plt

In [None]:
class DoubleConv(nn.Module):
  def __init__(self,in_channels,out_channels):
    super().__init__()
    self.conv = nn.Sequential(
        nn.Conv2d(in_channels,out_channels,kernel_size = 3),
        nn.ReLU(),
        nn.Conv2d(out_channels,out_channels,kernel_size=3),
        nn.ReLU()
    )
  def forward(self,x):
    return self.conv(x)

In [None]:
class DownSample(nn.Module):
  def __init__(self,in_channels,out_channels):
    super().__init__()
    self.conv = DoubleConv(in_channels,out_channels)
    self.pool = nn.MaxPool2d(kernel_size = 2,strid = 2)

  def forward(self,x):
    down = self.conv(x)
    p = self.pool(down)
    return down,p

In [None]:
class UpSample(nn.Module):
  def __init__(self,in_channels,out_channels):
    super().__init__()
    self.up = nn.ConvTranspose2d(in_channels,in_channels // 2,kernel_size=2,stride=2)
    self.conv = DoubleConv(in_channels,out_channels)

  def forward(self,x1,x2):
    x1 = self.up(x1)

    x = torch.cat([x1,x2],1)

    return self.conv(x)

In [None]:
class UNet(nn.Module):
  def __init__(self,in_channels,num_classes):
    super().__init__()
    self.down_conv1 = DownSample(in_channels,64)
    self.down_conv2 = DownSample(64,128)
    self.down_conv3 = DownSample(128,256)
    self.down_conv4 = DownSample(256,512)


    self.bottle_neck = DownSample(512,1024)

    self.up_conv1 = UpSample(1024,512)
    self.up_conv2 = UpSample(512,256)
    self.up_conv3 = UpSample(256,128)
    self.up_conv4 = UpSample(128,64)

    self.out = nn.Conv2d(64,out_channels=num_classes, kernel_size=1)

  def forward(self,x):
    x1,s1 = self.down_conv1(x)
    x2,s2 = self.down_conv2(s1)
    x3,s3 = self.down_conv3(s2)
    x4,s4 = self.down_conv4(s3)

    x5 = self.bottle_neck(s4)

    y1 = self.up_conv1(x5,x4)
    y2 = self.up_conv2(y1,x3)
    y3 = self.up_conv(y2,x2)
    y4 = self.up_conv4(y3,x1)

    return self.out(y4)

In [None]:
class CarvanaDataset(Dataset):
  def __init__(self,root_path,test=False):
    self.root_path = root_path
    if test:
      self.images = sorted([root_path + "/test/" + filename for filename in os.listdir(os.path.join(root_path,"test/"))])
      self.masks = sorted([root_path + "/test_masks/" + filename for filename in os.listdir(os.path.join(root_path,"test_masks/"))])
    else:
      self.images =  sorted([root_path + "/test/" + filename for filename in os.listdir(os.path.join(root_path,"train/"))])
      self.masks = sorted([root_path + "/train_masks/" + filename for filename in os.listdir(os.path.join(root_path,"train_masks/"))])

    self.transform = transforms.Compose([
        transforms.Resize((572,572)),
        transforms.ToTensor()
    ])

  def __getitem__(self,idx):
     image = Image.open(self.images[idx]).convert('RGB')
     mask = Image.open(self.masks[idx]).convert('L')
     return self.transform(image),self.transform(mask)

  def __len__(self):
    return len(self.images)

In [None]:
def train(model,train_loader,optimizer,criterion,start_epoch,num_epochs,MODEL_SAVE_PATH,device):
  for epoch in tqdm(range(start_epoch,num_epochs)):
      model.train()
      train_running_loss = 0.0
      best_loss = float('inf')
      for idx,img_mask in enumerate(tqdm(train_loader)):
          img = img_mask[0].to(device)
          mask = img_mask[1].to(device)
          optimizer.zero_grad()

          output = model(img)

          loss = criterion(output,mask)
          train_running_loss += loss.item()
          loss.backward()

          optimizer.step()
      train_loss = train_running_loss / len(train_loader)


      model.eval()
      val_running_loss = 0.0
      with torch.no_grad():
        for idx,img_mask in enumerate(tqdm(val_loader)):
          img = img_mask[0].float().to(device)
          mask = img_mask[1].float().to(device)

          output = model(img)

          loss = criterion(output,mask)

          val_running_loss += loss.item()
        val_loss = val_running_loss / len(val_loader)

      print("-" * 30)
      print(f"EPOCH:{epoch + 1}, Train Loss:{train_loss:.4f}, Val Loss:{val_loss:.4f}")
      print("-" * 30)

      if(val_loss < best_loss):
        best_loss = val_loss
        torch.save({
            "epoch":"epoch",
            "model_state_dict":model.state_dict(),
            "optimizer_state_dict":optimizer.state_dict(),
            "loss":val_loss
        },MODEL_SAVE_PATH + "model.pth")
        print(f"Model saved at EPOCH {epoch+1} with loss {best_loss:.4f}")
      else:
        print(f"Skipping save at epoch {epoch+1} with loss did not improve")



In [None]:
def load_checkpoint(model,optimizer,CHECKPOINT_PATH):
  checkpoint = torch.load(CHECKPOINT_PATH)
  model.load_state_dict(checkpoint("model_state_dict"))
  optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
  start_epoch = checkpoint['epoch']

  print(f"Resuming from epoch {start_epoch}, Best Loss was :{checkpoint['loss']:.4f}")

  return start_epoch

In [None]:
def show_image_grid(model,root_path,device):
  image_dataset = CarvanaDataset(root_path,test=True)
  # test_loader = DataLoader(image_dataset,batch_size=BATCH_SIZE,shuffle=False)
  images = []
  original_masks = []
  predicted_masks = []

  for img,original_mask in image_dataset:
    img = img.float().to(device)
    img = img.unsqueeze(0)

    pred_mask = model(img)

    img = img.squeeze(0).cpu().detach()
    img = img.permute(1,2,0)

    pred_mask = pred_mask.squeeze(0).cpu().detach()
    pred_mask = pred_mask.permute(1,2,0)
    pred_mask[pred_mask < 0] = 0
    pred_mask[pred_mask > 1] = 1

    original_mask = original_mask.cpu().detach()
    original_mask = original_mask.permute(1,2,0)

    images.append(img)
    original_masks.append(original_mask)
    predicted_masks.append(pred_mask)
  images.extend(original_masks)
  images.extend(predicted_masks)

  fig = plt.figure()
  for i in range(1, 3*len(image_dataset)+1):
    fig.add_subplots(3,len(image_dataset),i)
    plt.imshow(images[i-1],cmap="gray")
  plt.show()

In [None]:
def single_image_inference(model,image_path,device):
  transform = transforms.Compose([
    transforms.Resize((572,572)),
    transforms.ToTensor()
  ])

  img = transform(Image.open(image_path)).float().to(device)
  img = img.unsqueeze(0)

  pred_mask = model(img)

  img = img.squeeze(0).cpu().detach()
  img = img.permute(1,2,0)

  pred_mask = pred_mask.squeeze(0).cpu().detach()
  pred_mask = pred_mask.permute(1,2,0)
  pred_mask[pred_mask < 0] = 0
  pred_mask[pred_mask > 1] = 1

  fig = plt.figure()
  for i in range(1,3):
    fig.add_subplot(1,2,i)
    if i == 1:
      plt.imshow(img,cmap="gray")
    else:
      plt.imshow(pred_mask,cmap="gray")
  plt.show()

In [None]:
LEARNING_RATE = 3e-4 #0.0003
BATCH_SIZE = 32
EPOCHS = 2
DATA_PATH = "./drive/MyDrive/unet-datasets/"
MODEL_SAVE_PATH ="./drive/MyDrive/unet-datasets/model-checkpoints/"

SINGLE_PATH_IMAGE = "./drive/MyDrive/unet-datasets/29bb3ece3180_11.jpg"

device = torch.device = "cuda" if torch.cude.is_available else "cpu"

train_dataset = CarvanaDataset(DATA_PATH)

g = torch.Generator().manual_seed(42)
train_dataset,val_dataset = random_split(train_dataset,[int(0.8*len(train_dataset)),int(0.2*len(train_dataset))],generator=g)

train_loader = DataLoader(train_dataset,batch_size=BATCH_SIZE,shuffle=True)
val_loader = DataLoader(val_dataset,batch_size = BATCH_SIZE,shuffle=True)


model = UNet(in_channels=3,num_classes=1).to(device)
optimizer = optim.AdamW(model.parameters(),lr=LEARNING_RATE)
criterion = nn.BCEWithLogitsLoss()

In [None]:

if os.path.exists(MODEL_SAVE_PATH + "model.pth"):
  start_epoch = load_checkpoint(model,optimizer,MODEL_SAVE_PATH)
else:
  start_epoch = 0

train(model,train_loader,optimizer,criterion,start_epoch,EPOCHS,MODEL_SAVE_PATH,device)