In [None]:
import os
import glob
from obspy.io.segy.segy import _read_segy
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import TensorDataset, DataLoader
import matplotlib.pyplot as plt
from skimage.metrics import structural_similarity as ssim
from math import log10
from tqdm import tqdm
from sklearn.model_selection import train_test_split
from pytorch_msssim import SSIM

def cargar_shot_gathers(ruta_carpeta, max_archivos=None):
    archivos = sorted(glob.glob(os.path.join(ruta_carpeta, "*.sgy")))
    if max_archivos:
        archivos = archivos[:max_archivos]

    lista_gathers = []
    for archivo in archivos:
        st = _read_segy(archivo, headonly=False)
        datos = np.array([tr.data for tr in st.traces])
        lista_gathers.append(datos)

    gathers = np.stack(lista_gathers, axis=0)
    return gathers
    
def submuestrear_receptores_aleatorios(gathers, porcentaje):
    N_shots, n_receptores, n_muestras = gathers.shape
    n_remover = int(porcentaje * n_receptores)

    gathers_sub = gathers.copy()
    idxs_removidos = []

    for i in range(N_shots):
        idx_removidos = np.sort(np.random.choice(n_receptores, n_remover, replace=False))
        gathers_sub[i, idx_removidos, :] = 0
        idxs_removidos.append(idx_removidos)

    return gathers_sub, idxs_removidos

def normalizar_por_traza(gathers_np):
    norm = np.max(np.abs(gathers_np), axis=2, keepdims=True) + 1e-6
    return gathers_np / norm

def visualizar_comparacion(x_input, y_pred, y_true):
    error = np.abs(y_pred - y_true)
    vmin, vmax = -1, 1 

    fig, axs = plt.subplots(1, 4, figsize=(24, 6))

    im0 = axs[0].imshow(x_input.T, cmap='gray', aspect='auto', origin='upper', vmin=vmin, vmax=vmax)
    axs[0].set_title("Input con trazas nulas")
    axs[0].set_xlabel("Trazas")
    axs[0].set_ylabel("Tiempo")
    fig.colorbar(im0, ax=axs[0], fraction=0.046, pad=0.04)

    im1 = axs[1].imshow(y_pred.T, cmap='gray', aspect='auto', origin='upper', vmin=vmin, vmax=vmax)
    axs[1].set_title("Predicción (U-Net)")
    axs[1].set_xlabel("Trazas")
    axs[1].set_ylabel("Tiempo")
    fig.colorbar(im1, ax=axs[1], fraction=0.046, pad=0.04)

    im2 = axs[2].imshow(y_true.T, cmap='gray', aspect='auto', origin='upper', vmin=vmin, vmax=vmax)
    axs[2].set_title("Shot original")
    axs[2].set_xlabel("Trazas")
    axs[2].set_ylabel("Tiempo")
    fig.colorbar(im2, ax=axs[2], fraction=0.046, pad=0.04)

    im3 = axs[3].imshow(error.T, cmap='inferno', aspect='auto', origin='upper')
    axs[3].set_title("Error absoluto")
    axs[3].set_xlabel("Trazas")
    axs[3].set_ylabel("Tiempo")
    fig.colorbar(im3, ax=axs[3], fraction=0.046, pad=0.04)

    fig.suptitle("Comparación: Entrada — Predicción — Original — Error", fontsize=14)
    plt.tight_layout()
    plt.show()

