In [None]:
""" 

はい、元のConvNeXtベースのモデル構造を理解し、時系列情報とジオフォン位置のずれ情報を取り込むためにTransformer要素を導入するコードを生成します。

現在のコードを分析した結果、いくつか不明瞭な点がありますが、最も可能性の高いデータフローとモデル構造を推測し、それに沿ってTransformerを組み込みます。特に、`CustomDataset`でのデータローディングと`Net`クラスの`_update_stem`の処理に関する推測が含まれるため、実際のデータ形式によっては調整が必要になる可能性があります。

**推測されるデータフローとモデル構造:**

1.  **Raw Data:** `(T=1000, W=70)`または `(C=5, T=1000, W=70)` のwaveformデータと `(H=70, W=70)` のラベルデータ。
2.  **`CustomDataset`:** 各ファイルから`N_samples_per_file`個のサンプルを読み込みます。元のコードの形状プリント`(5, 1000, 70), (70, 70)`と、
    `in_chans=5`、出力`(70, 70)`から、各サンプルは恐らくロード後に`(5, T=1000, W=70)`と`(H=70, W=70)`のペアとして扱われていると考えられます。
    しかし、`__getitem__`では`time_step_idx`でスライスされており、`(5, 70)`のような形状になるはずですが、ConvNeXtは`(C, H, W)`を入力とします。ここが不明瞭です。
3.  **最も可能性の高い解釈:** データセットは`(N_samples_per_file, 5, H_in, 70)`の形状のデータと`(N_samples_per_file, 70, 70)`のラベルをロードしている。
     ここで`H_in`は元の`T=1000`から何らかの方法で変換された固定サイズ（例: 352）。`N_samples_per_file`は、元の`T=1000`から切り出せる独立したサンプル数
     （例: 500）。`__getitem__`は、この`(N_samples, 5, H_in, 70)`から1サンプル`(5, H_in, 70)`を抜き出している。
4.  **`Net` (`_update_stem`):** モデルへの入力は`(B, 5, H_in, 70)`であり、`_update_stem`はConvNeXtのstem層をこの特定の`(H_in, 70)`形状に適応させている。
    特に高さ方向(`H_in`)に対して強いダウンサンプリングを行っている。
5.  **ConvNeXt Encoder:** stem層の後、標準的な2D ConvNeXtステージ（stride 1, 2, 2, 2）で特徴を抽出し、マルチスケール特徴マップを生成する。
6.  **U-Net Decoder:** Encoderからのマルチスケール特徴を入力として、アップサンプリングとSkip Connectionにより解像度を上げ、最終的に`(B, 1, 70, 70)`の出力を得る。

**Transformer導入方針:**

上記の解釈に基づき、TransformerをConvNeXt U-Net構造に組み込みます。
時系列情報を取り込むため、`CustomDataset`を変更し、単一の時間ステップに対応するサンプルだけでなく、
**複数の連続する時間ステップに対応するサンプルをスタック**して入力チャンネルとして使用します。
これにより、モデルは近傍の時間ステップの情報を参照できるようになります。
Transformerブロックは、ConvNeXt Encoderの中間層の後に挿入し、空間的な特徴に大域的なアテンションを適用します。

**実装の変更点:**

1.  **`CustomDataset`の修正:** `num_input_slices`パラメータを追加し、指定された数の連続するサンプル
（各サンプルは`(5, H_in, 70)`形状を想定）をチャンネル次元で結合して返すように変更します。入力形状は`(B, 5 * num_input_slices, H_in, 70)`になります。
2.  **`Net`クラスの修正:**
    *   クラス名を`NetWithTransformer`とします。
    *   `__init__`で`num_input_slices`を受け取り、ConvNeXt backboneの`in_chans`を`5 * num_input_slices`に変更します。
    *   `_update_stem`関数を修正し、新しい`in_chans`に対応できるように、stemの最初の畳み込み層の入力チャンネル数を変更します。
         `H_in`は引き続き固定値として扱います（元のstem logicから推測される352を仮定）。
    *   ConvNeXt Encoderの中間ステージ（例: Stage 0の後）の出力に、`TransformerBlock2d`を挿入します。
        このTransformerブロックは、2D特徴マップをシーケンスに平坦化し、位置エンコーディングを加えた後、Self-Attentionを適用し、再び2D形状に戻します。
    *   Decoderへの入力は、Transformerブロックの出力で元のEncoderステージの出力を置き換えた、マルチスケール特徴マップのリストを使用します。
3.  **`TransformerBlock2d`と`PositionalEncoding2D`の実装:** 2D特徴マップ用のTransformerエンコーダーブロックと位置エンコーディング層を定義します。

以下に修正・追加したコードを生成します。

"""

In [15]:
# Add necessary imports at the top
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import timm
from timm.models.convnext import ConvNeXtBlock
from types import MethodType
import numpy as np
import random
import os
import time, glob
from tqdm import tqdm
from types import SimpleNamespace
import sys # Added for stderr
from copy import deepcopy
from timm.models.features import FeatureListNet

RUN_TRAIN = True # bfloat16 or float32 recommended
RUN_VALID = False
RUN_TEST  = False
USE_DEVICE = 'GPU' #'CPU'  # 'GPU'

In [2]:
# --- Configuration ---
# Original cfg setup - replace or update as needed
# For this code, we add transformer specific config here
cfg= SimpleNamespace()
cfg.device = torch.device("cuda" if torch.cuda.is_available() and torch.cuda.device_count() > 0 else "cpu")
# cfg.local_rank = 0 # Assume single GPU/CPU for simplicity unless distributed training is set up
cfg.seed = 123
cfg.subsample = 100 #None # Set to None to use all available samples

# Assuming file paths are set up correctly
data_paths_str = "./datasetfiles/FlatVel_A/data/*.npy"
label_paths_str = "./datasetfiles/FlatVel_A/model/*.npy"

# Get all file pairs
# cfg.file_pairs = list(zip(sorted(glob.glob(data_paths_str)), sorted(glob.glob(label_paths_str))))
# Split file pairs for train/validation
data_paths = sorted(glob.glob(data_paths_str))
label_paths = sorted(glob.glob(label_paths_str))
all_file_pairs = list(zip(data_paths, label_paths))
# Simple split (e.g., 80% train, 20% validation)
split_ratio = 0.8
split_idx = int(len(all_file_pairs) * split_ratio)
train_file_pairs = all_file_pairs[:split_idx]
valid_file_pairs = all_file_pairs[split_idx:]


cfg.backbone = "convnext_small.fb_in22k_ft_in1k"
cfg.ema = True
cfg.ema_decay = 0.99

cfg.epochs = 4
cfg.batch_size = 8
cfg.batch_size_val = 8

cfg.early_stopping = {"patience": 3, "streak": 0}
cfg.logging_steps = 10

In [3]:
# --- New Transformer/Dataset related config ---
cfg.num_input_slices = 5 # Number of consecutive input samples (time slices) to stack as channels
# Inferred input height (H_in) for a single sample based on original stem logic transforming T=1000
# This is highly speculative and assumes the original dataset or stem implicitly maps 1000 -> 352
cfg.inferred_input_height = 352
cfg.input_width = 70 # Original waveform width and target output width

cfg.transformer_config = {
    'num_layers': 2,      # Number of Transformer Encoder layers
    'num_heads': 8,       # Number of attention heads
    'hidden_dim': None,   # Hidden dimension for Transformer (defaults to input channels at insertion point)
    # Decoder channels and scale factors should match the structure needed to upsample
    # from the bottleneck/skip resolutions of the *modified* encoder to 70x70.
    # Assuming the original decoder config was correct for the original encoder output resolutions.
    # The decoder will receive [stage3, stage2, stage1, stage0_transformer_out, stem_out].
    # These correspond to specific spatial sizes.
    # If stage0 output (where Transformer is inserted) has spatial size H', W', and the decoder
    # has 5 levels with scale 2, the bottleneck spatial size is roughly H'/16, W'/16.
    # If the final output is 70x70, and the last decoder output is 70x70 or 72x72 for cropping,
    # the scales should match the reduction factors of the encoder stages.
    # Let's assume the original decoder config (256, 128, 64, 32, 32) and (2,2,2,2,2)
    # is correct for the spatial sizes output by the ConvNeXt stages [stage3, stage2, stage1, stage0, stem].
    'decoder_channels': (256, 128, 64, 32, 32), # 5 levels matching 5 encoder features
    'scale_factors': (2,2,2,2,2),           # 5 factors matching 5 decoder levels
}


try:
    import monai
    from monai.networks.blocks import UpSample, SubpixelUpsample # Ensure these are imported
except ImportError:
    print("MONAI not found. Please install it ('pip install monai').")
    sys.exit(1) # Exit if MONAI is required and not found

import datetime

def format_time(elapsed):
    elapsed_rounded = int(round((elapsed)))
    return str(datetime.timedelta(seconds=elapsed_rounded))

def set_seed(seed=cfg.seed):
    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    # Removed torch.cuda.manual_seed, torch.backends.cudnn settings

In [4]:

