In [2]:
# diagnose_invert_smart.py
import os
import numpy as np
import torch
import cv2
from PIL import Image
import matplotlib.pyplot as plt
from torchvision.datasets import EMNIST
from torchvision import transforms

# Reuse the same transform logic classes (copy from your train_expert_analyze_h.py)
import torch.nn.functional as F

class SobelGrayTransform:
    def __init__(self):
        kx = torch.tensor([[-1,0,1],[-2,0,2],[-1,0,1]], dtype=torch.float32)
        ky = torch.tensor([[-1,-2,-1],[0,0,0],[1,2,1]], dtype=torch.float32)
        self.fx = kx.view(1,1,3,3)
        self.fy = ky.view(1,1,3,3)

    def __call__(self, tensor):
        # expects tensor [1,H,W] in [0,1]
        if isinstance(tensor, torch.Tensor):
            if tensor.ndim == 2:
                t = tensor.unsqueeze(0).unsqueeze(0)  # [1,1,H,W]
            elif tensor.ndim == 3 and tensor.shape[0] == 1:
                t = tensor.unsqueeze(0)  # [1,1,H,W]
            else:
                t = tensor.unsqueeze(0)
        else:
            raise TypeError(f"SobelGrayTransform expected tensor, got {type(tensor)}")
        gx = F.conv2d(t, self.fx.to(t.device), padding=1)
        gy = F.conv2d(t, self.fy.to(t.device), padding=1)
        mag = (gx**2 + gy**2).sqrt().squeeze(0).squeeze(0)  # (H,W)
        minv = mag.min()
        maxv = mag.max()
        mag = (mag - minv) / (maxv - minv + 1e-6)
        return mag.unsqueeze(0)  # [1,H,W]

