In [1]:
from PIL import Image
from matplotlib import pyplot as plt
import numpy as np
from lang_sam import LangSAM
import torch
import os
import cv2
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
import pandas as pd
import torch
from sklearn.metrics.pairwise import cosine_similarity

In [2]:
#Configs
# path to directory to dump embeddings
run_id = 'run_1'
# number of frames to extract in a minute
frame_rate = 5
#maxID : number of products to process starting: 1001
max_id = 1999

In [3]:
def get_device_type() -> str:
    if torch.backends.mps.is_available():
        return "cpu"
    elif torch.cuda.is_available():
        return "cuda"
    else:
        # logging.warning("No GPU found, using CPU instead")
        return "cpu"

In [4]:
class ObjectDetector:
    def __init__(self):
        self.model = LangSAM(sam_type="sam2.1_hiera_tiny", gdino_type="tiny")
    
    def predict(self, image_pil, text_prompt):
        results = self.model.predict([image_pil], [text_prompt])
        return results
    
    def plot_results(self, image_pil, results, text_prompt):
        
        # Convert the mask to a numpy array
        mask = results[0]['masks'][0]

        # Plot the image
        plt.figure(figsize=(10, 10))
        plt.imshow(image_pil)

        # Plot the bounding box
        box = results[0]['boxes'][0]
        plt.gca().add_patch(plt.Rectangle((box[0], box[1]), box[2] - box[0], box[3] - box[1], edgecolor='red', facecolor='none', linewidth=2))

        # Plot the mask
        plt.imshow(np.ma.masked_where(mask == 0, mask), alpha=0.5, cmap='jet')

        plt.title(f"Prediction for: {text_prompt}")
        plt.axis('off')
        plt.show()

    # given an image and results, hide the mask on the image by making the values zero within the mask
    def hide_mask(self, image_pil, results):
        mask = results[0]['masks'][0]
        image_np = np.array(image_pil)
        image_np[mask == 1] = 0
        return Image.fromarray(image_np)

    def crop_image(self, image_pil, results):
        box = results[0]['boxes'][0]
        cropped_image = image_pil.crop((box[0], box[1], box[2], box[3]))
        return cropped_image
    
    def predict_and_crop_image(self, image_pil, text_prompt):
        results = self.predict(image_pil, text_prompt)
        cropped_image = self.crop_image(image_pil, results)
        return cropped_image

    def show_image(self, image_pil):
        plt.imshow(image_pil)
        plt.axis('off')
        plt.show()

    def resize(self, image_pil, long_side):
        width, height = image_pil.size
        if width > height:
            new_width = long_side
            new_height = int(long_side * height / width)
        else:
            new_height = long_side
            new_width = int(long_side * width / height)
        return image_pil.resize((new_width, new_height))
object_detector = ObjectDetector()

In [5]:
def extract_frames(video_path, frame_rate):
    video_capture = cv2.VideoCapture(video_path)
    fps = video_capture.get(cv2.CAP_PROP_FPS)
    interval = int(fps / frame_rate)
    frame_number = 0
    frames = []
    while True:
        success, frame = video_capture.read()
        if not success:
            break
        if frame_number % interval == 0:
            frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
            frames.append(frame_rgb)

        frame_number += 1
    video_capture.release()
    return frames

In [6]:
def toFilter(path, maxNumber= 1999):
    try:
        number = path.split('_')[0]
        return int(number) <= maxNumber
    except:
        return False

