In [21]:
import numpy as np
import pandas as pd
import os
import yaml
import time
import csv
import random
from pathlib import Path
from tqdm.auto import tqdm
import matplotlib.pyplot as plt
from sklearn.model_selection import KFold, train_test_split
import seaborn as sns

import scipy.stats as stats
import scipy.signal as signal

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.cuda.amp import autocast, GradScaler
from torchsummary import summary

In [22]:
# Set random seeds for reproducibility
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(42)

In [23]:
# Helper functions for data loading
def map_seismic_to_velocity_path(input_file):
    """Convert seismic data path to velocity model path"""
    return Path(str(input_file).replace('seis', 'vel').replace('data', 'model'))

def get_train_files(data_path):
    """Find all seismic data files and map to velocity model files"""
    # Find all seismic data files (containing 'seis' or 'data' in filename)
    input_files = [
        f for f in Path(data_path).rglob('*.npy')
        if ('seis' in f.stem) or ('data' in f.stem)
    ]
    
    # Map each input file to its corresponding output file
    output_files = [map_seismic_to_velocity_path(f) for f in input_files]
    
    # Verify all output files exist
    missing_files = [f for f in output_files if not f.exists()]
    if missing_files:
        raise FileNotFoundError(f"Missing velocity model files: {missing_files[:5]}...")
    
    return input_files, output_files

# Data Load for Analysis

In [24]:
# Data analysis functions
def analyze_dataset_statistics(input_files, output_files, n_samples=10):
    """Analyze basic statistics of the dataset"""
    input_stats = []
    output_stats = []
    
    # Sample files for analysis
    sample_indices = random.sample(range(len(input_files)), min(n_samples, len(input_files)))
    
    for idx in tqdm(sample_indices, desc="Analyzing files"):
        input_data = np.load(input_files[idx])
        output_data = np.load(output_files[idx])
        
        # Print shapes to understand data structure
        print(f"Input data shape: {input_data.shape}")
        print(f"Output data shape: {output_data.shape}")
        
        # Collect statistics for input data
        input_stats.append({
            'file': input_files[idx].name,
            'shape': input_data.shape,
            'min': np.min(input_data),
            'max': np.max(input_data),
            'mean': np.mean(input_data),
            'median': np.median(input_data),
            'std': np.std(input_data),
            'skew': stats.skew(input_data.reshape(-1)),
            'kurtosis': stats.kurtosis(input_data.reshape(-1)),
            'zeros_pct': np.mean(input_data == 0) * 100,
            'unique_values': len(np.unique(input_data)),
        })
        
        # Collect statistics for output data
        output_stats.append({
            'file': output_files[idx].name,
            'shape': output_data.shape,
            'min': np.min(output_data),
            'max': np.max(output_data),
            'mean': np.mean(output_data),
            'median': np.median(output_data),
            'std': np.std(output_data),
            'skew': stats.skew(output_data.reshape(-1)),
            'kurtosis': stats.kurtosis(output_data.reshape(-1)),
            'zeros_pct': np.mean(output_data == 0) * 100,
            'unique_values': len(np.unique(output_data)),
        })
    
    input_df = pd.DataFrame(input_stats)
    output_df = pd.DataFrame(output_stats)
    
    print("\n=== INPUT DATA STATISTICS ===")
    print(input_df[['min', 'max', 'mean', 'median', 'std', 'zeros_pct']].describe())
    
    print("\n=== OUTPUT DATA STATISTICS ===")
    print(output_df[['min', 'max', 'mean', 'median', 'std', 'zeros_pct']].describe())
    
    return input_df, output_df, input_data.shape, output_data.shape

def visualize_sample_pair(input_file, output_file, index=0, sample_idx=0, channel_idx=0, figsize=(15, 10)):
    """Visualize a sample pair of input seismic data and output velocity model"""
    input_data = np.load(input_file)
    output_data = np.load(output_file)  # Fixed: using the passed parameter
    
    print(f"Input data shape: {input_data.shape}")
    print(f"Output data shape: {output_data.shape}")
    
    # Based on your data structure: [samples, channels, height, width]
    # Select a specific sample and channel
    seismic_sample = input_data[sample_idx, channel_idx]
    velocity_sample = output_data[sample_idx, 0]  # Assuming output has 1 channel
    
    print(f"Seismic sample shape: {seismic_sample.shape}")
    print(f"Velocity sample shape: {velocity_sample.shape}")
    
    # Create figure
    fig, axs = plt.subplots(2, 2, figsize=figsize)
    
    # Plot seismic data
    im0 = axs[0, 0].imshow(seismic_sample, aspect='auto', cmap='seismic')
    axs[0, 0].set_title(f'Seismic Data - Sample {sample_idx}, Channel {channel_idx}')
    axs[0, 0].set_xlabel('X position')
    axs[0, 0].set_ylabel('Time/Depth')
    plt.colorbar(im0, ax=axs[0, 0], fraction=0.046, pad=0.04)
    
    # Plot velocity model
    im1 = axs[0, 1].imshow(velocity_sample, aspect='auto', cmap='viridis')
    axs[0, 1].set_title(f'Velocity Model - Sample {sample_idx}')
    axs[0, 1].set_xlabel('X position')
    axs[0, 1].set_ylabel('Depth')
    plt.colorbar(im1, ax=axs[0, 1], fraction=0.046, pad=0.04)
    
    # Plot histograms
    axs[1, 0].hist(seismic_sample.ravel(), bins=50, alpha=0.7, color='blue')
    axs[1, 0].set_title('Seismic Data Histogram')
    axs[1, 0].set_xlabel('Amplitude')
    axs[1, 0].set_ylabel('Frequency')
    
    axs[1, 1].hist(velocity_sample.ravel(), bins=50, alpha=0.7, color='green')
    axs[1, 1].set_title('Velocity Model Histogram')
    axs[1, 1].set_xlabel('Velocity (m/s)')
    axs[1, 1].set_ylabel('Frequency')
    
    plt.tight_layout()
    plt.savefig(f'/kaggle/working/{sample_idx}_visualize_sample_pair_channel{channel_idx}.png')
    plt.show()
    
    return seismic_sample, velocity_sample

def visualize_channels(input_file, sample_idx=0, figsize=(20, 15)):
    """Visualize all channels of a seismic data sample"""
    input_data = np.load(input_file)
    
    # Select a specific sample
    if len(input_data.shape) == 4:  # [samples, channels, height, width]
        sample_data = input_data[sample_idx]
        n_channels = sample_data.shape[0]
    else:
        raise ValueError(f"Unexpected data shape: {input_data.shape}")
    
    # Create a grid of subplots - 2 rows with 3 columns should be enough for 5 channels
    fig, axs = plt.subplots(2, 3, figsize=figsize)
    
    # Flatten the axes array for easier indexing
    axs = axs.flatten()
    
    # Plot each channel
    for i in range(n_channels):
        channel_data = sample_data[i]
        
        im = axs[i].imshow(channel_data, aspect='auto', cmap='seismic')
        axs[i].set_title(f'Channel {i} - Shape: {channel_data.shape}')
        axs[i].set_xlabel('X position')
        axs[i].set_ylabel('Time/Depth')
        plt.colorbar(im, ax=axs[i], fraction=0.046, pad=0.04)
    
    # Hide any unused subplots
    for i in range(n_channels, len(axs)):
        axs[i].axis('off')
    
    plt.tight_layout()
    plt.savefig(f'/kaggle/working/{sample_idx}_visualize_channels.png')
    plt.show()
    
    return sample_data

def analyze_spectral_content(input_file, output_file, sample_idx=0, channel_idx=0, figsize=(15, 10)):
    """Analyze the spectral content of input and output data"""
    input_data = np.load(input_file)
    output_data = np.load(output_file)
    
    # Select data based on known shape
    seismic_sample = input_data[sample_idx, channel_idx]
    velocity_sample = output_data[sample_idx, 0]  # Assuming output has 1 channel
    
    # Compute 2D FFT
    seismic_fft = np.fft.fft2(seismic_sample)
    seismic_fft_shifted = np.fft.fftshift(seismic_fft)
    seismic_magnitude = np.log1p(np.abs(seismic_fft_shifted))
    
    velocity_fft = np.fft.fft2(velocity_sample)
    velocity_fft_shifted = np.fft.fftshift(velocity_fft)
    velocity_magnitude = np.log1p(np.abs(velocity_fft_shifted))
    
    # Create figure
    fig, axs = plt.subplots(2, 2, figsize=figsize)
    
    # Plot original data
    im0 = axs[0, 0].imshow(seismic_sample, aspect='auto', cmap='seismic')
    axs[0, 0].set_title(f'Seismic Data - Sample {sample_idx}, Channel {channel_idx}')
    plt.colorbar(im0, ax=axs[0, 0], fraction=0.046, pad=0.04)
    
    im1 = axs[0, 1].imshow(velocity_sample, aspect='auto', cmap='viridis')
    axs[0, 1].set_title(f'Velocity Model - Sample {sample_idx}')
    plt.colorbar(im1, ax=axs[0, 1], fraction=0.046, pad=0.04)
    
    # Plot FFT magnitude
    im2 = axs[1, 0].imshow(seismic_magnitude, aspect='auto', cmap='inferno')
    axs[1, 0].set_title('Seismic Data FFT Magnitude (log scale)')
    plt.colorbar(im2, ax=axs[1, 0], fraction=0.046, pad=0.04)
    
    im3 = axs[1, 1].imshow(velocity_magnitude, aspect='auto', cmap='inferno')
    axs[1, 1].set_title('Velocity Model FFT Magnitude (log scale)')
    plt.colorbar(im3, ax=axs[1, 1], fraction=0.046, pad=0.04)
    
    plt.tight_layout()
    plt.savefig(f'/kaggle/working/{sample_idx}_spectral_content_channel{channel_idx}.png')
    plt.show()
    
    return seismic_magnitude, velocity_magnitude

