In [None]:
# cell 1
model_name = '224_model' # please enter 224_model or 32_model
op_size = 224 # please enter 224 or 32
input_image_path = 'test_input.jpg'
# target can be left when inference is required without target and execute the cell named Infer without target image
target_image_path = 'test_target.png' 

## Run this cell in all the cases

In [None]:
# cell 2
from math import exp
import torch
import torch.nn as nn
import numpy as np
from PIL import Image
import torchvision.transforms as T
import matplotlib.pyplot as plt
import time
from sklearn.preprocessing import MinMaxScaler
import pickle
from skimage.metrics import structural_similarity as ssim

# loading the distilled small visual transformer
dino_s = torch.hub.load('facebookresearch/dinov2', 'dinov2_vits14')

transform = T.Compose([
    T.ToTensor()
])

scaler = MinMaxScaler()
# loading the pickle file of the scaler object which is fitted with the training data
try:
    with open('scaler.pkl', 'rb') as file:
        scaler = pickle.load(file)
except IOError:
    print('error')

# defining the Scaled scale invariant loss
def ssim_loss(depth_pred, depth_gt):
    depth_gt_np = depth_gt.detach().numpy()
    depth_pred_np = depth_pred.detach().numpy()
    ssim_loss = (1.0 - ssim(depth_gt_np, depth_pred_np, data_range = depth_pred_np.max() - depth_pred_np.min()))
    return ssim_loss

class CustomLoss(nn.Module):
    def __init__(self):
        super(CustomLoss, self).__init__()
        self.mseLoss = torch.nn.MSELoss()

    def forward(self, depth_pred, depth_gt):
        mse_loss = self.mseLoss(depth_pred, depth_gt)
        ssim_loss_value = torch.tensor(ssim_loss(depth_pred, depth_gt), dtype=torch.float32)

        total_loss = mse_loss * exp(ssim_loss_value)

        return total_loss

# plotting the tensor into an image
def plot_the_tensor(tns):
    tns_np = tns.detach().cpu().numpy().astype('uint8')
    tns_img = Image.fromarray(tns_np)
    plt.imshow(tns_img)
    plt.axis('off')  # Hide the axes
    plt.show()

# converting the pixel estimated values to corresponding depth values
def depth_conversion(pred_pixel_values):
    max_pixel_value = 255
    max_depth_value = 5.4
    pred_depth_values = (pred_pixel_values * max_depth_value) / max_pixel_value
    return pred_depth_values

# defining the feed-forward network
class DepthEstimationModel(nn.Module):
    def __init__(self, op_size):
        super(DepthEstimationModel, self).__init__()
        self.fc_layers = nn.Sequential(
            nn.Linear(257*384, 1024*2),
            nn.Softplus(),
            nn.Linear(1024*2, 1024*2),
            nn.Softplus(),
            nn.Linear(1024*2, op_size*op_size),
            nn.Softplus()
        )

    def forward(self, x):
        x = x.view(x.size(0), -1)
        x = self.fc_layers(x)  # Passing through the fully connected layers
        x = x.view(x.size(0), op_size, op_size)
        return x

# creating an instance of the feed-forward network and loading the state_dict
model = DepthEstimationModel(op_size)
model.load_state_dict(torch.load(f'{model_name}.pt'))
# .eval() ensures that the weights are locked and not going to update
model.eval()

## Execute the below cell to infer without target image

In [None]:
# cell 3
def infer_without_target(image_path):
    start = time.time()
    rgb_image = Image.open(image_path).resize((224,224))
    rgb = transform(rgb_image).unsqueeze(0)
    with torch.no_grad():
        dino_features = dino_s.forward_features(rgb)

    patch_tokens = dino_features['x_norm_patchtokens']
    cls_tokens = dino_features['x_norm_clstoken']

    # concatinating the cls and patch tokens
    concat = torch.cat((cls_tokens.unsqueeze(0), patch_tokens),dim=1).squeeze(0)        

    # normalising the features
    concat_norm = scaler.transform(concat)
    # converting the features to a tensor
    concat_norm = torch.tensor(concat_norm,dtype=torch.float32).unsqueeze(0)
    # feeding the fetures to the network for inference    
    predicted_depth = model(concat_norm).squeeze(0)
    end = time.time()
    total = end - start
    total
    print('Time taken for inference is: ',total)
    print('Predicted pixel values: ',predicted_depth)
    print('Predicted depth estimates: ',depth_conversion(predicted_depth))
    plot_the_tensor(predicted_depth)
    display(rgb_image)

infer_without_target(input_image_path)

## Execute the below cell to infer with target image

In [None]:
# cell 4
def infer_with_target(rgb_path, target_path, op_size):
    criterion = CustomLoss()
    start = time.time()

    rgb_image = Image.open(rgb_path).resize((224,224))
    gt_depth_image = Image.open(target_path).convert('L').resize((op_size,op_size))
        
    rgb = transform(rgb_image).unsqueeze(0)

    with torch.no_grad():
        dino_features = dino_s.forward_features(rgb)

    patch_tokens = dino_features['x_norm_patchtokens']
    cls_tokens = dino_features['x_norm_clstoken']
    concat = torch.cat((cls_tokens.unsqueeze(0), patch_tokens),dim=1).squeeze(0)        

    concat_norm = scaler.transform(concat)
            
    concat_norm = torch.tensor(concat_norm,dtype=torch.float32).unsqueeze(0)
        
    predicted_depth = model(concat_norm).squeeze(0)

    end = time.time()
    total_time = end - start
        
    gt_np = np.array(gt_depth_image).astype(np.float32)
    gt_tensor = torch.tensor(gt_np)

    print(predicted_depth.shape)
    print('predicted pixel values: ',predicted_depth.detach().numpy().astype('uint8'))
    print('Predicted depth estimates: ',depth_conversion(predicted_depth))
    print('ground truth: ',gt_tensor.detach().numpy().astype('uint8'))

    loss = criterion(predicted_depth, gt_tensor)

    print('Time taken: ',total_time)
    print('Loss: ',loss)
    print('Predicted depth image:')
    plot_the_tensor(predicted_depth)
    print('Ground-truth depth:')
    plot_the_tensor(gt_tensor)
    print('Input image:')
    display(rgb_image)


infer_with_target(input_image_path, target_image_path, op_size)