In [None]:
""" 

はい、ご提示いただいたConvNeXtベースのモデルに、waveformの時系列情報（T=1000）とGeoPhoneの空間情報（C=70）をより効果的に取り込むためのTransformer要素を組み込んだハイブリッドモデルのコードを生成します。

現在のモデルは、入力waveform `(B, C=70, T=1000, W=70)` から特定のタイムステップにおける2Dスライス `(C=70, W=70)` （あるいは `(H=70, W=70)` と解釈される何か）を取り出して2D ConvNetに入力しているようです。これでは時系列情報（T=1000）のほとんどが失われてしまいます。また、GeoPhone数（C=70）がチャネルとして扱われるべきですが、現在のConvNeXt Encoderは `in_chans=5` となっており、入力チャネル数の扱いが不明確です。

ユーザーの要望「時系列や７０か所のGeoPhoneのずれの情報も取込」を真に満たすには、Datasetから waveform の3D情報全体（または適切な時系列ウィンドウ）を取得し、モデル側で3Dデータとして処理する必要があります。

ここでは、以下の構造を持つハイブリッドモデルを提案・実装します。
1.  **Datasetの変更**: 各ファイルから waveform 全体 `(T=1000, C=70, W=70)` （これを `(D=1000, H=70, W=70)` と解釈し、チャネルを1とします `(1, 1000, 70, 70)`）と、対応する2D地層図 `(H=70, W=70)` ラベルを返すように修正します。
2.  **3D Conv Encoder**: 入力の3D waveform `(B, 1, 1000, 70, 70)` に対して、3D Convolutional層を重ねて特徴抽出し、時間・空間次元をダウンサンプリングします。
3.  **Transformer Bottleneck**: 3D Conv Encoderの出力である圧縮された3D特徴マップをflattenしてシーケンスに変換し、Transformer Encoderで処理します。これにより、グローバルな時空間的関係性を捉えます。
4.  **Projection to 2D**: Transformerの出力を、2D Decoderが受け付けられる2D特徴マップの形状に変換します（例: 時間方向のPooling）。
5.  **2D Decoder**: 既存の `UnetDecoder2d` ライクな構造を再利用し、Bottleneckからの2D特徴マップと、3D Conv Encoderの途中からのSkip Connection（これも2Dに投影）を使って、最終的な2D地層図 `(70, 70)` を生成します。
6.  **Segmentation Head**: Decoderの出力を最終的な出力次元に変換します。

これにより、 waveform の時間情報（T=1000）と空間情報（GeoPhoneの配置C=70、Spatial W=70）を3Dデータとして扱い、Conv3Dでローカルな特徴を、Transformerでグローバルな特徴を学習できます。

Datasetの修正と、それに合わせた新しい `HybridModel` クラスのコードを以下に示します。既存のトレーニングループと評価ループは、新しいモデルクラスとDatasetを使用するように調整します。

"""

In [None]:
import random
import os
import time, glob
import numpy as np
import sys # Add sys for stderr
import datetime
import math

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
# Removed torch.distributed and torch.cuda.amp based on user's setup

from tqdm import tqdm
# from _cfg import cfg # Replace with SimpleNamespace
from types import SimpleNamespace

In [None]:
try:
    import timm
    # from timm.models.layers import DropPath, trunc_normal_ # DropPath requires full timm or specific import
    from timm.models.layers import trunc_normal_ # Assuming trunc_normal_ is enough
except ImportError:
    print("timm not found, installing...")
    os.system("pip install --no-deps timm -q")
    import timm
    from timm.models.layers import trunc_normal_

try:
    import monai
    from monai.networks.blocks import UpSample, SubpixelUpsample, ConvAct, Act, Norm
    from monai.networks.layers import PatchEmbed # Might not need PatchEmbed directly, but related concepts
except ImportError:
    print("monai not found, installing...")
    os.system("pip install --no-deps monai -q")
    import monai
    from monai.networks.blocks import UpSample, SubpixelUpsample, ConvAct, Act, Norm
    from monai.networks.layers import PatchEmbed

In [None]:
# Use SimpleNamespace for cfg based on the user's snippet
cfg= SimpleNamespace()
cfg.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
cfg.local_rank = 0 # Assume single GPU or CPU
cfg.seed = 123
cfg.subsample = None # Set to None to use all data, or e.g., 100 for a small subset

data_paths_str = "./datasetfiles/FlatVel_A/data/*.npy"
label_paths_str = "./datasetfiles/FlatVel_A/model/*.npy"
cfg.file_pairs = list(zip(sorted(glob.glob(data_paths_str)), sorted(glob.glob(label_paths_str))))

# Model parameters for the new HybridModel
cfg.in_channels = 1 # Input waveform channel (value)
cfg.spatial_size = (1000, 70, 70) # (D, H, W) - Time, GeoPhone Locations, Spatial Width
cfg.encoder_dims = (32, 64, 128, 256) # Channels for 3D Conv stages
cfg.embed_dim = 768 # Dimension for Transformer
cfg.num_heads = 8 # Attention heads in Transformer
cfg.mlp_ratio = 4. # MLP ratio in Transformer
cfg.transformer_depth = 6 # Number of Transformer blocks
cfg.decoder_channels = (256, 128, 64, 32) # Output channels of 2D Decoder blocks (4 blocks)
cfg.decoder_attention_type = "scse" # For 2D Decoder
cfg.upsample_mode = "transpose" # For 2D Decoder

cfg.ema = True
cfg.ema_decay = 0.99

cfg.epochs = 4
cfg.batch_size = 8
cfg.batch_size_val = 8

cfg.early_stopping = {"patience": 3, "streak": 0}
cfg.logging_steps = 10


def set_seed(seed=cfg.seed):
    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    # Removed distributed/cuda specific seed setting

In [None]:
# --- CustomDataset Modification ---
class CustomDataset(torch.utils.data.Dataset):
    def __init__(
        self,
        cfg,
        file_pairs,
        mode="train",
    ):
        self.cfg = cfg
        self.mode = mode
        self.file_pairs = file_pairs

        # Data and labels loaded as list of arrays (mmap)
        self.data, self.labels = self._load_data_arrays()

        # Assuming labels are (70, 70) per file, not time-dependent
        # The number of samples is simply the number of files, potentially subsampled
        total_files_available = len(self.data)

        # Subsample logic applies to files, not time steps within files
        subsample = getattr(self.cfg, "subsample", None)
        self.total_samples = min(subsample, total_files_available) if subsample else total_files_available

        # Index map is now just a list of file indices
        self.index_map = list(range(self.total_samples))


    def _load_data_arrays(self):
        data_arrays = []
        label_arrays = []
        mmap_mode = "r"

        for data_fpath, label_fpath in tqdm(
                        self.file_pairs, desc=f"Loading {self.mode} data (mmap)",
                        disable=self.cfg.local_rank != 0):
            try:
                # Load the numpy arrays using memory mapping
                # Expected data shape: (T, C, W) approx (1000, 70, 70)
                # Expected label shape: (H, W) approx (70, 70)
                arr = np.load(data_fpath, mmap_mode=mmap_mode)
                lbl = np.load(label_fpath, mmap_mode=mmap_mode)
                # print(f"Loaded {data_fpath}: data {arr.shape}, label {lbl.shape}") # Keep this check during development
                data_arrays.append(arr)
                label_arrays.append(lbl)
            except FileNotFoundError:
                print(f"Error: File not found - {data_fpath} or {label_fpath}", file=sys.stderr)
            except Exception as e:
                print(f"Error loading file pair: {data_fpath}, {label_fpath}", file=sys.stderr)
                print(f"Error: {e}", file=sys.stderr)
                continue

        if self.cfg.local_rank == 0:
            print(f"Finished loading {len(data_arrays)} file pairs for {self.mode} mode.") # Avoid spamming

        return data_arrays, label_arrays


    def __getitem__(self, idx):
        file_idx = self.index_map[idx]

        x_full = self.data[file_idx] # Shape (T, C, W) = (1000, 70, 70)
        y_full = self.labels[file_idx] # Shape (H, W) = (70, 70)

        # --- Augmentations (Apply to full 3D data and 2D label if applicable) ---
        # Note: 3D augmentations are more complex. Simple spatial flips applied consistently.
        x_augmented = x_full
        y_augmented = y_full

        if self.mode == "train":
            # Temporal flip (flips time dimension T)
            if np.random.random() < 0.5:
                x_augmented = x_augmented[::-1, ...] # Flip time axis (dim 0)

            # Spatial flip (flips W dimension) - Apply to both data (over W) and label (over W)
            if np.random.random() < 0.5:
                x_augmented = x_augmented[..., ::-1] # Flip W axis (dim 2)
                y_augmented = y_augmented[:, ::-1]   # Flip W axis (dim 1) for label (H, W)

            # Add more 3D augmentations if needed (e.g., shifts, rotations, scaling)

        # --- Convert to Tensor and add Channel dimension for Conv3d ---
        # PyTorch Conv3d expects (N, C_in, D, H, W)
        # Our data is (T, C, W) = (1000, 70, 70). Map this to (D, H, W) with C_in=1.
        # D = Time (1000), H = GeoPhone locations (70), W = Spatial width (70)
        x_tensor = torch.from_numpy(x_augmented.copy()).float().unsqueeze(0) # Shape (1, 1000, 70, 70)
        y_tensor = torch.from_numpy(y_augmented.copy()).float() # Shape (70, 70)

        return x_tensor, y_tensor

    def __len__(self):
        return self.total_samples

In [None]:
# --- Transformer Block Definition ---
class TransformerBlock(nn.Module):
    def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
        super().__init__()
        self.norm1 = norm_layer(dim)
        # Using nn.MultiheadAttention. batch_first=True is crucial.
        self.attn = nn.MultiheadAttention(dim, num_heads, dropout=attn_drop, bias=qkv_bias, batch_first=True)
        self.drop_path = nn.Identity() # Replace with DropPath if timm is available
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = nn.Sequential(
            nn.Linear(dim, mlp_hidden_dim),
            act_layer(),
            nn.Dropout(drop),
            nn.Linear(mlp_hidden_dim, dim),
            nn.Dropout(drop)
        )

    def forward(self, x):
        # x is (B, SeqLen, Dim)
        residual = x
        x = self.norm1(x)
        # MultiheadAttention needs query, key, value. Self-attention uses x for all three.
        # attn returns output, weights. We only need output [0].
        x = self.attn(x, x, x)[0]
        x = residual + self.drop_path(x)

        residual = x
        x = self.norm2(x)
        x = self.mlp(x)
        x = residual + self.drop_path(x)

        return x

# --- Helper function to calculate output size of convolution ---
def calc_conv_out_size(size_in, kernel_size, stride, padding):
    return (size_in + 2 * padding - kernel_size) // stride + 1


