**Load Model**

In [None]:
import torch
from model import MattingNetwork

In [None]:
model = MattingNetwork(variant='mobilenetv3').eval().cuda()
model.load_state_dict(torch.load('checkpoints/rvm_mobilenetv3.pth'))

**Inference on video**

In [None]:
from torch.utils.data import DataLoader
from torchvision.transforms import ToTensor
from inference_utils import VideoReader, VideoWriter

reader = VideoReader('videos/footage-1.mp4', transform=ToTensor())
writer = VideoWriter('videos/output.mp4', frame_rate=30)

bgr = torch.tensor([.47, 1, .6]).view(3, 1, 1).cuda()  # Green background.
rec = [None] * 4                                       # Initial recurrent states.

with torch.no_grad():
    for src in DataLoader(reader):
        fgr, pha, *rec = model(src.cuda(), *rec, downsample_ratio=0.25)  # Cycle the recurrent states.
        writer.write(fgr * pha + bgr * (1 - pha))

In [None]:
from inference import convert_video

In [None]:
from inference import convert_video

convert_video(
    model,                                                  # The loaded model, can be on any device (cpu or cuda).
    input_source='videos/footage-1.mp4',                    # A video file or an image sequence directory.
    input_resize=(1920, 1080),                              # [Optional] Resize the input (also the output).
    downsample_ratio=0.25,                                  # [Optional] If None, make downsampled max size be 512px.
    output_type='video',                                    # Choose "video" or "png_sequence"
    output_composition='videos/output.mp4',                            # File path if video; directory path if png sequence.
    output_alpha="pha.mp4",                                 # [Optional] Output the raw alpha prediction.
    output_foreground="fgr.mp4",                            # [Optional] Output the raw foreground prediction.
    output_video_mbps=4,                                    # Output video mbps. Not needed for png sequence.
    seq_chunk=12,                                           # Process n frames at once for better parallelism.
    num_workers=1,                                          # Only for image sequence input. Reader threads.
    progress=True                                           # Print conversion progress.
)