In [1]:
# Add necessary imports at the top
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import timm
from timm.models.convnext import ConvNeXtBlock
from types import MethodType
import numpy as np
import random
import os
import time, glob
from tqdm import tqdm
from types import SimpleNamespace
import sys # Added for stderr
from copy import deepcopy

RUN_TRAIN = True # bfloat16 or float32 recommended
RUN_VALID = False
RUN_TEST  = False
USE_DEVICE = 'GPU' #'CPU'  # 'GPU'

  from .autonotebook import tqdm as notebook_tqdm


In [11]:
# --- Configuration ---
# Original cfg setup - replace or update as needed
# For this code, we add transformer specific config here
cfg= SimpleNamespace()
cfg.device = torch.device("cuda" if torch.cuda.is_available() and torch.cuda.device_count() > 0 else "cpu")
# cfg.local_rank = 0 # Assume single GPU/CPU for simplicity unless distributed training is set up
cfg.seed = 123
cfg.subsample = 100 #None # Set to None to use all available samples

# Assuming file paths are set up correctly
data_paths_str = ".\\datasetfiles\\FlatVel_A\\data\\*.npy"
label_paths_str = ".\\datasetfiles\\FlatVel_A\\model\\*.npy"

# Get all file pairs
# cfg.file_pairs = list(zip(sorted(glob.glob(data_paths_str)), sorted(glob.glob(label_paths_str))))
# Split file pairs for train/validation
data_paths = sorted(glob.glob(data_paths_str))
label_paths = sorted(glob.glob(label_paths_str))
all_file_pairs = list(zip(data_paths, label_paths))
# Simple split (e.g., 80% train, 20% validation)
split_ratio = 0.8
split_idx = int(len(all_file_pairs) * split_ratio)
train_file_pairs = all_file_pairs[:split_idx]
valid_file_pairs = all_file_pairs[split_idx:]


cfg.backbone = "convnext_small.fb_in22k_ft_in1k"
cfg.ema = True
cfg.ema_decay = 0.99

cfg.epochs = 4
cfg.batch_size = 8
cfg.batch_size_val = 8

cfg.early_stopping = {"patience": 3, "streak": 0}
cfg.logging_steps = 10
# --- New Transformer/Dataset related config ---
cfg.num_input_slices = 5 # Number of consecutive input samples (time slices) to stack as channels
# Inferred input height (H_in) for a single sample based on original stem logic transforming T=1000
# This is highly speculative and assumes the original dataset or stem implicitly maps 1000 -> 352
cfg.inferred_input_height = 352
cfg.input_width = 70 # Original waveform width and target output width

In [71]:

