In [3]:
import argparse

class Args:
    def __init__(self):
        self.annotation_path = "/media/wangyujie/CXPMRG_Bench_MambaXray_VL/daasets/mimic/mimic+inter+intra0.001.json"  # 你的json路径
        self.image_root = "/media/wangyujie/CXPMRG_Bench_MambaXray_VL/daasets/mimic/image_224"
        self.region_root = "/media/wangyujie/CXPMRG_Bench_MambaXray_VL/promot/region_top5_capture/feature_boxes_flat"
        self.base_dir = self.image_root  # 通常是图片根目录
        self.use_sip = True  # 是否使用 SIP 版本
        self.patch_size = 40
        self.dataset = "mimic_cxr"
        self.input_size = 224

args = Args()

train_ds, val_ds, test_ds = create_datasets(args)

print(f"训练集大小: {len(train_ds)}")
print(f"验证集大小: {len(val_ds)}")
print(f"测试集大小: {len(test_ds)}")

# 取一个样本测试
sample = train_ds[0]
print("样本ID:", sample['id'])
print("报告文本:", sample['input_text'])
print("图像Tensor列表长度:", len(sample['image']))
print("patch_tensor shape:", sample['patch_tensor'].shape if sample['patch_tensor'] is not None else None)
print("box_tensor shape:", sample['box_tensor'].shape)
print("class_idx shape:", sample['class_idx'].shape)


NameError: name 'create_datasets' is not defined

In [None]:
import os
import json
import re
import numpy as np
from PIL import Image
import torch
from torch.utils.data import Dataset
from transformers import AutoImageProcessor

class FieldParser:
    def __init__(self, args):
        self.args = args
        self.dataset = args.dataset
        self.vit_feature_extractor = AutoImageProcessor.from_pretrained(
            '/media/wangyujie/CXPMRG_Bench_MambaXray_VL/huggingface/microsoft/swin-base-patch4-window7-224'
        )
    
    def _parse_image(self, img):
        pixel_values = self.vit_feature_extractor(img, return_tensors="pt", size=self.args.input_size).pixel_values
        return pixel_values[0]  # Tensor [3, H, W]

    def clean_report(self, report):
        # 这里简单清理示例，按需扩展
        report_cleaner = lambda t: t.replace('\n', ' ').strip().lower().split('. ')
        tokens = [t.strip() for t in report_cleaner(report) if len(t.strip()) > 0]
        return ' . '.join(tokens) + ' .'

    def parse(self, features):
        image_id = features['id']
        report = features.get("report", "")
        input_text = self.clean_report(report)
        images = []
        for img_path in features['image_path']:
            full_img_path = os.path.join(self.args.base_dir, img_path)
            with Image.open(full_img_path) as pil_img:
                pil_img = pil_img.convert("RGB")
                arr = np.array(pil_img)
                if len(arr.shape) != 3 or arr.shape[-1] != 3:
                    arr = np.array(pil_img.convert("RGB"))
                img_tensor = self._parse_image(arr)
                images.append(img_tensor)
        return {
            'id': image_id,
            'input_text': input_text,
            'image': images
        }


class ChestXrayWithRegionDataset(Dataset):
    def __init__(self, image_dir, region_dir, annotation_json, base_dir, parser, patch_size=40):
        self.image_dir = image_dir
        self.region_dir = region_dir
        self.base_dir = base_dir
        self.parser = parser
        self.patch_size = patch_size

        # 不再过滤，直接使用全部样本
        self.meta = annotation_json
        print(f"Dataset loaded. Total samples: {len(self.meta)}")

    def __len__(self):
        return len(self.meta)

    def __getitem__(self, idx):
        features = self.meta[idx]
        parsed = self.parser.parse(features)
        input_text = parsed['input_text']
        images = parsed['image']

        image_file = features['image_path'][0]
        image_id = os.path.splitext(image_file)[0]
        region_path = os.path.join(self.region_dir, f"{image_id}.pt")
        region_data = torch.load(region_path)
        patch_tensor = region_data.get('patch_tensor', None)
        boxes = region_data.get('box_tensor', None)
        labels = region_data.get('class_idx', None)

        return {
            'id': image_id,
            'input_text': input_text,
            'image': images,
            'patch_tensor': patch_tensor,
            'box_tensor': boxes,
            'class_idx': labels,
        }


