# Aplicación a datos reales - Utah FORGE
### Datos crudos originales

In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt
from obspy.io.segy.segy import _read_segy

SGY_DIR = "/home/pc-2/Documents/CAVE_minciencias/utah_model/2D_seismic_data/2D/Correlated_Shot_Gathers"
OUT_PATH = "/home/pc-2/Downloads/shot_gather.png"  

def get_dt_seconds(stream):
    """Intenta obtener dt (s) desde header de traza o del BinaryFileHeader (microsegundos)."""
    try:
        us = stream.traces[0].header.sample_interval_in_ms_for_this_trace 
        if us and us > 0:
            return float(us) * 1e-6
    except Exception:
        pass
    try:
        us = stream.binary_file_header.sample_interval_in_microseconds  
        if us and us > 0:
            return float(us) * 1e-6
    except Exception:
        pass
    return None

def percentile_clip(A, p=99.0):
    a = np.percentile(np.abs(A), p)
    return float(a) if a > 0 else float(np.max(np.abs(A)) + 1e-9)

files = sorted([f for f in os.listdir(SGY_DIR) if f.lower().endswith((".sgy", ".segy"))])
if not files:
    raise FileNotFoundError("No se encontraron archivos .sgy/.segy en la carpeta.")
path = os.path.join(SGY_DIR, files[0])

st = _read_segy(path, headonly=False)

data = np.array([tr.data for tr in st.traces], dtype=np.float32)
nrec, nt = data.shape

dt_s = get_dt_seconds(st)
t_ms = np.arange(nt, dtype=np.float32) * (dt_s * 1e3 if dt_s else 1.0)
t_label = "Tiempo (ms)" if dt_s else "Muestras"

A = percentile_clip(data, p=90.0)

plt.figure(figsize=(10, 6))
plt.imshow(
    data.T, cmap="gray", aspect="auto", origin="upper",
    vmin=-A, vmax=+A,
    extent=[0, nrec, t_ms[-1], t_ms[0]] 
)
plt.title(f"Shot gather: {os.path.basename(path)} (clip ±P90 |amp|)")
plt.xlabel("Receptor")
plt.ylabel(t_label)
cbar = plt.colorbar()
cbar.set_label("Amplitud")
plt.tight_layout()

# plt.savefig(OUT_PATH, dpi=300, bbox_inches="tight")
# print(f"✅ Figura guardada en: {OUT_PATH}")

plt.show()

### Parámetros de los datos

In [None]:
path = "/home/pc-2/Documents/CAVE_minciencias/utah_model/2D_seismic_data/2D/Correlated_Shot_Gathers/SGY_trim7/27_1511546140_30100_50100_20171127_150416_752_trim7.sgy"
from obspy import read

st = read(path, format="SEGY", unpack_trace_headers=True)
tr = st[0]
print("delta (s):", tr.stats.delta)         
print("sampling_rate (Hz):", tr.stats.sampling_rate)
print("npts:", tr.stats.npts)

### Preprocesamiento de los datos

In [None]:
import os, glob, numpy as np
import matplotlib.pyplot as plt
from obspy.io.segy.segy import _read_segy
from obspy.signal.filter import bandpass
from pathlib import Path

# -------- CONFIG --------
BASE_DIR = "/home/pc-2/Documents/CAVE_minciencias/utah_model/2D_seismic_data/2D/Correlated_Shot_Gathers/SGY_trim7"   # <--- CAMBIA ESTO
OUT_DIR  = "/home/pc-2/Downloads/fig_bandpass" 
FREQMIN, FREQMAX = 10.0, 80.0             
PCLIP = 99.0
BINS = 80
SAMPLING_HZ = 1000

os.makedirs(OUT_DIR, exist_ok=True)

def infer_sampling_hz(segy_obj, manual_fs=None):
    if manual_fs is not None:
        return float(manual_fs)
    cand = []
    h = segy_obj.traces[0].header
    for attr in ["sample_interval_in_ms", "sample_interval", "dt", "delta", "time_sampling", "trace_sampling_interval"]:
        if hasattr(h, attr):
            v = getattr(h, attr)
            if v not in (None, 0):
                v = float(v)
                dt = v * 1e-6 if v > 1e-3 else v
                if dt > 0:
                    cand.append(1.0 / dt)
    b = getattr(segy_obj, "binary_file_header", None)
    if b is not None:
        for attr in ["sample_interval", "sample_interval_in_ms", "interval", "hdt"]:
            if hasattr(b, attr):
                v = getattr(b, attr)
                if v not in (None, 0):
                    v = float(v)
                    dt = v * 1e-6 if v > 1e-3 else v
                    if dt > 0:
                        cand.append(1.0 / dt)
    if not cand:
        raise ValueError("No fue posible inferir fs; define SAMPLING_HZ.")
    return float(np.median(cand))

