In [1]:
import argparse
import gdown
import cv2
import numpy as np
import os
import sys
sys.path.append(sys.path[0]+"/tracker")
sys.path.append(sys.path[0]+"/tracker/model")
from track_anything import TrackingAnything
from track_anything import parse_augment
import requests
import json
import torchvision
import torch 
from tools.painter import mask_painter
import psutil
import time
try: 
    from mmcv.cnn import ConvModule
except:
    os.system("mim install mmcv")
import matplotlib.pyplot as plt
from pycocotools import mask as maskUtils


In [2]:


ovis_anotations = '../data.nosync/OVIS/annotations/'
ovis_images = '../data.nosync/OVIS/train_images/'

ovis_anotations = 'D:/HADA/data/OVIS/annotations/'
ovis_images = 'D:/HADA/data/OVIS/train_images/'

In [3]:
def cargarDatos(ruta_ann):
    with open(ruta_ann + 'annotations_train.json') as f:
        annotationsTrain = json.load(f)

    with open(ruta_ann + 'annotations_valid.json') as f:
        annotationsValid = json.load(f)

    with open(ruta_ann + 'annotations_test.json') as f:
        annotationsTest = json.load(f)

    clases = annotationsTrain['categories']
    vidTrain = annotationsTrain['videos']
    annTrain = annotationsTrain['annotations']
    vidValid = annotationsValid['videos']
    annValid = annotationsValid['annotations']
    vidTest = annotationsTest['videos']
    annTest = annotationsTest['annotations']

    return clases, vidTrain, annTrain, vidValid, annValid, vidTest, annTest

clases, vidTrain, annTrain, vidValid, annValid, vidTest, annTest = cargarDatos(ovis_anotations) 

In [4]:
def annToRLE(ann, frameId):
    """
    Convert annotation which can be polygons, uncompressed RLE to RLE.
    :return: binary mask (numpy 2D array)
    """
    h, w = ann['height'], ann['width']
    segm = ann['segmentations'][frameId]
    if segm is None:
        return None
    if type(segm) == "list":
        # polygon -- a single object might consist of multiple parts
        # we merge all parts into one mask rle code
        rles = maskUtils.frPyObjects(segm, h, w)
        rle = maskUtils.merge(rles)
    elif type(segm['counts']) == "list":
        # uncompressed RLE
        rle = maskUtils.frPyObjects(segm, h, w)
    else:
        # rle
        rle = segm
    return rle


def annToMask(ann, frameId):
    """
    Convert annotation which can be polygons, uncompressed RLE, or RLE to binary mask.
    :return: binary mask (numpy 2D array)
    """
    rle = annToRLE(ann, frameId)
    if rle is not None:
        m = maskUtils.decode(rle)
        return m



def combineMasks(masks, width, height):
    # Crear una matriz vacía para la máscara combinada
    combined = np.zeros((height, width), dtype=np.uint8)

    # Combinar las máscaras en la matriz vacía
    for mask in masks:
        combined += mask  # Sumar la máscara a la máscara combinada

    # Aplicar umbral para obtener una única máscara binaria
    combined = np.where(combined > 0, 1, 0)
    return combined

def unifyMasks(masks, width, height):
    # Crear una matriz vacía para la máscara combinada
    unified = np.zeros((height, width), dtype=np.uint8)

    # Combinar las máscaras en la matriz vacía
    for mask in masks:
        unified += mask  # Sumar la máscara a la máscara combinada

    
    return unified

In [5]:
def load_images_from_folder(path,image_files):
    images = []
    for file in image_files:
        img = cv2.imread(os.path.join(path,file))
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        images.append(img)
    return images

def load_all_initial_masks_from_dataset():
    all_masks = []
    for video in vidTrain:
        ann = [a for a in annTrain if a['video_id'] == video['id']]
        masks = [annToMask(a, 0) * (i + 1) for i, a in enumerate(ann) if annToMask(a, 0) is not None]
        all_masks.append(unifyMasks(masks, video['width'], video['height']))
    return all_masks

def load_all_masks_for_video(video):
    ann = [a for a in annTrain if a['video_id'] == video['id']]
    all_masks  = []
    for image_num in range(0,video['length']):
        masks = []
        for i, a in enumerate(ann):
            annot = annToMask(a, image_num)
            if annot is not None: masks.append(annot * (i + 1))
        single_mask = unifyMasks(masks, video['width'], video['height'])
        all_masks.append(single_mask)
    return all_masks

def generate_video_from_frames(frames, output_path, fps=30):
    frames = torch.from_numpy(np.asarray(frames))
    if not os.path.exists(os.path.dirname(output_path)):
        os.makedirs(os.path.dirname(output_path))
    torchvision.io.write_video(output_path, frames, fps=fps, video_codec="libx264")
    return output_path

In [6]:
def calculate_iou(mask1, mask2):
    # Ensure both masks have the same shape
    assert mask1.shape == mask2.shape, "Mask shapes must be the same."

    # Calculate intersection and union for each label
    labels = np.unique(np.concatenate((mask1, mask2)))
    intersection = np.zeros_like(mask1, dtype=np.float32)
    union = np.zeros_like(mask1, dtype=np.float32)
    iou_per_label = {}

    for label in labels:
        mask1_label = mask1 == label
        mask2_label = mask2 == label
        c_intersection = np.logical_and(mask1_label, mask2_label)
        c_union = np.logical_or(mask1_label, mask2_label)
        intersection += c_intersection
        union += c_union
        iou_per_label[label] = np.sum(c_intersection) / np.sum(c_union)

    # Calculate IoU
    iou = np.sum(intersection) / np.sum(union)

    return iou, iou_per_label