def create_datasets(args):
    with open(args.annotation_path, 'r') as f:
        full_data = json.load(f)
    train_ids = full_data['train']
    val_ids = full_data['val']
    test_ids = full_data['test']

    if getattr(args, 'use_sip', False):
        parser = FieldParser(args)
        train_dataset = ChestXrayWithRegionDataset(
            image_dir=args.image_root,
            region_dir=args.region_root,
            annotation_json=train_ids,
            base_dir=args.base_dir,
            parser=parser,
            patch_size=getattr(args, 'patch_size', 40)
        )
        val_dataset = ChestXrayWithRegionDataset(
            image_dir=args.image_root,
            region_dir=args.region_root,
            annotation_json=val_ids,
            base_dir=args.base_dir,
            parser=parser,
            patch_size=getattr(args, 'patch_size', 40)
        )
        test_dataset = ChestXrayWithRegionDataset(
            image_dir=args.image_root,
            region_dir=args.region_root,
            annotation_json=test_ids,
            base_dir=args.base_dir,
            parser=parser,
            patch_size=getattr(args, 'patch_size', 40)
        )
    else:
        # 这里你可以用你之前的 ParseDataset 或其他
        raise NotImplementedError("非 SIP 版本暂未实现")

    return train_dataset, val_dataset, test_dataset


# --------- 测试示例 ---------

class Args:
    def __init__(self):
        self.annotation_path = "/media/wangyujie/CXPMRG_Bench_MambaXray_VL/daasets/mimic/mimic+inter+intra1.json"  # 你的json路径
        self.image_root = "/media/wangyujie/CXPMRG_Bench_MambaXray_VL/daasets/mimic/image_224"
        self.region_root = "/media/wangyujie/CXPMRG_Bench_MambaXray_VL/promot/region_top5_capture/feature_boxes_flat"
        self.base_dir = self.image_root
        self.use_sip = True
        self.patch_size = 40
        self.dataset = "mimic_cxr"
        self.input_size = 224

def test():
    args = Args()
    train_ds, val_ds, test_ds = create_datasets(args)
    print(f"训练集大小: {len(train_ds)}")
    sample = train_ds[0]
    print(f"样本ID: {sample['id']}")
    print(f"报告文本: {sample['input_text']}")
    print(f"图像张数: {len(sample['image'])}")
    print(f"patch_tensor形状: {None if sample['patch_tensor'] is None else sample['patch_tensor'].shape}")
    print(f"box_tensor形状: {None if sample['box_tensor'] is None else sample['box_tensor'].shape}")
    print(f"class_idx: {sample['class_idx']}")

if __name__ == "__main__":
    test()


Dataset loaded. Total samples: 370342
Dataset loaded. Total samples: 2130
Dataset loaded. Total samples: 3858
训练集大小: 370342
样本ID: p10_p10000032_s50414267_02aa804e-bde0afdd-112c0b34-7bc16630-4e384014
报告文本: there is no focal consolidation, pleural effusion or pneumothorax . bilateral  nodular opacities that most likely represent nipple shadows . the  cardiomediastinal silhouette is normal . clips project over the left lung,  potentially within the breast . the imaged upper abdomen is unremarkable . chronic deformity of the posterior left sixth and seventh ribs are noted. .
图像张数: 1
patch_tensor形状: torch.Size([18, 1, 40, 40])
box_tensor形状: torch.Size([18, 4])
class_idx: tensor([0, 0, 0, 2, 2, 2, 3, 3, 3, 4, 4, 4, 1, 1, 1, 7, 7, 7])


In [None]:
import os
import json
import re
import numpy as np
from PIL import Image
import torch
from torch.utils.data import Dataset
from transformers import AutoImageProcessor

class FieldParser:
    def __init__(self, args):
        self.args = args
        self.dataset = args.dataset
        self.vit_feature_extractor = AutoImageProcessor.from_pretrained(
            '/media/wangyujie/CXPMRG_Bench_MambaXray_VL/huggingface/microsoft/swin-base-patch4-window7-224'
        )
    
    def _parse_image(self, img):
        pixel_values = self.vit_feature_extractor(img, return_tensors="pt", size=self.args.input_size).pixel_values
        return pixel_values[0]  # Tensor [3, H, W]

    def clean_report(self, report):
        # 简单示例清理，按需扩展
        report_cleaner = lambda t: t.replace('\n', ' ').strip().lower().split('. ')
        tokens = [t.strip() for t in report_cleaner(report) if len(t.strip()) > 0]
        return ' . '.join(tokens) + ' .'

    def parse(self, features):
        image_id = features['id']
        report = features.get("report", "")
        input_text = self.clean_report(report)
        images = []
        for img_path in features['image_path']:
            full_img_path = os.path.join(self.args.base_dir, img_path)
            with Image.open(full_img_path) as pil_img:
                pil_img = pil_img.convert("RGB")
                arr = np.array(pil_img)
                if len(arr.shape) != 3 or arr.shape[-1] != 3:
                    arr = np.array(pil_img.convert("RGB"))
                img_tensor = self._parse_image(arr)
                images.append(img_tensor)
        return {
            'id': image_id,
            'input_text': input_text,
            'image': images
        }

