In [35]:
import numpy as np
import rasterio as rio
from skimage.measure import label, regionprops
from skimage.morphology import binary_opening, disk
import torch

from models.rrdbnet import RRDBNet
from segment_anything import sam_model_registry, SamPredictor

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
SR_SCALE = 4
MIN_AREA = 16
BOX_BUFFER = 2
BATCH_BOXES = 64
MULTIMASK_OUTPUT = False

S2_RGB_TIF = "/home/tidop/masterIA/TFM_BRCD/data/filtered/s2/s2_26318.tif"
BUILD_MASK_TIF = "/home/tidop/masterIA/TFM_BRCD/data/filtered/building/building_26318.tif"
ROAD_MASK_TIF  = "/home/tidop/masterIA/TFM_BRCD/data/filtered/road/road_26318.tif"
REAL_ESRGAN_CKPT = "../checkpoints/checkpoint.tar"
SAM_CKPT = "../checkpoints/sam_vit_b_01ec64.pth"

# ---------- helpers de forma ----------
def chw01_to_hwc_uint8(chw):
    """(3,H,W)->(H,W,3) uint8 0..255"""
    assert chw.ndim == 3 and chw.shape[0] == 3, f"Esperaba CHW con C=3, tengo {chw.shape}"
    arr = np.clip(chw, 0, 1)
    arr = (arr * 255.0).round().astype(np.uint8)
    return np.moveaxis(arr, 0, -1)  # CHW -> HWC

def read_s2_rgb_chw01(path):
    """Lee S2 RGB como float01 CHW (3,H,W). Ajusta los índices a tu orden real."""
    with rio.open(path) as src:
        # si tu archivo ya trae RGB en [3,2,1], usa esos; si tiene 4 bandas RGBN, usa [4,3,2].
        if src.count >= 4:
            rgb = src.read(indexes=[4,3,2]).astype(np.float32)  # (3,H,W)
        else:
            rgb = src.read(indexes=[3,2,1]).astype(np.float32)
    rgb /= max(1e-6, rgb.max())
    assert rgb.ndim == 3 and rgb.shape[0] == 3, f"RGB esperado como CHW (3,H,W), obtuve {rgb.shape}"
    return rgb  # CHW

def read_mask_binary_hw(path):
    """Lee máscara HW binaria {0,1}."""
    with rio.open(path) as src:
        m = src.read(1)
    m = (m > 0).astype(np.uint8)
    assert m.ndim == 2, f"Máscara no es HW, shape={m.shape}"
    return m

def component_boxes_xyxy(mask01, buffer_px=2, min_area=16):
    """
    A partir de una máscara HW binaria, obtiene boxes en XYXY (x1,y1,x2,y2)
    NOTA: regionprops entrega bbox en (y1, x1, y2, x2); convertimos a XY.
    """
    m = (mask01 > 0).astype(np.uint8)
    if m.sum() == 0:
        return []
    m = binary_opening(m, disk(1)).astype(np.uint8)

    lab = label(m, connectivity=1)
    H, W = m.shape
    boxes = []
    for rp in regionprops(lab):
        if rp.area < min_area:
            continue
        y1, x1, y2, x2 = rp.bbox  # (y1, x1, y2, x2) half-open
        # cerramos a píxeles + buffer
        x1 = max(0, x1 - buffer_px)
        y1 = max(0, y1 - buffer_px)
        x2 = min(W-1, x2 - 1 + buffer_px)
        y2 = min(H-1, y2 - 1 + buffer_px)
        boxes.append([float(x1), float(y1), float(x2), float(y2)])  # XYXY
    return boxes

def upscale_boxes_xyxy(boxes, scale):
    if scale == 1 or not boxes:
        return boxes
    return [[x1*scale, y1*scale, x2*scale, y2*scale] for (x1,y1,x2,y2) in boxes]

def run_sam_boxes(predictor, image_uint8, boxes_xyxy, multimask_output=False):
    """
    Procesa UNA caja por llamada (compatible con builds de SAM que no soportan batch de boxes).
    - image_uint8: (H,W,3) uint8 en la MISMA grilla que 'boxes_xyxy'
    - boxes_xyxy: lista de [x1,y1,x2,y2] (floats)
    Devuelve máscara HW binaria (OR de todas las cajas).
    """
    predictor.set_image(image_uint8)
    H, W = image_uint8.shape[:2]
    out_mask = np.zeros((H, W), dtype=bool)

    for b in boxes_xyxy:
        box = np.asarray(b, dtype=np.float32).reshape(1, 4)  # (1,4) una sola caja
        masks, _, _ = predictor.predict(
            box=box, point_coords=None, point_labels=None,
            multimask_output=multimask_output
        )
        # masks: (1,H,W) si multimask_output=False ; (1,M,H,W) si True
        m = masks[0] if masks.ndim == 3 else masks.any(axis=1)[0]
        out_mask |= m

    return out_mask.astype(np.uint8)