def compute_f_measure(mask1, mask2):
    # Ensure both masks have the same shape
    assert mask1.shape == mask2.shape, "Mask shapes must be the same."

    # Calculate F-measure for each label
    labels = np.unique(np.concatenate((mask1, mask2)))
    f_measure_per_label = {}

    for label in labels:
        mask1_label = mask1 == label
        mask2_label = mask2 == label

        true_positives = np.logical_and(mask1_label, mask2_label).sum()
        false_positives = np.logical_and(mask1_label, np.logical_not(mask2_label)).sum()
        false_negatives = np.logical_and(np.logical_not(mask1_label), mask2_label).sum()

        precision = true_positives / (true_positives + false_positives)
        recall = true_positives / (true_positives + false_negatives)

        f_measure = 2 * (precision * recall) / (precision + recall)
        f_measure_per_label[label] = f_measure

    return f_measure_per_label

In [7]:
def run_model_on_ovis_set(name, model,path_set,videos, annotations, compute_metrics = False,save_masks = False, compute_video = False, verbose = True):
    for video in videos:
        # Load all images as np.array
        if verbose: print('Loading dataset images')
        images = load_images_from_folder(path_set,video['file_names'])

        # Load al poligon of first image to a usable mask
        if verbose: print('Creating first annotated mask for VOS model')
        ann = [a for a in annotations if a['video_id'] == video['id']]
        masks = [(annToMask(a, 0) * (i + 1)) for i, a in enumerate(ann)]
        initial_mask = unifyMasks(masks, video['width'], video['height'])

        #Compute masks for all images
        if verbose:print('Computing all masks')
        model.xmem.clear_memory()
        masks, logits, painted_images = model.generator(images=images, template_mask=initial_mask)
        model.xmem.clear_memory()  

        if compute_metrics:
            if verbose: print('Computing Metrics')
            ground_truth_masks = load_all_masks_for_video(video)
            for i,(mask_infered, mask_gt) in enumerate(zip(masks[1:],ground_truth_masks[1:])):
                f_measure = compute_f_measure(mask_infered,mask_gt)
                iou, iou_per_label = calculate_iou(mask_infered,mask_gt)
                print(f'Mask {i}: f_mesure{f_measure}, iou {iou}, per label {iou_per_label}')
                
        if compute_video: 
            if verbose: print('Generating video')
            generate_video_from_frames(painted_images, output_path="./result/track/{}.mp4".format('Video1'+name), fps = 30) 

        if save_masks:
            if verbose: print('Saving masks') 
            path_to_masks = './result/mask/{}'.format('Video1'+name)
            if not os.path.exists(path_to_masks): os.makedirs(path_to_masks)
            for i,mask in enumerate(masks): np.save(os.path.join(path_to_masks, '{:05d}.npy'.format(i)), mask)
                
    return masks, logits, painted_images

In [8]:
SAM_checkpoint = "./checkpoints/sam_vit_h_4b8939.pth"
xmem_checkpoint = "./checkpoints/XMem-s012.pth"
e2fgvi_checkpoint = "./checkpoints/E2FGVI-HQ-CVPR22.pth"
''' args = {
    'use_refinement' : False
        }
model = TrackingAnything(SAM_checkpoint, xmem_checkpoint, e2fgvi_checkpoint,args)


args = {
    'use_refinement' : True,
    'refinement_mode' : 'bbox'
         }
modelSamBbox = TrackingAnything(SAM_checkpoint, xmem_checkpoint, e2fgvi_checkpoint,args)


args = {
   'use_refinement' : True,
   'refinement_mode' : 'point'
       }
modelSamPoint = TrackingAnything(SAM_checkpoint, xmem_checkpoint, e2fgvi_checkpoint,args)



args = {
   'use_refinement' : True,
   'refinement_mode' : 'both'
       }
modelBoth = TrackingAnything(SAM_checkpoint, xmem_checkpoint, e2fgvi_checkpoint,args)

'''
args = {
   'use_refinement' : True,
   'refinement_mode' : 'mask_bbox_pos_neg'
       }
modelBothNeg = TrackingAnything(SAM_checkpoint, xmem_checkpoint, e2fgvi_checkpoint,args)

Initializing BaseSegmenter to cuda:0
Hyperparameters read from the model weights: C^k=64, C^v=512, C^h=64
Single object mode: False
Sam Refinement ACTIVATED. Mode: mask_bbox_pos_neg


In [9]:
masks_refinement_bbox, logits_refinement_bbox, painted_images_refinement_bbox = run_model_on_ovis_set(name = '_mask_pointsPN_lizard',model = modelBothNeg, path_set = ovis_images,videos = vidTrain[37:38],annotations = annTrain,compute_metrics = False, save_masks=False, compute_video=True)

Loading dataset images
Creating first annotated mask for VOS model
Computing all masks


Tracking image:   2%|▏         | 3/175 [00:03<02:40,  1.07it/s]

Correcting Point: 507,414
To: 503,418


Tracking image:   6%|▋         | 11/175 [00:06<00:59,  2.74it/s]

Correcting Point: 486,421
To: 487,431


Tracking image:   8%|▊         | 14/175 [00:07<00:57,  2.79it/s]

Correcting Point: 469,437
To: 464,436


Tracking image:  10%|▉         | 17/175 [00:08<00:55,  2.86it/s]

Correcting Point: 158,357
To: 156,357


Tracking image:  14%|█▎        | 24/175 [00:10<00:52,  2.90it/s]

Correcting Point: 168,363
To: 167,366


Tracking image:  14%|█▍        | 25/175 [00:11<00:51,  2.90it/s]

Correcting Point: 168,364
To: 168,366


Tracking image:  15%|█▍        | 26/175 [00:11<00:51,  2.89it/s]

Correcting Point: 167,363
To: 165,364