class ChestXrayWithRegionDataset(Dataset):
    def __init__(self, image_dir, region_dir, annotation_json, base_dir, parser, patch_size=40):
        self.image_dir = image_dir
        self.region_dir = region_dir
        self.base_dir = base_dir
        self.parser = parser
        self.patch_size = patch_size

        # 不再过滤，直接使用全部样本
        self.meta = annotation_json
        print(f"Dataset loaded. Total samples: {len(self.meta)}")

    def __len__(self):
        return len(self.meta)

    def __getitem__(self, idx):
        features = self.meta[idx]
        parsed = self.parser.parse(features)
        input_text = parsed['input_text']
        images = parsed['image']

        image_file = features['image_path'][0]
        image_id = os.path.splitext(image_file)[0]
        region_path = os.path.join(self.region_dir, f"{image_id}.pt")
        region_data = torch.load(region_path)
        patch_tensor = region_data.get('patch_tensor', None)
        boxes = region_data.get('box_tensor', None)
        labels = region_data.get('class_idx', None)

        return {
            'id': image_id,
            'input_text': input_text,
            'image': images,
            'patch_tensor': patch_tensor,
            'box_tensor': boxes,
            'class_idx': labels,
        }



def create_datasets(args):
    with open(args.annotation_path, 'r') as f:
        full_data = json.load(f)
    train_ids = full_data['train']
    val_ids = full_data['val']
    test_ids = full_data['test']

    if getattr(args, 'use_sip', False):
        parser = FieldParser(args)
        train_dataset = ChestXrayWithRegionDataset(
            image_dir=args.image_root,
            region_dir=args.region_root,
            annotation_json=train_ids,
            base_dir=args.base_dir,
            parser=parser,
            patch_size=getattr(args, 'patch_size', 40)
        )
        val_dataset = ChestXrayWithRegionDataset(
            image_dir=args.image_root,
            region_dir=args.region_root,
            annotation_json=val_ids,
            base_dir=args.base_dir,
            parser=parser,
            patch_size=getattr(args, 'patch_size', 40)
        )
        test_dataset = ChestXrayWithRegionDataset(
            image_dir=args.image_root,
            region_dir=args.region_root,
            annotation_json=test_ids,
            base_dir=args.base_dir,
            parser=parser,
            patch_size=getattr(args, 'patch_size', 40)
        )
    else:
        raise NotImplementedError("非 SIP 版本暂未实现")

    return train_dataset, val_dataset, test_dataset

# 测试函数
class Args:
    def __init__(self):
        self.annotation_path = "/media/wangyujie/CXPMRG_Bench_MambaXray_VL/daasets/mimic/mimic+inter+intra1.json"
        self.image_root = "/media/wangyujie/CXPMRG_Bench_MambaXray_VL/daasets/mimic/image_224"
        self.region_root = "/media/wangyujie/CXPMRG_Bench_MambaXray_VL/promot/region_top5_capture/feature_boxes_flat"
        self.base_dir = self.image_root
        self.use_sip = True
        self.patch_size = 40
        self.dataset = "mimic_cxr"
        self.input_size = 224

def test():
    args = Args()
    train_ds, val_ds, test_ds = create_datasets(args)
    print(f"训练集大小: {len(train_ds)}")

    sample = train_ds[0]
    print(f"样本ID: {sample['id']}")
    print(f"报告文本: {sample['input_text']}")
    print(f"图像张数: {len(sample['image'])}")

    image = sample['image'][0]
    print(f"图像 shape: {image.shape}")  # 应该是 torch.Size([3, 224, 224])


    # region 信息
    patch_tensor = sample['patch_tensor']
    boxes = sample['box_tensor']
    labels = sample['class_idx']

    print(f"patch_tensor形状: {None if patch_tensor is None else patch_tensor.shape}")
    print(f"box_tensor形状: {None if boxes is None else boxes.shape}")
    print(f"class_idx: {labels}")


if __name__ == "__main__":
    test()


