# STTN Video Inpainting in Google Colab

This notebook demonstrates how to run STTN (Spatial-Temporal Transformer Networks) for video inpainting in Google Colab.

**Paper**: Learning Joint Spatial-Temporal Transformations for Video Inpainting (ECCV 2020)

**Original Repository**: https://github.com/researchmm/STTN

## 1. Setup Environment

In [None]:
# Install required packages
!pip install torch>=1.8.0 torchvision>=0.9.0
!pip install opencv-python==4.6.0.66
!pip install Pillow>=8.0.0 numpy>=1.19.0 matplotlib>=3.3.0
!pip install tqdm>=4.49.0 imageio>=2.8.0 scipy>=1.6.0

In [None]:
# Check GPU availability
import torch
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU device: {torch.cuda.get_device_name(0)}")
    print(f"CUDA version: {torch.version.cuda}")
else:
    print("No GPU available - will use CPU (very slow)")

## 2. Upload Project Files

Upload the STTN project folder to Colab, or clone from GitHub:

In [None]:
# Option 1: Clone from GitHub
!git clone https://github.com/researchmm/STTN.git
%cd STTN

In [None]:
# Option 2: Upload files manually
from google.colab import files
import zipfile
import os

# Uncomment if you want to upload a zip file
# uploaded = files.upload()
# 
# # Extract the uploaded zip file
# for filename in uploaded.keys():
#     if filename.endswith('.zip'):
#         with zipfile.ZipFile(filename, 'r') as zip_ref:
#             zip_ref.extractall('.')
#         print(f"Extracted {filename}")

## 3. Download Pretrained Model

In [None]:
# Create checkpoints directory
!mkdir -p checkpoints

# Download pretrained model (you may need to install gdown first)
!pip install gdown
!gdown 1ZAMV8547wmZylKRt5qR_tC5VlosXD4Wv -O checkpoints/sttn.pth

# Verify the model file
import os
if os.path.exists('checkpoints/sttn.pth'):
    size = os.path.getsize('checkpoints/sttn.pth') / (1024*1024)
    print(f"✓ Model downloaded successfully! Size: {size:.1f} MB")
else:
    print("✗ Model download failed!")

## 4. Create Colab-Compatible Test Script

In [None]:
# Create the Colab-compatible test script
test_colab_content = '''
# -*- coding: utf-8 -*-
import cv2
import matplotlib.pyplot as plt
from PIL import Image
import numpy as np
import math
import time
import importlib
import os
import argparse
import copy
import datetime
import random
import sys
import json

import torch
from torch.autograd import Variable

import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
import torch.utils.model_zoo as model_zoo
from torchvision import models
import torch.multiprocessing as mp
from torchvision import transforms

# My libs
from core.utils import Stack, ToTorchFormatTensor

w, h = 432, 240
ref_length = 10
neighbor_stride = 5
default_fps = 24

_to_tensors = transforms.Compose([
    Stack(),
    ToTorchFormatTensor()])

def get_ref_index(neighbor_ids, length):
    ref_index = []
    for i in range(0, length, ref_length):
        if not i in neighbor_ids:
            ref_index.append(i)
    return ref_index

def read_mask(mpath):
    masks = []
    mnames = os.listdir(mpath)
    mnames.sort()
    for m in mnames: 
        m = Image.open(os.path.join(mpath, m))
        m = m.resize((w, h), Image.NEAREST)
        m = np.array(m.convert('L'))
        m = np.array(m > 0).astype(np.uint8)
        m = cv2.dilate(m, cv2.getStructuringElement(
            cv2.MORPH_CROSS, (3, 3)), iterations=4)
        masks.append(Image.fromarray(m*255))
    return masks

def read_frame_from_videos(vname):
    frames = []
    vidcap = cv2.VideoCapture(vname)
    success, image = vidcap.read()
    count = 0
    while success:
        image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
        frames.append(image.resize((w,h)))
        success, image = vidcap.read()
        count += 1
    return frames       

def run_video_inpainting(video_path="examples/schoolgirls_orig.mp4", 
                        mask_path="examples/schoolgirls", 
                        ckpt_path="checkpoints/sttn.pth",
                        model_name="sttn"):
    """
    Main function for video inpainting - Colab friendly
    """
    # Check if files exist
    if not os.path.exists(video_path):
        print(f"Error: Video file {video_path} not found!")
        return None
    if not os.path.exists(mask_path):
        print(f"Error: Mask directory {mask_path} not found!")
        return None
    if not os.path.exists(ckpt_path):
        print(f"Error: Checkpoint file {ckpt_path} not found!")
        return None
    
    # set up models - Colab friendly device selection
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")
    
    net = importlib.import_module('model.' + model_name)
    model = net.InpaintGenerator().to(device)
    data = torch.load(ckpt_path, map_location=device)
    model.load_state_dict(data['netG'])
    print('Loading model from: {}'.format(ckpt_path))
    model.eval()

    # prepare dataset, encode all frames into deep space 
    frames = read_frame_from_videos(video_path)
    video_length = len(frames)
    print(f"Video length: {video_length} frames")
    
    feats = _to_tensors(frames).unsqueeze(0)*2-1
    frames = [np.array(f).astype(np.uint8) for f in frames]

    masks = read_mask(mask_path)
    binary_masks = [np.expand_dims((np.array(m) != 0).astype(np.uint8), 2) for m in masks]
    masks = _to_tensors(masks).unsqueeze(0)
    feats, masks = feats.to(device), masks.to(device)
    comp_frames = [None]*video_length

    with torch.no_grad():
        feats = model.encoder((feats*(1-masks).float()).view(video_length, 3, h, w))
        _, c, feat_h, feat_w = feats.size()
        feats = feats.view(1, video_length, c, feat_h, feat_w)
    print('Loaded videos and masks')

    # completing holes by spatial-temporal transformers
    print("Processing frames...")
    for f in range(0, video_length, neighbor_stride):
        neighbor_ids = [i for i in range(max(0, f-neighbor_stride), min(video_length, f+neighbor_stride+1))]
        ref_ids = get_ref_index(neighbor_ids, video_length)
        with torch.no_grad():
            pred_feat = model.infer(
                feats[0, neighbor_ids+ref_ids, :, :, :], masks[0, neighbor_ids+ref_ids, :, :, :])
            pred_img = torch.tanh(model.decoder(
                pred_feat[:len(neighbor_ids), :, :, :])).detach()
            pred_img = (pred_img + 1) / 2
            pred_img = pred_img.cpu().permute(0, 2, 3, 1).numpy()*255
            for i in range(len(neighbor_ids)):
                idx = neighbor_ids[i]
                img = np.array(pred_img[i]).astype(
                    np.uint8)*binary_masks[idx] + frames[idx] * (1-binary_masks[idx])
                if comp_frames[idx] is None:
                    comp_frames[idx] = img
                else:
                    comp_frames[idx] = comp_frames[idx].astype(
                        np.float32)*0.5 + img.astype(np.float32)*0.5
        
        if (f // neighbor_stride) % 10 == 0:
            print(f"Processed {f + neighbor_stride}/{video_length} frames")
                        
    # Generate output filename
    output_filename = f"{os.path.splitext(os.path.basename(video_path))[0]}_result.mp4"
    print(f"Writing output video: {output_filename}")
    
    writer = cv2.VideoWriter(output_filename, cv2.VideoWriter_fourcc(*"mp4v"), default_fps, (w, h))
    for f in range(video_length):
        comp = np.array(comp_frames[f]).astype(
            np.uint8)*binary_masks[f] + frames[f] * (1-binary_masks[f])
        writer.write(cv2.cvtColor(np.array(comp).astype(np.uint8), cv2.COLOR_RGB2BGR))
    writer.release()
    print('✓ Video inpainting completed!')
    print(f'Output saved as: {output_filename}')
    return output_filename
'''

