# POC


#### Image Cropping Mechanism Test

In [None]:
import cv2
import numpy as np
import matplotlib.pyplot as plt
from ultralytics import YOLO
import sys
import os
import random
import json
import urllib.parse
from pathlib import Path
import torch

# Configuration
class Config:
    # Paths - Ensure these match local environment
    RAW_JSON: Path = Path("./data/project-4-at-2025-11-19-03-05-8b131c6a.json")
    RAW_IMG_DIR: Path = Path("./data/Device_Placement") 
    INPUT_SIZE: int = 512
    
    # Mock parameters
    WORK_DIR: Path = Path("./output_mock")
    IMG_OUT: Path = WORK_DIR / "images"
    MASK_OUT: Path = WORK_DIR / "masks"
    SPLIT_DIR: Path = WORK_DIR / "splits"

# Main Visualization Class
class PipelineVisualizer:
    """
    Visualizes the four stages of the Smart Crop Preprocessing for a single random image.
    This class is structurally derived from the original Preprocessor logic.
    """

    def __init__(self, cfg):
        self.cfg = cfg
        print("Initializing YOLO model for subject detection...")
        try:
            self.yolo = YOLO('yolo11n-seg.pt')
        except Exception as e:
            print(f"Error loading YOLO model: {e}")
            sys.exit(1)

    def select_random_image_data(self):
        """Selects a random image entry and loads its JSON annotation."""
        IMAGE_EXTENSIONS = {'.png', '.jpg', '.jpeg', '.bmp', '.tif'}
        
        try:
            with open(self.cfg.RAW_JSON, 'r') as f:  
                data = json.load(f)
        except FileNotFoundError:
            print(f"ERROR: JSON file not found at {self.cfg.RAW_JSON}. Cannot proceed.")
            sys.exit(1)

        valid_items = [x for x in data if x.get('image')]
        if not valid_items:
            print("ERROR: JSON file is empty or missing image entries.")
            sys.exit(1)

        while True:
            # Select a random entry
            item = random.choice(valid_items)
            
            # Replicate the path logic
            url = item['image']
            if "?d=" in url: url = url.split("?d=")[1]
            clean_url = urllib.parse.unquote(url).replace("\\", "/")
            fname = Path(clean_url).name
            img_path = self.cfg.RAW_IMG_DIR / fname

            if img_path.exists() and img_path.suffix.lower() in IMAGE_EXTENSIONS:
                print(f"Selected Random Image: {fname}")
                return img_path, item
            
            valid_items.remove(item)
            if not valid_items:
                 print(f"ERROR: None of the images listed in JSON were found in {self.cfg.RAW_IMG_DIR}")
                 sys.exit(1)


    def visualize(self):
        img_path, item = self.select_random_image_data()
        
        # STAGE 1: REAL IMAGE
        img_bgr = cv2.imread(str(img_path))
        img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
        h, w = img_rgb.shape[:2]
        
        raw_img = img_rgb.copy()

        # STAGE 2: YOLO SEGMENTED IMAGE
        results = self.yolo(img_rgb, verbose=False, retina_masks=True)
        person_mask = np.zeros((h, w), dtype=np.uint8)
        
        if results[0].masks:
            boxes = results[0].boxes
            persons = boxes.cls == 0
            if persons.any():
                idx = torch.argmax(boxes.xywh[persons, 2] * boxes.xywh[persons, 3])
                real_idx = persons.nonzero(as_tuple=True)[0][idx]
                m = results[0].masks.data[real_idx].cpu().numpy()
                if m.shape != (h, w): m = cv2.resize(m, (w, h))
                person_mask = (m > 0.5).astype(np.uint8)

        # Create a visual overlay for the YOLO segmentation
        red_mask = np.zeros_like(img_rgb)
        red_mask[:, :, 0] = 255
        yolo_segmented_img = np.where(
            np.stack((person_mask,) * 3, axis=-1) > 0,
            cv2.addWeighted(img_rgb, 0.7, red_mask, 0.3, 0),
            img_rgb
        )
        
        # STAGE 3 & 4: CROPPED AND FINAL RESIZE
        
        # Parse Device Points
        device_points = []
        if 'label' in item and item['label']:
            lbl = item['label'][0]
            pts = np.array(lbl['points'])
            pts[:, 0] *= (lbl['original_width'] / 100.0)
            pts[:, 1] *= (lbl['original_height'] / 100.0)
            device_points = pts

        # Determine Boundaries
        y_head_top = np.argmax(np.any(person_mask, axis=1)) if np.sum(person_mask) > 0 else 0
        y_device_bottom = int(np.max(device_points[:, 1])) if len(device_points) > 0 else h

        # Calculate Square Dim and Center
        roi_height = max(y_device_bottom - y_head_top, h // 3)
        square_dim = max(int(roi_height * 1.5), 256)
        center_y = y_head_top + (roi_height // 2)
        
        mask_slice = person_mask[max(0, y_head_top):min(h, y_device_bottom), :]
        cols = np.sum(mask_slice, axis=0)
        center_x = int(np.dot(np.arange(w), cols) / np.sum(cols)) if np.sum(cols) > 0 else w // 2

        # Apply Crop and Padding
        half = square_dim // 2
        x1, y1 = center_x - half, center_y - half
        x2, y2 = x1 + square_dim, y1 + square_dim

        # Replicate black background logic
        clean_img = np.where(np.stack((person_mask,) * 3, axis=-1) > 0, img_rgb, np.zeros_like(img_rgb))
        
        pad_l, pad_t = max(0, -x1), max(0, -y1)
        pad_r, pad_b = max(0, x2 - w), max(0, y2 - h)

        padded_img = cv2.copyMakeBorder(clean_img, pad_t, pad_b, pad_l, pad_r, cv2.BORDER_CONSTANT, value=[0,0,0])

        cx1, cy1 = x1 + pad_l, y1 + pad_t
        cropped_img_roi = padded_img[cy1:cy1 + square_dim, cx1:cx1 + square_dim] 

        # STAGE 4: 512x512 FINAL INPUT
        final_input_512 = cv2.resize(cropped_img_roi, (self.cfg.INPUT_SIZE, self.cfg.INPUT_SIZE), interpolation=cv2.INTER_AREA)

        # VISUALIZATION
        self._plot_stages(raw_img, yolo_segmented_img, cropped_img_roi, final_input_512, Path(img_path).name)

    def _plot_stages(self, raw, yolo_seg, cropped_roi, final_512, filename):
        plt.figure(figsize=(20, 6))
        plt.style.use('default')

        titles = [
            f"1. Real Image\n({raw.shape[1]}x{raw.shape[0]})",
            "2. YOLO Segmented (Subject Isolation)",
            f"3. Smart Crop ROI\n({cropped_roi.shape[0]}x{cropped_roi.shape[1]})",
            f"4. Model Input\n({final_512.shape[0]}x{final_512.shape[1]} - Normalized)"
        ]
        images = [raw, yolo_seg, cropped_roi, final_512]

        for i in range(4):
            plt.subplot(1, 4, i + 1)
            plt.imshow(images[i])
            plt.title(titles[i], fontsize=12, fontweight='bold')
            plt.axis('off')

        plt.suptitle(f"Smart Crop Pipeline Stages for: {filename}", fontsize=14, fontweight='bold')
        plt.tight_layout()
        plt.show()

# Execution Block
if __name__ == "__main__":
    
    # CRITICAL ACTION: Update these paths to match your local setup.
    CFG_VIZ = Config()
    
    # Demo paths
    CFG_VIZ.RAW_IMG_DIR = Path("/kaggle/input/patch-placement/Work/InnerGize/Datasets/Device_Placement") 
    CFG_VIZ.RAW_JSON = Path("/kaggle/input/patch-placement/Work/project-4-at-2025-11-19-03-05-8b131c6a.json")
    
    print("Starting Smart Crop Visualization Script...")
    
    if not (CFG_VIZ.RAW_IMG_DIR.exists() and CFG_VIZ.RAW_JSON.exists()):
        print("\n[SETUP ERROR] Please configure the file paths correctly.")
        print(f"Missing Image Directory: {CFG_VIZ.RAW_IMG_DIR.resolve()}")
        print(f"Missing JSON File: {CFG_VIZ.RAW_JSON.resolve()}")
        sys.exit(1)
        
    viz_pipeline = PipelineVisualizer(CFG_VIZ)
    viz_pipeline.visualize()

# Final Trainning Pipeline

In [None]:
import os
import sys
import json
import cv2
import numpy as np
import torch
import random
import warnings
import urllib.parse
import matplotlib.pyplot as plt
import gc
import time
import logging
import shutil
from pathlib import Path
from dataclasses import dataclass, field
from tqdm.auto import tqdm
from sklearn.model_selection import KFold
from datetime import datetime

# Dependency Installation
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger("SETUP")

logger.info("Setting up Environment...")

# Fix OpenCV Conflict
os.system('pip uninstall -y opencv-python opencv-contrib-python opencv-python-headless')
os.system('pip install -q opencv-python-headless==4.10.0.84') 

# Install ML Libraries
os.system('pip install -q "numpy<2.0" "ultralytics>=8.0.0" segmentation-models-pytorch albumentations torchmetrics')

try:
    import segmentation_models_pytorch as smp
    from ultralytics import YOLO
    import albumentations as albu
    from albumentations.pytorch import ToTensorV2
    import torchmetrics
except ImportError:
    import site
    site.main()
    import segmentation_models_pytorch as smp
    from ultralytics import YOLO
    import albumentations as albu
    from albumentations.pytorch import ToTensorV2
    import torchmetrics

from torch.utils.data import DataLoader, Dataset
from torch.cuda.amp import autocast, GradScaler
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts

warnings.filterwarnings("ignore")

# Configuration & Logging Setup
@dataclass
class Config:
    # Paths
    RAW_JSON: Path = Path("/kaggle/input/patch-placement/Work/project-4-at-2025-11-19-03-05-8b131c6a.json")
    RAW_IMG_DIR: Path = Path("/kaggle/input/patch-placement/Work/InnerGize/Datasets/Device_Placement")
    WORK_DIR: Path = Path("/kaggle/working/prod_pipeline_v1")
    
    # Model
    ARCH: str = 'UnetPlusPlus'
    ENCODER: str = 'resnet34'
    INPUT_SIZE: int = 512
    
    # Training
    FOLDS: int = 5
    BATCH_SIZE: int = 12
    LR: float = 1e-4
    EPOCHS: int = 100
    PATIENCE: int = 15
    DEVICE: str = 'cuda' if torch.cuda.is_available() else 'cpu'
    SEED: int = 42
    
    # MLOps
    EXP_NAME: str = f"exp_{datetime.now().strftime('%Y%m%d_%H%M')}"
    
    # Derived Paths
    IMG_OUT: Path = field(init=False)
    MASK_OUT: Path = field(init=False)
    MODEL_DIR: Path = field(init=False)
    LOG_DIR: Path = field(init=False)

    def __post_init__(self):
        self.IMG_OUT = self.WORK_DIR / "images"
        self.MASK_OUT = self.WORK_DIR / "masks"
        self.MODEL_DIR = self.WORK_DIR / "models"
        self.LOG_DIR = self.WORK_DIR / "logs" / self.EXP_NAME
        
        os.makedirs(self.IMG_OUT, exist_ok=True)
        os.makedirs(self.MASK_OUT, exist_ok=True)
        os.makedirs(self.MODEL_DIR, exist_ok=True)
        os.makedirs(self.LOG_DIR, exist_ok=True)

CFG = Config()

def setup_logger(cfg):
    """Configures a file logger and a console logger"""
    log_file = cfg.LOG_DIR / "training.log"
    
    root = logging.getLogger()
    if root.handlers:
        for handler in root.handlers:
            root.removeHandler(handler)
            
    logging.basicConfig(
        level=logging.INFO,
        format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
        handlers=[
            logging.FileHandler(log_file),
            logging.StreamHandler(sys.stdout)
        ]
    )
    return logging.getLogger("TRAINER")

LOGGER = setup_logger(CFG)

def seed_everything(seed=42):
    LOGGER.info(f"Seeding with {seed}...")
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)

seed_everything(CFG.SEED)

# Unit Tests
def run_unit_tests(cfg):
    """Runs sanity checks on model and data pipeline before training starts."""
    LOGGER.info("Running Pre-Flight Unit Tests...")
    
    # Test 1: Model Architecture
    try:
        model = smp.UnetPlusPlus(encoder_name=cfg.ENCODER, in_channels=3, classes=1)
        dummy_in = torch.randn(2, 3, cfg.INPUT_SIZE, cfg.INPUT_SIZE)
        dummy_out = model(dummy_in)
        assert dummy_out.shape == (2, 1, cfg.INPUT_SIZE, cfg.INPUT_SIZE)
        LOGGER.info("    [PASS] Model Architecture Output Shape")
    except Exception as e:
        LOGGER.error(f"    [FAIL] Model Architecture: {e}")
        raise e

    # Test 2: Augmentation Pipeline
    try:
        aug = get_transforms('train')
        dummy_img = np.random.randint(0, 255, (cfg.INPUT_SIZE, cfg.INPUT_SIZE, 3), dtype=np.uint8)
        dummy_mask = np.random.randint(0, 1, (cfg.INPUT_SIZE, cfg.INPUT_SIZE), dtype=np.uint8)
        res = aug(image=dummy_img, mask=dummy_mask)
        assert res['image'].shape == (3, cfg.INPUT_SIZE, cfg.INPUT_SIZE)
        LOGGER.info("    [PASS] Augmentation Pipeline")
    except Exception as e:
        LOGGER.error(f"    [FAIL] Augmentations: {e}")
        raise e
        
    LOGGER.info("All Unit Tests Passed. Proceeding to Pipeline.")

# Data Pipeline
class ComboLoss(nn.Module):
    def __init__(self):
        super().__init__()
        self.dice = smp.losses.DiceLoss(mode='binary', from_logits=True)
        self.bce = nn.BCEWithLogitsLoss()

    def forward(self, y_pred, y_true):
        return 0.5 * self.dice(y_pred, y_true) + 0.5 * self.bce(y_pred, y_true)

class Preprocessor:
    def __init__(self, cfg: Config):
        self.cfg = cfg

    def process_data(self):
        existing_files = list(self.cfg.IMG_OUT.iterdir())
        if len(existing_files) < 5:
            LOGGER.info(f"Starting Smart Crop Pipeline...")
            self._generate_images()
        else:
            LOGGER.info(f"Data found ({len(existing_files)} images). Skipping generation.")
        return sorted([f.name for f in self.cfg.IMG_OUT.iterdir()])

    def _generate_images(self):
        LOGGER.info("Loading YOLO for Smart Cropping...")
        yolo = YOLO('yolo11n-seg.pt')
        
        with open(self.cfg.RAW_JSON, 'r') as f: data = json.load(f)
        valid_items = [x for x in data if x.get('image')]
        
        success_count = 0
        for item in tqdm(valid_items, desc="Cropping"):
            try:
                url = item['image']
                if "?d=" in url: url = url.split("?d=")[1]
                clean_url = urllib.parse.unquote(url).replace("\\", "/")
                fname = Path(clean_url).name
                img_path = self.cfg.RAW_IMG_DIR / fname

                if not img_path.exists(): continue

                img_bgr = cv2.imread(str(img_path))
                if img_bgr is None: continue
                img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
                h, w = img_rgb.shape[:2]

                # Safety Resize
                infer_img = img_rgb.copy()
                if max(h, w) > 1500:
                    scale = 1500 / max(h, w)
                    infer_img = cv2.resize(img_rgb, (0,0), fx=scale, fy=scale)

                with torch.no_grad():
                    results = yolo(infer_img, verbose=False, retina_masks=False)
                
                person_mask = np.zeros((h, w), dtype=np.uint8)
                if results[0].masks:
                    boxes = results[0].boxes
                    persons = boxes.cls == 0
                    if persons.any():
                        idx = torch.argmax(boxes.xywh[persons, 2] * boxes.xywh[persons, 3])
                        real_idx = persons.nonzero(as_tuple=True)[0][idx]
                        m = results[0].masks.data[real_idx].cpu().numpy()
                        m = cv2.resize(m, (w, h))
                        person_mask = (m > 0.5).astype(np.uint8)

                del results, infer_img
                torch.cuda.empty_cache()

                gt_mask_full = np.zeros((h, w), dtype=np.uint8)
                device_points = []
                if 'label' in item and item['label']:
                    lbl = item['label'][0]
                    pts = np.array(lbl['points'])
                    pts[:, 0] *= (lbl['original_width'] / 100.0)
                    pts[:, 1] *= (lbl['original_height'] / 100.0)
                    device_points = pts
                    cv2.fillPoly(gt_mask_full, [pts.astype(np.int32)], 1)

                if np.sum(person_mask) > 0:
                    rows = np.any(person_mask, axis=1)
                    y_head_top = np.argmax(rows)
                    cols = np.sum(person_mask, axis=0)
                    center_x = int(np.dot(np.arange(w), cols) / (np.sum(cols)+1e-6))
                else:
                    y_head_top = 0
                    center_x = w // 2
                
                y_device_bottom = int(np.max(device_points[:, 1])) if len(device_points) > 0 else h
                roi_height = y_device_bottom - y_head_top
                if roi_height < 50: roi_height = h // 3
                
                square_dim = max(int(roi_height * 1.5), 256)
                center_y = y_head_top + (roi_height // 2)

                half = square_dim // 2
                x1, y1 = center_x - half, center_y - half
                x2, y2 = x1 + square_dim, y1 + square_dim

                pad_l, pad_t = max(0, -x1), max(0, -y1)
                pad_r, pad_b = max(0, x2 - w), max(0, y2 - h)

                padded_img = cv2.copyMakeBorder(img_rgb, pad_t, pad_b, pad_l, pad_r, cv2.BORDER_CONSTANT, value=[0,0,0])
                padded_mask = cv2.copyMakeBorder(gt_mask_full, pad_t, pad_b, pad_l, pad_r, cv2.BORDER_CONSTANT, value=0)

                cx1, cy1 = x1 + pad_l, y1 + pad_t
                cx2, cy2 = cx1 + square_dim, cy1 + square_dim

                fin_img = padded_img[cy1:cy2, cx1:cx2]
                fin_mask = padded_mask[cy1:cy2, cx1:cx2]

                fin_img = cv2.resize(fin_img, (self.cfg.INPUT_SIZE, self.cfg.INPUT_SIZE), interpolation=cv2.INTER_AREA)
                fin_mask = cv2.resize(fin_mask, (self.cfg.INPUT_SIZE, self.cfg.INPUT_SIZE), interpolation=cv2.INTER_NEAREST)

                out_name = Path(fname).stem + ".png"
                cv2.imwrite(str(self.cfg.IMG_OUT / out_name), cv2.cvtColor(fin_img, cv2.COLOR_RGB2BGR))
                cv2.imwrite(str(self.cfg.MASK_OUT / out_name), fin_mask * 255)
                success_count += 1
                
                if success_count % 10 == 0: gc.collect()

            except Exception as e:
                LOGGER.warning(f"Skipped image due to error: {e}")
                continue
        
        del yolo
        gc.collect()
        torch.cuda.empty_cache()

class SegmentationDataset(Dataset):
    def __init__(self, file_list, cfg, transform=None):
        self.files = file_list
        self.cfg = cfg
        self.transform = transform

    def __len__(self): return len(self.files)

    def __getitem__(self, idx):
        fname = self.files[idx]
        img_path = str(self.cfg.IMG_OUT / fname)
        mask_path = str(self.cfg.MASK_OUT / fname)

        image = cv2.cvtColor(cv2.imread(img_path), cv2.COLOR_BGR2RGB)
        mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)

        # Data Validation
        if image is None: raise ValueError(f"Failed to load image: {img_path}")
        if mask is None: raise ValueError(f"Failed to load mask: {mask_path}")
        if image.shape[:2] != mask.shape[:2]: raise ValueError(f"Shape mismatch: {fname}")

        if self.transform:
            res = self.transform(image=image, mask=mask)
            image, mask = res['image'], res['mask']
        
        return image, mask.unsqueeze(0).float() / 255.0

def get_transforms(phase):
    base = [albu.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), ToTensorV2()]
    if phase == 'train':
        return albu.Compose([
            albu.HorizontalFlip(p=0.5),
            albu.ShiftScaleRotate(scale_limit=0.15, rotate_limit=15, shift_limit=0.1, border_mode=0, p=0.7),
            albu.RandomBrightnessContrast(p=0.4),
            albu.HueSaturationValue(p=0.3),
            albu.CoarseDropout(max_holes=8, max_height=32, max_width=32, p=0.3),
        ] + base)
    return albu.Compose(base)

