In [1]:
# Copyright (c) Meta Platforms, Inc. and affiliates.

In [2]:
from IPython.display import display, HTML
display(HTML(
"""
<a target="_blank" href="https://colab.research.google.com/github/facebookresearch/co-tracker/blob/main/notebooks/demo.ipynb">
  <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
</a>
"""
))

In [3]:
using_colab = False

In [4]:
if using_colab:
    import torch
    import torchvision
    print("PyTorch version:", torch.__version__)
    print("Torchvision version:", torchvision.__version__)
    print("CUDA is available:", torch.cuda.is_available())
    import sys
    !{sys.executable} -m pip install opencv-python matplotlib
    !{sys.executable} -m pip install 'git+https://github.com/facebookresearch/co-tracker.git'
    
    !mkdir images

In [5]:
import os
import numpy as np
import torch
import matplotlib.pyplot as plt
import cv2
from PIL import Image
from torchvision.io import read_video
from cotracker.utils.visualizer import Visualizer
from IPython.display import HTML

In [6]:
video = read_video('../assets/apple.mp4')[0]
video = video.permute(0, 3, 1, 2)[None].float()

In [7]:
from cotracker.predictor import CoTrackerPredictor

model = CoTrackerPredictor(
    checkpoint=os.path.join(
        '../checkpoints/cotracker_stride_4_wind_8.pth'
    )
)

# Tracking manually selected points

In [13]:
queries = torch.tensor([
    [0., 400., 350.],
    [10., 600., 500.],
    [20., 750., 600.],
    [30., 900., 200.]
]).cuda()
#  visualize points

In [9]:
pred_tracks, __ = model(video, queries=queries[None])
vis = Visualizer(
    save_dir='./videos',
    linewidth=6,
    mode='cool',
    tracks_leave_trace=-1
)
vis.visualize(
    video=video,
    tracks=pred_tracks, 
    filename='queries');

Moviepy - Building video ./videos/queries_pred_track.mp4.
Moviepy - Writing video ./videos/queries_pred_track.mp4



                                                                                                                                                                                                                                                                                                                                                               

Moviepy - Done !
Moviepy - video ready ./videos/queries_pred_track.mp4
Video saved to ./videos/queries_pred_track.mp4


In [10]:
HTML("""
    <video width="640" height="480" autoplay loop controls>
        <source src="./videos/queries_pred_track.mp4" type="video/mp4">
    </video>
""")

# Points on a regular grid

### Tracking forward from the first frame

In [11]:
grid_size = 40

In [12]:
pred_tracks, __ = model(video, grid_size=grid_size)
vis = Visualizer(
    save_dir='./videos',
    pad_value=100,
    linewidth=3,
)
vis.visualize(
    video=video,
    tracks=pred_tracks, 
    filename='grid');

Moviepy - Building video ./videos/grid_pred_track.mp4.
Moviepy - Writing video ./videos/grid_pred_track.mp4



                                                                                                                                                                                                                                                                                                                                                               

Moviepy - Done !
Moviepy - video ready ./videos/grid_pred_track.mp4
Video saved to ./videos/grid_pred_track.mp4


In [14]:
HTML("""
    <video width="640" height="480" autoplay loop controls>
        <source src="./videos/grid_pred_track.mp4" type="video/mp4">
    </video>
""")

### Tracking forward from the frame number 30

In [15]:
grid_size = 40
grid_query_frame = 30

In [17]:
pred_tracks, __ = model(video, grid_size=grid_size, grid_query_frame=grid_query_frame)

In [18]:
vis.visualize(
    video=video,
    tracks=pred_tracks, 
    filename='grid_query_30',
    query_frame=grid_query_frame);

Moviepy - Building video ./videos/grid_query_30_pred_track.mp4.
Moviepy - Writing video ./videos/grid_query_30_pred_track.mp4



                                                                                                                                                                                                                                                                                                                                                               

