In [None]:
import os, re, glob, numpy as np
from collections import Counter
from obspy.io.segy.segy import _read_segy
from sklearn.model_selection import train_test_split
import torch, torch.nn as nn
from torch.utils.data import TensorDataset, DataLoader
from pytorch_msssim import SSIM
from tqdm import tqdm
import matplotlib.pyplot as plt
from skimage.metrics import structural_similarity as ssim
from math import log10

BASE_DIR = "/home/pc-2/Documents/CAVE_minciencias/utah_model/shot densos 2"
H, W = 512, 4096
BATCH = 1
EPOCHS = 20
LR = 1e-4
PCTS = list(range(10, 100, 10))
PRINT_MAX_EX = 5       

GLOBAL_SEED = 42
np.random.seed(GLOBAL_SEED)
torch.manual_seed(GLOBAL_SEED)

def print_header(title):
    print("\n" + "="*10 + f" {title} " + "="*10)

def print_basic_stats(name, arr):
    arr_np = arr if isinstance(arr, np.ndarray) else arr.cpu().numpy()
    print(f"[{name}] shape={arr_np.shape}, dtype={arr_np.dtype}, "
          f"min={np.nanmin(arr_np):.4f}, max={np.nanmax(arr_np):.4f}, "
          f"mean={np.nanmean(arr_np):.4f}, std={np.nanstd(arr_np):.4f}")

def print_tag_dist(label, tags):
    c = Counter(tags.tolist() if isinstance(tags, np.ndarray) else tags)
    print(f"[{label}] N={len(tags)} | dist={dict(c)}")

def sanity_no_nan_inf(name, arr):
    arr_np = arr if isinstance(arr, np.ndarray) else arr.cpu().numpy()
    n_nan = np.isnan(arr_np).sum()
    n_inf = np.isinf(arr_np).sum()
    assert n_nan == 0 and n_inf == 0, f"{name}: NaN={n_nan}, Inf={n_inf}"

def check_divisible_by_8(h, w):
    for val, nm in [(h, "H"), (w, "W")]:
        assert val % 8 == 0, f"{nm} debe ser múltiplo de 8 para U-Net: {val}"

def check_train_test_disjoint(idx_tr, idx_te):
    inter = set(idx_tr.tolist()).intersection(set(idx_te.tolist()))
    assert len(inter) == 0, f"Train/Test con intersección! {inter}"

def verify_subsampling(X, target_frac):
    N, nrec, _ = X.shape
    ks = min(N, PRINT_MAX_EX)
    fracs = []
    for i in range(ks):
        zero_rows = np.where(np.all(X[i] == 0, axis=1))[0]
        fracs.append(len(zero_rows)/nrec)
    print(f"[Check Submuestreo] target={target_frac:.2f}, "
          f"ejemplos={fracs} (prom={np.mean(fracs):.3f})")

def gpu_info(device):
    if device.type == "cuda":
        print(f"[GPU] {torch.cuda.get_device_name(0)} | "
              f"cap={torch.cuda.get_device_capability(0)} | "
              f"mem_alloc={torch.cuda.memory_allocated()/1e9:.2f} GB | "
              f"mem_reserved={torch.cuda.memory_reserved()/1e9:.2f} GB")
    else:
        print("[GPU] No CUDA: usando CPU")

def tag(fname):
    m = re.search(r'shot_(y\d{2})_', os.path.basename(fname).lower())
    return m.group(1) if m else None

def load_all(base_dir):
    pats = ["shot_y0[1-7]_*.sgy", "shot_y0[1-7]_*.segy"]
    files = sorted(sum([glob.glob(os.path.join(base_dir, p)) for p in pats], []))
    files = [f for f in files if tag(f) in {f"y0{i}" for i in range(1,8)}]
    arrs, tags_list = [], []
    for f in files:
        st = _read_segy(f, headonly=False)
        arr = np.array([tr.data for tr in st.traces], dtype=np.float32)
        arrs.append(arr)
        tags_list.append(tag(f))
    print_header("Carga de archivos")
    print(f"Total archivos: {len(files)} | ejemplos: {files[:PRINT_MAX_EX]}")
    for i in range(min(3, len(arrs))):
        print(f"  - arr[{i}] shape={arrs[i].shape}, tag={tags_list[i]}")
    return np.stack(arrs, 0), np.array(tags_list)

gathers, tags = load_all(BASE_DIR) 
print_header("Antes de recorte")
print_basic_stats("gathers_raw", gathers)
print_tag_dist("tags_total", tags)

