In [1]:
!nvidia-smi

import torch

if torch.cuda.is_available():
  print('Using GPU')
  device = 'cuda'
else:
  print('CUDA not available. Please connect to a GPU instance if possible.')
  device = 'cpu'

Wed Mar 20 13:01:39 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.104.05             Driver Version: 535.104.05   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  Tesla V100-SXM2-16GB           Off | 00000000:00:04.0 Off |                    0 |
| N/A   43C    P0              24W / 300W |      0MiB / 16384MiB |      0%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
                                                                    

# Get our code and install prerequisites

In [2]:
!git clone https://github.com/hkchengrex/XMem.git
%cd XMem
!pip install opencv-python
!pip install -U numpy
!pip install -r requirements.txt

Cloning into 'XMem'...
remote: Enumerating objects: 608, done.[K
remote: Counting objects: 100% (336/336), done.[K
remote: Compressing objects: 100% (136/136), done.[K
remote: Total 608 (delta 240), reused 217 (delta 199), pack-reused 272[K
Receiving objects: 100% (608/608), 269.38 KiB | 848.00 KiB/s, done.
Resolving deltas: 100% (352/352), done.
/content/XMem
Collecting numpy
  Downloading numpy-1.26.4-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (18.2 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m18.2/18.2 MB[0m [31m70.4 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: numpy
  Attempting uninstall: numpy
    Found existing installation: numpy 1.25.2
    Uninstalling numpy-1.25.2:
      Successfully uninstalled numpy-1.25.2
Successfully installed numpy-1.26.4
Collecting git+https://github.com/cheind/py-thin-plate-spline (from -r requirements.txt (line 4))
  Cloning https://github.com/cheind/py-thin-plate-spline to /tmp/pip-re

# Download the pretrained model



In [3]:
!wget -P ./saves/ https://github.com/hkchengrex/XMem/releases/download/v1.0/XMem.pth

--2024-03-20 13:03:09--  https://github.com/hkchengrex/XMem/releases/download/v1.0/XMem.pth
Resolving github.com (github.com)... 140.82.121.4
Connecting to github.com (github.com)|140.82.121.4|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://objects.githubusercontent.com/github-production-release-asset-2e65be/511262077/ea2968ee-04ab-4dee-8596-03319e8c7e9f?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=AKIAVCODYLSA53PQK4ZA%2F20240320%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20240320T130309Z&X-Amz-Expires=300&X-Amz-Signature=e41168ef6712387fd044bd66ca077ca7c7d5a348f0c0da45e5f8633cd80b86cb&X-Amz-SignedHeaders=host&actor_id=0&key_id=0&repo_id=511262077&response-content-disposition=attachment%3B%20filename%3DXMem.pth&response-content-type=application%2Foctet-stream [following]
--2024-03-20 13:03:09--  https://objects.githubusercontent.com/github-production-release-asset-2e65be/511262077/ea2968ee-04ab-4dee-8596-03319e8c7e9f?X-Amz-Algorithm=AWS4-HMAC

# Basic setup

In [4]:
import os
from os import path
from argparse import ArgumentParser
import shutil

import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
import numpy as np
from PIL import Image

from inference.data.test_datasets import LongTestDataset, DAVISTestDataset, YouTubeVOSTestDataset
from inference.data.mask_mapper import MaskMapper
from model.network import XMem
from inference.inference_core import InferenceCore

from progressbar import progressbar

torch.set_grad_enabled(False)

# default configuration
config = {
    'top_k': 30,
    'mem_every': 5,
    'deep_update_every': -1,
    'enable_long_term': True,
    'enable_long_term_count_usage': True,
    'num_prototypes': 128,
    'min_mid_term_frames': 5,
    'max_mid_term_frames': 10,
    'max_long_term_elements': 10000,
}

network = XMem(config, './saves/XMem.pth').eval().to(device)

Hyperparameters read from the model weights: C^k=64, C^v=512, C^h=64
Single object mode: False


Downloading: "https://download.pytorch.org/models/resnet50-19c8e357.pth" to /root/.cache/torch/hub/checkpoints/resnet50-19c8e357.pth
100%|██████████| 97.8M/97.8M [00:00<00:00, 380MB/s]
Downloading: "https://download.pytorch.org/models/resnet18-5c106cde.pth" to /root/.cache/torch/hub/checkpoints/resnet18-5c106cde.pth
100%|██████████| 44.7M/44.7M [00:00<00:00, 328MB/s]


# Test data

Unzip the input.zip (zip with original video and masks) in the XMem folder before proceeding

# Propagate frame-by-frame

In [22]:
import os

def generate_overlay_video(mask_name, video_name):
    mask = np.array(Image.open(mask_name))
    mask[mask == 255] = 1
    print(np.unique(mask))
    num_objects = len(np.unique(mask)) - 1

    torch.cuda.empty_cache()

    processor = InferenceCore(network, config=config)
    processor.set_all_labels(range(1, num_objects+1)) # consecutive labels
    cap = cv2.VideoCapture(video_name)

    # You can change these two numbers
    frames_to_propagate = 200
    visualize_every = 1

    current_frame_index = 0

    # Video writer settings
    folder_name = os.path.join('results',video_name.split('.')[0])
    os.makedirs(folder_name, exist_ok=True)

    output_fps = 30
    output_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    output_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))

    output_video_name = os.path.join(folder_name,  "out-" + video_name)
    output_video_writer = cv2.VideoWriter(output_video_name, cv2.VideoWriter_fourcc(*'mp4v'), output_fps, (output_width, output_height))

    with torch.cuda.amp.autocast(enabled=True):
        while (cap.isOpened()):
            # load frame-by-frame
            ret, frame = cap.read()
            if not ret or current_frame_index > frames_to_propagate:
                break

            # convert numpy array to pytorch tensor format
            frame_torch, _ = image_to_torch(frame, device=device)
            if current_frame_index == 0:
                # initialize with the mask
                mask_torch = index_numpy_to_one_hot_torch(mask, num_objects+1).to(device)
                # the background mask is not fed into the model
                prediction = processor.step(frame_torch, mask_torch[1:])
            else:
                # propagate only
                prediction = processor.step(frame_torch)

            # argmax, convert to numpy
            prediction = torch_prob_to_numpy_mask(prediction)

            if current_frame_index % visualize_every == 0:
                visualization = overlay_davis(frame, prediction)

                # Save mask
                name_png = f'out-{mask_name.split("-")[0]}-{str(current_frame_index+1).zfill(3)}.png'
                mask_output_name = os.path.join(folder_name, name_png)
                # Convert only values equal to 1 to 255 and keep other values unchanged
                cv2.imwrite(mask_output_name, (mask == 1).astype(np.uint8) * 255)

                # Write frame to output video
                output_video_writer.write(cv2.cvtColor(visualization, cv2.COLOR_RGB2BGR))
                # Display the frame (optional)
                # display(Image.fromarray(visualization))

            current_frame_index += 1

    # Release video capture and writer objects
    cap.release()
    output_video_writer.release()

    return output_video_name

In [23]:
for seq in ['bag', 'bear', 'book', 'swan', 'rhino', 'camel']:
  output_video_path = generate_overlay_video(f'{seq}-001.png', f'{seq}.mp4')

[0 1]
[0 1]
[0 1]
[0 1]
[0 1]
[0 1]


In [19]:
# Run on terminal to zip results folder
#zip -r /content/XMem/results.zip /content/XMem/results

UsageError: Line magic function `%zip` not found.


In [24]:
# Download results
from google.colab import files
files.download("/content/XMem/results.zip")

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>