Tracking image:  15%|█▌        | 27/175 [00:11<00:51,  2.89it/s]

Correcting Point: 166,363
To: 166,365


Tracking image:  16%|█▌        | 28/175 [00:12<00:50,  2.90it/s]

Correcting Point: 167,366
To: 166,367


Tracking image:  17%|█▋        | 29/175 [00:12<00:50,  2.89it/s]

Correcting Point: 166,368
To: 164,368


Tracking image:  17%|█▋        | 30/175 [00:12<00:50,  2.88it/s]

Correcting Point: 165,370
To: 164,370


Tracking image:  21%|██        | 36/175 [00:14<00:48,  2.87it/s]

Correcting Point: 394,23
To: 392,24


Tracking image:  23%|██▎       | 40/175 [00:16<00:46,  2.89it/s]

Correcting Point: 303,317
To: 306,321


Tracking image:  23%|██▎       | 41/175 [00:16<00:46,  2.88it/s]

Correcting Point: 535,287
To: 535,290


Tracking image:  24%|██▍       | 42/175 [00:17<00:46,  2.87it/s]

Correcting Point: 303,330
To: 306,334
Correcting Point: 538,284
To: 539,285


Tracking image:  25%|██▍       | 43/175 [00:17<00:46,  2.85it/s]

Correcting Point: 296,328
To: 304,335
Correcting Point: 535,282
To: 535,284


Tracking image:  25%|██▌       | 44/175 [00:17<00:46,  2.81it/s]

Correcting Point: 301,321
To: 310,331
Correcting Point: 533,280
To: 533,282


Tracking image:  26%|██▌       | 45/175 [00:18<00:45,  2.83it/s]

Correcting Point: 302,321
To: 310,334
Correcting Point: 176,385
To: 176,386
Correcting Point: 532,276
To: 533,278


Tracking image:  26%|██▋       | 46/175 [00:18<00:45,  2.82it/s]

Correcting Point: 151,299
To: 147,298
Correcting Point: 528,275
To: 529,279


Tracking image:  27%|██▋       | 47/175 [00:18<00:45,  2.82it/s]

Correcting Point: 290,338
To: 301,357
Correcting Point: 525,282
To: 528,291


Tracking image:  27%|██▋       | 48/175 [00:19<00:44,  2.84it/s]

Correcting Point: 299,338
To: 310,356
Correcting Point: 530,286
To: 532,292


Tracking image:  28%|██▊       | 49/175 [00:19<00:44,  2.85it/s]

Correcting Point: 292,338
To: 304,361
Correcting Point: 178,389
To: 176,389
Correcting Point: 520,283
To: 525,294


Tracking image:  29%|██▊       | 50/175 [00:19<00:43,  2.85it/s]

Correcting Point: 304,337
To: 314,358
Correcting Point: 520,283
To: 524,294
Correcting Point: 384,274
To: 384,273


Tracking image:  29%|██▉       | 51/175 [00:20<00:43,  2.84it/s]

Correcting Point: 304,337
To: 311,357
Correcting Point: 520,283
To: 523,294


Tracking image:  30%|██▉       | 52/175 [00:20<00:43,  2.84it/s]

Correcting Point: 241,361
To: 242,375
Correcting Point: 524,281
To: 528,292


Tracking image:  30%|███       | 53/175 [00:20<00:42,  2.84it/s]

Correcting Point: 246,360
To: 249,373


Tracking image:  31%|███       | 54/175 [00:21<00:42,  2.83it/s]

Correcting Point: 348,283
To: 347,283
Correcting Point: 141,323
To: 136,323
Correcting Point: 523,263
To: 523,256


Tracking image:  33%|███▎      | 57/175 [00:22<00:41,  2.86it/s]

Correcting Point: 550,284
To: 550,285


Tracking image:  33%|███▎      | 58/175 [00:22<00:40,  2.87it/s]

Correcting Point: 537,281
To: 540,290


Tracking image:  34%|███▎      | 59/175 [00:23<00:40,  2.87it/s]

Correcting Point: 534,281
To: 537,290


Tracking image:  34%|███▍      | 60/175 [00:23<00:39,  2.88it/s]

Correcting Point: 531,276
To: 536,286


Tracking image:  35%|███▍      | 61/175 [00:23<00:40,  2.81it/s]

Correcting Point: 530,275
To: 537,283


Tracking image:  35%|███▌      | 62/175 [00:24<00:40,  2.82it/s]

Correcting Point: 524,279
To: 527,287


Tracking image:  36%|███▌      | 63/175 [00:24<00:39,  2.85it/s]

Correcting Point: 522,280
To: 525,287


Tracking image:  37%|███▋      | 64/175 [00:24<00:38,  2.86it/s]

Correcting Point: 217,444
To: 216,445
Correcting Point: 525,281
To: 528,286


Tracking image:  37%|███▋      | 65/175 [00:25<00:38,  2.85it/s]

Correcting Point: 520,282
To: 521,286


Tracking image:  38%|███▊      | 66/175 [00:25<00:38,  2.83it/s]

Correcting Point: 298,283
To: 298,282
Correcting Point: 516,283
To: 519,285


Tracking image:  42%|████▏     | 74/175 [00:28<00:34,  2.89it/s]

Correcting Point: 295,406
To: 295,410


Tracking image:  43%|████▎     | 75/175 [00:28<00:34,  2.88it/s]

Correcting Point: 291,403
To: 286,403


Tracking image:  43%|████▎     | 76/175 [00:28<00:34,  2.85it/s]

Correcting Point: 287,403
To: 285,403


Tracking image:  45%|████▍     | 78/175 [00:29<00:34,  2.85it/s]

Correcting Point: 357,333
To: 356,333