class InvertIfSmart:
    def __init__(self, border_width=8, otsu_threshold=True, sobel_transform=None, sobel_thresh=0.3):
        self.border_width = border_width
        self.otsu_enabled = otsu_threshold
        self.sobel = sobel_transform
        self.sobel_thresh = sobel_thresh
        self.last_inverted = False

    def __call__(self, img):
        # img: PIL.Image or torch.Tensor in [0,1]
        if isinstance(img, Image.Image):
            arr = np.array(img).astype(np.float32) / 255.0  # H,W
            tensor = torch.from_numpy(arr).unsqueeze(0)  # [1,H,W]
        elif isinstance(img, torch.Tensor):
            tensor = img.clone()
            if tensor.ndim == 2:
                tensor = tensor.unsqueeze(0)
            elif tensor.ndim == 3 and tensor.shape[0] != 1:
                tensor = tensor.mean(dim=0, keepdim=True)
        else:
            raise TypeError(f"Unsupported type {type(img)} for InvertIfSmart")

        _, H, W = tensor.shape
        bw = min(self.border_width, H//2, W//2)

        top_border    = tensor[:, :bw, :].reshape(-1)
        bottom_border = tensor[:, H-bw:, :].reshape(-1)
        left_border   = tensor[:, :, :bw].reshape(-1)
        right_border  = tensor[:, :, W-bw:].reshape(-1)
        border_pixels = torch.cat([top_border, bottom_border, left_border, right_border], dim=0)
        center_pixels = tensor[:, bw:H-bw, bw:W-bw].reshape(-1)

        border_mean = float(border_pixels.mean().item())
        center_mean = float(center_pixels.mean().item())

        out = tensor
        inverted = False

        # primary signal: border brighter than center
        if border_mean > center_mean:
            out = 1.0 - tensor
            inverted = True
        else:
            # secondary: Otsu suggests inversion AND Sobel is weak
            otsu_invert = False
            if self.otsu_enabled:
                img_np = (tensor.squeeze(0).cpu().numpy() * 255.0).astype(np.uint8)
                _, bin_img = cv2.threshold(img_np, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
                num_white = int((bin_img == 255).sum())
                num_black = int((bin_img ==   0).sum())
                if num_white > num_black:
                    otsu_invert = True

            sobel_weak = False
            if self.sobel is not None:
                sobel_map = self.sobel(tensor)
                mean_sobel = float(sobel_map.mean().item())
                if mean_sobel < self.sobel_thresh:
                    sobel_weak = True

            if otsu_invert and sobel_weak:
                out = 1.0 - tensor
                inverted = True

        self.last_inverted = inverted
        return out  # [1,H,W]

def show_pair(orig, processed, out_path, title_suffix=""):
    o = orig.squeeze(0).numpy()
    p = processed.squeeze(0).numpy()
    fig, axs = plt.subplots(1,2, figsize=(4,2))
    axs[0].imshow(o, cmap='gray', vmin=0, vmax=1); axs[0].set_title('Original')
    axs[1].imshow(p, cmap='gray', vmin=0, vmax=1); axs[1].set_title('After InvertIfSmart')
    for ax in axs: ax.axis('off')
    plt.suptitle(title_suffix)
    plt.tight_layout()
    plt.savefig(out_path, dpi=150)
    plt.close(fig)

def main():
    os.makedirs("invert_diagnosis", exist_ok=True)

    # dataset without any transform to get raw PILs
    ds = EMNIST('data', split='balanced', train=True, download=True, transform=None)
    sobel = SobelGrayTransform()
    inverter = InvertIfSmart(sobel_transform=sobel, sobel_thresh=0.3)

    total = len(ds)
    inverted_indices = []
    for i in range(total):
        pil_img, label = ds[i]
        out_tensor = inverter(pil_img)  # [1,H,W]
        if inverter.last_inverted:
            inverted_indices.append(i)
        if i % 5000 == 0:
            print(f"Scanned {i}/{total}, current inverted count {len(inverted_indices)}")

    print(f"Total inverted on train set: {len(inverted_indices)}/{total} ({len(inverted_indices)/total:.3%})")

    # sample up to 10 inverted examples for inspection
    sample_idxs = inverted_indices[:10]
    for rank, idx in enumerate(sample_idxs):
        pil_img, label = ds[idx]
        orig_tensor = transforms.ToTensor()(pil_img)  # [1,H,W]
        inverter(pil_img)
        processed_tensor = inverter(pil_img)  # note: calling again to get the processed output
        show_pair(orig_tensor, processed_tensor,
                  out_path=f"invert_diagnosis/invert_sample_{rank}_idx_{idx}_lbl_{label}.png",
                  title_suffix=f"idx={idx} label={label} inverted={inverter.last_inverted}")

    # also save stats
    with open("invert_diagnosis/summary.txt", "w") as f:
        f.write(f"total={total}\n")
        f.write(f"inverted={len(inverted_indices)}\n")
        f.write(f"ratio={len(inverted_indices)/total:.4f}\n")
        f.write("sample_indices=" + ",".join(str(i) for i in sample_idxs) + "\n")

if __name__ == "__main__":
    main()


Scanned 0/112800, current inverted count 0
Scanned 5000/112800, current inverted count 167
Scanned 10000/112800, current inverted count 353
Scanned 15000/112800, current inverted count 533
Scanned 20000/112800, current inverted count 718
Scanned 25000/112800, current inverted count 911
Scanned 30000/112800, current inverted count 1100
Scanned 35000/112800, current inverted count 1267
Scanned 40000/112800, current inverted count 1467
Scanned 45000/112800, current inverted count 1659
Scanned 50000/112800, current inverted count 1859
Scanned 55000/112800, current inverted count 2043
Scanned 60000/112800, current inverted count 2229
Scanned 65000/112800, current inverted count 2408
Scanned 70000/112800, current inverted count 2604
Scanned 75000/112800, current inverted count 2800
Scanned 80000/112800, current inverted count 2952
Scanned 85000/112800, current inverted count 3118
Scanned 90000/112800, current inverted count 3275
Scanned 95000/112800, current inverted count 3458
Scanned 10000