# ---------- 1) SR en CHW y conversión a HWC ----------
# Carga SR
net_hr = RRDBNet(num_in_ch=3, num_out_ch=3, scale=SR_SCALE).to(DEVICE)
state = torch.load(REAL_ESRGAN_CKPT, map_location="cpu")
net_hr.load_state_dict(state['net_g_ema'])
net_hr.eval().half()
for p in net_hr.parameters(): p.requires_grad = False

# S2 como CHW
s2_chw = read_s2_rgb_chw01(S2_RGB_TIF)                    # (3,H,W)
s2_tensor = torch.from_numpy(s2_chw).unsqueeze(0).to(DEVICE)  # (1,3,H,W)

with torch.no_grad():
    sr_tensor = net_hr(s2_tensor.half())                  # (1,3,4H,4W)
sr_chw = sr_tensor.squeeze(0).float().cpu().numpy().clip(0,1)  # (3,4H,4W)

# SAM: HWC uint8
img_uint8_sr = chw01_to_hwc_uint8(sr_chw)                 # (4H,4W,3)

print("CHECK shapes -> S2 CHW:", s2_chw.shape,
      "| SR CHW:", sr_chw.shape, "| SAM HWC:", img_uint8_sr.shape)

# ---------- 2) Máscaras HW y boxes en grilla nativa ----------
build_mask = read_mask_binary_hw(BUILD_MASK_TIF)  # (H,W) 256x256
road_mask  = read_mask_binary_hw(ROAD_MASK_TIF)   # (H,W) 256x256
assert build_mask.shape == road_mask.shape, "Máscaras con distinta forma"

build_boxes_native = component_boxes_xyxy(build_mask, buffer_px=BOX_BUFFER, min_area=MIN_AREA)
road_boxes_native  = component_boxes_xyxy(road_mask,  buffer_px=BOX_BUFFER, min_area=MIN_AREA)

# como pasamos a SAM la SR (4H,4W), escalamos cajas ×4
build_boxes_sr = upscale_boxes_xyxy(build_boxes_native, SR_SCALE)
road_boxes_sr  = upscale_boxes_xyxy(road_boxes_native,  SR_SCALE)

# ---------- 3) SAM ----------
sam = sam_model_registry["vit_b"](checkpoint=SAM_CKPT).to(DEVICE)
predictor = SamPredictor(sam)

# Smoke test (opcional): 1–2 cajas para validar shapes
if len(build_boxes_sr) > 0:
    test_chunk = np.asarray(build_boxes_sr[:2], dtype=np.float32)  # (n,4)
    predictor.set_image(img_uint8_sr)
    _m, _, _ = predictor.predict(box=test_chunk, multimask_output=False)
    print("OK SAM test masks shape:", _m.shape)  # esperado: (n, Hsr, Wsr)

# ---------- 4) Refinado por clase en SR ----------
build_ref_sr = run_sam_boxes(predictor, img_uint8_sr, build_boxes_sr,
                             batch=BATCH_BOXES, multimask_output=MULTIMASK_OUTPUT)
road_ref_sr  = run_sam_boxes(predictor, img_uint8_sr, road_boxes_sr,
                             batch=BATCH_BOXES, multimask_output=MULTIMASK_OUTPUT)

# ---------- 5) Bajar a 256x256 y componer semántica ----------
from skimage.transform import resize

def down_nn(mask_hw, out_hw):
    return (resize(mask_hw.astype(float), out_hw, order=0, preserve_range=True) > 0.5).astype(np.uint8)

build_ref_native = down_nn(build_ref_sr, build_mask.shape)
road_ref_native  = down_nn(road_ref_sr,  road_mask.shape)

final_sem = np.zeros_like(build_mask, dtype=np.uint8)
final_sem[build_ref_native == 1] = 1
final_sem[(final_sem == 0) & (road_ref_native == 1)] = 2

print("FINAL shapes -> build_ref_sr:", build_ref_sr.shape,
      "| final_sem (nativo):", final_sem.shape)  # debería ser (256,256)

CHECK shapes -> S2 CHW: (3, 64, 64) | SR CHW: (3, 256, 256) | SAM HWC: (256, 256, 3)


RuntimeError: Sizes of tensors must match except in dimension 1. Expected size 1 but got size 2 for tensor number 1 in the list.