In [None]:
from typing import Any, Optional, Iterable 
from numpy.typing import NDArray

from pathlib import Path
import torch
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
from diffusers import StableDiffusionPipeline, StableDiffusionLatentUpscalePipeline
from torch.utils.data import DataLoader

import av
import av.stream
import imageio.v3 as iio

import torch._dynamo
torch._dynamo.config.suppress_errors = True

In [None]:
DATA = Path.cwd() / "data"
SD_VID_PATH = DATA / "inter4k_222_sd.mp4"
HD_VID_PATH = DATA / "inter4k_222_hd.mp4"

In [None]:
def get_metadata(container: av.ContainerFormat) -> dict:
    video = container.streams.video[0]
    return {
        "width": video.width,
        "height": video.height,
        "duration": float(video.duration * video.time_base),
        "fps": float(video.average_rate),
        "frames": video.frames,
        "pixel_format": video.format 
    }

def get_output_stream(container: av.ContainerFormat, fps: int, width: int, height: int, pixel_format: str, **kwargs)  -> av.video.stream.VideoStream:
    stream = container.add_stream("mpeg4", rate = fps)
    stream.width = width
    stream.height = height

    # TODO: take pixel format input from metadata
    stream.pix_fmt = "yuv420p" 
    return stream

def write_frame_to_stream(container: av.ContainerFormat, stream: av.video.stream.VideoStream, frame: Optional[NDArray] = None):
    frame = av.VideoFrame.from_ndarray(frame, format = "rgb24") if frame is not None else None
    for packet in stream.encode(frame):
        container.mux(packet)

In [None]:
def get_upscale_pipelines():
    pipeline = StableDiffusionPipeline.from_pretrained(
        "CompVis/stable-diffusion-v1-4", 
        torch_dtype=torch.float16,
        use_safetensors = True
    ).to("cuda")
    pipeline.enable_sequential_cpu_offload()
    pipeline.enable_attention_slicing()

    upscaler = StableDiffusionLatentUpscalePipeline.from_pretrained(
        "stabilityai/sd-x2-latent-upscaler",
        torch_dtype=torch.float16
    ).to("cuda")
    upscaler.enable_sequential_cpu_offload()
    upscaler.enable_attention_slicing()

def get_upscaled_frame(frame: NDArray, prompt: str, generator: torch.Generator, sd_encoder: StableDiffusionPipeline, upscaler: StableDiffusionLatentUpscalePipeline):
    low_res_latents = sd_encoder(
        prompt = prompt,
        image = frame,
        generator = generator,
        output_type = "latent",
    ).images
    upscaled_frame = upscaler(
        prompt = prompt,
        image = low_res_latents,
        num_inference_steps = 20,
        guidance_scale = 0,
        generator = generator,
    ).images[0]
    return upscaled_frame


def upscale_video(src_vid: Path, tgt_vid: Path, prompt: str, random_seed: int, limit_frames: int):
    src_container = av.open(src_vid, "r")
    tgt_container = av.open(tgt_vid, "w")
    generator = torch.Generator("cuda").manual_seed(random_seed)

    with tgt_container:
        with src_container:
            src_container.streams.video[0].thread_type = "AUTO"
            metadata = get_metadata(src_container)
            tgt_stream = get_output_stream(tgt_container, **metadata)
            sd_encoder, upscaler = get_upscale_pipelines()

            for i, src_frame in enumerate(src_container.decode(video = 0)):
                if i > limit_frames:
                    break

                frame = src_frame.to_ndarray(format = "rgb24")
                # frame = frame[120:360, 160:480, :].copy()
                frame = get_upscaled_frame(frame, prompt, generator, sd_encoder, upscaler)
                write_frame_to_stream(tgt_container, tgt_stream, frame)
        write_frame_to_stream(tgt_container, tgt_stream)

In [None]:
upscale_video(SD_VID_PATH, HD_VID_PATH, "nightscape", 42, 1)

In [None]:
# print(f"Loaded video {SD_VID_PATH} as ndarray")
# print(f"It's a {sd_vid.shape[1]}x{sd_vid.shape[2]} video which is", end = " ")
# print(f"{sd_vid_metadata['duration']}s long, at {sd_vid_metadata['fps']}fps, thus totalling {sd_vid.shape[0]} frames")