def calcular_metricas_globales(y_pred_tensor, y_true_tensor, x_input_tensor):
    mse_total = []
    psnr_total = []
    ssim_total = []
    snr_total = []

    mse_nulas = []
    psnr_nulas = []
    ssim_nulas = []
    snr_nulas = []

    N = y_true_tensor.shape[0]

    for i in range(N):
        yt = y_true_tensor[i, 0].cpu().numpy()
        xi = x_input_tensor[i, 0].cpu().numpy()
        
        with torch.no_grad():
            yp = model(x_input_tensor[i:i+1].to(device)).cpu().squeeze().numpy()

        # Global
        mse = np.mean((yt - yp)**2)
        psnr = 20 * log10(np.ptp(yt) / (np.sqrt(mse + 1e-8))) if mse > 0 else float('inf')
        ssim_val = np.mean([ssim(yt[j], yp[j], data_range=2) for j in range(yt.shape[0])])
        snr = 10 * np.log10(np.mean(yt**2) / (mse + 1e-8))

        mse_total.append(mse)
        psnr_total.append(psnr)
        ssim_total.append(ssim_val)
        snr_total.append(snr)

        # Solo trazas nulas
        idxs_nulas = np.where(np.all(xi == 0, axis=1))[0]
        if len(idxs_nulas) > 0:
            yt_nulas = yt[idxs_nulas]
            yp_nulas = yp[idxs_nulas]

            mse_n = np.mean((yt_nulas - yp_nulas)**2)
            psnr_n = 20 * log10(np.ptp(yt_nulas) / (np.sqrt(mse_n + 1e-8))) if mse_n > 0 else float('inf')
            ssim_n = np.mean([ssim(yt_nulas[j], yp_nulas[j], data_range=2) for j in range(len(idxs_nulas))])
            snr_n = 10 * np.log10(np.mean(yt_nulas**2) / (mse_n + 1e-8))

            mse_nulas.append(mse_n)
            psnr_nulas.append(psnr_n)
            ssim_nulas.append(ssim_n)
            snr_nulas.append(snr_n)

    return {
        "MSE_Global": np.mean(mse_total),
        "PSNR_Global": np.mean(psnr_total),
        "SSIM_Global": np.mean(ssim_total),
        "SNR_Global": np.mean(snr_total),
        "MSE_Nulas": np.mean(mse_nulas),
        "PSNR_Nulas": np.mean(psnr_nulas),
        "SSIM_Nulas": np.mean(ssim_nulas),
        "SNR_Nulas": np.mean(snr_nulas)
    }

