In [None]:
from google.colab import drive
drive.mount('/content/drive')
!python --version


In [None]:
 #!cp -r '/content/drive/MyDrive/SB' '/content/SB'
 !nvidia-smi

In [None]:
import os
import numpy as np
import torch
import torch.nn as nn
import torch.utils.data as data
import torchvision
from torchvision import transforms


In [None]:
device = "cuda"
EPOCHS = 2
BATCH_SIZE = 8
NUM_WORKERS = 2
IMAGE_HEIGHT = 256
IMAGE_WIDTH = 256
TRAIN_IMAGE_PATH = '/content/drive/MyDrive/SB/ears/train/'
TRAIN_MASK_PATH = '/content/drive/MyDrive/SB/ears/annotations/segmentation/train/'

TEST_IMAGE_PATH = '/content/drive/MyDrive/SB/ears/test/'
TEST_MASK_PATH = '/content/drive/MyDrive/SB/ears/annotations/segmentation/test/'
transform_img = transforms.Compose(
    [transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
     transforms.ToTensor()])

transform_eval = transforms.Compose(
    [transforms.Normalize((0.5), (0.5)),
     transforms.ToTensor()])


In [None]:
class DoubleConv(nn.Module):
  def __init__(self, in_channels, out_channels):
    super(DoubleConv, self).__init__()
    self.conv = nn.Sequential(
        nn.Conv2d(in_channels, out_channels, 3, 1, 1, bias=False),
        nn.BatchNorm2d(out_channels),
        nn.ReLU(inplace=True),
        nn.Conv2d(out_channels, out_channels, 3, 1, 1, bias=False),
        nn.BatchNorm2d(out_channels),
        nn.ReLU(inplace=True),
    )

  def forward(self, x):
    return self.conv(x)


class UNET(nn.Module):
  def __init__(self, in_channels=3, out_channels=1
               , features = [64, 128, 256, 512]):
    super(UNET, self).__init__()

    self.ups = nn.ModuleList()
    self.downs = nn.ModuleList()
    self.pool = nn.MaxPool2d(kernel_size=3, stride=3)

    #Down part of UNET
    for feature in features:
      self.downs.append(DoubleConv(in_channels, feature))
      in_channels = feature

    #Up part of UNET
    for feature in reversed(features):
        self.ups.append(
            nn.ConvTranspose2d(
                feature*2,
                feature,
                kernel_size=3,
                stride=3
            )
        )
        self.ups.append(DoubleConv(feature*2, feature))

    #bottom layer
    self.bottleneck = DoubleConv(features[-1], features[-1]*2)   
    self.final_conv = nn.Conv2d(features[0], out_channels, kernel_size=1)


  def forward(self, x):
    skip_connections = []

    for down in self.downs:
      x = down(x)
      skip_connections.append(x)
      x = self.pool(x)

    x = self.bottleneck(x)
    #reverse
    skip_connections = skip_connections[::-1]

    #2 steps - up (convTranspose) and double conv
    for idx in range(0, len(self.ups), 2):
      x = self.ups[idx](x)
      # divided by 2 cause we do 2 steps at a time
      skip_connection = skip_connections[idx//2]


      #if shapes dont match
      if x.shape != skip_connection.shape:
        #skiping batch size and num of channels for skip_connection
        x = transforms.functional.resize(x, size=skip_connection.shape[2:])


      concat_skip = torch.cat((skip_connection, x), dim=1)
      x = self.ups[idx+1](concat_skip)

    return self.final_conv(x)


x = torch.randn((3, 1, 480, 360))
model = UNET(in_channels=1, out_channels=1)
preds = model(x)
print(preds.shape)
print(x.shape)



In [None]:
import os
import cv2
from PIL import Image
from torch.utils.data import Dataset
import numpy as np

class EarDataset(Dataset):
  def __init__(self, image_dir, mask_dir, transform_img=None, transform_eval=None):
    self.image_dir = image_dir
    self.mask_dir = mask_dir
    self.transform_img = transform_img
    self.transform_eval = transform_eval
    self.images = os.listdir(image_dir)

  def __len__(self):
    return len(self.images)

  def __getitem__(self, item):
    img_path = os.path.join(self.image_dir, self.images[item])
    mask_path = os.path.join(self.mask_dir, self.images[item])
    image = cv2.imread(img_path)
    mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
    mask[mask == 255.0] = 1.0
    image = self.transform_img(image)
    mask = self.transform_eval(mask)
    print(type(image))
    print(type(mask))
    return image, mask

In [None]:
from tqdm import tqdm

def train_model(loader, model, optimizer, loss_fn, scaler, device):
  loop = tqdm(loader)

  for batch_idx, (data, targets) in enumerate(loop):
    data = data.to(device)
    targets = targets.float().to(device)
    print(data.shape)
    print(targets.shape)

    # zero the parameter gradients
    optimizer.zero_grad()

        # forward + backward + optimize
    outputs = model(data)
    loss = criterion(data, targets)
    loss.backward()
    optimizer.step()
    #update tqdm loop
    loop.set_postfix(loss=loss.item())  

def check_accuracy(test_loader, model, device):
  correct = 0
  num_pixels = 0  
  model.eval()

  with torch.no_grad():
    for x, y in loader:
      x = x.to(device) 
      y = y.to(device)
      preds = torch.sigmoid(model(x))
      preds =(preds > 0.5).float()
      num_correct += (preds == y).sum()
      num.num_pixels += torch.numel(preds)

  print( f"Got {correct}/{num_pixels} with acc {correct/num_pixels:.2f}")    


In [None]:
from torch.utils.data import DataLoader
def get_loaders(train_img_path,
                train_mask_path,
                test_img_path,
                test_mask_path,
                batch_size,
                transform_img,
                transform_eval,
                num_workers):
  train_ds = EarDataset(
      image_dir=train_img_path,
      mask_dir=train_mask_path,
      transform_img=transform_img,
      transform_eval=transform_eval
  )

  train_loader = DataLoader(
      train_ds,
      batch_size=batch_size,
      num_workers=num_workers,
  )

  test_ds = EarDataset(
      image_dir=test_img_path,
      mask_dir=test_mask_path,
      transform_img=transform_img,
      transform_eval=transform_eval
  )

  test_loader = DataLoader(
      test_ds,
      batch_size=batch_size,
      num_workers=num_workers,
  )
  return train_loader, test_loader

In [None]:
import matplotlib.pyplot as plt
import numpy as np
from torch import optim
model = UNET(in_channels=3, out_channels=1).to(device)    
loss_fn = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

train_loader, test_loader = get_loaders(
    TRAIN_IMAGE_PATH,
    TRAIN_MASK_PATH,
    TEST_IMAGE_PATH,
    TEST_MASK_PATH,
    BATCH_SIZE,
    transform_img,
    transform_eval,
    NUM_WORKERS
)
scaler = torch.cuda.amp.GradScaler()

for epoch in range(EPOCHS):
  train_model(train_loader, model, optimizer, loss_fn, scaler, device)

  #save model
  checkpoint = {
      "state_dict": model.state_dict(),
      "optimizer": optimizer.state_dict(),
  }
  #save_checkpoint(checkpoint)

  #check accuracy
  check_accuracy(test_loader, model, device)

  #save_prediction_as_imgs(train_loader, model)

