In [1]:
import torch
import torch.nn.functional as F
from torchvision import models, transforms
from PIL import Image
import numpy as np
import cv2
import glob

FEATURE_LAYER = 16  # Layer in VGG16 to extract features
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"  # Similarity threshold for detection
TEMPLATE_SIZE = (32, 32)  # Size of template for matching


In [2]:
# ...existing code...
import os
# --- NEW: helper to auto-crop a template (remove blank margins) ---
def auto_crop_digit(gray_tpl, pad=1):
    _, bin_inv = cv2.threshold(gray_tpl, 0, 255,
                               cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)
    ys, xs = np.where(bin_inv > 0)
    if len(xs) == 0:
        return gray_tpl  # fallback
    y1, y2 = max(0, ys.min()-pad), min(gray_tpl.shape[0], ys.max()+1+pad)
    x1, x2 = max(0, xs.min()-pad), min(gray_tpl.shape[1], xs.max()+1+pad)
    return gray_tpl[y1:y2, x1:x2]

In [None]:
def advanced_noise_removal(gray_img):
    """Remove lines and dots while preserving digits"""
    
    # 1. Remove small connected components (dots)
    num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(
        cv2.threshold(gray_img, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)[1]
    )
    
    # Filter out small components (likely dots)
    min_area = 20  # Adjust based on your image size
    cleaned = np.zeros_like(gray_img)
    
    for i in range(1, num_labels):
        if stats[i, cv2.CC_STAT_AREA] >= min_area:
            cleaned[labels == i] = 255
    
    # 2. Remove thin lines using morphological operations
    # Vertical line removal
    vertical_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (1, 10))
    vertical_lines = cv2.morphologyEx(cleaned, cv2.MORPH_OPEN, vertical_kernel)
    cleaned = cv2.subtract(cleaned, vertical_lines)
    
    # Horizontal line removal
    horizontal_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (10, 1))
    horizontal_lines = cv2.morphologyEx(cleaned, cv2.MORPH_OPEN, horizontal_kernel)
    cleaned = cv2.subtract(cleaned, horizontal_lines)
    
    # Convert back to grayscale
    cleaned = cv2.bitwise_not(cleaned)
    
    return cleaned

def robust_preprocessing(img_path, noise_reduction=True, contrast_enhance=True):
    """Enhanced preprocessing for noisy/warped captchas"""
    img = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE)
    
    if noise_reduction:
        # Remove lines and dots
        img = advanced_noise_removal(img)
        
        # Bilateral filter to reduce remaining noise while preserving edges
        img = cv2.bilateralFilter(img, 9, 75, 75)
        
        # Morphological opening to clean up
        kernel = np.ones((2,2), np.uint8)
        img = cv2.morphologyEx(img, cv2.MORPH_OPEN, kernel)
    
    if contrast_enhance:
        # CLAHE for adaptive contrast enhancement
        clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8,8))
        img = clahe.apply(img)
    
    return img

In [120]:
# --------------------------
# 1. Load CNN (feature extractor)
# --------------------------
cnn = models.vgg16(pretrained=True).features[:FEATURE_LAYER].to(DEVICE).eval()

# Image preprocessing
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])


def extract_feat_map(pil_img):
    t = transform(pil_img).unsqueeze(0).to(DEVICE)      # [1,3,H,W]
    with torch.no_grad():
        fm = cnn(t)                                # [1,C,Hf,Wf], stride≈8
    return fm.squeeze(0)    

In [121]:
def extract_features(img):
    img_tensor = transform(img).unsqueeze(0).to(DEVICE)
    with torch.no_grad():
        feat = cnn(img_tensor)
    return feat.squeeze(0)  # Shape: [C, H, W]

In [122]:
BASE_TARGET_HEIGHTS = [16, 20, 24,32]          # Much smaller!
SCALE_FACTORS = [0.8, 0.9, 1.0, 1.1]        # enlarge / shrink templates
DEDUP_MSE_THRESH = 15.0

