In [4]:
import numpy as np
import cv2
import torch
from time import time
import os
import datetime
import pandas as pd
import matplotlib.pyplot as plt
import torch.nn as nn
import torchvision
import torch.nn.functional as F
from torch.utils import data
from image_utils import dense_warp, warp
device = 'cuda'
height,width = 360,640
batch_size = 1
grid_h,grid_w = 15,15

In [5]:
def get_warp(net_out,img):
    '''
    Inputs:
        net_out: torch.Size([batch_size,grid_h +1 ,grid_w +1,2])
        img: image to warp
    '''
    grid_y, grid_x = torch.meshgrid(torch.linspace(-1,1, grid_h + 1),
                                    torch.linspace(-1,1, grid_h + 1),
                                    indexing='ij')
    src_grid = torch.stack([grid_x,grid_y],dim = -1).unsqueeze(0).repeat(batch_size,1,1,1).to(device)
    new_grid = src_grid + net_out
    grid_upscaled = F.interpolate(new_grid.permute(0,-1,1,2),size = (height,width), mode = 'bilinear',align_corners= True)
    warped = F.grid_sample(img, grid_upscaled.permute(0,2,3,1),align_corners=False,padding_mode='zeros')
    return warped

In [3]:
class StabNet(nn.Module):
    def __init__(self,trainable_layers = 10):
        super(StabNet, self).__init__()
        # Load the pre-trained ResNet model
        vgg19 = torchvision.models.vgg19(weights='IMAGENET1K_V1')
        # Extract conv1 pretrained weights for RGB input
        rgb_weights = vgg19.features[0].weight.clone() #torch.Size([64, 3, 3, 3])
        # Calculate the average across the RGB channels
        average_rgb_weights = torch.mean(rgb_weights, dim=1, keepdim=True).repeat(1,6,1,1)  #torch.Size([64, 5, 7, 7])
        # Change size of the first layer from 3 to 9 channels
        vgg19.features[0] = nn.Conv2d(9,64, kernel_size=3, stride=1, padding=1, bias=False)
        # set new weights
        new_weights = torch.cat((rgb_weights, average_rgb_weights), dim=1)
        vgg19.features[0].weight = nn.Parameter(new_weights)
        # Determine the total number of layers in the model
        total_layers = sum(1 for _ in vgg19.parameters())
        # Freeze the layers except the last 10
        for idx, param in enumerate(vgg19.parameters()):
            if idx > total_layers - trainable_layers:
                param.requires_grad = True
            else:
                param.requires_grad = False
        # Remove the last layer of ResNet
        self.encoder = nn.Sequential(*list(vgg19.children())[0][:-1])
        self.regressor = nn.Sequential(nn.Linear(512,2048),
                                       nn.ReLU(),
                                       nn.Linear(2048,1024),
                                       nn.ReLU(),
                                       nn.Linear(1024,512),
                                       nn.ReLU(),
                                       nn.Linear(512, ((grid_h + 1) * (grid_w + 1) * 2)))
        total_resnet_params = sum(p.numel() for p in self.encoder.parameters() if p.requires_grad)
        total_regressor_params = sum(p.numel() for p in self.regressor.parameters() if p.requires_grad)
        print("Total Trainable encoder Parameters: ", total_resnet_params)
        print("Total Trainable regressor Parameters: ", total_regressor_params)
        print("Total Trainable parameters:",total_regressor_params + total_resnet_params)
    
    def forward(self, x_tensor):
        x_batch_size = x_tensor.size()[0]
        x = x_tensor[:, :3, :, :]

        # summary 1, dismiss now
        x_tensor = self.encoder(x_tensor)
        x_tensor = torch.mean(x_tensor, dim=[2, 3])
        x = self.regressor(x_tensor)
        x = x.view(x_batch_size,grid_h + 1,grid_w + 1,2)
        return x

