In [6]:
import os
import pandas as pd 
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import transforms

from torch.utils.data import Dataset
from PIL import Image
import glob

from torchvision import models
import tqdm
from torchvision.transforms import ToTensor
import time
from torch.autograd import Variable
import torch.nn.functional as F
from torchvision.transforms import Resize, Compose, ToPILImage, ToTensor
import pickle
import math
import torchvision.transforms.functional as TF
#from efficientnet_pytorch import EfficientNet

#from kornia.filters import SpatialGradient

import random
from torchvision.transforms import RandomCrop

In [7]:
def conv_relu_block(in_channel,out_channel,kernel,padding):
    return nn.Sequential(
            nn.Conv2d(in_channel,out_channel, kernel_size = kernel, padding=padding),
            nn.ReLU()) #nn.ReLU(inplace=True) #nn.Ge

In [8]:
class vanilla_unet_full(nn.Module):
    def __init__(self, n_class):
        super().__init__()
        self.input_1 = conv_relu_block(3,3,3,1) ##grayscale inputs
        #self.input_2 = conv_relu_block(64, 64, 3, 1) #no extra channels

        self.base_model = models.resnet18(pretrained=True)
        self.base_layers = list(self.base_model.children())

        self.l0 = nn.Sequential(*self.base_layers[:3])
        self.U0_conv = conv_relu_block(64, 64, 1, 0)
        self.conv_up0 = conv_relu_block(64 + 256, 128, 3, 0)

        self.l1 = nn.Sequential(*self.base_layers[3:5])
        self.U1_conv = conv_relu_block(64, 64, 1, 0)
        self.conv_up1 = conv_relu_block(64 + 256, 256, 3, 1)

        self.l2 = self.base_layers[5]
        self.U2_conv = conv_relu_block(128, 128, 1, 0)
        self.conv_up2 = conv_relu_block(128 + 512, 256, 3, 1)

        self.l3 = self.base_layers[6]
        self.U3_conv = conv_relu_block(256, 256, 1, 0)
        self.conv_up3 = conv_relu_block(256 + 512, 512, 3, 1)

        self.l4 = self.base_layers[7]
        self.U4_conv = conv_relu_block(512, 512, 1, 0)

        self.conv_up4 = conv_relu_block(64 + 128, 64, 3, 1)

        self.out4 = nn.Conv2d(128, n_class, 1)

        self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)

    def forward(self, x, SAM):
        x = torch.cat([x,SAM,x], axis = 1)
        x = self.input_1(x)
        
        #print(x.shape,'x')
         #concat on channel
        #x_one = self.input_2(x_one)
        block0 = self.l0(x)
        block1 = self.l1(block0)
        block2 = self.l2(block1)
        block3 = self.l3(block2)
        block4 = self.l4(block3)

        block4 = self.U4_conv(block4)
        #print(block4.shape)
        x = nn.Upsample(size = (138,97), mode='bilinear', align_corners=True)(block4)
        block3 = self.U3_conv(block3)
        
        x = torch.cat([x, block3], axis=1)
        x = self.conv_up3(x)
        
        x = nn.Upsample(size = (275,194), mode='bilinear', align_corners=True)(x)
        
        block2 = self.U2_conv(block2)
        
        #print(x.shape, block2.shape)
        x = torch.cat([x, block2], axis=1)
        
        x = self.conv_up2(x)

        x = nn.Upsample(size = (550,388), mode='bilinear', align_corners=True)(x)
        block1 = self.U1_conv(block1)
        #print(x.shape, block1.shape)
        
        x = torch.cat([x, block1], axis=1)
        x = self.conv_up1(x)

        x = nn.Upsample(size = (1100, 775), mode='bilinear', align_corners=True)(x) 
        block0 = self.U0_conv(block0)

        #print(x.shape, block0.shape)
        x = torch.cat([x, block0], axis=1)
        x = self.conv_up0(x)
        out4 = self.out4(x)

        out4_upsampled = F.interpolate(out4, size=(2200,1550), mode='bilinear', align_corners=True)
        
        relu = nn.ReLU()
        out = relu(out4_upsampled)
        
        
        return out4_upsampled


In [27]:
device = 'cuda'

In [1]:
model_list = [vanilla_unet_full]
model_path_list = ['../../nkono/IVC_MDE/unet_sam.pt'] 

sam = True



for j, model in enumerate(model_list):
    
    num_samples = len(model_list)
    
    silog = np.zeros(num_samples, np.float32)
    log10 = np.zeros(num_samples, np.float32)
    rms = np.zeros(num_samples, np.float32)
    log_rms = np.zeros(num_samples, np.float32)
    abs_rel = np.zeros(num_samples, np.float32)
    sq_rel = np.zeros(num_samples, np.float32)
    d1 = np.zeros(num_samples, np.float32)
    d2 = np.zeros(num_samples, np.float32)
    d3 = np.zeros(num_samples, np.float32)
    
    model = model_list[j](5)
    
    

    
    model.load_state_dict(torch.load(model_path_list[j]))
    model.to(device)

    in_path = "../../jkoh/inputs/"
    y_path = '../../jkoh/depth_annotations/'
    dir_list = os.listdir(in_path)
    d_paths = [(in_path+v,y_path+v) for v in dir_list]
    
    if sam:
        sam_dir = "sam_outputs/val_mask/"
        sam_paths = [(sam_dir+v) for v in dir_list]
    
    silog = [] 
    for i,path in enumerate(d_paths):
        image = Image.open(path[0])
        dt_depth = Image.open(path[1])
        
        transform = transforms.Compose([transforms.ToTensor()])
        image = transform(image).unsqueeze(0).to(device)
        
        if sam:
            sam_output = Image.open(sam_paths[i])
            sam_output = transform(sam_output).unsqueeze(0).to(device)
       
            outputs = model(image, sam_output)
            
        
        else:
            outputs = model(image)
        I
        
        silog[i], log10[i], abs_rel[i], sq_rel[i], rms[i], log_rms[i], d1[i], d2[i], d3[i] = compute_errors(gt_depth, pred_depth)
        
    print("{:>7}, {:>7}, {:>7}, {:>7}, {:>7}, {:>7}, {:>7}, {:>7}, {:>7}".format('silog', 'abs_rel', 'log10', 'rms',
                                                                                 'sq_rel', 'log_rms', 'd1', 'd2', 'd3'))
    print("{:7.4f}, {:7.4f}, {:7.3f}, {:7.3f}, {:7.3f}, {:7.3f}, {:7.3f}, {:7.3f}, {:7.3f}".format(
        silog.mean(), abs_rel.mean(), log10.mean(), rms.mean(), sq_rel.mean(), log_rms.mean(), d1.mean(), d2.mean(),
        d3.mean()))

NameError: name 'vanilla_unet_full' is not defined

In [34]:
model_list[j]

__main__.vanilla_unet_full