class CustomDataset(torch.utils.data.Dataset):
    def __init__(
        self, 
        cfg,
        file_pairs,  #list of (data_path, label_path) tuples for this specific split
        mode = "train", 
    ):
        self.cfg = cfg
        self.mode = mode
        self.file_pairs = file_pairs
        
        self.data, self.labels = self._load_data_arrays()

        self.samples_per_file = 500  # assuming each file has 500 time steps
        total_samples_available = len(self.data) * self.samples_per_file

        # Subsample logic
        subsample = getattr(self.cfg, "subsample", None)
        self.total_samples = min(subsample, total_samples_available) if subsample else total_samples_available

        
        # Build list of (file_idx, time_step_idx) pairs
        self.index_map = []
        for file_idx in range(len(self.data)):
            for time_step_idx in range(self.samples_per_file):
                self.index_map.append((file_idx, time_step_idx))
                if len(self.index_map) >= self.total_samples:
                    break
            if len(self.index_map) >= self.total_samples:
                break

    def _load_data_arrays(self, ):
               
        data_arrays = []
        label_arrays = []
        mmap_mode = "r"

        for data_fpath, label_fpath in tqdm(
                        self.file_pairs, desc=f"Loading {self.mode} data (mmap)",
                        disable=self.cfg.local_rank != 0):
            try:
                # Load the numpy arrays using memory mapping
                arr = np.load(data_fpath, mmap_mode=mmap_mode)
                lbl = np.load(label_fpath, mmap_mode=mmap_mode)
                print(f"Loaded {data_fpath}: {arr.shape}, {lbl.shape}")
                data_arrays.append(arr)
                label_arrays.append(lbl)
            except FileNotFoundError:
                print(f"Error: File not found - {data_fpath} r {label_fpath}", file=sys.stderr)
            except Exception as e:
                print(f"Error loading file pari: {data_fpath}, {label_fpath}", file=sys.stderr)
                print(f"Error: {e}", file=sys.stderr)
                continue

            if self.cfg.local_rank == 0:
                print(f"Finished loading {len(data_arrays)} file pairs for {self.mode} mode.")

        return data_arrays, label_arrays

    def __getitem__(self, idx):
        # file_idx= idx // 500
        # time_step_idx= idx % 500
        # self.idx = idx

        file_idx, time_step_idx = self.index_map[idx]
        
        x_full = self.data[file_idx]
        y_full = self.labels[file_idx]

        # --- Augmentations ---
        # Apply augmentations to the full 3D blocks *before* slicing out the time step.
        # Make copies after slicing and augmentation to ensure memory safety.
        x_augmented = x_full
        y_augmented = y_full

        # Augs 
        if self.mode == "train":
            
            # Temporal flip
            if np.random.random() < 0.5:
                x_augmented = x_full[::-1, :, ::-1] # Time flip (dim 0), Spatial flip (dim 2)
                y_augmented = y_full[..., ::-1]  # Spatial flip (dim 2) only

        # --- Slicing and Copying ---
        # Get the specific time step from the (potentially augmented) full array
        # This reslts in a 2D array (Dim1, Dim2)
        x_sample = x_augmented[time_step_idx, ...]
        y_sample = y_augmented[time_step_idx, ...]

        # make copies to return independent arrays/tensors.
        # This is important especially with mmap and multiprocessing DataLoaders.
        x_sample = x_sample.copy()
        y_sample = y_sample.copy()

        x_tensor = torch.from_numpy(x_sample).float()
        y_tensor = torch.from_numpy(y_sample).float()
        
        return x_tensor, y_tensor

    def __len__(self, ):
        return self.total_samples

In [18]:
# --- Transformer Modules ---

class ModelEMA(nn.Module):
    def __init__(self, model, decay=0.99, device=None):
        super().__init__()
        self.module = deepcopy(model)
        self.module.eval()
        self.decay = decay
        self.device = device
        if self.device is not None:
            self.module.to(device=device)

    def _update(self, model, update_fn):
        with torch.no_grad():
            for ema_v, model_v in zip(self.module.state_dict().values(), model.state_dict().values()):
                if self.device is not None:
                    model_v = model_v.to(device=self.device)
                ema_v.copy_(update_fn(ema_v, model_v))

    def update(self, model):
        self._update(model, update_fn=lambda e, m: self.decay * e + (1. - self.decay) * m)

    def set(self, model):
        self._update(model, update_fn=lambda e, m: m)