def segy_to_numpy(segy_obj):
    return np.stack([tr.data.astype(np.float64) for tr in segy_obj.traces], axis=0)  # (rec,time)

def mean_spectrum_over_receivers(shot, fs_hz):
    nrec, nt = shot.shape
    SH = np.fft.rfft(shot, axis=1)
    amp = np.abs(SH).mean(axis=0)
    freqs = np.fft.rfftfreq(nt, d=1.0/fs_hz)
    return freqs, amp

def binned_spectrum(freqs, amps, nbins=80):
    fmin, fmax = freqs[0], freqs[-1]
    bins = np.linspace(fmin, fmax, nbins+1)
    idx = np.digitize(freqs, bins) - 1
    amp_binned = np.zeros(nbins, dtype=np.float64)
    counts = np.zeros(nbins, dtype=np.int64)
    for i, a in zip(idx, amps):
        if 0 <= i < nbins:
            amp_binned[i] += a
            counts[i] += 1
    counts[counts == 0] = 1
    amp_binned = amp_binned / counts
    bin_centers = 0.5*(bins[:-1] + bins[1:])
    return bin_centers, amp_binned

sgy_files = sorted(glob.glob(os.path.join(BASE_DIR, "*.sgy"))) + \
            sorted(glob.glob(os.path.join(BASE_DIR, "*.segy")))
if not sgy_files:
    raise FileNotFoundError("No se encontraron archivos .sgy/.segy en BASE_DIR.")
SHOT_PATH = sgy_files[0]
SHOT_NAME = Path(SHOT_PATH).stem
TAG = f"bp_{int(FREQMIN)}-{int(FREQMAX)}Hz"

D = _read_segy(SHOT_PATH)
fs = infer_sampling_hz(D, manual_fs=SAMPLING_HZ)
shot = segy_to_numpy(D)   
nrec, nt = shot.shape
print(f"Cargado: {SHOT_PATH}\nShape: {shot.shape} (rec, time) | fs = {fs:.3f} Hz")

shot_f = np.empty_like(shot)
for i in range(nrec):
    shot_f[i] = bandpass(shot[i], freqmin=FREQMIN, freqmax=FREQMAX, df=fs,
                         corners=4, zerophase=True)

shot_diff = shot - shot_f

A = np.percentile(np.abs(shot), PCLIP)

fig1, axes = plt.subplots(1, 3, figsize=(16, 5), constrained_layout=True)
im0 = axes[0].imshow(shot.T, aspect="auto", cmap="Greys", vmin=-0.001, vmax=0.001,
                     extent=[0, nrec, nt, 0])
axes[0].set_title("Shot original")
axes[0].set_xlabel("Receptor"); axes[0].set_ylabel("Tiempo (muestras)")
fig1.colorbar(im0, ax=axes[0], shrink=0.8, label="Amplitud")

im1 = axes[1].imshow(shot_f.T, aspect="auto", cmap="Greys", vmin=-0.001, vmax=0.001,
                     extent=[0, nrec, nt, 0])
axes[1].set_title(f"Shot filtrado (band-pass {FREQMIN}-{FREQMAX} Hz)")
axes[1].set_xlabel("Receptor"); axes[1].set_ylabel("Tiempo (muestras)")
fig1.colorbar(im1, ax=axes[1], shrink=0.8, label="Amplitud")

Ad = np.percentile(np.abs(shot_diff), PCLIP)
im2 = axes[2].imshow(shot_diff.T, aspect="auto", cmap="Greys", vmin=-Ad, vmax=+Ad,
                     extent=[0, nrec, nt, 0])
axes[2].set_title("Diferencia (original − filtrado)")
axes[2].set_xlabel("Receptor"); axes[2].set_ylabel("Tiempo (muestras)")
fig1.colorbar(im2, ax=axes[2], shrink=0.8, label="Amplitud")

