In [115]:
from typing import Optional
from numpy.typing import NDArray

from pathlib import Path
import torch
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
from diffusers import (
    StableDiffusionLatentUpscalePipeline
)
from torchvision.transforms.v2 import ToImage, ToDtype, Compose
import imageio.v3 as iio

from diffusers.utils.logging import disable_progress_bar
disable_progress_bar()

In [116]:
DATA = Path.cwd() / "data"
SD_VID_PATH = DATA / "inter4k_222_sd.mp4"
HD_VID_PATH = DATA / "inter4k_222_hd.mp4"

In [143]:
def get_upscale_pipelines():
    upscaler_model_id = "stabilityai/sd-x2-latent-upscaler"
    upscaler = StableDiffusionLatentUpscalePipeline.from_pretrained(
        upscaler_model_id,
        torch_dtype=torch.float16,
        use_safetensors = True
    )
    #upscaler.scheduler = DDIMScheduler.from_config(
        #upscaler.scheduler.config, timestep_spacing='trailing', rescale_betas_zero_snr=True,
    #) pipeline.set_progress_bar_config(disable=True)
    upscaler.set_progress_bar_config(disable = True)
    upscaler = upscaler.to("cuda")
    #upscaler.enable_sequential_cpu_offload()
    return upscaler

def get_upscaled_frame(frame: NDArray, generator: torch.Generator, upscaler: StableDiffusionLatentUpscalePipeline):
    return upscaler(
        prompt = " ",
        image = frame,
        num_inference_steps = 20,
        guidance_scale = 0,
        generator = generator,
        output_type = "pt"
    ).images[0].cpu()

def upscale_video(src_vid: Path, tgt_vid: Path, random_seed: int, limit_frames: Optional[int] = None):
    generator = torch.Generator("cuda").manual_seed(random_seed)
    to_uint8 = Compose([ToImage(), ToDtype(torch.uint8, scale = True)]) 
    to_float32 = Compose([ToImage(), ToDtype(torch.float32, scale = True)]) 

    with iio.imopen(tgt_vid, "w", plugin = "pyav") as upscaled_vid:
        metadata = iio.immeta(src_vid, plugin = "pyav")
        if limit_frames is None:
            limit_frames = metadata["fps"] * metadata["duration"]

        upscaled_vid.init_video_stream("h264", fps = metadata["fps"])
        upscaler = get_upscale_pipelines()

        for i, frame in tqdm(enumerate(iio.imiter(src_vid, plugin = "pyav")), total = limit_frames):
            if i >= limit_frames:
                break
            
            # print(frame.shape, frame.dtype, frame.min().item(), frame.max().item())
            low_res_frame = to_float32(frame)
            # print(low_res_frame.shape, low_res_frame.dtype, low_res_frame.min().item(), low_res_frame.max().item())
            high_res_frame = get_upscaled_frame(low_res_frame, generator, upscaler)
            # print(high_res_frame.shape, high_res_frame.dtype, high_res_frame.min(), high_res_frame.max())
            high_res_frame = to_uint8(high_res_frame).numpy()
            # print(high_res_frame.shape, high_res_frame.dtype, high_res_frame.min(), high_res_frame.max())
            upscaled_vid.write_frame(high_res_frame.transpose(1, 2, 0))

In [145]:
upscale_video(SD_VID_PATH, HD_VID_PATH, 42)

100%|██████████| 300/300.0 [05:40<00:00,  1.14s/it]


In [146]:
iio.immeta(HD_VID_PATH, plugin = "pyav")

{'video_format': 'yuv420p',
 'codec': 'h264',
 'long_codec': 'H.264 / AVC / MPEG-4 AVC / MPEG-4 part 10',
 'profile': 'High',
 'fps': 60.0,
 'duration': 5.0,
 'major_brand': 'isom',
 'minor_version': '512',
 'compatible_brands': 'isomiso2avc1mp41',
 'encoder': 'Lavf60.16.100',
 'language': 'und',
 'handler_name': 'VideoHandler',
 'vendor_id': '[0][0][0][0]'}

In [None]:
# print(f"Loaded video {SD_VID_PATH} as ndarray")
# print(f"It's a {sd_vid.shape[1]}x{sd_vid.shape[2]} video which is", end = " ")
# print(f"{sd_vid_metadata['duration']}s long, at {sd_vid_metadata['fps']}fps, thus totalling {sd_vid.shape[0]} frames")