Tracking image:  45%|████▌     | 79/175 [00:30<00:33,  2.85it/s]

Correcting Point: 280,410
To: 276,414
Correcting Point: 165,305
To: 157,306


Tracking image:  46%|████▌     | 80/175 [00:30<00:33,  2.88it/s]

Correcting Point: 280,401
To: 277,403


Tracking image:  46%|████▋     | 81/175 [00:30<00:32,  2.85it/s]

Correcting Point: 287,395
To: 278,386
Correcting Point: 361,328
To: 361,332


Tracking image:  47%|████▋     | 82/175 [00:31<00:32,  2.88it/s]

Correcting Point: 468,351
To: 468,357


Tracking image:  47%|████▋     | 83/175 [00:31<00:32,  2.87it/s]

Correcting Point: 498,331
To: 499,331
Correcting Point: 423,367
To: 425,370
Correcting Point: 304,307
To: 305,307
Correcting Point: 255,250
To: 251,246


Tracking image:  48%|████▊     | 84/175 [00:31<00:31,  2.90it/s]

Correcting Point: 220,315
To: 221,315


Tracking image:  49%|████▊     | 85/175 [00:32<00:30,  2.92it/s]

Correcting Point: 351,401
To: 350,386


Tracking image:  49%|████▉     | 86/175 [00:32<00:30,  2.90it/s]

Correcting Point: 356,398
To: 356,386


Tracking image:  50%|████▉     | 87/175 [00:32<00:30,  2.91it/s]

Correcting Point: 352,401
To: 352,390


Tracking image:  50%|█████     | 88/175 [00:33<00:30,  2.89it/s]

Correcting Point: 351,401
To: 351,392
Correcting Point: 357,244
To: 358,240
Correcting Point: 212,314
To: 214,314


Tracking image:  52%|█████▏    | 91/175 [00:34<00:28,  2.91it/s]

Correcting Point: 236,307
To: 240,309


Tracking image:  53%|█████▎    | 92/175 [00:34<00:28,  2.92it/s]

Correcting Point: 351,416
To: 351,414
Correcting Point: 217,307
To: 207,306


Tracking image:  53%|█████▎    | 93/175 [00:34<00:28,  2.90it/s]

Correcting Point: 350,416
To: 350,415
Correcting Point: 222,304
To: 208,300


Tracking image:  54%|█████▎    | 94/175 [00:35<00:27,  2.92it/s]

Correcting Point: 342,408
To: 342,407


Tracking image:  54%|█████▍    | 95/175 [00:35<00:27,  2.91it/s]

Correcting Point: 212,294
To: 210,292


Tracking image:  55%|█████▍    | 96/175 [00:35<00:27,  2.91it/s]

Correcting Point: 235,299
To: 244,304


Tracking image:  55%|█████▌    | 97/175 [00:36<00:26,  2.92it/s]

Correcting Point: 229,298
To: 216,295


Tracking image:  57%|█████▋    | 99/175 [00:36<00:26,  2.91it/s]

Correcting Point: 418,405
To: 418,403


Tracking image:  58%|█████▊    | 101/175 [00:37<00:25,  2.89it/s]

Correcting Point: 219,307
To: 215,307


Tracking image:  58%|█████▊    | 102/175 [00:37<00:25,  2.88it/s]

Correcting Point: 219,302
To: 216,302


Tracking image:  59%|█████▉    | 103/175 [00:38<00:25,  2.87it/s]

Correcting Point: 221,288
To: 221,287


Tracking image:  61%|██████    | 106/175 [00:39<00:24,  2.86it/s]

Correcting Point: 226,293
To: 219,293


Tracking image:  61%|██████    | 107/175 [00:39<00:23,  2.87it/s]

Correcting Point: 222,289
To: 219,289


Tracking image:  62%|██████▏   | 108/175 [00:40<00:23,  2.85it/s]

Correcting Point: 192,341
To: 191,341
Correcting Point: 477,395
To: 477,387
Correcting Point: 233,298
To: 248,310


Tracking image:  62%|██████▏   | 109/175 [00:40<00:23,  2.79it/s]

Correcting Point: 475,396
To: 478,393
Correcting Point: 231,275
To: 214,279


Tracking image:  63%|██████▎   | 110/175 [00:40<00:23,  2.80it/s]

Correcting Point: 473,396
To: 477,396
Correcting Point: 222,295
To: 215,295


Tracking image:  63%|██████▎   | 111/175 [00:41<00:22,  2.79it/s]

Correcting Point: 211,298
To: 214,296


Tracking image:  64%|██████▍   | 112/175 [00:41<00:22,  2.80it/s]

Correcting Point: 213,290
To: 218,290
Correcting Point: 234,290
To: 219,295


Tracking image:  65%|██████▍   | 113/175 [00:41<00:22,  2.80it/s]

Correcting Point: 188,313
To: 168,289
Correcting Point: 168,361
To: 158,355


Tracking image:  65%|██████▌   | 114/175 [00:42<00:21,  2.78it/s]

Correcting Point: 173,299
To: 167,293
Correcting Point: 132,408
To: 131,408
Correcting Point: 458,318
To: 459,318
Correcting Point: 512,277
To: 512,276
Correcting Point: 451,247
To: 451,248
Correcting Point: 168,257
To: 166,252


Tracking image:  66%|██████▌   | 115/175 [00:42<00:21,  2.79it/s]

Correcting Point: 257,421
To: 257,422
Correcting Point: 181,315
To: 168,307


Tracking image:  66%|██████▋   | 116/175 [00:42<00:21,  2.78it/s]

Correcting Point: 273,313
To: 268,312
Correcting Point: 242,383
To: 235,386
Correcting Point: 218,240
To: 215,234


Tracking image:  67%|██████▋   | 117/175 [00:43<00:20,  2.80it/s]