# out1 = os.path.join(OUT_DIR, f"{SHOT_NAME}_{TAG}_comparativo")
# for ext in ("png", "pdf", "svg"):
#     fig1.savefig(f"{out1}.{ext}", dpi=300, bbox_inches="tight")
# print("✅ Guardado:", out1 + ".{png,pdf,svg}")
plt.close(fig1)

freqs_o, amp_o = mean_spectrum_over_receivers(shot, fs)
freqs_f, amp_f = mean_spectrum_over_receivers(shot_f, fs)
nf = min(len(freqs_o), len(freqs_f))
freqs_o, amp_o = freqs_o[:nf], amp_o[:nf]
freqs_f, amp_f = freqs_f[:nf], amp_f[:nf]
fbins, amp_o_b = binned_spectrum(freqs_o, amp_o, nbins=BINS)
_,    amp_f_b = binned_spectrum(freqs_f, amp_f, nbins=BINS)

fig2, ax = plt.subplots(figsize=(12, 4), constrained_layout=True)
width = (fbins[1] - fbins[0]) * 0.45
ax.bar(fbins - width/2, amp_o_b, width=width, alpha=0.7, label="Original")
ax.bar(fbins + width/2, amp_f_b, width=width, alpha=0.7, label="Filtrado")
ax.set_title("Histograma (binned) de amplitud espectral promedio")
ax.set_xlabel("Frecuencia (Hz)"); ax.set_ylabel("Amplitud promedio (a.u.)")
ax.legend()

# out2 = os.path.join(OUT_DIR, f"{SHOT_NAME}_{TAG}_hist_binned")
# for ext in ("png", "pdf", "svg"):
#     fig2.savefig(f"{out2}.{ext}", dpi=300, bbox_inches="tight")
# print("✅ Guardado:", out2 + ".{png,pdf,svg}")
plt.close(fig2)

fig3, ax3 = plt.subplots(figsize=(12, 4), constrained_layout=True)
ax3.plot(freqs_o, amp_o, label="Original")
ax3.plot(freqs_f, amp_f, label="Filtrado")
ax3.set_xlim(0, min(200, freqs_o[-1]))
ax3.set_title("Espectro de amplitud promedio por receptor")
ax3.set_xlabel("Frecuencia (Hz)"); ax3.set_ylabel("Amplitud (a.u.)")
ax3.legend()

# out3 = os.path.join(OUT_DIR, f"{SHOT_NAME}_{TAG}_spectrum_line")
# for ext in ("png", "pdf", "svg"):
#     fig3.savefig(f"{out3}.{ext}", dpi=300, bbox_inches="tight")
# print("✅ Guardado:", out3 + ".{png,pdf,svg}")
plt.close(fig3)

### Guardar nuevos shots preprocesados .npz

In [None]:
import os, glob, numpy as np
from obspy.io.segy.segy import _read_segy
from obspy.signal.filter import bandpass
import tqdm

BASE_DIR = "/home/pc-2/Documents/CAVE_minciencias/utah_model/2D_seismic_data/2D/Correlated_Shot_Gathers/SGY_trim7"     # <--- cambia esta ruta
OUT_DIR  = os.path.join(BASE_DIR, "_processed_bandpass")
os.makedirs(OUT_DIR, exist_ok=True)

FREQMIN, FREQMAX = 10.0, 80.0          
SAMPLING_HZ = 1000.0                        

def segy_to_numpy(segy_obj):
    """Convierte un objeto SEG-Y a matriz (n_receptores, n_muestras)"""
    traces = [tr.data.astype(np.float32) for tr in segy_obj.traces]
    return np.stack(traces, axis=0)

def bandpass_gather(shot, fs, fmin, fmax):
    """Aplica filtro pasa banda traza por traza"""
    nrec = shot.shape[0]
    shot_f = np.empty_like(shot)
    for i in range(nrec):
        shot_f[i] = bandpass(shot[i], freqmin=fmin, freqmax=fmax,
                             df=fs, corners=4, zerophase=True)
    return shot_f

sgy_files = sorted(glob.glob(os.path.join(BASE_DIR, "*.sgy"))) + \
            sorted(glob.glob(os.path.join(BASE_DIR, "*.segy")))

if not sgy_files:
    raise FileNotFoundError(f"No hay archivos .sgy en {BASE_DIR}")