In [7]:
embeddings = {}
video_frame_map = {}
video_frame_map_cropped = {}
video_frame_map_masked = {}
def load_embeddings():
    path_to_videos = "./sourceDataVideos/"
    videos = os.listdir(path_to_videos)
    # videos = videos[:2]
    print(videos)
    count = 0
    for video in videos:
        try:
            if(not toFilter(video, max_id)):
                continue
            print("--------------------- Processing video: ", count, video)
            count += 1
            video_path = os.path.join(path_to_videos, video)
            frames = extract_frames(video_path, frame_rate)
            print("Number of frames: ", len(frames))
            id = video.split("_")[0]
            frame_number = 0
            for frame in frames:
                print("Processing frame: ", count, frame_number)
                image = Image.fromarray(frame)
                resultsMasked = object_detector.predict(image, "hand, no background, no object")
                masked_image = object_detector.hide_mask(image, resultsMasked)
                cropped_image = object_detector.predict_and_crop_image(masked_image, "object, no hand, no background")
                # inputs = embeddingProcessor(images=cropped_image, return_tensors="pt", padding=True).to(get_device_type())
                # with torch.no_grad():
                #     features = embeddingModel(**inputs)
                #     features.image_embeds = features.image_embeds
                # embeddings[id + "_" + str(frame_number)] = features
                # video_frame_map[id + "_" + str(frame_number)] = image
                video_frame_map_cropped[id + "_" + str(frame_number)] = cropped_image
                # video_frame_map_masked[id + "_" + str(frame_number)] = masked_image
                frame_number += 1
        except Exception as e:
            print("failed processing video",video, e)
load_embeddings()

