In [43]:
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"device: {DEVICE}")

device: cuda


## Install and import required libraries

In [44]:
pip install timm

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [45]:
!python -m pip install pyyaml==5.1

!pip install ffmpeg-python

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [46]:
import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

import cv2
import ffmpeg
from tqdm import tqdm

## Mount to google drive and specify path

In [47]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [48]:
input_folder_path = "/content/drive/MyDrive/Colab Notebooks/data/Project" #change your path here

## Define Unet architecture and load the model

In [49]:
class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )
        
    def forward(self, x):
        return self.conv(x)
    
class Down(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.mpconv = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(in_channels, out_channels)
        )
        
    def forward(self, x):
        return self.mpconv(x)
    
class Up(nn.Module):
    def __init__(self, in_channels, out_channels, bilinear=True):
        super().__init__()
        self.conv = DoubleConv(in_channels, out_channels)
        if bilinear:
            self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        else:
            self.up = nn.ConvTranspose2d(in_channels//2, in_channels//2, kernel_size=2, stride=2)
        
    def forward(self, x1, x2):
        x1 = self.up(x1)
        diffY = x2.size()[2] - x1.size()[2]
        diffX = x2.size()[3] - x1.size()[3]

        x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
                        diffY // 2, diffY - diffY // 2])
        x = torch.cat([x2, x1], dim=1)
        return self.conv(x)

class OutConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(OutConv, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)
        
    def forward(self, x):
        return self.conv(x)
    
class UNet(nn.Module):
    def __init__(self, n_channels=4, n_classes=3, bilinear=True):
        super(UNet, self).__init__()
        self.n_channels = n_channels
        self.n_classes = n_classes
        self.bilinear = bilinear
        
        self.inc = DoubleConv(n_channels, 64)
        self.down1 = Down(64, 128)
        self.down2 = Down(128, 256)
        self.down3 = Down(256, 512)
        self.down4 = Down(512, 512)
        self.up1 = Up(1024, 256, bilinear)
        self.up2 = Up(512, 128, bilinear)
        self.up3 = Up(256, 64, bilinear)
        self.up4 = Up(128, 64, bilinear)
        self.outc = OutConv(64, n_classes)
        
    def forward(self, x):
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        logits = self.outc(x)
        return logits

In [50]:
model = UNet(4,3)
model = model.to(DEVICE)

In [51]:
model.load_state_dict(torch.load(input_folder_path+"/my_checkpoint.pth.tar")["state_dict"])

<All keys matched successfully>

## Load the target video

In [52]:
if os.path.isfile(input_folder_path+"/vdo_3.mp4"): #change the video name here
    print('found')

found


In [53]:
vidcap = cv2.VideoCapture(os.path.join(input_folder_path+"/vdo_3.mp4")) #change the video name here
frame_count = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
target_img_w = 720
target_img_h = 480

## Import midas to create depth map

In [54]:
midas = torch.hub.load("intel-isl/MiDaS", "DPT_Large")
midas.to(DEVICE)
midas.eval()
midas_transforms = torch.hub.load("intel-isl/MiDaS", "transforms")
transform = midas_transforms.dpt_transform

Using cache found in /root/.cache/torch/hub/intel-isl_MiDaS_master
Using cache found in /root/.cache/torch/hub/intel-isl_MiDaS_master


In [55]:
def get_depth_map(img, midas, transform, device):
    # Process the image with the MiDaS model
    input_batch = transform(img).to(device)

    with torch.no_grad():
        prediction = midas(input_batch)
        prediction = torch.nn.functional.interpolate(
            prediction.unsqueeze(1),
            size=img.shape[:2],
            mode="nearest-exact",
        ).squeeze()
    depth_map = prediction.cpu().numpy()

    return depth_map

## Get the video output 

In [56]:
# Output Stream
out_stream = (
    ffmpeg
    .input('pipe:', format='rawvideo', pix_fmt='rgb24', s='{}x{}'.format(target_img_w, target_img_h))
    .output('output3.mp4', pix_fmt='rgb24')
    .overwrite_output()
    .run_async(pipe_stdin=True)
)

success, image = vidcap.read()
with tqdm(total=frame_count, position=0, leave=True) as pbar:
    predictions = []
    while success:
        image = cv2.resize(
            image, (target_img_w, target_img_h), 
            interpolation=cv2.INTER_LINEAR)
        
        # Generate the depth map using the MiDaS model
        depth_map = get_depth_map(image, midas, transform, DEVICE)
        
        # Convert the image and depth map to a PyTorch tensor and normalize them
        image_tensor = torch.tensor(image, dtype=torch.float32).permute(2, 0, 1).unsqueeze(0) / 255.0
        depth_map_tensor = torch.tensor(depth_map, dtype=torch.float32).unsqueeze(0).unsqueeze(0)
        image_tensor = torch.cat([image_tensor, depth_map_tensor], dim=1)
        image_tensor = image_tensor.to(DEVICE)

        # Uss U-Net model for segmentation
        logits = model(image_tensor)
        pred_masks = torch.argmax(logits, dim=1).squeeze(0).cpu().numpy()
        
        # Visualize the segmentation results
        out_image = np.zeros_like(image)
        for i in range(0, 3):  
            masks = pred_masks == i
            if i == 0:
                color = [61, 11, 81]  
            elif i == 1:
                color = [69, 142, 139] 
                font = cv2.FONT_HERSHEY_SIMPLEX
                text = "Detected"
                position = (50, 50)
                font_scale = 1
                font_color = (255, 255, 255)
                line_type = 2
                cv2.putText(out_image, text, position, font, font_scale, font_color, line_type) 
            else:
                color = [250, 230, 85] 

            out_image[masks] = color  # Set the color for the current class
        
        out_stream.stdin.write(
            out_image
            .astype(np.uint8)
            .tobytes()
        )
        pbar.update()
        success, image = vidcap.read()
        predictions.append(pred_masks)

out_stream.stdin.close()


100%|██████████| 455/455 [02:58<00:00,  2.55it/s]
