In [1]:
# Import necessary libraries
import os
from PIL import Image
import torch
import torchvision.transforms as transforms
from torchvision import models
import cv2
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt

In [2]:
# Ensure the device is correctly set
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load the trained model with ignoring unexpected keys
model = models.segmentation.deeplabv3_resnet50(weights=None, num_classes=2)
checkpoint = torch.load('deeplabv3_rock_detection.pth')
model.load_state_dict(checkpoint, strict=False)
model = model.to(device)
model.eval()

# Define transformations
transform = transforms.Compose([
    transforms.Resize((128, 128)),
    transforms.ToTensor()
])

In [3]:
# Function to process a single frame
def process_frame(frame, model, transform, device):
    # Convert frame to PIL Image
    image = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
    input_image = transform(image).unsqueeze(0).to(device)
    
    # Perform inference
    with torch.no_grad():
        output = model(input_image)['out']
    prediction = torch.argmax(output.squeeze(), dim=0).cpu().numpy()
    
    # Resize prediction to match original frame size
    prediction_resized = np.array(Image.fromarray(prediction.astype(np.uint8)).resize((frame.shape[1], frame.shape[0]), resample=Image.NEAREST))
    
    # Overlay the mask on the original frame
    mask_overlay = Image.fromarray(prediction_resized).convert("RGBA")
    mask_overlay = Image.blend(image.convert("RGBA"), mask_overlay, alpha=0.5)
    
    # Convert back to OpenCV format
    overlay_frame = cv2.cvtColor(np.array(mask_overlay), cv2.COLOR_RGBA2BGR)
    
    return overlay_frame

In [4]:
# Function to process a video
def process_video(input_video_path, output_video_path, model, transform, device):
    # Open the input video
    cap = cv2.VideoCapture(input_video_path)
    
    # Get video properties
    fps = int(cap.get(cv2.CAP_PROP_FPS))
    width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
    total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    
    # Define the codec and create VideoWriter object
    fourcc = cv2.VideoWriter_fourcc(*'XVID')
    out = cv2.VideoWriter(output_video_path, fourcc, fps, (width, height))
    
    # Initialize tqdm progress bar
    with tqdm(total=total_frames, desc="Processing Video") as pbar:
        for _ in range(total_frames):
            ret, frame = cap.read()
            if not ret:
                break
            
            # Process the frame
            overlay_frame = process_frame(frame, model, transform, device)
            
            # Write the frame to the output video
            out.write(overlay_frame)
            
            # Update progress bar
            pbar.update(1)
    
    # Release everything if job is finished
    cap.release()
    out.release()
    cv2.destroyAllWindows()

In [5]:
# Path to the input video
input_video_path = 'sample_video.mp4'

# Path to the output video
output_video_path = 'deeplab_v3_output_video.avi'

In [6]:
# Process the video # EXECUTE THE INFERENCE
process_video(input_video_path, output_video_path, model, transform, device)

Processing Video: 100%|███████████████████████| 901/901 [00:37<00:00, 24.29it/s]