def analyze_gradient_patterns(input_file, output_file, sample_idx=0, channel_idx=0, figsize=(15, 15)):
    """
    Analyze gradient patterns in the data
    """
    input_data = np.load(input_file)
    output_data = np.load(output_file)
    
    # Select data based on known shape
    seismic_sample = input_data[sample_idx, channel_idx]  # Shape: [1000, 70]
    velocity_sample = output_data[sample_idx, 0]         # Shape: [70, 70]
    
    print(f"Seismic sample shape: {seismic_sample.shape}")
    print(f"Velocity sample shape: {velocity_sample.shape}")
    
    # Compute gradients
    seismic_grad_y, seismic_grad_x = np.gradient(seismic_sample)
    velocity_grad_y, velocity_grad_x = np.gradient(velocity_sample)
    
    # Compute gradient magnitude
    seismic_grad_mag = np.sqrt(seismic_grad_x**2 + seismic_grad_y**2)
    velocity_grad_mag = np.sqrt(velocity_grad_x**2 + velocity_grad_y**2)
    
    # Create figure
    fig, axs = plt.subplots(3, 3, figsize=figsize)
    
    # Plot original data
    im0 = axs[0, 0].imshow(seismic_sample, aspect='auto', cmap='seismic')
    axs[0, 0].set_title(f'Seismic Data - Sample {sample_idx}, Channel {channel_idx}')
    plt.colorbar(im0, ax=axs[0, 0], fraction=0.046, pad=0.04)
    
    im1 = axs[0, 1].imshow(velocity_sample, aspect='auto', cmap='viridis')
    axs[0, 1].set_title(f'Velocity Model - Sample {sample_idx}')
    plt.colorbar(im1, ax=axs[0, 1], fraction=0.046, pad=0.04)
    
    # Plot x-gradients
    im2 = axs[1, 0].imshow(seismic_grad_x, aspect='auto', cmap='coolwarm')
    axs[1, 0].set_title('Seismic X-Gradient')
    plt.colorbar(im2, ax=axs[1, 0], fraction=0.046, pad=0.04)
    
    im3 = axs[1, 1].imshow(velocity_grad_x, aspect='auto', cmap='coolwarm')
    axs[1, 1].set_title('Velocity X-Gradient')
    plt.colorbar(im3, ax=axs[1, 1], fraction=0.046, pad=0.04)
    
    # Instead of scatter plots (which require same sizes), show histogram of gradients
    axs[1, 2].hist(seismic_grad_x.flatten(), bins=50, alpha=0.5, label='Seismic')
    axs[1, 2].hist(velocity_grad_x.flatten(), bins=50, alpha=0.5, label='Velocity')
    axs[1, 2].set_title('X-Gradient Histograms')
    axs[1, 2].set_xlabel('Gradient Value')
    axs[1, 2].set_ylabel('Frequency')
    axs[1, 2].legend()
    
    # Plot y-gradients
    im4 = axs[2, 0].imshow(seismic_grad_y, aspect='auto', cmap='coolwarm')
    axs[2, 0].set_title('Seismic Y-Gradient')
    plt.colorbar(im4, ax=axs[2, 0], fraction=0.046, pad=0.04)
    
    im5 = axs[2, 1].imshow(velocity_grad_y, aspect='auto', cmap='coolwarm')
    axs[2, 1].set_title('Velocity Y-Gradient')
    plt.colorbar(im5, ax=axs[2, 1], fraction=0.046, pad=0.04)
    
    # Y-gradient histograms instead of scatter
    axs[2, 2].hist(seismic_grad_y.flatten(), bins=50, alpha=0.5, label='Seismic')
    axs[2, 2].hist(velocity_grad_y.flatten(), bins=50, alpha=0.5, label='Velocity')
    axs[2, 2].set_title('Y-Gradient Histograms')
    axs[2, 2].set_xlabel('Gradient Value')
    axs[2, 2].set_ylabel('Frequency')
    axs[2, 2].legend()
    
    # Plot gradient magnitude
    im6 = axs[0, 2].imshow(np.log1p(seismic_grad_mag), aspect='auto', cmap='inferno')
    axs[0, 2].set_title('Seismic Gradient Magnitude (log scale)')
    plt.colorbar(im6, ax=axs[0, 2], fraction=0.046, pad=0.04)
    
    plt.tight_layout()
    plt.savefig(f'/kaggle/working/{sample_idx}_gradient_patterns_channel{channel_idx}.png')
    plt.show()
    
    # return mean gradient values
    return np.mean(seismic_grad_mag), np.mean(velocity_grad_mag)

def analyze_initial_pool_effect(input_file, sample_idx=0, channel_idx=0):
    """Analyze the effect of the initial pooling layer in the UNet model"""
    input_data = np.load(input_file)
    
    # Select a specific sample and channel
    seismic_sample = input_data[sample_idx, channel_idx]
    
    # Convert to torch tensor for processing (add batch and channel dimensions)
    seismic_tensor = torch.from_numpy(seismic_sample).float().unsqueeze(0).unsqueeze(0)
    
    # Apply the initial_pool operation
    pool = torch.nn.AvgPool2d(kernel_size=(14, 1), stride=(14, 1))
    pooled_tensor = pool(seismic_tensor)
    
    # Convert back to numpy for visualization
    seismic_np = seismic_sample
    pooled_np = pooled_tensor.squeeze().numpy()
    
    # Create figure
    fig, axs = plt.subplots(1, 2, figsize=(15, 5))
    
    # Plot original data
    im0 = axs[0].imshow(seismic_np, aspect='auto', cmap='seismic')
    axs[0].set_title(f'Original Seismic Data - Sample {sample_idx}, Channel {channel_idx}')
    axs[0].set_xlabel('X position')
    axs[0].set_ylabel('Time/Depth')
    plt.colorbar(im0, ax=axs[0], fraction=0.046, pad=0.04)
    
    # Plot pooled data
    im1 = axs[1].imshow(pooled_np, aspect='auto', cmap='seismic')
    axs[1].set_title(f'After AvgPool2d(14,1) - Shape: {pooled_np.shape}')
    axs[1].set_xlabel('X position')
    axs[1].set_ylabel('Time/Depth (reduced)')
    plt.colorbar(im1, ax=axs[1], fraction=0.046, pad=0.04)
    
    plt.tight_layout()
    plt.savefig(f'/kaggle/working/{sample_idx}_initial_pool_effect_channel{channel_idx}.png')
    plt.show()
    
    return seismic_np, pooled_np

