# Make things disappear with XMem and FGT

Sources:
- [ECCV 2022] XMem: Long-Term Video Object Segmentation with an Atkinson-Shiffrin Memory Model: https://github.com/hkchengrex/XMem
- [ECCV 2022] Flow-Guided Transformer for Video Inpainting: https://github.com/hitachinsk/fgt

In [4]:
try:
    import torch
    import torchvision
except ImportError:
    !pip install torch==1.10.1
    !pip install torchvision==0.11.2

In [None]:
!nvidia-smi

if torch.cuda.is_available():
    print('Using GPU')
    device = 'cuda'
else:
    print('CUDA not available. Please connect to a GPU instance if possible.')
    device = 'cpu'

## (a) Load data
### Download video from YouTube and split into frames
- Source: https://huggingface.co/spaces/YiYiXu/it-happened-one-frame-2

In [None]:
import os
from os.path import exists as path_exists

In [None]:
if not path_exists('helper.py'):
    !wget https://raw.githubusercontent.com/machinelearnear/make-things-disappear-with-XMem-and-FGT/main/helper.py

In [None]:
from helper import vid2frames

In [None]:
youtube_url = 'https://youtu.be/KOnfiFOCwH0' # Trump leaves Argentinean president alone on stage at G20
video_name = f'videos/{youtube_url.split("/")[-1]}.mp4'

In [None]:
if not path_exists(video_name):
    skip_frames, path_frames = vid2frames(youtube_url)

### Preview the video 

In [None]:
from IPython.display import HTML
from base64 import b64encode
data_url = "data:video/mp4;base64," + b64encode(open(video_name, 'rb').read()).decode()
HTML("""
<video width=400 controls>
      <source src="%s" type="video/mp4">
</video>
""" % data_url)

## (b) Segment first-frame

In [2]:
try:
    import transformers
except ImportError:
    !pip install transformers
    import transformers

In [11]:
import numpy as np

In [5]:
from transformers import MobileViTFeatureExtractor, MobileViTForSemanticSegmentation
from PIL import Image
import requests

url = "http://images.cocodataset.org/val2017/000000039769.jpg"
image = Image.open(requests.get(url, stream=True).raw)

feature_extractor = MobileViTFeatureExtractor.from_pretrained("apple/deeplabv3-mobilevit-small")
model = MobileViTForSemanticSegmentation.from_pretrained("apple/deeplabv3-mobilevit-small")

inputs = feature_extractor(images=image, return_tensors="pt")

with torch.no_grad():
    outputs = model(**inputs)

# logits are of shape (batch_size, num_labels, height, width)
logits = outputs.logits

In [13]:
palette = np.array(
[
    [  0,   0,   0], [192,   0,   0], [  0, 192,   0], [192, 192,   0],
    [  0,   0, 192], [192,   0, 192], [  0, 192, 192], [192, 192, 192],
    [128,   0,   0], [255,   0,   0], [128, 192,   0], [255, 192,   0],
    [128,   0, 192], [255,   0, 192], [128, 192, 192], [255, 192, 192],
    [  0, 128,   0], [192, 128,   0], [  0, 255,   0], [192, 255,   0],
    [  0, 128, 192]
],
dtype=np.uint8)

labels = [
    "background",
    "aeroplane",
    "bicycle",
    "bird",
    "boat",
    "bottle",
    "bus",
    "car",
    "cat",
    "chair",
    "cow",
    "diningtable",
    "dog",
    "horse",
    "motorbike",
    "person",
    "pottedplant",
    "sheep",
    "sofa",
    "train",
    "tvmonitor",
]

In [14]:
resized = (inputs["pixel_values"].numpy().squeeze().transpose(1, 2, 0)[..., ::-1] * 255).astype(np.uint8)

# Class predictions for each pixel.
classes = outputs.logits.argmax(1).squeeze().numpy().astype(np.uint8)

# Super slow method but it works... should probably improve this.
colored = np.zeros((classes.shape[0], classes.shape[1], 3), dtype=np.uint8)
for y in range(classes.shape[0]):
    for x in range(classes.shape[1]):
        colored[y, x] = palette[classes[y, x]]

