<a href="https://colab.research.google.com/github/detektor777/colab_list_video/blob/main/flavr.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
#@title ##**Install** { display-mode: "form" }
%%capture

import os
import subprocess
import importlib.util
from google.colab import files
import torch

if 'downloaded_model_type' not in globals():
    downloaded_model_type = None

model_type = "2x"  #@param ["2x", "4x", "8x"]

print(f"Current PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")

print("Installing FLAVR dependencies...")

print("Installing numpy...")
result_numpy = subprocess.run(["pip", "install", "numpy==1.23.5"], capture_output=True, text=True)
if result_numpy.returncode != 0:
    print("Error installing numpy:")
    print(result_numpy.stderr)

print("Installing opencv-python...")
result_opencv = subprocess.run(["pip", "install", "opencv-python==4.8.0.76"], capture_output=True, text=True)
if result_opencv.returncode != 0:
    print("Error installing opencv-python:")
    print(result_opencv.stderr)

print("Installing PyAV...")
result_av = subprocess.run(["pip", "install", "av==10.0.0"], capture_output=True, text=True)
if result_av.returncode != 0:
    print("Error installing PyAV:")
    print(result_av.stderr)

if importlib.util.find_spec("av") is not None:
    import av
    print(f"PyAV successfully installed, version: {av.__version__}")
else:
    print("Error: PyAV is not installed after installation attempt.")

if not os.path.exists("/content/FLAVR"):
    print("Cloning FLAVR repository...")
    result_git = subprocess.run(["git", "clone", "https://github.com/tarun005/FLAVR.git", "/content/FLAVR"], capture_output=True, text=True)
    if result_git.returncode != 0:
        print("Error cloning repository:")
        print(result_git.stderr)
    else:
        print("Repository successfully cloned.")
else:
    print("FLAVR repository already exists.")

model_urls = {
    "2x": "1IZe-39ZuXy3OheGJC-fT3shZocGYuNdH",
    "4x": "1xoZqWJdIOjSaE2DtH4ifXKlRwFySm5Gq",
    "8x": "1DlXgNANDGLZEYOCMvQ5T1cAqkW90FiPt"
}
model_path = f"/content/FLAVR_{model_type}.pth"

if os.path.exists(model_path) and downloaded_model_type == model_type:
    print(f"Model {model_type} is already downloaded and matches the selected type: {model_path}")
elif os.path.exists(model_path) and downloaded_model_type != model_type:
    print(f"Model {downloaded_model_type} already exists, but {model_type} is selected. Replacing model...")
    os.remove(model_path)
    download_required = True
else:
    print(f"Model {model_type} is missing. Download required.")
    download_required = True

if download_required:
    print(f"Downloading FLAVR model {model_type}...")
    subprocess.run(["pip", "install", "gdown"], capture_output=True, text=True)
    result_gdown = subprocess.run(["gdown", f"https://drive.google.com/uc?id={model_urls[model_type]}", "-O", model_path], capture_output=True, text=True)

    if result_gdown.returncode == 0 and os.path.exists(model_path):
        global downloaded_model_type
        downloaded_model_type = model_type
        print(f"Model {model_type} successfully downloaded to {model_path}")
    else:
        print(f"Error: Failed to automatically download model {model_type}.")
        print(result_gdown.stderr)
        print("Please download the model manually:")
        print(f"1. Use the link: https://drive.google.com/file/d/{model_urls[model_type]}/view")
        print(f"2. Download the file and upload it below with the name FLAVR_{model_type}.pth")
        uploaded = files.upload()
        for uploaded_file_name, _ in uploaded.items():
            if uploaded_file_name.endswith(".pth"):
                os.rename(uploaded_file_name, model_path)
                downloaded_model_type = model_type
                print(f"Model saved as {model_path}")
                break
        else:
            print("Error: Uploaded file is not a .pth model.")

