<a href="https://colab.research.google.com/github/kinrz/RAFT-Motion-Blur/blob/main/RAFT_Motion_Blur_GUI.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:

# @title RAFT Motion Blur (GUI)
# @markdown Jalankan sel ini dan tunggu sampai link gradio dibuat.

import os
import subprocess
import sys
from google.colab import output

def install_packages():
    subprocess.run(["pip", "install", "-q", "gradio"], check=True)
    subprocess.run(["apt-get", "install", "-y", "-qq", "ffmpeg"], check=True)

try:
    import gradio as gr
except ImportError:
    install_packages()
    import gradio as gr

import torch
import torch.nn.functional as F
import torchvision.transforms.functional as TF
from torchvision.models.optical_flow import raft_large, Raft_Large_Weights
import cv2
import numpy as np
import shutil
from glob import glob
import gc

os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def load_model():
    weights = Raft_Large_Weights.DEFAULT
    transforms = weights.transforms()
    model = raft_large(weights=weights, progress=False).to(device)
    model.eval()
    return model, transforms

model, transforms = load_model()

def get_grid(H, W):
    y = torch.linspace(-1, 1, H, device=device)
    x = torch.linspace(-1, 1, W, device=device)
    gy, gx = torch.meshgrid(y, x, indexing='ij')
    return torch.stack((gx, gy), dim=2).unsqueeze(0)

def calculate_consistency(flow_fwd, flow_bwd):
    N, C, H, W = flow_fwd.shape
    base_grid = get_grid(H, W)
    flow_fwd_norm = flow_fwd.permute(0, 2, 3, 1).clone()
    flow_fwd_norm[..., 0] /= (W / 2.0)
    flow_fwd_norm[..., 1] /= (H / 2.0)
    grid = base_grid + flow_fwd_norm
    warped_bwd = F.grid_sample(flow_bwd, grid, mode='bilinear', padding_mode='reflection', align_corners=True)
    diff = flow_fwd + warped_bwd
    magnitude = torch.norm(diff, dim=1, keepdim=True)
    return torch.exp(-0.5 * magnitude)

def apply_vector_blur_complex(img_tensor, flow_tensor, samples, strength, expansion, mask=None):
    N, C, H, W = img_tensor.shape
    base_grid = get_grid(H, W)
    accumulated_img = torch.zeros_like(img_tensor)

    processed_flow = flow_tensor.clone()
    if expansion > 0:
        k_size = int(expansion) * 2 + 1
        processed_flow = TF.gaussian_blur(processed_flow, kernel_size=k_size, sigma=expansion)

    flow_norm = processed_flow.permute(0, 2, 3, 1).clone()
    flow_norm[..., 0] /= (W / 2.0)
    flow_norm[..., 1] /= (H / 2.0)

    for i in range(samples):
        t = (i / (samples - 1)) - 0.5
        offset = flow_norm * (t * strength)
        sampling_grid = base_grid + offset
        warped = F.grid_sample(img_tensor, sampling_grid, mode='bilinear', padding_mode='reflection', align_corners=True)
        accumulated_img += warped

    blurred = accumulated_img / samples

    if mask is not None:
        return (blurred * mask) + (img_tensor * (1 - mask))
    else:
        return blurred

