In [2]:
import os
import torch

from base64 import b64encode
from cotracker.utils.visualizer import Visualizer, read_video_from_path
from IPython.display import HTML

filename = 'Triathlon_Women_Tokyo_2020_29'

video = read_video_from_path('./videos/Triathlon_Women_Tokyo_2020/' + filename + '.mp4')
video = torch.from_numpy(video).permute(0, 3, 1, 2)[None].float()



In [3]:

import torch 
use_cuda = torch. cuda. is_available()

print(torch.__version__, torch.cuda.is_available())

2.1.1+cu121 True


Import CoTrackerPredictor and create an instance of it. We'll use this object to estimate tracks:

In [4]:
from cotracker.predictor import CoTrackerPredictor

#model = torch.hub.load("facebookresearch/co-tracker", "cotracker_w8")

model = CoTrackerPredictor(
    checkpoint=os.path.join(
        './co-tracker/checkpoints/cotracker_stride_4_wind_8.pth'
    )
)

if torch.cuda.is_available():
    model = model.cuda()
    video = video.cuda()

Tracking manually selected points

In [5]:
import os, argparse, json, re
from collections import defaultdict


def load_annotations(file_path):
    frames = defaultdict(list)
    isExist = os.path.exists(file_path)
    if isExist:
      annotations = []
      with open(file_path, 'r') as f:
          annotations = json.load(f)
          
          keypoints = defaultdict(list)
          for person in annotations['annotations']:
            framecount = len(person['frames'])
            for frame_index in range(0, framecount):
              points = []
              if frame_index < framecount:
                frame = person['frames'][str(frame_index)]
                for node in frame['skeleton']['nodes']:
                    points.append({'id': node['name'], 'x' : node['x'], 'y': node['y']})

              if len(points) > 0:
                keypoints[frame_index].append({'person': {'points': points}})

          for frame in range(0, len(keypoints)):
            frames[frame] = keypoints[frame]

    return frames


def get_queries_for_frame(frame_number, annotations):
  i = 0

  for person in annotations[frame_number]:
    for point in person['person']['points']:
      new_tensor = torch.tensor([float(frame_number), point['x'], point['y']])
      if i == 0:
        queries_for_frames = new_tensor
      else:
        queries_for_frames = torch.vstack((queries_for_frames, new_tensor))
      i += 1

  return queries_for_frames


def get_queries_for_frames(start_frame, end_frame, annotations):
  frame = 0
  queries_for_frames = None
  for frame in range(start_frame, end_frame):
    for person in annotations[frame]:
      for point in person['person']['points']:
        new_tensor = torch.tensor([float(frame), point['x'], point['y']])
        if queries_for_frames is None:
          queries_for_frames = new_tensor
        else:
          queries_for_frames = torch.vstack((queries_for_frames, new_tensor))

  return queries_for_frames

In [6]:
grid_query_frame = 0

annot_json = 'videos/annotations/' + filename + '.json'
annotations = load_annotations(annot_json)

queries = get_queries_for_frames(0, 5, annotations)

if torch.cuda.is_available():
    queries = queries.cuda()

In [7]:
import pandas as pd
df_queries = pd.DataFrame(queries.cpu().numpy())
mask_frame = df_queries[0] == 1.0
df = df_queries.loc[mask_frame]
df

Unnamed: 0,0,1,2
26,1.0,562.609985,385.279999
27,1.0,786.75,162.720001
28,1.0,759.710022,148.389999
29,1.0,742.219971,165.919998
30,1.0,740.780029,301.029999
31,1.0,736.929993,109.389999
32,1.0,715.080017,1003.700012
33,1.0,725.549988,147.860001
34,1.0,639.159973,276.549988
35,1.0,863.669983,281.0


In [8]:
pred_tracks, pred_visibility = model(video, queries=queries[None])


In [19]:
import PoseEstimation.customutils as customutils

df_running_annotations = customutils.load_images_dataframe()
df_running_annotations

Unnamed: 0,image_id,file_name,frame
0,1000,Athletics_Mixed_Tokyo_2020_20_1.mp4,0
1,1001,Athletics_Mixed_Tokyo_2020_20_1.mp4,1
2,1002,Athletics_Mixed_Tokyo_2020_20_1.mp4,2
3,1003,Athletics_Mixed_Tokyo_2020_20_1.mp4,3
4,1004,Athletics_Mixed_Tokyo_2020_20_1.mp4,4
...,...,...,...
5586,46160,World_Athletics_Women_Marathon_Oregon_2022_8.mp4,160
5587,46161,World_Athletics_Women_Marathon_Oregon_2022_8.mp4,161
5588,46162,World_Athletics_Women_Marathon_Oregon_2022_8.mp4,162
5589,46163,World_Athletics_Women_Marathon_Oregon_2022_8.mp4,163


In [21]:
mask_file = df_running_annotations['file_name'] == filename + '.mp4'
mask_frame = df_running_annotations['frame'] == 0
image_id = df_running_annotations.loc[mask_file & mask_frame]['image_id'].iloc[0]
image_id

16000

In [25]:


num_frames = pred_tracks.cpu().shape[1]
np_pred = pred_tracks.cpu().numpy()
df = pd.DataFrame(index=range(num_frames), columns=[str(x) for x in range(num_frames)])
frames = []
for i in range(num_frames):
    keypoints = []
    for j in range(26):
        keypoints.append(np_pred[0][i][j][0].astype(float))
        keypoints.append(np_pred[0][i][j][1].astype(float))
        keypoints.append(2)

    frame = {}
    frame['image_id'] = int(image_id) + i
    frame['category_id'] = 1
    frame['keypoints'] = keypoints
    frames.append(frame)

customutils.writeJson(frames,'videos/results/cotracker/' + filename + '.json')

In [47]:

vis = Visualizer(
    save_dir='./videos',
    linewidth=6,
    mode='cool',
    tracks_leave_trace=0
)
vis.visualize(
    video=video,
    tracks=pred_tracks,
    visibility=pred_visibility,
    filename='queries_5',
    query_frame=grid_query_frame)

Video saved to ./videos/queries_5_pred_track.mp4


tensor([[[[[145, 173, 173,  ..., 149, 149, 121],
           [174, 203, 203,  ..., 180, 180, 152],
           [178, 206, 207,  ..., 180, 180, 152],
           ...,
           [180, 215, 216,  ..., 122, 121, 102],
           [179, 214, 215,  ..., 121, 120, 101],
           [151, 186, 187,  ..., 102, 101,  82]],

          [[145, 173, 173,  ..., 157, 157, 129],
           [174, 203, 203,  ..., 188, 188, 160],
           [178, 206, 207,  ..., 186, 186, 158],
           ...,
           [170, 205, 206,  ..., 147, 145, 126],
           [169, 204, 205,  ..., 146, 144, 125],
           [141, 176, 177,  ..., 127, 125, 106]],

          [[142, 170, 170,  ..., 166, 166, 138],
           [171, 200, 200,  ..., 197, 197, 169],
           [175, 203, 204,  ..., 189, 189, 161],
           ...,
           [171, 206, 207,  ..., 120, 121, 102],
           [170, 205, 206,  ..., 119, 120, 101],
           [142, 177, 178,  ..., 100, 101,  82]]],


         [[[145, 173, 173,  ..., 149, 149, 121],
           [1