In [None]:
from pathlib import Path

import cv2
import torch
import torchvision.transforms as transforms
from efficientvit.models.efficientvit.sam import EfficientViTSamPredictor, SamPad
from efficientvit.sam_model_zoo import create_efficientvit_sam_model
from src.core.utils import cv2_itervideo
from src.api.services import gaze_service


In [2]:
torch.cuda.empty_cache()

In [3]:
VIDEO_PATH = Path("../data/recordings/af47ccce-c344-49d9-9916-5729e2ddc021.mp4")
GAZE_DATA_PATH = Path("../data/recordings/af47ccce-c344-49d9-9916-5729e2ddc021.tsv")

torch.cuda.empty_cache()
frames = cv2_itervideo(VIDEO_PATH, device="cuda")

In [None]:
resolution = (frames.shape[3], frames.shape[2])
gaze_data = gaze_service.parse_gazedata_file(GAZE_DATA_PATH)
gaze_points = gaze_service.get_gaze_points(gaze_data, resolution)
frame_gaze_mapping = gaze_service.match_frames_to_gaze(
    num_frames=frames.shape[0],
    gaze_points=gaze_points,
)

In [5]:
model = create_efficientvit_sam_model(
    name="efficientvit-sam-xl1", pretrained=True, weight_url="../checkpoints/efficientvit_sam_xl1.pt"
)
model = model.cuda().eval()
model_predictor = EfficientViTSamPredictor(model)

In [6]:
def preprocess_batch(
    frames: torch.Tensor,
    crop_points: list[tuple[int, int] | None],
    crop_size: int,
    start_index: int,
    end_index: int,
) -> tuple[torch.Tensor, list[tuple[int, int]]]:
    """
    Preprocesses a batch of frames by cropping them around the given crop points

    Args:
        - frames: batch of frames to preprocess. Shape: (B, H, W, C)
        - crop_points: list of crop points for each frame. If a frame does not have a crop point, the value is None
        - crop_size: size of the crop window

    Returns:
        - preprocessed_batch: preprocessed_batch batch of frames. Shape: (B, C, H, W)
        - new_centers: list of local crop points for each cropped frame.
        - frame_indexes: list of original frame indices that were processed
    """
    HALF_CROP = crop_size // 2
    W, H = frames.shape[2], frames.shape[1]

    # create an empty tensor to store the cropped frames
    crop_points = crop_points[start_index:end_index]
    frame_count = sum([1 for point in crop_points if point is not None])
    preprocessed_batch = torch.zeros((frame_count, frames.shape[3], crop_size, crop_size), device=frames.device)
    new_centers = []
    valid_frame_idx = 0

    for frame_idx in range(start_index, end_index):
        crop_point = crop_points[frame_idx]
        if crop_point is not None:
            # Gaze coordinates
            frame = frames[frame_idx].cpu().numpy()
            cx = crop_point[0]
            cy = crop_point[1]

            # Compute bounding-box edges with clamping
            left = max(cx - HALF_CROP, 0)
            right = min(cx + HALF_CROP, W)
            top = max(cy - HALF_CROP, 0)
            bottom = min(cy + HALF_CROP, H)

            # Crop the region
            cropped_frame = frame[top:bottom, left:right]
            local_cx = cx - left
            local_cy = cy - top

            # Apply preprocessing steps
            # TODO create const or UI setting for kernel size?
            preprocessed_frame = cv2.GaussianBlur(cropped_frame, (7, 7), 0)

            # Apply necessary transforms (See efficientvit.models.efficientvit.sam.EfficientViTSam.transform)
            tf = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize(
                    mean=[123.675 / 255, 116.28 / 255, 103.53 / 255],
                    std=[58.395 / 255, 57.12 / 255, 57.375 / 255],
                ),
                SamPad(size=crop_size),
            ])

            # Store results
            temp_tensor = tf(preprocessed_frame).cuda()
            preprocessed_batch[valid_frame_idx] = temp_tensor
            torch.cuda.synchronize()
            del temp_tensor

            # Store the new crop point and frame index
            new_centers.append((local_cx, local_cy))

            valid_frame_idx += 1
            del frame

    return preprocessed_batch, new_centers

In [7]:
BATCH_SIZE = 50
BATCH_NUM = 0


crop_points = [gaze_points[0].position if len(gaze_points) > 0 else None for gaze_points in frame_gaze_mapping]
start_index = BATCH_NUM * BATCH_SIZE
end_index = (BATCH_NUM + 1) * BATCH_SIZE

crop_points_batch = crop_points[start_index:end_index]
crop_points_batch = [[point] for point in crop_points_batch if point is not None]


preprocessed_batch, new_centers = preprocess_batch(frames, crop_points, 1024, start_index, end_index)
preprocessed_batch.shape

torch.Size([43, 3, 1024, 1024])

In [8]:
model_predictor.set_image_batch(preprocessed_batch)
del preprocessed_batch

In [9]:
print(torch.tensor(crop_points_batch).shape)
print(torch.tensor([[1]] * len(crop_points_batch)).shape)

torch.Size([43, 1, 2])
torch.Size([43, 1])


In [10]:
for i in range(len(crop_points_batch)):
    masks, iou_predictions, low_res_masks = model_predictor.predict_torch(
        point_coords=torch.tensor(crop_points_batch).cuda(),
        point_labels=torch.tensor([[1]] * len(crop_points_batch)).cuda(),
        image_index=i,
    )

    print(masks.shape)

torch.Size([43, 3, 1024, 1024])
torch.Size([43, 3, 1024, 1024])
torch.Size([43, 3, 1024, 1024])
torch.Size([43, 3, 1024, 1024])
torch.Size([43, 3, 1024, 1024])
torch.Size([43, 3, 1024, 1024])
torch.Size([43, 3, 1024, 1024])
torch.Size([43, 3, 1024, 1024])
torch.Size([43, 3, 1024, 1024])
torch.Size([43, 3, 1024, 1024])
torch.Size([43, 3, 1024, 1024])
torch.Size([43, 3, 1024, 1024])
torch.Size([43, 3, 1024, 1024])
torch.Size([43, 3, 1024, 1024])
torch.Size([43, 3, 1024, 1024])
torch.Size([43, 3, 1024, 1024])
torch.Size([43, 3, 1024, 1024])
torch.Size([43, 3, 1024, 1024])
torch.Size([43, 3, 1024, 1024])
torch.Size([43, 3, 1024, 1024])
torch.Size([43, 3, 1024, 1024])
torch.Size([43, 3, 1024, 1024])
torch.Size([43, 3, 1024, 1024])
torch.Size([43, 3, 1024, 1024])
torch.Size([43, 3, 1024, 1024])
torch.Size([43, 3, 1024, 1024])
torch.Size([43, 3, 1024, 1024])
torch.Size([43, 3, 1024, 1024])
torch.Size([43, 3, 1024, 1024])
torch.Size([43, 3, 1024, 1024])
torch.Size([43, 3, 1024, 1024])
torch.Si