assert gathers.shape[1] >= H and gathers.shape[2] >= W, "Dimensiones de archivo < H/W"
check_divisible_by_8(H, W)
gathers = gathers[:, :H, :W].copy() 
print_header("Después de recorte")
print_basic_stats("gathers", gathers)
sanity_no_nan_inf("gathers", gathers)

idx = np.arange(len(tags))
idx_tr, idx_te = train_test_split(
    idx, test_size=0.2, random_state=42, shuffle=True)
check_train_test_disjoint(idx_tr, idx_te)
Gtr, Gte = gathers[idx_tr], gathers[idx_te]
Ttr, Tte = tags[idx_tr], tags[idx_te]
print_header("Split Train/Test")
print(f"N_total={len(tags)} | N_tr={len(Ttr)} | N_te={len(Tte)}")
print_tag_dist("tags_train", Ttr)
print_tag_dist("tags_test", Tte)

def norm_trace(x):
    m = np.max(np.abs(x), axis=2, keepdims=True) + 1e-6
    return x / m

Ytr, Yte = norm_trace(Gtr), norm_trace(Gte)
print_header("Normalización")
print_basic_stats("Ytr (targets train)", Ytr)
print_basic_stats("Yte (targets test)", Yte)
sanity_no_nan_inf("Ytr", Ytr)
sanity_no_nan_inf("Yte", Yte)
print("Rangos esperados ~[-1,1]. data_range=2.0 para SSIM.")

def subsample_all_with_idx(X, frac, seed):
    """
    Anula 'frac' de receptores en CADA shot (filas completas a cero).
    Devuelve X modificado y lista de arrays con los índices anulados por shot.
    """
    rng = np.random.default_rng(seed)
    N, nrec, _ = X.shape
    nrem = int(round(frac * nrec))
    idxs_removidos = []
    for i in range(N):
        if nrem > 0:
            idx0 = np.sort(rng.choice(nrec, size=nrem, replace=False))
            X[i, idx0, :] = 0.0
            idxs_removidos.append(idx0)
        else:
            idxs_removidos.append(np.array([], dtype=int))
    return X, idxs_removidos

class UNet2DFull(nn.Module):
    def __init__(self):
        super().__init__()
        def blk(cin, cout):
            return nn.Sequential(
                nn.Conv2d(cin, cout, 3, padding=1),
                nn.BatchNorm2d(cout), nn.LeakyReLU(0.01, True),
                nn.Conv2d(cout, cout, 3, padding=1),
                nn.BatchNorm2d(cout), nn.LeakyReLU(0.01, True)
            )
        self.e1, self.p1 = blk(1,64), nn.MaxPool2d(2,2)
        self.e2, self.p2 = blk(64,128), nn.MaxPool2d(2,2)
        self.e3, self.p3 = blk(128,256), nn.MaxPool2d(2,2)
        self.bott = blk(256,512)
        self.u3 = nn.ConvTranspose2d(512,256,2,2); self.d3 = blk(512,256)
        self.u2 = nn.ConvTranspose2d(256,128,2,2); self.d2 = blk(256,128)
        self.u1 = nn.ConvTranspose2d(128, 64,2,2); self.d1 = blk(128, 64)
        self.out = nn.Conv2d(64,1,1)
    def forward(self,x):
        e1=self.e1(x); p1=self.p1(e1)
        e2=self.e2(p1); p2=self.p2(e2)
        e3=self.e3(p2); p3=self.p3(e3)
        b=self.bott(p3)
        u3=self.u3(b); d3=self.d3(torch.cat([u3, self.crop(e3,u3)],1))
        u2=self.u2(d3); d2=self.d2(torch.cat([u2, self.crop(e2,u2)],1))
        u1=self.u1(d2); d1=self.d1(torch.cat([u1, self.crop(e1,u1)],1))
        return torch.tanh(self.out(d1))
    @staticmethod
    def crop(a,b):
        _,_,h,w=b.shape; _,_,H,W=a.shape
        dh,dw=(H-h)//2,(W-w)//2
        return a[:,:,dh:dh+h, dw:dw+w]

def count_params(model):
    return sum(p.numel() for p in model.parameters())

