In [3]:
import numpy as np
import pandas as pd
import os
import yaml
import time
import random
from pathlib import Path
from tqdm.auto import tqdm
import matplotlib.pyplot as plt
from sklearn.model_selection import KFold, train_test_split

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.cuda.amp import autocast, GradScaler
from torchsummary import summary

In [4]:
# Set random seeds for reproducibility
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

In [None]:
# Helper functions for data loading and preparation
def map_seismic_to_velocity_path(input_file):
    """Convert seismic data path to velocity model path"""
    return Path(str(input_file).replace('seis', 'vel').replace('data', 'model'))

def get_train_files(data_path):
    """Find all seismic data files and map to velocity model files"""
    # Find all seismic data files (containing 'seis' or 'data' in filename)
    input_files = [
        f for f in Path(data_path).rglob('*.npy')
        if ('seis' in f.stem) or ('data' in f.stem)
    ]
    
    # Map each input file to its corresponding output file
    output_files = [map_seismic_to_velocity_path(f) for f in input_files]
    
    # Verify all output files exist
    missing_files = [f for f in output_files if not f.exists()]
    if missing_files:
        raise FileNotFoundError(f"Missing velocity model files: {missing_files[:5]}...")
    
    return input_files, output_files

In [None]:
# Improved dataset classes with normalization
class SeismicDataset(Dataset):
    def __init__(self, inputs_files, output_files, n_examples_per_file=500, normalize=True, transform=None):
        assert len(inputs_files) == len(output_files)
        self.inputs_files = inputs_files
        self.output_files = output_files
        self.n_examples_per_file = n_examples_per_file
        self.normalize = normalize
        self.transform = transform
        
        # Calculate normalization statistics if needed
        if normalize:
            self.input_stats, self.output_stats = self._calculate_stats()

    def _calculate_stats(self):
        """Calculate mean and std for normalization"""
        # Sample a subset of files to calculate statistics
        input_means, input_stds = [], []
        output_means, output_stds = [], []
        
        sample_size = min(10, len(self.inputs_files))
        
        print(f"Calculating statistics from {sample_size} files...")
        for i in range(sample_size):
            X = np.load(self.inputs_files[i])
            y = np.load(self.output_files[i])
            
            # Calc stats for input
            input_means.append(np.mean(X))
            input_stds.append(np.std(X))
            
            # Calc stats for output
            output_means.append(np.mean(y))
            output_stds.append(np.std(y))
        
        input_mean = np.mean(input_means)
        input_std = np.std(input_stds) if np.std(input_stds) > 0 else 1.0
        output_mean = np.mean(output_means)
        output_std = np.std(output_stds) if np.std(output_stds) > 0 else 1.0
        
        print(f"Input stats - Mean: {input_mean:.4f}, Std: {input_std:.4f}")
        print(f"Output stats - Mean: {output_mean:.4f}, Std: {output_std:.4f}")
        
        return (input_mean, input_std), (output_mean, output_std)

    def __len__(self):
        return len(self.inputs_files) * self.n_examples_per_file

    def __getitem__(self, idx):
        # Calculate file offset and sample offset within file
        file_idx = idx // self.n_examples_per_file
        sample_idx = idx % self.n_examples_per_file
    
        X = np.load(self.inputs_files[file_idx], mmap_mode='r')
        y = np.load(self.output_files[file_idx], mmap_mode='r')
    
        try:
            X_sample, y_sample = X[sample_idx].copy(), y[sample_idx].copy()
            
            # Apply normalization if enabled
            if self.normalize:
                X_sample = (X_sample - self.input_stats[0]) / (self.input_stats[1] + 1e-8)
                y_sample = (y_sample - self.output_stats[0]) / (self.output_stats[1] + 1e-8)
            
            # Apply any additional transforms
            if self.transform:
                X_sample, y_sample = self.transform(X_sample, y_sample)
                
            # Ensure contiguous memory layout to avoid negative stride issues
            X_sample = np.ascontiguousarray(X_sample)
            y_sample = np.ascontiguousarray(y_sample)
                
            # Convert to torch tensors
            X_sample = torch.from_numpy(X_sample).float()
            y_sample = torch.from_numpy(y_sample).float()
                
            return X_sample, y_sample
        finally:
            del X, y

class TestDataset(Dataset):
    def __init__(self, test_files, normalize=True, input_stats=None):
        self.test_files = test_files
        self.normalize = normalize
        self.input_stats = input_stats

    def __len__(self):
        return len(self.test_files)

    def __getitem__(self, i):
        test_file = self.test_files[i]
        X = np.load(test_file)
        
        if self.normalize and self.input_stats:
            X = (X - self.input_stats[0]) / (self.input_stats[1] + 1e-8)
        
        X = torch.from_numpy(X).float()
        return X, test_file.stem

In [None]:
# Data augmentation class with fixed padding operation
class SeismicAugmentation:
    def __init__(self, 
                 flip_prob=0.5, 
                 noise_prob=0.3, 
                 noise_level=0.05,
                 shift_prob=0.3,
                 max_shift=5):
        self.flip_prob = flip_prob
        self.noise_prob = noise_prob
        self.noise_level = noise_level
        self.shift_prob = shift_prob
        self.max_shift = max_shift

    def __call__(self, X, y):
        # Make a copy to avoid modifying the original data
        X = X.copy()
        y = y.copy()
        
        # Horizontal flip
        if np.random.random() < self.flip_prob:
            X = np.flip(X, axis=1).copy()  # Use np.flip and make a copy to avoid negative strides
            y = np.flip(y, axis=1).copy()  # Use np.flip and make a copy to avoid negative strides
        
        # Add random noise to input
        if np.random.random() < self.noise_prob:
            noise = np.random.normal(0, self.noise_level, X.shape)
            X = X + noise
        
        # Random time shift (along time axis, usually dim 0)
        if np.random.random() < self.shift_prob:
            shift = np.random.randint(-self.max_shift, self.max_shift + 1)
            
            # Only apply shift if it's non-zero
            if shift != 0:
                # Create padding dimensions based on the array's actual shape
                pad_width = [(0, 0)] * X.ndim  # Initialize with no padding for all dimensions
                
                if shift > 0:
                    # Pad at the beginning of the second dimension (columns/width)
                    pad_width[1] = (shift, 0)
                    X = np.pad(X, pad_width, mode='constant')[:, :-shift]
                else:  # shift < 0
                    # Pad at the end of the second dimension
                    pad_width[1] = (0, -shift)
                    X = np.pad(X, pad_width, mode='constant')[:, -shift:]
        
        return X, y