# RealESRGAN Video Inference with Fine-Tuned Model on Google Colab (T4 GPU)

This notebook demonstrates how to use a fine-tuned RealESRGAN model to perform video super-resolution on a Google Colab environment with a T4 GPU. The script processes a video frame by frame, enhances each frame using the fine-tuned RealESRGAN model, and saves the upscaled video with preserved audio.

## Prerequisites
- Ensure you have a Google Colab environment with a T4 GPU assigned (select Runtime > Change runtime type > T4 GPU).
- Upload a video file (e.g., `test1.mp4`) to the Colab working directory or clone the Git repository containing the video.
- Obtain the fine-tuned model weights (`net_g_5000.pth`) from the Git repository (stored as a Git LFS file). Clone the repository and ensure Git LFS is set up to download the model.
- Install Git LFS in Colab if needed: `!apt-get install git-lfs && git lfs install`.

## Steps
1. Install required dependencies and Git LFS.
2. Clone the Git repository containing the fine-tuned model (`net_g_5000.pth`).
3. Set up paths for the input video and model weights.
4. Run the inference script to process the video.
5. Download the output video from the specified output directory.

## Notes
- The script uses FP16 precision by default to optimize performance on T4 GPU.
- If you encounter CUDA out-of-memory errors, reduce the `tile` size (default: 1000).
- The output video resolution is scaled by the `outscale` factor (default: 4x).
- For `.flv` videos, the script converts them to `.mp4` before processing.
- The fine-tuned model (`net_g_5000.pth`) is specific to this implementation and differs from the original RealESRGAN_x4plus model.

In [None]:
# Install dependencies
!pip install -q basicsr facexlib gfpgan numpy opencv-python Pillow torch torchvision tqdm realesrgan
# Install Git LFS
!apt-get update && apt-get install -y git-lfs
!git lfs install

In [None]:
# Clone the Git repository containing the fine-tuned model (replace with your repo URL)
!git clone https://github.com/your-username/your-repo.git
# Pull LFS files (model weights)
!cd your-repo && git lfs pull

In [None]:
%%writefile dependency-fix.sh
#!/bin/bash
# Fix torchvision import in basicsr/data/degradations.py
sed -i 's/from torchvision.transforms.functional_tensor import rgb_to_grayscale/from torchvision.transforms.functional import rgb_to_grayscale/' /usr/local/lib/python3.11/dist-packages/basicsr/data/degradations.py

In [None]:
!chmod +x dependency-fix.sh
!./dependency-fix.sh

In [None]:
import cv2
import os
import torch
from basicsr.archs.rrdbnet_arch import RRDBNet
from realesrgan import RealESRGANer
from tqdm import tqdm
import ffmpeg
import mimetypes
import numpy as np

class VideoReader:
    def __init__(self, video_path, ffmpeg_bin='ffmpeg'):
        self.ffmpeg_bin = ffmpeg_bin
        meta = self.get_video_meta_info(video_path)
        self.width = meta['width']
        self.height = meta['height']
        self.fps = meta['fps']
        self.audio = meta['audio']
        self.nb_frames = meta['nb_frames']
        self.stream_reader = (
            ffmpeg.input(video_path).output('pipe:', format='rawvideo', pix_fmt='bgr24', loglevel='error')
            .run_async(pipe_stdin=True, pipe_stdout=True, cmd=ffmpeg_bin)
        )
        self.idx = 0

    def get_video_meta_info(self, video_path):
        probe = ffmpeg.probe(video_path)
        video_streams = [stream for stream in probe['streams'] if stream['codec_type'] == 'video']
        has_audio = any(stream['codec_type'] == 'audio' for stream in probe['streams'])
        return {
            'width': video_streams[0]['width'],
            'height': video_streams[0]['height'],
            'fps': eval(video_streams[0]['avg_frame_rate']),
            'audio': ffmpeg.input(video_path).audio if has_audio else None,
            'nb_frames': int(video_streams[0]['nb_frames'])
        }

    def get_frame(self):
        if self.idx >= self.nb_frames:
            return None
        img_bytes = self.stream_reader.stdout.read(self.width * self.height * 3)
        if not img_bytes:
            return None
        img = np.frombuffer(img_bytes, np.uint8).reshape([self.height, self.width, 3])
        self.idx += 1
        return img

    def get_resolution(self):
        return self.height, self.width

    def get_fps(self):
        return self.fps

    def get_audio(self):
        return self.audio

    def __len__(self):
        return self.nb_frames

    def close(self):
        self.stream_reader.stdin.close()
        self.stream_reader.wait()

