In [1]:
import torch
print(torch.__version__)
print(torch.version.cuda)         # None 表示 CPU-only
print(torch.cuda.is_available())  # False 表示没有可用 GPU

1.12.1+cpu
None
False


In [2]:
import os
from glob import glob
import yaml
import torch
import monai
from monai.transforms import (
    LoadImaged,
    AddChanneld,
    ScaleIntensityd,
    RandCropByPosNegLabeld,
    EnsureTyped,
    Compose,
)
from monai.data import Dataset, DataLoader, list_data_collate
from train import dict2obj

# Load configs (same as train.py)
with open("configs/mnist_config.yaml") as f:
    config = dict2obj(yaml.load(f, Loader=yaml.FullLoader))
with open("configs/mnist_dataconfig.yaml") as f:
    dataconfig = dict2obj(yaml.load(f, Loader=yaml.FullLoader))

# Build the same transforms to inspect shapes
train_transforms = Compose(
    [
        LoadImaged(keys=["img", "seg"]),
        AddChanneld(keys=["img", "seg"]),
        ScaleIntensityd(keys=["img", "seg"]),
        RandCropByPosNegLabeld(
            keys=["img", "seg"],
            label_key="seg",
            spatial_size=dataconfig.DATA.IMG_SIZE,
            pos=1,
            neg=1,
            num_samples=dataconfig.DATA.NUM_PATCH,
        ),
        EnsureTyped(keys=["img", "seg"]),
    ]
)

data_path = dataconfig.DATA.DATA_PATH
images = sorted(glob(os.path.join(data_path + "images", "*" + dataconfig.DATA.FORMAT)))
segs = sorted(glob(os.path.join(data_path + "labels", "*" + dataconfig.DATA.FORMAT)))
train_files = [
    {"img": img, "seg": seg}
    for img, seg in zip(images[: dataconfig.DATA.TRAIN_SAMPLES], segs[: dataconfig.DATA.TRAIN_SAMPLES])
]

train_ds = Dataset(data=train_files, transform=train_transforms)
train_loader = DataLoader(
    train_ds,
    batch_size=config.TRAIN.BATCH_SIZE,
    shuffle=True,
    num_workers=0,
    collate_fn=list_data_collate,
)

batch = next(iter(train_loader))
imgs, labels = batch["img"], batch["seg"]
print("batch img shape:", imgs.shape)
print("batch seg shape:", labels.shape)
print("single img shape:", imgs[0].shape)
print("single seg shape:", labels[0].shape)
print("single img squeezed:", imgs[0].squeeze().shape)
print("single seg squeezed:", labels[0].squeeze().shape)


  pkg = __import__(module)  # top level module


batch img shape: (8, 1, 48, 48)
batch seg shape: (8, 1, 48, 48)
single img shape: (1, 48, 48)
single seg shape: (1, 48, 48)
single img squeezed: (48, 48)
single seg squeezed: (48, 48)




In [4]:
!python3 train.py --config configs/mnist_config.yaml --dataconfig configs/mnist_dataconfig.yaml

  pkg = __import__(module)  # top level module

*** Config file
configs/mnist_config.yaml

