In [5]:
import os
import glob
import random
import argparse

import h5py
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
from subsample import RandomMaskFunc

In [6]:
def label_from_name(fname):
    low = fname.lower()
    if "brain" in low:
        return 1
    elif "knee" in low:
        return 0
    else:
        raise ValueError(f"Cannot extract label from {fname}")


In [15]:
class KSpaceDataset(Dataset):
    def __init__(self, filepaths, ds='kspace', num_slices=4,
                 noise_std=0.01, flip_prob=0.5,
                 center_fractions=[0.08], accelerations=[4,8]):
        self.paths = filepaths
        self.ds = ds
        self.num_slices = num_slices
        self.noise_std = noise_std
        self.flip_prob = flip_prob
        self.mask_func = RandomMaskFunc(
            center_fractions=center_fractions,
            accelerations=accelerations,
            allow_any_combination=True
        )
        # fixed padding size
        self.max_h, self.max_w = 0, 0
        for path in self.paths:
            with h5py.File(path, 'r') as f:
                arr = f[self.ds][()]
            if arr.shape[-1] <= 32:
                arr = np.moveaxis(arr, -1, 1)
            _, _, Ky, Kx = arr.shape
            self.max_h = max(self.max_h, Ky)
            self.max_w = max(self.max_w, Kx)

    def __len__(self):
        return len(self.paths)

    
    def __getitem__(self, idx):
        path = self.paths[idx]
        # 1) load original k-space & mask volume
        with h5py.File(path, 'r') as f:
            arr   = f[self.ds][()]
            maskk = f["mask"][()]
        if arr.shape[-1] <= 32:
            arr   = np.moveaxis(arr, -1, 1)
            maskk = np.moveaxis(maskk, -1, 1)

        k1 = torch.from_numpy(arr)           # 원본
        k2 = torch.from_numpy(arr * maskk)   # mask 적용된

        # 공통 슬라이스 인덱스 계산
        S, C, Ky, Kx = k2.shape
        idxs = [0, S//3, (S*2)//3]

        # 2) k2_ori: mask 적용 전 슬라이스
        k2_ori = k2[idxs]  # (3, C, Ky, Kx)

        # 3) k1_slices: 마스크 + 샘플링
        mask_vol, _ = self.mask_func(k1.shape)
        mask_vol    = mask_vol.to(k1.dtype)
        k1_masked   = k1 * mask_vol
        k1_slices   = k1_masked[idxs]  # (3, C, Ky, Kx)

        # 4) k2_slices: flip + noise 후 샘플링
        if random.random() < self.flip_prob:
            k2 = torch.flip(k2, dims=[-1])
        std       = self.noise_std * torch.std(torch.abs(k2))
        k2_noised = k2 + torch.randn_like(k2) * std
        k2_slices = k2_noised[idxs]      # (3, C, Ky, Kx)

        k_slices = k1_slices + k2_slices + k2_ori
        print(f"[DEBUG] k_slices.shape = {k_slices.shape}")

        # RSS and normalization
        mag = torch.abs(k_slices)
        rss = torch.sqrt(torch.sum(mag**2, dim=1))  # (3, Ky, Kx)
        rss = (rss - rss.mean()) / (rss.std() + 1e-8)

        # pad and return
        pad_h = self.max_h - Ky
        pad_w = self.max_w - Kx
        x = F.pad(rss, (0, pad_w, 0, pad_h), mode='constant', value=0.0)
        y = label_from_name(os.path.basename(path))
        return x.float(), torch.tensor(y, dtype=torch.long)


Test 용 디버깅 툴


In [17]:
# 3) 데이터 경로, Dataset/Loader 생성 & 테스트
data_dir  = "Data/train/kspace"
paths     = sorted(glob.glob(f"{data_dir}/*.h5"))
print("▶ Found", len(paths), "files")

ds = KSpaceDataset(paths)
print("Dataset length:", len(ds))

# (1) 단일 샘플 호출
x0, y0 = ds[0]
print("Single sample x0.shape:", x0.shape, "y0:", y0)

# (2) DataLoader 호출 (num_workers=0 으로)
loader = DataLoader(ds, batch_size=2, shuffle=False, num_workers=0, pin_memory=True)
bx, by = next(iter(loader))
print("Batch bx.shape:", bx.shape, "by.shape:", by.shape)


▶ Found 340 files
Dataset length: 340
[DEBUG] k_slices.shape = torch.Size([3, 16, 768, 396])
Single sample x0.shape: torch.Size([3, 768, 480]) y0: tensor(1)
[DEBUG] k_slices.shape = torch.Size([3, 16, 768, 396])
[DEBUG] k_slices.shape = torch.Size([3, 16, 768, 396])
Batch bx.shape: torch.Size([2, 3, 768, 480]) by.shape: torch.Size([2])