Correcting Point: 355,249
To: 356,249


Tracking image:  67%|██████▋   | 118/175 [00:43<00:20,  2.82it/s]

Correcting Point: 305,350
To: 309,350


Tracking image:  68%|██████▊   | 119/175 [00:44<00:19,  2.84it/s]

Correcting Point: 323,345
To: 322,345
Correcting Point: 400,190
To: 409,192


Tracking image:  69%|██████▊   | 120/175 [00:44<00:19,  2.81it/s]

Correcting Point: 302,298
To: 300,296
Correcting Point: 114,475
To: 114,477
Correcting Point: 236,295
To: 236,292


Tracking image:  69%|██████▉   | 121/175 [00:44<00:19,  2.79it/s]

Correcting Point: 93,397
To: 91,397
Correcting Point: 323,305
To: 323,293
Correcting Point: 239,295
To: 237,293


Tracking image:  70%|██████▉   | 122/175 [00:45<00:18,  2.80it/s]

Correcting Point: 311,309
To: 310,295
Correcting Point: 249,291
To: 248,288


Tracking image:  70%|███████   | 123/175 [00:45<00:18,  2.82it/s]

Correcting Point: 307,310
To: 300,298
Correcting Point: 247,290
To: 247,288
Correcting Point: 451,231
To: 448,232


Tracking image:  71%|███████   | 124/175 [00:45<00:18,  2.81it/s]

Correcting Point: 304,310
To: 299,299
Correcting Point: 247,291
To: 249,292
Correcting Point: 444,228
To: 443,227


Tracking image:  71%|███████▏  | 125/175 [00:46<00:17,  2.83it/s]

Correcting Point: 302,311
To: 295,302


Tracking image:  72%|███████▏  | 126/175 [00:46<00:17,  2.83it/s]

Correcting Point: 303,311
To: 298,302
Correcting Point: 249,291
To: 249,289


Tracking image:  73%|███████▎  | 127/175 [00:46<00:16,  2.87it/s]

Correcting Point: 311,309
To: 307,299


Tracking image:  73%|███████▎  | 128/175 [00:47<00:16,  2.86it/s]

Correcting Point: 312,310
To: 309,300


Tracking image:  74%|███████▎  | 129/175 [00:47<00:15,  2.88it/s]

Correcting Point: 309,312
To: 305,305


Tracking image:  74%|███████▍  | 130/175 [00:47<00:15,  2.88it/s]

Correcting Point: 314,312
To: 313,306


Tracking image:  75%|███████▍  | 131/175 [00:48<00:15,  2.86it/s]

Correcting Point: 314,312
To: 314,307


Tracking image:  75%|███████▌  | 132/175 [00:48<00:14,  2.88it/s]

Correcting Point: 310,316
To: 310,315
Correcting Point: 266,295
To: 265,295


Tracking image:  76%|███████▌  | 133/175 [00:48<00:14,  2.89it/s]

Correcting Point: 309,320
To: 309,318
Correcting Point: 266,297
To: 266,296
Correcting Point: 488,258
To: 487,258


Tracking image:  82%|████████▏ | 143/175 [00:52<00:11,  2.89it/s]

Correcting Point: 272,369
To: 275,371
Correcting Point: 628,167
To: 630,165


Tracking image:  82%|████████▏ | 144/175 [00:52<00:10,  2.88it/s]

Correcting Point: 270,370
To: 268,373
Correcting Point: 628,167
To: 633,167


Tracking image:  83%|████████▎ | 145/175 [00:53<00:10,  2.87it/s]

Correcting Point: 431,397
To: 429,395


Tracking image:  83%|████████▎ | 146/175 [00:53<00:10,  2.85it/s]

Correcting Point: 424,405
To: 420,399


Tracking image:  84%|████████▍ | 147/175 [00:53<00:09,  2.87it/s]

Correcting Point: 418,411
To: 414,402


Tracking image:  85%|████████▍ | 148/175 [00:54<00:09,  2.85it/s]

Correcting Point: 415,422
To: 403,411


Tracking image:  85%|████████▌ | 149/175 [00:54<00:09,  2.86it/s]

Correcting Point: 414,424
To: 428,423


Tracking image:  86%|████████▌ | 150/175 [00:54<00:08,  2.88it/s]

Correcting Point: 418,419
To: 428,423


Tracking image:  86%|████████▋ | 151/175 [00:55<00:08,  2.81it/s]

Correcting Point: 419,418
To: 428,420


Tracking image:  87%|████████▋ | 152/175 [00:55<00:08,  2.84it/s]

Correcting Point: 420,416
To: 428,415


Tracking image:  87%|████████▋ | 153/175 [00:55<00:07,  2.85it/s]

Correcting Point: 421,413
To: 429,414
Correcting Point: 505,284
To: 505,286


Tracking image:  88%|████████▊ | 154/175 [00:56<00:07,  2.86it/s]

Correcting Point: 422,414
To: 428,416
Correcting Point: 507,288
To: 507,289


Tracking image:  89%|████████▊ | 155/175 [00:56<00:06,  2.87it/s]

Correcting Point: 472,348
To: 472,354


Tracking image:  89%|████████▉ | 156/175 [00:56<00:06,  2.85it/s]

Correcting Point: 486,346
To: 486,349


Tracking image:  90%|████████▉ | 157/175 [00:57<00:06,  2.83it/s]

Correcting Point: 438,401
To: 434,394


Tracking image:  90%|█████████ | 158/175 [00:57<00:05,  2.84it/s]

Correcting Point: 414,414
To: 408,407


Tracking image:  91%|█████████ | 159/175 [00:57<00:05,  2.84it/s]

Correcting Point: 480,348
To: 485,350


