# Image Segmentation for EyeWear Extraction 
## Overview
This project is developed to segment eyewear from the background and enhance them for display as catalogue images in the website. The algorithm depends on 
* Yolo-NAS Object Detection Model
    * To detect Eyewear, turntable and background
* Segment AnyThing Model (Facebook Research)
    * To segment the mask out of the Detection boxes
* ML model based Image enhancement (CV handcrafted Models)
    * Image Enhancement for better Resolution & Sharpness

## Project Setup
### Path setup

In [1]:
!nvidia-smi

Sun Feb 11 15:25:45 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.154.05             Driver Version: 535.154.05   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  Tesla T4                       Off | 00000001:00:00.0 Off |                  Off |
| N/A   43C    P8               9W /  70W |      2MiB / 16384MiB |      0%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
                                                                    

## Add Parameters

In [2]:
import os

In [3]:
OBJ_MODEL_ARCH = 'yolo_nas_l'
MODEL_TYPE = "vit_h"
SAM_MODEL_ARCH = "sam_vit_h_4b8939.pth"
IMAGE_TYPE = '.jpeg'
WORKSPACE = os.getcwd()
MODEL_DIR = os.path.join(WORKSPACE, "../models/pth")
IMAGE_DIR = os.path.join(WORKSPACE, "../data/NATT")

In [4]:
print("WORKSPACE:", WORKSPACE)
print("MODEL DIR:", MODEL_DIR)

WORKSPACE: /workspace/notebook
MODEL DIR: /workspace/notebook/../models/pth


### Download Necessary Models 

In [5]:
MODEL_PATH = os.path.join(MODEL_DIR, SAM_MODEL_ARCH)
if not os.path.exists(MODEL_PATH):
    !wget -q https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth -P {MODEL_PATH}

### Imports

In [6]:
import torch
import cv2
import numpy
from matplotlib import pyplot as plt
from super_gradients.training import models
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor
from matplotlib import pyplot as plt
import supervision as sv

The console stream is logged into /root/sg_logs/console.log


[2024-02-11 15:25:55] INFO - crash_tips_setup.py - Crash tips is enabled. You can set your environment variable to CRASH_HANDLER=FALSE to disable it


### Load Models
* SAM
* OD - YOLO-NAS

In [7]:
# Setup Device First
DEVICE = 'cuda' if torch.cuda.is_available() else "cpu"

In [8]:
# Load SAM Model
sam_model = sam_model_registry[MODEL_TYPE](checkpoint=MODEL_PATH).to(device=DEVICE)

IsADirectoryError: [Errno 21] Is a directory: '/workspace/notebook/../models/pth/sam_vit_h_4b8939.pth'

In [None]:
od_model = models.get(OBJ_MODEL_ARCH, pretrained_weights="coco").to(DEVICE)

## Load test Images from the Test Image Directory

In [None]:
impaths = list()
for root, dirs, files in os.walk(IMAGE_DIR):
    for file in files:
        if file.endswith(IMAGE_TYPE):
            impaths.append(os.path.join(root, file))

print("{} test images found.".format(len(impaths)))

In [None]:
%matplotlib inline
for idx, impath in enumerate(impaths[:8]):
    img = cv2.imread(impath)
    #Show the image with matplotlib
    plt.subplot(4, 2, idx+1)
    plt.axis('off')
    plt.imshow(img)

## Object Detection 

In [None]:
image_bgr = cv2.imread(impaths[0])
image_bgr = cv2.resize(image_bgr, (1200, 720), interpolation = cv2.INTER_LINEAR)
image_rgb = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB)

In [None]:
result = list(od_model.predict(image_rgb, conf=0.15))[0]

In [None]:
detections = sv.Detections(
    xyxy=result.prediction.bboxes_xyxy,
    confidence=result.prediction.confidence,
    class_id=result.prediction.labels.astype(int)
)

box_annotator = sv.BoxAnnotator()

labels = [
    f"{result.class_names[class_id]} {confidence:0.2f}"
    for _, _, confidence, class_id, _
    in detections
]

annotated_frame = box_annotator.annotate(
    scene=image_rgb.copy(),
    detections=detections,
    labels=labels
)

In [None]:
%matplotlib inline
sv.plot_image(annotated_frame, (12, 12))

## Detection and SAM Model 

In [None]:
image_bgr = cv2.imread(impaths[0])
image_bgr = cv2.resize(image_bgr, (1200, 720), interpolation = cv2.INTER_LINEAR)
mask_generator = SamAutomaticMaskGenerator(sam_model)

In [None]:
image_rgb = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB)
sam_result = mask_generator.generate(image_rgb)

In [None]:
%matplotlib inline
for idx, ires in enumerate(sam_result):
    start = (sam_result[idx]["bbox"][0], sam_result[idx]["bbox"][1])
    end = (sam_result[idx]["bbox"][2], sam_result[idx]["bbox"][3])
    test_img = cv2.rectangle(image_bgr, start, end, (200, 0, 200), 2)
    plt.imshow(test_img)
    plt.show()
    

In [None]:
mask_annotator = sv.MaskAnnotator(color_lookup=sv.ColorLookup.INDEX)
detections = sv.Detections.from_sam(sam_result=sam_result)
annotated_image = mask_annotator.annotate(scene=image_bgr.copy(), detections=detections)

sv.plot_images_grid(
    images=[image_bgr, annotated_image],
    grid_size=(1, 2),
    titles=['source image', 'segmented image']
)