# Y-net with new convolution block for depth estimation

### This notebook contains code to run a new depth estimation model called Y-net with a new convolutional block
Done By:
Chandravaran Kunjeti
Saikumar Dande

In [None]:
!pip install albumentations==0.4.6

In [None]:
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

Mounted at /content/drive


In [None]:
%cd /content/drive/MyDrive/Neural\ Network\ Project

/content/drive/MyDrive/Neural Network Project


In [None]:
from DataLoader import TransposeDepthInput, NYUDataset, save_checkpoint, get_loaders, save_predictions_as_imgs
from metrics import ScaleInvariantLoss, threeshold_percentage, rmse_linear, rmse_log, abs_relative_difference, squared_relative_difference

In [None]:
import torch
import torch.nn as nn
import torchvision.transforms.functional as TF

class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size = 3):
        super(ResidualBlock, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size, padding = 'same', bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size, padding = 'same', bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        return self.conv(x)

class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(DoubleConv, self).__init__()
        self.conv1_1 = nn.Conv2d(in_channels, out_channels, kernel_size = 1, padding = 'same', bias=False)
        self.conv3_1 = nn.Conv2d(in_channels, out_channels, kernel_size = 3, padding = 'same', bias=False)
        self.conv5_1 = nn.Conv2d(in_channels, out_channels, kernel_size = 5, padding = 'same', bias=False)
        self.conv3_2 = nn.Conv2d(3*out_channels, out_channels, kernel_size = 3, padding = 'same', bias=False)

        self.batchNorm1 = nn.BatchNorm2d(3*out_channels)
        self.batchNorm2 = nn.BatchNorm2d(out_channels)

        self.relu1 = nn.ReLU(inplace=True)
        self.relu2 = nn.ReLU(inplace=True)

    def forward(self, x):
        x = torch.cat((self.conv1_1(x), self.conv3_1(x), self.conv5_1(x)), dim=1)
        x = self.relu1(self.batchNorm1(x))
        x = self.batchNorm2(self.conv3_2(x))
        return self.relu2(x)

class DownConv(nn.Module):
    def __init__(self, in_channels, features=[64, 128, 256, 512]):
        super(DownConv, self).__init__()
        self.downs = nn.ModuleList()
        self.residualBlocks = nn.ModuleList()
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)

        # Down part of UNET
        for feature in features:
            self.downs.append(DoubleConv(in_channels, feature))
            self.residualBlocks.append(ResidualBlock(feature, feature, kernel_size = 5))
            in_channels = feature
    
    def forward(self, x):
        self.skip_connections = []

        for i in range(len(self.downs)):
            x = self.downs[i](x)
            skip_connection = self.residualBlocks[i](x)
            self.skip_connections.append(skip_connection)
            x = self.pool(x)
        
        self.skip_connections = self.skip_connections[::-1]
        return x

