In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import cv2
import numpy as np
from PIL import Image
from torchvision import transforms
from tqdm import tqdm

In [2]:
# Define the UNet model class
class UNet(nn.Module):
    def __init__(self):
        super(UNet, self).__init__()
        # Define the contracting (encoder) path
        self.encoder1 = self.contracting_block(3, 64)
        self.encoder2 = self.contracting_block(64, 128)
        self.encoder3 = self.contracting_block(128, 256)
        self.encoder4 = self.contracting_block(256, 512)
        self.encoder5 = self.contracting_block(512, 1024)

        # Define the expansive (decoder) path
        self.decoder1 = self.expansive_block(1024, 512)
        self.decoder2 = self.expansive_block(1024, 256)  # Input from decoder1 + encoder4
        self.decoder3 = self.expansive_block(512, 128)  # Input from decoder2 + encoder3
        self.decoder4 = self.expansive_block(256, 64)   # Input from decoder3 + encoder2
        self.final_layer = nn.Conv2d(128, 1, kernel_size=1)  # Input from decoder4 + encoder1

    def contracting_block(self, in_channels, out_channels):
        # Contracting block with two convolutional layers followed by batch normalization and ReLU activation
        block = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )
        return block

    def expansive_block(self, in_channels, out_channels):
        # Expansive block with two convolutional layers followed by batch normalization and ReLU activation
        block = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )
        return block

    def crop_and_concat(self, upsampled, bypass):
        # Crop and concatenate function to handle the skip connections
        diffY = bypass.size()[2] - upsampled.size()[2]
        diffX = bypass.size()[3] - upsampled.size()[3]
        upsampled = F.pad(upsampled, (diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2))
        return torch.cat((upsampled, bypass), 1)

    def forward(self, x):
        # Encoder path
        e1 = self.encoder1(x)
        e2 = self.encoder2(F.max_pool2d(e1, kernel_size=2, stride=2))
        e3 = self.encoder3(F.max_pool2d(e2, kernel_size=2, stride=2))
        e4 = self.encoder4(F.max_pool2d(e3, kernel_size=2, stride=2))
        e5 = self.encoder5(F.max_pool2d(e4, kernel_size=2, stride=2))

        # Decoder path
        d1 = self.crop_and_concat(F.interpolate(self.decoder1(F.interpolate(e5, scale_factor=2, mode='bilinear')), scale_factor=2, mode='bilinear'), e4)
        d2 = self.crop_and_concat(F.interpolate(self.decoder2(F.interpolate(d1, scale_factor=2, mode='bilinear')), scale_factor=2, mode='bilinear'), e3)
        d3 = self.crop_and_concat(F.interpolate(self.decoder3(F.interpolate(d2, scale_factor=2, mode='bilinear')), scale_factor=2, mode='bilinear'), e2)
        d4 = self.crop_and_concat(F.interpolate(self.decoder4(F.interpolate(d3, scale_factor=2, mode='bilinear')), scale_factor=2, mode='bilinear'), e1)

        # Final output layer
        final_output = self.final_layer(d4)
        return final_output

model = UNet()  # Instantiate the UNet model

In [3]:
# Import necessary modules
import cv2
import numpy as np
from PIL import Image
import torch
from torchvision import transforms

# Define transformations for inference
transform = transforms.Compose([
    transforms.ToPILImage(),              # Convert numpy array to PIL image
    transforms.Resize((256, 256)),        # Resize the image to 256x256 pixels
    transforms.ToTensor(),                # Convert PIL image to tensor
    transforms.Normalize((0.5,), (0.5,))  # Normalize the image with mean and std deviation of 0.5
])

