In [1]:
%load_ext autoreload
%autoreload 2
import matplotlib.pyplot as plt
import cv2
import numpy as np
import torch
from PIL import Image
from transformers import AutoImageProcessor, AutoModel
import pandas as pd
import mediapipe as mp
from tqdm import tqdm
from sklearn.decomposition import PCA
import torchvision.transforms as transforms
from typing import Tuple
from PIL import ImageDraw
from scipy.ndimage import distance_transform_edt

import os
HOME = os.getcwd()
print("HOME:", HOME)
%cd {HOME}/FastSAM

from Inference import segment

  from .autonotebook import tqdm as notebook_tqdm


HOME: c:\Users\arshs\OneDrive\Documents\GitHub\mimic\segmentation
c:\Users\arshs\OneDrive\Documents\GitHub\mimic\segmentation\FastSAM


In [2]:
def get_frame_point(videoPath, out=False):
    # Initialize MediaPipe Hands
    mp_hands = mp.solutions.hands
    hands = mp_hands.Hands(static_image_mode=False,
                        max_num_hands=2,
                        min_detection_confidence=0.5,
                        min_tracking_confidence=0.5)
    mp_drawing = mp.solutions.drawing_utils

    centers = [[], []]
    wrists = [[], []]
    finger_pts = [[], []]

    # Function to calculate standard deviation of distances
    def calculate_stdev(coordinates):
        distances = np.linalg.norm(coordinates - np.mean(coordinates, axis=0), axis=1)
        stdev = np.std(distances)
        return stdev

    # Function to draw a circle on the frame
    def draw_circle(frame, center, RED=1):
        cv2.circle(frame, center, 20, (0, 255*RED, 255*(1-RED)), -1)
        
    def extract_hand_data(frame):
        results = hands.process(frame)
        # print(results.multi_hand_world_landmarks)

        if results.multi_hand_world_landmarks:
            for hand in results.multi_handedness:
                # Get a constant index for the detected hand (0 or 1). If only 1 hand is detected, default to index = 0.
                hand_idx = hand.classification[0].index
                try:
                    hand_landmarks = results.multi_hand_landmarks[hand_idx]
                except:
                    hand_idx = 0
                    hand_landmarks = results.multi_hand_landmarks[0]
                
                # Get key points on palm
                # palm_points = np.asarray([[hand_landmarks.landmark[12].x, hand_landmarks.landmark[12].y, hand_landmarks.landmark[12].z], 
                #                         [hand_landmarks.landmark[16].x, hand_landmarks.landmark[16].y, hand_landmarks.landmark[16].z], 
                #                         [hand_landmarks.landmark[20].x, hand_landmarks.landmark[20].y, hand_landmarks.landmark[20].z],
                #                         [hand_landmarks.landmark[4].x, hand_landmarks.landmark[4].y, hand_landmarks.landmark[4].z],
                #                         [hand_landmarks.landmark[8].x, hand_landmarks.landmark[8].y, hand_landmarks.landmark[8].z]])
                palm_points = np.asarray([
                        [hand_landmarks.landmark[4].x, hand_landmarks.landmark[4].y, hand_landmarks.landmark[4].z],
                        [hand_landmarks.landmark[8].x, hand_landmarks.landmark[8].y, hand_landmarks.landmark[8].z]])

                # Get palm orientation by calculating normal vector of palm plane
                # normal_vector = np.cross(palm_points[2] - palm_points[0], palm_points[1] - palm_points[2])
                # normal_vector /= np.linalg.norm(normal_vector)
                # orientations[hand_idx].append(normal_vector)

                # both contact points
                finger_tips = np.copy(palm_points)
                finger_tips[:, 0] *= frame.shape[1]
                finger_tips[:, 1] *= frame.shape[0]
                finger_pts[hand_idx].append(finger_tips)

                # Get hand center
                palm_points_mean = np.mean(palm_points, axis=0)
                center_x = int(palm_points_mean[0] * frame.shape[1])
                center_y = int(palm_points_mean[1] * frame.shape[0])
                centers[hand_idx].append((center_x, center_y))
            
                wrist_center_x = int(hand_landmarks.landmark[0].x * frame.shape[1])
                wrist_center_y = int(hand_landmarks.landmark[0].y * frame.shape[0])
                wrists[hand_idx].append((wrist_center_x, wrist_center_y))
                # cv2.circle(frame, (center_x, center_y), 3, (255, 0, 0))
            # for x in range(2):
            #     # Draw current & past hand centers on existing frame
            #     cv2.polylines(frame, [np.array(centers[x])], False, (0,255*(1-x),255*x), 3)
            
        return frame

    # Open video capture
    cap = cv2.VideoCapture(videoPath)
    num_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))


    frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
    fps = int(cap.get(cv2.CAP_PROP_FPS))

    if out:
        out = cv2.VideoWriter(
            os.path.join(f"out.mp4"),
            cv2.VideoWriter_fourcc(*"mp4v"),
            fps, (frame_width, frame_height)
        )

    if not cap.isOpened():
        print("Error: Could not open video.")
        exit()

    # Parameters
    window_size = 20
    stdev_threshold = 5
    frame_width = int(cap.get(3))
    frame_height = int(cap.get(4))

    # Initialize variables
    centers_window = [[],[]]

    on = [False, False]
    ret_frame_point = [] #frame, hand1_point, hand2_point
    for x in tqdm(range(num_frames)):
        ret, frame = cap.read()

        if not ret:
            break

        frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)

        # Get hand center coordinates
        frame = extract_hand_data(frame)
        frame_level = []
        # print(centers)
        for h in range(2):
            hand_center = centers[h][-1] if len(centers[h]) > 0 else None
            # print(hand_center)

            if hand_center is not None:
                # Add hand center to the window
                centers_window[h].append(hand_center)

                # Keep the window size limited to the last 20 frames
                if len(centers_window[h]) > window_size:
                    centers_window[h].pop(0)

                # Calculate standard deviation of distances
                if len(centers_window[h]) == window_size:
                    stdev = calculate_stdev(np.array(centers_window[h]))

                    # Check if stdev is below the threshold
                    if stdev < stdev_threshold:
                        if not on[h]:
                            on[h]=True
                            # Draw a circle on the frame at the average center
                            average_center = tuple(centers_window[h][-1]) #tuple(np.mean(centers_window, axis=0).astype(int))
                            frame_level.append((x,average_center, wrists[h][-1], finger_pts[h][-1]))
                            draw_circle(frame, average_center, RED=h)
                    else:
                        on[h]=False
        if frame_level:
            ret_frame_point.append(frame_level)


        frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)

        # Display the frame
        if out:
            out.write(frame)

        # Exit when 'q' key is pressed
        # if cv2.waitKey(1) & 0xFF == ord('q'):
        #     break

    # Release video capture
    if out:
        out.release()
    cap.release()
    cv2.destroyAllWindows()
    return ret_frame_point