Dataset loaded. Total samples: 370342
Dataset loaded. Total samples: 2130
Dataset loaded. Total samples: 3858
训练集大小: 370342
样本ID: p10_p10000032_s50414267_02aa804e-bde0afdd-112c0b34-7bc16630-4e384014
报告文本: there is no focal consolidation, pleural effusion or pneumothorax . bilateral  nodular opacities that most likely represent nipple shadows . the  cardiomediastinal silhouette is normal . clips project over the left lung,  potentially within the breast . the imaged upper abdomen is unremarkable . chronic deformity of the posterior left sixth and seventh ribs are noted. .
图像张数: 1
图像 shape: torch.Size([3, 224, 224])
patch_tensor形状: torch.Size([18, 1, 40, 40])
box_tensor形状: torch.Size([18, 4])
class_idx: tensor([0, 0, 0, 2, 2, 2, 3, 3, 3, 4, 4, 4, 1, 1, 1, 7, 7, 7])


In [3]:
import os
import json
import numpy as np
from PIL import Image
import torch
import torch.utils.data as data
from transformers import AutoImageProcessor

# FieldParser复用你的版本
class FieldParser:
    def __init__(self, args):
        self.args = args
        self.dataset = args.dataset
        self.vit_feature_extractor = AutoImageProcessor.from_pretrained(
            '/media/wangyujie/CXPMRG_Bench_MambaXray_VL/huggingface/microsoft/swin-base-patch4-window7-224'
        )

    def _parse_image(self, img):
        pixel_values = self.vit_feature_extractor(img, return_tensors="pt", size=self.args.input_size).pixel_values
        return pixel_values[0]

    def clean_report(self, report):
        report_cleaner = lambda t: t.replace('\n', ' ').strip().lower().split('. ')
        tokens = [t.strip() for t in report_cleaner(report) if len(t.strip()) > 0]
        return ' . '.join(tokens) + ' .'

    def parse(self, features):
        image_id = features['id']
        report = features.get("report", "")
        input_text = self.clean_report(report)
        images = []
        for img_path in features['image_path']:
            full_img_path = os.path.join(self.args.base_dir, img_path)
            with Image.open(full_img_path) as pil_img:
                pil_img = pil_img.convert("RGB")
                arr = np.array(pil_img)
                if len(arr.shape) != 3 or arr.shape[-1] != 3:
                    arr = np.array(pil_img.convert("RGB"))
                img_tensor = self._parse_image(arr)
                images.append(img_tensor)
        return {
            'id': image_id,
            'input_text': input_text,
            'image': images
        }

    def transform_with_parse(self, inputs):
        return self.parse(inputs)

# ParseDataset原始版
class ParseDataset(data.Dataset):
    def __init__(self, args, split='train'):
        self.args = args
        self.meta = json.loads(open(args.annotation_path, 'r', encoding='utf-8').read())
        self.meta = self.meta[split]
        self.parser = FieldParser(args)

    def __len__(self):
        return len(self.meta)

    def __getitem__(self, index):
        return self.parser.transform_with_parse(self.meta[index])

# # ChestXrayWithRegionDataset带region的版本
# class ChestXrayWithRegionDataset(data.Dataset):
#     def __init__(self, image_dir, region_dir, annotation_json, base_dir, parser, patch_size=40):
#         self.image_dir = image_dir
#         self.region_dir = region_dir
#         self.base_dir = base_dir
#         self.parser = parser
#         self.patch_size = patch_size

#         filtered_meta = []
#         for feat in annotation_json:
#             image_file = feat['image_path'][0]
#             image_id = os.path.splitext(image_file)[0]
#             region_path = os.path.join(self.region_dir, f"{image_id}.pt")
#             if os.path.exists(region_path):
#                 filtered_meta.append(feat)
#             else:
#                 print(f"[Warning] region file missing, filtering out: {region_path}")
#         self.meta = filtered_meta
#         print(f"Filtered dataset size: {len(self.meta)}")

#     def __len__(self):
#         return len(self.meta)

#     def __getitem__(self, idx):
#         features = self.meta[idx]
#         parsed = self.parser.parse(features)
#         input_text = parsed['input_text']
#         images = parsed['image']

#         image_file = features['image_path'][0]
#         image_id = os.path.splitext(image_file)[0]
#         region_path = os.path.join(self.region_dir, f"{image_id}.pt")

#         region_data = torch.load(region_path)
#         patch_tensor = region_data.get('patch_tensor', None)
#         boxes = region_data.get('box_tensor', None)
#         labels = region_data.get('class_idx', None)