def analyze_frequency_components(input_file, output_file, sample_idx=0, channel_idx=0, figsize=(15, 10)):
    """Analyze and compare frequency components between input and output data"""
    input_data = np.load(input_file)
    output_data = np.load(output_file)
    
    # Select data based on known shape
    seismic_sample = input_data[sample_idx, channel_idx]
    velocity_sample = output_data[sample_idx, 0]  # Assuming output has 1 channel
    
    # Compute frequency content
    # For seismic data - average across spatial dimension
    seismic_fft_rows = np.fft.rfft(seismic_sample, axis=0)
    seismic_fft_cols = np.fft.rfft(seismic_sample, axis=1)
    
    seismic_power_rows = np.mean(np.abs(seismic_fft_rows)**2, axis=1)
    seismic_power_cols = np.mean(np.abs(seismic_fft_cols)**2, axis=0)
    
    # For velocity data
    velocity_fft_rows = np.fft.rfft(velocity_sample, axis=0)
    velocity_fft_cols = np.fft.rfft(velocity_sample, axis=1)
    
    velocity_power_rows = np.mean(np.abs(velocity_fft_rows)**2, axis=1)
    velocity_power_cols = np.mean(np.abs(velocity_fft_cols)**2, axis=0)
    
    # Create figure
    fig, axs = plt.subplots(2, 2, figsize=figsize)
    
    # Plot frequency power spectrum - rows (time/depth axis)
    freqs_rows = np.fft.rfftfreq(seismic_sample.shape[0])
    axs[0, 0].semilogy(freqs_rows, seismic_power_rows, label='Seismic')
    axs[0, 0].semilogy(np.fft.rfftfreq(velocity_sample.shape[0]), velocity_power_rows, label='Velocity')
    axs[0, 0].set_title('Power Spectrum - Depth/Time Direction')
    axs[0, 0].set_xlabel('Frequency')
    axs[0, 0].set_ylabel('Power (log scale)')
    axs[0, 0].legend()
    axs[0, 0].grid(True)
    
    # Plot frequency power spectrum - columns (spatial axis)
    freqs_cols = np.fft.rfftfreq(seismic_sample.shape[1])
    axs[0, 1].semilogy(freqs_cols, seismic_power_cols, label='Seismic')
    axs[0, 1].semilogy(np.fft.rfftfreq(velocity_sample.shape[1]), velocity_power_cols, label='Velocity')
    axs[0, 1].set_title('Power Spectrum - Spatial Direction')
    axs[0, 1].set_xlabel('Frequency')
    axs[0, 1].set_ylabel('Power (log scale)')
    axs[0, 1].legend()
    axs[0, 1].grid(True)
    
    # Plot original data for reference
    im0 = axs[1, 0].imshow(seismic_sample, aspect='auto', cmap='seismic')
    axs[1, 0].set_title(f'Seismic Data - Sample {sample_idx}, Channel {channel_idx}')
    plt.colorbar(im0, ax=axs[1, 0], fraction=0.046, pad=0.04)
    
    im1 = axs[1, 1].imshow(velocity_sample, aspect='auto', cmap='viridis')
    axs[1, 1].set_title(f'Velocity Model - Sample {sample_idx}')
    plt.colorbar(im1, ax=axs[1, 1], fraction=0.046, pad=0.04)
    
    plt.tight_layout()
    plt.savefig(f'/kaggle/working/{sample_idx}_frequency_components_channel{channel_idx}.png')
    plt.show()
    
    # Return frequency data for further analysis
    return {
        'seismic_power_rows': seismic_power_rows,
        'seismic_power_cols': seismic_power_cols,
        'velocity_power_rows': velocity_power_rows,
        'velocity_power_cols': velocity_power_cols,
        'freqs_rows': freqs_rows,
        'freqs_cols': freqs_cols
    }

In [25]:
# Main analysis function
def run_data_analysis(data_path, n_samples=5):
    """Run comprehensive data analysis"""
    print(f"Running data analysis on path: {data_path}")
    
    # Get train files
    input_files, output_files = get_train_files(data_path)
    print(f"Found {len(input_files)} input files and {len(output_files)} output files")
    
    # Basic dataset statistics
    print("\n=== DATASET STATISTICS ===")
    input_stats_df, output_stats_df, input_shape, output_shape = analyze_dataset_statistics(
        input_files, output_files, n_samples=min(n_samples, len(input_files)))
    
    if len(input_files) > 0:
        # We now know the data shape from analyze_dataset_statistics
        print(f"\nConfirmed input shape: {input_shape}")
        print(f"Confirmed output shape: {output_shape}")
        
        # Based on the actual structure, we have a 4D tensor:
        # [samples, channels, height, width] = [500, 5, 1000, 70]
        
        # Visualize the channels from first sample
        print("\n=== VISUALIZING CHANNELS ===")
        sample_idx = 2 
        sample_data = visualize_channels(input_files[0], sample_idx=sample_idx)
        
        # Visualize each channel
        for channel_idx in range(min(5, input_shape[1])):  # Look at all 5 channels
            print(f"\n=== SAMPLE VISUALIZATION - CHANNEL {channel_idx} ===")
            seismic_sample, velocity_sample = visualize_sample_pair(
                input_files[0], output_files[0], sample_idx=sample_idx, channel_idx=channel_idx)
            
            # Set channel
            #if channel_idx == 0:
            print("\n=== SPECTRAL ANALYSIS ===")
            analyze_spectral_content(input_files[0], output_files[0], 
                                    sample_idx=sample_idx, channel_idx=channel_idx)
            
            print("\n=== GRADIENT PATTERN ANALYSIS ===")
            x_grad_corr, y_grad_corr = analyze_gradient_patterns(
                input_files[0], output_files[0], sample_idx=sample_idx, channel_idx=channel_idx)
            
            print("\n=== INITIAL POOLING EFFECT ANALYSIS ===")
            orig_sample, pooled_sample = analyze_initial_pool_effect(
                input_files[0], sample_idx=sample_idx, channel_idx=channel_idx)
            
            print("\n=== FREQUENCY COMPONENT ANALYSIS ===")
            freq_data = analyze_frequency_components(
                input_files[0], output_files[0], sample_idx=sample_idx, channel_idx=channel_idx)
    
    # Return key statistics for further use
    return {
        'input_stats': {
            'mean': input_stats_df['mean'].mean(),
            'std': input_stats_df['std'].mean(),
            'min': input_stats_df['min'].min(),
            'max': input_stats_df['max'].max(),
        },
        'output_stats': {
            'mean': output_stats_df['mean'].mean(),
            'std': output_stats_df['std'].mean(),
            'min': output_stats_df['min'].min(),
            'max': output_stats_df['max'].max(),
        },
        'input_shape': input_shape,
        'output_shape': output_shape
    }

In [26]:
# if __name__ == "__main__":
    # Run data analysis on the provided dataset
#    data_path = "/kaggle/input/waveform-inversion"  # Update with your data path
#    stats = run_data_analysis(data_path, n_samples=5)
    
#    print("\n=== ANALYSIS COMPLETE ===")
#    print("\nSummary of findings:")
#    print(f"Input data shape: {stats['input_shape']}")
#    print(f"Output data shape: {stats['output_shape']}")
#    print(f"Input data - Min: {stats['input_stats']['min']:.4f}, Max: {stats['input_stats']['max']:.4f}")
#    print(f"Input data - Mean: {stats['input_stats']['mean']:.4f}, Std: {stats['input_stats']['std']:.4f}")
#    print(f"Output data - Min: {stats['output_stats']['min']:.4f}, Max: {stats['output_stats']['max']:.4f}")
#    print(f"Output data - Mean: {stats['output_stats']['mean']:.4f}, Std: {stats['output_stats']['std']:.4f}")

# Modeling Functions

In [27]:
# Improved dataset classes with normalization
class SeismicDataset(Dataset):
    def __init__(self, inputs_files, output_files, n_examples_per_file=500, normalize=True, transform=None):
        assert len(inputs_files) == len(output_files)
        self.inputs_files = inputs_files
        self.output_files = output_files
        self.n_examples_per_file = n_examples_per_file
        self.normalize = normalize
        self.transform = transform
        
        # Calculate normalization statistics if needed
        if normalize:
            self.input_stats, self.output_stats = self._calculate_stats()

    def _calculate_stats(self):
        """Calculate mean and std for normalization"""
        # Sample a subset of files to calculate statistics
        input_means, input_stds = [], []
        output_means, output_stds = [], []
        
        sample_size = min(10, len(self.inputs_files))
        
        print(f"Calculating statistics from {sample_size} files...")
        for i in range(sample_size):
            X = np.load(self.inputs_files[i])
            y = np.load(self.output_files[i])
            
            # Calc stats for input
            input_means.append(np.mean(X))
            input_stds.append(np.std(X))
            
            # Calc stats for output
            output_means.append(np.mean(y))
            output_stds.append(np.std(y))
        
        input_mean = np.mean(input_means)
        input_std = np.std(input_stds) if np.std(input_stds) > 0 else 1.0
        output_mean = np.mean(output_means)
        output_std = np.std(output_stds) if np.std(output_stds) > 0 else 1.0
        
        print(f"Input stats - Mean: {input_mean:.4f}, Std: {input_std:.4f}")
        print(f"Output stats - Mean: {output_mean:.4f}, Std: {output_std:.4f}")
        
        return (input_mean, input_std), (output_mean, output_std)

    def __len__(self):
        return len(self.inputs_files) * self.n_examples_per_file

    def __getitem__(self, idx):
        # Calculate file offset and sample offset within file
        file_idx = idx // self.n_examples_per_file
        sample_idx = idx % self.n_examples_per_file
    
        X = np.load(self.inputs_files[file_idx], mmap_mode='r')
        y = np.load(self.output_files[file_idx], mmap_mode='r')
    
        try:
            X_sample, y_sample = X[sample_idx].copy(), y[sample_idx].copy()
            
            # Apply normalization if enabled
            if self.normalize:
                X_sample = (X_sample - self.input_stats[0]) / (self.input_stats[1] + 1e-8)
                y_sample = (y_sample - self.output_stats[0]) / (self.output_stats[1] + 1e-8)
            
            # Apply any additional transforms
            if self.transform:
                X_sample, y_sample = self.transform(X_sample, y_sample)
                
            # Ensure contiguous memory layout to avoid negative stride issues
            X_sample = np.ascontiguousarray(X_sample)
            y_sample = np.ascontiguousarray(y_sample)
                
            # Convert to torch tensors
            X_sample = torch.from_numpy(X_sample).float()
            y_sample = torch.from_numpy(y_sample).float()
                
            return X_sample, y_sample
        finally:
            del X, y