def build_box_templates(template_dir="templates",
                        base_heights=BASE_TARGET_HEIGHTS,
                        scale_factors=SCALE_FACTORS,
                        dedup=True,
                        mse_thresh=DEDUP_MSE_THRESH):
    out = {}
    for d in range(10):
        paths = glob.glob(os.path.join(template_dir, str(d), "*.*"))
        lst = []
        for p in paths:
            if not p.lower().endswith((".png",".jpg",".jpeg")): 
                continue
            base = cv2.imread(p, cv2.IMREAD_GRAYSCALE)
            if base is None:
                continue
            base = cv2.normalize(base, None, 0, 255, cv2.NORM_MINMAX)
            base = auto_crop_digit(base)  # <-- NEW: tighter source template
            for bh in base_heights:
                scale0 = bh / max(1, base.shape[0])
                w0 = max(8, int(base.shape[1] * scale0))
                base_scaled = cv2.resize(base, (w0, bh), interpolation=cv2.INTER_LINEAR)
                for sf in scale_factors:
                    h2 = max(8, int(bh * sf))
                    w2 = max(8, int(w0 * sf))
                    tpl2 = cv2.resize(base_scaled, (w2, h2), interpolation=cv2.INTER_LINEAR)
                    lst.append({'img': tpl2, 'h': h2, 'w': w2})
        # Deduplicate by MSE (same size only)
        if dedup:
            kept = []
            for t in lst:
                add = True
                for k in kept:
                    if k['h'] == t['h'] and k['w'] == t['w']:
                        mse = np.mean((k['img'].astype(np.float32) - t['img'].astype(np.float32))**2)
                        if mse < mse_thresh:
                            add = False
                            break
                if add:
                    kept.append(t)
            lst = kept
        out[str(d)] = lst
        print(f"Digit {d}: originals={len(paths)} final_templates={len(lst)}")
    return out
def precompute_cnn_features(pixel_templates):
    """Precompute CNN features for templates"""
    print("Precomputing CNN features for templates...")
    template_features = {}
    
    for digit, templates in pixel_templates.items():
        template_features[digit] = []
        
        for i, template_info in enumerate(templates):
            template_gray = template_info['img']
            h, w = template_gray.shape
            
            # Convert to RGB for CNN
            template_rgb = cv2.cvtColor(template_gray, cv2.COLOR_GRAY2RGB)
            template_pil = Image.fromarray(template_rgb)
            
            try:
                features = extract_features(template_pil)
                template_features[digit].append({
                    'features': features,
                    'height': h,
                    'width': w
                })
            except Exception as e:
                print(f"Error processing template {digit}_{i}: {e}")
                continue
        
        print(f"Digit {digit}: {len(template_features[digit])} templates processed")
    
    return template_features
box_templates = build_box_templates()
template_features = precompute_cnn_features(box_templates)

Digit 0: originals=5 final_templates=80
Digit 1: originals=5 final_templates=80
Digit 2: originals=5 final_templates=80
Digit 3: originals=5 final_templates=80
Digit 4: originals=5 final_templates=80
Digit 5: originals=5 final_templates=80
Digit 6: originals=5 final_templates=80
Digit 7: originals=5 final_templates=80
Digit 8: originals=5 final_templates=66
Digit 9: originals=5 final_templates=80
Precomputing CNN features for templates...
Digit 0: 80 templates processed
Digit 1: 80 templates processed
Digit 2: 80 templates processed
Digit 3: 80 templates processed
Digit 4: 80 templates processed
Digit 5: 80 templates processed
Digit 6: 80 templates processed
Digit 7: 80 templates processed
Digit 8: 66 templates processed
Digit 9: 80 templates processed


