Copyright (c) Meta Platforms, Inc. and affiliates.

<a target="_blank" href="https://colab.research.google.com/github/facebookresearch/co-tracker/blob/main/notebooks/demo.ipynb">
  <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
</a>

# CoTracker: It is Better to Track Together
This is a demo for <a href="https://co-tracker.github.io/">CoTracker</a>, a model that can track any point in a video.

<img src="https://www.robots.ox.ac.uk/~nikita/storage/cotracker/bmx-bumps.gif" alt="Logo" width="50%">

Don't forget to turn on GPU support if you're running this demo in Colab. 

**Runtime** -> **Change runtime type** -> **Hardware accelerator** -> **GPU**

Let's install dependencies for Colab:

In [None]:
# !git clone https://github.com/facebookresearch/co-tracker
# %cd co-tracker
# !pip install -e .
# !pip install opencv-python einops timm matplotlib moviepy flow_vis
# !mkdir checkpoints
# %cd checkpoints
# !wget https://dl.fbaipublicfiles.com/cotracker/cotracker_stride_4_wind_8.pth

In [None]:
%cd ..
import os
import torch

from base64 import b64encode
from cotracker.utils.visualizer import Visualizer, read_video_from_path
from IPython.display import HTML
import torch.nn.functional as F

os.environ["CUDA_VISIBLE_DEVICES"] = "2"

Read a video from CO3D:

In [None]:
# video = read_video_from_path('./assets/output.mp4')
video = read_video_from_path("./assets/breakdance.mp4")
video = torch.from_numpy(video).permute(0, 3, 1, 2)[None].float()
# video = F.interpolate(video[0], scale_factor=0.6, mode='bilinear', align_corners=True)[None]
print(video.shape)

In [None]:
def show_video(video_path):
    video_file = open(video_path, "r+b").read()
    video_url = f"data:video/mp4;base64,{b64encode(video_file).decode()}"
    return HTML(f"""<video width="640" height="480" autoplay loop controls><source src="{video_url}"></video>""")
 
show_video("./assets/breakdance.mp4")

Import CoTrackerPredictor and create an instance of it. We'll use this object to estimate tracks:

In [None]:
from cotracker.predictor import CoTrackerPredictor

model = CoTrackerPredictor(
    checkpoint=os.path.join(
        './checkpoints/cotracker_pretrain/cotracker_stride_4_wind_8.pth'
    )
)

In [None]:
if torch.cuda.is_available():
    model = model.cuda()
    video = video.cuda()

Track points sampled on a regular grid of size 30\*30 on the first frame:

In [None]:
import numpy as np
from PIL import Image
import cv2
fps = 1

input_mask = './assets/breakdance.png'
segm_mask = np.array(Image.open(input_mask))
_, T, C, H, W = video.shape
vidLen = video.shape[1]
idx = torch.range(0, vidLen-1, fps).long()
video=video[:, idx]
if len(segm_mask.shape)==3:
    segm_mask = (segm_mask.mean(axis=-1)>0)
segm_mask = cv2.resize(segm_mask, (W, H), interpolation=cv2.INTER_NEAREST)
pred_tracks, pred_visibility = model(video, grid_size=50, backward_tracking=False, segm_mask=torch.from_numpy(segm_mask)[None, None])

Visualize and save the result: 

In [None]:
vis = Visualizer(save_dir='./videos', pad_value=0, tracks_leave_trace=10)
vis.visualize(video=video, tracks=pred_tracks, visibility=pred_visibility, filename=input_mask.split('/')[-1].split('.')[0])

In [None]:
show_video("./videos/teaser_pred_track.mp4")

## SpatialTracker Visualization

In [None]:
# ---------- import the basic packages ------------
%cd ..
import os
import torch

import cv2
from PIL import Image
fps = 1
from base64 import b64encode
from cotracker.utils.visualizer import Visualizer, read_video_from_path
from IPython.display import HTML
import numpy as np
import torch.nn.functional as F

os.environ["CUDA_VISIBLE_DEVICES"] = "2"