#         return {
#             'id': image_id,
#             'input_text': input_text,
#             'image': images,
#             'patch_tensor': patch_tensor,
#             'box_tensor': boxes,
#             'class_idx': labels,
#         }
class ChestXrayWithRegionDataset(data.Dataset):
    def __init__(self, image_dir, region_dir, annotation_json, base_dir, parser, patch_size=40):
        self.image_dir = image_dir
        self.region_dir = region_dir
        self.base_dir = base_dir
        self.parser = parser
        self.patch_size = patch_size

        # 不再过滤，直接使用全部样本
        self.meta = annotation_json
        print(f"Dataset loaded. Total samples: {len(self.meta)}")

    def __len__(self):
        return len(self.meta)

    def __getitem__(self, idx):
        features = self.meta[idx]
        parsed = self.parser.parse(features)
        input_text = parsed['input_text']
        images = parsed['image']

        image_file = features['image_path'][0]
        image_id = os.path.splitext(image_file)[0]
        region_path = os.path.join(self.region_dir, f"{image_id}.pt")
        region_data = torch.load(region_path)
        patch_tensor = region_data.get('patch_tensor', None)
        boxes = region_data.get('box_tensor', None)
        labels = region_data.get('class_idx', None)

        return {
            'id': image_id,
            'input_text': input_text,
            'image': images,
            'patch_tensor': patch_tensor,
            'box_tensor': boxes,
            'class_idx': labels,
        }



# create_datasets根据args.use_sip决定用哪个dataset
def create_datasets(args):
    with open(args.annotation_path, 'r') as f:
        full_data = json.load(f)
    train_ids = full_data['train']
    val_ids = full_data['val']
    test_ids = full_data['test']

    if getattr(args, 'use_sip', False):
        parser = FieldParser(args)
        train_dataset = ChestXrayWithRegionDataset(
            image_dir=args.image_root,
            region_dir=args.region_root,
            annotation_json=train_ids,
            base_dir=args.base_dir,
            parser=parser,
            patch_size=getattr(args, 'patch_size', 40)
        )
        val_dataset = ChestXrayWithRegionDataset(
            image_dir=args.image_root,
            region_dir=args.region_root,
            annotation_json=val_ids,
            base_dir=args.base_dir,
            parser=parser,
            patch_size=getattr(args, 'patch_size', 40)
        )
        test_dataset = ChestXrayWithRegionDataset(
            image_dir=args.image_root,
            region_dir=args.region_root,
            annotation_json=test_ids,
            base_dir=args.base_dir,
            parser=parser,
            patch_size=getattr(args, 'patch_size', 40)
        )
    else:
        train_dataset = ParseDataset(args, 'train')
        val_dataset = ParseDataset(args, 'val')
        test_dataset = ParseDataset(args, 'test')

    return train_dataset, val_dataset, test_dataset


# 测试示例
class Args:
    def __init__(self):
        self.annotation_path = "/media/wangyujie/CXPMRG_Bench_MambaXray_VL/daasets/mimic/mimic+inter+intra1.json"
        self.image_root = "/media/wangyujie/CXPMRG_Bench_MambaXray_VL/daasets/mimic/image_224"
        self.region_root = "/media/wangyujie/CXPMRG_Bench_MambaXray_VL/promot/region_top5_capture/feature_boxes_flat"
        self.base_dir = self.image_root
        # self.use_sip = True  # 你想测试region的话设True，否则False
        self.use_sip = False  # 你想测试region的话设True，否则False
        self.patch_size = 40
        self.dataset = "mimic_cxr"
        self.input_size = 224

def test():
    args = Args()
    train_ds, val_ds, test_ds = create_datasets(args)
    print(f"训练集大小: {len(train_ds)}")
    sample = train_ds[0]
    print(f"样本ID: {sample['id']}")
    print(f"报告文本: {sample['input_text']}")
    print(f"图像张数: {len(sample['image'])}")
    print(f"第一张图像Tensor形状: {sample['image'][0].shape}")

    if args.use_sip:
        print(f"patch_tensor形状: {None if sample['patch_tensor'] is None else sample['patch_tensor'].shape}")
        print(f"box_tensor: {sample['box_tensor']}")
        print(f"class_idx: {sample['class_idx']}")
    else:
        print("未使用region数据")

if __name__ == "__main__":
    test()


训练集大小: 370342
样本ID: 02aa804e-bde0afdd-112c0b34-7bc16630-4e384014
报告文本: there is no focal consolidation, pleural effusion or pneumothorax . bilateral  nodular opacities that most likely represent nipple shadows . the  cardiomediastinal silhouette is normal . clips project over the left lung,  potentially within the breast . the imaged upper abdomen is unremarkable . chronic deformity of the posterior left sixth and seventh ribs are noted. .
图像张数: 1
第一张图像Tensor形状: torch.Size([3, 224, 224])
未使用region数据