class CustomDatasetWithSlices(torch.utils.data.Dataset):
    def __init__(
        self,
        cfg,
        file_pairs,  # list of (data_path, label_path) tuples for this specific split
        mode = "train",
        num_input_slices: int = 3, # Number of consecutive samples to stack
        # inferred_input_height: int = 352, # Assumed H_in from original data processing
        # input_width: int = 70, # W_in from original data processing
        # output_height: int = 70, # H_out for labels
        # output_width: int = 70,  # W_out for labels
    ):
        self.cfg = cfg
        self.mode = mode
        print("[init]-self mode", self.mode)
        self.file_pairs = file_pairs
        print("[init]-self file pairs", self.file_pairs)
        self.num_input_slices = num_input_slices
        print("[init]-self num_input_slices", self.num_input_slices)
        # self.inferred_input_height = inferred_input_height # Use from cfg
        # self.input_width = input_width # Use from cfg
        # self.output_height = output_height # Use from cfg
        # self.output_width = output_width # Use from cfg

        # Load data.
        # ASSUMPTION: Each file pair (data.npy, model.npy) loads arrays
        #   - data.npy contains (N_samples_per_file, 5, H_in, 70)
        #   - model.npy contains (N_samples_per_file, 70, 70)
        # Where N_samples_per_file is the number of pre-processed samples in that file,
        # 5 is the fixed channel count, H_in is the fixed processed height (inferred from original stem logic),
        # and 70 is the width.
        self.data_arrays, self.label_arrays = self._load_data_arrays()
        print(f"[init]-self data_arrays {self.data_arrays[0][0][0][0][:3]} shape {self.data_arrays[0].shape} | label_arrays: {self.label_arrays[0][0][0][0][:3]}  shape {self.label_arrays[0].shape} ")

        # Validate and store shapes based on loaded data
        if not self.data_arrays:
             raise RuntimeError(f"[init]-No data files loaded for mode '{self.mode}'. Check file paths and format.")

        # Use the shape of the first loaded array to define dataset dimensions
        first_data_shape = self.data_arrays[0].shape
        first_label_shape = self.label_arrays[0].shape

        self.samples_per_file_ = first_data_shape[0] # Number of samples per file
        self.channels_in_single = first_data_shape[1] # Should be 5
        self.H_in = self.data_arrays[0].shape[2]        # Assumed fixed input height (matches cfg.inferred_input_height)
        self.W_in = self.data_arrays[0].shape[3]        # Should be 70 (matches cfg.input_width)
        self.H_out = self.label_arrays[0].shape[2]      # Should be 70 (matches cfg.output_height)
        self.W_out = self.label_arrays[0].shape[3]      # Should be 70 (matches cfg.output_width)

        # Validate against cfg expectations (optional but good)
        if self.channels_in_single != 5:
             print(f"[init]-Warning: Loaded data has {self.channels_in_single} channels, expected 5.", file=sys.stderr)
        # Note: Cannot strictly validate H_in against cfg.inferred_input_height here
        # as the cfg value was just an inference based on the original stem.
        # The loaded data's shape dictates the true H_in for the model.
        # self.cfg.inferred_input_height = self.H_in # Update cfg with actual loaded height
        if self.W_in != 70:
            print(f"[init]-Warning: Loaded data has width {self.W_in}, expected 70.", file=sys.stderr)
            self.cfg.input_width = self.W_in # Update cfg
        if self.H_out != 70 or self.W_out != 70:
            print(f"[init]-Warning: Loaded labels have shape ({self.H_out}, {self.W_out}), expected (70, 70).", file=sys.stderr)
            # Update cfg with actual dimensions if needed downstream
            self.cfg.output_height = self.H_out
            self.cfg.output_width = self.W_out
        
        # Update cfg with actual loaded dimensions for model compatibility
        self.cfg.inferred_input_height = self.H_in
        print(f"[init]-self.cfg.inferred_input_height {self.cfg.inferred_input_height} ")
        self.cfg.input_width = self.W_in # Matches W_in
        print(f"[init]-self.cfg.input_width {self.cfg.input_width} ")

        total_files = len(self.data_arrays)
        print(f"[init]-total_files {total_files} ")
        # Total number of *original* samples across all successfully loaded files
        total_samples_available = total_files * self.samples_per_file_

        subsample = getattr(self.cfg, "subsample", None)
        print(f"[init]-subsample {subsample} ")
        # Determine the total number of *effective* samples based on subsampling,
        # but the actual count in index_map might be slightly less due to padding
        # requirements at file boundaries.
        effective_subsample_limit = subsample if subsample and subsample > 0 else float('inf')
        print(f"[init]-effective_subsample_limit {effective_subsample_limit} ")

        # Build list of (file_idx, sample_center_idx) pairs
        # We select num_input_slices consecutive samples centered at sample_center_idx.
        pad = (self.num_input_slices - 1) // 2
        print(f"[init]-pad {pad} ")

        self.index_map = []
        current_effective_samples = 0
        
        for file_idx in range(total_files):
            print(f"[init]-file_idx {file_idx} ")
            N_samples_in_file = self.data_arrays[file_idx].shape[0] # Number of samples in this file's array

            # Iterate through possible center sample indices such that the window
            # [center - pad, center + pad] is entirely within [0, N_samples_in_file - 1].
            valid_start_idx = pad
            valid_end_idx = N_samples_in_file - pad - 1 # Inclusive end index for center

            # Check if there are enough samples in the file to form *at least one* window
            if valid_end_idx < valid_start_idx:
                print(f"[init]-Warning: File {file_idx} (with {N_samples_in_file} samples) is too short for window size {self.num_input_slices}. Skipping file for effective samples.", file=sys.stderr)
                continue # Skip this file if not enough samples for any window

            for sample_center_idx in range(valid_start_idx, valid_end_idx + 1):
                 self.index_map.append((file_idx, sample_center_idx))
                 current_effective_samples += 1
                 # Stop if subsample limit is reached for effective samples
                 if current_effective_samples >= effective_subsample_limit:
                     break
            # Stop if subsample limit is reached for effective samples
            if current_effective_samples >= effective_subsample_limit:
                 break

        self.total_effective_samples = len(self.index_map)

        print(f"[init]-Dataset initialized in {self.mode} mode.")
        print(f"[init]-Loaded {total_files} file pairs containing a total of {total_samples_available} raw samples.")
        print(f"[init]-Input shape per single slice: ({self.channels_in_single}, {self.H_in}, {self.W_in})")
        print(f"[init]-Output label shape: ({self.H_out}, {self.W_out})")
        print(f"[init]-Window size for stacking: {self.num_input_slices} slices (padding {pad} on each side).")
        print(f"[init]-Generated {self.total_effective_samples} effective samples for training/validation after considering windowing and subsampling.")


    def _load_data_arrays(self):
        """
        Loads data and label arrays from file pairs using mmap_mode for efficiency.
        Includes validation for expected shapes.
        """
        data_arrays_list = []
        label_arrays_list = []
        # Use 'r' mode always for memory efficiency with large datasets
        mmap_mode = "r"

        print(f"[_load_data_arrays] - Loading {self.mode} data using mmap_mode='{mmap_mode}'...")

        # Use local_rank to ensure tqdm is only shown on the main process in DDP
        disable_tqdm = getattr(self.cfg, 'local_rank', 0) != 0

        successful_loads = 0
        for data_fpath, label_fpath in tqdm(
                        self.file_pairs, desc=f"[_load_data_arrays / for loop] - Loading {self.mode} data (mmap)",
                        disable=disable_tqdm):
            print(f"[_load_data_arrays / for data_fpath] - data_fpath {data_fpath} | label_fpath={label_fpath} ")
            try:
                # Check if files exist before attempting to load
                if not os.path.exists(data_fpath):
                    print(f"[_load_data_arrays / try] - Warning: Data file not found: {data_fpath}. Skipping pair.", file=sys.stderr)
                    continue
                if not os.path.exists(label_fpath):
                    print(f"[_load_data_arrays / try] - Warning: Label file not found: {label_fpath}. Skipping pair.", file=sys.stderr)
                    continue

                # Load data with expected shape (N_samples, Channels, H_in, W_in)
                # For your data: (N, 5, 1000, 70)
                arr = np.load(data_fpath, mmap_mode=mmap_mode)
                print(f"[_load_data_arrays / try] - arr {arr[0][0][0][:10]} | arr shape {arr.shape}")
                # Load labels with expected shape (N_samples, H_out, W_out)
                # For your data: (N, 70, 70)
                lbl = np.load(label_fpath, mmap_mode=mmap_mode)
                print(f"[_load_data_arrays / try] - lbl {lbl[0][0][0][:10]} | lbl shape {lbl.shape} | lbl ndim {lbl.ndim}")

                # --- Basic shape validation based on YOUR specified structure ---
                expected_data_ndim = 4
                expected_label_ndim = 4
                expected_channels = 5 # Your specific channel count
                expected_data_width = 70 # Your specific GeoPhones dimension
                expected_label_height = 70 # Your specific output height
                expected_label_width = 70  # Your specific output width
                print(f"[_load_data_arrays / expected_label_width] - expected_label_width {expected_label_width} ")

                if arr.ndim != expected_data_ndim or \
                   arr.shape[1] != expected_channels or \
                   arr.shape[3] != expected_data_width:
                     print("[_load_data_arrays / if 1] ")
                     print(f"[_load_data_arrays / try / if] - Warning: Data file {data_fpath} has unexpected shape {arr.shape}. "
                           f"[_load_data_arrays / try / if] - Expected ndim={expected_data_ndim}, shape[1]={expected_channels} (channels), "
                           f"[_load_data_arrays / try / if] - shape[3]={expected_data_width} (width/geophones). Skipping.", file=sys.stderr)
                     continue

                if lbl.ndim != expected_label_ndim or \
                   lbl.shape[2] != expected_label_height or \
                   lbl.shape[3] != expected_label_width:
                     print("[_load_data_arrays / if 2] ")
                     print(f"[_load_data_arrays / try / if2] - Warning: Label file {label_fpath} has unexpected shape {lbl.shape}. "
                           f"[_load_data_arrays / try / if2] - Expected ndim={expected_label_ndim}, shape[1]={expected_label_height} (height), "
                           f"[_load_data_arrays / try / if2] - shape[2]={expected_label_width} (width). Skipping.", file=sys.stderr)
                     continue

                # Validate that the number of samples (batch dimension) matches
                if arr.shape[0] != lbl.shape[0]:
                     print("[_load_data_arrays / if 3] ")
                     print(f"[_load_data_arrays / try / if3] - Warning: Mismatch in number of samples (batch size) between data ({arr.shape[0]}) and label ({lbl.shape[0]}) "
                           f"[_load_data_arrays / try / if3] - in file pair {data_fpath}, {label_fpath}. Skipping.", file=sys.stderr)
                     continue

                # If it passes validation, add to lists
                print("[_load_data_arrays / try / before --- data_arrays_list] ")
                data_arrays_list.append(arr)
                print(f"[_load_data_arrays / try / data_arrays_list] - data_arrays_list {data_arrays_list[0][0][0][0][:3]} | data shape {arr.shape}")
                label_arrays_list.append(lbl)
                print(f"[_load_data_arrays / try / label_arrays_list] - label_arrays_list {label_arrays_list[0][0][0][0][:3]} | label shape {lbl.shape} ")
                successful_loads += 1
                print(f"[_load_data_arrays / try / successful_loads] - successful_loads {successful_loads} ")

            except FileNotFoundError:
                # This check is now redundant with the os.path.exists check above,
                # but keeping it doesn't hurt as a fallback.
                print(f"[_load_data_arrays / except] - Error: File not found - {data_fpath} or {label_fpath}. Skipping pair.", file=sys.stderr)
            except Exception as e:
                print(f"[_load_data_arrays / except Exception] - Error loading or validating file pair: {data_fpath}, {label_fpath}", file=sys.stderr)
                print(f"[_load_data_arrays / except Exception] - Error details: {e}", file=sys.stderr)
                # traceback.print_exc() # Uncomment for detailed error
                continue

        # if self.cfg.local_rank == 0: # Only print summary from main process
        print(f"[_load_data_arrays / try / if - end] -Finished loading {successful_loads} out of {len(self.file_pairs)} file pairs successfully for {self.mode} mode.")

        return data_arrays_list, label_arrays_list

    def __len__(self):
        """
        Returns the total number of effective samples available in the dataset.
        """
        print("[__len__]")
        return self.total_effective_samples

    def __getitem__(self, index):
        """
        Retrieves a single effective sample (input and label) based on the index.
        An effective sample consists of num_input_slices data slices stacked,
        and the corresponding label for the center slice.
        """
        print("[__getitem__]")
        if index < 0 or index >= self.total_effective_samples:
            raise IndexError(f"[__getitem__] - Index {index} out of bounds for dataset of size {self.total_effective_samples}")

        # Get the file index and the center sample index within that file
        file_idx, sample_center_idx = self.index_map[index]
        print(f"[__getitem__] - file_idx: {file_idx} | sample_center_idx: {sample_center_idx}")

        # Calculate the start and end indices for the window of slices
        pad = (self.num_input_slices - 1) // 2
        start_idx = sample_center_idx - pad
        end_idx = sample_center_idx + pad # This is inclusive

        # Retrieve the batch of consecutive data slices
        # The shape will be (num_input_slices, Channels, H_in, W_in)
        data_slices = self.data_arrays[file_idx][start_idx : end_idx + 1, ...]

        # Retrieve the label for the *center* slice
        # The shape will be (H_out, W_out)
        label_slice = self.label_arrays[file_idx][sample_center_idx, ...]

        # --- Stack the data slices ---
        # The original shape is (num_input_slices, Channels, H_in, W_in)
        # We want to combine the 'num_input_slices' and 'Channels' dimensions
        # into a single channel dimension, resulting in (num_input_slices * Channels, H_in, W_in).
        # This is a common way to represent stacked time series data as input channels for CNNs.
        combined_channels = self.num_input_slices * self.channels_in_single
        input_tensor = data_slices.reshape(combined_channels, self.H_in, self.W_in)

        # Convert numpy arrays to PyTorch tensors
        # Ensure correct data types (float32 is common for model inputs/outputs)
        input_tensor = torch.from_numpy(input_tensor).float()
        label_tensor = torch.from_numpy(label_slice).float() # Assuming regression output

        return input_tensor, label_tensor

