In [45]:
from typing import Any, Optional
from numpy.typing import NDArray

from pathlib import Path
import torch
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
from diffusers import StableDiffusionUpscalePipeline
from torch.utils.data import DataLoader

import av
import imageio.v3 as iio
from PIL import Image

import torch._dynamo
torch._dynamo.config.suppress_errors = True

In [46]:
DATA = Path.cwd() / "data"
SD_VID_PATH = DATA / "inter4k_222_sd.mp4"
HD_VID_PATH = DATA / "inter4k_222_hd.mp4"

sd_vid = iio.imread(SD_VID_PATH, plugin = "pyav")
sd_vid_metadata = iio.immeta(SD_VID_PATH, plugin = "pyav")
fps = sd_vid_metadata["fps"]

print(f"Loaded video {SD_VID_PATH} as ndarray")
print(f"It's a {sd_vid.shape[1]}x{sd_vid.shape[2]} video which is", end = " ")
print(f"{sd_vid_metadata['duration']}s long, at {sd_vid_metadata['fps']}fps, thus equalling {sd_vid.shape[0]} frames")

Loaded video /home/video-upscaling/data/inter4k_222_sd.mp4 as ndarray
It's a 480x640 video which is 5.0s long, at 60.0fps, thus equalling 300 frames


In [64]:
def upscale_video_4x(video: NDArray, prompt: str, limit_frames: Optional[int] = None, random_seed: int = 42, ) -> NDArray:
    """
    Parameters
    ----------
    video: NDArray
        numpy array of shape (frames, height, width, channels)
    prompt: str
        text prompt given to the model
    limit_frames: int, optional
        for debugging purposes, only upscale first limit_frames frames of the video
    random_seed: int, optional
        for reproducibility
    """

    model_id = "stabilityai/stable-diffusion-x4-upscaler"
    generator = torch.Generator("cuda").manual_seed(random_seed)
    upscale = StableDiffusionUpscalePipeline.from_pretrained(
        model_id,
        torch_dtype = torch.float16,
        use_safetensors = True,
        # generator = generator,
        # output_type = "ndarray",
        # num_inference_steps = 20,
        # num_images_per_prompt = 1,
    )
    # upscale.unet.set_attn_processor(AttnProcessor2_0())
    upscale = upscale.to("cuda")
    # upscale.unet = torch.compile(upscale.unet, mode="reduce-overhead", fullgraph=True)
    upscale.enable_sequential_cpu_offload()
    upscale.enable_attention_slicing()

    if limit_frames is None:
        limit_frames = video.shape[0]
    print(f"limiting frames to: {limit_frames}")
    
    upscaled_video = list()
    for idx, frame in enumerate(video):
        print(f"Upscaling Frame: {idx}")
        if idx > limit_frames:
            break
        frame = Image.fromarray(frame)
        upscaled_frame = upscale(
            prompt = prompt, 
            image = frame, 
            num_inference_steps = 20,
            generator = generator,
            ).images[0]
        upscaled_video.append(np.array(upscaled_frame))
    return np.stack(upscaled_video)

In [65]:
hd_vid = upscale_video_4x(sd_vid, "cityscape at night", limit_frames = 20)

Loading pipeline components...:   0%|          | 0/6 [00:00<?, ?it/s]

limiting frames to: 20
Upscaling Frame: 0


  0%|          | 0/20 [00:00<?, ?it/s]

Upscaling Frame: 1


  0%|          | 0/20 [00:00<?, ?it/s]

Upscaling Frame: 2


  0%|          | 0/20 [00:00<?, ?it/s]

Upscaling Frame: 3


  0%|          | 0/20 [00:00<?, ?it/s]

Upscaling Frame: 4


  0%|          | 0/20 [00:00<?, ?it/s]

Upscaling Frame: 5


  0%|          | 0/20 [00:00<?, ?it/s]

Upscaling Frame: 6


  0%|          | 0/20 [00:00<?, ?it/s]

Upscaling Frame: 7


  0%|          | 0/20 [00:00<?, ?it/s]

Upscaling Frame: 8


  0%|          | 0/20 [00:00<?, ?it/s]

Upscaling Frame: 9


  0%|          | 0/20 [00:00<?, ?it/s]

Upscaling Frame: 10


  0%|          | 0/20 [00:00<?, ?it/s]

Upscaling Frame: 11


  0%|          | 0/20 [00:00<?, ?it/s]

Upscaling Frame: 12


  0%|          | 0/20 [00:00<?, ?it/s]

Upscaling Frame: 13


  0%|          | 0/20 [00:00<?, ?it/s]

Upscaling Frame: 14


  0%|          | 0/20 [00:00<?, ?it/s]

Upscaling Frame: 15


  0%|          | 0/20 [00:00<?, ?it/s]

Upscaling Frame: 16


  0%|          | 0/20 [00:00<?, ?it/s]

Upscaling Frame: 17


  0%|          | 0/20 [00:00<?, ?it/s]

Upscaling Frame: 18


  0%|          | 0/20 [00:00<?, ?it/s]

Upscaling Frame: 19


  0%|          | 0/20 [00:00<?, ?it/s]

Upscaling Frame: 20


  0%|          | 0/20 [00:00<?, ?it/s]

Upscaling Frame: 21


In [66]:
hd_vid.shape

(21, 1920, 2560, 3)

In [75]:
# TODO: Encode video using h264 at 60fps
iio.imwrite(HD_VID_PATH, hd_vid)
print(f"Upscaled video {HD_VID_PATH}")
print(f"It's a {hd_vid.shape[1]}x{hd_vid.shape[2]} video, with {hd_vid.shape[0]} frames at {sd_vid_metadata['fps']}fps, thus totalling {hd_vid.shape[0] / sd_vid_metadata['fps'] :.2f}s")

Upscaled video /home/video-upscaling/data/inter4k_222_hd.mp4
It's a 1920x2560 video, with 21 frames at 60.0fps, thus totalling 0.35s
