### Libraries and imports

In [None]:
!pip -q install mediapy
!pip -q install skvideo

In [None]:
import skvideo.io
import mediapy as media
from tqdm.notebook import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as T
import cv2
import numpy as np
import os

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
print('device:', device)

device: cpu


In [None]:
os.chdir('../')

In [None]:
from lib.model.RT_distr_v2_conv1x1 import RT_MonoDepth_Mk2

### Load models

In [None]:
SIZE = (256, 256)
stats = (0.5, 0.5, 0.5), (0.5, 0.5, 0.5)
MODELS = {
    0: {'arch': RT_MonoDepth_Mk2(decode_distr=False),
        'checkpoint': 'RT_MonoDepth_Mk2.PELoss.pt',
        'name': 'conv_without_decoding_distribution'},
    1: {'arch': RT_MonoDepth_Mk2(decode_distr=True),
        'checkpoint': 'RT_MonoDepth_Mk2.PELoss_DD.pt',
        'name': 'conv_with_decoding_distribution'},
    }

### Video testing scripts

In [None]:
source = 'video/test.mp4'
target = 'video/result'

In [None]:
def test(video_path: str, model_num: int, show: bool):
    """loads pretrained model, processes video and save a result"""
    model = MODELS[model_num]['arch']
    chkp = MODELS[model_num]['checkpoint']
    checkpoint = torch.load(chkp, map_location=torch.device(device))
    model.load_state_dict(checkpoint['state_dict'])
    model.to(device)
    model.eval()
    transform = T.Compose([T.ToTensor(),
                           T.Resize(SIZE),
                           T.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.01),
                           T.Normalize(*stats)
                          ])
    cap = cv2.VideoCapture(video_path)
    stop = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))

    res_frames = []
    for fr in tqdm(range(stop), desc=f'processing video'):
        cap.set(cv2.CAP_PROP_POS_FRAMES, fr)
        _, frame = cap.read()
        frame = transform(frame)
        with torch.no_grad():
            result = model(frame.unsqueeze(0))
        res_frames.append(result.squeeze())

    save_path = f'{target}_{MODELS[model_num]["name"]}.mp4'
    out_video =  np.empty([len(res_frames), 256, 256], dtype = np.uint8)
    out_video =  out_video.astype(np.uint8)
    for i in range(len(res_frames)):
        out_video[i] = (res_frames[i] * 256).numpy().astype(np.uint8)
    skvideo.io.vwrite(save_path, out_video)

    if show:
        video = media.read_video(save_path)
        media.show_video(video, fps=25)

    return res_frames

### Model with distribution in decoder

In [None]:
frms_2 = test(source, model_num=1, show=True)

processing video:   0%|          | 0/393 [00:00<?, ?it/s]

0
This browser does not support the video tag.


### Smooth Function

In [None]:
def smooth(ws: int, model_num: int, frames: list, show: bool):

    smframes = []
    for i in range(ws, len(frames)):
        prev = 0
        for j in range(ws):
            prev += frames[i - ws + j]
        sm = (prev + frames[i]) / (ws +1)
        smframes.append(sm)
        res_frames = smframes
    save_path = f'{target}_{MODELS[model_num]["name"]}_smooth.mp4'
    out_video =  np.empty([len(res_frames), 256, 256], dtype = np.uint8)
    out_video =  out_video.astype(np.uint8)
    for i in range(len(res_frames)):
        out_video[i] = (res_frames[i] * 256).numpy().astype(np.uint8)
    skvideo.io.vwrite(save_path, out_video)

    if show:
        video = media.read_video(save_path)
        media.show_video(video, fps=25)

In [None]:
smooth(2, 1, frms_2, show=True)

0
This browser does not support the video tag.