class UpConv(nn.Module):
    def __init__(self, out_channels, downConvs1, downConvs2, features=[512, 256, 128, 64]):
        super(UpConv, self).__init__()
        self.ups = nn.ModuleList()
        self.downConvs1 = downConvs1
        self.downConvs2 = downConvs2

        # Up part of UNET
        for feature in features:
            self.ups.append(
                nn.ConvTranspose2d(
                    feature, feature, kernel_size=2, stride=2,
                )
            )
            self.ups.append(DoubleConv(feature*3, feature//2))
      
    def forward(self, x):
        for idx in range(0, len(self.ups), 2):
            x = self.ups[idx](x)
            skip_connection1 = self.downConvs1.skip_connections[idx//2]
            skip_connection2 = self.downConvs2.skip_connections[idx//2]

            if x.shape != skip_connection1.shape:
                x = TF.resize(x, size=skip_connection1.shape[2:])

            concat_skip = torch.cat((skip_connection1, skip_connection2, x), dim=1)
            x = self.ups[idx+1](concat_skip)
        
        return x


# In this network, we send the RGB and segmentation input separately
class YNET(nn.Module):
    def __init__(
            self, in_channels1=3, in_channels2=2, out_channels=1, features=[64, 128, 256, 512],
    ):
        super(YNET, self).__init__()
        self.ups = nn.ModuleList()
        self.downConvs1 = DownConv(in_channels1, features)
        self.downConvs2 = DownConv(in_channels2, features)
        self.bottleneck = DoubleConv(features[-1]*2, features[-1])
        self.upConvs = UpConv(out_channels, self.downConvs1, self.downConvs2, reversed(features))

        self.final_conv = nn.Conv2d(features[0]//2, out_channels, kernel_size=1)

    def forward(self, image, gradient):
        x1 = self.downConvs1(image)
        x2 = self.downConvs2(gradient)

        x = self.bottleneck(torch.cat((x1, x2), dim=1))
        x = self.upConvs(x)

        return self.final_conv(x)


In [None]:
def test():
    image = torch.randn((3, 3, 120, 160))
    gradient = torch.randn((3, 2, 120, 160))
    model = YNET(in_channels1=3, in_channels2=2, out_channels=1)
    preds = model(image, gradient)
    print("Input shape\t:", image.shape)
    print("Gradient shape\t:", gradient.shape)
    print("Output shape\t:", preds.shape)
    assert preds.shape[2:] == image.shape[2:]

test()

Input shape	: torch.Size([3, 3, 120, 160])
Gradient shape	: torch.Size([3, 2, 120, 160])
Output shape	: torch.Size([3, 1, 120, 160])


In [None]:
import cv2
import torch
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image

IMAGE_HEIGHT = 120
IMAGE_WIDTH = 160

rgb_data_transforms_1 = transforms.Compose([
    transforms.Resize((IMAGE_HEIGHT, IMAGE_WIDTH)),
    transforms.ToTensor(),
])

def Save_Predictions(model, file_names, image_dir, depth_dir, save_dir):
    model.eval()
    for image_name in file_names:
        # Load the image and dpeth
        image = cv2.imread(image_dir + image_name, cv2.IMREAD_UNCHANGED)
        depth = cv2.imread(depth_dir+ image_name, cv2.IMREAD_UNCHANGED)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        image = Image.fromarray(image)
        image = rgb_data_transforms_1(image)

        # Find the gradient
        gray = np.moveaxis(image.numpy(), [0, 1, 2], [2, 0, 1])
        gray = cv2.cvtColor(gray, cv2.COLOR_BGR2GRAY)
        gx = cv2.Sobel(gray, ddepth = cv2.CV_32F, dx=1, dy=0, ksize=3)
        gy = cv2.Sobel(gray, ddepth = cv2.CV_32F, dx=0, dy=1, ksize=3)
        gradient = torch.from_numpy(np.stack([gx, gy]))

        image = torch.unsqueeze(image, 0)
        gradient = torch.unsqueeze(gradient, 0)

        # Predict the output
        image = image.to(device=DEVICE)
        gradient = gradient.to(device=DEVICE)
        with torch.no_grad():
            predicted = model(image, gradient)

        image = image.cpu()
        predicted = predicted.cpu()

        input_image = np.zeros((120, 160, 3), dtype=np.float32)
        input_image[:, :, 0] = image[0, 0, :, :]
        input_image[:, :, 1] = image[0, 1, :, :]
        input_image[:, :, 2] = image[0, 2, :, :]
        predicted = predicted[0, 0, :, :]

        fig = plt.figure(figsize=(14, 6))

        ax = fig.add_subplot(1, 3, 1)
        ax.set_title('Input image')
        plt.imshow(input_image)
        ax = fig.add_subplot(1, 3, 2)
        ax.set_title('Ground truth')
        plt.imshow(depth, cmap='gist_gray')    #plt.imshow(actual_depth, cmap='jet')
        ax = fig.add_subplot(1, 3, 3)
        ax.set_title('Ynet + New Block predicted') 
        plt.imshow(predicted, cmap='gist_gray')
        plt.savefig(f'{save_dir}/{image_name}')
        plt.close(fig)
    model.train()

In [None]:
import torch
import albumentations as A
from albumentations.pytorch import ToTensorV2
from tqdm import tqdm
import torch.nn as nn
import torch.optim as optim
import math
import torchvision
import torchvision.transforms as transforms

# Hyperparameters etc.
LEARNING_RATE = 1e-4
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
BATCH_SIZE = 16
NUM_EPOCHS = 100
NUM_WORKERS = 16
PIN_MEMORY = True
LOAD_MODEL = True
TRAIN_IMG_DIR = "Datasets/Train/images/"
TRAIN_DEPTH_DIR = "Datasets/Train/depths/"
VAL_IMG_DIR = "Datasets/Validation/images/"
VAL_DEPTH_DIR = "Datasets/Validation/depths/"
TEST_IMG_DIR = "Datasets/Test/images/"
TEST_DEPTH_DIR = "Datasets/Test/depths/"

IMAGE_HEIGHT = 120
IMAGE_WIDTH = 160

MODEL_NAME = 'New_Model'
MODEL_SAVE_DIR = "Models/New_Model/checkpoint/"
MODEL_LOAD_PATH = "Models/New_Model/checkpoint/" + MODEL_NAME + "_80.pth.tar"
VALIDATION_IMAGES_SAVE_DIR = "Models/New_Model/validation_outputs/"

TRAIN_SAVE_PATH = "Models/New_Model/predictions/Train/"

dtype=torch.cuda.FloatTensor

def train_unet(loader, model, optimizer, loss_fn, scaler):
    # loop = tqdm(loader)

    train_loss = 0
    for batch_idx, (data, gradient, targets) in enumerate(loader):
        data = data.to(device=DEVICE)
        gradient = gradient.to(device=DEVICE)
        targets = targets.to(device=DEVICE)

        # forward
        predictions = model(data.type(dtype), gradient.type(dtype))
        loss = loss_fn(predictions, targets)

        train_loss += loss.item()
        
        # backward
        optimizer.zero_grad()
        
        loss.backward()
        optimizer.step()

        # scaler.scale(loss).backward()
        # scaler.step(optimizer)
        # scaler.update()

        # update tqdm loop
        # loop.set_postfix(loss=loss.item())

    train_loss /= (batch_idx + 1)
    return train_loss

def validate_unet(loader, model, loss_fn, epoch, train_loss, save_folder):
  # loop = tqdm(loader)

  validation_loss = 0
  scale_invariant_loss = 0
  delta1_accuracy = 0
  delta2_accuracy = 0
  delta3_accuracy = 0
  rmse_linear_loss = 0
  rmse_log_loss = 0
  abs_relative_difference_loss = 0
  squared_relative_difference_loss = 0

  model.eval()
  for batch_idx, (data, gradient, targets) in enumerate(loader):
      data = data.to(device=DEVICE)
      gradient = gradient.to(device=DEVICE)
      targets = targets.to(device=DEVICE)

      with torch.no_grad():
        predictions = model(data.type(dtype), gradient.type(dtype))
        loss = loss_fn(predictions, targets)
      
      validation_loss += loss.item()

      # Error function
      scale_invariant_loss += loss_fn(predictions, targets)
      delta1_accuracy += threeshold_percentage(predictions, targets, 1.25)
      delta2_accuracy += threeshold_percentage(predictions, targets, 1.25*1.25)
      delta3_accuracy += threeshold_percentage(predictions, targets, 1.25*1.25*1.25)
      rmse_linear_loss += rmse_linear(predictions, targets)
      rmse_log_loss += rmse_log(predictions, targets)
      abs_relative_difference_loss += abs_relative_difference(predictions, targets)
      squared_relative_difference_loss += squared_relative_difference(predictions, targets)

      # Saving output depths
      targets -= torch.min(targets)
      targets = targets/torch.max(targets)

      predictions -= torch.min(predictions)
      predictions = predictions/torch.max(predictions)

      torchvision.utils.save_image(predictions, f"{save_folder}/pred_{batch_idx}.png")
      torchvision.utils.save_image(targets, f"{save_folder}{batch_idx}.png")
      
      # update tqdm loop
      # loop.set_postfix(validation_loss=loss.item())
  
  validation_loss /= (batch_idx + 1)
  delta1_accuracy /= (batch_idx + 1)
  delta2_accuracy /= (batch_idx + 1)
  delta3_accuracy /= (batch_idx + 1)
  rmse_linear_loss /= (batch_idx + 1)
  rmse_log_loss /= (batch_idx + 1)
  abs_relative_difference_loss /= (batch_idx + 1)
  squared_relative_difference_loss /= (batch_idx + 1)

  print('Epoch: {}    {:.4f}      {:.4f}      {:.4f}      {:.4f}      {:.4f}      {:.4f}      {:.4f}      {:.4f}      {:.4f}'.format(epoch, train_loss, 
        validation_loss, delta1_accuracy, delta2_accuracy, delta3_accuracy, rmse_linear_loss, rmse_log_loss, 
        abs_relative_difference_loss, squared_relative_difference_loss))
  
  model.train()
  return validation_loss

def main():
    rgb_data_transforms = transforms.Compose([
        transforms.Resize((IMAGE_HEIGHT, IMAGE_WIDTH)),
        transforms.ToTensor(),
    ])

    depth_data_transforms = transforms.Compose([
        TransposeDepthInput(),
    ])

    train_loader, val_loader, test_loader = get_loaders(
          TRAIN_IMG_DIR,
          TRAIN_DEPTH_DIR,
          VAL_IMG_DIR,
          VAL_DEPTH_DIR,
          TEST_IMG_DIR,
          TEST_DEPTH_DIR,
          BATCH_SIZE,
          rgb_data_transforms,
          depth_data_transforms,
          NUM_WORKERS,
          PIN_MEMORY,
    )

    model = YNET(in_channels1=3, in_channels2=2, out_channels=1).to(DEVICE)
    loss_fn = ScaleInvariantLoss
    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
    # scaler = torch.cuda.amp.GradScaler()
    scaler = None
    
    train_losses, validation_losses = [], []

    if LOAD_MODEL:
      print("=> Loading Chekpoint")
      checkpoint = torch.load(MODEL_LOAD_PATH)
      model.load_state_dict(checkpoint["state_dict"])
      train_losses = checkpoint["train_losses"]
      validation_losses = checkpoint["validation_losses"]
      print("=> Checkpoint Loaded")

    print("********* Training the New Model **************")
    print("Epochs:     Train_loss  Val_loss    Delta_1     Delta_2     Delta_3    rmse_lin    rmse_log    abs_rel.  square_relative")
    print("Paper Val:                          (0.618)     (0.891)     (0.969)     (0.871)     (0.283)     (0.228)     (0.223)")
    
    file_names = ['31.png', '69.png', '1020.png', '1021.png']

    for epoch in range(81, NUM_EPOCHS+1):
        train_loss = train_unet(train_loader, model, optimizer, loss_fn, scaler)
        validation_loss = validate_unet(val_loader, model, loss_fn, epoch, train_loss, save_folder=VALIDATION_IMAGES_SAVE_DIR)

        train_losses.append(train_loss)
        validation_losses.append(validation_loss)
        
        Save_Predictions(model, file_names, TRAIN_IMG_DIR, TRAIN_DEPTH_DIR, TRAIN_SAVE_PATH)

        if epoch % 10 == 0:
          # save model
          checkpoint = {
              "state_dict": model.state_dict(),
              "train_losses": train_losses,
              "validation_losses": validation_losses,
          }
          save_path = MODEL_SAVE_DIR + MODEL_NAME + '_' + str(epoch) + '.pth.tar'
          save_checkpoint(checkpoint, save_path)

    print()

In [None]:
import warnings
warnings.filterwarnings("ignore")
main()

********* Training the New Model **************
Epochs:     Train_loss  Val_loss    Delta_1     Delta_2     Delta_3    rmse_lin    rmse_log    abs_rel.  square_relative
Paper Val:                          (0.618)     (0.891)     (0.969)     (0.871)     (0.283)     (0.228)     (0.223)
Epoch: 1    0.3090      0.2550      0.1471      0.3852      0.6712      1.3462      0.3974      0.4675      0.7237
Epoch: 2    0.2140      0.1999      0.2205      0.5164      0.7900      1.1949      0.2937      0.4236      0.6265
Epoch: 3    0.1920      0.1247      0.4514      0.7878      0.9449      0.8878      0.1513      0.3481      0.5098
Epoch: 4    0.1770      0.1363      0.4037      0.7335      0.9232      0.9622      0.1735      0.3522      0.5030
Epoch: 5    0.1686      0.1023      0.5908      0.8755      0.9672      0.7389      0.1148      0.3278      0.5131
Epoch: 6    0.1586      0.1510      0.3520      0.6815      0.8916      1.0154      0.2056      0.3579      0.5005
Epoch: 7    0.1542      0

In [None]:
import warnings
warnings.filterwarnings("ignore")
main()

=> Loading Chekpoint
=> Checkpoint Loaded
********* Training the New Model **************
Epochs:     Train_loss  Val_loss    Delta_1     Delta_2     Delta_3    rmse_lin    rmse_log    abs_rel.  square_relative
Paper Val:                          (0.618)     (0.891)     (0.969)     (0.871)     (0.283)     (0.228)     (0.223)
Epoch: 41    0.0483      0.1004      0.6336      0.8887      0.9643      0.7086      0.1109      0.3252      0.5460
Epoch: 42    0.0467      0.1004      0.6387      0.8946      0.9653      0.6932      0.1107      0.3077      0.5111
Epoch: 43    0.0468      0.1015      0.6019      0.8778      0.9663      0.7251      0.1148      0.3160      0.5674
Epoch: 44    0.0436      0.0923      0.6378      0.9035      0.9732      0.6765      0.1026      0.3005      0.4839
Epoch: 45    0.0421      0.0979      0.6518      0.9008      0.9689      0.6851      0.1071      0.3105      0.5422
Epoch: 46    0.0417      0.0983      0.6240      0.8891      0.9689      0.6981      0.1099  

KeyboardInterrupt: ignored

In [None]:
import warnings
warnings.filterwarnings("ignore")
main()

=> Loading Chekpoint
=> Checkpoint Loaded
********* Training the New Model **************
Epochs:     Train_loss  Val_loss    Delta_1     Delta_2     Delta_3    rmse_lin    rmse_log    abs_rel.  square_relative
Paper Val:                          (0.618)     (0.891)     (0.969)     (0.871)     (0.283)     (0.228)     (0.223)
Epoch: 81    0.0229      0.0913      0.6617      0.9096      0.9746      0.6610      0.1005      0.3112      0.5548
Epoch: 82    0.0225      0.0914      0.6489      0.9074      0.9755      0.6612      0.1016      0.3026      0.5225
Epoch: 83    0.0229      0.0910      0.6582      0.9104      0.9751      0.6597      0.1006      0.3064      0.5403
Epoch: 84    0.0217      0.0929      0.6670      0.9038      0.9723      0.6644      0.1021      0.3129      0.5405
Epoch: 85    0.0220      0.0906      0.6530      0.9118      0.9759      0.6566      0.1006      0.3048      0.5216
Epoch: 86    0.0215      0.0942      0.6600      0.9041      0.9727      0.6726      0.1039  

### **Testing**

In [None]:
%cd /content/drive/MyDrive/Neural\ Network\ Project

/content/drive/MyDrive/Neural Network Project


In [None]:
import os
import cv2
from PIL import Image
import numpy as np
import torchvision.transforms as transforms
import torch
from Ynet import YNET
from DataLoader import TransposeDepthInput, NYUDataset, save_checkpoint, get_loaders, save_predictions_as_imgs
from metrics import ScaleInvariantLoss, threeshold_percentage, rmse_linear, rmse_log, abs_relative_difference, squared_relative_difference

In [None]:
IMAGE_HEIGHT = 120
IMAGE_WIDTH = 160

rgb_data_transforms = transforms.Compose([
    transforms.Resize((IMAGE_HEIGHT, IMAGE_WIDTH)),
    transforms.ToTensor(),
])

In [None]:
import matplotlib.pyplot as plt

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
YNET_MODEL_PATH = "Models/Ynet/checkpoint/Ynet_model_100.pth.tar"
TRAIN_SAVE_PATH = "Models/Ynet/predictions/Train/"
VAL_SAVE_PATH = "Models/Ynet/predictions/Validation/"
TEST_SAVE_PATH = "Models/Ynet/predictions/Test/"

TRAIN_IMG_DIR = "Datasets/Train/images/"
TRAIN_DEPTH_DIR = "Datasets/Train/depths/"
VAL_IMG_DIR = "Datasets/Validation/images/"
VAL_DEPTH_DIR = "Datasets/Validation/depths/"
TEST_IMG_DIR = "Datasets/Test/images/"
TEST_DEPTH_DIR = "Datasets/Test/depths/"

model = YNET(in_channels1=3, in_channels2=2, out_channels=1).to(DEVICE)

# Loading Unet model
checkpoint = torch.load(YNET_MODEL_PATH)
model.load_state_dict(checkpoint["state_dict"])

In [None]:
def Save_Predictions(image_dir, depth_dir, save_dir):
    model.eval()
    for image_name in os.listdir(image_dir):
        # Load the image and dpeth
        image = cv2.imread(image_dir + image_name, cv2.IMREAD_UNCHANGED)
        depth = cv2.imread(depth_dir+ image_name, cv2.IMREAD_UNCHANGED)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        image = Image.fromarray(image)
        image = rgb_data_transforms(image)

        # Find the gradient
        gray = np.moveaxis(image.numpy(), [0, 1, 2], [2, 0, 1])
        gray = cv2.cvtColor(gray, cv2.COLOR_BGR2GRAY)
        gx = cv2.Sobel(gray, ddepth = cv2.CV_32F, dx=1, dy=0, ksize=3)
        gy = cv2.Sobel(gray, ddepth = cv2.CV_32F, dx=0, dy=1, ksize=3)
        gradient = torch.from_numpy(np.stack([gx, gy]))

        image = torch.unsqueeze(image, 0)
        gradient = torch.unsqueeze(gradient, 0)

        # Predict the output
        image = image.to(device=DEVICE)
        gradient = gradient.to(device=DEVICE)
        with torch.no_grad():
            predicted = model(image, gradient)

        image = image.cpu()
        predicted = predicted.cpu()

        input_image = np.zeros((120, 160, 3), dtype=np.float32)
        input_image[:, :, 0] = image[0, 0, :, :]
        input_image[:, :, 1] = image[0, 1, :, :]
        input_image[:, :, 2] = image[0, 2, :, :]
        predicted = predicted[0, 0, :, :]

        fig = plt.figure(figsize=(14, 6))

        ax = fig.add_subplot(1, 3, 1)
        ax.set_title('Input image')
        plt.imshow(input_image)
        ax = fig.add_subplot(1, 3, 2)
        ax.set_title('Ground truth')
        plt.imshow(depth, cmap='gist_gray')    #plt.imshow(actual_depth, cmap='jet')
        ax = fig.add_subplot(1, 3, 3)
        ax.set_title('Ynet predicted')
        plt.imshow(predicted, cmap='gist_gray')
        plt.savefig(f'{save_dir}/{image_name}')
        plt.close(fig)
    model.train()

In [None]:
Save_Predictions(VAL_IMG_DIR, VAL_DEPTH_DIR, VAL_SAVE_PATH)

In [None]:
Save_Predictions(TRAIN_IMG_DIR, TRAIN_DEPTH_DIR, TRAIN_SAVE_PATH)

In [None]:
Save_Predictions(TEST_IMG_DIR, TEST_DEPTH_DIR, TEST_SAVE_PATH)

### **Time taken**

In [None]:
import time
model.eval()
start_time = time.time()
num_images = 50
for i in range(num_images):
    image = cv2.imread(TRAIN_IMG_DIR + str(i) + '.png', cv2.IMREAD_UNCHANGED)
    depth = cv2.imread(TRAIN_DEPTH_DIR + str(i) + '.png', cv2.IMREAD_UNCHANGED)
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    image = Image.fromarray(image)
    image = rgb_data_transforms(image)

    # Find the gradient
    gray = np.moveaxis(image.numpy(), [0, 1, 2], [2, 0, 1])
    gray = cv2.cvtColor(gray, cv2.COLOR_BGR2GRAY)
    gx = cv2.Sobel(gray, ddepth = cv2.CV_32F, dx=1, dy=0, ksize=3)
    gy = cv2.Sobel(gray, ddepth = cv2.CV_32F, dx=0, dy=1, ksize=3)
    gradient = torch.from_numpy(np.stack([gx, gy]))

    image = torch.unsqueeze(image, 0)
    gradient = torch.unsqueeze(gradient, 0)

    # Predict the output
    image = image.to(device=DEVICE)
    gradient = gradient.to(device=DEVICE)
    with torch.no_grad():
        predicted = model(image, gradient)
end_time = time.time()
model.train()
print('Time taken:', (end_time-start_time)/num_images)

Time taken: 0.11543568611145019


### **Model summary**

In [None]:
print(model)

YNET(
  (ups): ModuleList()
  (downConvs1): DownConv(
    (downs): ModuleList(
      (0): DoubleConv(
        (conv): Sequential(
          (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=same, bias=False)
          (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU(inplace=True)
          (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=same, bias=False)
          (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (5): ReLU(inplace=True)
        )
      )
      (1): DoubleConv(
        (conv): Sequential(
          (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=same, bias=False)
          (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU(inplace=True)
          (3): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=same, bias=False)
          (4): BatchNorm2d(128, eps=1e-05, mome

In [None]:
from prettytable import PrettyTable

def count_parameters(model):
    table = PrettyTable(["Modules", "Parameters"])
    total_params = 0
    for name, parameter in model.named_parameters():
        if not parameter.requires_grad: continue
        param = parameter.numel()
        table.add_row([name, param])
        total_params+=param
    print(table)
    print(f"Total Trainable Params: {total_params}")
    return total_params
    
count_parameters(model)

+-------------------------------------------+------------+
|                  Modules                  | Parameters |
+-------------------------------------------+------------+
|      downConvs1.downs.0.conv.0.weight     |    1728    |
|      downConvs1.downs.0.conv.1.weight     |     64     |
|       downConvs1.downs.0.conv.1.bias      |     64     |
|      downConvs1.downs.0.conv.3.weight     |   36864    |
|      downConvs1.downs.0.conv.4.weight     |     64     |
|       downConvs1.downs.0.conv.4.bias      |     64     |
|      downConvs1.downs.1.conv.0.weight     |   73728    |
|      downConvs1.downs.1.conv.1.weight     |    128     |
|       downConvs1.downs.1.conv.1.bias      |    128     |
|      downConvs1.downs.1.conv.3.weight     |   147456   |
|      downConvs1.downs.1.conv.4.weight     |    128     |
|       downConvs1.downs.1.conv.4.bias      |    128     |
|      downConvs1.downs.2.conv.0.weight     |   294912   |
|      downConvs1.downs.2.conv.1.weight     |    256    

58156705