In [123]:
def separate_touching_digits(boxes, gray_img, min_width=8):
    """Attempt to separate touching digits using vertical projections"""
    separated_boxes = []
    
    for x1, y1, x2, y2 in boxes:
        width = x2 - x1
        height = y2 - y1
        
        # If box is suspiciously wide, try to split it
        if width > height * 1.8:  # Likely contains multiple digits
            digit_region = gray_img[y1:y2, x1:x2]
            
            # Invert and get vertical projection
            _, binary = cv2.threshold(digit_region, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)
            vertical_proj = np.sum(binary, axis=0)
            
            # Find valleys (potential split points)
            valleys = []
            for i in range(1, len(vertical_proj) - 1):
                if (vertical_proj[i] < vertical_proj[i-1] and 
                    vertical_proj[i] < vertical_proj[i+1] and
                    vertical_proj[i] < np.mean(vertical_proj) * 0.3):
                    valleys.append(i)
            
            if valleys:
                # Split at the middle valley
                split_x = valleys[len(valleys)//2] + x1
                
                # Create two boxes
                if split_x - x1 >= min_width and x2 - split_x >= min_width:
                    separated_boxes.append([x1, y1, split_x, y2])
                    separated_boxes.append([split_x, y1, x2, y2])
                    continue
        
        separated_boxes.append([x1, y1, x2, y2])
    
    return separated_boxes

In [124]:
def cnn_feature_matching(img_path, templates_features, similarity_thresh=0.85):
    """Use CNN features for more robust matching"""
    # Load and preprocess image
    img = Image.open(img_path).convert('RGB')
    img_features = extract_features(img)  # Shape: [C, H, W]
    
    detections = []
    
    # Sliding window approach on feature map
    for digit, template_list in templates_features.items():
        for template_feat in template_list:
            # Compute normalized cross-correlation in feature space
            feat_h, feat_w = img_features.shape[1], img_features.shape[2]
            tpl_h, tpl_w = template_feat['features'].shape[1], template_feat['features'].shape[2]
            
            if feat_h < tpl_h or feat_w < tpl_w:
                continue
                
            # Slide template over image features
            for y in range(feat_h - tpl_h + 1):
                for x in range(feat_w - tpl_w + 1):
                    img_patch = img_features[:, y:y+tpl_h, x:x+tpl_w]
                    
                    # Cosine similarity
                    similarity = F.cosine_similarity(
                        img_patch.flatten().unsqueeze(0),
                        template_feat['features'].flatten().unsqueeze(0)
                    ).item()
                    
                    if similarity > similarity_thresh:
                        # Map back to pixel coordinates (approximate stride=8)
                        stride = 8
                        px1, py1 = x * stride, y * stride
                        px2 = px1 + template_feat['width']
                        py2 = py1 + template_feat['height']
                        
                        detections.append([px1, py1, px2, py2, similarity])
    
    return detections

In [125]:
# ...existing code...
def refine_box(gray, x1, y1, x2, y2, min_size=3, dilate_iters=1, margin=1):
    """
    Tighten a raw template match box to foreground strokes.
    dilate_iters: small dilation to avoid cutting thin parts
    margin: expand final tight box a little so it is not over-tight
    """
    h, w = gray.shape
    x1 = max(0, x1); y1 = max(0, y1)
    x2 = min(w-1, x2); y2 = min(h-1, y2)
    if x2 <= x1+1 or y2 <= y1+1:
        return x1,y1,x2,y2
    patch = gray[y1:y2, x1:x2]
    # Otsu inverse (assuming dark digit)
    _, bin_inv = cv2.threshold(patch, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)
    if dilate_iters > 0:
        k = np.ones((3,3), np.uint8)
        bin_inv = cv2.dilate(bin_inv, k, iterations=dilate_iters)
    ys, xs = np.where(bin_inv > 0)
    if len(xs) == 0:
        return x1,y1,x2,y2
    ny1, ny2 = y1 + ys.min(), y1 + ys.max() + 1
    nx1, nx2 = x1 + xs.min(), x1 + xs.max() + 1
    if (nx2-nx1) < min_size or (ny2-ny1) < min_size:
        return x1,y1,x2,y2
    # Add small margin (clamped)
    nx1 = max(0, nx1 - margin)
    ny1 = max(0, ny1 - margin)
    nx2 = min(w-1, nx2 + margin)
    ny2 = min(h-1, ny2 + margin)
    return nx1, ny1, nx2, ny2

In [126]:
def nms_xyxy(boxes_scores, iou_threshold=0.3):
    """
    boxes_scores: list [x1,y1,x2,y2,score]
    """
    if not boxes_scores: 
        return []
    arr = np.array(boxes_scores, dtype=float)
    x1,y1,x2,y2,sc = arr[:,0],arr[:,1],arr[:,2],arr[:,3],arr[:,4]
    order = sc.argsort()[::-1]
    keep = []
    while order.size > 0:
        i = order[0]
        keep.append(i)
        if order.size == 1:
            break
        rest = order[1:]
        xx1 = np.maximum(x1[i], x1[rest])
        yy1 = np.maximum(y1[i], y1[rest])
        xx2 = np.minimum(x2[i], x2[rest])
        yy2 = np.minimum(y2[i], y2[rest])
        iw = np.clip(xx2 - xx1, 0, None)
        ih = np.clip(yy2 - yy1, 0, None)
        inter = iw * ih
        area_i = (x2[i]-x1[i])*(y2[i]-y1[i])
        area_r = (x2[rest]-x1[rest])*(y2[rest]-y1[rest])
        union = area_i + area_r - inter
        iou = inter / (union + 1e-6)
        order = rest[iou < iou_threshold]
    return [boxes_scores[i] for i in keep]

In [127]:
# ...existing code...
def suppress_nested_and_duplicates(dets, center_dist_thresh=6, nested_area_ratio=0.85):
    """
    dets: list [x1,y1,x2,y2,score]
    Removes:
      - boxes largely contained (>nested_area_ratio of smaller area inside larger)
      - near-duplicate boxes whose centers are very close and IoU moderate/high
    """
    if not dets:
        return dets
    # Sort by score desc
    dets = sorted(dets, key=lambda d: d[4], reverse=True)
    kept = []
    for d in dets:
        x1,y1,x2,y2,sc = d
        cx = 0.5*(x1+x2); cy = 0.5*(y1+y2)
        area_d = max(1,(x2-x1)*(y2-y1))
        discard = False
        for k in kept:
            kx1,ky1,kx2,ky2,ks = k
            kc = (0.5*(kx1+kx2), 0.5*(ky1+ky2))
            # center distance
            if abs(cx-kc[0]) <= center_dist_thresh and abs(cy-kc[1]) <= center_dist_thresh:
                # treat as duplicate
                discard = True
                break
            # nested check
            inter_x1 = max(x1,kx1); inter_y1 = max(y1,ky1)
            inter_x2 = min(x2,kx2); inter_y2 = min(y2,ky2)
            iw = max(0, inter_x2 - inter_x1)
            ih = max(0, inter_y2 - inter_y1)
            inter = iw*ih
            if inter > 0:
                area_k = max(1,(kx2-kx1)*(ky2-ky1))
                smaller = min(area_d, area_k)
                if inter / smaller >= nested_area_ratio:
                    discard = True
                    break
        if not discard:
            kept.append(d)
    return kept

In [149]:
def advanced_morphological_noise_removal(gray_img):
    """Remove noise using erosion, dilation, opening, and closing"""
    
    # 1. Initial thresholding
    _, binary = cv2.threshold(gray_img, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)
    
    # 2. Remove small noise with opening (erosion followed by dilation)
    noise_kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (2, 2))
    cleaned = cv2.morphologyEx(binary, cv2.MORPH_OPEN, noise_kernel)
    
    # 3. Remove thin lines (vertical and horizontal)
    # Vertical line removal
    vertical_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (1, 8))
    vertical_lines = cv2.morphologyEx(cleaned, cv2.MORPH_OPEN, vertical_kernel)
    cleaned = cv2.subtract(cleaned, vertical_lines)
    
    # Horizontal line removal  
    horizontal_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (8, 1))
    horizontal_lines = cv2.morphologyEx(cleaned, cv2.MORPH_OPEN, horizontal_kernel)
    cleaned = cv2.subtract(cleaned, horizontal_lines)
    
    # 4. Fill small gaps in digits with closing (dilation followed by erosion)
    fill_kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (2, 2))
    cleaned = cv2.morphologyEx(cleaned, cv2.MORPH_CLOSE, fill_kernel)
    
    # 5. Remove remaining small connected components
    num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(cleaned)
    min_area = 25  # Minimum area for digit components
    final_cleaned = np.zeros_like(cleaned)
    
    for i in range(1, num_labels):
        if stats[i, cv2.CC_STAT_AREA] >= min_area:
            final_cleaned[labels == i] = 255
    
    # 6. Final dilation to restore digit thickness
    restore_kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (1, 1))
    final_cleaned = cv2.dilate(final_cleaned, restore_kernel, iterations=1)
    
    # Convert back to grayscale (invert back)
    result = cv2.bitwise_not(final_cleaned)
    
    return result