# Resize predictions to input size (not original size).
colored = Image.fromarray(colored)
colored = colored.resize((resized.shape[1], resized.shape[0]), resample=Image.Resampling.NEAREST)

# Keep everything that is not background.
mask = (classes != 0) * 255
mask = Image.fromarray(mask.astype(np.uint8)).convert("RGB")
mask = mask.resize((resized.shape[1], resized.shape[0]), resample=Image.Resampling.NEAREST)

# Blend with the input image.
resized = Image.fromarray(resized)
highlighted = Image.blend(resized, mask, 0.4)

### Preview first-frame annotation
The first frame mask is a PNG with a color palette.

In [None]:
import IPython.display
IPython.display.Image('masks/0.png', width=400)

### Convert the mask to a numpy array
Note that the object IDs should be consecutive and start from `1` (`0` represents the background). 
If they are not, see `inference.data.mask_mapper` and `eval.py` on how to use it.

In [None]:
import numpy as np
from PIL import Image

In [None]:
mask = np.array(Image.open('masks/0.png'))
print(np.unique(mask))
num_objects = len(np.unique(mask)) - 1

## (c) Long-Term Video Object Segmentation with an Atkinson-Shiffrin Memory Model
- Source: https://colab.research.google.com/drive/1RXK5QsUo2-CnOiy5AOSjoZggPVHOPh1m?usp=sharing#scrollTo=MWGdN7XCSYSm

### Get our code and install pre-requisites

In [None]:
if not path_exists('XMem'):
    !git clone https://github.com/hkchengrex/XMem.git
    !pip install -r XMem/requirements.txt

### Download the pre-trained model

In [None]:
if not path_exists('XMem/saves/XMem.pth'):
    !wget -P ./XMem/saves/ https://github.com/hkchengrex/XMem/releases/download/v1.0/XMem.pth

### Basic setup

In [None]:
import os
from os import path
from argparse import ArgumentParser
import shutil

import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
import numpy as np
from PIL import Image

from XMem.inference.data.test_datasets import LongTestDataset, DAVISTestDataset, YouTubeVOSTestDataset
from XMem.inference.data.mask_mapper import MaskMapper
from XMem.model.network import XMem
from XMem.inference.inference_core import InferenceCore

from XMem.progressbar import progressbar

torch.set_grad_enabled(False)

# default configuration
config = {
    'top_k': 30,
    'mem_every': 5,
    'deep_update_every': -1,
    'enable_long_term': True,
    'enable_long_term_count_usage': True,
    'num_prototypes': 128,
    'min_mid_term_frames': 5,
    'max_mid_term_frames': 10,
    'max_long_term_elements': 10000,
}

network = XMem(config, './XMem/saves/XMem.pth').eval().to(device)

### Propagate frame-by-frame

In [None]:
import cv2
from inference.interact.interactive_utils import image_to_torch, index_numpy_to_one_hot_torch, torch_prob_to_numpy_mask, overlay_davis

torch.cuda.empty_cache()

processor = InferenceCore(network, config=config)
processor.set_all_labels(range(1, num_objects+1)) # consecutive labels
cap = cv2.VideoCapture(video_name)

# You can change these two numbers
frames_to_propagate = 200
visualize_every = 20

current_frame_index = 0

with torch.cuda.amp.autocast(enabled=True):
    while (cap.isOpened()):
    # load frame-by-frame
    _, frame = cap.read()
    if frame is None or current_frame_index > frames_to_propagate:
        break

    # convert numpy array to pytorch tensor format
    frame_torch, _ = image_to_torch(frame, device=device)
    if current_frame_index == 0:
        # initialize with the mask
        mask_torch = index_numpy_to_one_hot_torch(mask, num_objects+1).to(device)
        # the background mask is not fed into the model
        prediction = processor.step(frame_torch, mask_torch[1:])
    else:
        # propagate only
        prediction = processor.step(frame_torch)

    # argmax, convert to numpy
    prediction = torch_prob_to_numpy_mask(prediction)

    if current_frame_index % visualize_every == 0:
        visualization = overlay_davis(frame, prediction)
        display(Image.fromarray(visualization))

    current_frame_index += 1