In [None]:
!pip install torch torchvision opencv-python numpy pillow

import os
import cv2
import numpy as np
import torch
from torch.nn import functional as F
import glob
from collections import OrderedDict
from PIL import Image
import torchvision.transforms as transforms
from google.colab import files
from IPython.display import display, HTML
import time

# Define the ESRGAN model architecture
class ResidualDenseBlock(torch.nn.Module):
    def __init__(self, nf=64, gc=32, bias=True):
        super(ResidualDenseBlock, self).__init__()
        self.conv1 = torch.nn.Conv2d(nf, gc, 3, 1, 1, bias=bias)
        self.conv2 = torch.nn.Conv2d(nf + gc, gc, 3, 1, 1, bias=bias)
        self.conv3 = torch.nn.Conv2d(nf + 2 * gc, gc, 3, 1, 1, bias=bias)
        self.conv4 = torch.nn.Conv2d(nf + 3 * gc, gc, 3, 1, 1, bias=bias)
        self.conv5 = torch.nn.Conv2d(nf + 4 * gc, nf, 3, 1, 1, bias=bias)
        self.lrelu = torch.nn.LeakyReLU(negative_slope=0.2, inplace=True)

    def forward(self, x):
        x1 = self.lrelu(self.conv1(x))
        x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))
        x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))
        x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
        x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
        return x5 * 0.2 + x


class RRDB(torch.nn.Module):
    def __init__(self, nf, gc=32):
        super(RRDB, self).__init__()
        self.RDB1 = ResidualDenseBlock(nf, gc)
        self.RDB2 = ResidualDenseBlock(nf, gc)
        self.RDB3 = ResidualDenseBlock(nf, gc)

    def forward(self, x):
        out = self.RDB1(x)
        out = self.RDB2(out)
        out = self.RDB3(out)
        return out * 0.2 + x


class RRDBNet(torch.nn.Module):
    def __init__(self, in_nc=3, out_nc=3, nf=64, nb=23, gc=32, scale=4, upscale=True):
        super(RRDBNet, self).__init__()
        self.scale = scale
        self.upscale = upscale

        self.conv_first = torch.nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True)
        self.RRDB_trunk = torch.nn.Sequential(*[RRDB(nf, gc) for _ in range(nb)])
        self.trunk_conv = torch.nn.Conv2d(nf, nf, 3, 1, 1, bias=True)

        # Upsampling
        if upscale:
            self.upconv1 = torch.nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
            if scale == 4:
                self.upconv2 = torch.nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
            self.HRconv = torch.nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
            self.conv_last = torch.nn.Conv2d(nf, out_nc, 3, 1, 1, bias=True)
            self.lrelu = torch.nn.LeakyReLU(negative_slope=0.2, inplace=True)

    def forward(self, x):
        fea = self.conv_first(x)
        trunk = self.trunk_conv(self.RRDB_trunk(fea))
        fea = fea + trunk

        if self.upscale:
            fea = self.lrelu(self.upconv1(F.interpolate(fea, scale_factor=2, mode='nearest')))
            if self.scale == 4:
                fea = self.lrelu(self.upconv2(F.interpolate(fea, scale_factor=2, mode='nearest')))
            out = self.conv_last(self.lrelu(self.HRconv(fea)))
        else:
            out = fea

        return out


def load_model(model_path):
    model = RRDBNet(scale=4)

    # Load the state dictionary
    state_dict = torch.load(model_path, map_location=torch.device('cpu'))

    # Often pre-trained models have different key names, so we need to process them
    if 'params_ema' in state_dict:
        state_dict = state_dict['params_ema']

    # Create new OrderedDict that does not contain module.
    new_state_dict = OrderedDict()
    for k, v in state_dict.items():
        if k.startswith('module.'):
            name = k[7:]  # remove module.
        else:
            name = k
        new_state_dict[name] = v

    model.load_state_dict(new_state_dict, strict=True)
    model.eval()
    return model