# change /content/FLAVR/interpolate.py
with open('/content/FLAVR/interpolate.py', 'w') as f:
    f.write('''import os
import torch
import cv2
import time
import sys
import torchvision
import numpy as np
import tqdm
from torchvision.io import read_video
from dataset.transforms import ToTensorVideo, Resize
import argparse

parser = argparse.ArgumentParser()
parser.add_argument("--input_video", type=str, required=True, help="Path/URL to input video")
parser.add_argument("--youtube-dl", type=str, help="Path to youtube_dl", default=".local/bin/youtube-dl")
parser.add_argument("--factor", type=int, required=True, choices=[2, 4, 8], help="Interpolation factor: 2x/4x/8x")
parser.add_argument("--codec", type=str, help="Video codec", default="mpeg4")
parser.add_argument("--load_model", required=True, type=str, help="Path to saved model")
parser.add_argument("--up_mode", type=str, help="Upscale mode", default="transpose")
parser.add_argument("--output_ext", type=str, help="Output video format", default=".avi")
parser.add_argument("--input_ext", type=str, help="Input video format", default=".mp4")
parser.add_argument("--downscale", type=float, help="Downscale for memory saving", default=1)
parser.add_argument("--output_fps", type=int, help="Target FPS", default=30)
parser.add_argument("--batch_size", type=int, help="Batch size", default=1)
parser.add_argument("--is_folder", action="store_true")
args = parser.parse_args()

input_video = args.input_video
input_ext = args.input_ext

from os import path

if not args.is_folder and not path.exists(input_video):
    print("Invalid input file path!")
    exit()

if args.is_folder and not path.exists(input_video):
    print("Invalid input folder path!")
    exit()

if args.output_ext != ".avi":
    print("Only .avi output is supported for now. Use ffmpeg to convert to mp4 or other formats.")

if input_video.endswith("/"):
    video_name = input_video.split("/")[-2].split(input_ext)[0]
else:
    video_name = input_video.split("/")[-1].split(input_ext)[0]

output_video = os.path.join(video_name + f"_{args.factor}x" + str(args.output_ext))

n_outputs = args.factor - 1
model_name = "unet_18"
nbr_frame = 4
joinType = "concat"

if input_video.startswith("http"):
    assert args.youtube_dl is not None
    youtube_dl_path = args.youtube_dl
    cmd = f"{youtube_dl_path} -i -o video.mp4 {input_video}"
    os.system(cmd)
    input_video = "video.mp4"
    output_video = "video" + str(args.output_ext)

def loadModel(model, checkpoint):
    saved_state_dict = torch.load(checkpoint)['state_dict']
    saved_state_dict = {k.partition("module.")[-1]:v for k,v in saved_state_dict.items()}
    model.load_state_dict(saved_state_dict)

checkpoint = args.load_model
from model.FLAVR_arch import UNet_3D_3D

model = UNet_3D_3D(model_name.lower(), n_inputs=4, n_outputs=n_outputs, joinType=joinType, upmode=args.up_mode)
loadModel(model, checkpoint)
torch.backends.cudnn.benchmark = True
model = model.cuda()

print("Warming up model...")
with torch.no_grad(), torch.cuda.amp.autocast():
    dummy_input = [torch.randn(args.batch_size, 3, 424, 568).cuda() for _ in range(nbr_frame)]
    model(dummy_input)
print("Warmup complete.")

def write_video_cv2(frames, video_name, fps, sizes):
    out = cv2.VideoWriter(video_name, cv2.CAP_OPENCV_MJPEG, cv2.VideoWriter_fourcc('M','J','P','G'), fps, sizes)
    for frame in frames:
        out.write(frame)
    out.release()

def make_image(img):
    q_im = img.data.mul(255.).clamp(0, 255).round()
    im = q_im.permute(1, 2, 0).cpu().numpy().astype(np.uint8)
    im = cv2.cvtColor(im, cv2.COLOR_RGB2BGR)
    return im

def files_to_videoTensor(path, downscale=1.):
    from PIL import Image
    files = sorted(os.listdir(path))
    print(f"Files found: {len(files)}")
    images = [torch.Tensor(np.asarray(Image.open(os.path.join(input_video, f)))).type(torch.uint8) for f in files]
    print(f"Image size: {images[0].shape}")
    videoTensor = torch.stack(images)
    return videoTensor

def video_to_tensor(video):
    videoTensor, _, md = read_video(video, pts_unit='sec')
    fps = md["video_fps"]
    print(f"Video FPS: {fps}")
    return videoTensor

def video_transform(videoTensor, downscale=1):
    T, H, W = videoTensor.size(0), videoTensor.size(1), videoTensor.size(2)
    downscale = int(downscale * 8)
    resizes = 8 * (H // downscale), 8 * (W // downscale)
    transforms = torchvision.transforms.Compose([ToTensorVideo(), Resize(resizes)])
    videoTensor = transforms(videoTensor)
    print(f"Resizing to {resizes[0]}x{resizes[1]}")
    return videoTensor, resizes

if args.is_folder:
    videoTensor = files_to_videoTensor(input_video, args.downscale)
else:
    videoTensor = video_to_tensor(input_video)

idxs = torch.Tensor(range(len(videoTensor))).type(torch.long).view(1, -1).unfold(1, size=nbr_frame, step=1).squeeze(0)
videoTensor, resizes = video_transform(videoTensor, args.downscale)

frames = torch.unbind(videoTensor, 1)
n_inputs = len(frames)
width = n_outputs + 1

outputs = []
outputs.append(frames[idxs[0][1]])

model = model.eval()

batch_size = args.batch_size

for i in tqdm.tqdm(range(0, len(idxs), batch_size)):
    batch_idxs = idxs[i:i + batch_size]
    inputs = []
    for idx_set in batch_idxs:
        input_set = [frames[idx_].cuda().unsqueeze(0) for idx_ in idx_set]
        inputs.append(input_set)

    batch_inputs = []
    for j in range(nbr_frame):
        batch_inputs.append(torch.cat([inputs[k][j] for k in range(len(inputs))], dim=0))

    with torch.no_grad(), torch.cuda.amp.autocast():
        outputFrame = model(batch_inputs)
    torch.cuda.synchronize()

    temp_outputs = []
    for batch_idx in range(len(batch_idxs)):
        for output_idx in range(len(outputFrame)):
            temp_outputs.append(outputFrame[output_idx][batch_idx].to('cpu', non_blocking=True))
        temp_outputs.append(batch_inputs[2][batch_idx].squeeze(0).to('cpu', non_blocking=True))
    outputs.extend(temp_outputs)

new_video = [make_image(im_) for im_ in outputs]
write_video_cv2(new_video, output_video, args.output_fps, (resizes[1], resizes[0]))

print("Saving to", output_video.split(".")[0] + ".mp4")
os.system('ffmpeg -hide_banner -loglevel warning -i %s %s' % (output_video, output_video.split(".")[0] + ".mp4"))
os.remove(output_video)
''')

print("File /content/FLAVR/interpolate.py has been updated.")


In [None]:
#@title ##**Select Video File** { display-mode: "form" }
import os
import ipywidgets as widgets
from IPython.display import display, clear_output
from google.colab import files
from google.colab import drive

upload_option = "Upload from PC"  #@param ["Upload from PC", "Load from Google Drive Root", "Load from Google Drive"]

file_name = None
last_selected_button = None

def reset_button_colors(buttons):
    for btn in buttons:
        btn.style.button_color = None

if upload_option == "Upload from PC":
    print("Please upload a video file.")
    uploaded = files.upload()
    if uploaded:
        file_name = list(uploaded.keys())[0]
    else:
        print("No file uploaded.")
        file_name = None