# Training & Metric Tracking
class MetricTracker:
    """Simulates a remote experiment tracker"""
    def __init__(self, log_dir):
        self.log_dir = log_dir
        self.data = []

    def log(self, epoch, train_loss, val_iou, fold):
        self.data.append({
            "fold": fold,
            "epoch": epoch,
            "train_loss": train_loss,
            "val_iou": val_iou,
            "timestamp": datetime.now().isoformat()
        })
    
    def save(self):
        with open(self.log_dir / "metrics.json", "w") as f:
            json.dump(self.data, f, indent=4)

def train_fold(fold_idx, train_files, val_files, cfg, tracker):
    LOGGER.info(f"STARTING FOLD {fold_idx+1}/{cfg.FOLDS} (Train: {len(train_files)}, Val: {len(val_files)})")

    train_ds = SegmentationDataset(train_files, cfg, get_transforms('train'))
    val_ds = SegmentationDataset(val_files, cfg, get_transforms('val'))
    
    train_loader = DataLoader(train_ds, batch_size=cfg.BATCH_SIZE, shuffle=True, num_workers=2, pin_memory=True)
    val_loader = DataLoader(val_ds, batch_size=cfg.BATCH_SIZE, shuffle=False, num_workers=2)

    model = smp.UnetPlusPlus(encoder_name=cfg.ENCODER, encoder_weights="imagenet", in_channels=3, classes=1, activation=None).to(cfg.DEVICE)
    
    optimizer = optim.AdamW(model.parameters(), lr=cfg.LR, weight_decay=1e-3)
    scheduler = CosineAnnealingWarmRestarts(optimizer, T_0=20, T_mult=2, eta_min=1e-6)
    criterion = ComboLoss()
    metric = torchmetrics.JaccardIndex(task="binary").to(cfg.DEVICE)
    scaler = GradScaler()
    
    best_iou = 0.0
    patience_counter = 0

    for epoch in range(cfg.EPOCHS):
        model.train()
        t_loss = 0
        
        for imgs, masks in train_loader:
            imgs, masks = imgs.to(cfg.DEVICE), masks.to(cfg.DEVICE)
            with autocast():
                preds = model(imgs)
                loss = criterion(preds, masks)
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad()
            t_loss += loss.item()

        # Validation
        model.eval()
        metric.reset()
        with torch.no_grad():
            for imgs, masks in val_loader:
                imgs, masks = imgs.to(cfg.DEVICE), masks.to(cfg.DEVICE)
                with autocast():
                    preds = model(imgs)
                metric.update(preds.sigmoid() > 0.5, masks.int())
        
        val_iou = metric.compute().item()
        scheduler.step(epoch + val_iou)
        
        # Track Metrics
        t_loss_avg = t_loss/len(train_loader)
        tracker.log(epoch, t_loss_avg, val_iou, fold_idx)

        if (epoch+1) % 10 == 0:
            LOGGER.info(f"    Ep {epoch+1}: Loss {t_loss_avg:.4f} | IoU {val_iou:.4f}")

        if val_iou > best_iou:
            best_iou = val_iou
            torch.save(model.state_dict(), cfg.MODEL_DIR / f"model_fold_{fold_idx}.pth")
            patience_counter = 0
        else:
            patience_counter += 1
            if patience_counter >= cfg.PATIENCE:
                LOGGER.info(f"    Early stopping at epoch {epoch+1}")
                break
    
    del model, optimizer, scaler
    torch.cuda.empty_cache()
    gc.collect()
    return best_iou