['1014_ButterDelite-Biscuit_5.mp4', '1080_Sugar500g-Grains_24.mp4', '1022_Lux-Soap_30.mp4', '1030_SurfExcel-Surf_10.mp4', '1007_XXX-Soap_10.mp4', '1059_TurmericPowder-Spices_5.mp4', '1069_SmallRava1kg-Grains_40.mp4', '1001_Dove-Soap_35.mp4', '1036_Lime-Pickle_10.mp4', '1035_Mango-Pickle_10.mp4', '1039_Wheel-Surf_38.mp4', '1040_XXX-Surf_38.mp4', '1003_Lifeboy-Soap_30.mp4', '1031_Wheel-Surf_10.mp4', '1015_AllRounder-Biscuit_5.mp4', "1011_Mom'sMagic-Biscuit_11.mp4", '1083_ChanaDal500g-Grains_39.mp4', '1073_ChanaDal1kg-Grains_78.mp4', '1055_GodavariGhee-Milk_10.mp4', '1024_Dettol-Soap_10.mp4', '1021_Glucose-Biscuit_5.mp4', '1056_Yippee-Noodles_10.mp4', '1019_KrackJack-Biscuit_10.mp4', '1032_Ariel-Surf_10.mp4', '1018_Glucose-Biscuit_10.mp4', '1034_Tide-Surf_45.mp4', '1016_Oreo-Biscuit_10.mp4', '1042_Parachute-Oil_20.mp4', '1061_NaniGhee-Milk_10.mp4', '1017_KrackJack-Biscuit_5.mp4', '1047_Yippee-Noodles_14.mp4', '1066_GroundNut1kg-Grains_140.mp4', '1062_ChickenMasala-Spices_5.mp4', '1058_Efk

  img = torch.from_numpy(pic.transpose((2, 0, 1))).contiguous()


Predicting 1 masks
Predicted 1 masks
Processing frame:  1 1
Predicting 1 masks
Predicted 1 masks
Predicting 1 masks
Predicted 1 masks
Processing frame:  1 2
Predicting 1 masks
Predicted 1 masks
Predicting 1 masks
Predicted 1 masks
Processing frame:  1 3
Predicting 1 masks
Predicted 1 masks
Predicting 1 masks
Predicted 1 masks
Processing frame:  1 4
Predicting 1 masks
Predicted 1 masks
Predicting 1 masks
Predicted 1 masks
Processing frame:  1 5
Predicting 1 masks
Predicted 1 masks
Predicting 1 masks
Predicted 1 masks
Processing frame:  1 6
Predicting 1 masks
Predicted 1 masks
Predicting 1 masks
Predicted 1 masks
Processing frame:  1 7
Predicting 1 masks
Predicted 1 masks
Predicting 1 masks
Predicted 1 masks
Processing frame:  1 8
Predicting 1 masks
Predicted 1 masks
Predicting 1 masks
Predicted 1 masks
Processing frame:  1 9
Predicting 1 masks
Predicted 1 masks
Predicting 1 masks
Predicted 1 masks
Processing frame:  1 10
Predicting 1 masks
Predicted 1 masks
Predicting 1 masks
Predicted 

In [8]:
# images = {}
# image_map = {}
image_cropped_map = {}
# image_masked_map = {}
def load_images():
    path_to_images = "./sourceDataImages/"
    imageDirs = os.listdir(path_to_images)
    print(imageDirs)
    imageDirCount = 0
    for imageDir in imageDirs:
        try:
            if(not toFilter(imageDir, max_id)):
                continue
            print("Processing imageDir: ", imageDirCount)
            imagePaths = os.listdir(os.path.join(path_to_images, imageDir))
            count = 0
            print("Processing imageDir: ", imageDir)
            for imagePath in imagePaths:
                image = Image.open(os.path.join(path_to_images, imageDir, imagePath))
                resultsMasked = object_detector.predict(image, "hand, no background, no object")
                masked_image = object_detector.hide_mask(image, resultsMasked)
                cropped_image = object_detector.predict_and_crop_image(masked_image, "object, no hand, no background")
                # inputs = embeddingProcessor(images=cropped_image, return_tensors="pt", padding=True).to(get_device_type())
                # with torch.no_grad():
                #     features = embeddingModel(**inputs)
                #     features.image_embeds = features.image_embeds
                # images[imageDir + "_" + str(count) ] = features
                # image_map[imageDir + "_" + str(count) ] = image
                image_cropped_map[imageDir + "_" + str(count) ] = cropped_image
                # image_masked_map[imageDir + "_" + str(count)] = masked_image
                count += 1
        except:
            print("Error processing imageDir: ", imageDir)
        imageDirCount += 1
        
load_images()

['1056', '1016', '1026', '1012', '1019', '1025', '1037', '1035', '1009', '1021', '1051', '1067', '1065', '1042', '1071', '1057', '1070', '1003', '1074', '1058', '1005', '1050', '1046', '1007', '1008', '1032', '1020', '1077', '1036', '1022', '1048', '1055', '1062', '1049', '1023', '1034', '1081', '1045', '1029', '1075', '1054', '1082', '1066', '1064', '1002', '1017', '1024', '1018', '1080', '1011', '1060', '1039', '1030', '.DS_Store', '1078', '1013', '1040', '1061', '1063', '1028', '1015', '1033', '1041', '1038', '1014', '1079', '1043', '1068', '1083', '1076', '1069', '1073', '1010', '1001', '1031', '1044', '1027', '1052', '1072', '1006', '1053', '1047', '1059']
Processing imageDir:  0
Processing imageDir:  1056
Predicting 1 masks
Predicted 1 masks
Predicting 1 masks
Predicted 1 masks
Predicting 1 masks
Predicted 1 masks
Predicting 1 masks
Predicted 1 masks
Processing imageDir:  1
Processing imageDir:  1016
Predicting 1 masks
Predicted 1 masks
Predicting 1 masks
Predicted 1 masks
Predic

In [9]:
saveDir = "./SegementedData/InputImages"
if not os.path.exists(saveDir):
    os.makedirs(saveDir)
for key, value in video_frame_map_cropped.items():
    directory = key.split("_")[0]
    number = key.split("_")[1]
    if not os.path.exists(os.path.join(saveDir, directory)):
        os.makedirs(os.path.join(saveDir, directory))
    value.save(os.path.join(os.path.join(saveDir, directory), key + ".png"))

In [10]:
saveDir = "./SegementedData/OutputImages"
if not os.path.exists(saveDir):
    os.makedirs(saveDir)
for key, value in image_cropped_map.items():
    directory = key.split("_")[0]
    number = key.split("_")[1]
    if not os.path.exists(os.path.join(saveDir, directory)):
        os.makedirs(os.path.join(saveDir, directory))
    value.save(os.path.join(os.path.join(saveDir, directory), key + ".png"))