# --- Hybrid Model Class ---
class HybridModel(nn.Module):
    def __init__(
        self,
        in_channels=1,
        spatial_size=(1000, 70, 70), # (D, H, W) input shape
        encoder_dims=(32, 64, 128, 256), # Output channels of 3D Conv stages
        embed_dim=768, # Dimension for Transformer
        num_heads=8, # Attention heads in Transformer
        mlp_ratio=4., # MLP ratio in Transformer
        transformer_depth=6, # Number of Transformer blocks
        decoder_channels: tuple = (256, 128, 64, 32), # Output channels of 2D Decoder blocks
        decoder_attention_type: str = "scse", # For 2D Decoder
        upsample_mode: str = "transpose", # For 2D Decoder UpSample layers
    ):
        super().__init__()

        self.in_channels = in_channels
        self.spatial_size = spatial_size # (D, H, W)
        self.encoder_dims = encoder_dims
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.mlp_ratio = mlp_ratio
        self.transformer_depth = transformer_depth
        self.decoder_channels = decoder_channels

        self.num_encoder_stages = len(encoder_dims)
        self.conv_stages = nn.ModuleList()
        self._skip_3d_spatial_sizes = [] # Store (D, H, W) size after each 3D conv stage

        # --- 3D Conv Encoder Stages ---
        # Downsample (D, H, W) dimensions
        in_dim = in_channels
        current_spatial_size = list(spatial_size) # [D, H, W]

        # Example 3D Conv configurations (kernel, stride, padding)
        # Adjust these based on desired downsampling and output sizes
        conv3d_configs = [
            ((4, 4, 4), (4, 4, 4), (2, 2, 2)), # Stage 1: Aggressive downsampling
            ((2, 2, 2), (2, 2, 2), (1, 1, 1)), # Stage 2
            ((2, 2, 2), (2, 2, 2), (1, 1, 1)), # Stage 3
            ((2, 2, 2), (2, 2, 2), (1, 1, 1)), # Stage 4
        ]
        assert len(conv3d_configs) == self.num_encoder_stages

        for i in range(self.num_encoder_stages):
            k, s, p = conv3d_configs[i]
            out_dim = encoder_dims[i]

            self.conv_stages.append(
                nn.Sequential(
                    nn.Conv3d(in_dim, out_dim, kernel_size=k, stride=s, padding=p),
                    nn.GroupNorm(out_dim // 8, out_dim), # Using GroupNorm for 3D
                    nn.GELU(),
                )
            )
            in_dim = out_dim

            # Calculate and store output spatial size
            current_spatial_size = [ calc_conv_out_size(sz, k_i, s_i, p_i) for sz, k_i, s_i, p_i in zip(current_spatial_size, k, s, p) ]
            self._skip_3d_spatial_sizes.append(tuple(current_spatial_size))
            # print(f"3D Conv Stage {i} output size (D, H, W): {current_spatial_size}")

        # Final 3D feature map size after Conv stages
        self.bottleneck_spatial_size_3d = self._skip_3d_spatial_sizes[-1]
        # print(f"Bottleneck 3D size: {self.bottleneck_spatial_size_3d}")
        D_b, H_b, W_b = self.bottleneck_spatial_size_3d


        # --- Transformer Bottleneck ---
        # Flatten the spatial and temporal dimensions of the bottleneck feature map
        # Sequence length will be D_b * H_b * W_b
        sequence_length = D_b * H_b * W_b
        transformer_input_channels = encoder_dims[-1] # Channels from last 3D Conv stage

        # Project channels to embed_dim if necessary
        self.transformer_input_proj = nn.Linear(transformer_input_channels, embed_dim) if transformer_input_channels != embed_dim else nn.Identity()

        # Positional Embedding (learnable)
        self.pos_embed = nn.Parameter(torch.zeros(1, sequence_length, embed_dim))
        trunc_normal_(self.pos_embed, std=.02)

        # Transformer Encoder Layers
        norm_layer = nn.LayerNorm # Use LayerNorm for Transformer
        act_layer = nn.GELU # Use GELU for Transformer
        self.transformer_blocks = nn.ModuleList([
            TransformerBlock(
                dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio,
                qkv_bias=False, drop=0., attn_drop=0., drop_path=0., # Set dropout/droppath
                norm_layer=norm_layer, act_layer=act_layer
            ) for i in range(transformer_depth)])

        self.transformer_norm = norm_layer(embed_dim)


        # --- Projection from Transformer output to 2D Decoder input ---
        # Transformer output shape: (B, sequence_length, embed_dim)
        # Reshape to (B, embed_dim, D_b, H_b, W_b) and pool over time dimension D_b
        # to get (B, embed_dim, H_b, W_b) for 2D Decoder input (bottleneck feature)
        self.bottleneck_projection_3d_to_2d = nn.Sequential(
             # Reshape is done in forward pass
             # Average pool over the Time dimension (dim 2 after permute)
             # nn.AdaptiveAvgPool3d((1, H_b, W_b)) # Or just mean over dim 2
             # Squeeze time dimension (dim 2)
        )
        self.bottleneck_2d_channels = embed_dim # Output channels of the 2D bottleneck feature


        # --- 2D Skip Connection Projection Layers ---
        # Project 3D features from intermediate Conv stages to 2D features for skip connections
        # Decoder needs skips from levels 1, 2, 3 (corresponding to encoder_dims[0], [1], [2])
        # These skips need to match spatial size (H, W) and channel count required by the decoder blocks.
        # The decoder_channels tuple (256, 128, 64, 32) implies 4 decoder blocks.
        # Decoder block 0 takes bottleneck + skip from level 3.
        # Decoder block 1 takes output of block 0 + skip from level 2.
        # Decoder block 2 takes output of block 1 + skip from level 1.
        # Decoder block 3 takes output of block 2 + no skip (or 0 channels skip).

        # Channels expected for skips by the decoder (matching decoder_channels[1:])
        # Target channels for skips: decoder_channels[1], decoder_channels[2], decoder_channels[3]
        target_skip_2d_channels = list(decoder_channels[1:]) # [128, 64, 32]

        self.skip_projection_layers = nn.ModuleList()
        # Iterate through 3D Conv stages outputs for skip connections (excluding the last one)
        # We need skips from stages 0, 1, 2 (corresponding to encoder_dims[0,1,2])
        # The spatial sizes are self._skip_3d_spatial_sizes[0, 1, 2]
        # Pair 3D skip info with target 2D decoder skip channels
        # Skips for decoder are typically from finer to coarser features.
        # Decoder Block 0 needs skip from Conv3D Stage 3 -> projected 2D
        # Decoder Block 1 needs skip from Conv3D Stage 2 -> projected 2D
        # Decoder Block 2 needs skip from Conv3D Stage 1 -> projected 2D
        # This aligns with skip_features_3d [stage0, stage1, stage2] and target_skip_2d_channels [128, 64, 32] (reversed).

        # Iterate through 3D Conv stages for skips (from coarsest to finest)
        # Use encoder_dims[:-1] for input channels and _skip_3d_spatial_sizes[:-1] for sizes
        # Use target_skip_2d_channels for output channels (reversed for decoder)
        skip_3d_info = list(zip(encoder_dims[:-1], self._skip_3d_spatial_sizes[:-1])) # [(32, sz0), (64, sz1), (128, sz2)]

        # Pair 3D skip info (reversed for decoder) with target 2D channels (reversed for decoder)
        skip_pairs = list(zip(skip_3d_info[::-1], target_skip_2d_channels)) # [((128, sz2), 128), ((64, sz1), 64), ((32, sz0), 32)]

        for (skip_c_3d, skip_sz_3d), target_2d_channels in skip_pairs:
            self.skip_projection_layers.append(
                 nn.Sequential(
                     nn.AdaptiveAvgPool3d((1, skip_sz_3d[1], skip_sz_3d[2])), # Pool time D_i to 1 -> (B, C_in, 1, H_i, W_i)
                     nn.Conv2d(skip_c_3d, target_2d_channels, kernel_size=1), # Project channels -> (B, C_out, H_i, W_i)
                 )
            )

        # --- 2D Decoder Setup for UnetDecoder2d ---
        # UnetDecoder2d expects:
        # encoder_channels: tuple of channels from encoder stages, starting *before* the first skip used.
        #   Effectively channels of the input list `feats`[::-1].
        #   feats list will be [Bottleneck_2D, Projected_Skip_Level_1_2D, Projected_Skip_Level_2_2D, Projected_Skip_Level_3_2D]
        #   Channels: [embed_dim, target_skip_2d_channels[2], target_skip_2d_channels[1], target_skip_2d_channels[0]]
        #   Example: [embed_dim, 32, 64, 128]
        decoder_enc_channels_for_init = [embed_dim] + target_skip_2d_channels[::-1]

        # skip_channels: tuple of channels for skip connections, starting from the first skip used.
        #   Effectively channels of `feats`[1:] + [0].
        #   Channels: [target_skip_2d_channels[2], target_skip_2d_channels[1], target_skip_2d_channels[0], 0]
        #   Example: [32, 64, 128, 0]
        decoder_skip_channels_for_init = target_skip_2d_channels[::-1] + [0]


        # Re-calculate decoder_channels if needed, ensure consistency with decoder blocks.
        # The decoder_channels tuple passed to UnetDecoder2d defines the *output* channels of its blocks.
        # The input channels to its blocks are derived from encoder_channels and skip_channels.
        # Let's use the provided decoder_channels tuple directly.
        # Ensure the lengths match: len(decoder_channels) == num_decoder_blocks.
        # Number of decoder blocks seems to be len(decoder_channels).
        # UnetDecoder2d internally creates len(decoder_channels) blocks.
        # It needs len(decoder_channels) + 1 items in `feats` (1 bottleneck, len(decoder_channels) skips).
        # This implies we need skips from all `num_encoder_stages` levels *before* the bottleneck.
        # With 4 encoder stages, and last one is bottleneck, we need 3 skips.
        # If decoder_channels has length 4, UnetDecoder2d wants 4 skips + 1 bottleneck = 5 items in feats.
        # This means we need 4 skips from the encoder levels before the bottleneck.
        # This implies the bottleneck should come from stage 5 if encoder had 5 stages.
        # Or, UnetDecoder2d is designed for encoders with 1 more stage than decoder blocks.

        # Let's align the model structure. If decoder_channels is length 4, we need 4 decoder blocks.
        # Decoder needs 4 skips (from encoder stages 0, 1, 2, 3) + 1 bottleneck (from encoder stage 4).
        # This means the last 3D Conv stage (encoder_dims[-1]) should be the source of the bottleneck.
        # And encoder_dims[:-1] (stages 0, 1, 2) are sources for skips.
        # We seem to be missing one skip connection needed by the decoder if decoder_channels has length 4.

        # Let's assume the original decoder was meant to work with the 4 encoder stages provided in `encoder_dims`.
        # This would mean the decoder receives 1 bottleneck feature and 3 skip features.
        # UnetDecoder2d has `decoder_channels = (256, 128, 64, 32)`, length 4. It builds 4 blocks.
        # It expects `encoder_channels` (inputs to concat) and `skip_channels` (the skips themselves).
        # Its loop is `for i, b in enumerate(self.blocks): skip = feats[i] if i < len(feats) else None`
        # If feats has [bottleneck, skip1, skip2, skip3], it uses skip1 for block 0, skip2 for block 1, skip3 for block 2, and None for block 3.
        # This implies decoder_channels length should be len(feats) - 1 = 3 if we have 3 skips + 1 bottleneck.
        # But decoder_channels has length 4.
        # This means UnetDecoder2d with decoder_channels len 4 expects 4 skips + 1 bottleneck.

        # Re-interpreting the original UnetDecoder2d and how it's used in the base code:
        # base code: ecs = [_["num_chs"] for _ in self.backbone.feature_info][::-1] # 4 channels from ConvNeXt stages reversed
        # decoder = UnetDecoder2d(encoder_channels=ecs, ...)
        # decoder_channels = (256, 128, 64, 32) # len 4
        # UnetDecoder2d init: `in_channels= [encoder_channels[0]] + list(decoder_channels[:-1])`
        # `skip_channels= list(encoder_channels[1:]) + [0]`
        # This means UnetDecoder2d is initialized with 4 encoder_channels.
        # ecs example: [1024, 512, 256, 128] from ConvNeXt.
        # encoder_channels = (1024, 512, 256, 128)
        # decoder_channels = (256, 128, 64, 32)
        # in_channels_to_decoder_blocks = [1024] + [256, 128, 64] = [1024, 256, 128, 64]
        # skip_channels_for_decoder_init = [512, 256, 128] + [0] = [512, 256, 128, 0]
        # Decoder blocks loop: uses `zip(in_channels_to_blocks, skip_channels_for_init, decoder_channels)`
        # (1024, 512, 256), (256, 256, 128), (128, 128, 64), (64, 0, 32)
        # Block 0 input chans: 1024 (from prev block/bottleneck) + 512 (skip) -> concat -> 1024+512. Output 256.
        # Block 1 input chans: 256 (from prev block) + 256 (skip) -> concat -> 512. Output 128.
        # Block 2 input chans: 128 (from prev block) + 128 (skip) -> concat -> 256. Output 64.
        # Block 3 input chans: 64 (from prev block) + 0 (skip) -> concat -> 64. Output 32.

        # This structure requires `len(encoder_channels)` features as input `feats`, where feats[0] is bottleneck, feats[1:] are skips.
        # It uses `encoder_channels` for the *first* channel in the concatenation input to the blocks, and `skip_channels` for the *second*.
        # This is confusing. Let's assume the standard U-Net decoder input structure:
        # Feats list: [bottleneck_feature, skip_feature_level1, skip_feature_level2, ..., skip_feature_levelN]
        # Decoder has N blocks.
        # Block 0 input: concat(bottleneck, skip_feature_level1)
        # Block 1 input: concat(Output_Block0, skip_feature_level2)
        # ...
        # Block N-1 input: concat(Output_BlockN-2, skip_feature_levelN)

        # Re-aligning `UnetDecoder2d` to a more standard U-Net implementation structure.
        # If `decoder_channels` has length N, there are N decoder blocks.
        # Need 1 bottleneck feature and N skip features.
        # So `feats` list should have length N+1.
        # `encoder_channels` should be the list of channels of `feats`[::-1].
        # `skip_channels` should be the list of channels of `feats`[1:] + [0].

        # Let's match `decoder_channels` length (4) with the number of encoder stages (4).
        # This implies encoder stage 3 (idx 3) is bottleneck, and stages 0, 1, 2 are skips.
        # We need 4 skips if decoder has 4 blocks. This contradicts the 4 encoder stages -> 3 skips + 1 bottleneck idea.
        # Perhaps the decoder is meant to work with the 4 ConvNeXt stages, and takes 4 features (bottleneck + 3 skips).
        # If `decoder_channels` length is 4, and UnetDecoder2d builds 4 blocks, it needs 4 skips and 1 bottleneck.
        # This requires 5 levels of features from the encoder.

        # Let's assume the decoder_channels length defines the number of upsampling steps / resolution levels in the decoder.
        # If `decoder_channels` is (256, 128, 64, 32) -> 4 levels.
        # Decoder goes from lowest res up to highest res.
        # Low res input comes from bottleneck. Intermediate res inputs come from skips.
        # Upsampling factor needed: Initial H, W (70, 70) -> Final Bottleneck H, W (~4, 4) -> Target H, W (70, 70).
        # Total spatial downsample factor in encoder: ~17.5 (70->4).
        # Decoder needs total upsample factor ~17.5. 4 blocks with scale 2 gives 16.

        # Let's assume the provided `decoder_channels` (len 4) and the number of encoder stages (len 4) are designed to match, but maybe in a slightly non-standard U-Net way as per the original `UnetDecoder2d`.

        # Based on the original `UnetDecoder2d` usage:
        # `encoder_channels=ecs` (reversed ConvNeXt stages, len 4)
        # `skip_channels=None` initially in base code, then calculated as `list(encoder_channels[1:]) + [0]`
        # `decoder_channels=(256, 128, 64, 32)` (len 4)
        # This implies the decoder takes 4 input features (from encoder stages) and processes them through 4 blocks.

        # Adapting this to our Hybrid model:
        # Our 3D Conv stages provide features at 4 levels (encoder_dims).
        # We need to convert these 4 levels of 3D features into 4 levels of 2D features for the decoder.
        # Level 4 (finest 3D): embed_dim channels, (D_b, H_b, W_b) size. -> Pool time -> (embed_dim, H_b, W_b) 2D bottleneck.
        # Level 3 (3D): encoder_dims[2] channels, size (_skip_3d_spatial_sizes[2]). -> Project to 2D (target chan), Pool time -> (target chan, H_3, W_3) 2D skip.
        # Level 2 (3D): encoder_dims[1] channels, size (_skip_3d_spatial_sizes[1]). -> Project to 2D (target chan), Pool time -> (target chan, H_2, W_2) 2D skip.
        # Level 1 (3D): encoder_dims[0] channels, size (_skip_3d_spatial_sizes[0]). -> Project to 2D (target chan), Pool time -> (target chan, H_1, W_1) 2D skip.

        # We need to match the channels and spatial sizes for these 4 levels of 2D features to what UnetDecoder2d expects.
        # The original base code uses ConvNeXt features reversed. Let's assume our 3D Conv + Transformer features correspond to these.
        # If original ConvNeXt features (reversed) are [1024, 512, 256, 128].
        # Our features should mimic this relative scaling if possible.
        # Our levels: Bottleneck_2D (embed_dim=768), Projected Skip 3 (?), Skip 2 (?), Skip 1 (?)

        # Let's make the projected skip channels match the decoder_channels inputs roughly.
        # If decoder_channels = (256, 128, 64, 32) are output channels of blocks.
        # Input channels to blocks are sum of prev output + skip.
        # Let's simplify and assume UnetDecoder2d expects 4 input features, and we need to project our 4 levels to match expected channels.
        # Let's try to match the number of channels in `encoder_channels` tuple passed to UnetDecoder2d.
        # Original `ecs` had length 4. So `encoder_channels` param has length 4.
        # Let's define the channels for our 4 levels of 2D features:
        # Level 4 (bottleneck): embed_dim
        # Level 3: encoder_dims[2] -> project to D1=256?
        # Level 2: encoder_dims[1] -> project to D2=128?
        # Level 1: encoder_dims[0] -> project to D3=64?

        # Let's redefine the skip projection target channels to match the decoder_channels tuple directly.
        # Skip projection target channels: decoder_channels[0], decoder_channels[1], decoder_channels[2]
        # Skip from Stage 3 (encoder_dims[2]): project to decoder_channels[0] (256)
        # Skip from Stage 2 (encoder_dims[1]): project to decoder_channels[1] (128)
        # Skip from Stage 1 (encoder_dims[0]): project to decoder_channels[2] (64)

        target_skip_2d_channels_for_proj = list(decoder_channels[:-1]) # [256, 128, 64]

        self.skip_projection_layers = nn.ModuleList()
        # Iterate through 3D Conv stages 0, 1, 2 for skips
        # Use encoder_dims[:-1] for input channels and _skip_3d_spatial_sizes[:-1] for sizes
        skip_3d_info = list(zip(encoder_dims[:-1], self._skip_3d_spatial_sizes[:-1])) # [(32, sz0), (64, sz1), (128, sz2)]

        # Pair 3D skip info with target 2D channels for projection
        skip_proj_pairs = list(zip(skip_3d_info, target_skip_2d_channels_for_proj)) # [((32, sz0), 256), ((64, sz1), 128), ((128, sz2), 64)]

        for (skip_c_3d, skip_sz_3d), target_2d_channels in skip_proj_pairs:
            self.skip_projection_layers.append(
                 nn.Sequential(
                     nn.AdaptiveAvgPool3d((1, skip_sz_3d[1], skip_sz_3d[2])), # Pool time D_i to 1 -> (B, C_in, 1, H_i, W_i)
                     nn.Conv2d(skip_c_3d, target_2d_channels, kernel_size=1), # Project channels -> (B, C_out, H_i, W_i)
                 )
            )

        # Now define the 4 features feeding the decoder.
        # Feats list: [Bottleneck_2D, Projected_Skip_Level_3_2D, Projected_Skip_Level_2_2D, Projected_Skip_Level_1_2D]
        # Channels: [embed_dim, target_skip_2d_channels_for_proj[0], target_skip_2d_channels_for_proj[1], target_skip_2d_channels_for_proj[2]]
        # Example: [embed_dim, 256, 128, 64]

        # Define encoder_channels and skip_channels for UnetDecoder2d init:
        # encoder_channels (inputs to concat, from feats[::-1]): [target_skip_2d_channels_for_proj[2], target_skip_2d_channels_for_proj[1], target_skip_2d_channels_for_proj[0], embed_dim]
        # Example: [64, 128, 256, embed_dim]
        decoder_enc_channels_for_init = target_skip_2d_channels_for_proj[::-1] + [embed_dim]

        # skip_channels (the skips themselves, from feats[1:] + [0]): [target_skip_2d_channels_for_proj[0], target_skip_2d_channels_for_proj[1], target_skip_2d_channels_for_proj[2], 0]
        # Example: [256, 128, 64, 0]
        decoder_skip_channels_for_init = target_skip_2d_channels_for_proj + [0]

        # print("UnetDecoder2d encoder_channels (init):", decoder_enc_channels_for_init)
        # print("UnetDecoder2d skip_channels (init):", decoder_skip_channels_for_init)
        # print("UnetDecoder2d decoder_channels (init):", decoder_channels)

        # Determine scale factors for the decoder
        # Decoder goes from bottleneck spatial size (H_b, W_b) to target (70, 70).
        # Number of decoder blocks = len(decoder_channels) = 4.
        # Total spatial upsample factor needed = (70 / H_b, 70 / W_b).
        # Example: (70 / 4, 70 / 4) = (17.5, 17.5) if H_b, W_b = 4.
        # Average upsample factor per block = (17.5)**(1/4) ~ 2.04
        # Use float scale factors in MONAI UpSample

        H_b, W_b = self.bottleneck_spatial_size_3d[1:] # H, W from the bottleneck 3D features
        num_decoder_blocks = len(decoder_channels)
        final_decoder_output_H = H_b * (2**num_decoder_blocks) # Assuming scale_factor=2 for all blocks
        final_decoder_output_W = W_b * (2**num_decoder_blocks)
        # Example: 4 * (2^4) = 4 * 16 = 64. So from (4,4) to (64,64) with scale 2.

        # We need to reach (70, 70).
        # Option 1: Calculate scale factors to directly reach 70. (70/4)**(1/4) for each step?
        # Scale factor for block i should upsample from spatial_i to spatial_i-1.
        # E.g., H_b -> H_b*s_1 -> H_b*s_1*s_2 ... -> 70.
        # Total scale = Product(s_i) = 70 / H_b.
        # If s_i are all equal, s = (70 / H_b)**(1/num_decoder_blocks)
        # Let's calculate scale factors for each stage to reach 70x70 from bottleneck H_b, W_b
        # The spatial size after N Conv3D stages is _skip_3d_spatial_sizes[-1]. Let's call this (D_last, H_last, W_last).
        # Decoder input spatial size is (H_last, W_last).
        # Decoder output size should be (70, 70).
        # Number of decoder blocks is len(decoder_channels).
        # Total spatial upsampling factor required is (70/H_last, 70/W_last).
        # Let's assume the decoder upsamples by a factor of 2 in each block for simplicity and match the original code.
        # This results in (H_last * 2^N, W_last * 2^N) spatial size after the decoder.
        # We need to use interpolation in the final head to reach exactly (70, 70).

        # Decoder spatial size sequence (if starting from H_last, W_last with scale 2 per block):
        # (H_last, W_last) -> (H_last*2, W_last*2) -> (H_last*4, W_last*4) -> (H_last*8, W_last*8) -> (H_last*16, W_last*16)
        # Example: (4, 4) -> (8, 8) -> (16, 16) -> (32, 32) -> (64, 64)

        decoder_scale_factors = [2] * num_decoder_blocks # Use integer scale factors for decoder blocks

        self.decoder = UnetDecoder2d(
            encoder_channels=decoder_enc_channels_for_init,
            skip_channels=decoder_skip_channels_for_init,
            decoder_channels=decoder_channels,
            scale_factors=decoder_scale_factors,
            norm_layer=monai.networks.blocks.Norm("INSTANCE", spatial_dims=2), # Use MONAI Norm wrapper
            attention_type=decoder_attention_type,
            upsample_mode=upsample_mode,
        )
        # Need to adjust norm_layer definition for UnetDecoder2d

        # --- Segmentation Head ---
        # Decoder last output shape: (B, decoder_channels[-1], H_dec_out, W_dec_out)
        # H_dec_out, W_dec_out = (self.bottleneck_spatial_size_3d[1] * (2**num_decoder_blocks),
        #                         self.bottleneck_spatial_size_3d[2] * (2**num_decoder_blocks))
        # Example: (4 * 16, 4 * 16) = (64, 64)

        # Target shape: (B, 1, 70, 70)
        # Need to go from (H_dec_out, W_dec_out) to (70, 70) and change channels.
        final_upsample_scale = (70 / final_decoder_output_H,
                                70 / final_decoder_output_W)
        # Example: (70/64, 70/64) = (1.09375, 1.09375)

        self.seg_head = SegmentationHead2d(
            in_channels=decoder_channels[-1],
            out_channels=1,
            scale_factor=final_upsample_scale, # Use calculated float scale for interpolation
            kernel_size=3,
            mode="nontrainable", # Use interpolation mode
        )

        # Final activation/scaling - handled in forward pass for test-time aug


        # Initialize weights (optional, but good practice)
        self._init_weights()


    def _init_weights(self):
        for m in self.modules():
            if isinstance(m, (nn.Conv3d, nn.Conv2d)):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, nn.Linear):
                trunc_normal_(m.weight, std=.02)
                if isinstance(m, nn.Linear) and m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, (nn.LayerNorm, nn.GroupNorm, nn.InstanceNorm2d)): # Added InstanceNorm2d
                nn.init.constant_(m.bias, 0)
                nn.init.constant_(m.weight, 1.0)
        # Initialize positional embedding
        if hasattr(self, 'pos_embed') and self.pos_embed is not None:
             trunc_normal_(self.pos_embed, std=.02)


    def forward(self, x):
        # Input x shape: (B, 1, 1000, 70, 70)

        if not self.training:
             # Handle test-time augmentation by averaging predictions
             pred_normal_scaled = self._forward_unscaled(x).squeeze(1) * 1500 + 3000 # (B, 70, 70)

             # Flip Time (dim 2) and W (dim 4) of input (B, 1, T, H, W)
             x_flipped = torch.flip(x, dims=[2, 4])
             pred_flipped_raw = self._forward_unscaled(x_flipped).squeeze(1) # (B, 70, 70)
             # Flip W (dim 2) of prediction (B, H, W), then scale
             pred_flipped_scaled = torch.flip(pred_flipped_raw, dims=[2]) * 1500 + 3000

             # Average scaled predictions
             averaged_pred_scaled = torch.mean(torch.stack([pred_normal_scaled, pred_flipped_scaled], dim=0), dim=0) # (B, 70, 70)
             return averaged_pred_scaled
        else:
            # Training mode: just run forward pass and return unscaled output
            return self._forward_unscaled(x).squeeze(1) # (B, 70, 70) unscaled


    def _forward_unscaled(self, x):
        # Helper method to run the core forward pass without final scaling/TTA

        # --- 3D Conv Encoder ---
        skip_features_3d = []
        x_3d = x
        for i, stage in enumerate(self.conv_stages):
            x_3d = stage(x_3d)
            # print(f"3D Conv Stage {i} out shape: {x_3d.shape}")
            if i < self.num_encoder_stages - 1: # Save skips from all but the last stage
                skip_features_3d.append(x_3d)

        # x_3d is now the bottleneck 3D feature map (B, C_last, D_b, H_b, W_b)


        # --- Transformer Bottleneck ---
        B, C_last, D_b, H_b, W_b = x_3d.shape
        # Flatten time and spatial dimensions for Transformer input
        # Permute to (B, D_b, H_b, W_b, C_last) then flatten spatial/temporal
        x_flat = x_3d.permute(0, 2, 3, 4, 1).contiguous().view(B, -1, C_last) # (B, D_b*H_b*W_b, C_last)
        # print("Flattened shape before proj:", x_flat.shape)

        # Project channels to embed_dim
        x_flat = self.transformer_input_proj(x_flat) # (B, sequence_length, embed_dim)
        # print("Flattened shape after proj:", x_flat.shape)


        # Add positional embedding
        # Ensure sequence length matches pos_embed
        if x_flat.shape[1] != self.pos_embed.shape[1]:
             print(f"Warning: Positional embedding sequence length mismatch. Expected {self.pos_embed.shape[1]}, got {x_flat.shape[1]}.")
             # Handle mismatch, e.g., resize pos_embed or reinitialize
             # For now, let's assume the calculated sequence_length is correct
             # If input size varies, pos_embed needs to be applied differently (e.g., interpolation)
             # Assuming fixed input spatial_size allows fixed pos_embed.
             pass # pos_embed is already defined based on calculated sequence_length

        x_flat = x_flat + self.pos_embed # Add positional embedding

        # Apply Transformer blocks
        for i, block in enumerate(self.transformer_blocks):
            x_flat = block(x_flat)
            # print(f"Transformer Block {i} out shape: {x_flat.shape}")

        # Apply final norm
        x_flat = self.transformer_norm(x_flat) # (B, sequence_length, embed_dim)


        # --- Project Transformer output to 2D and pool time ---
        # Reshape back to 3D-like structure: (B, embed_dim, D_b, H_b, W_b)
        sequence_length_check = D_b * H_b * W_b
        if x_flat.shape[1] != sequence_length_check:
             print(f"Error: Transformer output sequence length {x_flat.shape[1]} does not match expected {sequence_length_check} based on bottleneck size {self.bottleneck_spatial_size_3d}")
             # This indicates an error in calculation or mismatch between transformer and conv output
             raise ValueError("Transformer sequence length mismatch")

        # Reshape from (B, sequence_length, embed_dim) to (B, embed_dim, D_b, H_b, W_b)
        x_3d_like = x_flat.view(B, D_b, H_b, W_b, self.embed_dim).permute(0, 4, 1, 2, 3).contiguous()
        # print("Reshaped 3D-like shape:", x_3d_like.shape)

        # Pool over the time dimension (D_b) to get 2D feature map
        # Use mean pooling over the time dimension (dim 2)
        x_2d_bottleneck = torch.mean(x_3d_like, dim=2).squeeze(2) # -> (B, embed_dim, H_b, W_b)
        # print("2D Bottleneck shape after pooling:", x_2d_bottleneck.shape)


        # --- Prepare 2D Skip Connections ---
        # Process the saved 3D skip features through projection layers
        # skip_features_3d contains [stage0, stage1, stage2] 3D features
        # skip_projection_layers processes skips from stage 0, 1, 2 in order.
        skip_features_2d = []
        # Iterate through 3D Conv stages 0, 1, 2
        for i in range(self.num_encoder_stages - 1):
            skip_3d = skip_features_3d[i] # Get 3D feature from stage i
            skip_2d = self.skip_projection_layers[i](skip_3d) # Project to 2D using corresponding layer
            skip_features_2d.append(skip_2d)
            # print(f"Projected Skip {i} (from stage {i}) shape: {skip_2d.shape}")


        # Combine bottleneck and skip features for the Decoder input
        # Feats list for UnetDecoder2d: [bottleneck_2d, skip_level_1_2d, skip_level_2_2d, skip_level_3_2d]
        # where skip_level_1_2d comes from Conv3D Stage 0 (coarsest spatial skip)
        # skip_level_2_2d from Conv3D Stage 1
        # skip_level_3_2d from Conv3D Stage 2
        # The order in skip_features_2d is [stage0, stage1, stage2]. This is the correct order for the feats list after the bottleneck.
        decoder_feats = [x_2d_bottleneck] + skip_features_2d
        # print("Decoder feats list shapes:", [f.shape for f in decoder_feats])

        # --- 2D Decoder ---
        # UnetDecoder2d expects feats[::-1] as input to its internal processing loop.
        # But the loop logic `skip = feats[i]` or `skip = feats[i+1]` depends on implementation.
        # Looking at the provided UnetDecoder2d code, the loop is `for i, b in enumerate(self.blocks): skip= feats[i] if i < len(feats) else None`.
        # This means feats[0] is skip for block 0, feats[1] for block 1, etc.
        # This is NOT the standard U-Net structure (bottleneck, then skips).
        # The standard structure feeds the bottleneck as the *first* input to the *first* decoder block, and then uses skips.
        # Re-reading the base code's UnetDecoder2d forward again:
        # `res= [feats[0]]` -> res[0] is the first element of input `feats`.
        # `feats= feats[1:]` -> remaining elements are skips.
        # `for i, b in enumerate(self.blocks): skip= feats[i] if i < len(feats) else None; res.append(b(res[-1], skip=skip),)`
        # Block 0 input: res[-1] (feats[0]), skip=feats[0] (from the `feats = feats[1:]` line). So concat(feats[0], feats[0]). This looks wrong.

        # Let's assume the original UnetDecoder2d is meant to work with a list of features `feats` where `feats[0]` is the bottleneck
        # and `feats[1:]` are the skip connections (ordered from coarsest spatial to finest spatial).
        # And the decoder blocks process these in reverse order (finest spatial skip first).
        # So `feats` list should be [Bottleneck_2D, Skip1_2D, Skip2_2D, Skip3_2D] where Skip1 is coarsest.
        # Our `decoder_feats` list is [x_2d_bottleneck, skip_features_2d[0], skip_features_2d[1], skip_features_2d[2]].
        # where skip_features_2d[0] is from stage 0 (coarsest), [1] from stage 1, [2] from stage 2.
        # This `decoder_feats` list IS in the correct order for a standard U-Net decoder expecting [bottleneck, skip_coarsest, ..., skip_finest].

        # Re-reading the original UnetDecoder2d forward again. It reverses the input `feats` list: `res= [feats[0]]; feats= feats[1:]`.
        # `for i, b in enumerate(self.blocks): skip= feats[i] if i < len(feats) else None; res.append(b(res[-1], skip=skip),)`
        # If `feats_in` = [B, S1, S2, S3] (B=bottleneck, S1=coarsest skip, S2, S3=finer skips)
        # Inside decoder: `res = [B]`; `feats = [S1, S2, S3]`
        # Block 0: skip=feats[0]=S1. Input to block 0 is concat(B, S1). Output added to res.
        # Block 1: skip=feats[1]=S2. Input to block 1 is concat(res[-1], S2). Output added to res.
        # Block 2: skip=feats[2]=S3. Input to block 2 is concat(res[-1], S3). Output added to res.
        # Block 3: skip=feats[3] is out of bounds (len(feats)=3). skip=None. Input to block 3 is res[-1].

        # This implies the UnetDecoder2d takes `feats` list as [bottleneck, skip_coarsest, skip_mid, skip_finest, ...]
        # And processes them from bottleneck up, using corresponding skips.
        # The number of blocks in UnetDecoder2d (len(decoder_channels)) must match the number of skips provided in `feats` list.
        # len(decoder_channels) = 4. We need 4 skips + 1 bottleneck -> 5 items in `feats`.
        # Our encoder stages give 3 skips + 1 bottleneck = 4 items.
        # There is a mismatch in the number of features expected by the original UnetDecoder2d (with decoder_channels len 4) and the features provided by our 4-stage encoder.

        # Let's assume the original decoder_channels implies 3 decoder blocks + a final upsample, or some other structure.
        # If decoder_channels=(256, 128, 64, 32) are output channels of 4 blocks.
        # Let's try to match the input/skip channel logic of UnetDecoder2d init again.
        # `encoder_channels=decoder_enc_channels_for_init` (length 4)
        # `skip_channels=decoder_skip_channels_for_init` (length 4)
        # `decoder_channels=decoder_channels` (length 4)
        # UnetDecoder2d uses `zip(in_channels, skip_channels, decoder_channels)` which requires all three lists to have the same length.
        # `in_channels` internal list is `[encoder_channels[0]] + list(decoder_channels[:-1])`. Length 1 + (4-1) = 4. OK.
        # So UnetDecoder2d init works with lists of length 4.

        # The forward pass issue with `feats` list length remains.
        # If decoder_channels len is N, UnetDecoder2d makes N blocks.
        # The forward loop `for i, b in enumerate(self.blocks): skip = feats[i] if i < len(feats) else None`
        # expects `feats` list to have at least N elements for skips, plus the first element which is the bottleneck.
        # So `feats` length should be N+1.
        # Our encoder provides 4 feature levels (1 bottleneck, 3 skips). Total 4 features.
        # If N=4 (decoder_channels len), we need feats length 5.
        # If N=3 (decoder_channels len 3), we need feats length 4.

        # Let's adjust the decoder_channels to match the number of skips + 1 bottleneck.
        # We have 3 skips (from 3D stages 0, 1, 2) and 1 bottleneck (from stage 3). Total 4 feature levels.
        # Let's define the decoder with 3 blocks, using skips from stages 0, 1, 2.
        # Bottleneck from stage 3. Skips from stages 0, 1, 2.
        # Feats list: [Bottleneck_2D (from stage 3), Skip1_2D (from stage 0), Skip2_2D (from stage 1), Skip3_2D (from stage 2)]. Length 4.
        # Number of decoder blocks should be 3 (using 3 skips).
        # Let decoder_channels be (128, 64, 32) - length 3.
        # Target skip 2D channels for projection: decoder_channels[0], decoder_channels[1], decoder_channels[2]. -> (128, 64, 32).

        # Skip projection layers will project 3D stages 0, 1, 2 to 2D with channels 128, 64, 32 respectively.
        # skip_3d_info = [(32, sz0), (64, sz1), (128, sz2)]
        # target_skip_2d_channels_for_proj = [128, 64, 32]
        # skip_proj_pairs = [((32, sz0), 128), ((64, sz1), 64), ((128, sz2), 32)] -> Redo projection layers with these targets.

        # Decoder init:
        # encoder_channels (from feats[::-1]): [target_skip_2d_channels_for_proj[2], target_skip_2d_channels_for_proj[1], target_skip_2d_channels_for_proj[0], embed_dim]
        # Example: [32, 64, 128, embed_dim]. Length 4.
        # skip_channels (from feats[1:] + [0]): [target_skip_2d_channels_for_proj[0], target_skip_2d_channels_for_proj[1], target_skip_2d_channels_for_proj[2], 0]
        # Example: [128, 64, 32, 0]. Length 4.
        # decoder_channels = (128, 64, 32). Length 3.

        # Still mismatch in lengths passed to UnetDecoder2d init (encoder_channels, skip_channels len 4, decoder_channels len 3).
        # The original UnetDecoder2d seems to use `encoder_channels` and `skip_channels` params with len = len(decoder_channels) + 1.
        # And the internal `in_channels`, `skip_channels` lists derived from these also have len = len(decoder_channels).

        # Let's go back to the original base code's use of UnetDecoder2d.
        # `ecs` (reversed ConvNeXt feats): length 4. `encoder_channels` param = ecs (len 4).
        # `skip_channels` param is derived from `ecs[1:]` -> length 3. Then `+ [0]` makes it length 4. `skip_channels` param has length 4.
        # `decoder_channels`: length 4.
        # All input lists to UnetDecoder2d init have length 4.
        # UnetDecoder2d creates 4 blocks.

        # This structure implies: `encoder_channels` param provides channel counts for `feats[::-1]`
        # `skip_channels` param provides channel counts for `feats[1:] + [0]`
        # If encoder_channels param is length N, skip_channels param is length N, decoder_channels param is length N.
        # UnetDecoder2d creates N blocks.
        # It needs `feats` list of length N.
        # `feats` list is [B, S1, ..., SN-1]. B is bottleneck, S1..SN-1 are N-1 skips.
        # With N=4: `feats` length 4. [B, S1, S2, S3]. 1 bottleneck, 3 skips.
        # encoder_channels param (len 4): [S3_chan, S2_chan, S1_chan, B_chan]
        # skip_channels param (len 4): [S1_chan, S2_chan, S3_chan, 0]
        # decoder_channels param (len 4): output channels [D1, D2, D3, D4]

        # So we need 4 levels of 2D features from our encoder: 1 bottleneck, 3 skips.
        # Bottleneck from 3D Stage 3 (embed_dim)
        # Skip from 3D Stage 2 (encoder_dims[2]) -> project to 2D
        # Skip from 3D Stage 1 (encoder_dims[1]) -> project to 2D
        # Skip from 3D Stage 0 (encoder_dims[0]) -> project to 2D

        # Let's redefine the target 2D channels for the skips based on matching `encoder_channels` param list structure.
        # `encoder_channels` param list: [Skip3_2D_chan, Skip2_2D_chan, Skip1_2D_chan, Bottleneck_2D_chan]
        # Let's set these channels to match decoder_channels[::-1]: [32, 64, 128, 256].
        # So, Bottleneck_2D chan = 256. (Need to project embed_dim to 256).
        # Skip3_2D chan = 32. (Project encoder_dims[2]=128 to 32).
        # Skip2_2D chan = 64. (Project encoder_dims[1]=64 to 64).
        # Skip1_2D chan = 128. (Project encoder_dims[0]=32 to 128).

        # Update bottleneck projection: embed_dim -> 256
        self.bottleneck_2d_proj = nn.Conv2d(embed_dim, decoder_channels[-1], kernel_size=1) # Project embed_dim to 256

        # Update skip projection target channels:
        # Skip from Stage 2 (encoder_dims[2]=128): target 32
        # Skip from Stage 1 (encoder_dims[1]=64): target 64
        # Skip from Stage 0 (encoder_dims[0]=32): target 128

        target_skip_2d_channels_for_proj = [decoder_channels[-1], decoder_channels[-2], decoder_channels[-3]] # [32, 64, 128] ? No, match order in encoder_channels param.
        # encoder_channels param order is [Skip3_2D_chan, Skip2_2D_chan, Skip1_2D_chan, Bottleneck_2D_chan]
        # Let's set:
        # Bottleneck_2D chan = decoder_channels[-1] = 32. (Project embed_dim to 32).
        # Skip3_2D chan = decoder_channels[-2] = 64. (Project encoder_dims[2]=128 to 64).
        # Skip2_2D chan = decoder_channels[-3] = 128. (Project encoder_dims[1]=64 to 128).
        # Skip1_2D chan = decoder_channels[-4] = 256. (Project encoder_dims[0]=32 to 256).

        self.bottleneck_2d_proj = nn.Conv2d(embed_dim, decoder_channels[-1], kernel_size=1) # Project embed_dim to 32

        # Skip projection layers (for stages 0, 1, 2):
        # Skip from Stage 0 (32) -> target 256
        # Skip from Stage 1 (64) -> target 128
        # Skip from Stage 2 (128) -> target 64

        target_skip_2d_channels_for_proj = list(decoder_channels[:-1])[::-1] # [64, 128, 256]
        skip_3d_info_for_proj = list(zip(encoder_dims[:-1], self._skip_3d_spatial_sizes[:-1])) # [(32, sz0), (64, sz1), (128, sz2)]
        skip_proj_pairs = list(zip(skip_3d_info_for_proj, target_skip_2d_channels_for_proj)) # [((32, sz0), 64), ((64, sz1), 128), ((128, sz2), 256)] # Order mismatch


        # Let's match the *input* channels to the UnetDecoder2d's internal blocks.
        # UnetDecoder2d block i input concat channels: `in_channels[i]` and `skip_channels[i]`
        # `in_channels` internal: `[encoder_channels[0]] + list(decoder_channels[:-1])`
        # `skip_channels` internal: `list(encoder_channels[1:]) + [0]`

        # If decoder_channels len is 4, UnetDecoder2d expects encoder_channels param len 4.
        # Let's set `encoder_channels` param list: [Chan_Input_Block0, Chan_Input_Block1_Skip, Chan_Input_Block2_Skip, Chan_Input_Block3_Skip]
        # Where Chan_Input_Block0 is the channel of the bottleneck feature.
        # And Chan_Input_Block i_Skip is the channel of the skip feature used in Block i.
        # Let's set `encoder_channels` param = (embed_dim, Chan_Skip3_2D, Chan_Skip2_2D, Chan_Skip1_2D)
        # And `skip_channels` param = (Chan_Skip3_2D, Chan_Skip2_2D, Chan_Skip1_2D, 0)
        # Where Skip3 is from stage 2, Skip2 from stage 1, Skip1 from stage 0.

        # Let's try setting target skip channels to match the decoder_channels:
        # Skip from stage 2 (128) -> project to decoder_channels[0]=256
        # Skip from stage 1 (64) -> project to decoder_channels[1]=128
        # Skip from stage 0 (32) -> project to decoder_channels[2]=64

        target_skip_2d_channels_for_proj = list(decoder_channels[:-1]) # [256, 128, 64]
        skip_3d_info_for_proj = list(zip(encoder_dims[:-1], self._skip_3d_spatial_sizes[:-1])) # [(32, sz0), (64, sz1), (128, sz2)]
        skip_proj_pairs = list(zip(skip_3d_info_for_proj[::-1], target_skip_2d_channels_for_proj)) # [((128, sz2), 256), ((64, sz1), 128), ((32, sz0), 64)]

        self.skip_projection_layers = nn.ModuleList()
        self._projected_skip_2d_channels = [] # Store the output channels of projected skips

        for (skip_c_3d, skip_sz_3d), target_2d_channels in skip_proj_pairs:
             self.skip_projection_layers.append(
                 nn.Sequential(
                     nn.AdaptiveAvgPool3d((1, skip_sz_3d[1], skip_sz_3d[2])), # Pool time D_i to 1
                     nn.Conv2d(skip_c_3d, target_2d_channels, kernel_size=1), # Project channels
                 )
             )
             self._projected_skip_2d_channels.append(target_2d_channels)
        # Projected skip channels (ordered from finest 3D stage -> coarsest 3D stage): [256, 128, 64]


        # Define encoder_channels and skip_channels for UnetDecoder2d init:
        # encoder_channels param (len 4): [Bottleneck_2D_chan, Skip3_2D_chan, Skip2_2D_chan, Skip1_2D_chan]
        # Based on standard U-Net, this list defines the *input channels* to the blocks' concatenation.
        # This requires feats list [B, S1, S2, S3] (bottleneck, coarsest skip, mid, finest).
        # Let's assume UnetDecoder2d expects `encoder_channels` list elements to be the channels of `feats`[::-1]
        # `feats` list: [Bottleneck_2D (embed_dim), Skip1_2D (proj from stg0, target 64), Skip2_2D (proj from stg1, target 128), Skip3_2D (proj from stg2, target 256)]
        # Feats channels: [embed_dim, 64, 128, 256]
        # `feats[::-1]` channels: [256, 128, 64, embed_dim]
        decoder_enc_channels_for_init = self._projected_skip_2d_channels[::-1] + [embed_dim]

        # `skip_channels` param list: [Skip1_2D_chan, Skip2_2D_chan, Skip3_2D_chan, 0]
        # Based on `feats[1:] + [0]`
        # Feats[1:] channels: [64, 128, 256]
        # `skip_channels` param: [64, 128, 256, 0]
        decoder_skip_channels_for_init = self._projected_skip_2d_channels + [0]

        # Recheck lengths:
        # decoder_channels: (256, 128, 64, 32) - len 4
        # decoder_enc_channels_for_init: len 4
        # decoder_skip_channels_for_init: len 4
        # This matches the expected input list lengths for UnetDecoder2d init.

        # The input `feats` list to UnetDecoder2d forward needs len 4.
        # `feats` list: [Bottleneck_2D (embed_dim), Skip1_2D (from stg0, proj to 64), Skip2_2D (from stg1, proj to 128), Skip3_2D (from stg2, proj to 256)]
        # Bottleneck_2D: embed_dim channels, spatial (H_b, W_b)
        # Skip1_2D: 64 channels, spatial (_skip_3d_spatial_sizes[0][1:])
        # Skip2_2D: 128 channels, spatial (_skip_3d_spatial_sizes[1][1:])
        # Skip3_2D: 256 channels, spatial (_skip_3d_spatial_sizes[2][1:])

        # Spatial sizes need to match upsampling:
        # Decoder block 0 goes from (H_b, W_b) to next size, uses skip from stage 2 spatial (_skip_3d_spatial_sizes[2][1:])
        # This requires the skip spatial size to match the output size of block 0.
        # If scale_factor=2 for block 0, output size is (H_b*2, W_b*2).
        # So skip from stage 2 spatial size (_skip_3d_spatial_sizes[2][1:]) must be (H_b*2, W_b*2).
        # Let's verify sizes: Bottleneck (30, 2, 2) -> H_b=2, W_b=2. Skip stage 2 (_skip_3d_spatial_sizes[2]) (61, 4, 4) -> spatial (4, 4).
        # Decoder input (2, 2). Block 0 scale 2 -> output (4, 4). Skip stage 2 is (4, 4). Matches.

        # Decoder block 1 output (4*2, 4*2) = (8, 8). Skip stage 1 (_skip_3d_spatial_sizes[1]) (124, 8, 8) -> spatial (8, 8). Matches.
        # Decoder block 2 output (8*2, 8*2) = (16, 16). Skip stage 0 (_skip_3d_spatial_sizes[0]) (250, 17, 17) -> spatial (17, 17). Mismatch! 16 != 17.

        # The spatial sizes of the 3D Conv outputs (when pooled to 2D) must match the required spatial sizes for the decoder skips.
        # This depends on the exact kernel/stride/padding in both encoder (3D Conv) and decoder (2D UpSample).
        # Let's re-evaluate 3D Conv spatial sizes with k=2, s=2, p=1.
        # Input (1000, 70, 70)
        # Stg 0: k=2, s=2, p=1 -> (500, 35, 35)
        # Stg 1: k=2, s=2, p=1 -> (250, 17, 17)
        # Stg 2: k=2, s=2, p=1 -> (125, 8, 8)
        # Stg 3: k=2, s=2, p=1 -> (62, 4, 4) -> Bottleneck H,W = (4,4)

        # Decoder spatial sizes (starting from (4,4), scale 2):
        # Block 0 input (4,4) -> output (8,8). Requires Skip from stage 2 with spatial (8,8). Matches.
        # Block 1 input (8,8) -> output (16,16). Requires Skip from stage 1 with spatial (17,17). Mismatch.
        # Block 2 input (16,16) -> output (32,32). Requires Skip from stage 0 with spatial (35,35). Mismatch.
        # Block 3 input (32,32) -> output (64,64). No skip needed.

        # This mismatch in spatial sizes requires adjusting the encoder 3D Conv strides/kernels or using interpolation layers explicitly.
        # A common trick is to slightly adjust padding or strides in early layers to ensure consistent spatial sizes downstream.
        # Or use AdaptiveAvgPool2d/3d to force spatial sizes after each stage/projection.
        # Let's use AdaptiveAvgPool2d in skip projections to force the spatial size to match the expected decoder input size.

        # Expected decoder skip spatial sizes:
        # Skip from Stage 2 (used by Block 0): H_b*s_0, W_b*s_0. If scale_factors=(s0, s1, s2, s3) = (2,2,2,2), this is (4*2, 4*2) = (8, 8).
        # Skip from Stage 1 (used by Block 1): H_b*s_0*s_1, W_b*s_0*s_1. (4*2*2, 4*2*2) = (16, 16).
        # Skip from Stage 0 (used by Block 2): H_b*s_0*s_1*s_2, W_b*s_0*s_1*s_2. (4*2*2*2, 4*2*2*2) = (32, 32).

        # Recalculate skip projection layers to use AdaptiveAvgPool2d to target specific (H, W) sizes.
        # target_skip_spatial_sizes = [(8, 8), (16, 16), (32, 32)] # Corresponding to skips from Stg2, Stg1, Stg0

        self.skip_projection_layers = nn.ModuleList()
        skip_3d_info_for_proj = list(zip(encoder_dims[:-1], self._skip_3d_spatial_sizes[:-1])) # [(32, sz0), (64, sz1), (128, sz2)]
        target_skip_2d_channels_for_proj = list(decoder_channels[:-1])[::-1] # [64, 128, 256] - Channels for skips from Stg0, Stg1, Stg2.
        target_skip_spatial_sizes = [(32, 32), (16, 16), (8, 8)] # Spatial sizes for skips from Stg0, Stg1, Stg2.

        # Combine info: ( (3D_chan, 3D_size), target_2D_chan, target_2D_size )
        skip_proj_info = list(zip(skip_3d_info_for_proj, target_skip_2d_channels_for_proj, target_skip_spatial_sizes)) # [((32, sz0), 64, (32,32)), ((64, sz1), 128, (16,16)), ((128, sz2), 256, (8,8))]

        for (skip_c_3d, skip_sz_3d), target_2d_channels, target_2d_size in skip_proj_info:
            self.skip_projection_layers.append(
                 nn.Sequential(
                     nn.AdaptiveAvgPool3d((1, skip_sz_3d[1], skip_sz_3d[2])), # Pool time D_i to 1 -> (B, C_in, 1, H_i, W_i)
                     nn.Conv2d(skip_c_3d, target_2d_channels, kernel_size=1), # Project channels -> (B, C_out, H_i, W_i)
                     nn.AdaptiveAvgPool2d(target_2d_size) # Force spatial size
                 )
            )
        # Projected skip channels (ordered from Stg0 -> Stg2): [64, 128, 256]
        self._projected_skip_2d_channels = target_skip_2d_channels_for_proj

        # Decoder init params are correct based on this structure:
        # encoder_channels param: [256, 128, 64, embed_dim]
        # skip_channels param: [64, 128, 256, 0]
        # decoder_channels param: (256, 128, 64, 32)

        # Need to project bottleneck channel (embed_dim) to match the expected channel for the first input in `encoder_channels` param list.
        # The first element of `encoder_channels` param (256) is used as the channel of the bottleneck feature by UnetDecoder2d's internal `in_channels`.
        # self.bottleneck_2d_proj needs to map embed_dim -> 256.
        self.bottleneck_2d_proj = nn.Conv2d(embed_dim, decoder_enc_channels_for_init[-1], kernel_size=1) # project embed_dim to 256


        # Final check on seg head scale factor.
        # Decoder outputs (H_dec_out, W_dec_out). Target (70, 70).
        # Using scale_factors=(2,2,2,2) and starting from bottleneck (H_b, W_b)=(4,4), output is (64,64).
        # Seg head needs to go from (64,64) to (70,70). Scale (70/64, 70/64). Correct.

        # Ensure MONAI Norm wrapper is used correctly
        decoder_norm_layer = lambda x: monai.networks.blocks.Norm("INSTANCE", spatial_dims=2)(x)


        self.decoder = UnetDecoder2d(
            encoder_channels=decoder_enc_channels_for_init, # [256, 128, 64, embed_dim]
            skip_channels=decoder_skip_channels_for_init, # [64, 128, 256, 0]
            decoder_channels=decoder_channels, # (256, 128, 64, 32)
            scale_factors=decoder_scale_factors, # [2, 2, 2, 2]
            norm_layer=decoder_norm_layer,
            attention_type=decoder_attention_type,
            upsample_mode=upsample_mode,
        )


    def _forward_unscaled(self, x):
        # Helper method to run the core forward pass without final scaling/TTA

        # --- 3D Conv Encoder ---
        skip_features_3d = []
        x_3d = x
        for i, stage in enumerate(self.conv_stages):
            x_3d = stage(x_3d)
            if i < self.num_encoder_stages - 1: # Save skips from stages 0, 1, 2
                skip_features_3d.append(x_3d)

        # x_3d is now the bottleneck 3D feature map from stage 3 (B, C_last, D_b, H_b, W_b)


        # --- Transformer Bottleneck ---
        B, C_last, D_b, H_b, W_b = x_3d.shape
        x_flat = x_3d.permute(0, 2, 3, 4, 1).contiguous().view(B, -1, C_last) # (B, sequence_length, C_last)

        x_flat = self.transformer_input_proj(x_flat) # (B, sequence_length, embed_dim)

        # Add positional embedding
        x_flat = x_flat + self.pos_embed

        # Apply Transformer blocks
        for block in self.transformer_blocks:
            x_flat = block(x_flat)

        # Apply final norm
        x_flat = self.transformer_norm(x_flat) # (B, sequence_length, embed_dim)


        # --- Project Transformer output to 2D and pool time ---
        # Reshape back to 3D-like structure: (B, embed_dim, D_b, H_b, W_b)
        x_3d_like = x_flat.view(B, D_b, H_b, W_b, self.embed_dim).permute(0, 4, 1, 2, 3).contiguous()

        # Pool over the time dimension (D_b) to get 2D feature map
        # Then project channels to match decoder_enc_channels_for_init[-1] (256)
        x_2d_bottleneck_pooled = torch.mean(x_3d_like, dim=2).squeeze(2) # (B, embed_dim, H_b, W_b)
        x_2d_bottleneck = self.bottleneck_2d_proj(x_2d_bottleneck_pooled) # (B, 256, H_b, W_b)
        # print("Bottleneck 2D shape:", x_2d_bottleneck.shape)


        # --- Prepare 2D Skip Connections ---
        # Process the saved 3D skip features from stages 0, 1, 2
        skip_features_2d = []
        for i in range(self.num_encoder_stages - 1):
            skip_3d = skip_features_3d[i] # Get 3D feature from stage i (0, 1, or 2)
            skip_2d = self.skip_projection_layers[i](skip_3d) # Project to 2D using corresponding layer
            skip_features_2d.append(skip_2d)
            # print(f"Projected Skip {i} shape: {skip_2d.shape}")


        # Combine bottleneck and skip features for the Decoder input
        # Feats list for UnetDecoder2d: [Bottleneck_2D, Skip1_2D, Skip2_2D, Skip3_2D]
        # Skip1 from stage 0, Skip2 from stage 1, Skip3 from stage 2.
        decoder_feats = [x_2d_bottleneck] + skip_features_2d # This order is [Bottleneck, Stg0_skip, Stg1_skip, Stg2_skip]
        # The required order for UnetDecoder2d feats list is [Bottleneck, Skip_coarsest, Skip_mid, Skip_finest]
        # So, skip_features_2d is [Stg0_skip, Stg1_skip, Stg2_skip], which is already ordered coarsest to finest.
        # Correct: decoder_feats = [x_2d_bottleneck] + skip_features_2d
        # This means feats[0]=Bottleneck, feats[1]=Stg0_skip, feats[2]=Stg1_skip, feats[3]=Stg2_skip.

        # Check shapes just before decoder:
        # print("Shapes feeding decoder:")
        # for i, f in enumerate(decoder_feats): print(f"  feats[{i}]: {f.shape}")
        # Expected:
        # feats[0]: (B, 256, H_b, W_b) eg (B, 256, 4, 4)
        # feats[1]: (B, 64, H_stg0_proj, W_stg0_proj) eg (B, 64, 32, 32)
        # feats[2]: (B, 128, H_stg1_proj, W_stg1_proj) eg (B, 128, 16, 16)
        # feats[3]: (B, 256, H_stg2_proj, W_stg2_proj) eg (B, 256, 8, 8)

        # This matches the expected input list structure and order for the modified UnetDecoder2d based on re-interpretation.

        # --- 2D Decoder ---
        decoder_output_features = self.decoder(decoder_feats) # Processes feats[::-1] internally

        # --- Segmentation Head ---
        # The final output of the decoder is the last element in the list.
        seg_output = self.seg_head(decoder_output_features[-1])
        # print("Seg Head output shape:", seg_output.shape) # Should be (B, 1, 70, 70)

        return seg_output # (B, 1, 70, 70) - raw, unscaled prediction


