In [1]:
import glob
import json

import torch
import torchvision.models as models
import torchvision.transforms as transforms
from PIL import Image
from sklearn.metrics.pairwise import cosine_similarity
import numpy as np
import shutil
import os
from skimage.metrics import structural_similarity as ssim
import cv2

In [2]:
model = models.resnet34()
model = torch.nn.Sequential(*(list(model.children())[:-1])).to("cuda")
model.eval()

# Preprocessing
preprocess = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225]),
])

In [3]:
def extract_features(image_paths):
    image_tensors = []
    for image_path in image_paths:
        img = Image.open(image_path)
        img_tensor = preprocess(img)
        img_tensor = img_tensor.unsqueeze(0).numpy()
        image_tensors.append(img_tensor)

    image_tensors = np.concatenate(image_tensors, axis=0)
    image_tensors = torch.tensor(image_tensors, dtype=torch.float32, device="cuda")

    with torch.no_grad():
        features = model(image_tensors)

    return features.squeeze().cpu().numpy()

def cal_similarity(image_features):
    num_frames = len(image_features)
    similarities = np.zeros((num_frames, num_frames))
    for i, feature1 in enumerate(image_features[:-1]):
        for j, feature2 in enumerate(image_features[i+1:]):
            sim = cosine_similarity([feature1], [feature2])[0][0]
            similarities[i, i+j+1] = sim
            similarities[i+j+1, i] = sim

    return similarities

def select_frame_to_remove(similarities):
    i, j = np.unravel_index(np.argmax(similarities), similarities.shape)
    sum_i = similarities[i].sum()
    sum_j = similarities[j].sum()

    if sum_i > sum_j: return i
    else: return j

def resample_frames(image_paths, target_num_frames):
    if len(image_paths) < target_num_frames:
        num_to_duplicate = target_num_frames - len(image_paths)
        for i in range(num_to_duplicate):
            dup_idx = i % len(image_paths)
            dup_frame = image_paths[dup_idx]
            image_paths.append(dup_frame)
        return image_paths

    image_features = extract_features(image_paths)
    similarities = cal_similarity(image_features)
    while len(image_paths) > target_num_frames:
        frame_to_remove = select_frame_to_remove(similarities)
        image_paths.pop(frame_to_remove)
        similarities = np.delete(similarities, frame_to_remove, axis=0)
        similarities = np.delete(similarities, frame_to_remove, axis=1)
    return image_paths

# images = [
#     'D:/frames/1293860/3.jpg',
#     'D:/frames/1293860/4.jpg',
#     'D:/frames/1293860/10.jpg',
#     'D:/frames/1293860/11.jpg',
#     'D:/frames/1293860/12.jpg',
#     'D:/frames/1293860/15.jpg',
#     'D:/frames/1293860/16.jpg',
#     'D:/frames/1293860/18.jpg',
# ]
# resampled_images = resample_frames(images, 8)

In [4]:
def get_all_dirs():
    root_dir = 'D:/frames'
    frames_dirs = os.listdir(root_dir)
    frames_dirs.sort()
    return frames_dirs

def update_progress(prog):
    with open('frame-resampling-progress.json', 'w') as file:
        json.dump(prog, file)

def get_progress():
    if not os.path.exists('frame-resampling-progress.json'):
        return []
    with open('frame-resampling-progress.json', 'r') as file:
        prog = json.load(file)
    return prog

In [5]:
all_dirs = get_all_dirs()
progress = get_progress()
remain_dirs = [frame_dir for frame_dir in all_dirs if frame_dir not in progress]
output_dir = 'D:/frames-resampled'

In [6]:
for i, frame_dir in enumerate(remain_dirs):
    frame_dir = os.path.join('D:/frames', frame_dir)
    print(f'Resampling {frame_dir}. Progress: {i}/{len(remain_dirs)}')
    frame_paths = glob.glob(f"{frame_dir}/*.jpg")
    frame_paths = resample_frames(frame_paths, 8)

    video_id = os.path.basename(frame_dir)
    for j, frame_path in enumerate(frame_paths):
        frame_dir = os.path.join(output_dir, video_id)
        os.makedirs(frame_dir, exist_ok=True)
        new_frame_path = os.path.join(frame_dir, f"{j}.jpg")
        shutil.copy(frame_path, new_frame_path)

    progress.append(video_id)
    update_progress(progress)

    if i % 10000 == 0:
        torch.cuda.empty_cache()

Resampling D:/frames\389690. Progress: 0/18919
Resampling D:/frames\389700. Progress: 1/18919
Resampling D:/frames\389710. Progress: 2/18919
Resampling D:/frames\389730. Progress: 3/18919
Resampling D:/frames\389740. Progress: 4/18919
Resampling D:/frames\389810. Progress: 5/18919
Resampling D:/frames\389870. Progress: 6/18919
Resampling D:/frames\389900. Progress: 7/18919
Resampling D:/frames\389940. Progress: 8/18919
Resampling D:/frames\389970. Progress: 9/18919
Resampling D:/frames\389980. Progress: 10/18919
Resampling D:/frames\390040. Progress: 11/18919
Resampling D:/frames\390090. Progress: 12/18919
Resampling D:/frames\390100. Progress: 13/18919
Resampling D:/frames\390200. Progress: 14/18919
Resampling D:/frames\390210. Progress: 15/18919
Resampling D:/frames\390220. Progress: 16/18919
Resampling D:/frames\390290. Progress: 17/18919
Resampling D:/frames\390330. Progress: 18/18919
Resampling D:/frames\390460. Progress: 19/18919
Resampling D:/frames\390510. Progress: 20/18919
Re