class TestDataset(Dataset):
    def __init__(self, test_files, normalize=True, input_stats=None):
        self.test_files = test_files
        self.normalize = normalize
        self.input_stats = input_stats

    def __len__(self):
        return len(self.test_files)

    def __getitem__(self, i):
        test_file = self.test_files[i]
        X = np.load(test_file)
        
        if self.normalize and self.input_stats:
            X = (X - self.input_stats[0]) / (self.input_stats[1] + 1e-8)
        
        X = torch.from_numpy(X).float()
        return X, test_file.stem

In [28]:
# Data augmentation class with fixed padding operation
class SeismicAugmentation:
    def __init__(self, 
                 flip_prob=0.5, 
                 noise_prob=0.3, 
                 noise_level=0.05,
                 shift_prob=0.3,
                 max_shift=5):
        self.flip_prob = flip_prob
        self.noise_prob = noise_prob
        self.noise_level = noise_level
        self.shift_prob = shift_prob
        self.max_shift = max_shift

    def __call__(self, X, y):
        # Make a copy to avoid modifying the original data
        X = X.copy()
        y = y.copy()
        
        # Horizontal flip
        if np.random.random() < self.flip_prob:
            X = np.flip(X, axis=1).copy()  # Use np.flip and make a copy to avoid negative strides
            y = np.flip(y, axis=1).copy()  # Use np.flip and make a copy to avoid negative strides
        
        # Add random noise to input
        if np.random.random() < self.noise_prob:
            noise = np.random.normal(0, self.noise_level, X.shape)
            X = X + noise
        
        # Random time shift (along time axis, usually dim 0)
        if np.random.random() < self.shift_prob:
            shift = np.random.randint(-self.max_shift, self.max_shift + 1)
            
            # Only apply shift if it's non-zero
            if shift != 0:
                # Create padding dimensions based on the array's actual shape
                pad_width = [(0, 0)] * X.ndim  # Initialize with no padding for all dimensions
                
                if shift > 0:
                    # Pad at the beginning of the second dimension (columns/width)
                    pad_width[1] = (shift, 0)
                    X = np.pad(X, pad_width, mode='constant')[:, :-shift]
                else:  # shift < 0
                    # Pad at the end of the second dimension
                    pad_width[1] = (0, -shift)
                    X = np.pad(X, pad_width, mode='constant')[:, -shift:]
        
        return X, y

In [29]:
# Attention Gate for U-Net
class AttentionGate(nn.Module):
    """Attention Gate for U-Net architecture"""
    def __init__(self, F_g, F_l, F_int):
        super(AttentionGate, self).__init__()
        
        # Print dimensions for debugging
        print(f"Creating AttentionGate with F_g={F_g}, F_l={F_l}, F_int={F_int}")
        
        self.W_g = nn.Sequential(
            nn.Conv2d(F_g, F_int, kernel_size=1, stride=1, padding=0, bias=True),
            nn.BatchNorm2d(F_int)
        )
        self.W_x = nn.Sequential(
            nn.Conv2d(F_l, F_int, kernel_size=1, stride=1, padding=0, bias=True),
            nn.BatchNorm2d(F_int)
        )
        self.psi = nn.Sequential(
            nn.Conv2d(F_int, 1, kernel_size=1, stride=1, padding=0, bias=True),
            nn.BatchNorm2d(1),
            nn.Sigmoid()
        )
        self.relu = nn.ReLU(inplace=True)

    def forward(self, g, x):
        # Print shapes for debugging
        # print(f"g shape: {g.shape}, x shape: {x.shape}")
        
        g1 = self.W_g(g)
        x1 = self.W_x(x)
        
        # Ensure shapes match for addition
        if g1.shape[2:] != x1.shape[2:]:
            # Resize g1 to match x1's spatial dimensions
            g1 = F.interpolate(g1, size=x1.shape[2:], mode='bilinear', align_corners=False)
        
        psi = self.relu(g1 + x1)
        psi = self.psi(psi)
        
        # Ensure psi has the same spatial dimensions as x
        if psi.shape[2:] != x.shape[2:]:
            psi = F.interpolate(psi, size=x.shape[2:], mode='bilinear', align_corners=False)
            
        return x * psi


In [30]:
# Squeeze-and-Excitation block
class SEBlock(nn.Module):
    def __init__(self, channel, reduction=16):
        super(SEBlock, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Sequential(
            nn.Linear(channel, channel // reduction, bias=False),
            nn.ReLU(inplace=True),
            nn.Linear(channel // reduction, channel, bias=False),
            nn.Sigmoid()
        )

    def forward(self, x):
        b, c, _, _ = x.size()
        y = self.avg_pool(x).view(b, c)
        y = self.fc(y).view(b, c, 1, 1)
        return x * y.expand_as(x)

In [31]:
# Improved Residual Double Conv Block with SE
class ResidualDoubleConv(nn.Module):
    """(Convolution => [BN] => ReLU) * 2 + Residual Connection + SE Block"""

    def __init__(self, in_channels, out_channels, mid_channels=None, use_se=True):
        super().__init__()
        if not mid_channels:
            mid_channels = out_channels

        # First convolution layer
        self.conv1 = nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(mid_channels)
        self.relu = nn.LeakyReLU(negative_slope=0.01, inplace=True)  # Using LeakyReLU

        # Second convolution layer
        self.conv2 = nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)

        # SE block
        self.use_se = use_se
        if use_se:
            self.se = SEBlock(out_channels, reduction=16)

        # Shortcut connection to handle potential channel mismatch
        if in_channels == out_channels:
            self.shortcut = nn.Identity()
        else:
            # Projection shortcut: 1x1 conv + BN to match output channels
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False),
                nn.BatchNorm2d(out_channels)
            )

    def forward(self, x):
        identity = x  # Store the input for the residual connection

        # First conv block
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        # Second conv block (without final ReLU yet)
        out = self.conv2(out)
        out = self.bn2(out)

        # Apply SE block if enabled
        if self.use_se:
            out = self.se(out)

        # Apply shortcut to the identity path
        identity_mapped = self.shortcut(identity)

        # Add the residual connection
        out += identity_mapped

        # Apply final ReLU
        out = self.relu(out)
        return out