for path in tqdm.tqdm(sgy_files, desc="Procesando gathers"):
    try:
        D = _read_segy(path)
        shot = segy_to_numpy(D)
        shot_f = bandpass_gather(shot, SAMPLING_HZ, FREQMIN, FREQMAX)

        # Nombre base
        name = os.path.splitext(os.path.basename(path))[0]
        out_path = os.path.join(OUT_DIR, name + ".npy")
        np.save(out_path, shot_f.astype(np.float32))

    except Exception as e:
        print(f"[ERROR] {path}: {e}")

print(f"\n✅ Procesamiento completo. Archivos guardados en:\n{OUT_DIR}")

### Unet + STE 

In [None]:
import os, glob, json, csv, re, numpy as np
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split

import torch, torch.nn as nn
from torch.utils.data import TensorDataset, DataLoader
from pytorch_msssim import SSIM
from skimage.metrics import structural_similarity as ssim


BASE_DIR = "/home/pc-2/Documents/CAVE_minciencias/utah_model/2D_seismic_data/2D/Correlated_Shot_Gathers/SGY_trim7/_processed_bandpass"
OUT_DIR  = os.path.join(BASE_DIR, "salidas_crg_shotmask_ste_multi")
os.makedirs(OUT_DIR, exist_ok=True)

GLOBAL_SEED = 42
np.random.seed(GLOBAL_SEED)
torch.manual_seed(GLOBAL_SEED)
plt.rcParams["figure.dpi"] = 110

EPOCHS = 700         
BATCH  = 1
LR     = 1e-4
LAMBDA_SPARSITY = 0.01
SCENARIOS = list(range(10, 100, 10)) 
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def norm_trace_lastaxis(x):
    m = np.max(np.abs(x), axis=-1, keepdims=True) + 1e-6
    return x / m

def load_all_npy(base_dir):
    files = sorted(glob.glob(os.path.join(base_dir, "*.npy")))
    if not files:
        raise FileNotFoundError("No hay archivos .npy en la carpeta.")
    arrays, paths = [], []
    for f in files:
        A = np.load(f)
        if A.ndim != 2:
            raise ValueError(f"{os.path.basename(f)}: se esperaba 2D, obtenido {A.shape}")
        r, t = A.shape
        shot = A.astype(np.float32, copy=False) if t >= r else A.T.astype(np.float32, copy=False)
        arrays.append(shot); paths.append(f)
    return arrays, paths

def unify_shapes(shots_list, target=None, mode="center"):
    Rs = [s.shape[0] for s in shots_list]
    Ts = [s.shape[1] for s in shots_list]
    if target is None:
        H, W = min(Rs), min(Ts)
    else:
        H, W = target
    def crop2d(x, H, W, mode="center"):
        r, t = x.shape
        if r < H or t < W:
            raise ValueError(f"Shot {x.shape} más pequeño que el target {(H,W)}.")
        if mode == "center":
            rs = (r - H)//2; ts = (t - W)//2
        else:
            rs, ts = 0, 0
        return x[rs:rs+H, ts:ts+W].copy()
    X = np.stack([crop2d(s, H, W, mode=mode) for s in shots_list], 0)
    return X, (H, W)