# Utils: Export & TTA
def export_to_onnx(cfg):
    LOGGER.info("Exporting Best Model to ONNX...")
    model = smp.UnetPlusPlus(encoder_name=cfg.ENCODER, in_channels=3, classes=1).to("cpu")
    weight_path = cfg.MODEL_DIR / "model_fold_0.pth"
    
    if not weight_path.exists():
        LOGGER.error("Model weights not found for export.")
        return
    
    model.load_state_dict(torch.load(weight_path, map_location="cpu"))
    model.eval()
    
    dummy_input = torch.randn(1, 3, cfg.INPUT_SIZE, cfg.INPUT_SIZE)
    out_path = cfg.WORK_DIR / "device_segmentation.onnx"
    
    try:
        torch.onnx.export(
            model, dummy_input, out_path,
            input_names=["input"], output_names=["output"],
            opset_version=11,
            dynamic_axes={'input': {0: 'batch_size'}, 'output': {0: 'batch_size'}}
        )
        LOGGER.info(f"Export successful: {out_path}")
    except Exception as e:
        LOGGER.error(f"Export failed: {e}")

def predict_with_tta(model, image_tensor):
    """Test Time Augmentation"""
    with torch.no_grad():
        p1 = model(image_tensor).sigmoid()
        p2 = torch.flip(model(torch.flip(image_tensor, [3])).sigmoid(), [3])
    return (p1 + p2) / 2.0