In [32]:
# Improved Up block with attention
class Up(nn.Module):
    """Upscaling then ResidualDoubleConv with Attention"""

    def __init__(self, in_channels, out_channels, bilinear=True, use_attention=True):
        super().__init__()
        self.bilinear = bilinear
        self.use_attention = use_attention

        # Print dimensions for debugging
        # print(f"Creating Up block with in_channels={in_channels}, out_channels={out_channels}, bilinear={bilinear}")

        if bilinear:
            self.up = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=False)
            # With bilinear upsampling, the number of channels doesn't change
            conv_in_channels = in_channels + out_channels  # Skip connection + upsampled features
            self.conv = ResidualDoubleConv(conv_in_channels, out_channels)
            
            # For attention gate, the g input is the upsampled feature map (in_channels)
            # and the x input is the skip connection (out_channels)
            if use_attention:
                self.attention = AttentionGate(F_g=in_channels, F_l=out_channels, F_int=out_channels // 2)

        else:  # Using ConvTranspose2d
            # ConvTranspose halves the channels: in_channels -> in_channels // 2
            self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
            # Input channels to ResidualDoubleConv
            conv_in_channels = in_channels // 2 + out_channels  # After transpose conv + skip connection
            self.conv = ResidualDoubleConv(conv_in_channels, out_channels)
            
            # For attention gate, the g input is the upsampled feature map (in_channels/2)
            # and the x input is the skip connection (out_channels)
            if use_attention:
                self.attention = AttentionGate(F_g=in_channels // 2, F_l=out_channels, F_int=out_channels // 2)

    def forward(self, x1, x2):
        # x1 is the feature map from the layer below (needs upsampling)
        # x2 is the skip connection from the corresponding encoder layer
        
        # Print shapes for debugging
        # print(f"Before up: x1 shape: {x1.shape}, x2 shape: {x2.shape}")
        
        x1 = self.up(x1)
        
        # Print shapes after upsampling for debugging
        # print(f"After up: x1 shape: {x1.shape}")
        
        # Input is CHW
        diffY = x2.size(2) - x1.size(2)
        diffX = x2.size(3) - x1.size(3)
    
        # Handle case where x1 needs padding (diffY/diffX > 0)
        if diffY > 0 or diffX > 0:
            # Pad format: (padding_left, padding_right, padding_top, padding_bottom)
            x1 = F.pad(
                x1, [diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2]
            )
        # Handle case where x1 is larger than x2 (diffY/diffX < 0)
        elif diffY < 0 or diffX < 0:
            # Crop x1 to match x2's spatial dimensions
            # Calculate cropping amounts
            crop_y = -diffY
            crop_x = -diffX
            
            # Crop x1 to match x2
            x1 = x1[:, :, 
                    crop_y//2:crop_y//2 + x2.size(2),
                    crop_x//2:crop_x//2 + x2.size(3)]
        
        # Print shapes after padding/cropping for debugging
        # print(f"After adjustment: x1 shape: {x1.shape}, x2 shape: {x2.shape}")
        
        # Apply attention if enabled
        if self.use_attention:
            x2_att = self.attention(x1, x2)
        else:
            x2_att = x2
    
        # Concatenate along the channel dimension
        x = torch.cat([x2_att, x1], dim=1)
        
        # Print shape before conv for debugging
        # print(f"Before conv: concatenated shape: {x.shape}")
        
        return self.conv(x)

In [33]:
# Improved UNet architecture
class UNet(nn.Module):
    """U-Net architecture implementation with Residual Blocks, SE, and Attention"""

    def __init__(
        self,
        n_channels=5,
        n_classes=1,
        init_features=64,  # Increased from 32 to 64
        depth=5,
        bilinear=True,
        use_attention=True,
        use_se=True,
    ):
        super().__init__()
        self.n_channels = n_channels
        self.n_classes = n_classes
        self.bilinear = bilinear
        self.depth = depth
        self.use_attention = use_attention
        self.use_se = use_se

        # Improved approach for multi-channel data
        self.initial_pool = nn.Sequential(
            # Use depthwise separable convolution to process each channel independently first
            nn.Conv2d(5, 5, kernel_size=(14, 1), stride=(14, 1), groups=5, bias=False),
            nn.BatchNorm2d(5),
            nn.LeakyReLU(negative_slope=0.1, inplace=True),
            # Then combine channel information with pointwise convolution
            nn.Conv2d(5, init_features, kernel_size=1, bias=False),
            nn.BatchNorm2d(init_features),
            nn.LeakyReLU(negative_slope=0.1, inplace=True)
        )

        # --- Encoder ---
        self.encoder_convs = nn.ModuleList()  # Store conv blocks
        self.encoder_pools = nn.ModuleList()  # Store pool layers

        # Initial conv block (no pooling before it)
        # Use ResidualDoubleConv for the initial convolution block
        self.inc = ResidualDoubleConv(n_channels, init_features, use_se=use_se)
        self.encoder_convs.append(self.inc)

        current_features = init_features
        for _ in range(depth):
            # Define convolution block for this stage
            conv = ResidualDoubleConv(current_features, current_features * 2, use_se=use_se)
            # Define pooling layer for this stage
            pool = nn.MaxPool2d(2)
            self.encoder_convs.append(conv)
            self.encoder_pools.append(pool)
            current_features *= 2

        # --- Bottleneck ---
        # Use ResidualDoubleConv for the bottleneck
        self.bottleneck = ResidualDoubleConv(current_features, current_features, use_se=use_se)

        # --- Decoder ---
        self.decoder_blocks = nn.ModuleList()
        # Input features start from bottleneck output features
        # Output features at each stage are halved
        for _ in range(depth):
            # Up block uses ResidualDoubleConv internally and handles channels
            up_block = Up(current_features, current_features // 2, 
                          bilinear=bilinear, use_attention=use_attention)
            self.decoder_blocks.append(up_block)
            current_features //= 2  # Halve features for next Up block input

        # --- Output Layer ---
        # Input features are the output features of the last Up block
        self.outc = OutConv(current_features, n_classes)

    def _pad_or_crop(self, x, target_h=70, target_w=70):
        """Pads or crops input tensor x to target height and width."""
        _, _, h, w = x.shape
        # Pad Height if needed
        if h < target_h:
            pad_top = (target_h - h) // 2
            pad_bottom = target_h - h - pad_top
            x = F.pad(x, (0, 0, pad_top, pad_bottom))  # Pad height only
            h = target_h
        # Pad Width if needed
        if w < target_w:
            pad_left = (target_w - w) // 2
            pad_right = target_w - w - pad_left
            x = F.pad(x, (pad_left, pad_right, 0, 0))  # Pad width only
            w = target_w
        # Crop Height if needed
        if h > target_h:
            crop_top = (h - target_h) // 2
            # Use slicing to crop
            x = x[:, :, crop_top : crop_top + target_h, :]
            h = target_h
        # Crop Width if needed
        if w > target_w:
            crop_left = (w - target_w) // 2
            x = x[:, :, :, crop_left : crop_left + target_w]
            w = target_w
        return x

    def forward(self, x):
        # Initial pooling and resizing
        # print(f"Input shape: {x.shape}")
        x_pooled = self.initial_pool(x)
        # print(f"After initial_pool: {x_pooled.shape}")
        x_resized = self._pad_or_crop(x_pooled, target_h=70, target_w=70)
        # print(f"After pad_or_crop: {x_resized.shape}")
    
        # --- Encoder Path ---
        skip_connections = []
        xi = x_resized
    
        # Apply initial conv (inc)
        xi = self.encoder_convs[0](xi)
        # print(f"After initial conv: {xi.shape}")
        skip_connections.append(xi)  # Store output of inc
    
        # Apply subsequent encoder convs and pools
        # self.depth is the number of pooling layers
        for i in range(self.depth):
            # Apply conv block for this stage
            xi = self.encoder_convs[i+1](xi)
            # print(f"After encoder conv {i+1}: {xi.shape}")
            # Store skip connection *before* pooling
            skip_connections.append(xi)
            # Apply pooling layer for this stage
            xi = self.encoder_pools[i](xi)
            # print(f"After pool {i+1}: {xi.shape}")
    
        # Apply bottleneck conv
        xi = self.bottleneck(xi)
        # print(f"After bottleneck: {xi.shape}")
    
        # --- Decoder Path ---
        xu = xi  # Start with bottleneck output
        # Iterate through decoder blocks and corresponding skip connections in reverse
        for i, block in enumerate(self.decoder_blocks):
            # Determine the correct skip connection index from the end
            skip_index = self.depth - 1 - i
            skip = skip_connections[skip_index]
            # print(f"Decoder {i}, xu shape: {xu.shape}, skip shape: {skip.shape}")
            xu = block(xu, skip)  # Up block combines xu (from below) and skip
            # print(f"After decoder block {i}: {xu.shape}")
    
        # --- Final Output ---
        logits = self.outc(xu)
        # print(f"Logits shape: {logits.shape}")
        
        # Apply scaling and offset specific to the problem's target range
        output = logits * 1000.0 + 1500.0
        return output

In [34]:
# Output Convolution
class OutConv(nn.Module):
    """1x1 Convolution for the output layer"""

    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)

    def forward(self, x):
        return self.conv(x)

In [35]:
# Custom Loss Functions
class FocalLoss(nn.Module):
    """
    Focal Loss for handling class imbalance
    """
    def __init__(self, gamma=2.0, alpha=None, reduction='mean', eps=1e-6):  # Added eps
        super(FocalLoss, self).__init__()
        self.gamma = gamma
        self.alpha = alpha
        self.reduction = reduction
        self.eps = eps  # Small epsilon to prevent numerical instability

    def forward(self, input, target):
        # For regression tasks, we can adapt focal loss
        loss = F.mse_loss(input, target, reduction='none')
        
        # Apply focal weighting
        focal_weight = torch.exp(-loss)
        focal_weight = torch.clamp(focal_weight, min=self.eps, max=1-self.eps) # Clamp to avoid log(0)
        focal_loss = torch.pow(1 - focal_weight, self.gamma) * loss
        
        if self.alpha is not None:
            # Apply alpha weighting if provided
            focal_loss = self.alpha * focal_loss
        
        if self.reduction == 'mean':
            return focal_loss.mean()
        elif self.reduction == 'sum':
            return focal_loss.sum()
        else:
            return focal_loss

class GradientLoss(nn.Module):
    """
    Loss that incorporates gradients of the prediction and target
    to preserve structural information in seismic imaging
    """
    def __init__(self, lambda_grad=0.5, reduction='mean'):
        super(GradientLoss, self).__init__()
        self.lambda_grad = lambda_grad
        self.reduction = reduction
    
    def forward(self, input, target):
        # MSE loss
        mse_loss = F.mse_loss(input, target, reduction='none')
        
        # Calculate gradients using Sobel filters
        # Horizontal gradient
        h_kernel = torch.tensor([[[[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]]]], 
                                 dtype=input.dtype, device=input.device)
        h_kernel = h_kernel.repeat(input.size(1), 1, 1, 1)
        
        # Vertical gradient
        v_kernel = torch.tensor([[[[-1, -2, -1], [0, 0, 0], [1, 2, 1]]]], 
                                 dtype=input.dtype, device=input.device)
        v_kernel = v_kernel.repeat(input.size(1), 1, 1, 1)
        
        # Calculate gradients
        input_h_grad = F.conv2d(input, h_kernel, padding=1, groups=input.size(1))
        input_v_grad = F.conv2d(input, v_kernel, padding=1, groups=input.size(1))
        target_h_grad = F.conv2d(target, h_kernel, padding=1, groups=target.size(1))
        target_v_grad = F.conv2d(target, v_kernel, padding=1, groups=target.size(1))
        
        # Calculate gradient loss
        h_grad_loss = F.mse_loss(input_h_grad, target_h_grad, reduction='none')
        v_grad_loss = F.mse_loss(input_v_grad, target_v_grad, reduction='none')
        grad_loss = h_grad_loss + v_grad_loss
        
        # Combine MSE and gradient loss
        combined_loss = mse_loss + self.lambda_grad * grad_loss
        
        if self.reduction == 'mean':
            return combined_loss.mean()
        elif self.reduction == 'sum':
            return combined_loss.sum()
        else:
            return combined_loss

class CombinedLoss(nn.Module):
    """
    Combine multiple loss functions
    """
    def __init__(self, losses, weights=None):
        super(CombinedLoss, self).__init__()
        self.losses = nn.ModuleList(losses)
        self.weights = weights if weights is not None else [1.0] * len(losses)
        assert len(self.losses) == len(self.weights), "Number of losses and weights must match"

    def forward(self, input, target):
        total_loss = 0
        for i, loss in enumerate(self.losses):
            total_loss += self.weights[i] * loss(input, target)
        return total_loss

class FrequencyDomainLoss(nn.Module):
    """
    Loss calculated in the frequency domain to stabilize training.
    """
    def __init__(self, reduction='mean', loss_type='mse', eps=1e-6, lambda_freq=0.3):
        super(FrequencyDomainLoss, self).__init__()
        self.reduction = reduction
        self.loss_type = loss_type
        self.eps = eps  # Small epsilon for numerical stability
        self.lambda_freq = lambda_freq

    def forward(self, input, target):
        # 1. Transform to frequency domain (using rfft2 for real-valued input)
        input_fft = torch.fft.rfft2(input.float())  # Ensure float type
        target_fft = torch.fft.rfft2(target.float())

        # 2. Calculate the loss, handling complex numbers correctly
        if self.loss_type == 'mse':
            # Calculate MSE on magnitude, real, and imaginary parts, then combine.
            mag_loss = F.mse_loss(torch.abs(input_fft), torch.abs(target_fft), reduction='none')
            real_loss = F.mse_loss(input_fft.real, target_fft.real, reduction='none')
            imag_loss = F.mse_loss(input_fft.imag, target_fft.imag, reduction='none')
            loss = mag_loss + real_loss + imag_loss # Combine losses.  You can weight them differently if needed.

        elif self.loss_type == 'mae':
            mag_loss = F.l1_loss(torch.abs(input_fft), torch.abs(target_fft), reduction='none')
            real_loss = F.l1_loss(input_fft.real, target_fft.real, reduction='none')
            imag_loss = F.l1_loss(input_fft.imag, target_fft.imag, reduction='none')
            loss = mag_loss + real_loss + imag_loss

        elif self.loss_type == 'magnitude':
            input_mag = torch.abs(input_fft)
            target_mag = torch.abs(target_fft)
            loss = F.mse_loss(input_mag, target_mag, reduction='none')
        elif self.loss_type == 'phase':
            input_phase = torch.angle(input_fft)
            target_phase = torch.angle(target_fft)
            loss = F.mse_loss(input_phase, target_phase, reduction='none')
        elif self.loss_type == 'complex_mse':
            # Complex MSE: Sum of squares of real and imaginary differences
            real_loss = F.mse_loss(input_fft.real, target_fft.real, reduction='none')
            imag_loss = F.mse_loss(input_fft.imag, target_fft.imag, reduction='none')
            loss = real_loss + imag_loss
        else:
            raise ValueError(f"Unsupported loss type: {self.loss_type}")

        # 3. Apply reduction
        if self.reduction == 'mean':
            return loss.mean()
        elif self.reduction == 'sum':
            return loss.sum()
        else:
            return loss


In [36]:
# Training function with mixed precision
def train_with_mixed_precision(model, train_loader, valid_loader, criterion, optimizer, scheduler, config, device):
    """
    Training function with mixed precision and validation
    """
    # Initialize gradient scaler for mixed precision
    try:
        # New way (PyTorch 2.0+)
        from torch.amp import GradScaler
        scaler = GradScaler(device_type='cuda')
    except TypeError:
        # Old way (PyTorch < 2.0)
        from torch.cuda.amp import GradScaler
        scaler = GradScaler()
    
    best_valid_loss = float('inf')
    early_stop_counter = 0
    history = {'train_loss': [], 'valid_loss': []}
    
    for epoch in range(config['max_epochs']):
        start_time = time.time()
        
        # Training phase
        model.train()
        train_losses = []
        
        for batch_idx, (data, target) in enumerate(train_loader):
            data, target = data.to(device), target.to(device)
            
            optimizer.zero_grad()
            
            # Use autocast for mixed precision
            with torch.autocast(device_type='cuda', dtype=torch.float16):
                output = model(data)
                loss = criterion(output, target)
            
            # Scale the loss and call backward
            scaler.scale(loss).backward()
            
            # Gradient clipping to prevent exploding gradients
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            
            # Perform optimization step with scaler
            scaler.step(optimizer)
            scaler.update()
            
            train_losses.append(loss.item())
            
            if batch_idx % config['print_freq'] == 0:
                print(f'Epoch: {epoch+1}/{config["max_epochs"]} '
                      f'[{batch_idx*len(data)}/{len(train_loader.dataset)} '
                      f'({100. * batch_idx / len(train_loader):.0f}%)]\t'
                      f'Loss: {loss.item():.6f}')
        
        avg_train_loss = sum(train_losses) / len(train_losses)
        history['train_loss'].append(avg_train_loss)
        
        # Validation phase
        model.eval()
        valid_losses = []
        
        with torch.no_grad():
            for data, target in valid_loader:
                data, target = data.to(device), target.to(device)
                # No need for autocast in validation as we're not training
                output = model(data)
                loss = criterion(output, target)
                valid_losses.append(loss.item())
        
        avg_valid_loss = sum(valid_losses) / len(valid_losses)
        history['valid_loss'].append(avg_valid_loss)
        
        epoch_time = time.time() - start_time
        
        print(f'Epoch: {epoch+1} Train Loss: {avg_train_loss:.6f} Valid Loss: {avg_valid_loss:.6f} '
              f'Time: {epoch_time:.1f}s LR: {optimizer.param_groups[0]["lr"]:.8f}')
        
        # Learning rate scheduler step based on validation loss
        scheduler.step(avg_valid_loss)
        
        # Save the best model
        if avg_valid_loss < best_valid_loss:
            best_valid_loss = avg_valid_loss
            early_stop_counter = 0
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'scheduler_state_dict': scheduler.state_dict(),
                'loss': best_valid_loss,
            }, 'best_model.pth')
            print(f"Model saved with validation loss: {best_valid_loss:.6f}")
        else:
            early_stop_counter += 1
            
        # Early stopping check
        if early_stop_counter >= config['es_epochs']:
            print(f"Early stopping triggered after {epoch+1} epochs")
            break
    
    return model, history

