# Logo and Icon Extraction Debugging

This notebook allows you to interactively test and visualize the logo processing pipeline, specifically focusing on:
1. Background removal (`rembg`)
2. Text detection (`EasyOCR`)
3. Icon extraction (`MobileSAM` with positive/negative prompts)

In [None]:
%matplotlib inline
import os
import cv2
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
from rembg import remove
import easyocr
import torch
from mobile_sam import sam_model_registry, SamPredictor
import sys

# Set working directory to the processor directory if needed
# os.chdir('services/enrichment-processor')

print(f"Python version: {sys.version}")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")

## 1. Load Image

In [None]:
# Path to a sample logo
input_path = "../../data/logos/DMP.png" # Adjust as needed
if not os.path.exists(input_path):
    print("Warning: Sample image not found. Please provide a valid path.")
    # You can also use a URL
    # import httpx
    # from io import BytesIO
    # response = httpx.get("URL_HERE")
    # img = Image.open(BytesIO(response.content))
else:
    img = Image.open(input_path)
    plt.imshow(img)
    plt.title("Original Image")
    plt.show()

## 2. Background Removal

In [None]:
if 'img' in locals():
    full_logo = remove(img)
    # Crop to content
    bbox = full_logo.getbbox()
    if bbox:
        full_logo = full_logo.crop(bbox)
    
    plt.imshow(full_logo)
    plt.title("Background Removed & Cropped")
    plt.show()

## 3. Text Detection

In [None]:
if 'full_logo' in locals():
    # Convert to BGR for OCR
    image_rgba = np.array(full_logo)
    image_bgr = cv2.cvtColor(image_rgba, cv2.COLOR_RGBA2BGR)
    h, w = image_bgr.shape[:2]

    reader = easyocr.Reader(['en'])
    results = reader.readtext(image_bgr)
    
    # Visualize detection
    vis_img = image_bgr.copy()
    text_mask = np.zeros((h, w), dtype=np.uint8)
    text_centers = []
    
    for (bbox, text, prob) in results:
        pts = np.array(bbox, np.int32)
        cv2.polylines(vis_img, [pts], True, (0, 255, 0), 2)
        cv2.fillPoly(text_mask, [pts], 255)
        
        center_x = int(np.mean(pts[:, 0]))
        center_y = int(np.mean(pts[:, 1]))
        text_centers.append([center_x, center_y])
        
    plt.figure(figsize=(10, 5))
    plt.subplot(1, 2, 1)
    plt.imshow(cv2.cvtColor(vis_img, cv2.COLOR_BGR2RGB))
    plt.title("OCR Detections")
    
    plt.subplot(1, 2, 2)
    plt.imshow(text_mask, cmap='gray')
    plt.title("Text Mask")
    plt.show()

## 4. Icon Extraction (SAM)

In [None]:
if 'full_logo' in locals():
    # Initialize SAM
    model_type = "vit_t"
    sam_checkpoint = "mobile_sam.pt"
    device = "cuda" if torch.cuda.is_available() else "cpu"
    
    mobile_sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
    mobile_sam.to(device=device)
    mobile_sam.eval()
    predictor = SamPredictor(mobile_sam)
    
    # Prepare prompts
    # 1. Expand text mask for safety
    kernel = np.ones((5, 5), np.uint8)
    dilated_text_mask = cv2.dilate(text_mask, kernel, iterations=2)
    
    # 2. Find potential icon regions
    alpha = np.array(full_logo.split()[-1])
    content_mask = (alpha > 0).astype(np.uint8) * 255
    non_text_content = cv2.bitwise_and(content_mask, cv2.bitwise_not(dilated_text_mask))
    
    num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(non_text_content)
    
    icon_points = []
    if num_labels > 1:
        largest_label = 1 + np.argmax(stats[1:, cv2.CC_STAT_AREA])
        centroid = centroids[largest_label]
        icon_points.append([int(centroid[0]), int(centroid[1])])
    
    # Run SAM
    predictor.set_image(image_bgr)
    
    input_points = np.array(icon_points + text_centers)
    input_labels = np.array([1]*len(icon_points) + [0]*len(text_centers))
    
    masks, scores, logits = predictor.predict(
        point_coords=input_points,
        point_labels=input_labels,
        multimask_output=True,
    )
    
    # Visualize results
    plt.figure(figsize=(15, 5))
    for i, (mask, score) in enumerate(zip(masks, scores)):
        plt.subplot(1, 3, i+1)
        plt.imshow(mask, cmap='gray')
        # Overlay points
        plt.scatter(input_points[input_labels==1, 0], input_points[input_labels==1, 1], color='green', marker='*')
        plt.scatter(input_points[input_labels==0, 0], input_points[input_labels==0, 1], color='red', marker='x')
        plt.title(f"Mask {i} (Score: {score:.2f})")
    plt.show()

## 5. Final Extraction

In [None]:
if 'masks' in locals():
    # Criteria: low text overlap AND high SAM score
    best_mask = None
    max_score = -1.0
    
    for i, (mask, score) in enumerate(zip(masks, scores)):
        mask_area = mask.sum()
        overlap = np.logical_and(mask, text_mask).sum() / np.maximum(mask_area, 1)
        
        if overlap < 0.1 and score > max_score:
            max_score = score
            best_mask = mask
            
    if best_mask is None:
        print("Using fallback: direct text removal")
        best_mask = non_text_content > 0
        
    # Extract and crop
    final_np = np.array(full_logo)
    final_np[~best_mask, 3] = 0
    final_img = Image.fromarray(final_np)
    bbox = final_img.getbbox()
    if bbox:
        final_img = final_img.crop(bbox)
        
    plt.imshow(final_img)
    plt.title("Final Extracted Icon")
    plt.show()