class UNet2DFull(nn.Module):
    def __init__(self):
        super(UNet2DFull, self).__init__()

        # Encoder
        self.enc1 = nn.Sequential(
            nn.Conv2d(1, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(negative_slope=0.01, inplace=True),
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(negative_slope=0.01, inplace=True)
        )
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)

        self.enc2 = nn.Sequential(
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(negative_slope=0.01, inplace=True),
            nn.Conv2d(128, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(negative_slope=0.01, inplace=True)
        )
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)

        self.enc3 = nn.Sequential(
            nn.Conv2d(128, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(negative_slope=0.01, inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(negative_slope=0.01, inplace=True)
        )
        self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)

        # Bottleneck
        self.bottleneck = nn.Sequential(
            nn.Conv2d(256, 512, kernel_size=3, padding=1),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(negative_slope=0.01, inplace=True),
            nn.Conv2d(512, 512, kernel_size=3, padding=1),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(negative_slope=0.01, inplace=True)
        )

        # Decoder
        self.upconv3 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
        self.dec3 = nn.Sequential(
            nn.Conv2d(512, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(negative_slope=0.01, inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(negative_slope=0.01, inplace=True)
        )

        self.upconv2 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
        self.dec2 = nn.Sequential(
            nn.Conv2d(256, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(negative_slope=0.01, inplace=True),
            nn.Conv2d(128, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(negative_slope=0.01, inplace=True)
        )

        self.upconv1 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
        self.dec1 = nn.Sequential(
            nn.Conv2d(128, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(negative_slope=0.01, inplace=True),
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(negative_slope=0.01, inplace=True)
        )

        # Final output
        self.final = nn.Conv2d(64, 1, kernel_size=1)

    def forward(self, x):
        # Encoder
        e1 = self.enc1(x)
        p1 = self.pool1(e1)

        e2 = self.enc2(p1)
        p2 = self.pool2(e2)

        e3 = self.enc3(p2)
        p3 = self.pool3(e3)

        # Bottleneck
        b = self.bottleneck(p3)

        # Decoder 3
        u3 = self.upconv3(b)
        e3_crop = self.center_crop(e3, u3)
        d3 = self.dec3(torch.cat([u3, e3_crop], dim=1))

        # Decoder 2
        u2 = self.upconv2(d3)
        e2_crop = self.center_crop(e2, u2)
        d2 = self.dec2(torch.cat([u2, e2_crop], dim=1))

        # Decoder 1
        u1 = self.upconv1(d2)
        e1_crop = self.center_crop(e1, u1)
        d1 = self.dec1(torch.cat([u1, e1_crop], dim=1))

        # Output
        out = self.final(d1)
        out = torch.tanh(out)
        return out

    @staticmethod
    def center_crop(enc_feature, target_feature):
        _, _, h, w = target_feature.shape
        _, _, H, W = enc_feature.shape

        dh = (H - h) // 2
        dw = (W - w) // 2

        return enc_feature[:, :, dh:dh+h, dw:dw+w]

In [None]:
ruta = "/home/pc-2/Documents/CAVE_minciencias/utah_model/2D_seismic_data/2D/Correlated_Shot_Gathers"
gathers = cargar_shot_gathers(ruta, max_archivos=200)

N_shots = gathers.shape[0]
idx_train, idx_test = train_test_split(np.arange(N_shots), test_size=0.2, random_state=42)
gathers_train = gathers[idx_train]
gathers_test = gathers[idx_test]

porcentajes = list(range(10, 100, 10))
resultados_metricas = {}
curvas_loss = {}

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

for pct in porcentajes:
    print(f"\n📌 Submuestreo del {pct}%")
    
    gathers_train_sub, idxs_removidos_train = submuestrear_receptores_aleatorios(gathers_train, porcentaje=pct/100)

    x_np = normalizar_por_traza(gathers_train_sub)[:, :160, :2048]
    y_np = normalizar_por_traza(gathers_train)[:, :160, :2048]

    x_tensor = torch.tensor(x_np, dtype=torch.float32).unsqueeze(1).to(device)
    y_tensor = torch.tensor(y_np, dtype=torch.float32).unsqueeze(1).to(device)

    dataset = TensorDataset(x_tensor, y_tensor)
    loader = DataLoader(dataset, batch_size=4, shuffle=True)

    model = UNet2DFull().to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
    criterion = SSIM(data_range=1.0, size_average=True, channel=1)
    n_epochs = 200
    loss_history = []

    for epoch in range(n_epochs):
        model.train()
        running_loss = 0.0
        for x_batch, y_batch in tqdm(loader, desc=f"Submuestreo {pct}% - Época {epoch+1}"):
            optimizer.zero_grad()
            y_pred = model(x_batch)
            loss = 1 - criterion(y_pred, y_batch)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
        loss_history.append(running_loss / len(loader))

    curvas_loss[f"{pct}%"] = loss_history

    gathers_test_sub, idxs_removidos_test = submuestrear_receptores_aleatorios(gathers_test, porcentaje=pct/100)

    x_eval_np = normalizar_por_traza(gathers_test_sub)[:, :160, :2048]
    y_eval_np = normalizar_por_traza(gathers_test)[:, :160, :2048]

    x_eval_tensor = torch.tensor(x_eval_np, dtype=torch.float32).unsqueeze(1).to(device)
    y_eval_tensor = torch.tensor(y_eval_np, dtype=torch.float32).unsqueeze(1).to(device)

    idx = 10
    model.eval()
    with torch.no_grad():
        pred = model(x_eval_tensor[idx:idx+1]).cpu().squeeze().numpy()
        entrada = x_eval_tensor[idx, 0].cpu().numpy()
        real = y_eval_tensor[idx, 0].cpu().numpy()
        visualizar_comparacion(entrada, pred, real)
        trazas_eliminadas = idxs_removidos_test[idx]

        trazas_a_mostrar = trazas_eliminadas[:3] if len(trazas_eliminadas) >= 3 else trazas_eliminadas

        tiempo = np.arange(real.shape[1])

        plt.figure(figsize=(30, 4))
        for i, traza_idx in enumerate(trazas_a_mostrar):
            plt.subplot(1, len(trazas_a_mostrar), i+1)
            plt.plot(tiempo, real[traza_idx], label='Real', linewidth=1.5)
            plt.plot(tiempo, pred[traza_idx], label='Predicho', linewidth=1.5)
            plt.title(f"Traza eliminada #{traza_idx}")
            plt.xlabel("Tiempo")
            plt.ylabel("Amplitud")
            plt.legend()
            plt.grid(True)

        plt.suptitle(f"Comparación real vs predicho - Shot {idx}")
        plt.tight_layout()
        plt.show()

    metricas = calcular_metricas_globales(model, y_eval_tensor, x_eval_tensor)
    resultados_metricas[f"{pct}"] = metricas

plt.figure(figsize=(10, 6))
for pct, curva in curvas_loss.items():
    plt.plot(range(1, len(curva)+1), curva, label=f"{pct}")
plt.title("Curvas de pérdida por submuestreo")
plt.xlabel("Época")
plt.ylabel("SSIM")
plt.legend(title="Submuestreo")
plt.grid(True)
plt.tight_layout()
plt.show()

print(resultados_metricas)