elif upload_option == "Load from Google Drive Root":
    drive.mount('/content/drive')
    root_dir = '/content/drive/MyDrive/'

    video_extensions = ['.mp4', '.mkv', '.avi', '.mov']
    files_list = []

    for f in os.listdir(root_dir):
        if os.path.isfile(os.path.join(root_dir, f)) and os.path.splitext(f)[1].lower() in video_extensions:
            files_list.append(f)

    if not files_list:
        print("No video files found in Google Drive root.")
        file_name = None
    else:
        print("Select a video file from Google Drive root:")

        output = widgets.Output()
        buttons = []

        def on_button_clicked(b):
            global file_name, last_selected_button
            with output:
                clear_output()
                reset_button_colors(buttons)
                selected_file = b.description
                file_name = os.path.join(root_dir, selected_file)

                if file_name and os.path.exists(file_name):
                    b.style.button_color = 'green'
                else:
                    b.style.button_color = 'red'

                last_selected_button = b
                print(f"Selected file: {file_name if file_name else 'None'}")

        for file in files_list:
            button = widgets.Button(description=file, layout=widgets.Layout(width='500px', overflow='hidden', text_overflow='ellipsis'))
            button.on_click(on_button_clicked)
            buttons.append(button)

        display(widgets.VBox(buttons), output)

elif upload_option == "Load from Google Drive":
    drive.mount('/content/drive')
    root_dir = '/content/drive/MyDrive/'

    video_extensions = ['.mp4', '.mkv', '.avi', '.mov']
    files_list = []

    for dirpath, _, filenames in os.walk(root_dir):
        for f in filenames:
            if os.path.splitext(f)[1].lower() in video_extensions:
                relative_path = os.path.relpath(os.path.join(dirpath, f), root_dir)
                files_list.append(relative_path)

    if not files_list:
        print("No video files found in Google Drive or its subfolders.")
        file_name = None
    else:
        print("Select a video file from Google Drive (including subfolders):")

        output = widgets.Output()
        buttons = []

        def on_button_clicked(b):
            global file_name, last_selected_button
            with output:
                clear_output()
                reset_button_colors(buttons)
                selected_file = b.description
                file_name = os.path.join(root_dir, selected_file)

                if file_name and os.path.exists(file_name):
                    b.style.button_color = 'green'
                else:
                    b.style.button_color = 'red'

                last_selected_button = b
                print(f"Selected file: {file_name if file_name else 'None'}")

        for file in files_list:
            button = widgets.Button(description=file, layout=widgets.Layout(width='500px', overflow='hidden', text_overflow='ellipsis'))
            button.on_click(on_button_clicked)
            buttons.append(button)

        display(widgets.VBox(buttons), output)

if file_name:
    print(f"Video file path set to: {file_name}")
else:
    print("Video file path not set. Please select a file.")

In [None]:
#@title ##**Config** { display-mode: "form" }
import os
from google.colab import files
import shutil
from google.colab import drive

interpolation_factor = 2  #@param {type:"slider", min:2, max:8, step:2}
segment_duration = 3  #@param {type:"slider", min:1, max:60, step:1}
batch_size = 1  #@param {type:"slider", min:1, max:16, step:1}

#path
output_folder = "google_drive" #@param ["google_drive","root"]

if output_folder == "google_drive":
    if not os.path.exists('/content/drive'):
        drive.mount('/content/drive')

#clear

clear_input_folder = False #@param {type:"boolean"}
flavr_input_folder = '/content/drive/MyDrive/temp_segments'

if clear_input_folder:
    if os.path.isdir(flavr_input_folder):
        shutil.rmtree(flavr_input_folder)
    os.makedirs(flavr_input_folder)

clear_output_folder = False #@param {type:"boolean"}
flavr_output_folder = '/content/drive/MyDrive/output_segments'

if clear_output_folder:
    if os.path.isdir(flavr_output_folder):
        shutil.rmtree(flavr_output_folder)
    os.makedirs(flavr_output_folder)





In [None]:
#@title ##**Run** { display-mode: "form" }
import subprocess
import os
import shutil
from google.colab import files
import cv2
import gc
import sys
import time
import logging
from google.colab import drive
import json

if 'file_name' not in globals() or file_name is None or not os.path.exists(file_name):
    print("Error: Video file not selected or does not exist. Run the 'Select video file' cell.")
