In [4]:
import os
import numpy as np
import medmnist
from medmnist import INFO
import torchvision.transforms as T
from torch.utils.data import DataLoader

# 设置参数
SIZE = (32, 64, 64)
batch_size = 64

# 统一的数据预处理类
class Transform3D:
    def __init__(self, mode='train', mul=None):
        self.mode = mode
        self.mul = mul
        
        # 基础变换列表 - 将lambda函数改为普通函数引用
        transform_list = [
            T.Lambda(self._normalize)  # 直接使用类方法而不是lambda
        ]
        
        # 训练模式添加数据增强
        if mode == 'train':
            transform_list.extend([
                T.RandomRotation(degrees=15),
                T.RandomAffine(degrees=0, translate=(0.1, 0.1))
            ])
            
        self.transforms = T.Compose(transform_list)

    def _normalize(self, voxel):
        # 应用乘法变换
        if self.mul == '0.5':
            voxel = voxel * 0.5
        elif self.mul == 'random':
            voxel = voxel * np.random.uniform()
            
        # 标准化处理
        if voxel.max() != 0:
            voxel = (voxel - voxel.min()) / (voxel.max() - voxel.min())
        return voxel

    def __call__(self, voxel):
        voxel = self.transforms(voxel)
        return voxel.astype(np.float32)

def load_data(batch_size=batch_size, data_flag='organmnist3d', download=True):
    """数据加载器函数"""
    # 数据集配置
    info = INFO[data_flag]
    DataClass = getattr(medmnist, info['python_class'])

    # 数据集加载
    train_dataset = DataClass(split='train', 
                            transform=Transform3D(mode='train', mul='random'), 
                            download=download)
    val_dataset = DataClass(split='val', 
                          transform=Transform3D(mode='val', mul='0.5'), 
                          download=download)
    test_dataset = DataClass(split='test', 
                           transform=Transform3D(mode='val', mul='0.5'), 
                           download=download)

    # 创建数据加载器
    train_loader = DataLoader(dataset=train_dataset, 
                            batch_size=batch_size, 
                            shuffle=True, 
                            num_workers=4, 
                            pin_memory=True)
    val_loader = DataLoader(dataset=val_dataset, 
                          batch_size=batch_size, 
                          shuffle=False, 
                          num_workers=4, 
                          pin_memory=True)
    test_loader = DataLoader(dataset=test_dataset, 
                           batch_size=batch_size, 
                           shuffle=False, 
                           num_workers=4, 
                           pin_memory=True)

    return train_loader, val_loader, test_loader

In [5]:
# 加载数据
train_loader, val_loader, test_loader = load_data(
    batch_size=64,
    data_flag='organmnist3d',
    download=True
)

def visualize_processing_results(train_loader):
    """可视化数据处理结果的函数"""
    # 获取一个批次的数据
    images, labels = next(iter(train_loader))
    
    # 创建一个3x3的子图来显示9个样本
    fig, axes = plt.subplots(3, 3, figsize=(12, 12))
    
    for i in range(min(9, len(images))):
        row = i // 3
        col = i % 3
        
        # 显示3D图像的中间切片
        middle_slice = images[i, 0, images.shape[2]//2, :, :].numpy()
        axes[row, col].imshow(middle_slice, cmap='gray')
        axes[row, col].set_title(f'标签: {labels[i].item()}')
        axes[row, col].axis('off')
    
    plt.tight_layout()
    plt.show()
    
# 查看数据集信息
info = INFO['organmnist3d']
print(f"数据集: {info['description']}")
print(f"类别数: {len(info['label'])}")
print(f"类别: {info['label']}")

visualize_processing_results(train_loader)

Using downloaded and verified file: C:\Users\Administrator\.medmnist\organmnist3d.npz
Using downloaded and verified file: C:\Users\Administrator\.medmnist\organmnist3d.npz
Using downloaded and verified file: C:\Users\Administrator\.medmnist\organmnist3d.npz
数据集: The source of the OrganMNIST3D is the same as that of the Organ{A,C,S}MNIST. Instead of 2D images, we directly use the 3D bounding boxes and process the images into 28×28×28 to perform multi-class classification of 11 body organs. The same 115 and 16 CT scans as the Organ{A,C,S}MNIST from the source training set are used as training and validation set, respectively, and the same 70 CT scans as the Organ{A,C,S}MNIST from the source test set are treated as the test set.
类别数: 11
类别: {'0': 'liver', '1': 'kidney-right', '2': 'kidney-left', '3': 'femur-right', '4': 'femur-left', '5': 'bladder', '6': 'heart', '7': 'lung-right', '8': 'lung-left', '9': 'spleen', '10': 'pancreas'}