Tracking image:  91%|█████████▏| 160/175 [00:58<00:05,  2.85it/s]

Correcting Point: 410,418
To: 403,411
Correcting Point: 488,384
To: 480,383
Correcting Point: 524,279
To: 527,281


Tracking image:  92%|█████████▏| 161/175 [00:58<00:04,  2.84it/s]

Correcting Point: 409,419
To: 402,412


Tracking image:  93%|█████████▎| 162/175 [00:59<00:04,  2.83it/s]

Correcting Point: 408,419
To: 402,412
Correcting Point: 527,276
To: 531,279


Tracking image:  93%|█████████▎| 163/175 [00:59<00:04,  2.84it/s]

Correcting Point: 408,418
To: 402,412
Correcting Point: 531,278
To: 531,279


Tracking image:  94%|█████████▎| 164/175 [00:59<00:03,  2.86it/s]

Correcting Point: 408,418
To: 402,412
Correcting Point: 532,277
To: 533,278


Tracking image:  94%|█████████▍| 165/175 [01:00<00:03,  2.85it/s]

Correcting Point: 422,412
To: 418,402
Correcting Point: 525,280
To: 527,282


Tracking image:  95%|█████████▍| 166/175 [01:00<00:03,  2.81it/s]

Correcting Point: 423,410
To: 418,402
Correcting Point: 499,286
To: 495,296


Tracking image:  95%|█████████▌| 167/175 [01:00<00:02,  2.81it/s]

Correcting Point: 416,415
To: 411,407
Correcting Point: 519,279
To: 525,284


Tracking image:  96%|█████████▌| 168/175 [01:01<00:02,  2.82it/s]

Correcting Point: 419,414
To: 415,404
Correcting Point: 396,175
To: 394,175


Tracking image:  97%|█████████▋| 169/175 [01:01<00:02,  2.82it/s]

Correcting Point: 487,345
To: 488,345


Tracking image:  98%|█████████▊| 172/175 [01:02<00:01,  2.86it/s]

Correcting Point: 285,224
To: 289,227


Tracking image:  99%|█████████▉| 173/175 [01:02<00:00,  2.86it/s]

Correcting Point: 274,228
To: 277,231


Tracking image: 100%|██████████| 175/175 [01:03<00:00,  2.75it/s]


Generating video


In [None]:
masks_refinement_bbox, logits_refinement_bbox, painted_images_refinement_bbox = run_model_on_ovis_set(name = '_both_TA_lizard',model = modelBoth, path_set = ovis_images,videos = vidTrain[37:38],annotations = annTrain,compute_metrics = False, save_masks=False, compute_video=True)

In [None]:
masks, logits, painted_images = run_model_on_ovis_set(name = '_no_ref_lizards',model = model, path_set = ovis_images,videos = vidTrain[37:38],annotations = annTrain,compute_metrics = False, save_masks=False, compute_video=True)

In [None]:
masks_refinement_point, logits_refinement_point, painted_images_refinement_point = run_model_on_ovis_set(name = '_point_ref_lizard',model = modelSamPoint, path_set = ovis_images,videos = vidTrain[37:38],annotations = annTrain,compute_metrics = False, save_masks=False, compute_video=True)

In [None]:
def print_images(image1, image2, image3):
    fig = plt.figure(figsize=(30, 15))

    # Create a grid with two subplots
    grid = plt.GridSpec(1, 3)

    # Display the first image in the left subplot
    ax1 = plt.subplot(grid[0])
    ax1.imshow(image1)
    ax1.set_title('Image')

    # Display the first image in the left subplot
    ax1 = plt.subplot(grid[1])
    ax1.imshow(image2)
    ax1.set_title('No Ref')

    # Display the second image in the right subplot
    ax2 = plt.subplot(grid[2])
    ax2.imshow(image3)
    ax2.set_title('Ref')

    # Adjust spacing between subplots
    plt.subplots_adjust(wspace=0.1)

    # Show the figure
    plt.show()

In [None]:
images = load_images_from_folder(ovis_images,vidTrain[37]['file_names'])

In [None]:
for i in range(153,157):
    plt.imshow(images[i])
    plt.show()

In [None]:
for id in range(0,10):
 print_images(images[id],painted_images[id],painted_images_refinement_bbox[id])

In [None]:
for image in painted_images_refinement_point: 
    plt.imshow(image)
    plt.show()

In [None]:
plt.imshow(masks_refinement_point[2])

In [None]:
for image in painted_images: 
    plt.imshow(image)
    plt.show()

JUNK TESTING 

In [None]:
images = load_images_from_folder(ovis_images,vidTrain[0]['file_names'][0:10])