# --- Reuse original UnetDecoder2d and SegmentationHead2d ---
# Copy paste the original definitions or ensure they are available.
# Assuming the original definitions are in the user's provided code block.

# Dummy DropPath if timm isn't fully available
class DropPath(nn.Module):
    """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
    def __init__(self, drop_prob=None):
        super(DropPath, self).__init__()
        self.drop_prob = drop_prob

    def forward(self, x):
        return timm.models.layers.drop_path(x, self.drop_prob, self.training) # Rely on timm's functional drop_path


# Replace DropPath in TransformerBlock if timm.models.layers.DropPath is not used directly
# from timm.models.layers import DropPath # Use this if timm is fully installed and imported

# Re-define TransformerBlock to use the dummy DropPath if needed
# class TransformerBlock(nn.Module):
#     def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
#         super().__init__()
#         self.norm1 = norm_layer(dim)
#         self.attn = nn.MultiheadAttention(dim, num_heads, dropout=attn_drop, bias=qkv_bias, batch_first=True)
#         self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() # Use dummy DropPath
#         self.norm2 = norm_layer(dim)
#         mlp_hidden_dim = int(dim * mlp_ratio)
#         self.mlp = nn.Sequential(
#             nn.Linear(dim, mlp_hidden_dim),
#             act_layer(),
#             nn.Dropout(drop),
#             nn.Linear(mlp_hidden_dim, dim),
#             nn.Dropout(drop)
#         )
#
#     def forward(self, x):
#         residual = x
#         x = self.norm1(x)
#         x = self.attn(x, x, x)[0]
#         x = residual + self.drop_path(x)
#
#         residual = x
#         x = self.norm2(x)
#         x = self.mlp(x)
#         x = residual + self.drop_path(x)
#         return x
#
# # Ensure DropPath is available if drop_path > 0 is used in HybridModel init

In [None]:
# --- Main Training Loop Integration ---
# Replace the old model initialization with the new HybridModel

set_seed(cfg.seed)

train_ds = CustomDataset(cfg=cfg, file_pairs=cfg.file_pairs, mode="train")
# Replaced DistributedSampler with standard DataLoader and shuffle
train_dl = torch.utils.data.DataLoader(
    train_ds,
    batch_size= cfg.batch_size,
    num_workers= 0, # Keep 0 for simpler debugging initially, increase for performance
    shuffle=True, # Add shuffle for training
    pin_memory=True if torch.cuda.is_available() else False,
)

valid_ds = CustomDataset(cfg=cfg, file_pairs=cfg.file_pairs, mode="valid")
# Replaced DistributedSampler with standard DataLoader
valid_dl = torch.utils.data.DataLoader(
    valid_ds,
    batch_size= cfg.batch_size_val,
    num_workers= 0, # Keep 0 for simpler debugging initially
    shuffle=False, # No shuffle for validation
    pin_memory=True if torch.cuda.is_available() else False,
)

In [None]:

# Test dataset output shape
try:
    x, y = next(iter(train_dl))
    print("Sample data batch shape:", x.shape, "Sample label batch shape:", y.shape)
    assert x.shape[1:] == (cfg.in_channels, cfg.spatial_size[0], cfg.spatial_size[1], cfg.spatial_size[2]), f"Data shape mismatch: Expected {(cfg.in_channels,) + cfg.spatial_size}, got {x.shape[1:]}"
    assert y.shape[1:] == (70, 70), f"Label shape mismatch: Expected (70, 70), got {y.shape[1:]}"
    print("Dataset shapes confirmed.")
except Exception as e:
    print(f"Error testing dataset shapes: {e}")
    print("Please check CustomDataset logic and data file shapes.")
    # sys.exit(1) # Exit if dataset shapes are wrong

In [None]:

# ========== Model / Optim ==========
# Instantiate the new HybridModel
model = HybridModel(
    in_channels=cfg.in_channels,
    spatial_size=cfg.spatial_size,
    encoder_dims=cfg.encoder_dims,
    embed_dim=cfg.embed_dim,
    num_heads=cfg.num_heads,
    mlp_ratio=cfg.mlp_ratio,
    transformer_depth=cfg.transformer_depth,
    decoder_channels=cfg.decoder_channels,
    decoder_attention_type=cfg.decoder_attention_type,
    upsample_mode=cfg.upsample_mode,
).to(cfg.device)


if cfg.ema:
    if cfg.local_rank == 0:
        print("Initializing EMA model..")
    # Set device explicitly to 'cpu' or 'cuda' for EMA
    ema_model = ModelEMA(
        model,
        decay=cfg.ema_decay,
        device=cfg.device, # EMA model on the same device as the main model
    )
else:
    ema_model = None

criterion = nn.L1Loss() # Mean Absolute Error is suitable for regression/velocity maps
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)

# Removed GradScaler for simplicity / CPU usage

# Example of tracking UD - keep or remove as needed
parameters_conv = []
for m in model.modules():
    if isinstance(m, (nn.Conv2d, nn.Conv3d)) and m.weight.requires_grad: # Include Conv3d
        p = m.weight
        if p.ndim >= 4: # Consider 3D and 2D conv weights
            parameters_conv.append(p)
ud = []
eps = 1e-8

best_loss= 1_000_000
val_loss= 1_000_000 # Initialize val_loss for logging on epoch 0


print(f"Starting training for {cfg.epochs} epochs...")

for epoch in range(0, cfg.epochs+1):
    if epoch > 0: # Skip epoch 0 training loop, only validate initially
        tstart= time.time()
        # Removed sampler.set_epoch - not needed for standard DataLoader

        # Train loop
        model.train()
        total_loss = []
        # tqdm for train loop only on rank 0
        train_loop = tqdm(train_dl, disable=cfg.local_rank != 0, desc=f"Epoch {epoch} Training")
        for i, (x, y) in enumerate(train_loop):
            x = x.to(cfg.device)
            y = y.to(cfg.device) # Label is (B, 70, 70)

            # Predict raw unscaled output (B, 70, 70)
            logits_unscaled = model(x)

            # Scale logits to match label range before calculating loss
            # Assuming labels are in range [3000, 4500] based on original code's scaling
            # And model predicts values that should be scaled by * 1500 + 3000
            # So, target `y` should be UNscaled to match `logits_unscaled` range (e.g., 0-1)
            # Unscale y: (y - 3000) / 1500
            y_unscaled = (y - 3000) / 1500.0

            # Clamp y_unscaled to [0, 1] range in case of floating point issues or label noise
            y_unscaled = torch.clamp(y_unscaled, 0.0, 1.0)

            loss = criterion(logits_unscaled, y_unscaled)

            loss.backward()

            ### Trach UD (Optional)
            if parameters_conv: # Only track if conv parameters were found
                lr = optimizer.param_groups[0]['lr']
                with torch.no_grad():
                    ud.append([
                            (lr * (p.grad.std() + eps) / (p.data.std() + eps)).log10().item()
                            if p.grad is not None else float('-inf')
                            for p in parameters_conv
                    ])

            optimizer.step()
            optimizer.zero_grad()

            total_loss.append(loss.item())

            if ema_model is not None:
                ema_model.update(model)

        avg_train_loss = np.mean(total_loss)
        if cfg.local_rank == 0:
            print(f"Epoch {epoch} Train Loss: {avg_train_loss:.4f}")
            print(f"Epoch {epoch} Training Time: {format_time(time.time() - tstart)}")


    # ========== Valid ==========
    model.eval()
    val_logits = [] # Store scaled predictions
    val_targets = []
    # tqdm for validation loop on rank 0
    valid_loop = tqdm(valid_dl, disable=cfg.local_rank != 0, desc=f"Epoch {epoch} Validation")
    with torch.no_grad():
        for x, y in valid_loop:
            x = x.to(cfg.device)
            # y is already scaled [3000, 4500]
            y = y.to(cfg.device)

            # Model returns scaled prediction during eval
            if ema_model is not None:
                 # Access the underlying model from EMA wrapper
                 # EMA model should be in eval mode already
                 out = ema_model.module(x)
            else:
                out = model(x) # Model returns scaled output in eval mode

            val_logits.append(out.cpu())
            val_targets.append(y.cpu())

        val_logits= torch.cat(val_logits, dim=0) # (N, 70, 70) scaled
        val_targets= torch.cat(val_targets, dim=0) # (N, 70, 70) scaled

        # Calculate loss on scaled values during validation
        val_loss = criterion(val_logits, val_targets).item()

    if cfg.local_rank == 0:
        print(f"Epoch {epoch} Validation Loss: {val_loss:.4f}")

        # Early Stopping Check
        if val_loss < best_loss:
            best_loss = val_loss
            cfg.early_stopping["streak"] = 0
            # Save best model (main model and EMA if exists)
            print("Validation loss improved, saving model...")
            # Ensure directory exists
            os.makedirs("./checkpoints", exist_ok=True)
            torch.save(model.state_dict(), f"./checkpoints/best_model_epoch{epoch:03d}.pth")
            if ema_model is not None:
                 torch.save(ema_model.module.state_dict(), f"./checkpoints/best_ema_model_epoch{epoch:03d}.pth")
        else:
            cfg.early_stopping["streak"] += 1
            print(f"Validation loss did not improve. Streak: {cfg.early_stopping['streak']}")
            if cfg.early_stopping["streak"] >= cfg.early_stopping["patience"]:
                print(f"Early stopping triggered after {cfg.early_stopping['patience']} epochs without improvement.")
                break # Exit training loop

    if epoch == 0 and cfg.epochs > 0: # If training is actually planned, print val loss for epoch 0 before first train
         print(f"Epoch {epoch} Validation Loss: {val_loss:.4f} (Initial Eval)")

    # Removed barrier/all_reduce - assuming single process training

print("Training finished.")

# --- Re-define UnetDecoder2d and SegmentationHead2d here if they weren't in the original block ---
# (Copy paste from the user's original code)
class ConvBnAct2d(nn.Module):
    def __init__(
        self,
        in_channels,
        out_channels,
        kernel_size,
        padding: int = 0,
        stride: int = 1,
        norm_layer: nn.Module = nn.Identity, # Using nn.Identity placeholder
        act_layer: nn.Module = nn.ReLU,
    ):
        super().__init__()

        self.conv= nn.Conv2d(
            in_channels,
            out_channels,
            kernel_size,
            stride=stride,
            padding=padding,
            bias=False, # Bias is usually false when followed by Norm
        )
        # Check if norm_layer is a type or an instance
        if isinstance(norm_layer, type) and issubclass(norm_layer, nn.Module):
             self.norm = norm_layer(out_channels) if norm_layer != nn.Identity else nn.Identity()
        elif isinstance(norm_layer, nn.Module):
             self.norm = norm_layer # Assume it's an initialized instance
        else:
             self.norm = nn.Identity() # Default if not a valid norm type/instance


        self.act= act_layer(inplace=True) if act_layer != nn.Identity else nn.Identity()


    def forward(self, x):
        x = self.conv(x)
        x = self.norm(x)
        x = self.act(x)
        return x


class SCSEModule2d(nn.Module):
    def __init__(self, in_channels, reduction=16):
        super().__init__()
        self.cSE = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(in_channels, in_channels // reduction, 1),
            nn.GELU(), # Changed Tanh to GELU for consistency
            nn.Conv2d(in_channels // reduction, in_channels, 1),
            nn.Sigmoid(),
        )
        self.sSE = nn.Sequential(
            nn.Conv2d(in_channels, 1, 1),
            nn.Sigmoid(),
            )

    def forward(self, x):
        return x * self.cSE(x) + x * self.sSE(x)

class Attention2d(nn.Module):
    def __init__(self, name, **params):
        super().__init__()
        if name is None:
            self.attention = nn.Identity(**params)
        elif name == "scse":
            # Pass in_channels to SCSEModule2d
            self.attention = SCSEModule2d(**params)
        else:
            raise ValueError("Attention {} is not implemented".format(name))

    def forward(self, x):
        return self.attention(x)

class DecoderBlock2d(nn.Module):
    def __init__(
        self,
        in_channels, # Input channels from previous decoder block / bottleneck
        skip_channels, # Input channels from skip connection
        out_channels, # Output channels of this block
        norm_layer: nn.Module = nn.Identity,
        attention_type: str = None,
        intermediate_conv: bool = False, # Not used in the base code's UnetDecoder2d
        upsample_mode: str = "deconv", # Used in base code
        scale_factor: int = 2, # Used in base code
    ):
        super().__init__()

        # Upsample block. Input channels should be `in_channels`. Output channels depends on mode.
        # If mode is "deconv", out_channels is required. If "pixelshuffle", in_channels is required.
        # MONAI UpSample requires out_channels for "deconv".
        # Let's assume the deconv layer immediately changes channels to the desired size before concatenation.
        # A common pattern is Deconv -> Concatenate Skip -> Conv -> Conv.
        # The base code's UnetDecoder2d uses UpSample, then concat, then Conv1, Conv2.
        # UpSample output channels become part of the concatenation input.
        # If upsample_mode is "deconv", UpSample produces `out_channels` features at larger spatial size.
        # If upsample_mode is not "deconv", UpSample produces `in_channels` features at larger spatial size.
        # The base code's UpSample definition in UnetDecoder2d is `UpSample(..., in_channels=ic, out_channels=ic, ...)`
        # So UpSample outputs `ic` channels. This makes sense before concatenation.

        self.upsample = UpSample(
            spatial_dims= 2,
            in_channels= in_channels,
            out_channels= in_channels if upsample_mode != "deconv" else in_channels, # Or specify out_channels?
            scale_factor= scale_factor,
            mode= upsample_mode,
            # Deconv mode requires out_channels explicitly
            # Let's check MONAI UpSample definition:
            # __init__(self, spatial_dims, in_channels, out_channels=None, scale_factor=2, mode='nearest', align_corners=None):
            # If mode == "deconv", out_channels should be specified for ConvTranspose.
            # If out_channels is None, it defaults to in_channels.
            # Let's assume deconv outputs `in_channels` features by default.
            # Or maybe it should output a number of channels suitable for concatenation?
            # A common deconv block is Deconv(in, out) -> Relu/Norm -> Concat(skip). The 'out' is the channels after deconv.
            # In the base code, `in_channels` to the block is `ic`, `skip_channels` is `sc`.
            # The concat layer input channels are `ic + sc`. First conv is `ConvBnAct2d(ic + sc, dc, ...)`.
            # The UpSample is `UpSample(..., in_channels=ic, out_channels=ic, mode=upsample_mode)`.
            # So UpSample outputs `ic` channels at larger size. Then it's concatenated with `sc` skip channels.
            # This structure implies UpSample should output `in_channels` features.

            # For deconv, need to adjust if `in_channels` isn't the desired output channels.
            # If using ConvTranspose2d, the output channels of transpose conv become the input to concatenation.
            # Let's assume UpSample handles this and `in_channels` is the input to UpSample, and it outputs features
            # ready for concatenation with `skip_channels`.
            # If mode is 'deconv', MONAI uses ConvTranspose. The output channels of ConvTranspose should probably be `in_channels`
            # or some other number to match the skip channels after concat.
            # A common pattern in UNet: Deconv(in_C, in_C/2) -> Concat(skip_C=in_C/2) -> Conv(in_C, out_C).
            # Base code uses `ConvBnAct2d(in_channels + skip_channels, out_channels, ...)` after concat.
            # This suggests UpSample outputs `in_channels` features.

        # If upsample_mode == "deconv", need to specify out_channels for ConvTranspose.
        # What should the out_channels of the transpose conv be?
        # It should probably be `in_channels` unless we want channel reduction there.
        # Let's assume `in_channels` features are produced by UpSample.

        if upsample_mode == "pixelshuffle":
             # PixelShuffle expects input channels to be C * scale_factor^2
             # If input has `in_channels` and scale is `scale_factor`, output channels will be `in_channels / scale_factor^2`
             # This seems reverse of what's needed for Unet.
             # Let's assume UpSample('pixelshuffle') maps `in_channels` to `in_channels` at larger size.
             # MONAI SubpixelUpsample takes `in_channels` and maps to `in_channels`.
            self.upsample= SubpixelUpsample(
                spatial_dims= 2,
                in_channels= in_channels,
                scale_factor= scale_factor,
            )
        else:
            # Defaulting to 'transpose' (deconv) or 'nearest', 'bilinear'.
            # MONAI UpSample with 'transpose' mode needs out_channels for ConvTranspose.
            # What should out_channels be? If input is `in_channels` and skip is `skip_channels`,
            # maybe the transpose conv outputs `in_channels` features to be concatenated with `skip_channels`.
            # Let's stick to the original code's apparent intent where UpSample outputs `in_channels`.
            self.upsample = UpSample(
                spatial_dims= 2,
                in_channels= in_channels,
                out_channels= in_channels, # Assuming output channels are the same as input channels for upsampling
                scale_factor= scale_factor,
                mode= upsample_mode,
            )


        # Intermediate conv is not used in original UnetDecoder2d
        # self.intermediate_conv = None


        # Attention applied after concatenation (or just after upsample if no skip)
        # Total input channels to attention: `in_channels + skip_channels`
        attention1_in_channels = in_channels + skip_channels if skip_channels > 0 else in_channels # If skip_channels is 0, only use upsampled features
        self.attention1 = Attention2d(
            name= attention_type,
            in_channels= attention1_in_channels,
            )

        # First convolution after concatenation
        self.conv1 = ConvBnAct2d(
            attention1_in_channels, # Input channels: upsampled + skip
            out_channels,
            kernel_size= 3,
            padding= 1,
            norm_layer= norm_layer,
            act_layer= nn.GELU, # Use GELU consistently
        )

        # Second convolution
        self.conv2 = ConvBnAct2d(
            out_channels,
            out_channels,
            kernel_size= 3,
            padding= 1,
            norm_layer= norm_layer,
            act_layer= nn.GELU, # Use GELU consistently
        )

        # Attention applied after the second convolution
        self.attention2 = Attention2d(
            name= attention_type,
            in_channels= out_channels,
            )

    def forward(self, x, skip=None):
        # x is feature from previous decoder block / bottleneck
        # skip is feature from encoder skip connection

        # Upsample
        x = self.upsample(x)

        # Concatenate with skip connection
        if skip is not None:
            # print(f"  DecoderBlock: Upsampled shape {x.shape}, Skip shape {skip.shape}")
            # Ensure spatial sizes match for concatenation
            if x.shape[-2:] != skip.shape[-2:]:
                print(f"Warning: Spatial size mismatch in DecoderBlock. Upsampled: {x.shape[-2:]}, Skip: {skip.shape[-2:]}. Using interpolation on skip.")
                # Interpolate skip to match upsampled size
                skip = F.interpolate(skip, size=x.shape[-2:], mode='nearest') # Or 'bilinear'

            x = torch.cat([x, skip], dim=1) # Concatenate along channel dimension
            # print(f"  DecoderBlock: After concat shape {x.shape}")
            x = self.attention1(x) # Apply attention after concat
            # print(f"  DecoderBlock: After attention1 shape {x.shape}")

        elif self.attention1.attention is not nn.Identity:
             # If no skip but attention1 exists, apply it to x
             x = self.attention1(x)


        # Apply convolutions
        x = self.conv1(x)
        # print(f"  DecoderBlock: After conv1 shape {x.shape}")
        x = self.conv2(x)
        # print(f"  DecoderBlock: After conv2 shape {x.shape}")

        # Apply second attention
        x = self.attention2(x)
        # print(f"  DecoderBlock: After attention2 shape {x.shape}")

        return x


class UnetDecoder2d(nn.Module):
    """
    Unet decoder adapted from the provided code.
    Assumes input `feats` list is [bottleneck_feature, skip_coarsest, ..., skip_finest].
    """
    def __init__(
        self,
        encoder_channels: tuple[int], # Channels corresponding to feats[::-1]
        skip_channels: tuple[int], # Channels corresponding to feats[1:] + [0]
        decoder_channels: tuple = (256, 128, 64, 32), # Output channels of decoder blocks
        scale_factors: tuple = (2,2,2,2), # Upsampling factors for blocks
        norm_layer: nn.Module = nn.Identity,
        attention_type: str = None,
        intermediate_conv: bool = False, # Not used, kept for compatibility
        upsample_mode: str = "deconv",
    ):
        super().__init__()

        self.encoder_channels = encoder_channels # Channels of feats[::-1]
        self.skip_channels = skip_channels       # Channels of feats[1:] + [0]
        self.decoder_channels = decoder_channels # Output channels of blocks
        self.scale_factors = scale_factors
        self.norm_layer = norm_layer
        self.attention_type = attention_type
        self.upsample_mode = upsample_mode

        # Number of decoder blocks
        num_blocks = len(decoder_channels)
        assert len(encoder_channels) == num_blocks, "encoder_channels length must match decoder_channels length"
        assert len(skip_channels) == num_blocks, "skip_channels length must match decoder_channels length"
        assert len(scale_factors) == num_blocks, "scale_factors length must match decoder_channels length"


        # Build decoder blocks
        self.blocks = nn.ModuleList()

        # UnetDecoder2d's internal logic creates input channel lists based on encoder_channels and decoder_channels
        # in_channels_to_blocks = [encoder_channels[0]] + list(decoder_channels[:-1]) # Channels from previous block / bottleneck
        # skip_channels_for_blocks = list(encoder_channels[1:]) + [0] # Channels from skip connections

        # Let's re-implement the blocks loop directly based on standard U-Net flow
        # Block 0 input: bottleneck + skip_coarsest
        # Block 1 input: block0_output + skip_mid1
        # ...
        # Block N-1 input: blockN-2_output + skip_finest

        # This requires knowing the input channels and skip channels for EACH block.
        # Let's define the input channels to each block's concatenation explicitly.
        # Block i takes input from block i-1 (or bottleneck for i=0) and skip feature i.
        # Input channel to Block 0 concat: [Channels from bottleneck] + [Channels from skip_coarsest]
        # Input channel to Block i concat: [Channels from Block i-1 Output] + [Channels from skip_i]
        # Output channel of Block i is decoder_channels[i].

        # Let's trace the channels based on the `encoder_channels` and `skip_channels` params and `decoder_channels` outputs
        # These lists define the channels of the features passed to the forward method `feats`.
        # feats = [bottleneck, skip1, skip2, ..., skipN] where skip1 is coarsest. len(feats) = N+1.
        # UnetDecoder2d params len = N+1 (seems wrong based on base code).
        # Let's go back to base code logic: params len = len(decoder_channels) = N.
        # encoder_channels param (len N): defines channels of feats[::-1]. [S_N, ..., S_1, B]
        # skip_channels param (len N): defines channels of feats[1:] + [0]. [S_1, S_2, ..., S_N, 0]
        # decoder_channels param (len N): defines output channels. [D_1, ..., D_N]

        # This implies `feats` list has length N+1.
        # feats = [Bottleneck, Skip1, Skip2, ..., SkipN]
        # len(feats) = N+1.
        # But UnetDecoder2d forward loop is `for i, b in enumerate(self.blocks)` (N blocks).
        # `skip = feats[i] if i < len(feats) else None`.
        # This means skip for block 0 is feats[0] (Bottleneck)? Skip for block 1 is feats[1] (Skip1)?
        # This structure is confusing and seems non-standard.

        # Let's make a reasonable assumption that the provided UnetDecoder2d class is intended for a standard U-Net flow:
        # Takes a list of features `feats = [bottleneck, skip_coarsest, ..., skip_finest]`
        # Processes N blocks, where block i uses `feats[i]` as skip (for i > 0).
        # The first input to the first block comes from `feats[0]`.

        # Let's define the blocks based on this standard interpretation.
        # Number of blocks = len(decoder_channels).
        num_blocks = len(decoder_channels)
        # Need num_blocks skips + 1 bottleneck = num_blocks + 1 features in `feats`.
        # Ensure input list `feats` to forward method has length num_blocks + 1.

        # Channel tracking for standard U-Net decoder:
        # Block 0 input channels: Channels of bottleneck + Channels of skip_coarsest (feats[1])
        # Output channels: decoder_channels[0]
        # Block 1 input channels: Channels of Block 0 Output + Channels of skip_mid1 (feats[2])
        # Output channels: decoder_channels[1]
        # ...
        # Block N-1 input channels: Channels of Block N-2 Output + Channels of skip_finest (feats[N])
        # Output channels: decoder_channels[N-1]

        # We need to tell each DecoderBlock its input channels (after concat).
        # DecoderBlock init needs `in_channels` (from previous block/bottleneck) and `skip_channels`.
        # Let's pass the required channels directly.

        # Input channel to Block 0: from bottleneck. Channels = Channels of feats[0].
        # Skip channel for Block 0: from skip_coarsest. Channels = Channels of feats[1].
        # Output channel of Block 0: decoder_channels[0].

        # Input channel to Block 1: from Block 0 output. Channels = decoder_channels[0].
        # Skip channel for Block 1: from skip_mid1. Channels = Channels of feats[2].
        # Output channel of Block 1: decoder_channels[1].

        # Input channel to Block i: Channels of decoder_channels[i-1].
        # Skip channel for Block i: Channels of feats[i+1].
        # Output channel of Block i: decoder_channels[i].

        # This requires knowing the channels of the `feats` list *before* calling UnetDecoder2d.
        # The `encoder_channels` and `skip_channels` params in UnetDecoder2d's __init__ seem intended
        # to *derive* these channels, based on the input list `feats`.
        # Let's assume:
        # `encoder_channels` param = channels of `feats` in forward pass: [B_chan, S1_chan, S2_chan, ..., SN_chan]
        # `skip_channels` param = channels of `feats`[1:] + [0]: [S1_chan, S2_chan, ..., SN_chan, 0]
        # `decoder_channels` param = output channels: [D1_chan, ..., DN_chan]

        # Based on this assumption, let's redefine the blocks:
        # Block 0: in_channels = encoder_channels[0] (bottleneck). skip_channels = skip_channels[0] (skip1). out_channels = decoder_channels[0].
        # Block 1: in_channels = decoder_channels[0] (prev output). skip_channels = skip_channels[1] (skip2). out_channels = decoder_channels[1].
        # Block i: in_channels = decoder_channels[i-1]. skip_channels = skip_channels[i]. out_channels = decoder_channels[i].

        input_channels_to_blocks = [encoder_channels[0]] + list(decoder_channels[:-1])
        skip_channels_for_blocks = list(skip_channels) # Use the skip_channels param directly

        assert len(input_channels_to_blocks) == num_blocks, "Input channel calculation error"
        assert len(skip_channels_for_blocks) == num_blocks, "Skip channel list length mismatch"

        for i in range(num_blocks):
             ic = input_channels_to_blocks[i]
             sc = skip_channels_for_blocks[i]
             dc = decoder_channels[i]
             sf = scale_factors[i]

             self.blocks.append(
                 DecoderBlock2d(
                     ic, sc, dc,
                     norm_layer= self.norm_layer,
                     attention_type= self.attention_type,
                     intermediate_conv= intermediate_conv, # Kept False
                     upsample_mode= self.upsample_mode,
                     scale_factor= sf,
                 )
             )


    def forward(self, feats: list[torch.Tensor]):
        # feats: list of tensors [bottleneck_feature, skip_coarsest, ..., skip_finest]
        # Number of blocks = len(self.blocks) = len(self.decoder_channels)
        # Expected len(feats) = len(self.decoder_channels) + 1

        num_blocks = len(self.blocks)
        if len(feats) != num_blocks + 1:
             print(f"Error: UnetDecoder2d expected {num_blocks+1} features but got {len(feats)}")
             # Adjusting input list based on expectation for demo purposes
             # If we have 4 features (1 bottleneck, 3 skips) and 4 blocks are expected
             # The original base code seemed to feed only 4 features to a 4-block decoder?
             # Let's assume the original base code's use of feats was `feats = [f_stg4, f_stg3, f_stg2, f_stg1]`.
             # And its decoder had 4 blocks.
             # UnetDecoder2d loop: res = [feats[0]]; feats = feats[1:]
             # Block 0: skip=feats[0] (original feats[1]), input=res[-1] (original feats[0]) -> concat(f_stg4, f_stg3)
             # Block 1: skip=feats[1] (original feats[2]), input=res[-1] (output block 0) -> concat(output_b0, f_stg2)
             # ...
             # Block 3: skip=feats[3] (original feats[4] - out of bounds), input=res[-1] -> concat(output_b2, None)

             # Let's align with the original base code's UnetDecoder2d forward pass logic.
             # It seems to take `feats` as [feature_level_N, feature_level_N-1, ..., feature_level_1]
             # where level N is the coarsest skip, and level 1 is the finest skip.
             # And it treats `feats[0]` as the bottleneck feature *for the first block's input*, and `feats[1:]` as the skips.
             # This implies the input `feats` list should be [bottleneck_input_to_block0, skip_for_block0, skip_for_block1, ...]
             # No, the base code was `res = [feats[0]]` and `feats= feats[1:]` and `skip=feats[i]`.
             # This means feats[0] is res[-1] for block 0, and feats[0] (after slicing) is skip for block 0.
             # This is not standard U-Net.

             # Let's assume the standard U-Net flow where `feats` is [bottleneck, skip_coarsest, ..., skip_finest]
             # And the loop is adjusted.
             # This requires modifying the UnetDecoder2d forward pass.
             # Option 1: Modify UnetDecoder2d to match standard U-Net.
             # Option 2: Prepare `feats` list in `HybridModel.forward` to match the current `UnetDecoder2d` forward logic.
             # Based on `res = [feats[0]]; feats = feats[1:]; for i, b in enumerate(self.blocks): skip = feats[i] if i < len(feats) else None; res.append(b(res[-1], skip=skip))`.
             # If `feats_in` = [B, S1, S2, S3] (B=bottleneck, S1=coarsest, S2, S3=finest)
             # res = [B]
             # feats = [S1, S2, S3]
             # Block 0: skip=feats[0]=S1. Input=res[-1]=B. Block0(B, S1). res=[B, Output_B0]
             # Block 1: skip=feats[1]=S2. Input=res[-1]=Output_B0. Block1(Output_B0, S2). res=[B, Output_B0, Output_B1]
             # Block 2: skip=feats[2]=S3. Input=res[-1]=Output_B1. Block2(Output_B1, S3). res=[B, Output_B0, Output_B1, Output_B2]
             # Block 3: skip=feats[3] (index out of bounds). skip=None. Input=res[-1]=Output_B2. Block3(Output_B2, None). res=[B, Output_B0, Output_B1, Output_B2, Output_B3]
             # This seems to be the intended logic of the provided UnetDecoder2d.
             # It takes feats=[Bottleneck, Skip1, Skip2, Skip3] and has 4 blocks.
             # This perfectly matches the 4 levels of features (1 bottleneck, 3 skips) produced by our 4-stage encoder.

             # So the `decoder_feats` list constructed in `HybridModel.forward` is correct:
             # `decoder_feats = [x_2d_bottleneck] + skip_features_2d` where `skip_features_2d` is [Stg0_skip, Stg1_skip, Stg2_skip].
             # This means `feats` list will be [Bottleneck, Stg0_skip, Stg1_skip, Stg2_skip].
             # And UnetDecoder2d is initialized with `encoder_channels` and `skip_channels` reflecting the channels of this `feats` list.
             # This means `encoder_channels` param is [Stg2_skip_chan, Stg1_skip_chan, Stg0_skip_chan, Bottleneck_chan].
             # And `skip_channels` param is [Stg0_skip_chan, Stg1_skip_chan, Stg2_skip_chan, 0].

             # Let's trace channels again with this understanding:
             # Feats list: [Bottleneck(embed_dim), Stg0_skip(proj to 64), Stg1_skip(proj to 128), Stg2_skip(proj to 256)]
             # Feats channels: [embed_dim, 64, 128, 256]

             # UnetDecoder2d `encoder_channels` param = [256, 128, 64, embed_dim]
             # UnetDecoder2d `skip_channels` param = [64, 128, 256, 0]
             # UnetDecoder2d `decoder_channels` param = [256, 128, 64, 32]

             # UnetDecoder2d internal loop:
             # res = [feats[0]] (feats[0] is Bottleneck, channel embed_dim)
             # feats = feats[1:] (feats is now [Stg0_skip, Stg1_skip, Stg2_skip])

             # Block 0:
             # ic = input_channels_to_blocks[0] = encoder_channels[0] = 256
             # sc = skip_channels_for_blocks[0] = skip_channels[0] = 64
             # dc = decoder_channels[0] = 256
             # Block input concat channels: res[-1] (embed_dim) + skip (feats[0]=Stg0_skip, chan 64).
             # DecoderBlock2d expects in_channels=embed_dim, skip_channels=64. Actual call: Block0(res[-1], feats[0]).
             # The `in_channels`, `skip_channels` params in DecoderBlock2d init must match the *actual* channels passed during forward.

             # This setup is very confusing. The most likely scenario is that the original `UnetDecoder2d` expects `feats` list channels
             # and uses `encoder_channels` and `skip_channels` params to define block input/skip channels *indirectly*.

             # Let's assume the most standard U-Net structure where `feats` is [bottleneck, skip1, skip2, ...] (from coarsest to finest)
             # And the decoder iterates from finest spatial up.
             # Requires `feats = [Bottleneck, Skip_finest, Skip_mid, Skip_coarsest]` order in the list for the forward pass.
             # And UnetDecoder2d block 0 uses Bottleneck and Skip_finest. Block 1 uses Block0 output and Skip_mid.

             # Let's prepare `decoder_feats` in `HybridModel.forward` in the order [Bottleneck_2D, Skip_finest_2D, Skip_mid_2D, Skip_coarsest_2D]
             # Skip_finest is from Stage 2 (index 2). Skip_mid from Stage 1 (index 1). Skip_coarsest from Stage 0 (index 0).
             # `skip_features_2d` is [Stg0_skip, Stg1_skip, Stg2_skip]
             # Corrected `decoder_feats`: [x_2d_bottleneck, skip_features_2d[2], skip_features_2d[1], skip_features_2d[0]]
             # This matches `feats` order: [Bottleneck, Skip_Stg2, Skip_Stg1, Skip_Stg0]

             # Then, UnetDecoder2d needs params matching this `feats` list structure.
             # `encoder_channels` param: [Chan(feats[0]), Chan(feats[1]), Chan(feats[2]), Chan(feats[3])] = [embed_dim, 256, 128, 64] ? No.
             # The internal logic of UnetDecoder2d is key. Based on its code:
             # `res = [feats[0]]`. `feats = feats[1:]`. `skip = feats[i]`.
             # If feats_in = [B, S1, S2, S3] (B=bottleneck, S1=coarsest, S2, S3=finest).
             # res = [B]. feats = [S1, S2, S3].
             # Block 0: skip=feats[0]=S1. Input=res[-1]=B. Block0(B, S1). res=[B, Output_B0].
             # Block 1: skip=feats[1]=S2. Input=res[-1]=Output_B0. Block1(Output_B0, S2). res=[B, Output_B0, Output_B1].
             # Block 2: skip=feats[2]=S3. Input=res[-1]=Output_B1. Block2(Output_B1, S3). res=[B, Output_B0, Output_B1, Output_B2].
             # Block 3: skip=feats[3] (OOB). skip=None. Input=res[-1]=Output_B2. Block3(Output_B2, None). res=[..., Output_B3].

             # The `feats` list should be ordered [Bottleneck, Skip_coarsest, Skip_mid, Skip_finest].
             # Our `skip_features_2d` is already [Stg0_skip, Stg1_skip, Stg2_skip] (coarsest to finest).
             # So `decoder_feats = [x_2d_bottleneck] + skip_features_2d` is indeed correct for this interpretation.

             # Final structure seems consistent with this interpretation of UnetDecoder2d.

             # The problem is the parameters `encoder_channels` and `skip_channels` passed to UnetDecoder2d init.
             # They seem to define the *expected* input channels to blocks' concatenation.
             # Let's follow the original base code's logic for these params:
             # `encoder_channels_param` = `ecs` (reversed encoder chans) = [S3_chan, S2_chan, S1_chan, B_chan] (finest to coarsest)
             # `skip_channels_param` = `ecs[1:] + [0]` = [S2_chan, S1_chan, B_chan, 0] (mid-finest to bottleneck)
             # `decoder_channels_param` = [D1, D2, D3, D4] (output chans)

             # This contradicts how feats are used in forward.
             # Let's ignore the parameter names in UnetDecoder2d init and just pass the necessary channels for the blocks based on
             # the standard U-Net flow that the forward pass *seems* to imply (despite the confusing implementation).

             # Redefine UnetDecoder2d init and forward to be standard U-Net.
             # Input: feats = [bottleneck, skip1, ..., skipN] (skip1 is coarsest, skipN finest). len = N+1.
             # Decoder has N blocks (len decoder_channels).
             # Block 0: input concat(feats[0], feats[N]). Output decoder_channels[0].
             # Block 1: input concat(Output Block 0, feats[N-1]). Output decoder_channels[1].
             # Block i: input concat(Output Block i-1, feats[N-i]). Output decoder_channels[i].

             # This requires the skips in `feats` to be in reverse order (finest to coarsest).
             # `feats = [Bottleneck, Skip_finest, Skip_mid, Skip_coarsest]`.
             # So `decoder_feats = [x_2d_bottleneck] + skip_features_2d[::-1]`.

             # Let's adjust UnetDecoder2d forward to use skips from feats[1:] reversed.
             # This matches standard U-Net and makes channel logic clearer.

             pass # Modified UnetDecoder2d forward below


class UnetDecoder2d(nn.Module):
    """
    Unet decoder - Modified for standard U-Net skip connection order.
    Input `feats` list is [bottleneck_feature, skip_coarsest, ..., skip_finest].
    """
    def __init__(
        self,
        decoder_channels: tuple = (256, 128, 64, 32), # Output channels of decoder blocks
        scale_factors: tuple = (2,2,2,2), # Upsampling factors for blocks
        norm_layer: nn.Module = nn.Identity,
        attention_type: str = None,
        upsample_mode: str = "deconv",
        # Removed encoder_channels and skip_channels from init, will determine block channels internally
    ):
        super().__init__()

        self.decoder_channels = decoder_channels
        self.scale_factors = scale_factors
        self.norm_layer = norm_layer
        self.attention_type = attention_type
        self.upsample_mode = upsample_mode

        # Number of decoder blocks
        num_blocks = len(decoder_channels)
        assert len(scale_factors) == num_blocks, "scale_factors length must match decoder_channels length"

        self.blocks = nn.ModuleList()

        # Block definitions are based on the channels provided in the forward pass `feats`.
        # Need to determine input/skip channels dynamically in forward or pass them to init.
        # Let's pass required channels during init based on the expected `feats` list structure.

        # Expected feats list: [bottleneck, skip1, skip2, skip3] (skip1 coarsest, skip3 finest)
        # Need to know channels and spatial sizes of these expected feats.
        # This is computed in HybridModel.init and HybridModel._forward_unscaled.

        # Let's pass the list of feature channels directly to UnetDecoder2d init.
        # `feature_channels`: list of channels of the expected `feats` list [B_chan, S1_chan, S2_chan, ..., SN_chan]

        # Modify UnetDecoder2d init signature
        pass # Modified class signature below


class UnetDecoder2d(nn.Module):
    """
    Unet decoder - Modified for standard U-Net skip connection order and channel definition.
    Input `feats` list is [bottleneck_feature, skip_coarsest, ..., skip_finest].
    """
    def __init__(
        self,
        feature_channels: list[int], # Channels of the expected `feats` list [B_chan, S1_chan, ..., SN_chan]
        decoder_channels: tuple = (256, 128, 64, 32), # Output channels of decoder blocks
        scale_factors: tuple = (2,2,2,2), # Upsampling factors for blocks
        norm_layer: nn.Module = nn.Identity,
        attention_type: str = None,
        upsample_mode: str = "deconv",
    ):
        super().__init__()

        self.feature_channels = feature_channels # Channels of [bottleneck, skip1, ..., skipN]
        self.decoder_channels = decoder_channels # Output channels of blocks
        self.scale_factors = scale_factors
        self.norm_layer = norm_layer
        self.attention_type = attention_type
        self.upsample_mode = upsample_mode

        num_blocks = len(decoder_channels)
        # Need num_blocks skips + 1 bottleneck = num_
-----------------