In [37]:
# K-fold cross-validation function
def k_fold_cross_validation(model_class, inputs_files, output_files, criterion, config, n_folds=5):
    """
    Perform k-fold cross-validation
    """
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    kf = KFold(n_splits=n_folds, shuffle=True, random_state=config['seed'])
    
    # Convert to numpy arrays for easier indexing
    inputs_files = np.array(inputs_files)
    output_files = np.array(output_files)
    
    fold_results = []
    
    for fold, (train_idx, val_idx) in enumerate(kf.split(inputs_files)):
        print(f"Training fold {fold+1}/{n_folds}")
        
        # Get files for this fold
        train_input_files = inputs_files[train_idx]
        train_output_files = output_files[train_idx]
        val_input_files = inputs_files[val_idx]
        val_output_files = output_files[val_idx]
        
        # Create datasets and dataloaders
        # Add data augmentation to training
        transform = SeismicAugmentation(flip_prob=0.5, noise_prob=0.3, noise_level=0.05)
        
        train_dataset = SeismicDataset(train_input_files, train_output_files, 
                                      normalize=True, transform=transform)
        val_dataset = SeismicDataset(val_input_files, val_output_files, 
                                    normalize=True, transform=None)
        
        train_loader = DataLoader(
            train_dataset, 
            batch_size=config['batch_size'],
            shuffle=True,
            num_workers=4,
            pin_memory=True
        )
        
        val_loader = DataLoader(
            val_dataset,
            batch_size=config['batch_size'],
            shuffle=False,
            num_workers=4,
            pin_memory=True
        )
        
        # Initialize model, optimizer, and scheduler for this fold
        model = model_class(**config['model']['unet_params'])
        model.to(device)
        
        optimizer = torch.optim.AdamW(
            model.parameters(),
            lr=config['optimizer']['lr'],
            weight_decay=config['optimizer']['weight_decay']
        )
        
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            optimizer,
            **config['scheduler']['params']
        )
        
        # Train model
        model, history = train_with_mixed_precision(model, train_loader, val_loader, 
                                                  criterion, optimizer, scheduler, 
                                                  config, device)
        
        # Save fold results
        fold_results.append({
            'fold': fold + 1,
            'best_val_loss': min(history['valid_loss']),
            'history': history
        })
        
        # Save fold model
        torch.save({
            'fold': fold + 1,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
            'history': history,
        }, f'model_fold_{fold+1}.pth')
    
    # Calculate average performance across folds
    avg_best_val_loss = sum(result['best_val_loss'] for result in fold_results) / n_folds
    print(f"Average best validation loss across {n_folds} folds: {avg_best_val_loss:.6f}")
    
    return fold_results

