In [None]:
import os, re, glob, numpy as np
from collections import Counter
from obspy.io.segy.segy import _read_segy
from sklearn.model_selection import train_test_split
import torch, torch.nn as nn
from torch.utils.data import TensorDataset, DataLoader
from pytorch_msssim import SSIM
from tqdm import tqdm
import matplotlib.pyplot as plt
from skimage.metrics import structural_similarity as ssim

BASE_DIR = "/home/pc-2/Documents/CAVE_minciencias/utah_model/shot densos 2"
H, W = 512, 4096
BATCH = 1
EPOCHS = 20
LR = 1e-4
PCTS = list(range(10, 100, 10))  
PRINT_MAX_EX = 5
LAMBDA_SPARSITY = 0.1          
OUT_DIR = "./resultados_topk_ste"
os.makedirs(OUT_DIR, exist_ok=True)

GLOBAL_SEED = 42
np.random.seed(GLOBAL_SEED)
torch.manual_seed(GLOBAL_SEED)

def print_header(title):
    print("\n" + "="*10 + f" {title} " + "="*10)

def print_basic_stats(name, arr):
    arr_np = arr if isinstance(arr, np.ndarray) else arr.detach().cpu().numpy()
    print(f"[{name}] shape={arr_np.shape}, dtype={arr_np.dtype}, "
          f"min={np.nanmin(arr_np):.4f}, max={np.nanmax(arr_np):.4f}, "
          f"mean={np.nanmean(arr_np):.4f}, std={np.nanstd(arr_np):.4f}")

def sanity_no_nan_inf(name, arr):
    arr_np = arr if isinstance(arr, np.ndarray) else arr.detach().cpu().numpy()
    n_nan = np.isnan(arr_np).sum()
    n_inf = np.isinf(arr_np).sum()
    assert n_nan == 0 and n_inf == 0, f"{name}: NaN={n_nan}, Inf={n_inf}"

def check_divisible_by_8(h, w):
    for val, nm in [(h, "H"), (w, "W")]:
        assert val % 8 == 0, f"{nm} debe ser múltiplo de 8 para U-Net: {val}"

def check_train_test_disjoint(idx_tr, idx_te):
    inter = set(idx_tr.tolist()).intersection(set(idx_te.tolist()))
    assert len(inter) == 0, f"Train/Test con intersección! {inter}"

def gpu_info(device):
    if device.type == "cuda":
        print(f"[GPU] {torch.cuda.get_device_name(0)} | "
              f"cap={torch.cuda.get_device_capability(0)} | "
              f"mem_alloc={torch.cuda.memory_allocated()/1e9:.2f} GB | "
              f"mem_reserved={torch.cuda.memory_reserved()/1e9:.2f} GB")
    else:
        print("[GPU] No CUDA: usando CPU")

def tag(fname):
    m = re.search(r'shot_(y\d{2})_', os.path.basename(fname).lower())
    return m.group(1) if m else None

def load_all(base_dir):
    pats = ["shot_y0[1-7]_*.sgy", "shot_y0[1-7]_*.segy"]
    files = sorted(sum([glob.glob(os.path.join(base_dir, p)) for p in pats], []))
    files = [f for f in files if tag(f) in {f"y0{i}" for i in range(1,8)}]
    arrs, tags_list = [], []
    for f in files:
        st = _read_segy(f, headonly=False)
        arr = np.array([tr.data for tr in st.traces], dtype=np.float32)
        arrs.append(arr)
        tags_list.append(tag(f))
    print_header("Carga de archivos")
    print(f"Total archivos: {len(files)} | ejemplos: {files[:PRINT_MAX_EX]}")
    for i in range(min(3, len(arrs))):
        print(f"  - arr[{i}] shape={arrs[i].shape}, tag={tags_list[i]}")
    return np.stack(arrs, 0), np.array(tags_list)

gathers, tags = load_all(BASE_DIR)   
print_header("Antes de recorte")
print_basic_stats("gathers_raw", gathers)

assert gathers.shape[1] >= H and gathers.shape[2] >= W, "Dimensiones de archivo < H/W"
check_divisible_by_8(H, W)
gathers = gathers[:, :H, :W].copy()
print_header("Después de recorte")
print_basic_stats("gathers", gathers)
sanity_no_nan_inf("gathers", gathers)

