In [None]:
import os, re, glob, numpy as np
from obspy.io.segy.segy import _read_segy
from sklearn.model_selection import train_test_split
import torch, torch.nn as nn
from torch.utils.data import TensorDataset, DataLoader
from pytorch_msssim import SSIM
import matplotlib.pyplot as plt
from skimage.metrics import structural_similarity as ssim
import pandas as pd
from scipy.interpolate import griddata
from scipy.spatial import QhullError

BASE_DIR = "/home/pc-2/Documents/CAVE_minciencias/final_model/shot_cruz_3D"

OUT_DIR  = os.path.join(
    "/home/pc-2/Documents/CAVE_minciencias/final_model/",
    "salidas_crg_shotmask_ste_3D_ANALITICO"
)
os.makedirs(OUT_DIR, exist_ok=True)

H, W = 512, 4096
BATCH = 1
EPOCHS = 40
LR = 1e-4
GLOBAL_SEED = 42
LAMBDA_SPARSITY = 10
SCENARIOS = [10, 20, 30, 40, 50, 60, 70, 80, 90]

np.random.seed(GLOBAL_SEED)
torch.manual_seed(GLOBAL_SEED)
plt.rcParams["figure.dpi"] = 120


def norm_trace_lastaxis(x):
    m = np.max(np.abs(x), axis=-1, keepdims=True) + 1e-6
    return x / m