# Main Execution
if __name__ == "__main__":
    try:
        # 1. Run Unit Tests
        run_unit_tests(CFG)
        
        # 2. Preprocess
        prep = Preprocessor(CFG)
        all_files = np.array(prep.process_data())
        
        # 3. Setup Experiment Tracker
        tracker = MetricTracker(CFG.LOG_DIR)
        
        # 4. Cross Validation
        kf = KFold(n_splits=CFG.FOLDS, shuffle=True, random_state=CFG.SEED)
        fold_scores = []
        
        LOGGER.info(f"Starting Training: {CFG.FOLDS} Folds | {CFG.ARCH} | {CFG.ENCODER}")
        
        for fold, (train_idx, val_idx) in enumerate(kf.split(all_files)):
            score = train_fold(fold, all_files[train_idx], all_files[val_idx], CFG, tracker)
            fold_scores.append(score)
            LOGGER.info(f"Fold {fold+1} Result: {score:.4f}")
            
        # 5. Save Metrics
        tracker.save()
        LOGGER.info(f"AVERAGE IoU: {np.mean(fold_scores):.4f}")
        LOGGER.info(f"Metrics saved to {CFG.LOG_DIR}/metrics.json")
        
        # 6. Export
        export_to_onnx(CFG)
        
        # 7. Visual Validation
        LOGGER.info("Generating Visual Report...")
        model = smp.UnetPlusPlus(encoder_name=CFG.ENCODER, in_channels=3, classes=1).to(CFG.DEVICE)
        model.load_state_dict(torch.load(CFG.MODEL_DIR / "model_fold_0.pth"))
        model.eval()
        
        _, val_idx = next(kf.split(all_files))
        val_ds = SegmentationDataset(all_files[val_idx], CFG, get_transforms('val'))
        val_loader = DataLoader(val_ds, batch_size=3, shuffle=True)
        
        imgs, masks = next(iter(val_loader))
        imgs = imgs.to(CFG.DEVICE)
        preds = predict_with_tta(model, imgs)
        
        # Plot
        imgs = imgs.cpu().numpy()
        masks = masks.cpu().numpy()
        preds = preds.cpu().numpy()
        
        mean = np.array([0.485, 0.456, 0.406]).reshape(3,1,1)
        std = np.array([0.229, 0.224, 0.225]).reshape(3,1,1)
        
        fig, axes = plt.subplots(len(imgs), 3, figsize=(10, 3*len(imgs)))
        if len(imgs) == 1: axes = np.array([axes])
        
        for i in range(len(imgs)):
            viz_img = np.clip((imgs[i] * std + mean).transpose(1,2,0), 0, 1)
            axes[i,0].imshow(viz_img); axes[i,0].set_title("Input")
            axes[i,1].imshow(masks[i].squeeze(), cmap='gray'); axes[i,1].set_title("Truth")
            axes[i,2].imshow(preds[i].squeeze() > 0.5, cmap='jet'); axes[i,2].set_title("Pred (TTA)")
            for ax in axes[i]: ax.axis('off')
            
        plt.tight_layout()
        plt.show()
        LOGGER.info("Pipeline Completed Successfully.")

    except Exception as e:
        LOGGER.critical(f"Critical Pipeline Failure: {e}")
        raise e