def eval_metrics_removed_only(model, Xe, Ye, device, return_counts=False):
    """
    Métricas SOLO sobre las trazas eliminadas (filas todo-cero en la entrada).
    Ignora shots sin trazas eliminadas. Devuelve promedio y (opcional) conteos.
    """
    MSE, PSNR, SSIMg, SNR = [], [], [], []
    shots_used = 0
    total_removed_traces = 0

    model.eval()
    with torch.no_grad():
        for i in range(Ye.shape[0]):
            yt = Ye[i, 0].cpu().numpy()     
            xi = Xe[i, 0].cpu().numpy()       
            idx0 = np.where(np.all(xi == 0, axis=1))[0]
            if len(idx0) == 0:
                continue

            yp = model(Xe[i:i+1].to(device)).cpu().squeeze().numpy() 
            ytn = yt[idx0]
            ypn = yp[idx0]

            mse = np.mean((ytn - ypn)**2)
            amp = np.ptp(ytn) + 1e-8
            psn = 20 * np.log10(amp / (np.sqrt(mse + 1e-12))) if mse > 0 else float('inf')
            ssi = np.mean([ssim(ytn[j], ypn[j], data_range=2) for j in range(ytn.shape[0])])
            snr = 10 * np.log10((np.mean(ytn**2) + 1e-12) / (mse + 1e-12))

            MSE.append(mse); PSNR.append(psn); SSIMg.append(ssi); SNR.append(snr)
            shots_used += 1
            total_removed_traces += len(idx0)

    if len(MSE) == 0:
        out = {"MSE": None, "PSNR": None, "SSIM": None, "SNR": None}
        return (out, (0, 0)) if return_counts else out

    mean = lambda v: float(np.mean(v)) if len(v) else None
    out = {"MSE": mean(MSE), "PSNR": mean(PSNR), "SSIM": mean(SSIMg), "SNR": mean(SNR)}
    return (out, (shots_used, total_removed_traces)) if return_counts else out

def visualizar_comparacion(x_input, y_pred, y_true, titulo_extra=""):
    error = np.abs(y_pred - y_true)
    vmin, vmax = -1, 1
    fig, axs = plt.subplots(1, 4, figsize=(24, 6))

    im0 = axs[0].imshow(x_input.T, cmap='gray', aspect='auto', origin='upper', vmin=vmin, vmax=vmax)
    axs[0].set_title("Input (trazas anuladas)")
    axs[0].set_xlabel("Receptor"); axs[0].set_ylabel("Tiempo")
    fig.colorbar(im0, ax=axs[0], fraction=0.046, pad=0.04)

    im1 = axs[1].imshow(y_pred.T, cmap='gray', aspect='auto', origin='upper', vmin=vmin, vmax=vmax)
    axs[1].set_title("Predicción (U-Net)")
    axs[1].set_xlabel("Receptor"); axs[1].set_ylabel("Tiempo")
    fig.colorbar(im1, ax=axs[1], fraction=0.046, pad=0.04)

    im2 = axs[2].imshow(y_true.T, cmap='gray', aspect='auto', origin='upper', vmin=vmin, vmax=vmax)
    axs[2].set_title("Shot original")
    axs[2].set_xlabel("Receptor"); axs[2].set_ylabel("Tiempo")
    fig.colorbar(im2, ax=axs[2], fraction=0.046, pad=0.04)

    im3 = axs[3].imshow(error.T, cmap='inferno', aspect='auto', origin='upper')
    axs[3].set_title("Error absoluto")
    axs[3].set_xlabel("Receptor"); axs[3].set_ylabel("Tiempo")
    fig.colorbar(im3, ax=axs[3], fraction=0.046, pad=0.04)

    fig.suptitle(f"Comparación: Entrada — Predicción — Original — Error {titulo_extra}", fontsize=14)
    plt.tight_layout(); plt.show()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print_header("Dispositivo")
print(f"device={device} | PCTS={PCTS} | BATCH={BATCH} | EPOCHS={EPOCHS}")
gpu_info(device)

curvas_loss = {}
resultados_metricas = {}