idx = np.arange(len(tags))
idx_tr, idx_te = train_test_split(idx, test_size=0.2, random_state=42, shuffle=True)
check_train_test_disjoint(idx_tr, idx_te)
Gtr, Gte = gathers[idx_tr], gathers[idx_te]

def norm_trace(x):
    m = np.max(np.abs(x), axis=2, keepdims=True) + 1e-6
    return x / m

Ytr, Yte = norm_trace(Gtr), norm_trace(Gte)
print_header("Normalización")
print_basic_stats("Ytr", Ytr); print_basic_stats("Yte", Yte)
sanity_no_nan_inf("Ytr", Ytr); sanity_no_nan_inf("Yte", Yte)
print("Rango esperado ~[-1,1]. data_range=2.0 para SSIM.")

class UNet2DFull(nn.Module):
    def __init__(self):
        super().__init__()
        def blk(cin, cout):
            return nn.Sequential(
                nn.Conv2d(cin, cout, 3, padding=1),
                nn.BatchNorm2d(cout), nn.LeakyReLU(0.01, True),
                nn.Conv2d(cout, cout, 3, padding=1),
                nn.BatchNorm2d(cout), nn.LeakyReLU(0.01, True)
            )
        self.e1, self.p1 = blk(1,64), nn.MaxPool2d(2,2)
        self.e2, self.p2 = blk(64,128), nn.MaxPool2d(2,2)
        self.e3, self.p3 = blk(128,256), nn.MaxPool2d(2,2)
        self.bott = blk(256,512)
        self.u3 = nn.ConvTranspose2d(512,256,2,2); self.d3 = blk(512,256)
        self.u2 = nn.ConvTranspose2d(256,128,2,2); self.d2 = blk(256,128)
        self.u1 = nn.ConvTranspose2d(128, 64,2,2); self.d1 = blk(128, 64)
        self.out = nn.Conv2d(64,1,1)
    def forward(self,x):
        e1=self.e1(x); p1=self.p1(e1)
        e2=self.e2(p1); p2=self.p2(e2)
        e3=self.e3(p2); p3=self.p3(e3)
        b=self.bott(p3)
        u3=self.u3(b); d3=self.d3(torch.cat([u3, self.crop(e3,u3)],1))
        u2=self.u2(d3); d2=self.d2(torch.cat([u2, self.crop(e2,u2)],1))
        u1=self.u1(d2); d1=self.d1(torch.cat([u1, self.crop(e1,u1)],1))
        return torch.tanh(self.out(d1))
    @staticmethod
    def crop(a,b):
        _,_,h,w=b.shape; _,_,H,W=a.shape
        dh,dw=(H-h)//2,(W-w)//2
        return a[:,:,dh:dh+h, dw:dw+w]

def count_params(model):
    return sum(p.numel() for p in model.parameters())

class BinaryReceiverMaskSTE(nn.Module):
    """
    Elimina K = round(frac_remove*H) receptores con menor prob(keep) y usa STE.
    """
    def __init__(self, n_rec, init_keep_prob=0.9):
        super().__init__()
        self.n_rec = n_rec
        init_logit = np.log(init_keep_prob/(1-init_keep_prob))
        self.logits = nn.Parameter(torch.full((n_rec,), float(init_logit)))
        self.frac_remove = 0.10  
    def set_frac_remove(self, frac):
        self.frac_remove = float(np.clip(frac, 0.0, 1.0))
    def forward(self, x):
        assert x.dim()==4 and x.shape[2]==self.n_rec
        probs = torch.sigmoid(self.logits)        
        K = int(round(self.frac_remove * self.n_rec))
        if K > 0:
            idx_del = torch.topk(probs, K, largest=False).indices
            hard = torch.ones_like(probs); hard[idx_del] = 0.0
        else:
            hard = torch.ones_like(probs)
        ste_mask = hard + probs - probs.detach()  
        x_masked = x * ste_mask.view(1,1,self.n_rec,1)
        return x_masked, probs, hard
    @torch.no_grad()
    def hard_indices_removed(self):
        probs = torch.sigmoid(self.logits)
        K = int(round(self.frac_remove * self.n_rec))
        if K <= 0: return np.array([], dtype=int)
        idx_del = torch.topk(probs, K, largest=False).indices
        return idx_del.cpu().numpy()