# Code to use the model for placement prediction

In [None]:
import os
import cv2
import numpy as np
import onnxruntime as ort
import matplotlib.pyplot as plt
from pathlib import Path
from ultralytics import YOLO
import urllib.parse

# Install Inference Dependencies
os.system('pip install -q onnxruntime-gpu')

class PatchPredictor:
    def __init__(self, onnx_model_path, yolo_model_path='yolo11n-seg.pt'):
        print(f"Loading ONNX Model: {onnx_model_path}...")
        
        providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
        try:
            self.session = ort.InferenceSession(str(onnx_model_path), providers=providers)
        except Exception:
            print("CUDA not found. Falling back to CPU.")
            self.session = ort.InferenceSession(str(onnx_model_path), providers=['CPUExecutionProvider'])
            
        self.input_name = self.session.get_inputs()[0].name
        self.img_size = 512
        
        print("Loading YOLO for Smart Cropping...")
        self.yolo = YOLO(yolo_model_path)

    def preprocess(self, image_path):
        img_bgr = cv2.imread(str(image_path))
        if img_bgr is None: raise ValueError(f"Image not found: {image_path}")
        img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
        h, w = img_rgb.shape[:2]

        # Smart Crop
        infer_img = img_rgb.copy()
        if max(h, w) > 1500:
            scale = 1500 / max(h, w)
            infer_img = cv2.resize(img_rgb, (0,0), fx=scale, fy=scale)
        
        results = self.yolo(infer_img, verbose=False, retina_masks=False)
        
        person_mask = np.zeros((h, w), dtype=np.uint8)
        if results[0].masks:
            boxes = results[0].boxes
            persons = boxes.cls == 0
            if persons.any():
                idx = results[0].boxes.conf[persons].argmax()
                real_idx = persons.nonzero(as_tuple=True)[0][idx]
                m = results[0].masks.data[real_idx].cpu().numpy()
                m = cv2.resize(m, (w, h))
                person_mask = (m > 0.5).astype(np.uint8)

        if np.sum(person_mask) > 0:
            rows = np.any(person_mask, axis=1)
            y_head_top = np.argmax(rows)
            cols = np.sum(person_mask, axis=0)
            center_x = int(np.dot(np.arange(w), cols) / (np.sum(cols)+1e-6))
        else:
            y_head_top = 0
            center_x = w // 2
        
        roi_height = h // 3
        square_dim = max(int(roi_height * 1.5), 256)
        center_y = y_head_top + (roi_height // 2)

        half = square_dim // 2
        x1, y1 = center_x - half, center_y - half
        x2, y2 = x1 + square_dim, y1 + square_dim

        pad_l, pad_t = max(0, -x1), max(0, -y1)
        pad_r, pad_b = max(0, x2 - w), max(0, y2 - h)

        padded_img = cv2.copyMakeBorder(img_rgb, pad_t, pad_b, pad_l, pad_r, cv2.BORDER_CONSTANT, value=[0,0,0])
        
        cx1, cy1 = x1 + pad_l, y1 + pad_t
        cx2, cy2 = cx1 + square_dim, cy1 + square_dim

        crop_img = padded_img[cy1:cy2, cx1:cx2]
        input_img = cv2.resize(crop_img, (self.img_size, self.img_size), interpolation=cv2.INTER_AREA)
        
        # Normalization
        norm_img = input_img.astype(np.float32) / 255.0
        norm_img = (norm_img - np.array([0.485, 0.456, 0.406])) / np.array([0.229, 0.224, 0.225])
        
        blob = np.transpose(norm_img, (2, 0, 1))
        blob = np.expand_dims(blob, axis=0)
        
        meta = {
            'padded_img': padded_img,
            'crop_coords': (cx1, cy1, cx2, cy2),
            'pad_info': (pad_t, pad_l, h, w)
        }
        
        return blob.astype(np.float32), meta

    def predict_raw(self, input_tensor):
        outputs = self.session.run(None, {self.input_name: input_tensor})
        logits = outputs[0][0, 0, :, :]
        probs = 1 / (1 + np.exp(-logits))
        return probs

    def visualize(self, image_path):
        try:
            input_tensor, meta = self.preprocess(image_path)
            
            # Get Probabilities
            probs = self.predict_raw(input_tensor)
            max_conf = np.max(probs)
            print(f"    Max Conf: {max_conf:.4f}")

            # Adaptive Thresholding
            thresholds = [0.5, 0.3, 0.15]
            mask = None
            final_thresh = 0.5
            
            for thresh in thresholds:
                temp_mask = (probs > thresh).astype(np.uint8)
                if np.sum(temp_mask) > 50: # Ensure at least 50 pixels are detected
                    mask = temp_mask
                    final_thresh = thresh
                    print(f"    Patch detected at threshold: {thresh}")
                    break
            
            padded_img = meta['padded_img']
            cx1, cy1, cx2, cy2 = meta['crop_coords']
            pad_t, pad_l, h, w = meta['pad_info']
            
            overlay = padded_img.copy()

            if mask is not None:
                # Resize mask to crop size
                crop_h, crop_w = cy2 - cy1, cx2 - cx1
                real_scale_mask = cv2.resize(mask, (crop_w, crop_h), interpolation=cv2.INTER_NEAREST)
                
                # Place in full image
                full_mask = np.zeros((padded_img.shape[0], padded_img.shape[1]), dtype=np.uint8)
                full_mask[cy1:cy2, cx1:cx2] = real_scale_mask
                
                # Draw Green Overlay
                color_mask = np.zeros_like(padded_img)
                color_mask[:, :] = [0, 255, 0] 
                
                overlay = np.where(full_mask[:, :, None] == 1, 
                                cv2.addWeighted(padded_img, 0.7, color_mask, 0.3, 0), 
                                padded_img)
                
                # Draw White Contour
                contours, _ = cv2.findContours(full_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
                cv2.drawContours(overlay, contours, -1, (255, 255, 255), 3)
            else:
                print("    No patch detected.")
                # Draw Text on image indicating failure
                cv2.putText(overlay, f"No Patch Detected (Max Conf: {max_conf:.2f})", 
                        (50, 100), cv2.FONT_HERSHEY_SIMPLEX, 2, (255, 0, 0), 3)

            # Remove padding
            final_view = overlay[pad_t:pad_t+h, pad_l:pad_l+w]

            plt.figure(figsize=(12, 12))
            plt.imshow(final_view)
            plt.axis('off')
            plt.title(f"File: {image_path.name} | Conf: {max_conf:.2f}", fontsize=14)
            plt.show()
            
        except Exception as e:
            print(f"    Skipping {image_path.name}: {str(e)}")

if __name__ == "__main__":
    ONNX_PATH = Path("/kaggle/working/prod_pipeline_v1/device_segmentation.onnx")
    
    # Update this path to your target directory
    RAW_DIR = Path("/kaggle/input/sample3")
    
    if ONNX_PATH.exists():
        predictor = PatchPredictor(ONNX_PATH)
        
        # Gather all images
        all_images = sorted(list(RAW_DIR.glob("*.jpg")) + list(RAW_DIR.glob("*.png")) + list(RAW_DIR.glob("*.jpeg")))
        
        if all_images:
            print(f"Found {len(all_images)} images in directory. Processing all...")
            for idx, test_img in enumerate(all_images):
                print(f"\n[{idx+1}/{len(all_images)}] Processing: {test_img.name}")
                predictor.visualize(test_img)
        else:
            print("No images found to test.")
    else:
        print("Please run training first.")