def robust_preprocessing(img_path, noise_reduction=True, contrast_enhance=True):
    """Enhanced preprocessing with morphological operations"""
    
    # Load image
    img = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE)
    
    # 1. Initial denoising
    img = cv2.bilateralFilter(img, 9, 75, 75)
    
    # 2. Advanced morphological noise removal
    img = advanced_morphological_noise_removal(img)
    
    # 3. Additional erosion to separate touching components
    separate_kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (1, 1))
    img = cv2.erode(img, separate_kernel, iterations=1)
    
    # 4. Final contrast enhancement
    clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8,8))
    img = clahe.apply(img)
    
    return img

def detect_with_multiple_thresholds(gray_img, template, thresholds=[0.3, 0.4, 0.5]):
    """Try multiple thresholds and combine results"""
    all_detections = []
    
    for thresh in thresholds:
        res = cv2.matchTemplate(gray_img, template, cv2.TM_CCOEFF_NORMED)
        loc = np.where(res >= thresh)
        
        for y, x in zip(loc[0], loc[1]):
            score = res[y, x]
            x1, y1 = x, y
            x2, y2 = x + template.shape[1], y + template.shape[0]
            all_detections.append([x1, y1, x2, y2, score])
    
    return all_detections

In [150]:
def visualize_boxes_only(img_path, boxes, out_path="boxes_only.png"):
    img = cv2.imread(img_path)
    for (x1,y1,x2,y2) in boxes:
        cv2.rectangle(img, (x1,y1), (x2,y2), (0,255,0), 2)
    cv2.imwrite(out_path, img)
    print("Saved", out_path)