In [None]:
def get_best_point_of_interest(segmentation_mask):
    # Find contours in the segmentation mask
    points = []
    contours, _ = cv2.findContours(segmentation_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    for contour in contours:
        # Extract the bounding box coordinates of the contour
        M = cv2.moments(contour)
        points.append([int(M["m10"] / M["m00"]), int(M["m01"] / M["m00"])])

    return np.array(points).astype('int')

In [None]:
ann = [a for a in annTrain if a['video_id'] == vidTrain[0]['id']]
masks = [(annToMask(a, 0) * (i + 1)) for i, a in enumerate(ann)]
initial_mask = unifyMasks(masks, vidTrain[0]['width'], vidTrain[0]['height'])

In [None]:
get_best_point_of_interest(masks[4])

In [None]:
all_points = []
for mask in masks:
    
    points = get_best_point_of_interest(mask)
    all_points.append(points)
    plt.imshow(mask)
    print(points[0][0])
    print(all_points)
    plt.scatter(points[0][0], points[0][1], c='red', marker='o')

    # Set the axis limits
    plt.xlim(0, mask.shape[1])
    plt.ylim(mask.shape[0], 0)

    # Show the plot
    plt.show()

In [None]:
all_points = np.array(all_points).astype('uint8')
all_labels = np.ones((all_points.shape[0],1)).astype('uint8')

In [None]:
modelSam.xmem.sam_model.sam_controler.reset_image()
modelSam.xmem.sam_model.sam_controler.set_image(images[3])

In [None]:
all_points[0]

In [None]:
all_labels[0]

In [None]:
modelSam.xmem.sam_model.sam_controler.predictor.predict(point_coords=all_points[0], 
                                point_labels=all_labels[0], 
                                multimask_output=False)

In [None]:
all_points[0].shape[0]

In [None]:
all_labels[0]

In [None]:
mode = 'point'
prompts = {
    'point_coords': all_points[0],
    'point_labels': all_labels[0], 
}
modelSam.xmem.sam_model.sam_controler.predict(prompts, mode, multimask=False)

In [None]:
modelSam.xmem.sam_model.sam_controler.predictor.predict(point_coords=all_points, 
                                point_labels=all_labels, 
                                multimask_output=False)

In [None]:
plt.imshow(images[0])

In [None]:
SAM_checkpoint = "./checkpoints/sam_vit_h_4b8939.pth"
xmem_checkpoint = "./checkpoints/XMem-s012.pth"
e2fgvi_checkpoint = "./checkpoints/E2FGVI-HQ-CVPR22.pth"
args = {'use_refinement':False}
model = TrackingAnything(SAM_checkpoint, xmem_checkpoint, e2fgvi_checkpoint,args)
args = {'use_refinement':True}
modelSam = TrackingAnything(SAM_checkpoint, xmem_checkpoint, e2fgvi_checkpoint,args)

In [None]:
images = load_images_from_folder(ovis_images,vidTrain[0]['file_names'])
#model.xmem.sam_refinement()

In [None]:
plt.imshow(images[0])

In [None]:
ann = [a for a in annTrain if a['video_id'] == vidTrain[0]['id']]
masks = [(annToMask(a, 0) * (i + 1)) for i, a in enumerate(ann)]
initial_mask = unifyMasks(masks, vidTrain[0]['width'], vidTrain[0]['height'])
masks = [(annToMask(a, 0)) for i, a in enumerate(ann)]

In [None]:
#resized_mask = cv2.resize(masks[3], (256, 256), interpolation=cv2.INTER_NEAREST)

In [None]:
np.nonzero(masks[3])

In [None]:
def compute_bounding_box(segmentation_mask):
    # Get the indices where the segmentation mask is non-zero
    nonzero_indices = np.nonzero(segmentation_mask)
    
    # Calculate the bounding box coordinates
    min_row = np.min(nonzero_indices[0])
    max_row = np.max(nonzero_indices[0])
    min_col = np.min(nonzero_indices[1])
    max_col = np.max(nonzero_indices[1])
    
    # Return the bounding box coordinates as a tuple
    bounding_box = [min_col,min_row, max_col, max_row]
    return bounding_box

In [None]:
def show_box(box, ax):
    x0, y0 = box[0], box[1]
    w, h = box[2] - box[0], box[3] - box[1]
    ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2))  

In [None]:
bboxes = [compute_bounding_box(mask) for mask in masks]

In [None]:
input_boxes = torch.tensor(bboxes, device=modelSam.xmem.sam_model.sam_controler.predictor.device)

In [None]:
transformed_boxes = modelSam.xmem.sam_model.sam_controler.predictor.transform.apply_boxes_torch(input_boxes, images[0].shape[:2])

In [None]:
modelSam.xmem.sam_model.sam_controler.reset_image()
modelSam.xmem.sam_model.sam_controler.set_image(images[0])

In [None]:
plt.imshow(images[0])
plt.show()
# mask only ------------------------
mode = 'bounding_boxes'
prompts = {'bounding_boxes': transformed_boxes}

masksout, scores, logits = modelSam.xmem.sam_model.sam_controler.predict(prompts, mode, multimask=False)  # masks (n, h, w), scores (n,), logits (n, 256, 256)

In [None]:
masksout

In [None]:
pp_mask = np.zeros_like(masksout[0])

In [None]:
np.unique(pp_mask)

In [None]:
(masksout[0] * 1 + masksout[1] * 2 )

In [None]:
pp_mask[masksout[1]] = 2

In [None]:
masksout[0] == pp_mask

In [None]:
for i in range(0,len(scores)):
    if i == np.argmax(scores): print('Selected')
    painted_image = mask_painter(images[0], masksout[i][0].numpy().astype('uint8'))
    show_box(input_boxes[i], plt.gca())
    plt.imshow(painted_image)
    plt.show()

In [None]:

plt.imshow(images[0])
plt.show()
# mask only ------------------------
mode = 'bbox'
prompts = {'bounding_box': transformed_boxes}

masksout, scores, logits = modelSam.xmem.sam_model.sam_controler.predict(prompts, mode, multimask=False)  # masks (n, h, w), scores (n,), logits (n, 256, 256)
for i in range(0,len(scores)):
    if i == np.argmax(scores): print('Selected')
    painted_image = mask_painter(images[1], masksout[i].astype('uint8'))
    show_box(bb, plt.gca())
    plt.imshow(painted_image)
    plt.show()

In [None]:
modelSam.xmem.sam_model.sam_controler.predictor.transform

In [None]:
#modelSam.xmem.sam_model.sam_controler.reset_image()
#modelSam.xmem.sam_model.sam_controler.set_image(images[0])
plt.imshow(images[0])

# mask only ------------------------
mode = 'mask'
prompts = {'mask_input': masks[3][None,:,:]}