class PositionalEncoding2D(nn.Module):
    """
    2D positional encoding for Transformer.
    Adds sin/cos embeddings based on height and width coordinates.
    Assumes input sequence is flattened from (H, W) to (H*W).
    """
    def __init__(self, d_model, height, width):
        super().__init__()
        if d_model % 4 != 0:
            raise ValueError("d_model must be divisible by 4 for 2D positional encoding")

        d_model_half = d_model // 2
        d_model_quarter = d_model // 4

        # Compute the positional encodings
        pe = torch.zeros(d_model_half, height, width) # Combine H and W encodings
        position_h = torch.arange(0., height).unsqueeze(1) # (H, 1)
        position_w = torch.arange(0., width).unsqueeze(1)  # (W, 1)

        div_term = torch.exp(torch.arange(0., d_model_quarter, 2) * -(math.log(10000.0) / d_model_quarter)) # (d_model/8)

        # PE for height dimension
        pe[0:d_model_quarter:2, :, :] = torch.sin(position_h * div_term).transpose(0, 1).unsqueeze(2).repeat(1, 1, width)
        pe[1:d_model_quarter:2, :, :] = torch.cos(position_h * div_term).transpose(0, 1).unsqueeze(2).repeat(1, 1, width)

        # PE for width dimension
        pe[d_model_quarter:d_model_half:2, :, :] = torch.sin(position_w * div_term).transpose(0, 1).unsqueeze(1).repeat(1, height, 1)
        pe[d_model_quarter+1:d_model_half:2, :, :] = torch.cos(position_w * div_term).transpose(0, 1).unsqueeze(1).repeat(1, height, 1)

        # Extend to full d_model if needed (e.g., repeat) or truncate if odd
        pe = pe[:d_model_half, :, :] # Ensure we have d_model_half channels

        # Final PE shape (d_model_half, H, W). Need to project/match d_model later.
        # Store as (1, d_model_half, H, W)
        self.register_buffer('pe', pe.unsqueeze(0))

        # Need a linear projection if d_model_half != feature_channels
        # Or add PE directly to features if d_model_half == feature_channels
        # Let's assume we add PE to the flattened sequence directly, so PE should have dim d_model.
        # Re-calculating PE to have d_model channels:
        pe_full = torch.zeros(d_model, height, width)
        pe_full[:d_model_half, :, :] = pe # Use the computed pe_half
        # Could add another set of sin/cos or zeros for the second half if needed,
        # depending on how d_model is structured in the Transformer layer.
        # Simplest: assume d_model is for concatenation or addition. Let's make PE have d_model channels.

        pe = torch.zeros(d_model, height, width)
        position_h = torch.arange(0., height).unsqueeze(1) # (H, 1)
        position_w = torch.arange(0., width).unsqueeze(1)  # (W, 1)

        # PE for height dimension (using first d_model/2 channels)
        div_term_h = torch.exp(torch.arange(0., d_model//2, 2) * -(math.log(10000.0) / (d_model//2)))
        pe[0:d_model//2:2, :, :] = torch.sin(position_h * div_term_h).transpose(0, 1).unsqueeze(2).repeat(1, 1, width)
        pe[1:d_model//2:2, :, :] = torch.cos(position_h * div_term_h).transpose(0, 1).unsqueeze(2).repeat(1, 1, width)

        # PE for width dimension (using second d_model/2 channels)
        div_term_w = torch.exp(torch.arange(0., d_model//2, 2) * -(math.log(10000.0) / (d_model//2)))
        pe[d_model//2::2, :, :] = torch.sin(position_w * div_term_w).transpose(0, 1).unsqueeze(1).repeat(1, height, 1)
        pe[d_model//2+1::2, :, :] = torch.cos(position_w * div_term_w).transpose(0, 1).unsqueeze(1).repeat(1, height, 1)

        self.register_buffer('pe', pe.unsqueeze(0)) # (1, d_model, height, width)


    def forward(self, x):
        # x shape: (B, C, H, W) where C should match d_model of transformer
        # PE shape: (1, d_model, H, W)
        # Add PE to the features. Requires C == d_model.
        # If C != d_model, a projection is needed before adding PE.
        # Assuming input features x are projected to d_model first.
        # The TransformerBlock2d handles the projection. PE is added to the sequence *after* projection.

        # PE needs to be reshaped to (1, H*W, d_model) to match the sequence shape
        H, W = x.shape[-2], x.shape[-1]
        pe_seq = self.pe[:, :, :H, :W].view(1, self.pe.shape[1], -1).permute(0, 2, 1) # (1, H*W, d_model)
        return pe_seq


class TransformerBlock2d(nn.Module):
    """
    Applies a Transformer Encoder block to a 2D feature map.
    Input (B, C, H, W) -> Project channels -> Flatten (B, H*W, hidden_dim)
    -> Add Positional Encoding -> Transformer Encoder -> Reshape (B, hidden_dim, H, W)
    -> Project channels back -> Add skip connection.
    """
    def __init__(self, in_channels, hidden_dim, num_layers=1, num_heads=8, height=None, width=None):
        super().__init__()
        self.in_channels = in_channels
        self.hidden_dim = hidden_dim
        self.height = height
        self.width = width

        if height is None or width is None:
             # If height/width are not provided, cannot initialize PE.
             # This means the block must be initialized after knowing the exact feature map size.
             # Or use a dynamic PE (less common) or no PE.
             # For this implementation, height and width must be provided.
             raise ValueError("Height and Width must be provided for Positional Encoding and Transformer block")

        # Project input channels to hidden_dim for Transformer
        # Use a simple Conv2d 1x1
        self.proj_in = nn.Conv2d(in_channels, hidden_dim, kernel_size=1)

        # Positional Encoding
        self.pos_embed = PositionalEncoding2D(hidden_dim, height, width) # PE has shape (1, H*W, hidden_dim)

        # Transformer Encoder
        # batch_first=True means input/output shape is (batch_size, sequence_length, features)
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=hidden_dim,
            nhead=num_heads,
            dim_feedforward=hidden_dim * 4, # Typical feedforward dimension
            batch_first=True,
            norm_first=True # Pre-LayerNorm
        )
        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)

        # Project hidden_dim back to in_channels (for residual connection)
        self.proj_out = nn.Conv2d(hidden_dim, in_channels, kernel_size=1)

        # Layer norm after projection out and residual connection (optional, depends on design)
        # self.norm = nn.LayerNorm(in_channels) # Applied to (B, C, H, W) -> needs transpose


    def forward(self, x):
        # x: (B, C, H, W) - Input feature map
        B, C, H, W = x.shape

        # Check if spatial dimensions match the PE size (should be handled during init)
        if H != self.height or W != self.width:
             raise ValueError(f"Input spatial size mismatch. Expected ({self.height}, {self.width}), got ({H}, {W})")

        # Project channels
        x_proj = self.proj_in(x) # (B, hidden_dim, H, W)

        # Flatten spatial dimensions and permute to (B, H*W, hidden_dim)
        x_seq = x_proj.view(B, self.hidden_dim, -1).permute(0, 2, 1) # (B, H*W, hidden_dim)

        # Add positional encoding (PE shape: (1, H*W, hidden_dim))
        # PE is added broadcastingly over the batch dimension
        x_seq = x_seq + self.pos_embed(x_proj) # Add PE to the sequence

        # Transformer Encoder
        transformer_output_seq = self.transformer_encoder(x_seq) # (B, H*W, hidden_dim)

        # Reshape back to (B, hidden_dim, H, W)
        transformer_output_reshaped = transformer_output_seq.permute(0, 2, 1).view(B, self.hidden_dim, H, W)

        # Project channels back to original in_channels
        out = self.proj_out(transformer_output_reshaped) # (B, in_channels, H, W)

        # Add residual connection
        out = out + x

        # Optional final norm
        # out = self.norm(out.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) # Apply LayerNorm over channel dim

        return out


# --- Custom Dataset ---

class CustomDatasetWithSlices(torch.utils.data.Dataset):
    def __init__(
        self,
        cfg,
        file_pairs,  # list of (data_path, label_path) tuples for this specific split
        mode = "train",
        num_input_slices: int = 3, # Number of consecutive samples to stack
        # inferred_input_height: int = 352, # Assumed H_in from original data processing
        # input_width: int = 70, # W_in from original data processing
        # output_height: int = 70, # H_out for labels
        # output_width: int = 70,  # W_out for labels
    ):
        self.cfg = cfg
        self.mode = mode
        self.file_pairs = file_pairs
        self.num_input_slices = num_input_slices
        # self.inferred_input_height = inferred_input_height # Use from cfg
        # self.input_width = input_width # Use from cfg
        # self.output_height = output_height # Use from cfg
        # self.output_width = output_width # Use from cfg

        # Load data.
        # ASSUMPTION: Each file pair (data.npy, model.npy) loads arrays
        #   - data.npy contains (N_samples_per_file, 5, H_in, 70)
        #   - model.npy contains (N_samples_per_file, 70, 70)
        # Where N_samples_per_file is the number of pre-processed samples in that file,
        # 5 is the fixed channel count, H_in is the fixed processed height (inferred from original stem logic),
        # and 70 is the width.
        self.data_arrays, self.label_arrays = self._load_data_arrays()

        # Validate and store shapes based on loaded data
        if not self.data_arrays:
             raise RuntimeError(f"No data files loaded for mode '{self.mode}'. Check file paths and format.")

        # Use the shape of the first loaded array to define dataset dimensions
        first_data_shape = self.data_arrays[0].shape
        first_label_shape = self.label_arrays[0].shape

        self.samples_per_file_ = first_data_shape[0] # Number of samples per file
        self.channels_in_single = first_data_shape[1] # Should be 5
        self.H_in = self.data_arrays[0].shape[2]        # Assumed fixed input height (matches cfg.inferred_input_height)
        self.W_in = self.data_arrays[0].shape[3]        # Should be 70 (matches cfg.input_width)
        self.H_out = self.label_arrays[0].shape[2]      # Should be 70 (matches cfg.output_height)
        self.W_out = self.label_arrays[0].shape[3]      # Should be 70 (matches cfg.output_width)

        # Validate against cfg expectations (optional but good)
        if self.channels_in_single != 5:
             print(f"Warning: Loaded data has {self.channels_in_single} channels, expected 5.", file=sys.stderr)
        # Note: Cannot strictly validate H_in against cfg.inferred_input_height here
        # as the cfg value was just an inference based on the original stem.
        # The loaded data's shape dictates the true H_in for the model.
        # self.cfg.inferred_input_height = self.H_in # Update cfg with actual loaded height
        if self.W_in != 70:
            print(f"Warning: Loaded data has width {self.W_in}, expected 70.", file=sys.stderr)
            self.cfg.input_width = self.W_in # Update cfg
        if self.H_out != 70 or self.W_out != 70:
            print(f"Warning: Loaded labels have shape ({self.H_out}, {self.W_out}), expected (70, 70).", file=sys.stderr)
            # Update cfg with actual dimensions if needed downstream
            self.cfg.output_height = self.H_out
            self.cfg.output_width = self.W_out
        
        # Update cfg with actual loaded dimensions for model compatibility
        self.cfg.inferred_input_height = self.H_in
        self.cfg.input_width = self.W_in # Matches W_in

        total_files = len(self.data_arrays)
        # Total number of *original* samples across all successfully loaded files
        total_samples_available = total_files * self.samples_per_file_

        subsample = getattr(self.cfg, "subsample", None)
        # Determine the total number of *effective* samples based on subsampling,
        # but the actual count in index_map might be slightly less due to padding
        # requirements at file boundaries.
        effective_subsample_limit = subsample if subsample and subsample > 0 else float('inf')

        # Build list of (file_idx, sample_center_idx) pairs
        # We select num_input_slices consecutive samples centered at sample_center_idx.
        pad = (self.num_input_slices - 1) // 2

        self.index_map = []
        current_effective_samples = 0
        
        for file_idx in range(total_files):
            N_samples_in_file = self.data_arrays[file_idx].shape[0] # Number of samples in this file's array

            # Iterate through possible center sample indices such that the window
            # [center - pad, center + pad] is entirely within [0, N_samples_in_file - 1].
            valid_start_idx = pad
            valid_end_idx = N_samples_in_file - pad - 1 # Inclusive end index for center

            # Check if there are enough samples in the file to form *at least one* window
            if valid_end_idx < valid_start_idx:
                print(f"Warning: File {file_idx} (with {N_samples_in_file} samples) is too short for window size {self.num_input_slices}. Skipping file for effective samples.", file=sys.stderr)
                continue # Skip this file if not enough samples for any window

            for sample_center_idx in range(valid_start_idx, valid_end_idx + 1):
                 self.index_map.append((file_idx, sample_center_idx))
                 current_effective_samples += 1
                 # Stop if subsample limit is reached for effective samples
                 if current_effective_samples >= effective_subsample_limit:
                     break
            # Stop if subsample limit is reached for effective samples
            if current_effective_samples >= effective_subsample_limit:
                 break

        self.total_effective_samples = len(self.index_map)

        print(f"Dataset initialized in {self.mode} mode.")
        print(f"Loaded {total_files} file pairs containing a total of {total_samples_available} raw samples.")
        print(f"Input shape per single slice: ({self.channels_in_single}, {self.H_in}, {self.W_in})")
        print(f"Output label shape: ({self.H_out}, {self.W_out})")
        print(f"Window size for stacking: {self.num_input_slices} slices (padding {pad} on each side).")
        print(f"Generated {self.total_effective_samples} effective samples for training/validation after considering windowing and subsampling.")


    def _load_data_arrays(self):
        """
        Loads data and label arrays from file pairs using mmap_mode for efficiency.
        Includes validation for expected shapes.
        """
        data_arrays_list = []
        label_arrays_list = []
        # Use 'r' mode always for memory efficiency with large datasets
        mmap_mode = "r"

        print(f"Loading {self.mode} data using mmap_mode='{mmap_mode}'...")

        # Use local_rank to ensure tqdm is only shown on the main process in DDP
        disable_tqdm = getattr(self.cfg, 'local_rank', 0) != 0

        successful_loads = 0
        for data_fpath, label_fpath in tqdm(
                        self.file_pairs, desc=f"Loading {self.mode} data (mmap)",
                        disable=disable_tqdm):
            try:
                # Check if files exist before attempting to load
                if not os.path.exists(data_fpath):
                    print(f"Warning: Data file not found: {data_fpath}. Skipping pair.", file=sys.stderr)
                    continue
                if not os.path.exists(label_fpath):
                    print(f"Warning: Label file not found: {label_fpath}. Skipping pair.", file=sys.stderr)
                    continue

                # Load data with expected shape (N_samples, Channels, H_in, W_in)
                # For your data: (N, 5, 1000, 70)
                arr = np.load(data_fpath, mmap_mode=mmap_mode)
                # Load labels with expected shape (N_samples, H_out, W_out)
                # For your data: (N, 70, 70)
                lbl = np.load(label_fpath, mmap_mode=mmap_mode)

                # --- Basic shape validation based on YOUR specified structure ---
                expected_data_ndim = 4
                expected_label_ndim = 4
                expected_channels = 5 # Your specific channel count
                expected_data_width = 70 # Your specific GeoPhones dimension
                expected_label_height = 70 # Your specific output height
                expected_label_width = 70  # Your specific output width

                if arr.ndim != expected_data_ndim or \
                   arr.shape[1] != expected_channels or \
                   arr.shape[3] != expected_data_width:
                     print(f"Warning: Data file {data_fpath} has unexpected shape {arr.shape}. "
                           f"Expected ndim={expected_data_ndim}, shape[1]={expected_channels} (channels), "
                           f"shape[3]={expected_data_width} (width/geophones). Skipping.", file=sys.stderr)
                     continue

                if lbl.ndim != expected_label_ndim or \
                   lbl.shape[2] != expected_label_height or \
                   lbl.shape[3] != expected_label_width:
                     print(f"Warning: Label file {label_fpath} has unexpected shape {lbl.shape}. "
                           f"Expected ndim={expected_label_ndim}, shape[1]={expected_label_height} (height), "
                           f"shape[2]={expected_label_width} (width). Skipping.", file=sys.stderr)
                     continue

                # Validate that the number of samples (batch dimension) matches
                if arr.shape[0] != lbl.shape[0]:
                     print(f"Warning: Mismatch in number of samples (batch size) between data ({arr.shape[0]}) and label ({lbl.shape[0]}) "
                           f"in file pair {data_fpath}, {label_fpath}. Skipping.", file=sys.stderr)
                     continue

                # If it passes validation, add to lists
                data_arrays_list.append(arr)
                label_arrays_list.append(lbl)
                successful_loads += 1

            except FileNotFoundError:
                # This check is now redundant with the os.path.exists check above,
                # but keeping it doesn't hurt as a fallback.
                print(f"Error: File not found - {data_fpath} or {label_fpath}. Skipping pair.", file=sys.stderr)
            except Exception as e:
                print(f"Error loading or validating file pair: {data_fpath}, {label_fpath}", file=sys.stderr)
                print(f"Error details: {e}", file=sys.stderr)
                # traceback.print_exc() # Uncomment for detailed error
                continue

        # if self.cfg.local_rank == 0: # Only print summary from main process
        print(f"Finished loading {successful_loads} out of {len(self.file_pairs)} file pairs successfully for {self.mode} mode.")

        return data_arrays_list, label_arrays_list

    def __len__(self):
        """
        Returns the total number of effective samples available in the dataset.
        """
        return self.total_effective_samples

    def __getitem__(self, index):
        """
        Retrieves a single effective sample (input and label) based on the index.
        An effective sample consists of num_input_slices data slices stacked,
        and the corresponding label for the center slice.
        """
        if index < 0 or index >= self.total_effective_samples:
            raise IndexError(f"Index {index} out of bounds for dataset of size {self.total_effective_samples}")

        # Get the file index and the center sample index within that file
        file_idx, sample_center_idx = self.index_map[index]

        # Calculate the start and end indices for the window of slices
        pad = (self.num_input_slices - 1) // 2
        start_idx = sample_center_idx - pad
        end_idx = sample_center_idx + pad # This is inclusive

        # Retrieve the batch of consecutive data slices
        # The shape will be (num_input_slices, Channels, H_in, W_in)
        data_slices = self.data_arrays[file_idx][start_idx : end_idx + 1, ...]

        # Retrieve the label for the *center* slice
        # The shape will be (H_out, W_out)
        label_slice = self.label_arrays[file_idx][sample_center_idx, ...]

        # --- Stack the data slices ---
        # The original shape is (num_input_slices, Channels, H_in, W_in)
        # We want to combine the 'num_input_slices' and 'Channels' dimensions
        # into a single channel dimension, resulting in (num_input_slices * Channels, H_in, W_in).
        # This is a common way to represent stacked time series data as input channels for CNNs.
        combined_channels = self.num_input_slices * self.channels_in_single
        input_tensor = data_slices.reshape(combined_channels, self.H_in, self.W_in)

        # Convert numpy arrays to PyTorch tensors
        # Ensure correct data types (float32 is common for model inputs/outputs)
        input_tensor = torch.from_numpy(input_tensor).float()
        label_tensor = torch.from_numpy(label_slice).float() # Assuming regression output

        return input_tensor, label_tensor

# --- Decoder (Keep original, ensure it works with modified encoder outputs) ---
# The UnetDecoder2d class remains the same. It takes a list of encoder features
# and skip connections. We will feed it the features from our modified encoder.
# (Paste ConvBnAct2d class here)

class ConvBnAct2d(nn.Module):
    def __init__(
        self,
        in_channels,
        out_channels,
        kernel_size,
        padding: int = 0,
        stride: int = 1,
        norm_layer: nn.Module = nn.Identity,
        act_layer: nn.Module = nn.ReLU,
    ):
        super().__init__()

        self.conv= nn.Conv2d(
            in_channels, 
            out_channels,
            kernel_size,
            stride=stride, 
            padding=padding, 
            bias=False,
        )
        self.norm = norm_layer(out_channels) if norm_layer != nn.Identity else nn.Identity()
        self.act= act_layer(inplace=True)

    def forward(self, x):
        x = self.conv(x)
        x = self.norm(x)
        x = self.act(x)
        return x


# (Paste SCSEModule2d class here)


class SCSEModule2d(nn.Module):
    def __init__(self, in_channels, reduction=16):
        super().__init__()
        self.cSE = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(in_channels, in_channels // reduction, 1),
            nn.Tanh(),
            nn.Conv2d(in_channels // reduction, in_channels, 1),
            nn.Sigmoid(),
        )
        self.sSE = nn.Sequential(
            nn.Conv2d(in_channels, 1, 1), 
            nn.Sigmoid(),
            )

    def forward(self, x):
        return x * self.cSE(x) + x * self.sSE(x)
    
# (Paste Attention2d class here)


class Attention2d(nn.Module):
    def __init__(self, name, **params):
        super().__init__()
        if name is None:
            self.attention = nn.Identity(**params)
        elif name == "scse":
            self.attention = SCSEModule2d(**params)
        else:
            raise ValueError("Attention {} is not implemented".format(name))

    def forward(self, x):
        return self.attention(x)
    

# (Paste DecoderBlock2d class here)


class DecoderBlock2d(nn.Module):
    def __init__(
        self,
        in_channels,
        skip_channels,
        out_channels,
        norm_layer: nn.Module = nn.Identity,
        attention_type: str = None,
        intermediate_conv: bool = False,
        upsample_mode: str = "deconv",
        scale_factor: int = 2,
    ):
        super().__init__()

        # Upsample block
        if upsample_mode == "pixelshuffle":
            self.upsample= SubpixelUpsample(
                spatial_dims= 2,
                in_channels= in_channels,
                scale_factor= scale_factor,
            )
        else:
            self.upsample = UpSample(
                spatial_dims= 2,
                in_channels= in_channels,
                out_channels= in_channels,
                scale_factor= scale_factor,
                mode= upsample_mode,
            )

        if intermediate_conv:
            k= 3
            c= skip_channels if skip_channels != 0 else in_channels
            self.intermediate_conv = nn.Sequential(
                ConvBnAct2d(c, c, k, k//2),
                ConvBnAct2d(c, c, k, k//2),
                )
        else:
            self.intermediate_conv= None

        self.attention1 = Attention2d(
            name= attention_type, 
            in_channels= in_channels + skip_channels,
            )

        self.conv1 = ConvBnAct2d(
            in_channels + skip_channels,
            out_channels,
            kernel_size= 3,
            padding= 1,
            norm_layer= norm_layer,
        )

        self.conv2 = ConvBnAct2d(
            out_channels,
            out_channels,
            kernel_size= 3,
            padding= 1,
            norm_layer= norm_layer,
        )
        self.attention2 = Attention2d(
            name= attention_type, 
            in_channels= out_channels,
            )

    def forward(self, x, skip=None):
        x = self.upsample(x)

        if self.intermediate_conv is not None:
            if skip is not None:
                skip = self.intermediate_conv(skip)
            else:
                x = self.intermediate_conv(x)

        if skip is not None:
            # print(x.shape, skip.shape)
            x = torch.cat([x, skip], dim=1)
            x = self.attention1(x)

        x = self.conv1(x)
        x = self.conv2(x)
        x = self.attention2(x)
        return x
    

# (Paste UnetDecoder2d class here)

class UnetDecoder2d(nn.Module):
    """
    Unet decoder.
    Source: https://arxiv.org/abs/1505.04597
    """
    def __init__(
        self,
        encoder_channels: tuple[int],
        skip_channels: tuple[int] = None,
        decoder_channels: tuple = (256, 128, 64, 32),
        scale_factors: tuple = (2,2,2,2),
        norm_layer: nn.Module = nn.Identity,
        attention_type: str = None,
        intermediate_conv: bool = False,
        upsample_mode: str = "deconv",
    ):
        super().__init__()
        
        if len(encoder_channels) == 4:
            decoder_channels= decoder_channels[1:]
        self.decoder_channels= decoder_channels
        
        if skip_channels is None:
            skip_channels= list(encoder_channels[1:]) + [0]

        # Build decoder blocks
        in_channels= [encoder_channels[0]] + list(decoder_channels[:-1])
        self.blocks = nn.ModuleList()

        for i, (ic, sc, dc) in enumerate(zip(in_channels, skip_channels, decoder_channels)):
            # print(i, ic, sc, dc)
            self.blocks.append(
                DecoderBlock2d(
                    ic, sc, dc, 
                    norm_layer= norm_layer,
                    attention_type= attention_type,
                    intermediate_conv= intermediate_conv,
                    upsample_mode= upsample_mode,
                    scale_factor= scale_factors[i],
                    )
            )

    def forward(self, feats: list[torch.Tensor]):
        res= [feats[0]]
        feats= feats[1:]

        # Decoder blocks
        for i, b in enumerate(self.blocks):
            skip= feats[i] if i < len(feats) else None
            res.append(
                b(res[-1], skip=skip),
                )
            
        return res
    
# (Paste SegmentationHead2d class here)

class SegmentationHead2d(nn.Module):
    def __init__(
        self,
        in_channels,
        out_channels,
        scale_factor: tuple[int] = (2,2),
        kernel_size: int = 3,
        mode: str = "nontrainable",
    ):
        super().__init__()
        self.conv= nn.Conv2d(
            in_channels, out_channels, kernel_size= kernel_size,
            padding= kernel_size//2
        )
        self.upsample = UpSample(
            spatial_dims= 2,
            in_channels= out_channels,
            out_channels= out_channels,
            scale_factor= scale_factor,
            mode= mode,
        )

    def forward(self, x):
        x = self.conv(x)
        x = self.upsample(x)
        return x        
        

class NetWithTransformer(nn.Module):
    """
    A ConvNeXt U-Net style model with a Transformer block inserted after an early encoder stage.
    Takes stacked input slices as channels to incorporate temporal context.
    """
    def __init__(
        self,
        cfg,
        backbone: str,
        pretrained: bool = True,
    ):
        super().__init__()
        self.cfg = cfg
        num_input_slices = cfg.num_input_slices
        print(f"[NetWithTransformer - init / num_input_slices] : {num_input_slices}")
        transformer_config = cfg.transformer_config

        # Assumed input dimensions for a single sample (after dataset preprocessing)
        # These will be updated by the dataset upon loading if necessary.
        self.H_in = cfg.inferred_input_height # Assumed input height for a single slice
        print(f"[NetWithTransformer - init / self.H_in] : {self.H_in}")
        self.W_in = cfg.input_width          # Assumed input width
        print(f"[NetWithTransformer - init / self.W_in] : {self.W_in}")

        # Encoder backbone (ConvNeXt stages, WITHOUT the initial stem as provided by timm)
        # We will handle the initial layers (stem) separately.
        # Create the full ConvNeXt model, then we will modify its stem and use its stages.
        # Set in_chans to a temporary value (e.g., 5 or 3) here, as it will be updated in _update_stem.
        # self.backbone = timm.create_model(
        #     backbone,
        #     in_chans=5, # This will be replaced in _update_stem
        #     pretrained=pretrained,
        #     features_only=True, # Returns list of features after stem and each stage
        #     drop_path_rate=0.0, # Keep original drop_path_rate or set via cfg if needed
        # )
        real_backbone = timm.create_model(
            backbone,
            pretrained=False,
            features_only=True  # so you can access `stem`
        )
        feature_info = real_backbone.feature_info
        self.backbone = FeatureListNet(
            real_backbone,
            out_indices=[0,1,2,3],
            output_fmt='default',
        )
        print(f"[NetWithTransformer - init / self.backbone] : {self.backbone.default_cfg['architecture']}")

        # Modify the stem to handle stacked input channels and the inferred height/width
        # This function also determines the exact shape of features after the stem and stages.
        # It will update self.transformer_stage_idx_in_features, transformer_in_channels, etc.
        self._update_stem(
            self.backbone,
            in_chans=5 * num_input_slices,
            input_height=self.H_in,
            input_width=self.W_in,
        )

        # Transformer Block - Initialized in _update_stem after spatial dims are known
        # self.transformer_block = ...

        # Decoder - Initialized in _update_stem after actual encoder channel counts are known
        # self.decoder = ...

        # Seg Head - Initialized in _update_stem
        # self.seg_head = ...

        # Apply replacements to backbone and decoder modules
        self.replace_activations(self.backbone) # Apply to ConvNeXt backbone
        self.replace_norms(self.backbone)     # Apply to ConvNeXt backbone
        # Apply custom forward for ConvNeXt blocks if needed (copied from original code)
        self.replace_forwards(self.backbone)

        # Apply replacements to the decoder and seg head as well
        self.replace_activations(self.decoder)
        self.replace_norms(self.decoder)
        self.replace_activations(self.seg_head)
        self.replace_norms(self.seg_head)


    def _update_stem(self, backbone, in_chans, input_height, input_width):
        """
        Adapts the ConvNeXt stem to handle 'in_chans' and determines
        the actual spatial dimensions of feature maps after each stage.
        Initializes the Transformer block and Decoder based on these dimensions.
        """
        print(f"[NetWithTransformer - _update_stem] : inside update_stems")
        # if not backbone.name.startswith("convnext"):
        if not backbone.default_cfg['architecture'].startswith("convnext"):
            raise ValueError("Custom stem modification implemented only for convnext backbone.")

        # Find the first Conv2d layer in the stem (assuming it's in a Sequential module)
        first_conv_layer = None
        # Iterate through children, looking for the first Conv2d within the stem
        # for name, module in backbone.stem.named_children():
        real_backbone = timm.models.get_model(backbone.default_cfg['architecture'], pretrained=False)
        for name, module in real_backbone.stem.named_children():
            print(f"[NetWithTransformer - name] : {name}")
            if isinstance(module, nn.Conv2d):
                first_conv_layer = module
                break
            # Also check within children if the stem has nested Sequential/modules
            for sub_name, sub_module in module.named_modules():
                 if isinstance(sub_module, nn.Conv2d):
                      first_conv_layer = sub_module
                      break
            if first_conv_layer: break

        print(f"[NetWithTransformer - first_conv_layer in None section] : first_conv_layer: {first_conv_layer}")
        if first_conv_layer is None:
             raise RuntimeError("Could not find the first Conv2d layer in the ConvNeXt stem.")

        # Create a new first convolution layer with updated in_channels
        new_first_conv = nn.Conv2d(
            in_chans,
            first_conv_layer.out_channels,
            kernel_size=first_conv_layer.kernel_size,
            stride=first_conv_layer.stride,
            padding=first_conv_layer.padding,
            bias=hasattr(first_conv_layer, 'bias') and first_conv_layer.bias is not None
        )
        # Copy weights for the first original 5 channels (if pretrained)
        if hasattr(first_conv_layer, 'weight') and first_conv_layer.weight is not None:
             # Simple initialization: copy weights for the original 5 channels, zero out/repeat others
             # Or just use default initialization if complex handling is not needed
             # For simplicity, use default Kaiming init of the new layer.
             pass # new_first_conv.weight.data[:, :original_in_chans, ...] = first_conv_layer.weight.data


        # Replace the original first conv layer within the stem module(s).
        # This is fragile depending on how timm builds the stem Sequential.
        # A safer way might be to create a new stem module from scratch and replace backbone.stem.
        # Let's try creating a new Sequential module.
        
        # Assuming original stem was something like:
        # Sequential( ReflectionPad2d, Conv2d (stride=(4,1)), Conv2d (stride=(4,1)) )
        
        # Reconstruct the stem with the new first convolution.
        # This requires knowing the order and types of layers in the original stem.
        # Based on the original code, it seems like Pad, Conv, Conv.
        
        new_stem_modules = []
        conv_count = 0
        for name, module in backbone.stem.named_children():
             if isinstance(module, nn.Conv2d) and conv_count == 0:
                  new_stem_modules.append(new_first_conv)
                  conv_count += 1
                  print(f"Replaced first Conv2d in stem with new layer handling {in_chans} channels.")
             else:
                  new_stem_modules.append(module)

        if conv_count == 0:
             # Handle cases where the first layer is not Conv2d, but a child contains it.
             # This gets complicated. Let's assume the first Conv2d is a direct child.
             raise RuntimeError("Could not find/replace the first Conv2d layer as a direct child in the ConvNeXt stem.")
             
        # Replace the backbone's stem with the new sequential module
        backbone.stem = nn.Sequential(*new_stem_modules)

        print(f"Updated ConvNeXt stem to handle {in_chans} input channels.")
        print(f"Assumed model input shape: (B, {in_chans}, {input_height}, {input_width})")


        # --- Determine Feature Map Shapes and Initialize Transformer/Decoder ---
        
        # Get actual output shapes after stem and stages from the modified backbone
        # Use dummy input with the actual expected shape (B, in_chans, input_height, input_width)
        dummy_input_for_shape = torch.randn(1, in_chans, input_height, input_width).to(self.cfg.device)
        
        # Pass dummy data through backbone to get feature shapes at each output point
        # `features_only=True` returns a list of tensors.
        # For ConvNeXt in timm, this list usually contains outputs after stem and each stage:
        # [stem_out, stage0_out, stage1_out, stage2_out, stage3_out].
        
        try:
            # Temporarily disable gradients for shape inference
            with torch.no_grad():
                # The backbone's forward method with features_only=True handles the sequential pass
                features_list = self.backbone(dummy_input_for_shape)
                features_shapes = [f.shape for f in features_list]
        except Exception as e:
            print(f"Error during dummy forward pass for shape inference: {e}", file=sys.stderr)
            print("Please check input dimensions, stem modification logic, and backbone compatibility.", file=sys.stderr)
            raise

        # features_shapes list corresponds to [stem_out, stage0_out, stage1_out, stage2_out, stage3_out]
        # The index in this list where we insert the Transformer. Stage 0 is index 1.
        self.transformer_stage_idx_in_features = 1 
        if self.transformer_stage_idx_in_features >= len(features_shapes):
             raise ValueError(f"Transformer insertion index {self.transformer_stage_idx_in_features} is out of bounds for backbone features ({len(features_shapes)} levels).")

        # Get channels and spatial size for Transformer input from the chosen stage output shape
        transformer_in_channels = features_shapes[self.transformer_stage_idx_in_features][1]
        transformer_spatial_h = features_shapes[self.transformer_stage_idx_in_features][2]
        transformer_spatial_w = features_shapes[self.transformer_stage_idx_in_features][3]

        print(f"Inserting Transformer after backbone feature index {self.transformer_stage_idx_in_features}.")
        print(f"Transformer input feature shape determined: (B, {transformer_in_channels}, {transformer_spatial_h}, {transformer_spatial_w})")

        # Initialize the Transformer Block
        transformer_config = self.cfg.transformer_config
        self.transformer_block = TransformerBlock2d(
            in_channels=transformer_in_channels,
            hidden_dim=transformer_config.get('hidden_dim', transformer_in_channels), # Use config or default
            num_layers=transformer_config.get('num_layers', 1),
            num_heads=transformer_config.get('num_heads', 8),
            height=transformer_spatial_h, # Pass determined spatial dimensions
            width=transformer_spatial_w,  # Pass determined spatial dimensions
        )
        print("Initialized TransformerBlock2d.")

        # --- Initialize the Decoder ---
        # The decoder needs encoder_channels corresponding to the channel counts
        # of the features it receives, in reverse order of spatial size (bottleneck to high-res skip).
        # features_list is [stem, stage0, stage1, stage2, stage3]
        # Decoder receives [stage3, stage2, stage1, stage0_transformer_out, stem]
        # The channel counts are the second dimension of the shapes.
        # The list of channel counts in the order the decoder expects skips/bottleneck:
        # [stage3_chs, stage2_chs, stage1_chs, stage0_chs, stem_chs]
        ecs = [s[1] for s in features_shapes[::-1]]

        # Check if the number of encoder channels matches the decoder levels + 1 bottleneck
        # Decoder channels tuple size should be ecs size - 1 (skips) + 1 (bottleneck)
        # No, decoder_channels is the list of output channels *of the decoder blocks*.
        # UnetDecoder2d's `encoder_channels` argument is the list of *input* channel counts
        # it receives from the encoder side. This list is `ecs`.
        # The number of levels in `decoder_channels` and `scale_factors` should match
        # the number of decoder blocks, which is typically len(encoder_channels) - 1.
        # ecs has 5 elements. So decoder should have 4 blocks.
        # BUT the original decoder config (256, 128, 64, 32, 32) has 5 elements, and scale_factors has 5 elements.
        # This suggests the original decoder might have had 5 blocks, possibly using stem output as a skip.
        # Let's stick to the 5-level decoder as in original config.
        # UnetDecoder2d takes `encoder_channels`, which are the channels of the features *before* the decoder blocks.
        # The first element of `encoder_channels` is the bottleneck input channels.
        # The remaining elements are the skip connection channels, in decreasing spatial resolution order.
        # So `encoder_channels` = [stage3_chs, stage2_chs, stage1_chs, stage0_chs, stem_chs]. This matches `ecs`.

        # Ensure decoder_channels and scale_factors lengths match
        decoder_channels_cfg = transformer_config.get('decoder_channels', (256, 128, 64, 32)) # Default 4 levels?
        scale_factors_cfg = transformer_config.get('scale_factors', (2,2,2,2))             # Default 4 factors?

        # Check consistency between encoder feature levels and decoder levels
        num_encoder_features = len(ecs) # 5 features: stage3, stage2, stage1, stage0, stem
        # The decoder has num_encoder_features - 1 skip connections and 1 bottleneck input.
        # So, typically len(decoder_channels) should be num_encoder_features - 1.
        # Original config had 5 decoder_channels and 5 scale_factors. This is unusual.
        # It might mean the decoder structure is slightly different, or one level was added/handled differently.
        # Let's adjust the default decoder levels to match the common U-Net structure (N encoder features -> N-1 decoder blocks).
        # If ecs has 5 elements, decoder should have 4 blocks.
        # Let's use the last 4 elements of the original decoder_channels config.
        
        default_decoder_channels = (256, 128, 64, 32) # 4 levels for 5 encoder features
        default_scale_factors = (2,2,2,2)
        
        decoder_channels_cfg = transformer_config.get('decoder_channels', default_decoder_channels)
        scale_factors_cfg = transformer_config.get('scale_factors', default_scale_factors)

        if len(decoder_channels_cfg) != num_encoder_features - 1:
             print(f"Warning: Number of decoder levels ({len(decoder_channels_cfg)}) does not match standard U-Net structure for {num_encoder_features} encoder features. Expected {num_encoder_features - 1}.", file=sys.stderr)
             # Adjust if lengths are inconsistent, or trust user config if provided explicitly
             if transformer_config.get('decoder_channels') is None:
                  print("Using default decoder channels matching encoder levels - 1.", file=sys.stderr)
                  decoder_channels_cfg = default_decoder_channels
             if transformer_config.get('scale_factors') is None:
                  print("Using default scale factors matching decoder levels.", file=sys.stderr)
                  scale_factors_cfg = default_scale_factors
                  
        if len(decoder_channels_cfg) != len(scale_factors_cfg):
             raise ValueError(f"Number of decoder channels ({len(decoder_channels_cfg)}) must match number of scale factors ({len(scale_factors_cfg)}).")


        self.decoder= UnetDecoder2d(
            encoder_channels= ecs, # Use actual channel counts from modified backbone
            decoder_channels = decoder_channels_cfg,
            scale_factors = scale_factors_cfg,
            norm_layer = nn.InstanceNorm2d, # Use InstanceNorm consistent with replacement
            attention_type = 'scse', # Keep original attention
        )
        print("Initialized UnetDecoder2d.")
        print(f"Decoder levels: {len(self.decoder.blocks)}")
        print(f"Decoder input channels (encoder_channels): {self.decoder.encoder_channels}")
        print(f"Decoder output channels (decoder_channels): {self.decoder.decoder_channels}")
        print(f"Decoder scale factors: {self.decoder.blocks[0].upsample.scale_factor}") # Check first block

        # Initialize Seg Head
        self.seg_head= SegmentationHead2d(
            in_channels= self.decoder.decoder_channels[-1], # Last decoder block output channels
            out_channels= 1,
            scale_factor= 1, # Assuming the last decoder block outputs at the final spatial resolution (70x70 or 72x72)
            mode = "nontrainable" # Keep original mode
        )
        print("Initialized SegmentationHead2d.")


    # --- Replacement Helper Methods ---

    def replace_activations(self, module):
        """Replaces all activations with GELU recursively, avoiding Transformer block."""
        for name, child in module.named_children():
            # Skip replacements within the TransformerBlock2d
            if isinstance(child, TransformerBlock2d):
                 continue

            if isinstance(child, (
                nn.ReLU, nn.LeakyReLU, nn.Mish, nn.Sigmoid,
                nn.Tanh, nn.Softmax, nn.Hardtanh, nn.ELU,
                nn.SELU, nn.PReLU, nn.CELU, nn.GELU, nn.SiLU,
            )):
                # Only replace if it's *not* GELU or SiLU already (common in newer models)
                if not isinstance(child, (nn.GELU, nn.SiLU)):
                    setattr(module, name, nn.GELU())
            else:
                self.replace_activations(child)

    def replace_norms(self, mod):
        """Replaces BatchNorm2d with InstanceNorm2d recursively, avoiding Transformer block."""
        for name, c in mod.named_children():
            # Skip replacements within the TransformerBlock2d
            if isinstance(c, TransformerBlock2d):
                 continue

            # Replace BatchNorm2d with InstanceNorm2d
            if isinstance(c, nn.BatchNorm2d):
                 n_feats= c.num_features
                 new = nn.InstanceNorm2d(
                     n_feats,
                     affine=True, # Keep affine=True as in original replacement
                     )
                 setattr(mod, name, new)
                 # print(f"Replaced BatchNorm2d {name} with InstanceNorm2d")
            # Optional: Replace other norms if necessary, but avoid LayerNorm within standard Transformer layers
            # elif isinstance(c, nn.LayerNorm):
            #      # Careful: Only replace LayerNorms that are part of the ConvNeXt/UNet structure if needed
            #      # E.g., ConvNeXt blocks have LayerNorm.
            #      pass # Avoid replacing LayerNorm unless specifically targeted

            else:
                # Recursively apply to children
                self.replace_norms(c)

    def replace_forwards(self, mod):
        """Replaces forward pass for specific modules like ConvNeXtBlock if needed."""
        # Original code replaced ConvNeXtBlock forward. Keep this logic.
        for name, c in mod.named_children():
            # Skip Transformer block
            if isinstance(c, TransformerBlock2d):
                 continue

            if isinstance(c, ConvNeXtBlock):
                # Apply the custom forward method
                c.forward = MethodType(_convnext_block_forward, c)
            else:
                self.replace_forwards(c)

    # --- Forward Pass ---

    def forward(self, batch):
        # Input batch shape is (B, 5 * N_slices, H_in, 70)
        x = batch

        # Run through the modified backbone.
        # It returns a list of features: [stem_out, stage0_out, stage1_out, stage2_out, stage3_out].
        feats = self.backbone(x)

        # Apply Transformer block to the output of the specified stage (e.g., stage0)
        # The feature list index for stage0 output is self.transformer_stage_idx_in_features (should be 1)
        stage_output_for_transformer = feats[self.transformer_stage_idx_in_features]
        transformer_output = self.transformer_block(stage_output_for_transformer)

        # Replace the original stage output with the transformer output in the feature list
        modified_feats = list(feats) # Create a copy to not modify the original list from backbone
        modified_feats[self.transformer_stage_idx_in_features] = transformer_output

        # Decoder input expects features in reverse order of spatial size (bottleneck to high-res skip)
        # features_list from backbone: [stem_out, stage0_out, stage1_out, stage2_out, stage3_out]
        # Original decoder expects: [stage3_out, stage2_out, stage1_out, stage0_out, stem_out]
        # Our modified decoder input list: [stage3_out, stage2_out, stage1_out, transformer_output, stem_out]
        
        # The list `modified_feats` is already in the order [stem, stage0, stage1, stage2, stage3]
        # Reversing it gives [stage3, stage2, stage1, modified_stage0, stem]
        decoder_input_feats = modified_feats[::-1] # Reverse the list of features

        # Decoder takes the list of features
        # UnetDecoder2d expects: [bottleneck, skip3, skip2, skip1, skip0]
        # decoder_input_feats[0] = stage3_out (bottleneck)
        # decoder_input_feats[1] = stage2_out (skip3)
        # decoder_input_feats[2] = stage1_out (skip2)
        # decoder_input_feats[3] = transformer_output (replacing stage0_out, acting as skip1)
        # decoder_input_feats[4] = stem_out (acting as skip0)
        # This matches the structure if the decoder has 4 blocks processing skips and one bottleneck input.
        # Our decoder_channels config has 4 levels, matching 4 blocks.
        # The UnetDecoder expects N encoder features, uses the first as bottleneck, and the rest as N-1 skips.
        # So ecs = [stage3, stage2, stage1, stage0, stem]. len=5.
        # Decoder expects bottleneck ecs[0] (stage3). Skips ecs[1:] (stage2, stage1, stage0, stem).
        # UnetDecoder2d takes encoder_channels=[stage3_chs, stage2_chs, stage1_chs, stage0_chs, stem_chs]
        # Its first block uses stage3_chs as input, skip stage2_chs.
        # Second block uses output of first, skip stage1_chs.
        # ...
        # Fourth block uses output of third, skip stem_chs.
        # The list passed to self.decoder(list) needs to be in the order of skips/bottleneck.
        # The `modified_feats[::-1]` is already in the correct order for `UnetDecoder2d`.

        x_dec_outputs = self.decoder(decoder_input_feats)

        # Seg Head takes the last decoder output (highest resolution)
        # The last decoder output should have spatial dimensions matching the target 70x70 or slightly larger.
        x_seg = self.seg_head(x_dec_outputs[-1])

        # Final output processing (cropping, scaling)
        # Crop 1 pixel border. Assuming seg_head output is 72x72 to get 70x70.
        # This depends on the exact upsampling in the decoder and seg_head.
        # If the last decoder block outputs 70x70 and seg_head has scale_factor=1, output is 70x70.
        # Then cropping is not needed or needs adjustment.
        # Let's keep the original cropping logic, assuming the final output is slightly larger than 70x70.
        x_seg = x_seg[..., 1:-1, 1:-1] # Crop to 70x70

        x_seg = x_seg * 1500 + 3000 # Scale output


        if self.training:
            return x_seg
        else:
             # Test-time augmentation with flip
             # The proc_flip function re-runs the model forward path on flipped input.
             # It must include the Transformer block in its logic.
             p1 = self.proc_flip(x_in=batch) # Pass original batch for flip
             x_seg = torch.mean(torch.stack([x_seg, p1]), dim=0)
             return x_seg

    def proc_flip(self, x_in):
        """
        Processes input with spatial flip for test-time augmentation.
        Assumes input is (B, C_stacked, H_in, W_in).
        Flips spatial dimensions H_in and W_in.
        """
        # x_in shape: (B, 5 * N_slices, H_in, 70)
        # Spatial flip on H_in and W_in dimensions (-2, -1)
        x_in_flipped = torch.flip(x_in, dims=[-2, -1])

        # Run flipped input through backbone to get flipped features
        # Returns [stem_out, stage0_out, stage1_out, stage2_out, stage3_out] (flipped)
        feats_flipped = self.backbone(x_in_flipped)

        # Apply Transformer block to the flipped stage output
        stage_output_for_transformer_flipped = feats_flipped[self.transformer_stage_idx_in_features]
        transformer_output_flipped = self.transformer_block(stage_output_for_transformer_flipped)

        # Replace original stage output with transformer output in the flipped feature list
        modified_feats_flipped = list(feats_flipped)
        modified_feats_flipped[self.transformer_stage_idx_in_features] = transformer_output_flipped

        # Decoder input expects features in reverse order
        decoder_input_feats_flipped = modified_feats_flipped[::-1]

        # Run flipped features through decoder
        x_dec_outputs_flipped = self.decoder(decoder_input_feats_flipped)

        # Seg Head on the last decoder output
        x_seg_flipped = self.seg_head(x_dec_outputs_flipped[-1])

        # Post-processing - Crop and Flip back the output spatially
        # Crop should match the size expected before flipping back.
        x_seg_flipped = x_seg_flipped[..., 1:-1, 1:-1] # Crop to 70x70

        # Flip back spatial dimensions (H and W) of the output
        # Original code only flipped W (-1). Let's stick to flipping H and W (-2, -1) for symmetry.
        x_seg_flipped = torch.flip(x_seg_flipped, dims=[-2, -1]) # Flip H and W back

        x_seg_flipped = x_seg_flipped * 1500 + 3000 # Scale

        return x_seg_flipped


# --- Helper method for ConvNeXtBlock forward replacement (from original code) ---
# This method needs to be defined outside the class or within it if it's a static/class method.
# Keeping it outside and using MethodType in __init__.
def _convnext_block_forward(self, x):
    """Custom forward pass for ConvNeXtBlock, potentially from original code."""
    # This method is intended to replace the original forward of timm.models.convnext.ConvNeXtBlock
    # It needs access to self (the ConvNeXtBlock instance).
    # It should be copied verbatim from the original code if possible.
    # The provided snippet has this function. Let's copy it.

    shortcut = x
    x = self.conv_dw(x)

    if self.use_conv_mlp:
        x = self.norm(x) # This norm might be targeted by replace_norms
        x = self.mlp(x)
    else:
        # Original code permutes for layernorm/mlp if not use_conv_mlp
        # The norm here is nn.LayerNorm in standard ConvNeXt
        x = self.norm(x) # This norm is LayerNorm, replace_norms should skip it
        x = x.permute(0, 2, 3, 1) # (B, H, W, C)
        x = x.contiguous()
        x = self.mlp(x)
        x = x.permute(0, 3, 1, 2) # (B, C, H, W)
        x = x.contiguous()

    if self.gamma is not None:
        # Gamma is for LayerScale
        x = x * self.gamma.reshape(1, -1, 1, 1)

    # Drop path and residual connection
    x = self.drop_path(x) + self.shortcut(shortcut)
    return x

# --- End of Helper Method ---


# --- Main Execution Block ---
if __name__ == "__main__":
    # Set device based on config
    device = torch.device("cuda" if USE_DEVICE == 'GPU' and torch.cuda.is_available() else "cpu")
    cfg.device = device
    cfg.local_rank = 0 # Assuming single process

    print(f"[main loop]- Using device: {cfg.device}")

    set_seed(cfg.seed)

    # File paths (assuming these are set up correctly)
    data_paths_str = "./datasetfiles/FlatVel_A/data/*.npy"
    label_paths_str = "./datasetfiles/FlatVel_A/model/*.npy"

    # Get all file pairs
    data_paths = sorted(glob.glob(data_paths_str))
    label_paths = sorted(glob.glob(label_paths_str))
    if not data_paths or not label_paths or len(data_paths) != len(label_paths):
         print("[main loop]- Error: Data or label files not found or mismatch count.", file=sys.stderr)
         # sys.exit(1) # Exit if no data found

    all_file_pairs = list(zip(data_paths, label_paths))

    if not all_file_pairs:
        print("[main loop]- No file pairs found. Please check data_paths_str and label_paths_str.", file=sys.stderr)
        RUN_TRAIN = RUN_VALID = RUN_TEST = False # Disable runs if no data


    # Split file pairs for train/validation/test
    # Simple split (e.g., 80% train, 20% validation/test combined)
    split_ratio_train = 0.8
    split_idx_train = int(len(all_file_pairs) * split_ratio_train)

    train_file_pairs = all_file_pairs[:split_idx_train]
    # Use the rest for validation for this example
    valid_file_pairs = all_file_pairs[split_idx_train:]
    # If a separate test set is needed, split valid_file_pairs further.


    # Create datasets using the modified class
    if RUN_TRAIN:
        # train_ds = CustomDataset(cfg=cfg, file_pairs=train_file_pairs, mode="train")
        train_ds = CustomDatasetWithSlices(cfg=cfg, file_pairs=train_file_pairs, mode="train", num_input_slices=cfg.num_input_slices)
        train_dl = torch.utils.data.DataLoader(
            train_ds,
            batch_size= cfg.batch_size,
            num_workers= 4 if cfg.device.type == 'cuda' else 0, # Use more workers on GPU
            shuffle=True,
            pin_memory=cfg.device.type == 'cuda', # Pin memory for faster GPU transfer
            drop_last=True, # Drop last batch if batch size doesn't divide dataset size
        )
        # Update cfg with actual H_in from dataset
        cfg.inferred_input_height = train_ds.H_in
        cfg.input_width = train_ds.W_in # Update if dataset detected different width


    if RUN_VALID or RUN_TEST: # Create validation dataset if needed for validation or testing
        # valid_ds = CustomDataset(cfg=cfg, file_pairs=valid_file_pairs, mode="valid")
        valid_ds = CustomDatasetWithSlices(cfg=cfg, file_pairs=valid_file_pairs, mode="valid", num_input_slices=cfg.num_input_slices)
        valid_dl = torch.utils.data.DataLoader(
            valid_ds,
            batch_size= cfg.batch_size_val,
            num_workers= 4 if cfg.device.type == 'cuda' else 0, # Use more workers on GPU
            shuffle=False, # No shuffle for validation/test
            pin_memory=cfg.device.type == 'cuda',
            drop_last=False, # Don't drop last batch for evaluation
        )
        # Update cfg with actual H_in from dataset if not already set by train_ds
        if not hasattr(cfg, 'inferred_input_height'):
             cfg.inferred_input_height = valid_ds.H_in
             cfg.input_width = valid_ds.W_in


    if RUN_TRAIN or RUN_VALID or RUN_TEST:
        # Create the new model
        model = NetWithTransformer(cfg=cfg, backbone=cfg.backbone).to(cfg.device)
        print(f"[main loop]- Model created with {sum(p.numel() for p in model.parameters() if p.requires_grad):,} trainable parameters.")

        if cfg.ema and RUN_TRAIN: # Only initialize EMA if training is enabled
            print("[main loop]- Initializing EMA model..")
            ema_model = ModelEMA(
                model,
                decay=cfg.ema_decay,
                device=cfg.device, # EMA model on the same device as the main model
            )
        else:
            ema_model = None

        criterion = nn.L1Loss()
        optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)


        best_loss= 1_000_000 # Initialize with a high value
        # Initialize val_loss for logging on epoch 0 if validation is run
        val_loss = float('inf') if RUN_VALID else -1.0 # Use infinity if validating, else placeholder


        print(f"[main loop]- Starting training for {cfg.epochs} epochs...")

        for epoch in range(1, cfg.epochs + 1): # Start from epoch 1
            tstart= time.time()

            # Train loop
            if RUN_TRAIN:
                model.train()
                total_loss = []
                train_loop = tqdm(train_dl, disable=cfg.local_rank != 0, desc=f"Epoch {epoch} Training")
                for i, (x, y) in enumerate(train_loop):
                    x = x.to(cfg.device)
                    y = y.to(cfg.device) # y is (B, 70, 70)

                    logits = model(x) # logits is (B, 1, 70, 70)

                    # Criterion expects (B, 1, H, W) and (B, 1, H, W) or (B, H, W)
                    # Add channel dim to y (B, 70, 70) -> (B, 1, 70, 70)
                    loss = criterion(logits, y.unsqueeze(1))

                    loss.backward()

                    # UD tracking - Removed for simplicity in this version
                    # lr = optimizer.param_groups[0]['lr']
                    # with torch.no_grad():
                    #    pass # UD tracking omitted

                    optimizer.step()
                    optimizer.zero_grad()

                    total_loss.append(loss.item())

                    if ema_model is not None:
                        ema_model.update(model)
                    
                    # Update tqdm description
                    train_loop.set_postfix(loss=np.mean(total_loss[-cfg.logging_steps:] if len(total_loss) >= cfg.logging_steps else total_loss)) # Removed lr from postfix


                avg_train_loss = np.mean(total_loss)
                print(f"\nEpoch {epoch} Train Loss: {avg_train_loss:.4f}")


            # Validation loop
            if RUN_VALID:
                model.eval()
                val_logits = []
                val_targets = []
                valid_loop = tqdm(valid_dl, disable=cfg.local_rank != 0, desc=f"Epoch {epoch} Validation")
                with torch.no_grad():
                    for x, y in valid_loop:
                        x = x.to(cfg.device)
                        y = y.to(cfg.device)

                        # Use EMA model if available, otherwise use main model
                        current_model = ema_model.module if ema_model is not None else model
                        out = current_model(x) # This includes TTA (flip average)

                        val_logits.append(out.cpu())
                        val_targets.append(y.cpu())

                    val_logits= torch.cat(val_logits, dim=0) # (N_samples_val, 1, 70, 70)
                    val_targets= torch.cat(val_targets, dim=0) # (N_samples_val, 70, 70)

                    # Criterion expects (B, 1, H, W) and (B, 1, H, W) or (B, H, W)
                    val_loss = criterion(val_logits, val_targets.unsqueeze(1)).item() # Add channel dim to targets

                    print(f"Epoch {epoch} Validation Loss: {val_loss:.4f}")

                    # Save best model (if validation loss is lower)
                    if val_loss < best_loss:
                        best_loss = val_loss
                        print(f"Validation loss improved. Saving model to 'best_model.pth'.")
                        # Save the state dict of the main model (or EMA model)
                        save_model = ema_model.module if ema_model is not None else model
                        # Save to a fixed name 'best_model.pth' to easily load the best one
                        torch.save(save_model.state_dict(), f"best_model.pth")
                        cfg.early_stopping['streak'] = 0 # Reset streak
                    else:
                        cfg.early_stopping['streak'] += 1 # Increment streak
                        print(f"Validation loss did not improve. Early stopping streak: {cfg.early_stopping['streak']}/{cfg.early_stopping['patience']}.")


            t_epoch = time.time() - tstart
            print(f"Epoch {epoch} finished in {format_time(t_epoch)}.")

            # Early Stopping check
            if RUN_VALID and cfg.early_stopping['streak'] >= cfg.early_stopping['patience']:
                 print(f"Early stopping triggered after {cfg.early_stopping['patience']} epochs without improvement.")
                 break # Exit training loop

        print("Training finished.")


    # --- Testing (Evaluation on validation/test set) ---
    # Use RUN_TEST flag to control final evaluation
    if RUN_TEST:
        print("\nRunning final evaluation...")
        # Load the best model if it was saved during validation
        test_model = NetWithTransformer(cfg=cfg, backbone=cfg.backbone).to(cfg.device)
        best_model_path = "best_model.pth" # Use the fixed name
        if os.path.exists(best_model_path):
            print(f"Loading best model from {best_model_path}")
            try:
                test_model.load_state_dict(torch.load(best_model_path, map_location=cfg.device))
            except RuntimeError as e:
                print(f"Error loading model state_dict: {e}", file=sys.stderr)
                print("This might happen if model architecture or saved state_dict mismatch.", file=sys.stderr)
                # Attempt partial load or skip loading if critical error
                print("Attempting to load without strict matching...", file=sys.stderr)
                try:
                    test_model.load_state_dict(torch.load(best_model_path, map_location=cfg.device), strict=False)
                    print("Partial load successful.")
                except Exception as e_partial:
                    print(f"Partial load failed: {e_partial}", file=sys.stderr)
                    print("Proceeding with randomly initialized model (or last epoch if no best saved).", file=sys.stderr)

        else:
            print("No best model found at 'best_model.pth'. Using the last state from training (if RUN_TRAIN was true).")
            # If training didn't save 'best_model.pth' (e.g. RUN_VALID=False),
            # the 'model' variable holds the last trained state.
            # If RUN_TRAIN was also false, this is a fresh model.

        test_model.eval()
        test_logits = []
        test_targets = []
        # Use valid_dl for testing for this example, assuming it's the evaluation set
        test_loop = tqdm(valid_dl, disable=cfg.local_rank != 0, desc="Evaluating")
        with torch.no_grad():
            for x, y in test_loop:
                 x = x.to(cfg.device)
                 y = y.to(cfg.device)

                 # Model forward includes TTA (flip average) when not in training mode
                 out = test_model(x)

                 test_logits.append(out.cpu())
                 test_targets.append(y.cpu())

        test_logits = torch.cat(test_logits, dim=0) # (N_samples_test, 1, 70, 70)
        test_targets = torch.cat(test_targets, dim=0) # (N_samples_test, 70, 70)

        # Calculate final loss on the test set
        final_test_loss = criterion(test_logits, test_targets.unsqueeze(1)).item()
        print(f"\nFinal Evaluation Loss: {final_test_loss:.4f}")

        # Optional: Save predictions or visualize results
        # print("Saving predictions (optional)...")
        # np.save("test_predictions.npy", test_logits.numpy())
        # np.save("test_targets.npy", test_targets.numpy())

    print("Script finished.")


[main loop]- Using device: cpu
Loading train data using mmap_mode='r'...


Loading train data (mmap): 100%|██████████| 1/1 [00:00<00:00, 294.34it/s]

Finished loading 1 out of 1 file pairs successfully for train mode.
Dataset initialized in train mode.
Loaded 1 file pairs containing a total of 500 raw samples.
Input shape per single slice: (5, 1000, 70)
Output label shape: (70, 70)
Window size for stacking: 5 slices (padding 2 on each side).
Generated 100 effective samples for training/validation after considering windowing and subsampling.
[NetWithTransformer - init / num_input_slices] : 5
[NetWithTransformer - init / self.H_in] : 1000
[NetWithTransformer - init / self.W_in] : 70





ValueError: 'default' is not a valid Format

In [None]:
-