In [None]:
# Importing the libraries
import os
import torch
import numpy as np
from PIL import Image

from torch.utils.data import DataLoader
import torchvision.transforms as transforms

from models.resnet_depth_unet import ResnetDepthUnet
from utils.dataloader import TraversabilityDataset

import matplotlib.pyplot as plt
import cv2 as cv

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
# Create the parameters object
class Object(object):
    pass

params = Object()
# dataset parameters
params.data_path        = r'C:/Users/deeks/Documents/WayFAST/myfile/data'
params.csv_path         = os.path.join(params.data_path, 'data.csv')
params.preproc          = True  # Vertical flip augmentation
params.depth_mean       = 3.5235
params.depth_std        = 10.6645

# model parameters
params.model            = ResnetDepthUnet
params.batch_size       = 1
params.pretrained       = True
params.load_network_path = True 
params.input_size       = (424, 240)
params.output_size      = (424, 240)
params.output_channels  = 1
params.bottleneck_dim   = 256

In [None]:
# Load the test images from the data folder
def load_images_from_folder(folder):
    images = []
    for filename in os.listdir(folder):
        img = cv.imread(os.path.join(folder,filename))
        if img is not None:
            images.append(img)
    return images

In [None]:
# Transform the images and create a dataloader
def transform_images(images):
    transform = transforms.Compose([
            transforms.ToPILImage(),
            transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
            ])

    dataset = TraversabilityDataset(params, transform)
    test_loader     = DataLoader(dataset, batch_size=params.batch_size, shuffle=True, num_workers=2)
    return test_loader


In [None]:
# Perform inference on the input images
def inference_on_images(model, test_loader, output_dir):
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
        

    with torch.no_grad():
        model.eval()
        for i, data in enumerate(test_loader):
            data = (item.to(device).type(torch.float32) for item in data)
            color_img, depth_img, path_img, mu_img, nu_img, weight = data

            pred = model(color_img, depth_img)


            # Save the image with the predicted label
            image = Image.fromarray(image[0])
            image.save(os.path.join(output_dir, f"image_{i}_predicted.png"))

            # Display the image with the predicted image beside eachother
            plt.figure(figsize = (14,14))
            plt.subplot(1, 3, 1)
            plt.imshow(image.permute(1, 2, 0).cpu().numpy())
            plt.subplot(1, 3, 2)
            plt.imshow(255*pred[0,0,:,:].detach().cpu().numpy(), vmin=0, vmax=255)
            plt.show(block=False)
            plt.pause(1)

In [None]:
# Read TIF images from data using cv2 and convert to numpy arrays

def read_images(data_path):
    images = []
    for filename in os.listdir(data_path):
        if filename.endswith('.tif'):
            img = cv.imread(os.path.join(data_path, filename), cv.IMREAD_GRAYSCALE)
            images.append(img)
    return images

In [None]:
# Print first 5 images

def print_images(images):
    for i in range(5):
        plt.imshow(images[i], cmap='gray')
        plt.show()

In [None]:
# Main function

def main():
    data_path = r'C:\Users\deeks\Documents\WayFAST\myfile\data\mu'
    images = read_images(data_path)
    print_images(images)

if __name__ == '__main__':
    main()