In [None]:
### Notebook for extracting and annotating clips from Cholec80

### 1. Find frame ranges displaying the same prompt in CholecT50
### 2. Read Cholec80 video (same video but higher FPS than CholecT50)
### 3. Create tool segmentation masks with finetuned YOLOv8 model
### 4. Save resulting mask videos for later use with ControlVideo

In [None]:
import json
import cv2
import os
from PIL import Image
from ultralytics import YOLO
import torch
import random
import numpy as np

def find_equal_subarray_ranges(arr, min_length=10, max_length=10):
    ranges = []
    start = 0
    end = 0
    
    while end < len(arr):
        while end < len(arr) and arr[start] == arr[end] and (end - start != max_length):
            end += 1
            
        if end - start > 1:
            if arr[start] != -1 and end - start >= min_length:
                ranges.append((arr[start],(start, end - 1)))
        start = end
        
    return ranges



label_file = "" # path to a CholecT50 label file
with open(label_file) as f:
    
    label_dict = json.load(f)
    annotations = label_dict["annotations"]
    categories = label_dict["categories"]
    prompt_list = []
    for i in range(len(list(annotations.keys()))):
        i = str(i)

        phase_text = categories["phase"][str(annotations[i][0][14])]
        triplet_ids = [x[0] for x in annotations[i]]
        if isinstance(triplet_ids, int):
            annotation_text = categories['triplet'][str(
                triplet_ids)] + " in " + phase_text
        elif len(triplet_ids) > 1:
            triplet_texts = [categories['triplet']
                             [str(triplet_id)] for triplet_id in triplet_ids]
            annotation_text = " and ".join(triplet_texts) + \
                " in " + phase_text
        else:
            tool_texts = [categories['instrument'][str(x[1])] if x[1]!=-1 else None for x in annotations[i]]
            tool_texts = [x.lower() for x in tool_texts if x is not None]
            annotation_text = " and ".join(tool_texts) + \
                " in " + phase_text if len(tool_texts) > 0 else phase_text
        annotation_text = annotation_text.replace("null_verb", "")
        annotation_text = annotation_text.replace("null_target", "")
        annotation_text = annotation_text.replace(",", " ")
        annotation_text = " ".join(annotation_text.split())
        
        prompt_list.append(annotation_text)

frame_ranges = find_equal_subarray_ranges(prompt_list) 
frame_ranges


In [None]:
import random
len(frame_ranges)
frame_ranges = random.sample(frame_ranges, 50)

In [None]:
clips = []
mask_clips = []
cholec80_path = "" # path to Cholec80 data
YOLO_PATH = "" # path to yolov8 model weights
model = YOLO('')
cap = cv2.VideoCapture(os.path.join(cholec80_path, "videos/", "video{}.mp4".format("01")))

prompt_list = []
for r in frame_ranges:
    print(r[0])
    prompt_list.append(r[0])
    r = (r[1][0]*25, (r[1][1]+1)*25)
    clip  = []
    masks = []
    prev_mask = Image.fromarray(np.zeros((128,128))).convert(mode="L")
    for i in range(*r):
        cap.set(cv2.CAP_PROP_POS_FRAMES, i)
        res, frame = cap.read()
        if res:
            frame = cv2.resize(frame, (128, 128))
            frame = cv2.flip(frame, 0)
            frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
            frame = Image.fromarray(frame)
            
            mask_pred = model(frame, verbose=False)
            if mask_pred[0].masks:
                mask = (torch.any(mask_pred[0].masks.data, dim=0).int()* 255).cpu().numpy()
                mask = Image.fromarray(mask).convert(mode="L")
                mask = mask.resize((128,128))
                mask = mask.transpose(Image.FLIP_LEFT_RIGHT)
                prev_mask = mask
            else:
                mask = prev_mask
                if len(np.argwhere(mask== 1.0)) > 0:
                    idx = random.choice(np.argwhere(mask== 1.0))
                    mask[idx[0]][idx[1]] == 0
            clip.append(frame)
            masks.append(mask)
    clips.append(clip)
    mask_clips.append(masks)


In [None]:
np_mask = np.array(mask_clips)
np_mask.shape

In [None]:
size = (128,128)
fps = 25
dest_path = "./"
for idx,mask in enumerate(np_mask):
    out = cv2.VideoWriter(os.path.join(dest_path, f'output_video{str(idx)}.mp4'), cv2.VideoWriter_fourcc(*'mp4v'), fps, (size[1], size[0]), False)
    for frame in mask:
        out.write(frame)
    out.release()