*** Dataconfig file
configs/mnist_dataconfig.yaml
Running on cpu
  0%|                                                    | 0/10 [00:00<?, ?it/s]
  0%|                                                   | 0/563 [00:00<?, ?it/s][A
  0%|                                                    | 0/10 [00:00<?, ?it/s]
Traceback (most recent call last):
  File "/home/dsand/projects/Betti-matching/train.py", line 353, in <module>
    main(args)
  File "/home/dsand/projects/Betti-matching/train.py", line 268, in main
    loss, dic = loss_function(outputs, labels)
  File "/home/dsand/miniconda3/envs/bm39/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/dsand/projects/Betti-matching/loss_functions.py", line 56, in forward
    losses.append(compute_BettiMatchingLoss(pair, sigmoid=True, filtration=self.filtration, relative=s

In [None]:
import torch
from BettiMatching import CubicalPersistence

# 兼容不同 MONAI 版本的 MetaTensor 导入
try:
    from monai.data import MetaTensor
except Exception:
    from monai.data.meta_tensor import MetaTensor

def inspect_picture(pic, name):
    print(f"\n=== {name} ===")
    print("type:", type(pic))
    print("isinstance(pic, torch.Tensor):", isinstance(pic, torch.Tensor))
    print("type(pic) == torch.Tensor:", type(pic) == torch.Tensor)
    print("shape before:", tuple(pic.shape))

    # 模拟你当前代码里的逻辑
    pic_old_logic = pic
    if type(pic_old_logic) == torch.Tensor:
        pic_old_logic = torch.squeeze(pic_old_logic)
    print("shape after OLD logic:", tuple(pic_old_logic.shape))

    # 推荐逻辑
    pic_new_logic = pic
    if isinstance(pic_new_logic, torch.Tensor):
        pic_new_logic = torch.squeeze(pic_new_logic)
    print("shape after NEW logic:", tuple(pic_new_logic.shape))

    # 实际验证 CubicalPersistence 是否能吃下
    for tag, candidate in [("OLD", pic_old_logic), ("NEW", pic_new_logic)]:
        try:
            cp = CubicalPersistence(
                candidate,
                relative=False,   # 避免触发 relative+training 的额外分支
                reduced=False,
                filtration="superlevel",
                construction="V",
                training=False
            )
            print(f"{tag} -> CubicalPersistence OK, m,n=({cp.m},{cp.n})")
        except Exception as e:
            print(f"{tag} -> CubicalPersistence ERROR: {type(e).__name__}: {e}")

# 3D 输入: [1, H, W]，与你训练里单样本很像
x = torch.rand(1, 48, 48)
mx = MetaTensor(x.clone())

inspect_picture(x, "Plain torch.Tensor")
inspect_picture(mx, "MONAI MetaTensor")


=== Plain torch.Tensor ===
type: <class 'torch.Tensor'>
isinstance(pic, torch.Tensor): True
type(pic) == torch.Tensor: True
shape before: (1, 48, 48)
shape after OLD logic: (48, 48)
shape after NEW logic: (48, 48)
OLD -> CubicalPersistence OK, m,n=(48,48)
NEW -> CubicalPersistence OK, m,n=(48,48)

=== MONAI MetaTensor ===
type: <class 'monai.data.meta_tensor.MetaTensor'>
isinstance(pic, torch.Tensor): True
type(pic) == torch.Tensor: False
shape before: (1, 48, 48)
shape after OLD logic: (1, 48, 48)
shape after NEW logic: (48, 48)
OLD -> CubicalPersistence ERROR: ValueError: too many values to unpack (expected 2)
NEW -> CubicalPersistence OK, m,n=(48,48)


In [7]:
import torch
from monai.data.meta_tensor import MetaTensor
x = MetaTensor(torch.rand(1,48,48))
print('x.shape', tuple(x.shape))
y = torch.squeeze(x)
print('torch.squeeze(x).shape', tuple(y.shape), type(y))
z = x.squeeze()
print('x.squeeze().shape', tuple(z.shape), type(z))

x.shape (1, 48, 48)
torch.squeeze(x).shape (48, 48) <class 'monai.data.meta_tensor.MetaTensor'>
x.squeeze().shape (48, 48) <class 'monai.data.meta_tensor.MetaTensor'>


In [11]:
import importlib
import BettiMatching
importlib.reload(BettiMatching)
from BettiMatching import CubicalPersistence

In [12]:
import torch
from monai.data.meta_tensor import MetaTensor
from BettiMatching import CubicalPersistence
x = MetaTensor(torch.rand(1, 48, 48))
cp = CubicalPersistence(x, relative=False, filtration='superlevel', construction='V', training=True)
print('relative=False OK', cp.m, cp.n)
cp2 = CubicalPersistence(x, relative=True, filtration='superlevel', construction='V', training=True)
print('relative=True OK', cp2.m, cp2.n)


relative=False OK 48 48
relative=True OK 50 50


In [13]:
!python3 train.py --config configs/mnist_config.yaml --dataconfig configs/mnist_dataconfig.yaml

  pkg = __import__(module)  # top level module

*** Config file
configs/mnist_config.yaml

*** Dataconfig file
configs/mnist_dataconfig.yaml
Running on cpu
  0%|                                                    | 0/10 [00:00<?, ?it/s]
  0%|                                                   | 0/563 [00:00<?, ?it/s][A
  0%|                                           | 1/563 [00:05<47:41,  5.09s/it][A
  0%|▏                                          | 2/563 [00:10<48:35,  5.20s/it][A
  1%|▏                                          | 3/563 [00:15<49:23,  5.29s/it][A
  1%|▎                                          | 4/563 [00:21<49:30,  5.31s/it][A
  1%|▍                                          | 5/563 [00:26<49:37,  5.34s/it][A
  1%|▍                                          | 6/563 [00:31<49:42,  5.36s/it][A
  1%|▌                                          | 7/563 [00:37<49:40,  5.36s/it][A
  1%|▌                                          | 8/563 [00:42<49:47,  5.38s/it][A
  2%|▋ 

In [10]:
import torch
from monai.data.meta_tensor import MetaTensor
from BettiMatching import CubicalPersistence

x = MetaTensor(torch.rand(1, 48, 48))
cp = CubicalPersistence(x, relative=True, filtration='superlevel', construction='V', training=True)

print("type:", type(cp.PixelMap))
print("is torch tensor:", isinstance(cp.PixelMap, torch.Tensor))
print("dtype:", getattr(cp.PixelMap, "dtype", None))
print("device:", getattr(cp.PixelMap, "device", None))
print("shape:", cp.PixelMap.shape)

type: <class 'monai.data.meta_tensor.MetaTensor'>
is torch tensor: True
dtype: torch.float32
device: cpu
shape: (50, 50)