else:
    print(f"Increasing FPS for video: {file_name}")
    print(f"Input file size: {os.path.getsize(file_name) / (1024*1024):.2f} MB")

    log_file = '/content/processing_log.txt'
    if os.path.exists(log_file):
        os.remove(log_file)
    logging.basicConfig(filename=log_file, level=logging.INFO,
                       format='%(asctime)s - %(message)s')

    if output_folder == "google_drive":
        drive.mount('/content/drive')
        temp_dir = "/content/drive/MyDrive/temp_segments"
        output_segments_dir = "/content/drive/MyDrive/output_segments"
    else:
        temp_dir = "/content/temp_segments"
        output_segments_dir = "/content/output_segments"

    def count_segments(directory):
        for attempt in range(20):
            try:
                if os.path.exists(directory):
                    return len([f for f in os.listdir(directory) if f.endswith(".mp4")])
                return 0
            except Exception as e:
                if attempt < 19:
                    time.sleep(3)
                    continue
                print(f"Failed to count segments in {directory} after 20 attempts: {str(e)}")
                return 0

    os.makedirs(temp_dir, exist_ok=True)
    os.makedirs(output_segments_dir, exist_ok=True)

    existing_input_segments = count_segments(temp_dir)
    existing_output_segments = count_segments(output_segments_dir)
    print(f"Segments to process (temp_segments): {existing_input_segments}")
    print(f"Processed segments (output_segments): {existing_output_segments}")

    cmd_bitrate = ["ffprobe", "-v", "error", "-show_entries", "format=bit_rate", "-of", "default=noprint_wrappers=1:nokey=1", file_name]
    result_bitrate = subprocess.run(cmd_bitrate, capture_output=True, text=True)
    input_bitrate = int(result_bitrate.stdout.strip()) if result_bitrate.stdout.strip().isdigit() else 2000000
    output_bitrate = input_bitrate * interpolation_factor
    logging.info(f"Input video bitrate: {input_bitrate} bps, Output bitrate for all segments: {output_bitrate} bps")

    cap = cv2.VideoCapture(file_name)
    input_fps = cap.get(cv2.CAP_PROP_FPS)
    frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    input_duration = frame_count / input_fps
    output_fps = input_fps * interpolation_factor
    expected_frame_count = frame_count * interpolation_factor
    cap.release()
    print(f"Input FPS: {input_fps}, Frames: {frame_count}, Duration: {input_duration:.2f} sec")
    print(f"Target FPS: {output_fps}, Expected frames: {expected_frame_count}, Expected duration: {input_duration:.2f} sec")

    base_file_name = os.path.basename(file_name)
    output_file_name = base_file_name.rsplit('.', 1)[0] + f'_{interpolation_factor}x.mp4'
    if output_folder == "google_drive":
        save_path = '/content/drive/MyDrive/'
    elif output_folder == "root":
        save_path = '/content/'
    else:
        save_path = '/content/'
    output_video = os.path.join(save_path, output_file_name)
    temp_output_video = os.path.join(save_path, "temp_" + output_file_name)

    model_path = f"/content/FLAVR_{model_type}.pth"
    if not os.path.exists(model_path):
        print(f"Error: model not found at {model_path}. Run the 'Install FLAVR dependencies' cell.")
    elif not os.path.exists("/content/FLAVR/interpolate.py"):
        print("Error: interpolate.py script not found. Make sure the FLAVR repository is downloaded.")
    else:
        stats_file = os.path.join(temp_dir, "processing_stats.json")
        if os.path.exists(stats_file):
            with open(stats_file, 'r') as f:
                stats = json.load(f)
                total_flavr_time = stats.get('total_flavr_time', 0)
                total_batches = stats.get('total_batches', 0)
                processed_duration = stats.get('processed_duration', 0)
        else:
            total_flavr_time = 0
            total_batches = 0
            processed_duration = 0

        if existing_input_segments == 0:
            print(f"Splitting video into segments of {segment_duration} seconds...")
            segment_pattern = os.path.join(temp_dir, "segment_%03d.mp4")
            cmd_split = [
                "ffmpeg", "-i", file_name, "-c", "copy", "-map", "0",
                "-segment_time", str(segment_duration), "-f", "segment",
                "-reset_timestamps", "1", "-y", segment_pattern
            ]
            result_split = subprocess.run(cmd_split, capture_output=True, text=True)
            if result_split.returncode != 0:
                print("Error splitting video:")
                print(result_split.stderr)
                raise subprocess.CalledProcessError(result_split.returncode, cmd_split)
        segments = sorted([f for f in os.listdir(temp_dir) if f.startswith("segment_") and f.endswith(".mp4")])
        print(f"Total segments to process: {len(segments)}")

        processed_segments = []
        start_time = time.time()
        # Determine the last processed segment from output_segments_dir and print it
        output_files = [f for f in os.listdir(output_segments_dir) if f.startswith("final_segment_") and f.endswith(f"_{interpolation_factor}x.mp4")]
        processed_indices = set()
        last_index = -1
        if output_files:
            segment_indices = []
            for f in output_files:
                parts = f.split("_")
                if len(parts) >= 3 and parts[2].isdigit():
                    idx = int(parts[2])
                    segment_indices.append(idx)
            if segment_indices:
                last_index = max(segment_indices)
                processed_indices = set(segment_indices)
                print(f"Last processed segment: index {last_index}, file final_segment_{last_index}_{interpolation_factor}x.mp4")
            else:
                print("No valid processed segments found in output_segments_dir")
        else:
            print("No processed segments found in output_segments_dir")

        # Process all segments that exist and haven't been processed
        segments_to_process = []
        for s in segments:
            try:
                seg_index = int(s.split("_")[1].split(".")[0])
                output_file = os.path.join(output_segments_dir, f"final_segment_{seg_index}_{interpolation_factor}x.mp4")
                if not os.path.exists(output_file) and os.path.exists(os.path.join(temp_dir, s)):
                    segments_to_process.append((seg_index, s))
            except (IndexError, ValueError):
                continue
        segments_to_process.sort()  # Process in order
        if not segments_to_process:
            print("No segments to process: all segments already have output files or are missing in temp_dir")

        for i, segment in segments_to_process:
            segment_path = os.path.join(temp_dir, segment)
            temp_output = f"/content/processed_segment_{i}_{interpolation_factor}x.mp4"
            flavr_output = f"/content/{segment.rsplit('.', 1)[0]}_{interpolation_factor}x.mp4"
            first_frame_video = f"/content/first_frame_{i}.mp4"
            first_frame_dup = f"/content/first_frame_dup_{i}.mp4"
            second_frame_video = f"/content/second_frame_{i}.mp4"
            last_frame_video = f"/content/last_frame_{i}.mp4"
            last_frame_dup = f"/content/last_frame_dup_{i}.mp4"
            final_segment = os.path.join(output_segments_dir, f"final_segment_{i}_{interpolation_factor}x.mp4")

            for f in [temp_output, flavr_output, first_frame_video, first_frame_dup, second_frame_video, last_frame_video, last_frame_dup]:
                if os.path.exists(f):
                    os.remove(f)

            cap = cv2.VideoCapture(segment_path)
            seg_fps = cap.get(cv2.CAP_PROP_FPS)
            seg_frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
            seg_duration = seg_frame_count / seg_fps
            cap.release()
            seg_output_fps = seg_fps * interpolation_factor
            frame_duration = 1 / seg_output_fps
            expected_seg_frame_count = seg_frame_count * interpolation_factor

            logging.info(f"Segment {i} input: FPS {seg_fps}, Frames {seg_frame_count}, Duration {seg_duration:.2f} sec")
            logging.info(f"Segment {i} expected: FPS {seg_output_fps}, Frames {expected_seg_frame_count}, Duration {seg_duration:.2f} sec")

            cmd_extract_first = [
                "ffmpeg", "-i", segment_path, "-frames:v", "1", "-r", str(int(seg_output_fps)),
                "-c:v", "libx264", "-c:a", "aac", "-t", str(frame_duration), "-y", first_frame_video
            ]
            subprocess.run(cmd_extract_first, capture_output=True, text=True)
            shutil.copy(first_frame_video, first_frame_dup)

            cmd_extract_second = [
                "ffmpeg", "-i", segment_path, "-vf", "select='eq(n\,1)'",
                "-frames:v", "1", "-r", str(int(seg_output_fps)), "-c:v", "libx264", "-c:a", "aac",
                "-t", str(frame_duration), "-y", second_frame_video
            ]
            subprocess.run(cmd_extract_second, capture_output=True, text=True)

            cmd_extract_last = [
                "ffmpeg", "-i", segment_path, "-vf", f"select='eq(n\,{seg_frame_count-1})'",
                "-frames:v", "1", "-r", str(int(seg_output_fps)), "-c:v", "libx264", "-c:a", "aac",
                "-t", str(frame_duration), "-y", last_frame_video
            ]
            subprocess.run(cmd_extract_last, capture_output=True, text=True)
            shutil.copy(last_frame_video, last_frame_dup)

            cmd_flavr = [
                "python", "/content/FLAVR/interpolate.py",
                "--input_video", str(segment_path),
                "--factor", str(interpolation_factor),
                "--load_model", str(model_path),
                "--output_fps", str(int(seg_output_fps)),
                "--batch_size", str(batch_size)
            ]

            flavr_start_time = time.time()
            result_flavr = subprocess.run(cmd_flavr, capture_output=True, text=True)
            flavr_time = time.time() - flavr_start_time
            total_flavr_time += flavr_time

            cap = cv2.VideoCapture(segment_path)
            seg_frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
            groups = max(1, seg_frame_count - 4 + 1)
            batches = (groups + batch_size - 1) // batch_size
            total_batches += batches
            processed_duration += seg_duration
            cap.release()

            with open(stats_file, 'w') as f:
                json.dump({
                    'total_flavr_time': total_flavr_time,
                    'total_batches': total_batches,
                    'processed_duration': processed_duration
                }, f)

            elapsed_time = time.time() - start_time
            progress = (i + 1) / len(segments)
            if i > 0 and total_flavr_time > 0:
                avg_flavr_time_per_segment = total_flavr_time / (i + 1)
                remaining_segments = len(segments) - (i + 1)
                remaining_time = avg_flavr_time_per_segment * remaining_segments
            else:
                remaining_time = 0
            bar_length = 30
            filled = int(bar_length * progress)
            bar = '=' * filled + '-' * (bar_length - filled)
            elapsed_str = f"{int(elapsed_time // 60):02d}:{int(elapsed_time % 60):02d}"
            remaining_str = f"{int(remaining_time // 60):02d}:{int(remaining_time % 60):02d}"

            if total_batches > 0 and processed_duration > 0 and i >= 0:
                avg_time_per_batch = total_flavr_time / total_batches
                avg_time_per_second = total_flavr_time / processed_duration
                avg_time_per_segment = total_flavr_time / (i + 1)
                avg_stats = f"Avg time/batch (size={batch_size}): {avg_time_per_batch:.2f}s, Avg time/sec: {avg_time_per_second:.2f}s, Avg time/segment: {avg_time_per_segment:.2f}s"
            else:
                avg_stats = "Avg time/batch: N/A, Avg time/sec: N/A, Avg time/segment: N/A"

            sys.stdout.write(f"\rProcessing segment {i+1}/{len(segments)}: [{bar}] {progress:.1%} | Elapsed: {elapsed_str} | Remaining: {remaining_str} | {avg_stats}")
            sys.stdout.flush()

            if result_flavr.returncode != 0:
                print(f"\nError processing segment {segment}:")
                print(result_flavr.stderr)
                raise subprocess.CalledProcessError(result_flavr.returncode, cmd_flavr)
            elif not os.path.exists(flavr_output):
                print(f"\nError: FLAVR file not created: {flavr_output}")
                raise FileNotFoundError(f"FLAVR output not created: {flavr_output}")
            else:
                def move_with_retry(src, dst):
                    if output_folder == "google_drive":
                        for attempt in range(20):
                            try:
                                shutil.move(src, dst)
                                return True
                            except:
                                if attempt < 19:
                                    time.sleep(3)
                                    continue
                                print(f"\nFailed to move {src} after 20 attempts")
                                return False
                    else:
                        shutil.move(src, dst)
                        return True

                if move_with_retry(flavr_output, temp_output):
                    cap = cv2.VideoCapture(temp_output)
                    flavr_frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
                    flavr_duration = flavr_frame_count / seg_output_fps
                    cap.release()
                    logging.info(f"Segment {i} after FLAVR: FPS {seg_output_fps}, Frames {flavr_frame_count}, Duration {flavr_duration:.2f} sec")
                    logging.info(f"Segment {i} FLAVR frame difference (actual - expected): {flavr_frame_count - expected_seg_frame_count}")

                    concat_list_temp = f"/content/concat_list_{i}.txt"
                    def write_concat_list():
                        for attempt in range(20):
                            try:
                                with open(concat_list_temp, "w") as f:
                                    f.write(f"file '{first_frame_video}'\n")
                                    f.write(f"file '{first_frame_dup}'\n")
                                    f.write(f"file '{second_frame_video}'\n")
                                    f.write(f"file '{temp_output}'\n")
                                    f.write(f"file '{last_frame_dup}'\n")
                                    f.write(f"file '{last_frame_video}'\n")
                                return True
                            except:
                                if attempt < 19:
                                    time.sleep(3)
                                    continue
                                print(f"\nFailed to write concat list after 20 attempts")
                                return False

                    if write_concat_list():
                        cmd_concat_segment = [
                            "ffmpeg", "-f", "concat", "-safe", "0", "-i", concat_list_temp,
                            "-c:v", "libx264", "-b:v", f"{output_bitrate}", "-c:a", "aac", "-r", str(int(seg_output_fps)), "-y", final_segment
                        ]
                        result_concat_seg = subprocess.run(cmd_concat_segment, capture_output=True, text=True)
                        if result_concat_seg.returncode != 0:
                            print(f"\nError concatenating segment {segment}:")
                            print(result_concat_seg.stderr)
                            raise subprocess.CalledProcessError(result_concat_seg.returncode, cmd_concat_segment)

                        cap = cv2.VideoCapture(final_segment)
                        final_seg_frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
                        final_seg_duration = final_seg_frame_count / seg_output_fps
                        cap.release()
                        logging.info(f"Segment {i} final: FPS {seg_output_fps}, Frames {final_seg_frame_count}, Duration {final_seg_duration:.2f} sec")
                        logging.info(f"Segment {i} final frame difference (final - expected): {final_seg_frame_count - expected_seg_frame_count}")
                        logging.info(f"Segment {i} duration difference (final - input): {final_seg_duration - seg_duration:.2f} sec")
                        processed_segments.append(final_segment)

                        for f in [temp_output, first_frame_video, first_frame_dup, second_frame_video, last_frame_video, last_frame_dup, concat_list_temp]:
                            if os.path.exists(f):
                                os.remove(f)

            gc.collect()

        print(f"\nSegment processing complete. Processed segments saved in {output_segments_dir}")