def eval_metrics_removed_only_with_mask(model, Xe, Ye, hard_mask_01, device):
    MSE, PSNR, SSIMg, SNR = [], [], [], []
    model.eval()
    with torch.no_grad():
        idx0 = np.where(hard_mask_01.detach().cpu().numpy()==0)[0]
        if len(idx0)==0:
            return {"MSE": None, "PSNR": None, "SSIM": None, "SNR": None}
        for i in range(Ye.shape[0]):
            yt = Ye[i, 0].detach().cpu().numpy()         
            yp = model(Xe[i:i+1].to(device)).cpu().squeeze().numpy()
            ytn, ypn = yt[idx0], yp[idx0]
            mse = np.mean((ytn - ypn)**2)
            amp = np.ptp(ytn) + 1e-8
            psn = 20*np.log10(amp / (np.sqrt(mse + 1e-12))) if mse > 0 else float('inf')
            ssi = np.mean([ssim(ytn[j], ypn[j], data_range=2) for j in range(ytn.shape[0])])
            snr = 10*np.log10((np.mean(ytn**2)+1e-12)/(mse+1e-12))
            MSE.append(mse); PSNR.append(psn); SSIMg.append(ssi); SNR.append(snr)
    mean = lambda v: float(np.mean(v)) if len(v) else None
    return {"MSE": mean(MSE), "PSNR": mean(PSNR), "SSIM": mean(SSIMg), "SNR": mean(SNR)}

def visualizar_comparacion(x_input, y_pred, y_true, titulo_extra=""):
    """ Panel 4: entrada enmascarada / pred / original / error absoluto """
    error = np.abs(y_pred - y_true)
    vmin, vmax = -1, 1
    fig, axs = plt.subplots(1, 4, figsize=(24, 6))

    im0 = axs[0].imshow(x_input.T, cmap='gray', aspect='auto', origin='upper', vmin=vmin, vmax=vmax)
    axs[0].set_title("Input (enmascarado)")
    axs[0].set_xlabel("Receptor"); axs[0].set_ylabel("Tiempo")
    fig.colorbar(im0, ax=axs[0], fraction=0.046, pad=0.04)

    im1 = axs[1].imshow(y_pred.T, cmap='gray', aspect='auto', origin='upper', vmin=vmin, vmax=vmax)
    axs[1].set_title("Predicción (U-Net)")
    axs[1].set_xlabel("Receptor"); axs[1].set_ylabel("Tiempo")
    fig.colorbar(im1, ax=axs[1], fraction=0.046, pad=0.04)

    im2 = axs[2].imshow(y_true.T, cmap='gray', aspect='auto', origin='upper', vmin=vmin, vmax=vmax)
    axs[2].set_title("Shot original")
    axs[2].set_xlabel("Receptor"); axs[2].set_ylabel("Tiempo")
    fig.colorbar(im2, ax=axs[2], fraction=0.046, pad=0.04)

    im3 = axs[3].imshow(error.T, cmap='inferno', aspect='auto', origin='upper')
    axs[3].set_title("Error absoluto")
    axs[3].set_xlabel("Receptor"); axs[3].set_ylabel("Tiempo")
    fig.colorbar(im3, ax=axs[3], fraction=0.046, pad=0.04)

    fig.suptitle(f"Comparación: Entrada — Predicción — Original — Error {titulo_extra}", fontsize=14)
    plt.tight_layout(); plt.show()

def plot_schematic_bars(hard_mask_01, height=40, title="Esquemático de receptores (1 keep / 0 removed)"):
    hard_np = np.asarray(hard_mask_01, dtype=np.float32)
    Hloc = hard_np.shape[0]
    img = np.ones((height, Hloc), dtype=np.float32)
    img[:, hard_np == 0] = 0.0                         
    plt.figure(figsize=(14, 2.2))
    plt.imshow(img, cmap="gray", aspect="auto", origin="upper", vmin=0, vmax=1)
    plt.yticks([]); plt.xlabel("Receptor (índice 0..H-1)")
    plt.title(title)
    plt.tight_layout(); plt.show()