def crop_ST_mult8(arr):
    H, S, T = arr.shape
    S8, T8 = (S // 8) * 8, (T // 8) * 8
    if S8 != S or T8 != T:
        arr = arr[:, :S8, :T8].copy()
    return arr

def save_np(path, arr):
    np.save(path, arr)

print(">>> Cargando .npy ...")
shots_list, file_list = load_all_npy(BASE_DIR) 
N = len(shots_list)
print(f">>> Cargados {N} shots .npy")

gathers_shot, (Hc, Wc) = unify_shapes(shots_list, target=None, mode="center")
H8, W8 = (Hc // 8) * 8, (Wc // 8) * 8
if (H8 != Hc) or (W8 != Wc):
    gathers_shot = np.ascontiguousarray(gathers_shot[:, :H8, :W8])
    Hc, Wc = H8, W8
assert Hc % 8 == 0 and Wc % 8 == 0

gathers_shot = norm_trace_lastaxis(gathers_shot).astype(np.float32)

CRG_all = np.transpose(gathers_shot, (1, 0, 2)).copy()
H, S, T = CRG_all.shape
print(f">>> CRG_all: (H={H}, S={S}, T={T})")

idx_rec = np.arange(H)
idx_tr, idx_te = train_test_split(idx_rec, test_size=0.2, random_state=GLOBAL_SEED, shuffle=True)
CRG_tr, CRG_te = CRG_all[idx_tr], CRG_all[idx_te]

CRG_all = crop_ST_mult8(CRG_all)
CRG_tr  = crop_ST_mult8(CRG_tr)
CRG_te  = crop_ST_mult8(CRG_te)
H_all, S_all, T_all = CRG_all.shape
H_tr,  S_tr,  T_tr  = CRG_tr.shape
H_te,  S_te,  T_te  = CRG_te.shape
assert S_tr == S_all == S_te and T_tr == T_all == T_te
print(">>> Shapes tras hotfix:",
      "CRG_all", CRG_all.shape,
      "CRG_tr",  CRG_tr.shape,
      "CRG_te",  CRG_te.shape)

class UNet2DFull(nn.Module):
    def __init__(self):
        super().__init__()
        def blk(cin, cout):
            return nn.Sequential(
                nn.Conv2d(cin, cout, 3, padding=1),
                nn.BatchNorm2d(cout), nn.LeakyReLU(0.01, True),
                nn.Conv2d(cout, cout, 3, padding=1),
                nn.BatchNorm2d(cout), nn.LeakyReLU(0.01, True)
            )
        self.e1, self.p1 = blk(1,64), nn.MaxPool2d(2,2)
        self.e2, self.p2 = blk(64,128), nn.MaxPool2d(2,2)
        self.e3, self.p3 = blk(128,256), nn.MaxPool2d(2,2)
        self.bott = blk(256,512)
        self.u3 = nn.ConvTranspose2d(512,256,2,2); self.d3 = blk(512,256)
        self.u2 = nn.ConvTranspose2d(256,128,2,2); self.d2 = blk(256,128)
        self.u1 = nn.ConvTranspose2d(128, 64,2,2); self.d1 = blk(128, 64)
        self.out = nn.Conv2d(64,1,1)
    def forward(self,x):
        e1=self.e1(x); p1=self.p1(e1)
        e2=self.e2(p1); p2=self.p2(e2)
        e3=self.e3(p2); p3=self.p3(e3)
        b=self.bott(p3)
        u3=self.u3(b); d3=self.d3(torch.cat([u3, self.crop(e3,u3)],1))
        u2=self.u2(d3); d2=self.d2(torch.cat([u2, self.crop(e2,u2)],1))
        u1=self.u1(d2); d1=self.d1(torch.cat([u1, self.crop(e1,u1)],1))
        return torch.tanh(self.out(d1))
    @staticmethod
    def crop(a,b):
        _,_,h,w=b.shape; _,_,H,W=a.shape
        dh,dw=(H-h)//2,(W-w)//2
        return a[:,:,dh:dh+h, dw:dw+w]

class BinaryShotMaskSTE(nn.Module):
    def __init__(self, n_shots, init_keep_prob=0.5):
        super().__init__()
        self.n_shots = int(n_shots)
        init_keep_prob = float(np.clip(init_keep_prob, 1e-3, 1-1e-3))
        init_logit = np.log(init_keep_prob/(1-init_keep_prob))
        self.logits = nn.Parameter(torch.full((self.n_shots,), float(init_logit)))
        self.frac_remove = 0.5
    def set_frac_remove(self, frac):
        self.frac_remove = float(np.clip(frac, 0.0, 1.0))
    def forward(self, x):
        assert x.dim()==4 and x.shape[2]==self.n_shots, f"Esperaba S={self.n_shots}, got {x.shape}"
        probs = torch.sigmoid(self.logits)                 # (S,)
        K = int(round(self.frac_remove * self.n_shots))
        if K > 0:
            idx_del = torch.topk(probs, K, largest=False).indices
            hard = torch.ones_like(probs); hard[idx_del] = 0.0
        else:
            hard = torch.ones_like(probs)
        ste_mask = hard + probs - probs.detach()           # STE
        x_masked = x * ste_mask.view(1,1,self.n_shots,1)
        return x_masked, probs, hard
    @torch.no_grad()
    def hard_indices_removed(self):
        probs = torch.sigmoid(self.logits)
        K = int(round(self.frac_remove * self.n_shots))
        if K <= 0: return np.array([], dtype=int)
        idx_del = torch.topk(probs, K, largest=False).indices
        return idx_del.cpu().numpy()

def train_quick(CRG_tr, pct, epochs=EPOCHS, batch=BATCH, lr=LR, lambda_sp=LAMBDA_SPARSITY):
    H, S, T = CRG_tr.shape
    assert S % 8 == 0 and T % 8 == 0, f"(S,T)=({S},{T}) deben ser múltiplos de 8"
    target_keep = 1.0 - (pct/100.0)

    Xt = torch.from_numpy(CRG_tr).unsqueeze(1).to(device) 
    loader = DataLoader(TensorDataset(Xt, Xt), batch_size=batch, shuffle=True)

    model = UNet2DFull().to(device)
    mask_layer = BinaryShotMaskSTE(n_shots=S, init_keep_prob=target_keep).to(device)
    mask_layer.set_frac_remove(pct/100.0)

    opt = torch.optim.Adam([
        {"params": model.parameters(), "lr": lr},
        {"params": mask_layer.parameters(), "lr": lr}
    ])
    crit = SSIM(data_range=2.0, size_average=True, channel=1)

    loss_hist = []
    model.train()
    for ep in range(1, epochs+1):
        run = 0.0
        for xb, yb in loader:
            opt.zero_grad(set_to_none=True)
            xb_masked, probs, hard = mask_layer(xb)
            yp = model(xb_masked)
            loss_rec = 1.0 - crit(yp, yb)
            keep_mean = torch.sigmoid(mask_layer.logits).mean()
            loss_sp = (keep_mean - target_keep) ** 2
            loss = loss_rec + lambda_sp * loss_sp
            loss.backward(); opt.step()
            run += float(loss.item())
        loss_hist.append(run / len(loader))
        print(f"[PCT={pct:02d} | E{ep:02d}] loss={loss_hist[-1]:.4f} | keep_mean≈{float(keep_mean):.3f}")
    return model, mask_layer, loss_hist

def eval_removed_only(model, CRG_te, removed_idx, binf=4):
    if len(removed_idx) == 0:
        return {"MSE": None, "PSNR": None, "SSIM": None, "SNR": None}
    model.eval()
    Htot, S, T = CRG_te.shape
    MSE, PSNR, SSIMg, SNR = [], [], [], []
    with torch.no_grad():
        for i0 in range(0, Htot, binf):
            i1 = min(i0+binf, Htot)
            xb_cpu = CRG_te[i0:i1].copy()
            xb_cpu[:, removed_idx, :] = 0.0
            xb = torch.from_numpy(xb_cpu).unsqueeze(1).to(device)
            yp = model(xb).squeeze(1).cpu().numpy()
            yt = CRG_te[i0:i1]
            for b in range(yp.shape[0]):
                ypn = yp[b, removed_idx]
                ytn = yt[b, removed_idx]
                mse = np.mean((ytn - ypn)**2)
                amp = np.ptp(ytn) + 1e-8
                psn = 20*np.log10(amp / (np.sqrt(mse + 1e-12))) if mse > 0 else float('inf')
                ssi = np.mean([ssim(ytn[j], ypn[j], data_range=2) for j in range(ytn.shape[0])])
                snr = 10*np.log10((np.mean(ytn**2)+1e-12)/(mse+1e-12))
                MSE.append(mse); PSNR.append(psn); SSIMg.append(ssi); SNR.append(snr)
    mean = lambda v: float(np.mean(v)) if len(v) else None
    return {"MSE": mean(MSE), "PSNR": mean(PSNR), "SSIM": mean(SSIMg), "SNR": mean(SNR)}

def save_mask_bars_and_array(removed_idx, S, save_stub, title=None, height=40):
    mask = np.ones(S, dtype=np.float32); mask[removed_idx] = 0.0
    img = np.ones((height, S), dtype=np.float32); img[:, mask == 0] = 0.0
    fig = plt.figure(figsize=(12, 2.2))
    ax = fig.add_subplot(1,1,1)
    ax.imshow(img, cmap="gray", aspect="auto", origin="upper", vmin=0, vmax=1)
    ax.set_yticks([]); ax.set_xlabel("Shot (índice)")
    ax.set_title(title if title else "Máscara de shots (1 keep / 0 removed)")
    fig.tight_layout(); fig.savefig(save_stub + "_mask_bars.png", dpi=140); plt.close(fig)
    save_np(save_stub + "_mask_bars.npy", img)
    return img

def viz_one_shot_and_arrays(model, removed_idx, CRG_all, shot_idx, save_dir, stub, vmin=-1, vmax=1, binf=4):
    os.makedirs(save_dir, exist_ok=True)
    model.eval()
    Htot, S, T = CRG_all.shape
    pred = np.zeros((Htot, T), dtype=np.float32)
    real = CRG_all[:, shot_idx, :].copy()
    with torch.no_grad():
        for i0 in range(0, Htot, binf):
            i1 = min(i0+binf, Htot)
            xb_cpu = CRG_all[i0:i1].copy()
            if len(removed_idx) > 0:
                xb_cpu[:, removed_idx, :] = 0.0
            xb = torch.from_numpy(xb_cpu).unsqueeze(1).to(device)
            yb = model(xb).squeeze(1).cpu().numpy()
            pred[i0:i1, :] = yb[:, shot_idx, :]
    err = np.abs(pred - real)

    fig = plt.figure(figsize=(16, 4.2))
    ax1 = fig.add_subplot(1,3,1); im0 = ax1.imshow(pred.T, cmap='gray', aspect='auto', vmin=vmin, vmax=vmax)
    ax1.set_title(f"Predicho — Shot {shot_idx}"); ax1.set_xlabel("Receptor"); ax1.set_ylabel("Tiempo")
    fig.colorbar(im0, ax=ax1, fraction=0.046, pad=0.04)
    ax2 = fig.add_subplot(1,3,2); im1 = ax2.imshow(real.T, cmap='gray', aspect='auto', vmin=vmin, vmax=vmax)
    ax2.set_title(f"Real — Shot {shot_idx}"); ax2.set_xlabel("Receptor"); ax2.set_ylabel("Tiempo")
    fig.colorbar(im1, ax=ax2, fraction=0.046, pad=0.04)
    ax3 = fig.add_subplot(1,3,3); im2 = ax3.imshow(err.T, cmap='inferno', aspect='auto')
    ax3.set_title("Error absoluto"); ax3.set_xlabel("Receptor"); ax3.set_ylabel("Tiempo")
    fig.colorbar(im2, ax=ax3, fraction=0.046, pad=0.04)
    fig.tight_layout(); fig.savefig(os.path.join(save_dir, f"{stub}_maps.png"), dpi=140); plt.close(fig)

    recs = np.linspace(0, pred.shape[0]-1, num=3, dtype=int)
    t = np.arange(pred.shape[1])
    fig2 = plt.figure(figsize=(12,3.6))
    for k, r in enumerate(recs):
        ax = fig2.add_subplot(1,3,k+1)
        ax.plot(t, real[r], label="Real", linewidth=1.2)
        ax.plot(t, pred[r], label="Predicho", linewidth=1.0)
        ax.set_title(f"Rec {r}"); ax.set_xlabel("Tiempo"); ax.set_ylabel("Amp"); ax.grid(True)
        if k==0: ax.legend()
    fig2.tight_layout(); fig2.savefig(os.path.join(save_dir, f"{stub}_traces.png"), dpi=140); plt.close(fig2)

    save_np(os.path.join(save_dir, f"{stub}_pred.npy"), pred)
    save_np(os.path.join(save_dir, f"{stub}_real.npy"), real)
    save_np(os.path.join(save_dir, f"{stub}_err.npy"),  err)
    save_np(os.path.join(save_dir, f"{stub}_traces_idx.npy"), recs)

all_loss_hist = {}     
metrics_rows = []        
print(">>> Ejecutando escenarios:", SCENARIOS)

for PCT in SCENARIOS:
    print("\n" + "="*70)
    print(f">>> Entrenando escenario {PCT}% shots eliminados")
    sc_dir = os.path.join(OUT_DIR, f"pct_{PCT:02d}")
    os.makedirs(sc_dir, exist_ok=True)

    model, mask_layer, loss_hist = train_quick(CRG_tr, pct=PCT)
    all_loss_hist[PCT] = loss_hist
    save_np(os.path.join(sc_dir, "loss_hist.npy"), np.array(loss_hist, dtype=np.float32))

    removed_idx = mask_layer.hard_indices_removed()
    S_tot = CRG_all.shape[1]
    print(f">>> Shots eliminados (duros): {len(removed_idx)} de {S_tot}")

    mask_stub = os.path.join(sc_dir, "mask")
    _ = save_mask_bars_and_array(removed_idx, S_tot, mask_stub,
                                 title=f"Máscara — {PCT}% eliminados")
    save_np(os.path.join(sc_dir, "removed_idx.npy"), removed_idx.astype(np.int32))

    metrics = eval_removed_only(model, CRG_te, removed_idx, binf=4)
    with open(os.path.join(sc_dir, "metrics.json"), "w") as f:
        json.dump(metrics, f, indent=2)
    print(">>> Métricas (solo shots eliminados en TEST):", metrics)

    for s in removed_idx[:2]:
        viz_one_shot_and_arrays(model, removed_idx, CRG_all, s,
                                save_dir=sc_dir, stub=f"shot_{s}")

    plt.figure(figsize=(6,4))
    plt.plot(range(1,len(loss_hist)+1), loss_hist)
    plt.title(f"Pérdida (1-SSIM + λ·sparsity) — {PCT}%")
    plt.xlabel("Época"); plt.ylabel("Loss"); plt.grid(True); plt.tight_layout()
    plt.savefig(os.path.join(sc_dir, f"loss_curve_{PCT:02d}.png"), dpi=140); plt.close()

    metrics_rows.append({
        "pct_removed": PCT,
        "num_removed": int(len(removed_idx)),
        "S_total": int(S_tot),
        "MSE": metrics["MSE"],
        "PSNR": metrics["PSNR"],
        "SSIM": metrics["SSIM"],
        "SNR": metrics["SNR"]
    })

plt.figure(figsize=(8,5))
for PCT in SCENARIOS:
    plt.plot(range(1, len(all_loss_hist[PCT])+1), all_loss_hist[PCT], label=f"{PCT}%")
plt.title("Curvas de pérdida por escenario (1-SSIM + λ·sparsity)")
plt.xlabel("Época"); plt.ylabel("Loss"); plt.grid(True)
plt.legend(ncol=3, fontsize=9)
plt.tight_layout()
plt.savefig(os.path.join(OUT_DIR, "loss_curves_all.png"), dpi=160)
plt.close()

np.save(os.path.join(OUT_DIR, "all_loss_hist.npy"), 
        np.array([ (pct, np.array(all_loss_hist[pct], dtype=np.float32)) for pct in SCENARIOS ], dtype=object),
        allow_pickle=True)

csv_path = os.path.join(OUT_DIR, "metrics_by_pct.csv")
with open(csv_path, "w", newline="") as f:
    w = csv.DictWriter(f, fieldnames=["pct_removed","num_removed","S_total","MSE","PSNR","SSIM","SNR"])
    w.writeheader(); w.writerows(metrics_rows)

print("\n>>> FIN. Salidas en:", OUT_DIR)
print(" - loss_curves_all.png")
print(" - metrics_by_pct.csv")
print(" - pct_xx/* (pérdidas, máscara PNG/NPY, índices, métricas, preds/real/err)")

### Imprimir métricas

In [None]:
import os
import numpy as np
import pandas as pd
from pathlib import Path

BASE_DIR = Path("/home/pc-2/Documents/CAVE_minciencias/utah_model/2D_seismic_data/2D/Correlated_Shot_Gathers/SGY_trim7/_processed_bandpass/salidas_crg_shotmask_ste_all_equal_quick")

metric_files = sorted(BASE_DIR.glob("pct_*/metrics_removed_only.npz"))

if not metric_files:
    raise FileNotFoundError("No se encontraron archivos metrics_removed_only.npz dentro de pct_*")

rows = []

for f in metric_files:
    pct = f.parent.name.replace("pct_", "")
    data = np.load(f)
    metrics = {k: float(v) if np.ndim(v) == 0 else float(np.mean(v)) for k, v in data.items()}
    metrics["pct_remove"] = int(pct)
    rows.append(metrics)

df = pd.DataFrame(rows).sort_values("pct_remove").reset_index(drop=True)

print("\n===== Métricas por porcentaje de submuestreo =====")
print(df.to_string(index=False, float_format="%.6f"))

OUT_CSV = BASE_DIR / "resumen_metricas_removed_only.csv"
df.to_csv(OUT_CSV, index=False, float_format="%.6f")
print(f"\n✅ Resumen guardado en: {OUT_CSV}")