masks, scores, logits = modelSam.xmem.sam_model.sam_controler.predict(prompts, mode, multimask=True)  # masks (n, h, w), scores (n,), logits (n, 256, 256)
for i in range(0,len(scores)):
    if i == np.argmax(scores): print('Selected')
    painted_image = mask_painter(images[1], masks[i].astype('uint8'))
    plt.imshow(painted_image)
    plt.show()

In [None]:
model.xmem.clear_memory()
masksout, logitsout, painted_imagesout = model.generator(images=images[0:2], template_mask=initial_mask)
model.xmem.clear_memory() 

In [None]:
from torchvision.transforms import Resize
resizer = Resize([256, 256])

In [None]:
ind_logits.shape

In [None]:
triallogits = logitsout[1][0].unsqueeze(0)
ind_logits = resizer(triallogits).cpu().numpy()

In [None]:
painted_imagesout[0].shape

In [None]:
ind_logits.shape

In [None]:
modelSam.xmem.sam_model.sam_controler.reset_image()
modelSam.xmem.sam_model.sam_controler.set_image(images[0])
plt.imshow(images[0])

# mask only ------------------------
mode = 'mask'
prompts = {'mask_input': ind_logits}

masks, scores, logits = modelSam.xmem.sam_model.sam_controler.predict(prompts, mode, multimask=True)  # masks (n, h, w), scores (n,), logits (n, 256, 256)
for i in range(0,len(scores)):
    if i == np.argmax(scores): print('Selected')
    painted_image = mask_painter(images[1], masks[i].astype('uint8'))
    plt.imshow(painted_image)
    plt.show()

In [None]:
modelSam.xmem.sam_model.sam_controler.reset_image()

In [None]:

modelSam.xmem.sam_model.sam_controler.set_image(images[0])
#model.samcontroler.sam_controler.set_image(images[0])
mode = 'point'
prompts = {
    'point_coords': np.array([[500, 650]]),
    'point_labels': np.array([1]), 
}
masks, scores, logits = modelSam.xmem.sam_model.sam_controler.predict(prompts, mode, multimask=False)  # masks (n, h, w), scores (n,), logits (n, 256, 256)
painted_image = mask_painter(images[0], masks[np.argmax(scores)].astype('uint8'))
#cv2.imwrite('/hhd3/gaoshang/truck_point.jpg', painted_image)


plt.imshow(painted_image)
plt.show()

# mask only ------------------------
mode = 'mask'
mask_input  = logits[np.argmax(scores), :, :]

prompts = {'mask_input': mask_input[None, :, :]}
print(prompts['mask_input'].shape)

masks, scores, logits = modelSam.xmem.sam_model.sam_controler.predict(prompts, mode, multimask=True)  # masks (n, h, w), scores (n,), logits (n, 256, 256)
for i in range(0,len(scores)):
    if i == np.argmax(scores): print('Selected')
    painted_image = mask_painter(images[0], masks[i].astype('uint8'))
    plt.imshow(painted_image)
    plt.show()

JUNK

In [None]:
# Lee una imagen y la muestra en una ventanaovis_images + first_video_folder +'/img_0000001.jpg'
img = cv2.imread(os.path.join(ovis_images + first_video_folder +'/img_0000001.jpg'))
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
plt.imshow(img)
plt.axis('off')

In [None]:
colors = [
    (255, 0, 0),    # Rojo
    (0, 255, 0),    # Verde
    (0, 0, 255),    # Azul
    (255, 255, 0),  # Amarillo
    (255, 0, 255),  # Magenta
    (0, 255, 255),  # Cian
    (128, 0, 0),    # Marrón oscuro
    (0, 128, 0),    # Verde oscuro
    (0, 0, 128),    # Azul oscuro
    (128, 128, 0),  # Amarillo oscuro
    (128, 0, 128),  # Magenta oscuro
    (0, 128, 128),  # Cian oscuro
    (255, 128, 0),  # Naranja
    (128, 255, 0),  # Lima
    (255, 0, 128),  # Rosa
    (128, 0, 255),  # Violeta
    (0, 255, 128),  # Turquesa
    (0, 128, 255),  # Azul claro
    (255, 128, 128), # Rosa claro
    (128, 255, 128)  # Verde claro
]

In [None]:
ann = []
video = vidTrain[0]
for a in annTrain:
    if a['video_id'] == video['id']:
        ann.append(a)
        break
    else: continue

mask = annToMask(ann[0], 0)
colored_mask = np.zeros((mask.shape[0], mask.shape[1], 3), dtype=np.uint8)
colored_mask[:, :, 0] = mask * colors[0][0]
colored_mask[:, :, 1] = mask * colors[0][1]
colored_mask[:, :, 2] = mask * colors[0][2]
plt.imshow(colored_mask)

In [None]:
ann = []
for a in annTrain:
    if a['video_id'] == video['id']:
        ann.append(a)
    else: continue

masks = []
for i,a in enumerate(ann):
    m = annToMask(a, 0)
    m = m * (i + 1)
    masks.append(m)

w, h = video['width'], video['height']
unified = unifyMasks(masks, w, h)

In [None]:
model.xmem.clear_memory()
masks, logits, painted_images = model.generator(images=images[0:5], template_mask=unified)
model.xmem.clear_memory()

In [None]:
plt.imshow(painted_images[1])

In [None]:
def print_rgb_mask(mask):
    colored_mask = np.zeros((mask.shape[0], mask.shape[1], 3), dtype=np.uint8)
    colored_mask[:, :, 0] = mask * colors[0][0]
    colored_mask[:, :, 1] = mask * colors[0][1]
    colored_mask[:, :, 2] = mask * colors[0][2]
    plt.imshow(colored_mask)
    plt.show()
    

In [None]:
print_rgb_mask(logits[29])