In [151]:
def cnn_feature_matching_with_edge_filtering(img_path, templates_features, similarity_thresh=0.5):
    """CNN matching with edge filtering to avoid corner detections"""
    img = Image.open(img_path).convert('RGB')
    img_features = extract_features(img)
    
    detections = []
    feat_h, feat_w = img_features.shape[1], img_features.shape[2]
    
    original_img = cv2.imread(img_path)
    orig_h, orig_w = original_img.shape[:2]
    stride_h = orig_h / feat_h
    stride_w = orig_w / feat_w
    
    # Define edge margins (avoid detections too close to edges)
    edge_margin_x = int(orig_w * 0.05)  # 5% margin from edges
    edge_margin_y = int(orig_h * 0.05)
    
    for digit, template_list in templates_features.items():
        digit_detections = 0
        
        for template_feat in template_list:
            tpl_features = template_feat['features']
            tpl_h, tpl_w = tpl_features.shape[1], tpl_features.shape[2]
            
            if feat_h < tpl_h or feat_w < tpl_w:
                continue
                
            for y in range(0, feat_h - tpl_h + 1, 2):
                for x in range(0, feat_w - tpl_w + 1, 2):
                    img_patch = img_features[:, y:y+tpl_h, x:x+tpl_w]
                    
                    img_flat = F.normalize(img_patch.flatten(), dim=0)
                    tpl_flat = F.normalize(tpl_features.flatten(), dim=0)
                    
                    similarity = F.cosine_similarity(
                        img_flat.unsqueeze(0),
                        tpl_flat.unsqueeze(0),
                        dim=1
                    ).item()
                    
                    if similarity > similarity_thresh:
                        center_x = (x + tpl_w/2) * stride_w
                        center_y = (y + tpl_h/2) * stride_h
                        
                        # FILTER OUT EDGE DETECTIONS
                        if (center_x < edge_margin_x or center_x > orig_w - edge_margin_x or
                            center_y < edge_margin_y or center_y > orig_h - edge_margin_y):
                            continue  # Skip edge detections
                        
                        template_width = template_feat['width'] * 1.5  # Reduced scale
                        template_height = template_feat['height'] * 1.5
                        
                        px1 = int(center_x - template_width/2)
                        py1 = int(center_y - template_height/2)
                        px2 = int(center_x + template_width/2)
                        py2 = int(center_y + template_height/2)
                        
                        px1 = max(0, px1)
                        py1 = max(0, py1)
                        px2 = min(orig_w, px2)
                        py2 = min(orig_h, py2)
                        
                        detections.append([px1, py1, px2, py2, similarity])
                        digit_detections += 1
        
        if digit_detections > 0:
            print(f"Digit {digit}: {digit_detections} detections")
    
    return detections

def detect_digits_cnn_filtered(img_path, template_features, similarity_thresh=0.6):
    """CNN detection with edge filtering and higher threshold"""
    
    all_detections = cnn_feature_matching_with_edge_filtering(
        img_path, 
        template_features, 
        similarity_thresh=similarity_thresh  # Higher threshold
    )
    
    # More aggressive NMS
    kept = nms_xyxy(all_detections, iou_threshold=0.3)  # Lower IoU = more aggressive
    kept = suppress_nested_and_duplicates(kept, center_dist_thresh=15)
    
    boxes = [[x1, y1, x2, y2] for x1, y1, x2, y2, _ in kept]
    boxes = sorted(boxes, key=lambda b: b[0])
    
    return boxes

In [None]:
boxes_filtered = detect_digits_cnn_filtered(
    r"captcha\21136.png", 
    template_features, 
    similarity_thresh=0.4  # Higher threshold
)
visualize_boxes_only(r"captcha\21136.png", boxes_filtered, "result_cnn_filtered.png")
print(f"Filtered CNN detected {len(boxes_filtered)} boxes")

Digit 0: 12 detections
Digit 2: 10 detections
Digit 5: 2 detections
Digit 6: 13 detections
Digit 8: 1 detections
Digit 9: 8 detections
Saved result_cnn_filtered.png
Filtered CNN detected 6 boxes


: 