In [72]:

train_ds = CustomDatasetWithSlices(cfg=cfg, file_pairs=train_file_pairs, mode="train", num_input_slices=cfg.num_input_slices)
train_dl = torch.utils.data.DataLoader(
    train_ds,
    batch_size= cfg.batch_size,
    num_workers= 4 if cfg.device.type == 'cuda' else 0, # Use more workers on GPU
    shuffle=True,
    pin_memory=cfg.device.type == 'cuda', # Pin memory for faster GPU transfer
    drop_last=True, # Drop last batch if batch size doesn't divide dataset size
)

[init]-self mode train
[init]-self file pairs [('.\\datasetfiles\\FlatVel_A\\data\\data1.npy', '.\\datasetfiles\\FlatVel_A\\model\\model1.npy')]
[init]-self num_input_slices 5
[_load_data_arrays] - Loading train data using mmap_mode='r'...


[_load_data_arrays / for loop] - Loading train data (mmap): 100%|██████████| 1/1 [00:00<00:00, 667.35it/s]

[_load_data_arrays / for data_fpath] - data_fpath .\datasetfiles\FlatVel_A\data\data1.npy | label_fpath=.\datasetfiles\FlatVel_A\model\model1.npy 
[_load_data_arrays / try] - arr [-0.00038193  0.          0.          0.          0.          0.
  0.          0.          0.          0.        ] | arr shape (500, 5, 1000, 70)
[_load_data_arrays / try] - lbl [1524. 1524. 1524. 1524. 1524. 1524. 1524. 1524. 1524. 1524.] | lbl shape (500, 1, 70, 70) | lbl ndim 4
[_load_data_arrays / expected_label_width] - expected_label_width 70 
[_load_data_arrays / try / before --- data_arrays_list] 
[_load_data_arrays / try / data_arrays_list] - data_arrays_list [-0.00038193  0.          0.        ] | data shape (500, 5, 1000, 70)
[_load_data_arrays / try / label_arrays_list] - label_arrays_list [1524. 1524. 1524.] | label shape (500, 1, 70, 70) 
[_load_data_arrays / try / successful_loads] - successful_loads 1 
[_load_data_arrays / try / if - end] -Finished loading 1 out of 1 file pairs successfully for