In [None]:
#@title ##**Create video** { display-mode: "form" }
import subprocess
import os
import cv2
import logging
import re

log_file = '/content/processing_log.txt'
logging.basicConfig(filename=log_file, level=logging.INFO,
                   format='%(asctime)s - %(message)s')


if 'file_name' not in globals() or 'output_video' not in globals() or 'temp_output_video' not in globals() or 'output_fps' not in globals() or 'output_segments_dir' not in globals() or 'expected_frame_count' not in globals() or 'input_duration' not in globals():
    print("Error: Required variables not defined. Run the 'Segment Processing' cell first.")
else:
    print("Merging processed segments...")

    segment_files = [f for f in os.listdir(output_segments_dir) if f.endswith(".mp4")]
    def get_segment_number(filename):
        match = re.search(r'segment_(\d+)', filename)
        return int(match.group(1)) if match else float('inf')
    processed_segments = sorted(
        [os.path.join(output_segments_dir, f) for f in segment_files],
        key=lambda x: get_segment_number(os.path.basename(x))
    )
    if not processed_segments:
        print(f"Error: No processed segments found in {output_segments_dir}")
        raise FileNotFoundError("No processed segments available")

    total_segment_frames = 0
    total_segment_duration = 0
    print(f"Found {len(processed_segments)} segments to merge")
    logging.info(f"Merging {len(processed_segments)} segments")

    for i, seg in enumerate(processed_segments):
        cap = cv2.VideoCapture(seg)
        seg_frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
        seg_fps = cap.get(cv2.CAP_PROP_FPS)
        seg_duration = seg_frame_count / seg_fps if seg_fps > 0 else 0
        total_segment_frames += seg_frame_count
        total_segment_duration += seg_duration
        cap.release()
        logging.info(f"Segment {i} for merging: File {os.path.basename(seg)}, Frames {seg_frame_count}, FPS {seg_fps:.2f}, Duration {seg_duration:.2f} sec")
        print(f"Segment {i}: File {os.path.basename(seg)}, Frames {seg_frame_count}, Duration {seg_duration:.2f} sec")

    logging.info(f"Total before merging: Frames {total_segment_frames}, Duration {total_segment_duration:.2f} sec")
    print(f"Total before merging: Frames {total_segment_frames}, Duration {total_segment_duration:.2f} sec")

    cmd_bitrate = ["ffprobe", "-v", "error", "-show_entries", "format=bit_rate", "-of", "default=noprint_wrappers=1:nokey=1", file_name]
    result_bitrate = subprocess.run(cmd_bitrate, capture_output=True, text=True)
    input_bitrate = int(result_bitrate.stdout.strip()) if result_bitrate.stdout.strip().isdigit() else 2000000
    output_bitrate = input_bitrate * interpolation_factor
    print(f"Input bitrate: {input_bitrate / 1000:.0f} kbit/s, Target bitrate: {output_bitrate / 1000:.0f} kbit/s")
    logging.info(f"Input bitrate: {input_bitrate / 1000:.0f} kbit/s, Target bitrate: {output_bitrate / 1000:.0f} kbit/s")

    concat_list = os.path.join(temp_dir, "concat_list.txt")
    def write_final_concat_list():
        for attempt in range(20):
            try:
                with open(concat_list, "w") as f:
                    for seg in processed_segments:
                        f.write(f"file '{seg}'\n")
                return True
            except:
                if attempt < 19:
                    time.sleep(3)
                    continue
                print(f"Failed to write final concat list after 20 attempts")
                logging.error(f"Failed to write final concat list after 20 attempts")
                return False

    if write_final_concat_list():
        cmd_concat = [
            "ffmpeg", "-f", "concat", "-safe", "0", "-i", concat_list,
            "-c:v", "libx264", "-vsync", "0", "-copyts",
            "-b:v", f"{output_bitrate}", "-an", "-sn", "-y", temp_output_video
        ]
        result_concat = subprocess.run(cmd_concat, capture_output=True, text=True)
        if result_concat.returncode != 0:
            print("Error merging segments:")
            print(result_concat.stderr)
            logging.error(f"Error merging segments: {result_concat.stderr}")
            raise subprocess.CalledProcessError(result_concat.returncode, cmd_concat)

        cap = cv2.VideoCapture(temp_output_video)
        temp_frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
        temp_fps = cap.get(cv2.CAP_PROP_FPS)
        temp_duration = temp_frame_count / temp_fps if temp_fps > 0 else 0
        cap.release()
        logging.info(f"Temp merged video: Frames {temp_frame_count}, FPS {temp_fps:.2f}, Duration {temp_duration:.2f} sec")
        logging.info(f"Temp merged frame difference (actual - expected): {temp_frame_count - expected_frame_count}")
        logging.info(f"Temp merged duration difference (actual - input): {temp_duration - input_duration:.2f} sec")
        print(f"Temp merged video: Frames {temp_frame_count}, Diff from expected {temp_frame_count - expected_frame_count}, Duration {temp_duration:.2f} sec")

        cmd_final = [
            "ffmpeg", "-i", temp_output_video, "-i", file_name,
            "-map", "0:v", "-map", "1:a?", "-map", "1:s?",
            "-c:v", "libx264", "-vsync", "0", "-copyts",
            "-b:v", f"{output_bitrate}",
            "-c:a", "copy", "-c:s", "copy", "-y", output_video
        ]
        result_final = subprocess.run(cmd_final, capture_output=True, text=True)
        if result_final.returncode != 0:
            print("Error in final merge:")
            print(result_final.stderr)
            logging.error(f"Error in final merge: {result_final.stderr}")
            raise subprocess.CalledProcessError(result_final.returncode, cmd_final)

        cap = cv2.VideoCapture(output_video)
        final_frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
        final_fps = cap.get(cv2.CAP_PROP_FPS)
        final_duration = final_frame_count / final_fps if final_fps > 0 else 0
        cap.release()

        cmd_audio_duration = [
            "ffprobe", "-v", "error", "-show_entries", "format=duration",
            "-of", "default=noprint_wrappers=1:nokey=1", output_video
        ]
        result_audio_duration = subprocess.run(cmd_audio_duration, capture_output=True, text=True)
        final_audio_duration = float(result_audio_duration.stdout.strip()) if result_audio_duration.stdout.strip() else 0

        logging.info(f"Final video: Frames {final_frame_count}, FPS {final_fps:.2f}, Video Duration {final_duration:.2f} sec, Audio Duration {final_audio_duration:.2f} sec")
        logging.info(f"Final frame difference (final - expected): {final_frame_count - expected_frame_count}")
        logging.info(f"Final video duration difference (final - input): {final_duration - input_duration:.2f} sec")
        logging.info(f"Final audio-video duration difference: {final_audio_duration - final_duration:.2f} sec")
        print(f"Final video: Frames {final_frame_count}, Diff from expected {final_frame_count - expected_frame_count}, Video Duration {final_duration:.2f} sec, Audio Duration {final_audio_duration:.2f} sec")
        print(f"Duration differences: Video vs Input {final_duration - input_duration:.2f} sec, Audio vs Video {final_audio_duration - final_duration:.2f} sec")