def pad_to_8(tensor):
    h, w = tensor.shape[-2:]
    new_h = ((h + 7) // 8) * 8
    new_w = ((w + 7) // 8) * 8
    ph = new_h - h
    pw = new_w - w
    if ph > 0 or pw > 0:
        return F.pad(tensor, (0, pw, 0, ph)), h, w
    return tensor, h, w

def process_raft(input_video, blur_direction, tail_expansion, blur_strength, num_samples, flow_iterations, vram_safe_mode, safe_resolution, consistency_check, video_quality, progress=gr.Progress()):
    if input_video is None:
        return None

    temp_dir = "temp_frames_gradio"
    blur_dir = "blur_frames_gradio"
    output_video = "result.mp4"

    if os.path.exists(temp_dir): shutil.rmtree(temp_dir)
    if os.path.exists(blur_dir): shutil.rmtree(blur_dir)
    os.makedirs(temp_dir)
    os.makedirs(blur_dir)

    progress(0.05, desc="Extracting Frames...")
    subprocess.run(f"ffmpeg -i {input_video} -pix_fmt rgb24 {temp_dir}/%08d.png -y -hide_banner -loglevel error", shell=True)
    frame_files = sorted(glob(f"{temp_dir}/*.png"))
    total_frames = len(frame_files)

    limit_res = safe_resolution if vram_safe_mode else 0

    for i, curr_path in enumerate(frame_files):
        progress((i / total_frames), desc=f"Rendering Frame {i}/{total_frames}")

        if blur_direction == "Forward":
            if i == total_frames - 1:
                shutil.copy(curr_path, f"{blur_dir}/{i:08d}.png")
                continue
            target_path = frame_files[i+1]
        else:
            if i == 0:
                shutil.copy(curr_path, f"{blur_dir}/{i:08d}.png")
                continue
            target_path = frame_files[i-1]

        img1_orig = cv2.cvtColor(cv2.imread(curr_path), cv2.COLOR_BGR2RGB)
        img2_orig = cv2.cvtColor(cv2.imread(target_path), cv2.COLOR_BGR2RGB)
        orig_h, orig_w = img1_orig.shape[:2]

        calc_h, calc_w = orig_h, orig_w
        scale_factor = 1.0

        if limit_res > 0 and (orig_w > limit_res or orig_h > limit_res):
            if orig_w >= orig_h: scale_factor = limit_res / orig_w
            else: scale_factor = limit_res / orig_h
            calc_w, calc_h = int(orig_w * scale_factor), int(orig_h * scale_factor)
            img1_small = cv2.resize(img1_orig, (calc_w, calc_h), interpolation=cv2.INTER_LINEAR)
            img2_small = cv2.resize(img2_orig, (calc_w, calc_h), interpolation=cv2.INTER_LINEAR)
        else:
            img1_small, img2_small = img1_orig, img2_orig

        img1_t_full = TF.to_tensor(img1_orig).unsqueeze(0).to(device)
        img1_t_calc = TF.to_tensor(img1_small).unsqueeze(0).to(device)
        img2_t_calc = TF.to_tensor(img2_small).unsqueeze(0).to(device)

        img1_batch, img2_batch = transforms(img1_t_calc, img2_t_calc)
        img1_pad, pad_h, pad_w = pad_to_8(img1_batch)
        img2_pad, _, _ = pad_to_8(img2_batch)

        with torch.no_grad():
            flow_main = model(img1_pad, img2_pad, num_flow_updates=flow_iterations)[-1]

            consistency_mask = None
            if consistency_check:
                flow_rev = model(img2_pad, img1_pad, num_flow_updates=flow_iterations)[-1]
                mask_pad = calculate_consistency(flow_main, flow_rev)
                consistency_mask = mask_pad[:, :, :pad_h, :pad_w]

            del img1_batch, img2_batch, img1_pad, img2_pad
            if consistency_check: del flow_rev

        flow_small = flow_main[:, :, :pad_h, :pad_w]

        if scale_factor != 1.0:
            flow_final = F.interpolate(flow_small, size=(orig_h, orig_w), mode='bilinear', align_corners=False)
            flow_final *= (1.0 / scale_factor)
            if consistency_mask is not None:
                consistency_mask = F.interpolate(consistency_mask, size=(orig_h, orig_w), mode='bilinear', align_corners=False)
        else:
            flow_final = flow_small

        blurred_tensor = apply_vector_blur_complex(img1_t_full, flow_final, num_samples, blur_strength, tail_expansion, consistency_mask)

        res_img = (blurred_tensor.squeeze(0).permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8)
        res_img = cv2.cvtColor(res_img, cv2.COLOR_RGB2BGR)
        cv2.imwrite(f"{blur_dir}/{i:08d}.png", res_img)

        del flow_main, flow_small, flow_final, blurred_tensor, img1_t_full, img1_t_calc, img2_t_calc
        if consistency_mask is not None: del consistency_mask
        torch.cuda.empty_cache()

    progress(0.98, desc="Encoding Video...")
    fps_cmd = f"ffmpeg -i {input_video} 2>&1 | sed -n 's/.*, \\(.*\\) fp.*/\\1/p'"
    try:
        orig_fps = subprocess.check_output(fps_cmd, shell=True).decode("utf-8").strip()
        if not orig_fps: orig_fps = 30
    except: orig_fps = 30

    subprocess.run(f"ffmpeg -r {orig_fps} -i {blur_dir}/%08d.png -i {input_video} -map 0:v -map 1:a -c:a copy -c:v h264_nvenc -profile:v high -level 4.1 -preset p6 -tune hq -rc constqp -qp {video_quality} -pix_fmt yuv420p {output_video} -y -hide_banner -loglevel error", shell=True)

    return output_video

rft_css = """
footer {visibility: hidden}
.gradio-container {background-color: #1a1a1a !important;}
.contain {background-color: #1a1a1a !important;}
div[data-testid="block-label"] {color: #cfcfcf !important; font-weight: bold;}
label span {color: #cfcfcf !important;}

button.primary {
    background: linear-gradient(90deg, #7c3aed, #a855f7) !important;
    color: white !important;
    border: none !important;
    box-shadow: 0 0 15px rgba(139, 92, 246, 0.6) !important;
    transition: all 0.3s ease;
}
button.primary:hover {
    box-shadow: 0 0 25px rgba(139, 92, 246, 0.9) !important;
    transform: scale(1.02);
}

fieldset label {
    background-color: transparent !important;
    background-image: none !important;
    border: 2px solid #8b5cf6 !important;
    border-radius: 8px !important;
    transition: all 0.3s ease;
    margin-right: 5px;
}

fieldset label.selected {
    background-color: rgba(139, 92, 246, 0.1) !important;
    border-color: #8b5cf6 !important;
    box-shadow: 0 0 10px #8b5cf6, inset 0 0 5px rgba(139, 92, 246, 0.3) !important;
    color: #ffffff !important;
    font-weight: normal !important;
}

h1 {
    color: #a855f7;
    text-shadow: 0 0 10px rgba(139, 92, 246, 0.5);
    font-family: 'Courier New', monospace;
}
"""

rft_theme = gr.themes.Base(
    primary_hue=gr.themes.colors.purple,
    neutral_hue=gr.themes.colors.slate,
    font=[gr.themes.GoogleFont('Inter'), 'ui-sans-serif', 'system-ui', 'sans-serif'],
).set(
    body_background_fill='#161618',
    body_text_color='#eaeaea',
    block_background_fill='#25252b',
    block_border_width='0px',
    block_label_background_fill='#25252b',
    block_title_text_color='#ffffff',
    input_background_fill='#1a1a1f',
    button_primary_background_fill='#7c3aed',
    button_primary_background_fill_hover='#6d28d9',
    slider_color='#8b5cf6'
)

with gr.Blocks(theme=rft_theme, css=rft_css, title="RAFT Motion Blur") as demo:
    with gr.Row():
        gr.Markdown("# RAFT Motion Blur")

    with gr.Row():
        with gr.Column():
            input_video = gr.Video(label="Input Video", sources=["upload"])
            with gr.Accordion("Settings", open=True):
                blur_direction = gr.Radio(["Forward", "Backward"], value="Backward", label="Blur Direction", info="Forward: Hitung dari frame depan. Backward: Hitung dari frame belakang.")
                tail_expansion = gr.Slider(0, 100, value=40, step=5, label="Tail Expansion", info="Feathering tepi blur (0= Tajam, 30+= Halus).")
                blur_strength = gr.Slider(0.1, 3.0, value=1.0, step=0.05, label="Blur Strength", info="Panjang blur.")
                num_samples = gr.Slider(16, 128, value=128, step=8, label="Blur Samples", info="Kualitas/Kehalusan blur.")
                flow_iterations = gr.Slider(10, 40, value=16, step=2, label="Flow Iterations", info="Akurasi frame tracking (Higher = Slower).")

            with gr.Accordion("Advanced", open=False):
                with gr.Row():
                    vram_safe_mode = gr.Checkbox(value=False, label="Safe Mode", info="Jika video resolusi tinggi sering gagal render.")
                    consistency_check = gr.Checkbox(value=False, label="Consistency Check", info="Experimental.")
                safe_resolution = gr.Slider(960, 1280, value=1280, step=64, label="Safe Mode Limit", info="Resolusi baca engine jika Safe Mode ON.")
                video_quality = gr.Slider(15, 30, value=19, step=1, label="Video Quality (CRF)", info="Semakin kecil semakin jernih.")

            btn_run = gr.Button("RENDER", variant="primary", size="lg")

        with gr.Column():
            output_video = gr.Video(label="Result Preview", interactive=False)
            gr.Markdown("> **Tip:** jika gagal render. Nyalakan *Safe Mode*.")

    btn_run.click(
        fn=process_raft,
        inputs=[input_video, blur_direction, tail_expansion, blur_strength, num_samples, flow_iterations, vram_safe_mode, safe_resolution, consistency_check, video_quality],
        outputs=output_video
    )

demo.queue().launch(share=True, inline=False, debug=False, show_api=False)