In [1]:
import torch
import cv2
import numpy as np
import matplotlib.pyplot as plt

from PIL import Image
from torchvision import models, transforms
from utils import FeatureExtractor, prepare_frame_for_inference, get_feature_vector
from utils import imshow_frame, cos, compute_similarity, compute_similarity_index_known

In [2]:
video_path = 'path_of_corrupted_video'
video_frames = []

capture = cv2.VideoCapture(video_path)
while capture.isOpened():
    ret, frame = capture.read()
    if ret:
        rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        video_frames.append(rgb_frame)
    else:
        break
capture.release()
print("Number of captured frames:", len(video_frames))

Number of captured frames: 114


In [3]:
# Define the transformer to prepare images to be acceptable by model
transform = transforms.Compose([
    transforms.Resize(size = 256),
    transforms.CenterCrop(size = 224),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    
])

# Check if cuda is avaible
device = torch.device('cuda:0' if torch.cuda.is_available() else "cpu")

base_model = models.vgg16(pretrained = True)
model = FeatureExtractor(base_model)
model = model.to(device)

In [4]:
anchor_img = prepare_frame_for_inference(video_frames[0], transform)
anchor_feature_vec = get_feature_vector(anchor_img, device, model)

In [5]:
feature_list = list()
feature_array = list()
index = list()

In [6]:
for i in range(1, len(video_frames)):
    img_tensor = prepare_frame_for_inference(video_frames[i], transform)
    feature_vec = get_feature_vector(img_tensor, device, model)
    similarity = int(cos(anchor_feature_vec, feature_vec)*100)
    if similarity >= 75:
        feature_list.append(feature_vec)
        feature_array.append(feature_vec.numpy().reshape(-1))
        index.append(i)
print("Number of frames belong to originel video:", len(feature_array))

Number of frames belong to originel video: 96


In [7]:
feature_list.append(anchor_feature_vec)
feature_array.append(anchor_feature_vec.numpy().reshape(-1))
index.append(0)

In [8]:
matched_video_frames = []
for i in range(len(index)):
    matched_video_frames.append(video_frames[index[i]])
#matched_video_frames.append(video_frames[0])
len(matched_video_frames)

97

In [9]:
general_mean = torch.mean(torch.stack(feature_list), dim=0)
general_mean.shape

torch.Size([1, 4096])

In [10]:
similarity_to_mean = compute_similarity(feature_list, general_mean, False)

In [11]:
edge_index = similarity_to_mean[0][1]

In [12]:
similarity_list = compute_similarity(feature_list, feature_list[edge_index])

In [13]:
similarity_list_temp = similarity_list.copy()

In [14]:
keep_index = list()

In [15]:
div = 5
stop = True
while stop:
    if len(similarity_list_temp)<div:
        div = len(similarity_list_temp)
        stop = False
    for i in range(div):
        keep_index.append(similarity_list_temp[i][1])
    for i in range(div)[::-1]:
        similarity_list_temp.pop(i)
    if stop:
        edge_index = similarity_list_temp[0][1]
        similarity_list_temp = compute_similarity_index_known(feature_list, similarity_list_temp, feature_list[edge_index])    

In [16]:
height, width, layers = matched_video_frames[0].shape
size = (width,height)

out = cv2.VideoWriter('videos/restored.mp4',cv2.VideoWriter_fourcc(*'DIVX'), 20, size)

for i in keep_index:
    img = matched_video_frames[i]
    rgb_img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    out.write(rgb_img)

out.release()

In [17]:
out_reverse = cv2.VideoWriter('videos/restored_reverse.mp4',cv2.VideoWriter_fourcc(*'DIVX'), 20, size)

for i in keep_index[::-1]:
    img = matched_video_frames[i]
    rgb_img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    out_reverse.write(rgb_img)

out_reverse.release()