# Ring Matcher Notebook
This notebook sets up the Ring Matcher pipeline:
- Installs dependencies
- Loads models (SAM, CLIP, Real-ESRGAN, YOLO)
- Uploads and indexes catalog images
- Processes a query ring image
- Performs rotation-invariant similarity matching

In [ ]:
# ===============================
# 1️⃣ Setup: Install dependencies (Run once)
# ===============================
import os
import subprocess

def install_packages():
    subprocess.run(["pip", "install", "torch", "torchvision", "--index-url", "https://download.pytorch.org/whl/cpu"])
    subprocess.run(["pip", "install", "git+https://github.com/facebookresearch/segment-anything.git"])
    subprocess.run(["pip", "install", "basicsr", "--no-deps"])
    subprocess.run(["pip", "install", "gfpgan", "facexlib"])
    subprocess.run(["pip", "install", "git+https://github.com/xinntao/Real-ESRGAN.git@v0.3.0"])
    subprocess.run(["pip", "install", "git+https://github.com/openai/CLIP.git"])
    subprocess.run(["pip", "install", "ultralytics"])

# install_packages()  # Uncomment to install packages

In [ ]:
# ===============================
# 2️⃣ Imports
# ===============================
import shutil
import torch
import torchvision
import cv2
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
from segment_anything import sam_model_registry, SamPredictor
from realesrgan import RealESRGANer
from basicsr.archs.rrdbnet_arch import RRDBNet
from ultralytics import YOLO
from sklearn.metrics.pairwise import cosine_similarity


In [ ]:
# Fix for torchvision.functional_tensor error
if not hasattr(torchvision.transforms, 'functional_tensor'):
    import torchvision.transforms.functional as F
    torchvision.transforms.functional_tensor = F
    print("Torchvision compatibility patch applied.")

In [ ]:
# ===============================
# 3️⃣ Setup Models
# ===============================
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# Load SAM
sam_checkpoint_path = 'sam_vit_b.pth'  # Download manually or script
sam = sam_model_registry['vit_b'](checkpoint=sam_checkpoint_path).to(device)
predictor = SamPredictor(sam)

# Load CLIP
import clip
clip_model, clip_preprocess = clip.load('ViT-B/32', device=device)

# Load Real-ESRGAN
model_arch = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
realesrgan_model = RealESRGANer(
    scale=4,
    model_path='https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth',
    model=model_arch,
    tile=0,
    tile_pad=10,
    pre_pad=0,
    half=True if torch.cuda.is_available() else False
)

# Load YOLO
yolo_model = YOLO('yolo11m.pt')

In [ ]:
# ===============================
# 4️⃣ Helper Functions (Masking, Embeddings, Extraction)
# ===============================
def get_sam_mask(image_bgr):
    h, w = image_bgr.shape[:2]
    predictor.set_image(cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB))
    masks, _, _ = predictor.predict(
        point_coords=np.array([[w//2, h//2]]),
        point_labels=np.array([1]),
        multimask_output=False
    )
    return masks[0]

def process_for_clip(image_bgr, mask=None):
    if mask is not None:
        image_bgr[mask == 0] = [255, 255, 255]
    img_rgb = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB)
    return Image.fromarray(img_rgb).resize((224,224), Image.LANCZOS)

def get_embedding(pil_img):
    img_input = clip_preprocess(pil_img).unsqueeze(0).to(device)
    with torch.no_grad():
        emb = clip_model.encode_image(img_input)
    emb = emb / emb.norm(dim=-1, keepdim=True)
    return emb.cpu().numpy()

In [ ]:
# ===============================
# 5️⃣ Rotation-Invariant Matching
# ===============================
def rotation_invariant_matching(query_embs, db_embs, top_k=5, threshold=0.65):
    sims = cosine_similarity(query_embs, db_embs)
    best_sim = sims.max(axis=0)
    idx = np.where(best_sim >= threshold)[0]
    idx = idx[np.argsort(best_sim[idx])[::-1]]
    return idx[:top_k], best_sim[idx[:top_k]]

In [ ]:
# ===============================
# 6️⃣ Main Pipeline (Provide Paths)
# ===============================
upload_folder = './uploads'
db_folder = './catalog'
os.makedirs(upload_folder, exist_ok=True)
os.makedirs(db_folder, exist_ok=True)

# Users must place catalog images in './catalog' and query image in './uploads'
query_path = './uploads/query_ring.jpg'

# Load query image
original_img = Image.open(query_path).convert('RGB')

# Process query (clean or extract)
def is_clean_catalog_image(pil_img):
    img = np.array(pil_img.convert('RGB'))
    h, w, _ = img.shape
    gray = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
    white_ratio = np.mean(gray > 240)
    hsv = cv2.cvtColor(img, cv2.COLOR_RGB2HSV)
    skin_mask = cv2.inRange(hsv, (0,20,70),(20,255,255))
    skin_ratio = np.mean(skin_mask>0)
    _, thresh = cv2.threshold(gray, 245, 255, cv2.THRESH_BINARY_INV)
    object_ratio = np.mean(thresh>0)
    return white_ratio>0.55 and skin_ratio<0.005 and object_ratio<0.6

if is_clean_catalog_image(original_img):
    extracted_pil = original_img
else:
    # Use your extracted_actual_ring function
    extracted_pil = extracted_actual_ring(query_path, predictor, realesrgan_model, yolo_model)
    if extracted_pil is None:
        print('Fallback to center crop')
        # fallback_center_crop(query_path)