# ---------- read the video ------------
video = read_video_from_path("./assets/fan.mp4")
video = torch.from_numpy(video).permute(0, 3, 1, 2)[None].float()
video = F.interpolate(video[0], scale_factor=1.0, mode='bilinear', align_corners=True)[None]
_, T, C, H, W = video.shape

def show_video(video_path):
    video_file = open(video_path, "r+b").read()
    video_url = f"data:video/mp4;base64,{b64encode(video_file).decode()}"
    return HTML(f"""<video width="640" height="480" autoplay loop controls><source src="{video_url}"></video>""")
 
show_video("./assets/fan.mp4")

# ---------- run the spatialtracker ------------
from spatracker_v1.predictor import CoTrackerPredictor
from spatracker_v1.zoeDepth.models.builder import build_model
from spatracker_v1.zoeDepth.utils.config import get_config
from spatracker_v1.utils.visualizer import Visualizer, read_video_from_path
video = video.cuda()
vidLen = video.shape[1]
idx = torch.range(0, vidLen-1, fps).long()
video=video[:, idx]
# init the monocular depth perception
# conf = get_config("zoedepth", "infer", config_version="kitti")
conf = get_config("zoedepth_nk", "infer")
DEVICE = f"cuda:0" if torch.cuda.is_available() else "cpu"
model_zoe_nk = build_model(conf).to(DEVICE)
model_zoe_nk.eval()

model = CoTrackerPredictor(
    checkpoint=os.path.join(
        '/home/xyx/home/codes/co_tracker/checkpoints/spv1/model_cotracker_199375.pth'
    )
)

if torch.cuda.is_available():
    model = model.cuda()
    video = video.cuda()

input_mask = './assets/fan.png'
segm_mask = np.array(Image.open(input_mask))
if len(segm_mask.shape)==3:
    segm_mask = (segm_mask[..., :3].mean(axis=-1)>0).astype(np.uint8)
segm_mask = cv2.resize(segm_mask, (W, H), interpolation=cv2.INTER_NEAREST)
pred_tracks, pred_visibility = model(video, grid_size=50, backward_tracking=False, depth_predictor=model_zoe_nk, segm_mask=torch.from_numpy(segm_mask)[None, None])
vis = Visualizer(save_dir='./videos', pad_value=0, grayscale=True, tracks_leave_trace=5, fps=15)
vis.visualize(video=video, tracks=pred_tracks[..., :2], visibility=pred_visibility, filename='teaser')
show_video("./videos/teaser_pred_track.mp4")

