In [9]:
model_name = '32_model' # please enter 224_model or 32_model
op_size = 32 # please enter 224 or 32

In [10]:
from math import exp
import torch
import torch.nn as nn
import torch.optim.lr_scheduler as lr_scheduler
import numpy as np
from PIL import Image
import torchvision.transforms as T
import matplotlib.pyplot as plt
import glob
import time
from sklearn.preprocessing import MinMaxScaler
import warnings
import random
import pickle
from skimage.metrics import structural_similarity as ssim

# loading the distilled small visual transformer
dino_s = torch.hub.load('facebookresearch/dinov2', 'dinov2_vits14')

transform = T.Compose([
    T.ToTensor()
])

scaler = MinMaxScaler()
# loading the pickle file of the scaler object which is fitted with the training data
try:
    with open('scaler.pkl', 'rb') as file:
        scaler = pickle.load(file)
except IOError:
    print('error')

# defining the Scaled scale invariant loss
def ssim_loss(depth_pred, depth_gt):
    depth_gt_np = depth_gt.detach().numpy()
    depth_pred_np = depth_pred.detach().numpy()
    ssim_loss = (1.0 - ssim(depth_gt_np, depth_pred_np, data_range = depth_pred_np.max() - depth_pred_np.min()))
    return ssim_loss

class CustomLoss(nn.Module):
    def __init__(self):
        super(CustomLoss, self).__init__()
        self.mseLoss = torch.nn.MSELoss()

    def forward(self, depth_pred, depth_gt):
        mse_loss = self.mseLoss(depth_pred, depth_gt)
        ssim_loss_value = torch.tensor(ssim_loss(depth_pred, depth_gt), dtype=torch.float32)

        total_loss = mse_loss * exp(ssim_loss_value)

        return total_loss

# plotting the tensor into an image
def plot_the_tensor(tns):
    tns_np = tns.detach().cpu().numpy().astype('uint8')
    tns_img = Image.fromarray(tns_np)
    plt.imshow(tns_img)
    plt.axis('off')  # Hide the axes
    plt.show()

# converting the pixel estimated values to corresponding depth values
def depth_conversion(pred_pixel_values):
    max_pixel_value = 255
    max_depth_value = 5.4
    pred_depth_values = (pred_pixel_values * max_depth_value) / max_pixel_value
    return pred_depth_values

# defining the feed-forward network
class DepthEstimationModel(nn.Module):
    def __init__(self, op_size):
        super(DepthEstimationModel, self).__init__()
        self.fc_layers = nn.Sequential(
            nn.Linear(257*384, 1024*2),
            nn.Softplus(),
            nn.Linear(1024*2, 1024*2),
            nn.Softplus(),
            nn.Linear(1024*2, op_size*op_size),
            nn.Softplus()
        )

    def forward(self, x):
        x = x.view(x.size(0), -1)
        x = self.fc_layers(x)  # Passing through the fully connected layers
        x = x.view(x.size(0), op_size, op_size)
        return x

# creating an instance of the feed-forward network and loading the state_dict
model = DepthEstimationModel(op_size)
model.load_state_dict(torch.load(f'{model_name}.pt'))
# .eval() ensures that the weights are locked and not going to update
model.eval()

Using cache found in /home/akhkr1/.cache/torch/hub/facebookresearch_dinov2_main


DepthEstimationModel(
  (fc_layers): Sequential(
    (0): Linear(in_features=98688, out_features=2048, bias=True)
    (1): Softplus(beta=1, threshold=20)
    (2): Linear(in_features=2048, out_features=2048, bias=True)
    (3): Softplus(beta=1, threshold=20)
    (4): Linear(in_features=2048, out_features=1024, bias=True)
    (5): Softplus(beta=1, threshold=20)
  )
)

In [14]:
total, total_resize, total_concat_and_norm, total_dino, total_network = 0,0,0,0,0
def infer_without_target(image_path):
    global total, total_resize, total_concat_and_norm, total_dino, total_network
    start = time.time()
    
    start_resize = time.time()
    rgb_image = Image.open(image_path).resize((224,224))
    end_resize = time.time()
    
    rgb = transform(rgb_image).unsqueeze(0)
    
    start_dino = time.time()
    with torch.no_grad():
        dino_features = dino_s.forward_features(rgb)

    patch_tokens = dino_features['x_norm_patchtokens']
    cls_tokens = dino_features['x_norm_clstoken']
    end_dino = time.time()

    start_concat_and_norm = time.time()
    # concatinating the cls and patch tokens
    concat = torch.cat((cls_tokens.unsqueeze(0), patch_tokens),dim=1).squeeze(0)        

    # normalising the features
    concat_norm = scaler.transform(concat)
    end_concat_and_norm = time.time()
    
    # converting the features to a tensor
    concat_norm = torch.tensor(concat_norm,dtype=torch.float32).unsqueeze(0)

    start_network = time.time()
    # feeding the fetures to the network for inference    
    predicted_depth = model(concat_norm).squeeze(0)
    end_network = time.time()
    
    end = time.time()
    total += end - start
    total_resize += end_resize - start_resize
    total_dino += end_dino - start_dino
    total_concat_and_norm += end_concat_and_norm - start_concat_and_norm
    total_network += end_network - start_network

input_files = glob.glob('/cs/home/akhkr1/Documents/test_data/test_input/*.jpg')
input_files.sort()

for i in range(20,50):
    infer_without_target(input_files[i])

with open('latency_logs.txt', 'a') as file:
    file.write(f'The following are the mean latencies by 32_model for 30 images from test dataset:\n')
    file.write(f'Resizing the input image: {total_resize/30}\n')
    file.write(f'Dino processing the input image: {total_dino/30}\n')
    file.write(f'Concatenating and normalising the dino features: {total_concat_and_norm/30}\n')
    file.write(f'Dense network processing the input image: {total_network/30}\n')
    file.write(f'Total inference process: {total/30}\n')