for pct in PCTS:
    print_header(f"Submuestreo {pct}%")
    frac = pct / 100.0

    Xtr = Ytr.copy()
    Xte = Yte.copy()

    seed_pct = 1000 + pct
    Xtr, idxs_removidos_train = subsample_all_with_idx(Xtr, frac, seed=seed_pct)
    Xte, idxs_removidos_test  = subsample_all_with_idx(Xte, frac, seed=seed_pct)

    verify_subsampling(Xtr, frac)
    verify_subsampling(Xte, frac)
    sanity_no_nan_inf("Xtr", Xtr)
    sanity_no_nan_inf("Xte", Xte)

    Xt = torch.from_numpy(Xtr).unsqueeze(1).to(device) 
    Yt = torch.from_numpy(Ytr).unsqueeze(1).to(device) 
    Xe = torch.from_numpy(Xte).unsqueeze(1).to(device) 
    Ye = torch.from_numpy(Yte).unsqueeze(1).to(device)  

    print_basic_stats("Tensor Xt", Xt)
    print_basic_stats("Tensor Yt", Yt)
    print(f"Dataset sizes -> train={Xt.shape[0]}, test={Xe.shape[0]}")
    loader = DataLoader(TensorDataset(Xt, Yt), batch_size=BATCH, shuffle=True)

    model = UNet2DFull().to(device)
    print(f"Parámetros del modelo: {count_params(model):,}")

    opt   = torch.optim.Adam(model.parameters(), lr=LR)
    crit  = SSIM(data_range=2.0, size_average=True, channel=1)

    loss_hist = []
    model.train()
    for ep in range(EPOCHS):
        run = 0.0
        for xb, yb in tqdm(loader, desc=f"{pct}% | Época {ep+1}/{EPOCHS}", leave=False):
            opt.zero_grad(set_to_none=True)
            yp = model(xb)
            loss = 1 - crit(yp, yb)
            loss.backward(); opt.step()
            run += loss.item()
        ep_loss = run / len(loader)
        loss_hist.append(ep_loss)
        print(f"[{pct}%] Epoch {ep+1:02d}/{EPOCHS} | loss(1-SSIM)={ep_loss:.6f}")

    curvas_loss[str(pct)] = loss_hist

    idx_vis = None
    for i in range(Xe.shape[0]):
        xi = Xe[i, 0].detach().cpu().numpy()
        if np.any(np.all(xi == 0, axis=1)):
            idx_vis = i
            break
    if idx_vis is None:
        print("⚠ No se encontró shot de test con trazas anuladas; usando idx 0.")
        idx_vis = 0

    model.eval()
    with torch.no_grad():
        pred = model(Xe[idx_vis:idx_vis+1]).cpu().squeeze().numpy()
        entrada = Xe[idx_vis, 0].cpu().numpy()
        real = Ye[idx_vis, 0].cpu().numpy()
        n_zero = np.where(np.all(entrada == 0, axis=1))[0].size
        print(f"[Visualización] test_idx={idx_vis} | trazas anuladas en ese shot={n_zero}")
        visualizar_comparacion(entrada, pred, real, titulo_extra=f"(Test idx={idx_vis}, {pct}%)")

        trazas_eliminadas = np.where(np.all(entrada == 0, axis=1))[0]
        trazas_a_mostrar = trazas_eliminadas[:3] if len(trazas_eliminadas) >= 3 else trazas_eliminadas
        if len(trazas_a_mostrar) > 0:
            tiempo = np.arange(real.shape[1])
            plt.figure(figsize=(30, 4))
            for k, traza_idx in enumerate(trazas_a_mostrar):
                plt.subplot(1, len(trazas_a_mostrar), k+1)
                plt.plot(tiempo, real[traza_idx], label='Real', linewidth=1.5)
                plt.plot(tiempo, pred[traza_idx], label='Predicho', linewidth=1.5)
                plt.title(f"Traza eliminada #{traza_idx}")
                plt.xlabel("Tiempo"); plt.ylabel("Amplitud")
                plt.legend(); plt.grid(True)
            plt.suptitle(f"Comparación real vs predicho - Shot {idx_vis} - {pct}%")
            plt.tight_layout(); plt.show()
        else:
            print("⚠ Ese shot de test no tenía trazas anuladas para mostrar trazas individuales.")

    m_removed, (shots_used, total_removed) = eval_metrics_removed_only(model, Xe, Ye, device, return_counts=True)
    resultados_metricas[f"{pct}"] = m_removed
    print(f"[Métricas {pct}%] usados {shots_used} shots | trazas anuladas evaluadas={total_removed} | {m_removed}")

plt.figure(figsize=(10, 6))
for pct, curva in curvas_loss.items():
    plt.plot(range(1, len(curva)+1), curva, label=f"{pct}%")
plt.title("Curvas de pérdida por submuestreo (1 - SSIM)")
plt.xlabel("Época"); plt.ylabel("Pérdida")
plt.legend(title="Submuestreo")
plt.grid(True)
plt.tight_layout()
plt.show()

print("\n=== Resultados de métricas (diccionario) ===")
print(resultados_metricas)