# ---------- visualize the 4D point cloud ------------
import plotly.graph_objects as go
from ipywidgets import interact, IntSlider
import ipywidgets as widgets
xyzt = pred_tracks[0].cpu().numpy()   # T x N x 3
intr = np.array([[W, 0.0, W//2],
                 [0.0, W, H//2],
                 [0.0, 0.0, 1.0]])
xyztVis = xyzt.copy()
xyztVis[..., 2] = 1.0
# xyztVis[..., 0] = 2*(xyztVis[..., 0] / W - 0.5)
# xyztVis[..., 1] = 2*(xyztVis[..., 1] / H - 0.5)

xyztVis = np.linalg.inv(intr[None, ...]) @ xyztVis.reshape(-1, 3, 1) # (TN) 3 1
xyztVis = xyztVis.reshape(T, -1, 3) # T N 3
xyztVis[..., 2] *= xyzt[..., 2]
scatter = go.Scatter3d(
    x=xyztVis[0, :, 0],
    y=xyztVis[0, :, 1],
    z=xyztVis[0, :, 2],
    mode='markers',
    marker=dict(
        size=3,
        color='blue',
        opacity=0.8
    )
)
data = [scatter]

# layout = go.Layout(
#     scene=dict(
#         aspectmode='manual',  
#         xaxis=dict(title='X'),
#         yaxis=dict(title='Y'),
#         zaxis=dict(title='Z')
#     ),
#     uirevision=True
# )

layout = go.Layout(
    scene=dict(
        xaxis=dict(range=[-1.5, 1.5], autorange=False),  # 设置 x 轴范围并禁用自动调整
        yaxis=dict(range=[-1.5, 1.5], autorange=False),  # 设置 y 轴范围并禁用自动调整
        zaxis=dict(range=[-0.5, 30], autorange=False),  # 设置 z 轴范围并禁用自动调整
        aspectmode='manual',
        aspectratio=dict(x=1, y=1, z=1),
    )
)

# fig = go.Figure(data=data, layout=layout)

fig = go.FigureWidget()
scatter = fig.add_scatter3d(x=xyztVis[0, :, 0],
                             y=xyztVis[0, :, 1], z=xyztVis[0, :, 2],
                                 mode='markers',
                                 marker=dict(
                                 size=1,  
                                 color='blue'  
                                ))

slider = IntSlider(min=0, max=T-1, step=1, value=0)

def update(frame):
    fig.data[0].x = xyztVis[frame, :, 0]
    fig.data[0].y = xyztVis[frame, :, 1]
    fig.data[0].z = xyztVis[frame, :, 2]
fig.layout = layout
widgets.interact(update, frame=slider)
display(fig, slider)

# def update(frame):
#     fig.data[0].x = xyztVis[frame, :, 0]
#     fig.data[0].y = xyztVis[frame, :, 1]
#     fig.data[0].z = xyztVis[frame, :, 2]
#     print(frame)
#     # display(fig)

# def quit(obj):
#     print("quit")
#     return
    

# slider = IntSlider(min=0, max=T-1, step=1, value=0)
# btn=widgets.Button (description="quit")
# display(btn)
# display(fig, slider)

# btn.on_click(quit)
# interact(update, frame=slider)

## SpatialTracker Final

In [None]:
# ---------- import the basic packages ------------
%cd ..
import os
import torch

from PIL import Image

from base64 import b64encode
from cotracker.utils.visualizer import Visualizer, read_video_from_path
from IPython.display import HTML
import numpy as np
import torch.nn.functional as F
fps = 1.
os.environ["CUDA_VISIBLE_DEVICES"] = "2"

# ---------- read the video ------------
video = read_video_from_path("./assets/cheetan.mp4")
video = torch.from_numpy(video).permute(0, 3, 1, 2)[None].float()
video = F.interpolate(video[0], scale_factor=0.8, mode='bilinear', align_corners=True)[None]
_, T, C, H, W = video.shape

def show_video(video_path):
    video_file = open(video_path, "r+b").read()
    video_url = f"data:video/mp4;base64,{b64encode(video_file).decode()}"
    return HTML(f"""<video width="640" height="480" autoplay loop controls><source src="{video_url}"></video>""")
 
show_video("./assets/cheetan.mp4")

# ---------- run the spatialtracker ------------
from spatracker1.predictor import CoTrackerPredictor
from spatracker1.zoeDepth.models.builder import build_model
from spatracker1.zoeDepth.utils.config import get_config
from spatracker1.utils.visualizer import Visualizer, read_video_from_path
video = video.cuda()
# init the monocular depth perception
# conf = get_config("zoedepth", "infer", config_version="kitti")
conf = get_config("zoedepth_nk", "infer")
DEVICE = f"cuda:0" if torch.cuda.is_available() else "cpu"
model_zoe_nk = build_model(conf).to(DEVICE)
model_zoe_nk.eval()

model = CoTrackerPredictor(
    checkpoint=os.path.join(
        './checkpoints/spv1_noise_new/model_cotracker_061875.pth'
        # './checkpoints/SpatialTracker/model_cotracker_199375.pth'
    )
)

if torch.cuda.is_available():
    model = model.cuda()
    video = video.cuda()
    
import cv2
input_mask = './assets/cheetan.png'
segm_mask = np.array(Image.open(input_mask))
if len(segm_mask.shape)==3:
    segm_mask = segm_mask.mean(axis=-1)
segm_mask = cv2.resize(segm_mask, (W, H), interpolation=cv2.INTER_NEAREST)
vidLen = video.shape[1]
idx = torch.range(0, vidLen-1, fps).long()
video=video[:, idx]
pred_tracks, pred_visibility = model(video, grid_size=60, backward_tracking=False,
                                     depth_predictor=model_zoe_nk, grid_query_frame=0,
                                     segm_mask=torch.from_numpy(segm_mask)[None, None], add_new_pts=False)

vis = Visualizer(save_dir='./videos', pad_value=0, tracks_leave_trace=10)
vis.visualize(video=video, tracks=pred_tracks[..., :2], visibility=pred_visibility, filename='teaser')
show_video("./videos/teaser_pred_track.mp4")

# ---------- visualize the 4D point cloud ------------
import plotly.graph_objects as go
from ipywidgets import interact, IntSlider
import ipywidgets as widgets
xyzt = pred_tracks[0].cpu().numpy()   # T x N x 3
intr = np.array([[W, 0.0, W//2],
                 [0.0, W, H//2],
                 [0.0, 0.0, 1.0]])
xyztVis = xyzt.copy()
xyztVis[..., 2] = 1.0
# xyztVis[..., 0] = 2*(xyztVis[..., 0] / W - 0.5)
# xyztVis[..., 1] = 2*(xyztVis[..., 1] / H - 0.5)

xyztVis = np.linalg.inv(intr[None, ...]) @ xyztVis.reshape(-1, 3, 1) # (TN) 3 1
xyztVis = xyztVis.reshape(T, -1, 3) # T N 3
xyztVis[..., 2] *= xyzt[..., 2]
scatter = go.Scatter3d(
    x=xyztVis[0, :, 0],
    y=xyztVis[0, :, 1],
    z=xyztVis[0, :, 2],
    mode='markers',
    marker=dict(
        size=3,
        color='blue',
        opacity=0.8
    )
)
data = [scatter]

# layout = go.Layout(
#     scene=dict(
#         aspectmode='manual',  
#         xaxis=dict(title='X'),
#         yaxis=dict(title='Y'),
#         zaxis=dict(title='Z')
#     ),
#     uirevision=True
# )

layout = go.Layout(
    scene=dict(
        xaxis=dict(range=[-1.5, 1.5], autorange=False),  # 设置 x 轴范围并禁用自动调整
        yaxis=dict(range=[-1.5, 1.5], autorange=False),  # 设置 y 轴范围并禁用自动调整
        zaxis=dict(range=[-0.5, 12], autorange=False),  # 设置 z 轴范围并禁用自动调整
        aspectmode='manual',
        aspectratio=dict(x=1, y=1, z=1),
    )
)

# fig = go.Figure(data=data, layout=layout)

fig = go.FigureWidget()
scatter = fig.add_scatter3d(x=xyztVis[0, :, 0],
                             y=xyztVis[0, :, 1], z=xyztVis[0, :, 2],
                                 mode='markers',
                                 marker=dict(
                                 size=1,  
                                 color='blue'  
                                ))

# 1 T N 3
pred_tracks2d = pred_tracks[0][:, :, :2]
S1, N1, _ = pred_tracks2d.shape
video2d = video[0] # T C H W
H1, W1 = video[0].shape[-2:] 
pred_tracks2dNm = pred_tracks2d.clone()
pred_tracks2dNm[..., 0] = 2*(pred_tracks2dNm[..., 0] / W1 - 0.5)
pred_tracks2dNm[..., 1] = 2*(pred_tracks2dNm[..., 1] / H1 - 0.5)
color_interp = torch.nn.functional.grid_sample(video2d, pred_tracks2dNm[:,:,None,:], align_corners=True)
# T N 1 3 
color_interp = color_interp[:, :, :, 0].permute(0,2,1).cpu().numpy()

colored_pts = np.concatenate([xyztVis, color_interp], axis=-1)


slider = IntSlider(min=0, max=T-1, step=1, value=0)

def update(frame):
    fig.data[0].x = xyztVis[frame, :, 0]
    fig.data[0].y = xyztVis[frame, :, 1]
    fig.data[0].z = xyztVis[frame, :, 2]
fig.layout = layout
widgets.interact(update, frame=slider)
display(fig, slider)

np.save('./assets/cheetan.npy', xyztVis)
np.save('./assets/cheetan.npy', colored_pts)

# def update(frame):
#     fig.data[0].x = xyztVis[frame, :, 0]
#     fig.data[0].y = xyztVis[frame, :, 1]
#     fig.data[0].z = xyztVis[frame, :, 2]
#     print(frame)
#     # display(fig)

# def quit(obj):
#     print("quit")
#     return
    

# slider = IntSlider(min=0, max=T-1, step=1, value=0)
# btn=widgets.Button (description="quit")
# display(btn)
# display(fig, slider)

# btn.on_click(quit)
# interact(update, frame=slider)

In [None]:
grid_size = 30
grid_query_frame = 20

In [None]:
pred_tracks, pred_visibility = model(video, grid_size=grid_size, grid_query_frame=grid_query_frame)

In [None]:
vis = Visualizer(save_dir='./videos', pad_value=100)
vis.visualize(
    video=video,
    tracks=pred_tracks,
    visibility=pred_visibility,
    filename='grid_query_20',
    query_frame=grid_query_frame)

Note that tracking starts only from points sampled on a frame in the middle of the video. This is different from the grid in the first example:

In [None]:
show_video("./videos/grid_query_20_pred_track.mp4")

### Tracking forward **and backward** from the frame number x

CoTracker is an online algorithm and tracks points only in one direction. However, we can also run it backward from the queried point to track in both directions: 

In [None]:
grid_size = 30
grid_query_frame = 20

Let's activate backward tracking:

In [None]:
pred_tracks, pred_visibility = model(video, grid_size=grid_size, grid_query_frame=grid_query_frame, backward_tracking=True)
vis.visualize(
    video=video,
    tracks=pred_tracks,
    visibility=pred_visibility,
    filename='grid_query_20_backward',
    query_frame=grid_query_frame)

As you can see, we are now tracking points queried in the middle from the first frame:

In [None]:
show_video("./videos/grid_query_20_backward_pred_track.mp4")

## Regular grid + Segmentation mask

Let's now sample points on a grid and filter them with a segmentation mask.
This allows us to track points sampled densely on an object because we consume less GPU memory.

In [None]:
import numpy as np
from PIL import Image
grid_size = 100

In [None]:
input_mask = './assets/apple_mask.png'
segm_mask = np.array(Image.open(input_mask))

That's a segmentation mask for the first frame:

In [None]:
plt.imshow((segm_mask[...,None]/255.*video[0,0].permute(1,2,0).cpu().numpy()/255.))

In [None]:
pred_tracks, pred_visibility = model(video, grid_size=grid_size, segm_mask=torch.from_numpy(segm_mask)[None, None])
vis = Visualizer(
    save_dir='./videos',
    pad_value=100,
    linewidth=2,
)
vis.visualize(
    video=video,
    tracks=pred_tracks,
    visibility=pred_visibility,
    filename='segm_grid')

We are now only tracking points on the object (and around):

In [None]:
show_video("./videos/segm_grid_pred_track.mp4")

## Dense Tracks

### Tracking forward **and backward** from the frame number x

CoTracker also has a mode to track **every pixel** in a video in a **dense** manner but it is much slower than in previous examples. Let's downsample the video in order to make it faster: 

In [None]:
video.shape

In [None]:
import torch.nn.functional as F
video_interp = F.interpolate(video[0], [100,180], mode="bilinear")[None]

The video now has a much lower resolution:

In [None]:
video_interp.shape

Again, let's track points in both directions. This will only take a couple of minutes:

In [None]:
pred_tracks, pred_visibility = model(video_interp, grid_query_frame=20, backward_tracking=True)

Visualization with an optical flow color encoding:

In [None]:
vis = Visualizer(
    save_dir='./videos',
    pad_value=20,
    linewidth=1,
    mode='optical_flow'
)
vis.visualize(
    video=video_interp,
    tracks=pred_tracks,
    visibility=pred_visibility,
    query_frame=grid_query_frame,
    filename='dense')

In [None]:
show_video("./videos/dense_pred_track.mp4")

That's all, now you can use CoTracker in your projects!