In [38]:
# Function to create test predictions
def predict(model, test_loader, device):
    """Generate predictions for test data"""
    model.eval()
    predictions = {}
    
    with torch.no_grad():
        for data, file_names in test_loader:
            data = data.to(device)
            outputs = model(data)
            
            # Denormalize outputs if needed (assuming model outputs already in correct range)
            # Convert tensors to numpy arrays
            outputs_cpu = outputs.cpu().numpy()
            
            # Store predictions by file name
            for i, file_name in enumerate(file_names):
                predictions[file_name] = outputs_cpu[i]
    
    return predictions

# Function to create model ensemble
def ensemble_predict(models, test_loader, device):
    """Generate predictions using an ensemble of models"""
    for model in models:
        model.eval()
    
    predictions = {}
    
    with torch.no_grad():
        for data, file_names in test_loader:
            data = data.to(device)
            
            # Initialize outputs tensor
            ensemble_outputs = None
            
            # Accumulate predictions from each model
            for model in models:
                outputs = model(data)
                
                if ensemble_outputs is None:
                    ensemble_outputs = outputs
                else:
                    ensemble_outputs += outputs
            
            # Average predictions
            ensemble_outputs /= len(models)
            
            # Convert tensors to numpy arrays
            outputs_cpu = ensemble_outputs.cpu().numpy()
            
            # Store predictions by file name
            for i, file_name in enumerate(file_names):
                predictions[file_name] = outputs_cpu[i]
    
    return predictions

In [39]:
# Function to create TTA (Test Time Augmentation) predictions
def tta_predict(model, test_loader, device, n_augmentations=5):
    """Generate predictions using Test Time Augmentation"""
    model.eval()
    predictions = {}
    
    # Create augmentation object
    augmentation = SeismicAugmentation(flip_prob=1.0, noise_prob=0.0, noise_level=0.0)
    
    with torch.no_grad():
        for data, file_names in test_loader:
            # Original prediction
            data_orig = data.to(device)
            outputs_orig = model(data_orig)
            
            # Initialize outputs tensor with original prediction
            tta_outputs = outputs_orig.clone()
            
            # Apply augmentations and accumulate predictions
            for _ in range(n_augmentations):
                # Apply horizontal flip augmentation
                data_np = data.numpy()
                augmented_data = []
                
                for i in range(data_np.shape[0]):
                    # Flip horizontally
                    aug_data, _ = augmentation(data_np[i], data_np[i])
                    augmented_data.append(aug_data)
                
                # Convert back to tensor and predict
                data_aug = torch.tensor(np.array(augmented_data), dtype=torch.float).to(device)
                outputs_aug = model(data_aug)
                
                # Apply inverse augmentation to predictions (flip horizontally)
                outputs_aug_flipped = torch.flip(outputs_aug, dims=[3])  # flip along width dimension
                
                # Accumulate
                tta_outputs += outputs_aug_flipped
            
            # Average predictions
            tta_outputs /= (n_augmentations + 1)  # +1 for original prediction
            
            # Convert tensors to numpy arrays
            outputs_cpu = tta_outputs.cpu().numpy()
            
            # Store predictions by file name
            for i, file_name in enumerate(file_names):
                predictions[file_name] = outputs_cpu[i]
    
    return predictions

In [40]:
# Main function to run the training and evaluation process
def run_training(config):
    # Set seed for reproducibility
    set_seed(config['seed'])
    
    # Get device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f'Using device: {device}')
    
    # Get data files
    input_files, output_files = get_train_files(config['data_path'])
    print(f'Found {len(input_files)} input files and {len(output_files)} output files')
    
    # Set up loss function
    # Option 1: Simple MSE Loss
    # criterion = nn.MSELoss()
    
    # Option 2: Combined Loss (MSE + Gradient Loss)
    mse_loss = nn.MSELoss()
    gradient_loss = GradientLoss(lambda_grad=0.3)
    freq_loss = FrequencyDomainLoss(lambda_freq=0.3)
    
   # Combine all three losses
    criterion = CombinedLoss(
        [mse_loss, gradient_loss, freq_loss], 
        [0.8, 0.2, 0.0]  # Adjust weights according to importance
    )
    
    # Choose whether to use K-fold CV or a simple train/val split
    if config.get('use_kfold', False):
        # Use K-fold cross-validation
        print(f"Using {config['n_folds']}-fold cross-validation")
        model_class = globals()[config['model']['name']]
        fold_results = k_fold_cross_validation(
            model_class=model_class,
            inputs_files=input_files,
            output_files=output_files,
            criterion=criterion,
            config=config,
            n_folds=config.get('n_folds', 5)
        )
        
        # Load the best model from each fold for ensembling
        models = []
        for fold in range(config.get('n_folds', 5)):
            model = globals()[config['model']['name']](**config['model']['unet_params'])
            checkpoint = torch.load(f'model_fold_{fold+1}.pth')
            model.load_state_dict(checkpoint['model_state_dict'])
            model.to(device)
            models.append(model)
        
        # Get test files
        test_files = [f for f in Path(config['data_path']).rglob('*.npy')
                    if 'test' in f.stem]
        
        # Create test dataset and loader
        test_dataset = TestDataset(test_files)
        test_loader = DataLoader(
            test_dataset,
            batch_size=config['batch_size'],
            shuffle=False,
            num_workers=4,
            pin_memory=True
        )
        
        # Generate ensemble predictions
        print(f"Generating ensemble predictions from {len(models)} models")
        predictions = ensemble_predict(models, test_loader, device)
        
        # Save predictions
        for file_name, prediction in predictions.items():
            np.save(f'prediction_{file_name}.npy', prediction)
        
    else:
        # Use simple train/val split
        print("Using simple train/val split")
        
        # Shuffle the file lists
        indices = list(range(len(input_files)))
        random.shuffle(indices)
        input_files = [input_files[i] for i in indices]
        output_files = [output_files[i] for i in indices]
        
        # Split into train and validation sets
        val_size = len(input_files) // config['valid_frac']
        train_size = len(input_files) // config['train_frac'] if config.get('train_frac', 0) > 0 else None
        
        if train_size:
            val_input_files = input_files[:val_size]
            val_output_files = output_files[:val_size]
            train_input_files = input_files[val_size:val_size+train_size]
            train_output_files = output_files[val_size:val_size+train_size]
        else:
            val_input_files = input_files[:val_size]
            val_output_files = output_files[:val_size]
            train_input_files = input_files[val_size:]
            train_output_files = output_files[val_size:]
        
        print(f'Training with {len(train_input_files)} files, validating with {len(val_input_files)} files')
        
        # Create datasets with data augmentation for training
        transform = SeismicAugmentation(flip_prob=0.5, noise_prob=0.3, noise_level=0.05)
        
        train_dataset = SeismicDataset(train_input_files, train_output_files, 
                                      normalize=True, transform=transform)
        val_dataset = SeismicDataset(val_input_files, val_output_files, 
                                    normalize=True, transform=None)
        
        # Create data loaders
        train_loader = DataLoader(
            train_dataset,
            batch_size=config['batch_size'],
            shuffle=True,
            num_workers=4,
            pin_memory=True
        )
        
        val_loader = DataLoader(
            val_dataset,
            batch_size=config['batch_size'],
            shuffle=False,
            num_workers=4,
            pin_memory=True
        )
        
        # Create model
        model_class = globals()[config['model']['name']]
        model = model_class(**config['model']['unet_params'])
        model.to(device)
        
        # Print model summary
        print(model)
        total_params = sum(p.numel() for p in model.parameters())
        trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
        print(f'Total parameters: {total_params:,}')
        print(f'Trainable parameters: {trainable_params:,}')
        
        # Initialize optimizer and scheduler
        optimizer = torch.optim.AdamW(
            model.parameters(),
            lr=config['optimizer']['lr'],
            weight_decay=config['optimizer']['weight_decay']
        )
        
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            optimizer,
            **config['scheduler']['params']
        )
        
        # Train model with mixed precision
        model, history = train_with_mixed_precision(
            model=model,
            train_loader=train_loader,
            valid_loader=val_loader,
            criterion=criterion,
            optimizer=optimizer,
            scheduler=scheduler,
            config=config,
            device=device
        )
        
        # Plot training history
        plt.figure(figsize=(10, 5))
        plt.plot(history['train_loss'], label='Train Loss')
        plt.plot(history['valid_loss'], label='Validation Loss')
        plt.title('Training and Validation Loss')
        plt.xlabel('Epoch')
        plt.ylabel('Loss')
        plt.legend()
        plt.grid(True)
        plt.savefig('training_history.png')
        plt.show()
        
        # Get test files
        test_files = [f for f in Path(config['data_path']).rglob('*.npy')
                    if 'test' in f.stem]
        
        # Create test dataset and loader
        test_dataset = TestDataset(test_files)
        test_loader = DataLoader(
            test_dataset,
            batch_size=config['batch_size'],
            shuffle=False,
            num_workers=4,
            pin_memory=True
        )
        
        # Load best model for inference
        best_model = model_class(**config['model']['unet_params'])
        checkpoint = torch.load('best_model.pth')
        best_model.load_state_dict(checkpoint['model_state_dict'])
        best_model.to(device)
        
        # Generate predictions with TTA
        print("Generating predictions with Test Time Augmentation")
        predictions = tta_predict(best_model, test_loader, device, n_augmentations=5)
        
        # Save predictions
        for file_name, prediction in predictions.items():
            np.save(f'prediction_{file_name}.npy', prediction)
    
    print("Training and prediction complete!")