# Define the function to perform inference on a single image
def inference_on_image(model, image_path, transform, device):
    # Load and preprocess the image
    image = cv2.imread(image_path)                       # Read the image from the specified path
    original_size = image.shape[:2]                      # Save original image size (height, width)
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)       # Convert the image from BGR to RGB
    input_image = transform(image).unsqueeze(0).to(device)  # Apply transformations and add batch dimension
    
    # Perform inference
    with torch.no_grad():                                # Disable gradient computation
        output = model(input_image)                      # Get the model output
        
        # Check and print raw model outputs range
        raw_output = output.cpu().numpy().squeeze()      # Convert the output to numpy array and remove batch dimension
        print(f"Raw model outputs range: {raw_output.min()} to {raw_output.max()}")
        
        output = torch.sigmoid(output)                   # Apply sigmoid to get probability map
        
        # Visualize the raw output mask
        raw_output_image = (output.cpu().numpy().squeeze() * 255).astype(np.uint8)  # Scale the output to 0-255 range and convert to uint8
        raw_output_image = cv2.resize(raw_output_image, (original_size[1], original_size[0]), interpolation=cv2.INTER_NEAREST)  # Resize to original image size
        cv2.imwrite('raw_output_image.jpg', raw_output_image)  # Save the raw output image
        
        # Apply thresholding to get binary mask
        binary_output = (output.cpu().numpy().squeeze() > 0.5).astype(np.uint8)  # Threshold the output at 0.5
        
        # Check and print unique values in the binary output
        unique_values = np.unique(binary_output)         # Get unique values in the binary mask
        print(f"Unique values in the binary mask: {unique_values}")
        
        # Visualize the binary output mask
        binary_output_image = (binary_output * 255).astype(np.uint8)  # Scale the binary mask to 0-255 range and convert to uint8
        binary_output_image = cv2.resize(binary_output_image, (original_size[1], original_size[0]), interpolation=cv2.INTER_NEAREST)  # Resize to original image size
        cv2.imwrite('binary_output_image.jpg', binary_output_image)  # Save the binary output image
    
    # Resize the binary output mask to the original image size
    output_resized = cv2.resize(binary_output, (original_size[1], original_size[0]), interpolation=cv2.INTER_NEAREST)  # Resize to original image size
    
    # Superimpose the result on the original image
    superimposed_image = image.copy()                   # Copy the original image
    superimposed_image[output_resized == 1] = [0, 255, 0]  # Highlight drivable area with green color
    
    return superimposed_image

# Perform inference and save the result
model.load_state_dict(torch.load('unet_model.pth'))      # Load the trained model weights
model.eval()                                             # Set the model to evaluation mode
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")  # Use GPU if available, else use CPU
model.to(device)                                         # Move the model to the specified device
input_image_path = 'road_image_191.png'  # Replace with your image path
result_image = inference_on_image(model, input_image_path, transform, device)  # Perform inference on the input image

# Save the superimposed image
result_image = Image.fromarray(result_image)             # Convert numpy array to PIL image
result_image.save('superimposed_image.jpg')              # Save the superimposed image
print("Inference on image completed and saved!")


Raw model outputs range: -7.473902702331543 to 0.8656378388404846
Unique values in the binary mask: [0 1]
Inference on image completed and saved!


In [4]:
# Inference on a Video
def inference_on_video(model, video_path, transform, device):
    cap = cv2.VideoCapture(video_path)
    fourcc = cv2.VideoWriter_fourcc(*'XVID')
    out = cv2.VideoWriter('output_video.avi', fourcc, 20.0, (int(cap.get(3)), int(cap.get(4))))
    
    total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))  # Get total number of frames for progress bar
    with tqdm(total=total_frames, desc="Processing Video", unit="frame") as pbar:
        while cap.isOpened():
            ret, frame = cap.read()
            if ret:
                # Save original frame size
                original_size = frame.shape[:2]
                
                # Preprocess the frame
                input_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
                input_frame = transform(input_frame).unsqueeze(0).to(device)
                
                # Perform inference
                with torch.no_grad():
                    output = model(input_frame)
                    output = torch.sigmoid(output)
                    output = output.cpu().numpy().squeeze()
                    output = (output > 0.5).astype(np.uint8)
                
                # Resize the output mask to the original frame size
                output_resized = cv2.resize(output, (original_size[1], original_size[0]), interpolation=cv2.INTER_NEAREST)
                
                # Superimpose the result on the original frame
                superimposed_frame = frame.copy()
                superimposed_frame[output_resized == 1] = [0, 255, 0]  # Highlight drivable area with green
                
                # Write the superimposed frame to the output video
                out.write(superimposed_frame)
                
                # Update progress bar
                pbar.update(1)
            else:
                break
    
    cap.release()
    out.release()
    print("Inference on video completed and saved!")

# Perform inference on the video
input_video_path = 'video_road_driving.mp4'  # Replace with your video path
inference_on_video(model, input_video_path, transform, device)

Processing Video: 100%|████████████████████| 391/391 [00:08<00:00, 44.08frame/s]

Inference on video completed and saved!