def plot_gather_with_removed_bars(gather_2d, hard_mask_01, vmin=-1, vmax=1, title="Shot gather (barras negras = 0)"):
    Hloc, Wloc = gather_2d.shape
    plt.figure(figsize=(16, 6))
    plt.imshow(gather_2d.T, cmap="gray", aspect="auto", origin="upper", vmin=vmin, vmax=vmax)
    removed = np.where(np.asarray(hard_mask_01)==0)[0]
    for r in removed:
        plt.axvline(r, color="k", linewidth=1.2)
    plt.xlabel("Receptor"); plt.ylabel("Tiempo (muestras)")
    plt.title(title)
    plt.tight_layout(); plt.show()

def save_npz_per_pct(pct, probs, hard, logits, removed_idx, Xe_masked_shot, Y_true_shot, Y_pred_shot, loss_hist, metrics_dict):
    path = os.path.join(OUT_DIR, f"run_pct{pct:02d}.npz")
    np.savez_compressed(
        path,
        probs=np.asarray(probs, dtype=np.float32),
        hard=np.asarray(hard, dtype=np.float32),
        logits=np.asarray(logits, dtype=np.float32),
        removed_idx=np.asarray(removed_idx, dtype=np.int32),
        Xe_masked_shot=np.asarray(Xe_masked_shot, dtype=np.float32),
        Y_true_shot=np.asarray(Y_true_shot, dtype=np.float32),
        Y_pred_shot=np.asarray(Y_pred_shot, dtype=np.float32),
        loss_hist=np.asarray(loss_hist, dtype=np.float32),
        MSE=metrics_dict.get("MSE") if metrics_dict else None,
        PSNR=metrics_dict.get("PSNR") if metrics_dict else None,
        SSIM=metrics_dict.get("SSIM") if metrics_dict else None,
        SNR=metrics_dict.get("SNR") if metrics_dict else None,
        H=np.int32(Y_true_shot.shape[0]),
        W=np.int32(Y_true_shot.shape[1])
    )
    print(f"💾 Guardado: {path}")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print_header("Dispositivo")
print(f"device={device} | PCTS={PCTS} | BATCH={BATCH} | EPOCHS={EPOCHS}")
gpu_info(device)

Xt_base = torch.from_numpy(Ytr).unsqueeze(1).to(device) 
Yt_t    = torch.from_numpy(Ytr).unsqueeze(1).to(device)
Xe_base = torch.from_numpy(Yte).unsqueeze(1).to(device)
Ye_t    = torch.from_numpy(Yte).unsqueeze(1).to(device)

curvas_loss = {}
resultados_metricas = {}

