In [17]:
import torch, cv2
import numpy as np
from cotracker.utils.visualizer import Visualizer, read_video_from_path
from IPython.display import HTML, display
from base64 import b64encode
from cotracker.predictor import CoTrackerPredictor
import os
import matplotlib.pyplot as plt
import json, subprocess, labelme
import PIL.Image as Image
from labelme import utils

torch.cuda.empty_cache()
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
model = CoTrackerPredictor(
    checkpoint=os.path.join(
        './checkpoints/scaled_offline.pth'
    )
)
video = read_video_from_path("videos/video2.mp4")
video = torch.from_numpy(video).permute(0, 3, 1, 2)[None].float()

if torch.cuda.is_available():
    model = model.cuda()
    video = video.cuda()

mask_laparoscope = np.load('c:/github/occlusions/ischemia_models_pigs/laparoscope_masks/pig1_bowel_baseline.npy')

# Select points and track

In [None]:
video = video[:,0:300,:,:,:]

#select a point, and the model will track it
def select_points(frame):
    frame = frame.transpose(1, 2, 0)  # Convert from (C, H, W) to (H, W, C)
    frame = frame.astype(np.uint8)    # Convert to uint8 data type
    # RGB to BGR
    frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)

    cv2.imshow('frame', frame)
    selected_points = []

    # Callback to capture the selected points
    def mouse_callback(event, x, y, flags, param):
        if event == cv2.EVENT_LBUTTONDOWN:
            cv2.circle(frame, (x, y), 5, (0, 255, 0), -1)
            selected_points.append((x, y))

    cv2.setMouseCallback('frame', mouse_callback)

    while True:
        cv2.imshow('frame', frame)
        if cv2.waitKey(1) & 0xFF == ord('q'):  # Exit on pressing 'q'
            break

    cv2.destroyAllWindows()
    queries = []
    for point in selected_points:
        queries.append([0.0, point[0], point[1]])  # Format: [0., y, x]
    
    queries = torch.tensor(queries)
    if torch.cuda.is_available():
        queries = queries.cuda()
    return queries

queries = select_points(video[0][0].cpu().numpy())

chunk_size = 50  # Number of frames per chunk
pred_tracks = None
pred_visibility = None
for i in range(0, video.shape[1], chunk_size):

    if i == 0:
        queries = queries
    else:
        queries = chunk_pred_tracks[0, -1,:,:]
        queries = torch.cat([torch.tensor([[0]], device=queries.device), queries], dim=1)
    video_chunk = video[:, i:i+chunk_size, :, :]

    # Pad the last chunk if it has fewer frames
    if video_chunk.shape[1] < chunk_size:
        padding = (0, 0, 0, 0, 0, chunk_size - video_chunk.shape[1])  # Pad along dim=1
        video_chunk = F.pad(video_chunk, padding, mode='constant', value=0)
        chunk_pred_tracks = F.pad(chunk_pred_tracks, (0, 0, 0, 0, 0, chunk_size - chunk_pred_tracks.shape[1]), mode='constant', value=0)
        chunk_pred_visibility = F.pad(chunk_pred_visibility, (0, chunk_size - chunk_pred_visibility.shape[1]), mode='constant', value=0)
    else:
        chunk_pred_tracks, chunk_pred_visibility = model(video_chunk, queries=queries[None])

    if pred_tracks is None:
        pred_tracks = chunk_pred_tracks
        pred_visibility = chunk_pred_visibility
    else:
        pred_tracks = torch.cat([pred_tracks, chunk_pred_tracks], dim=1)
        pred_visibility = torch.cat([pred_visibility, chunk_pred_visibility], dim=1)

pred_tracks = pred_tracks[:,:video.shape[1], :, :]
pred_visibility = pred_visibility[:, :video.shape[1],:]

vis = Visualizer(
    save_dir='./videos',
    linewidth=3,
    mode='cool',
    tracks_leave_trace=-1
)
vis.visualize(
    video=video,
    tracks=pred_tracks,
    visibility=pred_visibility,
    filename='queries')

torch.Size([1, 100, 1, 2])
torch.Size([1, 150, 1, 2])
torch.Size([1, 200, 1, 2])
torch.Size([1, 250, 1, 2])
torch.Size([1, 300, 1, 2])


IndexError: too many indices for tensor of dimension 3

In [27]:
pred_visibility.shape

torch.Size([1, 300, 1])

# Track a grid of points along with a segmentation mask

In [None]:
#get mask
def mask_using_labelme(rgb):

    #delete previous files
    if os.path.exists('image_to_annotate.jpg'):
        os.remove('image_to_annotate.jpg')
    if os.path.exists('labelme.json'):
        os.remove('labelme.json')
')
    
    image = Image.fromarray(rgb)
    image.save('image_to_annotate.jpg')
    subprocess.run(['labelme', 'image_to_annotate.jpg'])

    # run labelme in terminal and annotate. Save annotations to json file
    with open('labelme.json') as f:
        data = json.load(f)

    img = utils.img_b64_to_arr(data['imageData'])

    label_name_to_value = {shape['label']: i + 1 for i, shape in enumerate(data['shapes'])}
    lbl, _ = utils.shapes_to_label(img.shape, data['shapes'], label_name_to_value=label_name_to_value)
    return lbl.astype(np.uint8)

mask = mask_using_labelme(video[0][0].cpu().numpy().transpose(1, 2, 0).astype(np.uint8))  
plt.imshow((mask[...,None]*video[0,0].permute(1,2,0).cpu().numpy()/255.)) 

In [None]:
grid_size = 30
pred_tracks, pred_visibility = model(video, grid_size=grid_size, segm_mask=torch.from_numpy(mask)[None, None])
vis = Visualizer(
    save_dir='./videos',
    pad_value=100,
    linewidth=2,
)
vis.visualize(
    video=video,
    tracks=pred_tracks,
    visibility=pred_visibility,
    filename='segm_grid')