def enhance_image(model, img, device):
    # Convert to RGB if needed and normalize
    if len(img.shape) == 3 and img.shape[2] == 3:
        img = img[:, :, [2, 1, 0]]  # BGR to RGB
    img = img.astype(np.float32) / 255.

    # Convert to tensor
    img = torch.from_numpy(img).permute(2, 0, 1).unsqueeze(0).to(device)

    # Inference
    with torch.no_grad():
        output = model(img)

    # Convert back to numpy array
    output = output.squeeze().float().cpu().clamp_(0, 1).detach().permute(1, 2, 0).numpy()
    output = (output * 255.0).round().astype(np.uint8)

    # Convert back to BGR for OpenCV
    if output.shape[2] == 3:
        output = output[:, :, [2, 1, 0]]

    return output


def process_video(input_video, output_video, model, device='cpu', scale_factor=4):
    # Open the input video
    cap = cv2.VideoCapture(input_video)

    # Get video properties
    fps = cap.get(cv2.CAP_PROP_FPS)
    width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
    total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))

    # Define the codec and create VideoWriter object
    fourcc = cv2.VideoWriter_fourcc(*'mp4v')
    out = cv2.VideoWriter(output_video, fourcc, fps, (width * scale_factor, height * scale_factor))

    print(f"Processing video: {input_video}")
    print(f"Output video: {output_video}")
    print(f"Total frames: {total_frames}")
    print(f"FPS: {fps}")
    print(f"Original resolution: {width}x{height}")
    print(f"Enhanced resolution: {width*scale_factor}x{height*scale_factor}")

    start_time = time.time()
    frame_count = 0

    # Process video frame by frame
    while cap.isOpened():
        ret, frame = cap.read()
        if not ret:
            break

        # Enhance the frame
        enhanced_frame = enhance_image(model, frame, device)

        # Write the enhanced frame
        out.write(enhanced_frame)

        frame_count += 1

        # Calculate progress and ETA
        elapsed_time = time.time() - start_time
        frames_per_second = frame_count / elapsed_time if elapsed_time > 0 else 0
        remaining_frames = total_frames - frame_count
        eta_seconds = remaining_frames / frames_per_second if frames_per_second > 0 else 0

        # Format ETA as HH:MM:SS
        eta_hours = int(eta_seconds // 3600)
        eta_minutes = int((eta_seconds % 3600) // 60)
        eta_seconds = int(eta_seconds % 60)

        print(f"Processing: {frame_count}/{total_frames} ({frame_count/total_frames*100:.2f}%) | " +
              f"Speed: {frames_per_second:.2f} fps | " +
              f"ETA: {eta_hours:02d}:{eta_minutes:02d}:{eta_seconds:02d}", end='\r')

    # Release everything when done
    cap.release()
    out.release()

    elapsed_time = time.time() - start_time
    print(f"\nVideo processing complete! Total time: {elapsed_time:.2f} seconds")
    return output_video

# Updated main Colab interface to use the uploaded files
def main_colab_uploaded(input_video_path="demo.mp4", model_path="RRDB_PSNR_x4.pth"):
    # Check if CUDA is available
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    print(f"Using device: {device}")

    print(f"Using model file: {model_path}")
    try:
        model = load_model(model_path)
        model = model.to(device)
        print("Model loaded successfully.")
    except Exception as e:
        print(f"Error loading model: {e}")
        return None

    print(f"Processing video: {input_video_path}")

    # Set output filename
    input_filename = os.path.splitext(os.path.basename(input_video_path))[0]
    output_video = f"enhanced_{input_filename}.mp4"

    # Process the video
    processed_file = process_video(input_video_path, output_video, model, device)

    if processed_file:
        # Download the enhanced video
        print("Processing complete! Downloading the enhanced video...")
        files.download(output_video)
        return output_video
    else:
        print("Video processing failed.")
        return None

# Run the updated Colab interface with the uploaded files
main_colab_uploaded(input_video_path="/content/animation144p_input.mp4", model_path="/content/RRDB_ESRGAN_x4.pth")

Using device: cuda
Using model file: /content/RRDB_ESRGAN_x4.pth
Model loaded successfully.
Processing video: /content/animation144p_input.mp4
Processing video: /content/animation144p_input.mp4
Output video: enhanced_animation144p_input.mp4
Total frames: 252
FPS: 25.0
Original resolution: 256x144
Enhanced resolution: 1024x576
Processing: 252/252 (100.00%) | Speed: 2.77 fps | ETA: 00:00:00
Video processing complete! Total time: 91.05 seconds
Processing complete! Downloading the enhanced video...


<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

'enhanced_animation144p_input.mp4'