In [9]:
# ========== Setup & Imports ==========
import os, shutil
from PIL import Image, ImageSequence
import torch
import torch.nn.functional as F
from torchvision.transforms import functional as TF
from torchvision.models.optical_flow import raft_large, Raft_Large_Weights
from google.colab import files
# Make sure you're in /content
%cd /content
!rm -rf GFPGAN
!git clone https://github.com/TencentARC/GFPGAN.git
%cd GFPGAN

# Reinstall environment
!pip install numpy==1.24.4
!pip install basicsr facexlib realesrgan
!pip install -r requirements.txt
!python setup.py develop

# Patch deprecated torchvision import
!sed -i 's|from torchvision.transforms.functional_tensor import rgb_to_grayscale|from torchvision.transforms.functional import rgb_to_grayscale|' /usr/local/lib/python3.11/dist-packages/basicsr/data/degradations.py

# Download pre-trained model
!mkdir -p experiments/pretrained_models
!wget https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth -P experiments/pretrained_models


/content
Cloning into 'GFPGAN'...
remote: Enumerating objects: 527, done.[K
remote: Counting objects: 100% (253/253), done.[K
remote: Compressing objects: 100% (59/59), done.[K
remote: Total 527 (delta 211), reused 194 (delta 194), pack-reused 274 (from 2)[K
Receiving objects: 100% (527/527), 5.38 MiB | 19.04 MiB/s, done.
Resolving deltas: 100% (282/282), done.
/content/GFPGAN
/usr/local/lib/python3.11/dist-packages/setuptools/__init__.py:94: _DeprecatedInstaller: setuptools.installer and fetch_build_eggs are deprecated.
!!

        ********************************************************************************
        Requirements should be satisfied by a PEP 517 installer.
        If you are using pip, you can try `pip install --use-pep517`.
        ********************************************************************************

!!
  dist.fetch_build_eggs(dist.setup_requires)
running develop
!!

        ***************************************************************************

In [10]:
gif_upload = 'inputs/upload'
shutil.rmtree(gif_upload, ignore_errors=True)
os.makedirs(gif_upload, exist_ok=True)

uploaded = files.upload()
gif_file = list(uploaded.keys())[0]

img = Image.open(gif_file)
frame_paths = []
for i, frame in enumerate(ImageSequence.Iterator(img)):
    path = os.path.join(gif_upload, f"frame_{i:04d}.png")
    frame.convert("RGB").save(path)
    frame_paths.append(path)
print(f"Extracted {len(frame_paths)} frames")



Saving sticker2.gif to sticker2.gif
Extracted 44 frames


In [11]:
!rm -rf results
!python inference_gfpgan.py -i {gif_upload} -o results -v 1.3 -s 2 --bg_upsampler realesrgan
restored_dir = 'results/restored_imgs'
restored_paths = sorted([os.path.join(restored_dir, f)
                         for f in os.listdir(restored_dir) if f.endswith('.png')])

Downloading: "https://github.com/xinntao/facexlib/releases/download/v0.1.0/detection_Resnet50_Final.pth" to /content/GFPGAN/gfpgan/weights/detection_Resnet50_Final.pth

100% 104M/104M [00:00<00:00, 221MB/s] 
Downloading: "https://github.com/xinntao/facexlib/releases/download/v0.2.2/parsing_parsenet.pth" to /content/GFPGAN/gfpgan/weights/parsing_parsenet.pth

100% 81.4M/81.4M [00:00<00:00, 197MB/s]
Processing frame_0000.png ...
	Tile 1/1
Processing frame_0001.png ...
	Tile 1/1
Processing frame_0002.png ...
	Tile 1/1
Processing frame_0003.png ...
	Tile 1/1
Processing frame_0004.png ...
	Tile 1/1
Processing frame_0005.png ...
	Tile 1/1
Processing frame_0006.png ...
	Tile 1/1
Processing frame_0007.png ...
	Tile 1/1
Processing frame_0008.png ...
	Tile 1/1
Processing frame_0009.png ...
	Tile 1/1
Processing frame_0010.png ...
	Tile 1/1
Processing frame_0011.png ...
	Tile 1/1
Processing frame_0012.png ...
	Tile 1/1
Processing frame_0013.png ...
	Tile 1/1
Processing frame_0014.png ...
	Tile 1/1

In [20]:
import torch, gc

# Manually delete variables and clear cache
del raft  # if already partially loaded
torch.cuda.empty_cache()
gc.collect()

# Try to clear memory from other models
if 'gfpgan_model' in globals():
    del gfpgan_model
torch.cuda.empty_cache()
gc.collect()


0

In [None]:
%env PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True

In [22]:
raft = raft_large(weights=weights).to(device).eval()
import os
import torch
import torch.nn.functional as F
from torchvision.transforms import functional as TF
from torchvision.models.optical_flow import raft_large, Raft_Large_Weights
from PIL import Image

# ========= RAFT Optical Flow Setup ==========
device = "cuda" if torch.cuda.is_available() else "cpu"
weights = Raft_Large_Weights.DEFAULT
transforms = weights.transforms()
raft = raft_large(weights=weights).to(device).eval()

# ========= Load frames with resizing ==========
def load_frame(path, resize=(384, 384)):
    img = Image.open(path).convert("RGB").resize(resize)
    return TF.to_tensor(img).unsqueeze(0)  # shape: [1, 3, H, W]

# Load resized image tensors
tensors = [load_frame(p) for p in restored_paths]
flows = []

# Optional environment debug info
print(f"Using GPU: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'CPU'}")
print(f"PYTORCH_CUDA_ALLOC_CONF = {os.environ.get('PYTORCH_CUDA_ALLOC_CONF')}")

# ========= Optical Flow per-frame (batch size = 1) ==========
for i in range(len(tensors) - 1):
    img1, img2 = tensors[i].to(device), tensors[i + 1].to(device)

    # Apply the required transforms (normalize, resize internally, etc.)
    img1p, img2p = transforms(img1, img2)

    # Inference with RAFT
    with torch.no_grad():  # very important to prevent memory leak
        flow_iters = raft(img1p, img2p)
        flow = flow_iters[-1]  # get final refined flow

    flows.append(flow.cpu())  # move to CPU to reduce GPU memory pressure

print(f"✓ Optical flow estimated for {len(flows)} frame pairs")



Using GPU: Tesla T4
PYTORCH_CUDA_ALLOC_CONF = expandable_segments:True
✓ Optical flow estimated for 43 frame pairs


In [24]:

import os, shutil
from PIL import Image, ImageSequence
import torch
import torch.nn.functional as F
from torchvision.transforms import functional as TF
from torchvision.models.optical_flow import raft_large, Raft_Large_Weights
from google.colab import files
# ========== iv. Warping & v. Temporal Loss ==========
warped = []
losses = []
criterion = torch.nn.MSELoss()

for i, flow in enumerate(flows):
    img = tensors[i].to(device)
    target = tensors[i + 1].to(device)
    b, c, h, w = img.shape

    # Move flow to device (fix for the error)
    flow = flow.to(device)

    # Build normalized sampling grid
    grid_y, grid_x = torch.meshgrid(
        torch.arange(h, device=device), torch.arange(w, device=device), indexing='ij'
    )
    grid = torch.stack((grid_x, grid_y), 0).float().unsqueeze(0)  # shape [1, 2, H, W]

    # Apply flow to grid
    vgrid = grid + flow[0]  # [1, 2, H, W]

    # Normalize to [-1, 1] for grid_sample
    vgrid[:, 0] = 2.0 * vgrid[:, 0] / (w - 1) - 1.0
    vgrid[:, 1] = 2.0 * vgrid[:, 1] / (h - 1) - 1.0
    vgrid = vgrid.permute(0, 2, 3, 1)  # [1, H, W, 2]

    # Warp image
    img_warp = F.grid_sample(img, vgrid, align_corners=True)

    # Save warped frame and temporal loss
    warped.append(img_warp)
    loss = criterion(img_warp, target)
    losses.append(loss)

print(f"✓ Temporal losses across frames:\n{[round(l.item(), 4) for l in losses]}")


✓ Temporal losses across frames:
[0.0151, 0.0085, 0.0178, 0.0119, 0.0034, 0.0059, 0.0079, 0.0057, 0.0079, 0.0037, 0.0099, 0.0125, 0.0046, 0.0162, 0.0047, 0.0038, 0.002, 0.0052, 0.0063, 0.0086, 0.0106, 0.0084, 0.0143, 0.0091, 0.0086, 0.0094, 0.021, 0.0084, 0.0105, 0.0262, 0.0262, 0.0366, 0.0472, 0.0294, 0.0356, 0.0276, 0.0107, 0.0079, 0.0112, 0.0118, 0.007, 0.009, 0.0015]


In [None]:
# ========== vi. Re-save or Blend Frames (optional smoothing) ==========

# Here we save enhanced gif as an output.

restored_images = [Image.open(p).convert("RGB") for p in restored_paths]
output = '/content/outputs/enhanced_output.gif'
os.makedirs(os.path.dirname(output), exist_ok=True)
restored_images[0].save(output, save_all=True,
                       append_images=restored_images[1:],
                       duration=100, loop=0)
print(f"Saved enhanced GIF → {output}")

Saved enhanced GIF → /content/outputs/enhanced_output.gif


In [15]:
%env PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True

env: PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True


In [None]:
### Evaluation of Model Performance   #######

In [26]:
from PIL import Image
import numpy as np
from skimage.metrics import peak_signal_noise_ratio as psnr
from skimage.metrics import structural_similarity as ssim
import os

input_dir = "inputs/upload"  # Original extracted frames
restored_dir = "results/restored_imgs"  # Enhanced frames

input_frames = sorted([os.path.join(input_dir, f) for f in os.listdir(input_dir) if f.endswith(".png")])
restored_frames = sorted([os.path.join(restored_dir, f) for f in os.listdir(restored_dir) if f.endswith(".png")])

psnr_scores, ssim_scores = [], []

for input_fp, restored_fp in zip(input_frames, restored_frames):
    gt = Image.open(input_fp).convert("RGB").resize((384, 384))
    restored = Image.open(restored_fp).convert("RGB").resize((384, 384))

    gt_np = np.array(gt)
    restored_np = np.array(restored)

    psnr_scores.append(psnr(gt_np, restored_np))
    ssim_scores.append(ssim(gt_np, restored_np, channel_axis=2))

print(f"\n✅ PSNR (avg): {np.mean(psnr_scores):.2f} dB")
print(f"✅ SSIM (avg): {np.mean(ssim_scores):.4f}")



✅ PSNR (avg): 33.20 dB
✅ SSIM (avg): 0.8566


In [27]:
print(f"\n✅ Average Temporal Loss: {np.mean([l.item() for l in losses]):.4f}")



✅ Average Temporal Loss: 0.0128


In [28]:
!pip install lpips
import lpips
import torch

loss_fn = lpips.LPIPS(net='alex').to(device)
lpips_scores = []

for input_fp, restored_fp in zip(input_frames, restored_frames):
    img0 = TF.to_tensor(Image.open(input_fp).convert("RGB").resize((384, 384))).unsqueeze(0).to(device)
    img1 = TF.to_tensor(Image.open(restored_fp).convert("RGB").resize((384, 384))).unsqueeze(0).to(device)

    d = loss_fn(img0, img1)
    lpips_scores.append(d.item())

print(f"\n✅ LPIPS (avg): {np.mean(lpips_scores):.4f} (lower is better)")


Collecting lpips
  Downloading lpips-0.1.4-py3-none-any.whl.metadata (10 kB)
Downloading lpips-0.1.4-py3-none-any.whl (53 kB)
[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/53.8 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m53.8/53.8 kB[0m [31m5.4 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: lpips
Successfully installed lpips-0.1.4
Setting up [LPIPS] perceptual loss: trunk [alex], v[0.1], spatial [off]


Downloading: "https://download.pytorch.org/models/alexnet-owt-7be5be79.pth" to /root/.cache/torch/hub/checkpoints/alexnet-owt-7be5be79.pth
100%|██████████| 233M/233M [00:01<00:00, 159MB/s]


Loading model from: /usr/local/lib/python3.11/dist-packages/lpips/weights/v0.1/alex.pth

✅ LPIPS (avg): 0.1379 (lower is better)


In [29]:
print(f"\n✅ LPIPS (avg): {np.mean(lpips_scores):.4f} (lower is better)")


✅ LPIPS (avg): 0.1379 (lower is better)