In [None]:
#@title ##**Compare videos (optional)** { display-mode: "form" }
from IPython.display import display, HTML
import os
import base64

original_video_path = file_name
processed_video_path = output_video

if not os.path.exists(original_video_path):
    raise ValueError(f"Оригинальное видео не найдено по пути: {original_video_path}")
if not os.path.exists(processed_video_path):
    raise ValueError(f"Обработанное видео не найдено по пути: {processed_video_path}")

original_size = os.path.getsize(original_video_path) / (1024 * 1024)
processed_size = os.path.getsize(processed_video_path) / (1024 * 1024)
print(f"Размер оригинального видео: {original_size:.2f} МБ")
print(f"Размер обработанного видео: {processed_size:.2f} МБ")


def video_to_base64(video_path):
    with open(video_path, "rb") as video_file:
        video_data = video_file.read()
    return base64.b64encode(video_data).decode('utf-8')

original_base64 = video_to_base64(original_video_path)
processed_base64 = video_to_base64(processed_video_path)

html_code = f"""
<div style="display: flex; justify-content: center; flex-direction: column; align-items: center;">
    <div style="display: flex; justify-content: center;">
        <div style="margin-right: 10px;">
            <video id="originalVideo" width="400" controls preload="auto">
                <source src="data:video/mp4;base64,{original_base64}" type="video/mp4">
                Ваш браузер не поддерживает видео.
            </video>
            <p>Оригинальное видео</p>
        </div>
        <div>
            <video id="processedVideo" width="400" controls preload="auto">
                <source src="data:video/mp4;base64,{processed_base64}" type="video/mp4">
                Ваш браузер не поддерживает видео.
            </video>
            <p>Обработанное видео</p>
        </div>
    </div>
    <button id="playPauseBtn" style="margin-top: 10px; padding: 10px 20px; font-size: 16px;">Play</button>
</div>
<script>
(function() {{
    var originalVideo = document.getElementById("originalVideo");
    var processedVideo = document.getElementById("processedVideo");
    var playPauseBtn = document.getElementById("playPauseBtn");
    var isPlaying = false;

    playPauseBtn.disabled = false;

    function playBoth() {{
        Promise.all([
            originalVideo.play().catch(function(error) {{
                console.log("Ошибка воспроизведения оригинального видео:", error);
            }}),
            processedVideo.play().catch(function(error) {{
                console.log("Ошибка воспроизведения обработанного видео:", error);
            }})
        ]).then(function() {{
            playPauseBtn.textContent = "Pause";
            isPlaying = true;
        }}).catch(function(error) {{
            console.log("Не удалось воспроизвести видео:", error);
        }});
    }}

    function pauseBoth() {{
        originalVideo.pause();
        processedVideo.pause();
        playPauseBtn.textContent = "Play";
        isPlaying = false;
    }}

    playPauseBtn.addEventListener("click", function() {{
        if (isPlaying) {{
            pauseBoth();
        }} else {{
            playBoth();
        }}
    }});

    originalVideo.addEventListener("play", function() {{
        if (processedVideo.paused) processedVideo.play();
        playPauseBtn.textContent = "Pause";
        isPlaying = true;
    }});
    processedVideo.addEventListener("play", function() {{
        if (originalVideo.paused) originalVideo.play();
        playPauseBtn.textContent = "Pause";
        isPlaying = true;
    }});
    originalVideo.addEventListener("pause", function() {{
        if (!processedVideo.paused) processedVideo.pause();
        playPauseBtn.textContent = "Play";
        isPlaying = false;
    }});
    processedVideo.addEventListener("pause", function() {{
        if (!originalVideo.paused) originalVideo.pause();
        playPauseBtn.textContent = "Play";
        isPlaying = false;
    }});

    originalVideo.addEventListener("timeupdate", function() {{
        if (Math.abs(originalVideo.currentTime - processedVideo.currentTime) > 0.5) {{
            processedVideo.currentTime = originalVideo.currentTime;
        }}
    }});
    processedVideo.addEventListener("timeupdate", function() {{
        if (Math.abs(processedVideo.currentTime - originalVideo.currentTime) > 0.5) {{
            originalVideo.currentTime = processedVideo.currentTime;
        }}
    }});

    console.log("Скрипт синхронизации видео инициализирован");
}})();
</script>
"""

