In [19]:
import glob
import json

import torch
import torchvision.models as models
import torchvision.transforms as transforms
from PIL import Image
from sklearn.metrics.pairwise import cosine_similarity
import numpy as np
import shutil
import os
from skimage.metrics import structural_similarity as ssim
import cv2

In [20]:
model = models.resnet34()
model = torch.nn.Sequential(*(list(model.children())[:-1])).to("cuda")
model.eval()

# Preprocessing
preprocess = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225]),
])

In [21]:
def extract_features(image_paths):
    image_tensors = []
    for image_path in image_paths:
        img = Image.open(image_path)
        img_tensor = preprocess(img)
        img_tensor = img_tensor.unsqueeze(0).numpy()
        image_tensors.append(img_tensor)

    image_tensors = np.concatenate(image_tensors, axis=0)
    image_tensors = torch.tensor(image_tensors, dtype=torch.float32, device="cuda")

    with torch.no_grad():
        features = model(image_tensors)

    return features.squeeze().cpu().numpy()

def ssim_similarity(frame1, frame2):
    """
    Compare two frames using Structural Similarity Index.

    Args:
        frame1, frame2: Input frames

    Returns:
        float: SSIM score (1.0 means identical images)
    """
    # Convert to grayscale
    gray1 = cv2.cvtColor(frame1, cv2.COLOR_BGR2GRAY)
    gray2 = cv2.cvtColor(frame2, cv2.COLOR_BGR2GRAY)

    # Ensure frames are the same size
    if gray1.shape != gray2.shape:
        # Resize second frame to match first
        gray2 = cv2.resize(gray2, (gray1.shape[1], gray1.shape[0]))

    # Calculate SSIM
    score, _ = ssim(gray1, gray2, full=True)
    return score

def cal_similarity(image_features):
    num_frames = len(image_features)
    similarities = np.zeros((num_frames, num_frames))
    for i, feature1 in enumerate(image_features[:-1]):
        for j, feature2 in enumerate(image_features[i+1:]):
            sim = cosine_similarity([feature1], [feature2])[0][0]
            similarities[i, i+j+1] = sim
            similarities[i+j+1, i] = sim

    return similarities

def cal_ssim(image_paths):
    num_frames = len(image_paths)
    similarities = np.zeros((num_frames, num_frames))
    for i, feature1 in enumerate(image_paths[:-1]):
        for j, feature2 in enumerate(image_paths[i+1:]):
            sim = ssim_similarity(cv2.imread(feature1), cv2.imread(feature2))
            similarities[i, i+j+1] = sim
            similarities[i+j+1, i] = sim

    return similarities

def select_frame_to_remove(similarities):
    i, j = np.unravel_index(np.argmax(similarities), similarities.shape)
    sum_i = similarities[i].sum()
    sum_j = similarities[j].sum()

    if sum_i > sum_j: return i
    else: return j

def resample_frames(image_paths, target_num_frames):
    image_features = extract_features(image_paths)
    similarities = cal_similarity(image_features)
    # similarities = cal_ssim(image_paths)

    while len(image_paths) > target_num_frames:
        frame_to_remove = select_frame_to_remove(similarities)
        image_paths.pop(frame_to_remove)
        similarities = np.delete(similarities, frame_to_remove, axis=0)
        similarities = np.delete(similarities, frame_to_remove, axis=1)

    if len(image_paths) < target_num_frames:
        num_to_duplicate = target_num_frames - len(image_paths)
        for i in range(num_to_duplicate):
            dup_idx = i % len(image_paths)
            dup_frame = image_paths[dup_idx]
            image_paths.append(dup_frame)

    return image_paths

# images = [
#     'D:/frames/1293860/3.jpg',
#     'D:/frames/1293860/4.jpg',
#     'D:/frames/1293860/10.jpg',
#     'D:/frames/1293860/11.jpg',
#     'D:/frames/1293860/12.jpg',
#     'D:/frames/1293860/15.jpg',
#     'D:/frames/1293860/16.jpg',
#     'D:/frames/1293860/18.jpg',
# ]
# resampled_images = resample_frames(images, 8)

In [22]:
def get_all_dirs():
    root_dir = 'D:/frames'
    frames_dirs = os.listdir(root_dir)
    frames_dirs.sort()
    return frames_dirs

def update_progress(prog):
    with open('frame-resampling-progress.json', 'w') as file:
        json.dump(prog, file)

def get_progress():
    if not os.path.exists('frame-resampling-progress.json'):
        return []
    with open('frame-resampling-progress.json', 'r') as file:
        prog = json.load(file)
    return prog

In [23]:
all_dirs = get_all_dirs()
progress = get_progress()
remain_dirs = [frame_dir for frame_dir in all_dirs if frame_dir not in progress]
output_dir = 'D:/frames-resampled'

In [24]:
for i, frame_dir in enumerate(remain_dirs):
    frame_dir = os.path.join('D:/frames', frame_dir)
    frame_paths = glob.glob(f"{frame_dir}/*.jpg")
    frame_paths = resample_frames(frame_paths, 8)

    for i, frame_path in enumerate(frame_paths):
        video_id = os.path.basename(os.path.dirname(frame_path))
        frame_dir = os.path.join(output_dir, video_id)
        os.makedirs(frame_dir, exist_ok=True)
        new_frame_path = os.path.join(frame_dir, f"{i}.jpg")
        shutil.copy(frame_path, new_frame_path)
    # progress.append(frame_dir)
    # update_progress(progress)