class VideoWriter:
    def __init__(self, video_save_path, audio, height, width, fps, outscale, ffmpeg_bin='ffmpeg'):
        self.ffmpeg_bin = ffmpeg_bin
        out_width, out_height = int(width * outscale), int(height * outscale)
        if out_height > 2160:
            print('Warning: Output video exceeds 4K resolution, which may be slow due to I/O. Consider reducing outscale.')
        input_args = {
            'format': 'rawvideo',
            'pix_fmt': 'bgr24',
            's': f'{out_width}x{out_height}',
            'framerate': fps
        }
        output_args = {
            'pix_fmt': 'yuv420p',
            'vcodec': 'libx264',
            'loglevel': 'error'
        }
        if audio is not None:
            output_args['acodec'] = 'copy'
            self.stream_writer = (
                ffmpeg.input('pipe:', **input_args)
                .output(audio, video_save_path, **output_args)
                .overwrite_output()
                .run_async(pipe_stdin=True, pipe_stdout=True, cmd=ffmpeg_bin)
            )
        else:
            self.stream_writer = (
                ffmpeg.input('pipe:', **input_args)
                .output(video_save_path, **output_args)
                .overwrite_output()
                .run_async(pipe_stdin=True, pipe_stdout=True, cmd=ffmpeg_bin)
            )

    def write_frame(self, frame):
        frame = frame.astype(np.uint8).tobytes()
        self.stream_writer.stdin.write(frame)

    def close(self):
        self.stream_writer.stdin.close()
        self.stream_writer.wait()

def main():
    # Hardcoded parameters
    input_path = 'video/test1.mp4'
    output_dir = 'output'
    model_name = 'RealESRGAN_finetuned'
    model_path = 'your-repo/model/net_g_5000.pth'
    outscale = 4
    suffix = 'out'
    tile = 1000
    ffmpeg_bin = 'ffmpeg'
    fp32 = False  # Use FP16 by default

    # Validate input and model path
    input_path = input_path.rstrip('/').rstrip('\\')
    if not os.path.isfile(input_path) or not mimetypes.guess_type(input_path)[0].startswith('video'):
        raise ValueError('Input must be a video file')
    if not os.path.isfile(model_path):
        raise ValueError(f'Model path {model_path} does not exist')

    # Convert .flv to .mp4 if necessary
    if input_path.endswith('.flv'):
        mp4_path = input_path.replace('.flv', '.mp4')
        os.system(f'{ffmpeg_bin} -i {input_path} -codec copy {mp4_path}')
        input_path = mp4_path

    # Initialize model
    model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
    upsampler = RealESRGANer(
        scale=4,
        model_path=model_path,
        model=model,
        tile=tile,
        tile_pad=10,
        pre_pad=0,
        half=not fp32
    )

    # Create output directory
    os.makedirs(output_dir, exist_ok=True)
    video_name = os.path.splitext(os.path.basename(input_path))[0]
    video_save_path = os.path.join(output_dir, f'{video_name}_{suffix}.mp4')

    # Process video
    reader = VideoReader(input_path, ffmpeg_bin)
    audio = reader.get_audio()
    height, width = reader.get_resolution()
    fps = reader.get_fps()
    writer = VideoWriter(video_save_path, audio, height, width, fps, outscale, ffmpeg_bin)

    pbar = tqdm(total=len(reader), unit='frame', desc=f'Processing {video_name}')
    while True:
        img = reader.get_frame()
        if img is None:
            break
        try:
            output, _ = upsampler.enhance(img, outscale=outscale)
            writer.write_frame(output)
        except RuntimeError as error:
            print(f'Error processing frame: {error}')
            print('Try reducing tile size if you encounter CUDA out of memory.')
        pbar.update(1)

    reader.close()
    writer.close()
    print(f'Saved: {video_save_path}')

if __name__ == '__main__':
    main()

## Usage Instructions

1. **Set Up Files**:
   - Upload your input video (e.g., `test1.mp4`) to the `video/` directory in the Colab working directory, or include it in your Git repository.
   - Clone the Git repository containing the fine-tuned model weights (`net_g_5000.pth`) by running the provided cell. Ensure Git LFS is installed to download the model.

2. **Modify Paths**:
   - Update `input_path` in the `main` function to point to your video file (e.g., `video/test1.mp4`).
   - Update `model_path` to point to the fine-tuned model (e.g., `your-repo/model/net_g_5000.pth`).
   - Optionally, adjust `output_dir`, `outscale`, or `tile` as needed.

3. **Run the Notebook**:
   - Execute all cells in order.
   - Monitor the progress bar for frame processing.

4. **Download Output**:
   - The enhanced video will be saved in the `output/` directory (e.g., `output/test1_out.mp4`).
   - Download the video from Colab's file explorer.

## Troubleshooting
- **CUDA Out of Memory**: Reduce `tile` size (e.g., to 512) or lower `outscale`.
- **FFmpeg Errors**: Ensure FFmpeg is installed correctly by running the dependency installation cell.
- **Model Not Found**: Verify the Git repository is cloned, Git LFS is set up, and the model path is correct.
- **Slow Processing**: High `outscale` or large video resolutions may increase processing time. Consider reducing `outscale` or video resolution.
- **Git LFS Issues**: Ensure Git LFS is installed and the repository is properly configured. Run `git lfs pull` in the repository directory to download the model.