In [None]:
# -*- coding: utf-8 -*-
from util.My_tool1 import *
import time
import torch
import numpy as np
import matplotlib.pyplot as plt
import torch.nn.functional as F

def preprocess_seismic_data(seismic_data, nsub, noise_level=0.1):
    """
    Preprocess seismic data by adding noise and subsampling traces.

    Parameters:
        seismic_data (np.ndarray): Original seismic data array.
        nsub (int): Subsampling factor. From every 'nsub' traces, only one is kept.
        noise_level (float): Standard deviation of the Gaussian noise to be added.

    Returns:
        np.ndarray: Preprocessed seismic data.
    """
    # Add Gaussian noise
    noise = np.random.normal(0, noise_level * np.std(seismic_data), seismic_data.shape)
    noisy_data = seismic_data + noise

    # Subsample traces by zeroing out
    kept_traces = np.arange(0, noisy_data.shape[1], nsub)
    subsampled_data = noisy_data.copy()
    for trace_idx in range(subsampled_data.shape[1]):
        if trace_idx not in kept_traces:
            subsampled_data[:, trace_idx] = 0

    return subsampled_data

# Load clear seismic data
x = np.load('test_data/data_dense.npy').T
x = x.astype(np.float64)

# Define preprocessing parameters
nsub = 2  # From every 4 traces, keep 1
noise_level = 0  # 10% of data's standard deviation

# Preprocess the seismic data
y = preprocess_seismic_data(x, nsub, noise_level)

# Save the preprocessed data if needed
np.save('test_data/new_noise_and_miss.npy', y)

if __name__ == '__main__':
    model = torch.load('models\combined_model_swin_tiny_patch4_window7_224.pth')
    model.eval()  # evaluation mode
    if torch.cuda.is_available():
        model = model.cuda()

    # Load the same clean data used for preprocessing
    x = np.load('test_data/data_dense.npy').T
    x = x.astype(np.float64)

    y = np.load('test_data/new_noise_and_miss.npy')

    # Ensure y has the same shape as x
    assert y.shape == x.shape, "Shapes of x and y do not match."

    # Create input tensor with batch and channel dimensions
    y_ = torch.from_numpy(y).unsqueeze(0).unsqueeze(1)  # Shape: (1, 1, height, width)

    # Get window size from the model or define it
    window_size = 8

    # Calculate padding for height and width
    batch_size, channels, height, width = y_.shape
    pad_height = (window_size - height % window_size) % window_size
    pad_width = (window_size - width % window_size) % window_size

    # Pad the tensor
    y_padded = F.pad(y_, (0, pad_width, 0, pad_height), mode='constant', value=0)

    # Proceed with the rest of the code
    torch.cuda.synchronize()
    start_time = time.time()
    y_padded = y_padded.type(torch.float32)
    y_padded = y_padded.cuda()

    x_ = model(y_padded)  # Inferences

    # Remove padding if necessary
    x_ = x_[:, :, :height, :width]

    x_ = x_.squeeze(0).squeeze(0).cpu().detach().numpy().astype(np.float64)
    torch.cuda.synchronize()
    elapsed_time = time.time() - start_time

    # Calculate SNR and SSIM
    pre_snr = snr_(y, x)
    print("Before SNR: ", pre_snr)
    snr = snr_(x_, x)
    print("After SNR: ", snr)

    pre_ssim = ssim_(y, x, data_range=1.0)
    print("Before SSIM: ", pre_ssim)
    ssim = ssim_(x_, x, data_range=1.0)
    print("After SSIM: ", ssim)

    # Calculate MSE between original and denoised data
    mse = np.mean((x - x_) ** 2)
    print("MSE between original and denoised data: ", mse)

    # Plot the data
    plt.imshow(x, cmap='gray', aspect='auto', vmin=-1, vmax=1)
    plt.title("Исходные данные")
    plt.xlabel('Расстояние от источника, м')
    plt.ylabel('Время свободного пробега, мс')
    plt.show()

    plt.imshow(y, cmap='gray', aspect='auto', vmin=-1, vmax=1)
    plt.title("Прореженные данные")
    plt.xlabel('Расстояние от источника, м')
    plt.ylabel('Время свободного пробега, мс')
    plt.show()

    plt.imshow(x_, cmap='gray', aspect='auto', vmin=-1, vmax=1)
    plt.title("Результат работы нейросети")
    plt.xlabel('Расстояние от источника, м')
    plt.ylabel('Время свободного пробега, мс')
    plt.show()

    plt.imshow(x-x_, cmap='gray', aspect='auto', vmin=-1, vmax=1)
    plt.title("Разница")
    plt.xlabel('Расстояние от источника, м')
    plt.ylabel('Время свободного пробега, мс')
    plt.show()