In [10]:
ckpt_dir = './ckpts/original/'
stabnet = StabNet().to(device).eval()
ckpts = os.listdir(ckpt_dir)
if ckpts:
    ckpts = sorted(ckpts, key=lambda x: datetime.datetime.strptime(x.split('_')[2].split('.')[0], "%H-%M-%S"), reverse=True)
    
    # Get the filename of the latest checkpoint
    latest = os.path.join(ckpt_dir, ckpts[0])

    state = torch.load(latest)
    stabnet.load_state_dict(state['model'])
    print('loaded weights',latest)

Total Trainable encoder Parameters:  2360320
Total Trainable regressor Parameters:  3936256
Total Trainable parameters: 6296576
loaded weights ./ckpts/original/stabnet_2023-10-26_13-42-14.pth


In [7]:
path = 'E:/Datasets/DeepStab_Dataset/unstable/2.avi'
cap = cv2.VideoCapture(path)
frames = []
while True:
    ret,frame = cap.read()
    if not ret : break
    frame = cv2.resize(frame,(width,height))
    frames.append(frame)
frames = np.array(frames)

In [8]:
frames_t = torch.from_numpy(frames/255.0).permute(0,3,1,2).float()
frames_t.shape

torch.Size([447, 3, 360, 640])

In [9]:
num_frames,_,h,w = frames_t.shape
warped_frames = frames_t.clone()
buffer = torch.zeros((6,1,h,w)).float()
cv2.namedWindow('window',cv2.WINDOW_NORMAL)
start = time()
for iter in range(1):
    for idx in range(33,num_frames):
        for i in range(6):
            buffer[i,...] = torch.mean(warped_frames[idx - 2**i,...],dim = 0,keepdim = True)
        curr = warped_frames[idx:idx+1,...] 
        net_in = torch.cat([curr,buffer.permute(1,0,2,3)], dim = 1).to(device)
        with torch.no_grad():
            trasnform = stabnet(net_in)
            warped = get_warp(trasnform * 1 ,curr.to(device))
            warped_frames[idx:idx+1,...] = warped.cpu()
            warped_gray = torch.mean(warped,dim = 1,keepdim=True)
            buffer = torch.roll(buffer, shifts= 1, dims=1)
            buffer[:,:1,:,:] = warped_gray
            img = warped_frames[idx,...].permute(1,2,0).numpy()
            img = (img * 255).astype(np.uint8)
            cv2.imshow('window',img)
            if cv2.waitKey(1) & 0xFF == ord(' '):
                break
cv2.destroyAllWindows()
total = time() - start
speed = total / num_frames
print(f'speed: {speed} seconds per frame')  

KeyboardInterrupt: 

In [17]:
from time import sleep
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
out = cv2.VideoWriter('./results/2.avi', fourcc, 30.0, (256,256))
cv2.namedWindow('window',cv2.WINDOW_NORMAL)
for idx in range(num_frames):
    img = warped_frames[idx,...].permute(1,2,0).numpy()
    img = (img * 255).astype(np.uint8)
    diff = cv2.absdiff(img,frames[idx,...])
    out.write(img)
    cv2.imshow('window',img)
    sleep(1/30)
    if cv2.waitKey(1) & 0xFF == ord(' '):
        break
cv2.destroyAllWindows()
out.release()

In [None]:
cv2.namedWindow('window',cv2.WINDOW_NORMAL)
for idx in range(num_frames):
    img = warped_frames[idx,...].permute(1,2,0).numpy()
    img = (img * 255).astype(np.uint8)
    diff = cv2.absdiff(img,frames[idx,...])
    cv2.imshow('window',diff)
    sleep(1/30)
    if cv2.waitKey(1) & 0xFF == ord(' '):
        break
cv2.destroyAllWindows()

In [18]:
from metrics import metric
metric('E:/Datasets/DeepStab_Dataset/unstable/2.avi','./results/Regular_2.avi')

Frame: 446/447
cropping score:1.000	distortion score:0.989	stability:0.639	pixel:0.997


(1.0, 0.98863894, 0.6389072784994596, 0.99749401723966)