def crop_to_mult8_2d(arr):
    S, T = arr.shape[-2], arr.shape[-1]
    S8, T8 = (S // 8) * 8, (T // 8) * 8
    return arr[..., :S8, :T8].copy()

def check_divisible_by_8(h, w):
    for v, n in [(h, "H"), (w, "W")]:
        assert v % 8 == 0, f"{n} debe ser múltiplo de 8"


def tag(fname):
    base = os.path.basename(fname).lower()
    m = re.search(r"shot_sn_(\d+)\.(?:sgy|segy)$", base)
    return m.group(1) if m else None


def load_all(base_dir):
    pats = ["shot_SN_*.sgy", "shot_SN_*.segy"]
    files = sorted(sum([glob.glob(os.path.join(base_dir, p)) for p in pats], []))
    files = [f for f in files if tag(f) is not None]
    if not files:
        raise RuntimeError("No encontré archivos shot_SN_*.sgy/segy")

    arrs = []
    for f in files:
        st = _read_segy(f, headonly=False)
        A = np.array([tr.data for tr in st.traces], dtype=np.float32)
        arrs.append(A)

    return np.stack(arrs, 0), files


class UNet2DFull(nn.Module):
    def __init__(self):
        super().__init__()
        def blk(cin, cout):
            return nn.Sequential(
                nn.Conv2d(cin, cout, 3, padding=1),
                nn.BatchNorm2d(cout),
                nn.LeakyReLU(0.01, True),
                nn.Conv2d(cout, cout, 3, padding=1),
                nn.BatchNorm2d(cout),
                nn.LeakyReLU(0.01, True)
            )
        self.e1, self.p1 = blk(1,64), nn.MaxPool2d(2)
        self.e2, self.p2 = blk(64,128), nn.MaxPool2d(2)
        self.e3, self.p3 = blk(128,256), nn.MaxPool2d(2)
        self.bott = blk(256,512)

        self.u3 = nn.ConvTranspose2d(512,256,2,2); self.d3 = blk(512,256)
        self.u2 = nn.ConvTranspose2d(256,128,2,2); self.d2 = blk(256,128)
        self.u1 = nn.ConvTranspose2d(128,64, 2,2); self.d1 = blk(128,64)
        self.out = nn.Conv2d(64,1,1)

    def crop(self, a, b):
        _,_,h,w = b.shape
        _,_,H,W = a.shape
        dh = (H-h)//2
        dw = (W-w)//2
        return a[:,:,dh:dh+h, dw:dw+w]

    def forward(self, x):
        e1=self.e1(x); p1=self.p1(e1)
        e2=self.e2(p1); p2=self.p2(e2)
        e3=self.e3(p2); p3=self.p3(e3)

        b=self.bott(p3)
        u3=self.u3(b); d3=self.d3(torch.cat([u3, self.crop(e3,u3)],1))
        u2=self.u2(d3); d2=self.d2(torch.cat([u2, self.crop(e2,u2)],1))
        u1=self.u1(d2); d1=self.d1(torch.cat([u1, self.crop(e1,u1)],1))
        return torch.tanh(self.out(d1))


class BinaryShotMaskSTE(nn.Module):
    def __init__(self, n_shots, init_keep_prob=0.9):
        super().__init__()
        init_keep_prob = np.clip(init_keep_prob, 1e-3, 1-1e-3)
        init_logit = np.log(init_keep_prob/(1-init_keep_prob))
        self.logits = nn.Parameter(torch.full((n_shots,), float(init_logit)))
        self.n_shots = n_shots
        self.frac_remove = 0.1

    def set_frac_remove(self, frac):
        self.frac_remove = float(frac)

    def forward(self, x):
        probs = torch.sigmoid(self.logits)
        K = int(round(self.frac_remove * self.n_shots))

        if K > 0:
            idx_del = torch.topk(probs, K, largest=False).indices
            hard = torch.ones_like(probs)
            hard[idx_del] = 0.0
        else:
            hard = torch.ones_like(probs)

        ste_mask = hard + probs - probs.detach()
        x_masked = x * ste_mask.view(1,1,self.n_shots,1)
        return x_masked, probs, hard

    @torch.no_grad()
    def hard_indices_removed(self):
        probs = torch.sigmoid(self.logits)
        K = int(round(self.frac_remove * self.n_shots))
        if K <= 0:
            return np.array([], dtype=int)
        idx_del = torch.topk(probs, K, largest=False).indices
        return idx_del.cpu().numpy()


def train_unet_with_shotmask(CRG_tr, pct, device):
    S = CRG_tr.shape[1]
    target_keep = 1 - pct/100

    Xt = torch.from_numpy(CRG_tr).unsqueeze(1).to(device)
    Yt = torch.from_numpy(CRG_tr).unsqueeze(1).to(device)
    loader = DataLoader(TensorDataset(Xt, Yt), batch_size=BATCH, shuffle=True)

    model = UNet2DFull().to(device)
    mask_layer = BinaryShotMaskSTE(n_shots=S, init_keep_prob=target_keep).to(device)
    mask_layer.set_frac_remove(pct/100)

    opt = torch.optim.Adam(
        [{"params": model.parameters(), "lr": LR},
         {"params": mask_layer.parameters(), "lr": LR}]
    )
    crit = SSIM(data_range=2.0, size_average=True, channel=1)

    loss_hist = []
    model.train()

    for ep in range(EPOCHS):
        run = 0.0
        for xb, yb in loader:
            opt.zero_grad(set_to_none=True)

            xb_masked, probs, hard = mask_layer(xb)
            yp = model(xb_masked)

            mask_removed = (hard == 0.0)
            if mask_removed.any():
                yp_sub = yp[:,:,mask_removed,:]
                yb_sub = yb[:,:,mask_removed,:]
                loss_rec = 1 - crit(yp_sub, yb_sub)
            else:
                loss_rec = torch.tensor(0., device=device)

            keep_mean = torch.sigmoid(mask_layer.logits).mean()
            loss_sp = (keep_mean - target_keep)**2

            loss = loss_rec + LAMBDA_SPARSITY*loss_sp
            loss.backward()
            opt.step()
            run += loss.item()

        loss_hist.append(run/len(loader))
        print(f"   Época {ep+1}/{EPOCHS} — loss={loss_hist[-1]:.5f}")

    return model, mask_layer, loss_hist


def eval_metrics_areal(model, CRG_te, removed_idx, device, binf=2, data_range=2):
    if len(removed_idx)==0:
        return None

    model.eval()
    Htot, S, T = CRG_te.shape
    R = len(removed_idx)

    mse_map  = np.zeros((Htot,R), np.float32)
    psnr_map = np.zeros((Htot,R), np.float32)
    ssim_map = np.zeros((Htot,R), np.float32)
    snr_map  = np.zeros((Htot,R), np.float32)

    with torch.no_grad():
        for i0 in range(0, Htot, binf):
            i1 = min(i0+binf, Htot)
            xb_cpu = CRG_te[i0:i1].copy()
            xb_cpu[:, removed_idx, :] = 0
            xb = torch.from_numpy(xb_cpu).unsqueeze(1).to(device)

            yp = model(xb).squeeze(1).cpu().numpy()
            yt = CRG_te[i0:i1]

            B = yp.shape[0]
            for b in range(B):
                rec_id = i0 + b
                for j, shot_idx in enumerate(removed_idx):
                    ypn = yp[b,shot_idx]
                    ytn = yt[b,shot_idx]

                    mse = np.mean((ytn-ypn)**2)
                    mse_map[rec_id,j] = mse

                    amp = np.max(ytn)-np.min(ytn)
                    psnr = 20*np.log10(amp/(np.sqrt(mse+1e-12)))
                    psnr_map[rec_id,j] = psnr

                    snr = 10*np.log10((np.mean(ytn**2)+1e-12)/(mse+1e-12))
                    snr_map[rec_id,j] = snr

                    ssim_map[rec_id,j] = ssim(ytn,ypn,data_range=data_range)

    per_shot = {
        "shot_idx": np.array(removed_idx),
        "MSE": mse_map.mean(0),
        "PSNR": psnr_map.mean(0),
        "SSIM": ssim_map.mean(0),
        "SNR": snr_map.mean(0)
    }

    per_rec = {
        "rec_idx": np.arange(Htot),
        "MSE": mse_map.mean(1),
        "PSNR": psnr_map.mean(1),
        "SSIM": ssim_map.mean(1),
        "SNR": snr_map.mean(1)
    }

    global_m = {
        "MSE": mse_map.mean(),
        "PSNR": psnr_map.mean(),
        "SSIM": ssim_map.mean(),
        "SNR": snr_map.mean()
    }

    return {
        "mse_map": mse_map,
        "psnr_map": psnr_map,
        "ssim_map": ssim_map,
        "snr_map": snr_map,
        "per_shot": per_shot,
        "per_rec": per_rec,
        "global": global_m
    }


def save_npz_areal(pct, areal, out_dir):
    path = os.path.join(out_dir, f"run_pct{pct:02d}_areal.npz")
    np.savez_compressed(
        path,
        mse_map=areal["mse_map"],
        psnr_map=areal["psnr_map"],
        ssim_map=areal["ssim_map"],
        snr_map=areal["snr_map"],
        shot_idx=areal["per_shot"]["shot_idx"],
        MSE_per_shot=areal["per_shot"]["MSE"],
        PSNR_per_shot=areal["per_shot"]["PSNR"],
        SSIM_per_shot=areal["per_shot"]["SSIM"],
        SNR_per_shot=areal["per_shot"]["SNR"],
        rec_idx=areal["per_rec"]["rec_idx"],
        MSE_per_rec=areal["per_rec"]["MSE"],
        PSNR_per_rec=areal["per_rec"]["PSNR"],
        SSIM_per_rec=areal["per_rec"]["SSIM"],
        SNR_per_rec=areal["per_rec"]["SNR"],
        MSE_global=areal["global"]["MSE"],
        PSNR_global=areal["global"]["PSNR"],
        SSIM_global=areal["global"]["SSIM"],
        SNR_global=areal["global"]["SNR"]
    )

def save_csv_areal(pct, areal, out_dir):
    ps = areal["per_shot"]
    df_shot = pd.DataFrame({
        "shot_idx": ps["shot_idx"],
        "MSE": ps["MSE"],
        "PSNR": ps["PSNR"],
        "SSIM": ps["SSIM"],
        "SNR": ps["SNR"]
    })
    df_shot.to_csv(os.path.join(out_dir, f"metrics_per_shot_pct{pct:02d}.csv"), index=False)

    pr = areal["per_rec"]
    df_rec = pd.DataFrame({
        "rec_idx": pr["rec_idx"],
        "MSE": pr["MSE"],
        "PSNR": pr["PSNR"],
        "SSIM": pr["SSIM"],
        "SNR": pr["SNR"]
    })
    df_rec.to_csv(os.path.join(out_dir, f"metrics_per_rec_pct{pct:02d}.csv"), index=False)

def save_npz_per_pct(pct, probs, hard, logits, removed_idx, loss_hist, metrics_dict, out_dir):
    path = os.path.join(out_dir, f"run_pct{pct:02d}.npz")
    np.savez_compressed(
        path,
        probs=np.asarray(probs, dtype=np.float32),
        hard=np.asarray(hard, dtype=np.float32),
        logits=np.asarray(logits, dtype=np.float32),
        removed_idx=np.asarray(removed_idx, dtype=np.int32),
        loss_hist=np.asarray(loss_hist, dtype=np.float32),
        MSE=metrics_dict.get("MSE", None),
        PSNR=metrics_dict.get("PSNR", None),
        SSIM=metrics_dict.get("SSIM", None),
        SNR=metrics_dict.get("SNR", None),
    )


def build_and_save_geometry(base_dir=BASE_DIR, out_dir=OUT_DIR):
    pats = ["shot_SN_*.sgy", "shot_SN_*.segy"]
    files = sorted(sum([glob.glob(os.path.join(base_dir,p)) for p in pats], []))
    files = [f for f in files if tag(f) is not None]
    if not files:
        print("No hay SEGY para geometría")
        return

    shot_x = []
    shot_y = []
    rec_x = None
    rec_y = None

    for f in files:
        st = _read_segy(f, headonly=True)
        tr0 = st.traces[0]
        hdr0 = tr0.header

        sx = getattr(hdr0, "source_coordinate_x", None)
        sy = getattr(hdr0, "source_coordinate_y", None)
        if sx is None or sy is None:
            print("Faltan source_coordinate_x/y en el header")
            return

        shot_x.append(sx)
        shot_y.append(sy)

        if rec_x is None:
            rx = []
            ry = []
            for tr in st.traces:
                h = tr.header
                gx = getattr(h, "group_coordinate_x", None)
                gy = getattr(h, "group_coordinate_y", None)
                if gx is None or gy is None:
                    print("Faltan group_coordinate_x/y en el header")
                    return
                rx.append(gx)
                ry.append(gy)
            rec_x = np.array(rx,float)
            rec_y = np.array(ry,float)

    shot_x = np.array(shot_x,float)
    shot_y = np.array(shot_y,float)

    np.savez_compressed(
        os.path.join(out_dir,"geometry_shot_rec.npz"),
        shot_x=shot_x, shot_y=shot_y,
        rec_x=rec_x, rec_y=rec_y
    )
    print(">>> Geometría generada:", len(shot_x), "shots |", len(rec_x),"recs")


def plot_geometry_colored_by_shot_metric(pct, out_dir=OUT_DIR, metric="SSIM"):
    geom_path = os.path.join(out_dir, "geometry_shot_rec.npz")
    csv_path  = os.path.join(out_dir, f"pct_{pct:02d}", f"metrics_per_shot_pct{pct:02d}.csv")

    if not os.path.exists(geom_path):
        build_and_save_geometry(BASE_DIR, out_dir)

    if not os.path.exists(geom_path) or not os.path.exists(csv_path):
        print("Falta geometría o CSV para el escenario", pct)
        return

    geom = np.load(geom_path)
    sx, sy = geom["shot_x"], geom["shot_y"]
    rx, ry = geom["rec_x"], geom["rec_y"]

    df = pd.read_csv(csv_path)
    removed_idx = df["shot_idx"].values.astype(int)
    vals = df[metric].values.astype(float)

    vmin = np.percentile(vals, 5)
    vmax = np.percentile(vals, 95)
    if vmin == vmax:
        vmin = vals.min() - 1e-3
        vmax = vals.max() + 1e-3

    vals_norm = (vals - vmin) / (vmax - vmin + 1e-12)
    vals_norm = np.clip(vals_norm, 0.0, 1.0)
    sizes = 80 + 500 * vals_norm

    fig, ax = plt.subplots(figsize=(10,7))

    ax.plot(rx, ry, "-", color="lightgray", linewidth=2, alpha=0.6, label="Receptores")
    ax.scatter(sx, sy, c="silver", s=40, alpha=0.5, label="Shots (todos)")

    sc = ax.scatter(
        sx[removed_idx], sy[removed_idx],
        c=vals,
        s=sizes,
        cmap="viridis",
        vmin=vmin, vmax=vmax,
        edgecolors="white",
        linewidths=1.2,
        alpha=0.95,
        label=f"Shots eliminados ({metric})"
    )

    cb = plt.colorbar(sc, ax=ax, shrink=0.8)
    cb.set_label(metric, fontsize=12)

    ax.set_aspect("equal", adjustable="box")
    ax.set_xlabel("X (m)")
    ax.set_ylabel("Y (m)")
    ax.set_title(f"Métrica {metric} por shot — {pct}%", fontsize=15)
    ax.legend()
    plt.tight_layout()

    save_path = os.path.join(out_dir, f"pct_{pct:02d}", f"geometry_{metric}_pct{pct:02d}.png")
    plt.savefig(save_path, dpi=160)
    plt.close()
    print(">>> Figura scatter", metric, "guardada en:", save_path)


def plot_geometry_heatmap_metric(pct, out_dir=OUT_DIR, metric="SSIM",
                                 nx=200, ny=200):
    geom_path = os.path.join(out_dir, "geometry_shot_rec.npz")
    csv_path  = os.path.join(out_dir, f"pct_{pct:02d}", f"metrics_per_shot_pct{pct:02d}.csv")

    if not os.path.exists(geom_path):
        build_and_save_geometry(BASE_DIR, out_dir)

    if not os.path.exists(geom_path) or not os.path.exists(csv_path):
        print("Falta geometría o CSV para heatmap en pct=", pct)
        return

    geom = np.load(geom_path)
    sx, sy = geom["shot_x"], geom["shot_y"]
    rx, ry = geom["rec_x"], geom["rec_y"]

    df = pd.read_csv(csv_path)
    removed_idx = df["shot_idx"].values.astype(int)
    vals = df[metric].values.astype(float)

    if len(removed_idx) == 0:
        print(f"No hay shots eliminados para pct={pct}")
        return

    points = np.column_stack([sx[removed_idx], sy[removed_idx]])

    xmin, xmax = sx.min(), sx.max()
    ymin, ymax = sy.min(), sy.max()
    dx = 0.02 * (xmax - xmin + 1e-9)
    dy = 0.02 * (ymax - ymin + 1e-9)
    xmin, xmax = xmin - dx, xmax + dx
    ymin, ymax = ymin - dy, ymax + dy

    grid_x, grid_y = np.meshgrid(
        np.linspace(xmin, xmax, nx),
        np.linspace(ymin, ymax, ny)
    )

    try:
        grid_vals_lin  = griddata(points, vals, (grid_x, grid_y), method="linear")
    except QhullError:
        print("Advertencia: puntos casi colineales, se omite 'linear'.")
        grid_vals_lin = np.full_like(grid_x, np.nan, dtype=float)

    grid_vals_near = griddata(points, vals, (grid_x, grid_y), method="nearest")
    grid_vals = np.where(np.isnan(grid_vals_lin), grid_vals_near, grid_vals_lin)

    valid = np.isfinite(grid_vals)
    if np.any(valid):
        vmin = np.percentile(grid_vals[valid], 5)
        vmax = np.percentile(grid_vals[valid], 95)
        if vmin == vmax:
            vmin = grid_vals[valid].min() - 1e-3
            vmax = grid_vals[valid].max() + 1e-3
    else:
        vmin, vmax = vals.min(), vals.max()

    fig, ax = plt.subplots(figsize=(10,7))

    im = ax.imshow(
        grid_vals,
        origin="lower",
        extent=[xmin, xmax, ymin, ymax],
        cmap="viridis",
        vmin=vmin, vmax=vmax,
        aspect="equal"
    )

    cb = plt.colorbar(im, ax=ax, shrink=0.8)
    cb.set_label(metric, fontsize=12)

    ax.scatter(rx, ry, s=10, c="lightgray", alpha=0.7, label="Receptores")

    ax.scatter(
        sx[removed_idx], sy[removed_idx],
        c=vals,
        cmap="viridis",
        vmin=vmin, vmax=vmax,
        edgecolors="black",
        linewidths=0.8,
        s=60,
        label=f"Shots eliminados ({metric})"
    )

    ax.set_xlabel("X (m)")
    ax.set_ylabel("Y (m)")
    ax.set_title(f"Heatmap espacial de {metric} — {pct}% shots eliminados", fontsize=15)
    ax.legend(loc="best")
    ax.set_aspect("equal", adjustable="box")
    plt.tight_layout()

    save_path = os.path.join(out_dir, f"pct_{pct:02d}", f"geometry_{metric}_heatmap_pct{pct:02d}.png")
    plt.savefig(save_path, dpi=160)
    plt.close()
    print(">>> Heatmap", metric, "guardado en:", save_path)


print(">>> Cargando SEGY…")
gathers_shot, file_list = load_all(BASE_DIR)
print(">>> Shots cargados:", len(file_list))

assert gathers_shot.shape[1] >= H and gathers_shot.shape[2] >= W
gathers_shot = gathers_shot[:, :H, :W]

CRG_all = np.transpose(gathers_shot, (1,0,2))
CRG_all = norm_trace_lastaxis(CRG_all)
CRG_all = crop_to_mult8_2d(CRG_all)
S8, T8 = CRG_all.shape[1], CRG_all.shape[2]
check_divisible_by_8(S8, T8)

idx_rec = np.arange(CRG_all.shape[0])
idx_tr, idx_te = train_test_split(idx_rec, test_size=0.2, random_state=42)
CRG_tr, CRG_te = CRG_all[idx_tr], CRG_all[idx_te]

print(">>> Datos preparados. CRG_tr:", CRG_tr.shape, " CRG_te:", CRG_te.shape)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

dict_losses = {}
dict_metrics_global = {}
dict_removed_counts = {}

for pct in SCENARIOS:
    print(f"\n==============================\n>> Escenario {pct}%\n==============================")
    sc_dir = os.path.join(OUT_DIR, f"pct_{pct:02d}")
    os.makedirs(sc_dir, exist_ok=True)

    model, mask_layer, loss_hist = train_unet_with_shotmask(CRG_tr, pct, device)
    dict_losses[str(pct)] = loss_hist

    removed_idx = mask_layer.hard_indices_removed()
    dict_removed_counts[str(pct)] = len(removed_idx)
    print("   Shots eliminados:", removed_idx)

    areal = eval_metrics_areal(model, CRG_te, removed_idx, device)
    save_npz_areal(pct, areal, sc_dir)
    save_csv_areal(pct, areal, sc_dir)

    metrics_global = areal["global"]
    dict_metrics_global[str(pct)] = metrics_global

    with torch.no_grad():
        probs_np = torch.sigmoid(mask_layer.logits).cpu().numpy()
        hard_np  = np.ones_like(probs_np)
        hard_np[removed_idx] = 0.0
        logits_np = mask_layer.logits.detach().cpu().numpy()
    save_npz_per_pct(pct, probs_np, hard_np, logits_np, removed_idx, loss_hist, metrics_global, sc_dir)

    fig = plt.figure(figsize=(8,4))
    plt.plot(loss_hist, '-o')
    plt.title(f"Pérdida (1-SSIM + λ·sparsity) — {pct}%")
    plt.xlabel("Época")
    plt.ylabel("Loss")
    plt.grid(True)
    plt.tight_layout()
    plt.savefig(os.path.join(sc_dir, "loss_curve.png"), dpi=150)
    plt.close()

    plot_geometry_colored_by_shot_metric(pct, OUT_DIR, metric="SSIM")
    plot_geometry_heatmap_metric(pct, OUT_DIR, metric="SSIM")

rows = []
for k in sorted(dict_metrics_global.keys(), key=lambda z: int(z)):
    m = dict_metrics_global[k]
    rows.append({
        "pct_removed": int(k),
        "shots_removed_count": dict_removed_counts.get(k),
        "MSE": m["MSE"],
        "PSNR": m["PSNR"],
        "SSIM": m["SSIM"],
        "SNR": m["SNR"],
    })

df = pd.DataFrame(rows).sort_values("pct_removed")
csv_path = os.path.join(OUT_DIR, "metrics_removed_shots_global.csv")
df.to_csv(csv_path, index=False, float_format="%.6f")

print("\n>>> Listo.")
print(">>> Salidas en:", OUT_DIR)
print(">>> CSV global:", csv_path)