In [3]:
class ContactPointMatching:

    def __init__(self):
        self.REPO_NAME = "facebookresearch/dinov2"
        self.MODEL_NAME = "dinov2_vitb14"


        self.DEFAULT_SMALLER_EDGE_SIZE = 448
        self.IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
        self.IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)

        self.model = torch.hub.load(repo_or_dir=self.REPO_NAME, model=self.MODEL_NAME)
        self.model.eval()

        self.ref_pca_img = None
        self.query_pca_img = None
        self.ref_tokens = None 
        self.query_tokens = None
        self.ref_mask = None
        self.query_mask = None
        self.ref_grid_size = None
        self.query_grid_size = None
        self.ref_scale = None
        self.query_scale = None
        self.heatmap = None

        self.ref_contact_pt = None
        self.pred_contact_pt = None
    
    def zero_pixel(self, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)):
        val = mean[0] / std[0] + mean[1] / std[1] + mean[2]/std[2]
        return -1 * val

    def make_transform(self, smaller_edge_size: int) -> transforms.Compose:
        IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
        IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
        interpolation_mode = transforms.InterpolationMode.BICUBIC

        return transforms.Compose([
            transforms.Resize(size=smaller_edge_size, interpolation=interpolation_mode, antialias=True),
            transforms.ToTensor(),
            transforms.Normalize(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
        ])


    def prepare_image(self, image: Image,
                    smaller_edge_size: float,
                    patch_size: int) -> Tuple[torch.Tensor, Tuple[int, int]]:
        transform = self.make_transform(int(smaller_edge_size))
        image_tensor = transform(image)
        resize_scale = image.width / image_tensor.shape[2]

        # Crop image to dimensions that are a multiple of the patch size
        height, width = image_tensor.shape[1:] # C x H x W
        cropped_width, cropped_height = width - width % patch_size, height - height % patch_size
        image_tensor = image_tensor[:, :cropped_height, :cropped_width]

        grid_size = (cropped_height // patch_size, cropped_width // patch_size) # h x w (TODO: check)
        return image_tensor, grid_size, resize_scale


    def make_foreground_mask(self, image_tensor):
        mask = torch.sum(image_tensor, dim=0)
        threshold = self.zero_pixel()
        mask = (torch.abs(mask - threshold) > 0.001).int()
        new_size = (mask.size(0) // 14, mask.size(1) // 14)
        resized_mask = torch.empty(new_size, dtype=torch.bool)
        for i in range(new_size[0]):
            for j in range(new_size[1]):
                ones = torch.sum(mask[i*14:(i+1)*14, j*14:(j+1)*14])
                if ones <= (14 * 14 * 0.8):
                    resized_mask[i, j] = False
                else:
                    resized_mask[i, j] = True
        
        mask = resized_mask.flatten()
        return mask.flatten()

    def render_patch_pca(self, ref_image: Image,
                        query_image: Image,
                        smaller_edge_size: float = 448,
                        patch_size: int = 14):
        
        ref_image_tensor, self.ref_grid_size, self.ref_scale = self.prepare_image(ref_image, smaller_edge_size, patch_size)
        query_image_tensor, self.query_grid_size, self.query_scale = self.prepare_image(query_image, smaller_edge_size, patch_size)

        print("image shape: ", end="")
        print(ref_image_tensor.shape, query_image_tensor.shape)

        self.ref_mask = self.make_foreground_mask(ref_image_tensor)
        self.query_mask = self.make_foreground_mask(query_image_tensor)

        print("mask shape: ", end="")
        print(self.ref_mask.shape, self.query_mask.shape)

        with torch.inference_mode():
            self.ref_tokens = self.model.get_intermediate_layers(ref_image_tensor.unsqueeze(0))[0].squeeze()
            self.query_tokens = self.model.get_intermediate_layers(query_image_tensor.unsqueeze(0))[0].squeeze()
            # tokens = model(image_batch)

        print("tokens shape: ", end="")
        print(self.ref_tokens.shape, self.query_tokens.shape)

        masked_tokens = torch.cat([self.ref_tokens[self.ref_mask], self.query_tokens[self.query_mask]], dim=0)

        pca = PCA(n_components=3)
        pca.fit(masked_tokens)
        projected_ref_tokens = pca.transform(self.ref_tokens)
        projected_query_tokens = pca.transform(self.query_tokens)

        t = torch.tensor(projected_ref_tokens)
        t_min = t.min(dim=0, keepdim=True).values
        t_max = t.max(dim=0, keepdim=True).values
        normalized_t = (t - t_min) / (t_max - t_min)

        array = (normalized_t * 255).byte().numpy()
        array[~self.ref_mask] = 0
        array = array.reshape(*self.ref_grid_size, 3)
        self.ref_pca_img = Image.fromarray(array).resize((ref_image.width, ref_image.height), 0)

        t = torch.tensor(projected_query_tokens)
        t_min = t.min(dim=0, keepdim=True).values
        t_max = t.max(dim=0, keepdim=True).values
        normalized_t = (t - t_min) / (t_max - t_min)

        array = (normalized_t * 255).byte().numpy()
        array[~self.query_mask] = 0
        array = array.reshape(*self.query_grid_size, 3)
        self.query_pca_img = Image.fromarray(array).resize((query_image.width, query_image.height), 0)

    def source_position_to_idx(self, row, col, grid_size, resize_scale):
        idx = ((row / resize_scale) // (14)) * grid_size[1] + ((col / resize_scale) // (14))
        return int(idx)

    def idx_to_source_position(self, idx, grid_size, resize_scale):
        row = (idx // grid_size[1])*14*resize_scale + 14 / 2
        col = (idx % grid_size[1])*14*resize_scale + 14 / 2
        return int(row), int(col)

    def closest_embedding(self, ref_embedding, query_embeddings, query_mask):
        distances = torch.norm(query_embeddings - ref_embedding, dim=1)
        dist_copy = distances.clone()
        distances[~query_mask] = float('inf')
        return torch.argmin(distances).item(), dist_copy

    def generate_heatmap(self, distances, mask, grid_size, image_size):
        distances = distances.reshape(grid_size)
        mask = mask.reshape(grid_size)
        heatmap_np = distances.numpy()
        heatmap_np *= -1
        heatmap_np = (heatmap_np - np.min(heatmap_np)) / (np.max(heatmap_np) - np.min(heatmap_np))
        heatmap_np[~mask] = 0

        cmap = plt.get_cmap('jet')
        heatmap = cmap(heatmap_np)
        heatmap_rgb = (heatmap[:, :, :3] * 255).astype(np.uint8)
        resized_heatmap = cv2.resize(heatmap_rgb, (image_size[0], image_size[1]))
        
        return resized_heatmap

    def map_ref_contact_point(self, contact_pt):
        self.ref_contact_pt = contact_pt
        idx = self.source_position_to_idx(contact_pt[0], contact_pt[1], self.ref_grid_size, self.ref_scale)
        matched_idx, distances = self.closest_embedding(self.ref_tokens[idx, :], self.query_tokens, self.query_mask)
        self.heatmap = self.generate_heatmap(distances, self.query_mask, self.query_grid_size, self.query_pca_img.size)
        row, col = self.idx_to_source_position(matched_idx, self.query_grid_size, self.query_scale)
        self.pred_contact_pt = [row, col]
        return self.pred_contact_pt
    
    def visualize(self, query_image_path, ref_image_path):
        query_image = Image.open(query_image_path)
        ref_image = Image.open(ref_image_path)

        draw = ImageDraw.Draw(query_image)
        draw.ellipse([self.pred_contact_pt[1]-5, self.pred_contact_pt[0]-5, self.pred_contact_pt[1]+5, self.pred_contact_pt[0]+5], fill=(255, 0, 0))

        draw = ImageDraw.Draw(ref_image)
        draw.ellipse([self.ref_contact_pt[1]-5, self.ref_contact_pt[0]-5, self.ref_contact_pt[1]+5, self.ref_contact_pt[0]+5], fill=(255, 0, 0))

        display(self.ref_pca_img)
        display(self.query_pca_img)
        display(ref_image)
        display(query_image)

        query_image = Image.open(query_image_path)
        overlay = cv2.addWeighted(np.array(query_image), 0.5, self.heatmap, 0.5, 0)

        plt.imshow(overlay)
        plt.axis('off')
        plt.show()
        
    def main(self, ref_img_path, query_img_path, contact_pts):
        query_image = Image.open(query_img_path)
        ref_image = Image.open(ref_img_path)

        self.render_patch_pca(ref_image, query_image)
        
        pred_pts = []
        for i in range(contact_pts.shape[0]):
            pred_pt = self.map_ref_contact_point(contact_pts[i])
            pred_pts.append(pred_pt)

        return pred_pts

In [6]:
class TrajectoryGenerator:

    def __init__(self):
        self.processor = AutoImageProcessor.from_pretrained('facebook/dinov2-base')
        self.hf_model = AutoModel.from_pretrained('facebook/dinov2-base')
        self.reference_embeddings = []
        self.reference_patch_embeddings = []
        self.ref_contact_pts = []
        self.key_frames = []
        self.all_ref_masks = []
        self.query_masks = None
        self.query_patch_embeddings = None
        self.query_embeddings = None
        self.matched_query_masks = None
        self.matched_query_patch_embeddings = None


    def create_mask_images(self, original_image, masks):
        original_image_array = np.array(original_image)
        all_cropped_imgs = []

        for i in range(masks.shape[0]):
            mask = masks[i, :, :]
            if isinstance(mask, np.ndarray):
                mask = torch.from_numpy(mask)
            np_mask = mask.unsqueeze(dim=2).numpy()
            masked_img = (original_image_array * np_mask).astype(np.uint8)  # Convert to uint8
            ys, xs = np.where(mask)
            if ys.size == 0 or xs.size == 0:
                continue
            bbox = np.min(xs), np.min(ys), np.max(xs), np.max(ys)
            cropped_image_array = masked_img[bbox[1]:bbox[3]+1, bbox[0]:bbox[2]+1, :]
            cropped_image_pil = Image.fromarray(cropped_image_array)
            all_cropped_imgs.append(cropped_image_pil)

        return all_cropped_imgs

    def get_obj_masks(self, img_path, point="[[0,0]]", point_label="[0]", filter=True):
        segs = segment(img_path, point_prompt=point, point_label=point_label, filter=filter)
        print(segs.shape)
        return segs

    def get_obj_embeddings(self, img_path, masks):
        image = cv2.imread(img_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        cropped_objs = self.create_mask_images(image, masks)

        if len(cropped_objs) != 0:
            inputs = self.processor(images=cropped_objs, return_tensors="pt")
            outputs = self.hf_model(**inputs)
            last_hidden_states = outputs.last_hidden_state
            # cls_embeddings = last_hidden_states[:, 0, :].squeeze()
            cls_embeddings = last_hidden_states

            return cls_embeddings
        
        return None
    
    def save_frame_as_image(self, video_path, frame_number, save_path, point=None, radius=5, color=(0, 255, 0), thickness=-1):
        cap = cv2.VideoCapture(video_path)
        cap.set(cv2.CAP_PROP_POS_FRAMES, frame_number)
        ret, frame = cap.read()
        cap.release()
        
        # cv2.circle(frame, point, radius, color, thickness)
        cv2.imwrite(save_path, frame)

        return frame

    def closest_pt_to_mask(self, point, mask, crop_transform=False):
        mask = mask.squeeze().numpy()
        dist_transform, indices = distance_transform_edt(1 - mask, return_indices=True)
        closest_point = [indices[1, point[1], point[0]], indices[0, point[1], point[0]]]

        if crop_transform:
            ys, xs = np.where(mask)
            x_min, y_min = np.min(xs), np.min(ys)
            closest_point[0] = closest_point[0] - x_min
            closest_point[1] = closest_point[1] - y_min

        return tuple(closest_point)
    
    def get_contour_points(self, mask):
        mask_np = mask.cpu().numpy().astype(np.uint8)
        contours, _ = cv2.findContours(mask_np, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
        points = np.vstack(contours).squeeze()
        return points

    def match_points(self, points1, points2, threshold):
        matched_points = []
        used_indices = set()
        for p1 in points1:
            min_dist = float('inf')
            match_idx = -1
            for idx, p2 in enumerate(points2):
                if idx in used_indices:
                    continue
                dist = np.linalg.norm(p1 - p2)
                if dist < min_dist and dist < threshold:
                    min_dist = dist
                    match_idx = idx
            if match_idx != -1:
                matched_points.append((p1, points2[match_idx]))
                used_indices.add(match_idx)
        return matched_points
    
    def find_centroid(self, mask):
        non_zero_coords = mask.nonzero(as_tuple=True)
        if len(non_zero_coords[0]) == 0:
            return None
        centroid_y = non_zero_coords[0].float().mean().item()
        centroid_x = non_zero_coords[1].float().mean().item()
        return int(centroid_y), int(centroid_x)

    def visualize_masks_on_frame(self, frame, masks, save_path):
        frame_np = np.array(frame)
        overlay = frame_np.copy()
        centroids = []
        colors = [
            (255, 0, 0),    # Red
            (0, 255, 0),    # Green
            (0, 0, 255),    # Blue
            (255, 255, 0),  # Yellow
            (255, 0, 255),  # Magenta
            (0, 255, 255),  # Cyan
            (255, 165, 0),  # Orange
            (128, 0, 128),  # Purple
            (128, 128, 0),  # Olive
            (0, 128, 128)   # Teal
        ]

        for idx, mask in enumerate(masks):
            color = colors[idx % len(colors)]
            mask_np = mask.cpu().numpy().astype(np.uint8)
            colored_mask = np.zeros_like(frame_np)
            for j in range(3):
                colored_mask[:, :, j] = mask_np * color[j]
            overlay = cv2.addWeighted(overlay, 1, colored_mask, 0.5, 0)
            centroid = self.find_centroid(mask)
            if centroid:
                centroids.append(centroid)
                cv2.circle(overlay, (centroid[1], centroid[0]), 8, color, -1)

        for i in range(1, len(centroids)):
            cv2.line(overlay, (centroids[i-1][1], centroids[i-1][0]), (centroids[i][1], centroids[i][0]), colors[i - 1], 4)

        overlay_rgb = cv2.cvtColor(overlay, cv2.COLOR_BGR2RGB)
        result_image = Image.fromarray(overlay_rgb)
        result_image.save(save_path)

    def find_interacting_mask(self, masks, mask, unseen=None, distance_threshold=10):
        if unseen is None:
            unseen = torch.ones(masks.shape[0], dtype=torch.bool)

        max_matches = 0
        bestMatches = None
        interacting_mask = None
        mask2 = mask
        for i in range(masks.shape[0]):  
            mask1 = masks[i]
            if unseen[i] == False or torch.equal(mask1, mask2):
                unseen[i] = False
                continue 

            points1 = self.get_contour_points(mask1)
            points2 = self.get_contour_points(mask2)
            
            matches = self.match_points(points1, points2, distance_threshold)
            
            if len(matches) > max_matches:
                max_matches = len(matches)
                bestMatches = matches
                interacting_mask = mask1
        
        return interacting_mask, bestMatches

    def find_closest_mask(self, masks, point1, point2, num_to_match=10):
        min_distance = float('inf')
        min_idx = -1
        closest_mask = None
        
        for i, mask in enumerate(masks):
            contour_points = self.get_contour_points(mask)
            distances_point1 = np.linalg.norm(contour_points - np.array(point1), axis=1)
            distances_point2 = np.linalg.norm(contour_points - np.array(point2), axis=1)
            closest_distances_point1 = np.sort(distances_point1)   # [:num_to_match]
            closest_distances_point2 = np.sort(distances_point2)   # [:num_to_match]
            total_distance = np.mean(closest_distances_point1) + np.mean(closest_distances_point2)
            
            if total_distance < min_distance:
                min_distance = total_distance
                min_idx = i
                closest_mask = mask
                
        return closest_mask, min_idx
    
    def filter_arms_mask_by_points(self, masks, wrist_point):
        idx = -1
        area = float('-inf')
        for i in range(masks.shape[0]):
            if masks[i, wrist_point[0], wrist_point[1]] and torch.sum(masks[i]) > area:
                area = torch.sum(masks[i])
                idx = i

        if idx == -1:
            print("No mask that contains wrist point.")
            return masks

        mask = torch.ones(masks.shape[0], dtype=torch.bool)
        mask[idx] = False
        filtered_masks = masks[mask]

        return filtered_masks
    
    def find_interaction_chain(self, masks, mask):
        interaction_chain = []
        contact_pts_chain = []
        prev_mask = None
        matches = None
        mask_ratio_thresh = 0.5
        unseen = torch.ones(masks.shape[0], dtype=torch.bool)
        while (prev_mask is None or torch.sum(mask) > mask_ratio_thresh * torch.sum(prev_mask)):
            prev_mask = mask
            interaction_chain.append(prev_mask)
            contact_pts_chain.append(matches)
            mask, matches = self.find_interacting_mask(masks, prev_mask, unseen)
        
        return interaction_chain, contact_pts_chain
    
    def process_ref_video(self, ref_video_path):
        frames = get_frame_point(ref_video_path)
        for i, points in enumerate(frames):
            point = points[0]
            frame_num, fingers_point, wrist_point, finger_tips = point
            save_path = f"../images/current_vid_{str(i)}.jpg"
            frame = self.save_frame_as_image(ref_video_path, frame_num, save_path, fingers_point)
            # point_input = f"[[{str(fingers_point[0])},{str(fingers_point[1])}], [{str(wrist_point[0])},{str(wrist_point[1])}]]"
            
            masks = self.get_obj_masks(save_path)
            filtered_masks = self.filter_arms_mask_by_points(masks, wrist_point[::-1])
            self.all_ref_masks.append(masks)

            # find closest mask
            thumb_pt = [int(finger_tips[0, 0]), int(finger_tips[0, 1])]
            index_pt = [int(finger_tips[1, 0]), int(finger_tips[1, 1])]

            closest_mask, closest_idx = self.find_closest_mask(filtered_masks, thumb_pt, index_pt)

            # compute closest finger tips to mask

            thumb_pt = self.closest_pt_to_mask(thumb_pt, closest_mask, crop_transform=True)
            index_pt = self.closest_pt_to_mask(index_pt, closest_mask, crop_transform=True)

            self.ref_contact_pts.append([[thumb_pt[0], thumb_pt[1]], [index_pt[0], index_pt[1]]])

            # embeddings
            embedding = self.get_obj_embeddings(save_path, closest_mask.unsqueeze(dim=0))
            cls_embedding = embedding[:, 0, :].squeeze()
            self.reference_embeddings.append(cls_embedding)
            self.reference_patch_embeddings.append(embedding[:, 1:, :].squeeze())

            # find interaction chain
            interaction_chain, contact_pts_chain = self.find_interaction_chain(filtered_masks, closest_mask)


            output_path = f"./output/current_vid_{str(i)}.jpg"
            output_path2 = f"./output/current_interaction_{str(i)}.jpg"

            frame = cv2.imread(save_path)

            self.visualize_masks_on_frame(frame, interaction_chain, output_path2)
            print(f"Written interaction image to {output_path2}")


            masked_frame = self.create_mask_images(frame, closest_mask.unsqueeze(dim=0))[0]
            masked_frame_np = np.array(masked_frame)

            # saving image of only closest mask (cropped)
            cv2.imwrite(output_path, masked_frame_np)
            print(f"Written closest mask image to {output_path}")


        self.ref_contact_pts = torch.tensor(self.ref_contact_pts)
        self.reference_embeddings = torch.stack(self.reference_embeddings)
        self.reference_patch_embeddings = torch.stack(self.reference_patch_embeddings)
    
    def process_query_image(self, query_img_path):
        self.query_masks = self.get_obj_masks(query_img_path, filter=False)
        query_embeddings = self.get_obj_embeddings(query_img_path, self.query_masks)
        self.query_patch_embeddings = query_embeddings[:, 1:, :]
        self.query_embeddings = query_embeddings[:, 0, :].squeeze()    
    
    def sim_matrix(self, a, b, eps=1e-8):
        a_n, b_n = a.norm(dim=1)[:, None], b.norm(dim=1)[:, None]
        a_norm = a / torch.max(a_n, eps * torch.ones_like(a_n))
        b_norm = b / torch.max(b_n, eps * torch.ones_like(b_n))
        sim_mt = torch.mm(a_norm, b_norm.transpose(0, 1))
        return sim_mt
    
    def match_ref_and_query(self, query_path):
        similarities = self.sim_matrix(self.query_embeddings, self.reference_embeddings)
        matched_query_masks_idx = torch.argmax(similarities, dim=0)
        self.matched_query_masks = self.query_masks[matched_query_masks_idx, :, :]
        self.matched_query_patch_embeddings = self.query_patch_embeddings[matched_query_masks_idx, :, :]

        query_img = cv2.imread(query_path)
        for i in range(self.matched_query_masks.shape[0]):
            masked_frame = self.create_mask_images(query_img, self.matched_query_masks[i:i+1])[0]
            masked_frame_np = np.array(masked_frame)

            output_path = f"./output/query_img_{str(i)}.jpg"
            cv2.imwrite(output_path, masked_frame_np)
            print(f"Written matching query image to {output_path}")

    
    def main(self, ref_path, query_img_path):
        self.process_ref_video(ref_path)
        self.process_query_image(query_img_path)
        self.match_ref_and_query(query_img_path)
        contact_matcher = ContactPointMatching()

        pred_pts = []
        num_frames = len(self.key_frames)
        for i in range(num_frames):
            ref_frame_path = f"./output/current_vid_{str(i)}.jpg"
            query_obj_path = f"./output/query_img_{str(i)}.jpg"

            contact_pts = self.ref_contact_pts[i]
            pred_pts.append(contact_matcher.main(ref_frame_path, query_obj_path, contact_pts))



In [7]:
tg = TrajectoryGenerator()
tg.main("../videos/IMG_3288.MOV", "../images/query_img.jpg")

100%|██████████| 511/511 [00:29<00:00, 17.16it/s]

0: 576x1024 37 objects, 4637.7ms
Speed: 15.0ms preprocess, 4637.7ms inference, 917.3ms postprocess per image at shape (1, 3, 1024, 1024)


torch.Size([10, 1080, 1920])
Written interaction image to ./output/current_interaction_0.jpg
Written closest mask image to ./output/current_vid_0.jpg



0: 576x1024 39 objects, 5009.9ms
Speed: 14.0ms preprocess, 5009.9ms inference, 1168.7ms postprocess per image at shape (1, 3, 1024, 1024)


torch.Size([14, 1080, 1920])
Written interaction image to ./output/current_interaction_1.jpg
Written closest mask image to ./output/current_vid_1.jpg



0: 576x1024 30 objects, 5356.2ms
Speed: 15.2ms preprocess, 5356.2ms inference, 849.4ms postprocess per image at shape (1, 3, 1024, 1024)


torch.Size([16, 1080, 1920])
Written interaction image to ./output/current_interaction_2.jpg
Written closest mask image to ./output/current_vid_2.jpg



0: 576x1024 34 objects, 5494.3ms
Speed: 67.5ms preprocess, 5494.3ms inference, 1015.7ms postprocess per image at shape (1, 3, 1024, 1024)


torch.Size([7, 1080, 1920])
Written interaction image to ./output/current_interaction_3.jpg
Written closest mask image to ./output/current_vid_3.jpg



0: 576x1024 42 objects, 5195.5ms
Speed: 16.0ms preprocess, 5195.5ms inference, 1002.3ms postprocess per image at shape (1, 3, 1024, 1024)


torch.Size([15, 1080, 1920])
Written interaction image to ./output/current_interaction_4.jpg
Written closest mask image to ./output/current_vid_4.jpg



0: 576x1024 43 objects, 4695.4ms
Speed: 14.4ms preprocess, 4695.4ms inference, 1023.9ms postprocess per image at shape (1, 3, 1024, 1024)


torch.Size([14, 1080, 1920])
Written interaction image to ./output/current_interaction_5.jpg
Written closest mask image to ./output/current_vid_5.jpg



0: 576x1024 48 objects, 4829.7ms
Speed: 16.3ms preprocess, 4829.7ms inference, 1072.1ms postprocess per image at shape (1, 3, 1024, 1024)


torch.Size([16, 1080, 1920])
Written interaction image to ./output/current_interaction_6.jpg
Written closest mask image to ./output/current_vid_6.jpg



0: 576x1024 45 objects, 4891.0ms
Speed: 15.0ms preprocess, 4891.0ms inference, 976.2ms postprocess per image at shape (1, 3, 1024, 1024)


torch.Size([15, 1080, 1920])
Written interaction image to ./output/current_interaction_7.jpg
Written closest mask image to ./output/current_vid_7.jpg



0: 544x1024 28 objects, 4602.2ms
Speed: 14.1ms preprocess, 4602.2ms inference, 604.4ms postprocess per image at shape (1, 3, 1024, 1024)


torch.Size([28, 970, 1834])


Using cache found in C:\Users\arshs/.cache\torch\hub\facebookresearch_dinov2_main