display(HTML(html_code))

In [None]:
#@title ##**Compare frames (optional)** { display-mode: "form" }
import cv2
import os
import numpy as np
from PIL import Image as PILImage
from IPython.display import display, Image

max_frames_to_show = 3 #@param {type:"slider", min:1, max:50, step:1}

if 'file_name' not in globals() or file_name is None or not os.path.exists(file_name):
    print("Error: Original video file is not selected or does not exist.")
elif 'output_video' not in globals() or not os.path.exists(output_video):
    print("Error: Interpolated video file not found. Run the 'Run' cell first.")
else:
    cap_orig = cv2.VideoCapture(file_name)
    cap_interp = cv2.VideoCapture(output_video)

    if not cap_orig.isOpened() or not cap_interp.isOpened():
        print("Error: Could not open one or both video files.")
    else:
        print(f"Building frame comparison strip for original ({file_name}) and interpolated ({output_video}) videos:")
        print(f"Interpolation factor: {interpolation_factor}")

        orig_frame_count = int(cap_orig.get(cv2.CAP_PROP_FRAME_COUNT))
        interp_frame_count = int(cap_interp.get(cv2.CAP_PROP_FRAME_COUNT))

        max_orig_frames = min(max_frames_to_show, orig_frame_count)
        max_interp_frames = min(max_orig_frames * interpolation_factor, interp_frame_count)

        frame_width, frame_height = 640, 360

        orig_frames = []
        interp_frames = []

        from PIL import ImageDraw, ImageFont
        try:
            font = ImageFont.truetype("DejaVuSans.ttf", 30)
        except:
            font = ImageFont.load_default()

        for orig_frame_num in range(max_orig_frames):
            ret_orig, frame_orig = cap_orig.read()
            if not ret_orig:
                print(f"Error reading original frame {orig_frame_num}")
                break
            frame_orig_rgb = cv2.cvtColor(frame_orig, cv2.COLOR_BGR2RGB)
            pil_img = PILImage.fromarray(frame_orig_rgb).resize((frame_width, frame_height))
            draw = ImageDraw.Draw(pil_img)
            draw.text((10, frame_height-40), f"#{orig_frame_num}", fill="white", font=font)
            orig_frames.append(pil_img)

        for interp_frame_num in range(max_interp_frames):
            ret_interp, frame_interp = cap_interp.read()
            if not ret_interp:
                print(f"Error reading interpolated frame {interp_frame_num}")
                break
            frame_interp_rgb = cv2.cvtColor(frame_interp, cv2.COLOR_BGR2RGB)
            pil_img = PILImage.fromarray(frame_interp_rgb).resize((frame_width, frame_height))
            draw = ImageDraw.Draw(pil_img)
            draw.text((10, frame_height-40), f"#{interp_frame_num}", fill="white", font=font)
            interp_frames.append(pil_img)

        cap_orig.release()
        cap_interp.release()

        blank_frame = PILImage.new('RGB', (frame_width, frame_height), color='black')
        draw_blank = ImageDraw.Draw(blank_frame)
        draw_blank.text((10, frame_height-40), "#N/A", fill="white", font=font)

        orig_strip_frames = []
        for i in range(max_interp_frames):
            if i % interpolation_factor == 0 and i // interpolation_factor < len(orig_frames):
                orig_strip_frames.append(orig_frames[i // interpolation_factor])
            else:
                orig_strip_frames.append(blank_frame.copy())

        try:
            num_frames = min(len(orig_strip_frames), len(interp_frames))

            orig_strip = PILImage.new('RGB', (frame_width * num_frames, frame_height))
            interp_strip = PILImage.new('RGB', (frame_width * num_frames, frame_height))

            for i in range(num_frames):
                x_offset = i * frame_width
                if i < len(orig_strip_frames):
                    orig_strip.paste(orig_strip_frames[i], (x_offset, 0))
                if i < len(interp_frames):
                    interp_strip.paste(interp_frames[i], (x_offset, 0))

            comparison_strip = PILImage.new('RGB', (frame_width * num_frames, frame_height * 2))
            comparison_strip.paste(orig_strip, (0, 0))
            comparison_strip.paste(interp_strip, (0, frame_height))

            draw = ImageDraw.Draw(comparison_strip)
            try:
                font = ImageFont.truetype("DejaVuSans.ttf", 20)
            except:
                font = ImageFont.load_default()

            draw.text((10, 10), "Original video", fill="white", font=font)
            draw.text((10, frame_height + 10), "Interpolated video", fill="white", font=font)

            strip_path = "/content/comparison_strip.jpg"
            comparison_strip.save(strip_path)

            print("\nOriginal frames (top) vs Interpolated frames (bottom):")
            display(Image(filename=strip_path))

        except Exception as e:
            print(f"Error creating comparison image: {str(e)}")

        print("\nFrame comparison strip completed.")

In [None]:
#@title ##**Download** { display-mode: "form" }
import os
from google.colab import files

if os.path.exists(output_video):
    print(f"Stabilized video saved at: {output_video}")
    files.download(output_video)
else:
    print("Error: Stabilized video was not created.")
