In [3]:
# === Minimal one-cell GAN debug: load (handles .utils), instantiate, smoke test, summary ===
import os, sys, importlib.util
import torch
from types import SimpleNamespace
from torchinfo import summary

PROJECT_ROOT = "/home/mingyeong/GAL2DM_ASIM_GAN"
MODEL_PATH   = os.path.join(PROJECT_ROOT, "src", "model.py")
UTILS_PATH   = os.path.join(PROJECT_ROOT, "src", "utils.py")

def load_with_pkg(model_path, utils_path, symbol, pkg="pix2pixcc3d"):
    # load <pkg>.utils
    utils_name = f"{pkg}.utils"
    if utils_name not in sys.modules:
        u_spec = importlib.util.spec_from_file_location(utils_name, utils_path)
        u_mod  = importlib.util.module_from_spec(u_spec)
        sys.modules[utils_name] = u_mod
        u_spec.loader.exec_module(u_mod)  # type: ignore
    # load <pkg>.model with package context so `.utils` resolves
    model_name = f"{pkg}.model"
    m_spec = importlib.util.spec_from_file_location(model_name, model_path)
    m_mod  = importlib.util.module_from_spec(m_spec)
    m_mod.__package__ = pkg
    sys.modules[model_name] = m_mod
    m_spec.loader.exec_module(m_mod)  # type: ignore
    return getattr(m_mod, symbol)

# 1) Load class
GeneratorPix2PixCC3D = load_with_pkg(MODEL_PATH, UTILS_PATH, "GeneratorPix2PixCC3D")

# 2) Minimal opt & instantiate
opt = SimpleNamespace(
    input_ch=2, target_ch=1,
    n_gf=32, n_df=32,
    n_downsample=3, n_residual=6,
    norm_type="InstanceNorm3d", padding_type="reflection",
    trans_conv=True, n_D=3, ch_balance=0.0,
    lambda_LSGAN=1.0, lambda_FM=10.0, lambda_CC=5.0,
    n_CC=2, ccc=True, eps=1e-8, gpu_ids=0, data_type=32
)
device = "cuda" if torch.cuda.is_available() else "cpu"
G = GeneratorPix2PixCC3D(opt).to(device).eval()

# 3) Smoke forward (small input), then concise summary
with torch.inference_mode():
    y = G(torch.randn(1, 2, 32, 32, 32, device=device))
print("✅ forward OK | out shape:", tuple(y.shape))

summary(
    G,
    input_size=(1, 2, 128, 128, 128),   # use 64³ for safe summary; switch to 128³ if needed
    device=device,
    depth=2,
    col_names=("input_size","output_size","num_params")
)


✅ forward OK | out shape: (1, 1, 32, 32, 32)


Layer (type:depth-idx)                        Input Shape               Output Shape              Param #
GeneratorPix2PixCC3D                          [1, 2, 128, 128, 128]     [1, 1, 128, 128, 128]     --
├─Sequential: 1-1                             [1, 2, 128, 128, 128]     [1, 1, 128, 128, 128]     --
│    └─ReflectionPad3d: 2-1                   [1, 2, 128, 128, 128]     [1, 2, 134, 134, 134]     --
│    └─Conv3d: 2-2                            [1, 2, 134, 134, 134]     [1, 32, 128, 128, 128]    21,984
│    └─InstanceNorm3d: 2-3                    [1, 32, 128, 128, 128]    [1, 32, 128, 128, 128]    --
│    └─Mish: 2-4                              [1, 32, 128, 128, 128]    [1, 32, 128, 128, 128]    --
│    └─Conv3d: 2-5                            [1, 32, 128, 128, 128]    [1, 64, 64, 64, 64]       256,064
│    └─InstanceNorm3d: 2-6                    [1, 64, 64, 64, 64]       [1, 64, 64, 64, 64]       --
│    └─Mish: 2-7                              [1, 64, 64, 64, 64]       [1, 6