# Write the script to file
with open('test_colab.py', 'w') as f:
    f.write(test_colab_content)

print("✓ Created test_colab.py")

## 5. Run Video Inpainting

In [None]:
# Check if example files exist
import os

required_files = [
    "examples/schoolgirls_orig.mp4",
    "examples/schoolgirls",
    "checkpoints/sttn.pth"
]

print("Checking required files:")
all_exist = True
for file in required_files:
    exists = os.path.exists(file)
    status = "✓" if exists else "✗"
    print(f"{status} {file}")
    if not exists:
        all_exist = False

if all_exist:
    print("\n✓ All required files found!")
else:
    print("\n✗ Some files are missing. Please upload them first.")

In [None]:
# Run video inpainting
from test_colab import run_video_inpainting

# Run with default parameters (schoolgirls example)
result = run_video_inpainting()

if result:
    print(f"\n🎉 Success! Output video: {result}")
    
    # Display video info
    import cv2
    cap = cv2.VideoCapture(result)
    frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    fps = cap.get(cv2.CAP_PROP_FPS)
    width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
    cap.release()
    
    print(f"Video info: {frame_count} frames, {fps} FPS, {width}x{height}")
else:
    print("❌ Video inpainting failed!")

## 6. Custom Usage

You can also run with your own video and mask files:

In [None]:
# Upload your own video and mask files
from google.colab import files

print("Upload your video file:")
uploaded_video = files.upload()

print("\nUpload your mask files (as a zip):")
uploaded_masks = files.upload()

# Extract mask files if uploaded as zip
import zipfile
for filename in uploaded_masks.keys():
    if filename.endswith('.zip'):
        with zipfile.ZipFile(filename, 'r') as zip_ref:
            zip_ref.extractall('custom_masks')
        print(f"Extracted masks to custom_masks/")
        break

In [None]:
# Run with custom files
# Adjust these paths based on your uploaded files
custom_video = "your_video.mp4"  # Replace with your video filename
custom_masks = "custom_masks"    # Replace with your mask folder

result = run_video_inpainting(
    video_path=custom_video,
    mask_path=custom_masks,
    ckpt_path="checkpoints/sttn.pth"
)

if result:
    print(f"\n🎉 Success! Custom output video: {result}")
else:
    print("❌ Custom video inpainting failed!")

## 7. Download Results

In [None]:
# Download the output video
from google.colab import files
import os

# List all result videos
result_files = [f for f in os.listdir('.') if f.endswith('_result.mp4')]

print("Available result videos:")
for file in result_files:
    size = os.path.getsize(file) / (1024*1024)
    print(f"  {file} ({size:.1f} MB)")

# Download all result videos
for file in result_files:
    print(f"Downloading {file}...")
    files.download(file)