In [41]:
# Save config to YAML file
def create_config_file():
    config = {
        'data_path': '/kaggle/input/waveform-inversion',
        'model': {
            'name': 'UNet',
            'unet_params': {
                'n_channels': 5,
                'n_classes': 1,
                'init_features': 64,  # Increased from 32
                'depth': 5,
                'bilinear': True,
                'use_attention': True,  # Enable attention
                'use_se': True,  # Enable SE blocks
            }
        },
        'read_weights': None,
        'batch_size': 32,  # Reduced from 64 to handle more complex model
        'print_freq': 100,
        'max_epochs': 30,  # Increased from 20
        'es_epochs': 5,  # Early stopping patience
        'seed': 42,
        'valid_frac': 5,  # 1/5 of data for validation
        'train_frac': 0,  # Use all remaining data for training
        'use_kfold': True,  # Enable K-fold cross-validation
        'n_folds': 5,  # Number of folds
        'optimizer': {
            'lr': 1e-4,
            'weight_decay': 1e-3,
        },
        'scheduler': {
            'params': {
                'factor': 0.5,  # More gentle LR reduction
                'patience': 2,  # Wait longer before reducing LR
                'min_lr': 1e-6  # Minimum LR
                #'verbose': False,
            }
        }
    }
    
    with open('config.yaml', 'w') as f:
        yaml.dump(config, f, default_flow_style=False)
    
    print("Configuration file 'config.yaml' created.")
    return config

# Run Training

In [None]:
# usage
if __name__ == "__main__":
    # Create updated config
    config = create_config_file()
    
    # Save config to YAML file
    with open('config.yaml', 'w') as f:
        yaml.dump(config, f, default_flow_style=False)
    
    # Run training process
    run_training(config)

Configuration file 'config.yaml' created.
Using device: cuda


# Inference and Submission

In [None]:
# Inference and submission
def create_submission(config, model_path='best_model.pth'):   
    t0 = time.time()
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f'Using device: {device}')
    
    # Get test files
    test_dir = Path(config.get('test_path', 'input/test'))
    test_files = list(test_dir.glob("*.npy"))
    print(f'Found {len(test_files)} test files')
    
    # Define column names for submission
    x_cols = [f"x_{i}" for i in range(1, 70, 2)]
    fieldnames = ["oid_ypos"] + x_cols
    
    # Load model
    model_class = globals()[config['model']['name']]
    model = model_class(**config['model']['unet_params'])
    
    # Load from ensemble if using k-fold
    if config.get('use_kfold', False):
        print("Using ensemble model for predictions")
        models = []
        for fold in range(config.get('n_folds', 5)):
            fold_model = model_class(**config['model']['unet_params'])
            checkpoint = torch.load(f'model_fold_{fold+1}.pth', map_location=device)
            fold_model.load_state_dict(checkpoint['model_state_dict'])
            fold_model.to(device)
            fold_model.eval()
            models.append(fold_model)
        
        # Create test dataset
        test_dataset = TestDataset(test_files, normalize=True)
        test_loader = DataLoader(
            test_dataset,
            batch_size=config['batch_size'] * 2,  # Doubled batch size for inference
            shuffle=False,
            num_workers=4,
            pin_memory=True
        )
        
        # Generate ensemble predictions
        print(f"Generating ensemble predictions from {len(models)} models")
        
        # Open submission file
        with open("submission.csv", "wt", newline="") as csvfile:
            writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
            writer.writeheader()
            
            # Process each batch
            for inputs, oids_test in test_loader:
                inputs = inputs.to(device)
                
                # Initialize ensemble outputs
                ensemble_outputs = None
                
                with torch.no_grad():
                    with torch.autocast(device_type="cuda", dtype=torch.float16):
                        # Accumulate predictions from each model
                        for model in models:
                            outputs = model(inputs)
                            
                            if ensemble_outputs is None:
                                ensemble_outputs = outputs
                            else:
                                ensemble_outputs += outputs
                        
                        # Average predictions
                        ensemble_outputs /= len(models)
                
                # Extract predictions
                y_preds = ensemble_outputs[:, 0].cpu().numpy()
                
                # Write to CSV
                for y_pred, oid_test in zip(y_preds, oids_test):
                    for y_pos in range(70):
                        row = dict(zip(x_cols, [y_pred[y_pos, x_pos] for x_pos in range(1, 70, 2)]))
                        row["oid_ypos"] = f"{oid_test}_y_{y_pos}"
                        writer.writerow(row)
                
    else:
        # Using single model
        print("Using single best model for predictions")
        checkpoint = torch.load(model_path, map_location=device)
        model.load_state_dict(checkpoint['model_state_dict'])
        model.to(device)
        model.eval()
        
        # Create test dataset
        test_dataset = TestDataset(test_files, normalize=True)
        test_loader = DataLoader(
            test_dataset,
            batch_size=config['batch_size'] * 2,  # Doubled batch size for inference
            shuffle=False,
            num_workers=4,
            pin_memory=True
        )
        
        # Open submission file
        with open("submission.csv", "wt", newline="") as csvfile:
            writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
            writer.writeheader()
            
            # Process each batch
            for inputs, oids_test in test_loader:
                inputs = inputs.to(device)
                
                with torch.no_grad():
                    with torch.autocast(device_type="cuda", dtype=torch.float16):
                        outputs = model(inputs)
                
                # Extract predictions
                y_preds = outputs[:, 0].cpu().numpy()
                
                # Write to CSV
                for y_pred, oid_test in zip(y_preds, oids_test):
                    for y_pos in range(70):
                        row = dict(zip(x_cols, [y_pred[y_pos, x_pos] for x_pos in range(1, 70, 2)]))
                        row["oid_ypos"] = f"{oid_test}_y_{y_pos}"
                        writer.writerow(row)
    
    t1 = time.time() - t0
    print(f"Inference Time: {t1:.2f} seconds")
    print(f"Submission file created: submission.csv")

# Run inference and create submission
if __name__ == "__main__":
    # Load config
    with open('config.yaml', 'r') as f:
        config = yaml.safe_load(f)
    
    # Add test path to config if not present
    if 'test_path' not in config:
        config['test_path'] = '/kaggle/input/waveform-inversion/test'
    
    # Create submission
    create_submission(config)