Moviepy - Done !
Moviepy - video ready ./videos/grid_query_30_pred_track.mp4
Video saved to ./videos/grid_query_30_pred_track.mp4


In [19]:
HTML("""
    <video width="640" height="480" autoplay loop controls>
        <source src="./videos/grid_query_30_pred_track.mp4" type="video/mp4">
    </video>
""")

### Tracking forward **and backward** from the frame number 30

In [20]:
grid_size = 40
grid_query_frame = 30

In [21]:
pred_tracks, __ = model(video, grid_size=grid_size, grid_query_frame=grid_query_frame, backward_tracking=True)

NameError: name 'load_video' is not defined

In [22]:
vis.visualize(
    video=video,
    tracks=pred_tracks, 
    filename='grid_query_30_backward',
    query_frame=grid_query_frame);

Moviepy - Building video ./videos/grid_query_30_backward_pred_track.mp4.
Moviepy - Writing video ./videos/grid_query_30_backward_pred_track.mp4



                                                                                                                                                                                                                                                                                                                                                               

Moviepy - Done !
Moviepy - video ready ./videos/grid_query_30_backward_pred_track.mp4
Video saved to ./videos/grid_query_30_backward_pred_track.mp4


In [23]:
HTML("""
    <video width="640" height="480" autoplay loop controls>
        <source src="./videos/grid_query_30_backward_pred_track.mp4" type="video/mp4">
    </video>
""")

# Regular grid + Segmentation mask

In [24]:
grid_size = 120

In [25]:
input_mask = '../assets/apple_mask.png'
segm_mask = np.array(Image.open(input_mask))
segm_mask = torch.from_numpy(segm_mask)[None, None]

In [26]:
pred_tracks, __ = model(video, grid_size=grid_size, segm_mask=segm_mask)
vis = Visualizer(
    save_dir='./videos',
    pad_value=100,
    linewidth=2,
)
vis.visualize(
    video=video,
    tracks=pred_tracks, 
    filename='segm_grid');

Moviepy - Building video ./videos/segm_grid_pred_track.mp4.
Moviepy - Writing video ./videos/segm_grid_pred_track.mp4



                                                                                                                                                                                                                                                                                                                                                               

Moviepy - Done !
Moviepy - video ready ./videos/segm_grid_pred_track.mp4
Video saved to ./videos/segm_grid_pred_track.mp4


In [27]:
HTML("""
    <video width="640" height="480" autoplay loop controls>
        <source src="./videos/segm_grid_pred_track.mp4" type="video/mp4">
    </video>
""")

# Dense Tracks

### Tracking forward **and backward** from the frame number 25

In [35]:
video.shape

torch.Size([1, 48, 3, 719, 1282])

In [28]:
import torch.nn.functional as F
video_interp = F.interpolate(video[0], [100,180], mode="bilinear")[None].cuda()

In [29]:
video_interp.shape

torch.Size([1, 48, 3, 100, 180])

In [30]:
pred_tracks, __ = model(video_interp, grid_query_frame=25, backward_tracking=True)


100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 9/9 [02:08<00:00, 14.24s/it]


In [33]:
vis = Visualizer(
    save_dir='./videos',
    pad_value=20,
    linewidth=1,
    mode='optical_flow'
)
vis.visualize(
    video=video_interp,
    tracks=pred_tracks, 
    query_frame=grid_query_frame,
    filename='dense');
from IPython.display import HTML

HTML("""
    <video width="320" height="240" autoplay loop controls>
        <source src="./videos/dense_pred_track.mp4" type="video/mp4">
    </video>
""")

Moviepy - Building video ./videos/dense_pred_track.mp4.
Moviepy - Writing video ./videos/dense_pred_track.mp4



                                                                                                                                                                                                                                                                                                                                                               

Moviepy - Done !
Moviepy - video ready ./videos/dense_pred_track.mp4
Video saved to ./videos/dense_pred_track.mp4