for pct in PCTS:
    print_header(f"Máscara Top-K STE con objetivo {pct}% receptores ELIMINADOS")
    frac = pct / 100.0
    target_keep = 1.0 - frac

    loader = DataLoader(TensorDataset(Xt_base, Yt_t), batch_size=BATCH, shuffle=True)

    model = UNet2DFull().to(device)
    print(f"Parámetros U-Net: {count_params(model):,}")

    init_keep_prob = max(0.05, min(0.95, target_keep))
    mask_layer = BinaryReceiverMaskSTE(n_rec=H, init_keep_prob=init_keep_prob).to(device)
    mask_layer.set_frac_remove(frac)

    opt = torch.optim.Adam([
        {"params": model.parameters(), "lr": LR},
        {"params": mask_layer.parameters(), "lr": LR}
    ], lr=LR)

    crit  = SSIM(data_range=2.0, size_average=True, channel=1)

    loss_hist = []
    model.train()
    for ep in range(EPOCHS):
        run = 0.0
        for xb, yb in tqdm(loader, desc=f"{pct}% | Época {ep+1}/{EPOCHS}", leave=False):
            opt.zero_grad(set_to_none=True)
            xb_masked, probs, hard = mask_layer(xb)      
            yp = model(xb_masked)
            loss_rec = 1 - crit(yp, yb)

            keep_mean = torch.sigmoid(mask_layer.logits).mean()
            loss_sp = (keep_mean - target_keep) ** 2

            loss = loss_rec + LAMBDA_SPARSITY * loss_sp
            loss.backward(); opt.step()
            run += float(loss.item())

        ep_loss = run / len(loader)
        loss_hist.append(ep_loss)
        with torch.no_grad():
            p = torch.sigmoid(mask_layer.logits).detach().cpu().numpy()
            q10, q50, q90 = np.percentile(p, [10, 50, 90])
        print(f"[{pct}%] Epoch {ep+1:02d}/{EPOCHS} | total={ep_loss:.6f} "
              f"| rec={float(loss_rec):.6f} | sp={float(loss_sp):.6f} "
              f"| keep_mean={float(keep_mean):.4f} | probs q10={q10:.3f}, q50={q50:.3f}, q90={q90:.3f}")

    curvas_loss[str(pct)] = loss_hist

    model.eval()
    with torch.no_grad():
        mask_layer.set_frac_remove(frac)
        Xe_masked, probs_test, hard_test = mask_layer(Xe_base) 

        idx_vis = 0
        pred = model(Xe_masked[idx_vis:idx_vis+1]).cpu().squeeze().numpy()
        entrada = Xe_masked[idx_vis, 0].cpu().numpy()
        real = Ye_t[idx_vis, 0].cpu().numpy()
        n_elim = int((hard_test==0).sum().item())
        print(f"[Visualización] receptores eliminados por máscara={n_elim} de {H}")

        visualizar_comparacion(entrada, pred, real, titulo_extra=f"(Test idx={idx_vis}, {pct}%)")

        elim_idx = torch.where(hard_test==0)[0].cpu().numpy()
        if len(elim_idx) > 0:
            tiempo = np.arange(real.shape[1])
            sub = elim_idx[:3]
            plt.figure(figsize=(30, 4))
            for k, tr_i in enumerate(sub):
                plt.subplot(1, len(sub), k+1)
                plt.plot(tiempo, real[tr_i], label='Real', linewidth=1.5)
                plt.plot(tiempo, pred[tr_i], label='Predicho', linewidth=1.5)
                plt.title(f"Traza eliminada #{tr_i}")
                plt.xlabel("Tiempo"); plt.ylabel("Amplitud")
                plt.legend(); plt.grid(True)
            plt.suptitle(f"Comparación real vs predicho - Shot {idx_vis} - {pct}% (máscara STE Top-K)")
            plt.tight_layout(); plt.show()
        else:
            print("⚠ La máscara no eliminó receptores (frac_remove=0).")

        plot_schematic_bars(hard_test.detach().cpu().numpy(),
                            title=f"Esquemático de receptores — {pct}% eliminados")
        plot_gather_with_removed_bars(real, hard_test.detach().cpu().numpy(),
                                      title=f"Shot TEST idx={idx_vis} — {pct}% (barras negras = receptores eliminados)")

    with torch.no_grad():
        Xm_test, _, hard_test = mask_layer(Xe_base)
    m_removed = eval_metrics_removed_only_with_mask(model, Xm_test, Ye_t, hard_test, device)
    resultados_metricas[f"{pct}"] = m_removed
    print(f"[Métricas {pct}%] {m_removed}")

    probs_np = torch.sigmoid(mask_layer.logits).detach().cpu().numpy()
    hard_np  = hard_test.detach().cpu().numpy()
    logits_np = mask_layer.logits.detach().cpu().numpy()
    removed_idx = np.where(hard_np == 0)[0]

    save_npz_per_pct(
        pct=pct,
        probs=probs_np,
        hard=hard_np,
        logits=logits_np,
        removed_idx=removed_idx,
        Xe_masked_shot=entrada,   
        Y_true_shot=real,        
        Y_pred_shot=pred,         
        loss_hist=loss_hist,
        metrics_dict=m_removed
    )

plt.figure(figsize=(10, 6))
for pct, curva in curvas_loss.items():
    plt.plot(range(1, len(curva)+1), curva, label=f"{pct}%")
plt.title("Curvas de pérdida (total = reconstrucción + λ·sparsity)")
plt.xlabel("Época"); plt.ylabel("Pérdida")
plt.legend(title="Porcentaje eliminado")
plt.grid(True); plt.tight_layout(); plt.show()

print("\n=== Resultados de métricas (diccionario) ===")
print(resultados_metricas)
