In [4]:
import os
import cv2
import pandas as pd
import torch
from torch.utils.data import Dataset
from torchvision import transforms
from PIL import Image
import matplotlib.pyplot as plt
from tqdm import tqdm
from sklearn.model_selection import train_test_split
import gc
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)
pd.set_option('mode.chained_assignment',  None) # 경고 제어

In [5]:
class MIMICCXR(Dataset):
    def __init__(self, df, args, transform=None, split='train'):
        self.data_dir = df['image_path'].values  # 이미지 경로가 포함된 리스트
        self.transform = transform
        self.args=args
        self.CLASSES = ['Atelectasis', 'Cardiomegaly', 'Consolidation', 'Edema',
                        'Enlarged Cardiomediastinum', 'Fracture', 'Lung Lesion',
                        'Lung Opacity', 'No Finding', 'Pleural Effusion', 
                        'Pleural Other', 'Pneumonia', 'Pneumothorax', 'Support Devices']
        
        self.filenames_loaded  = df['dicom_id'].astype(str).values
        self.filenames_to_path = dict(zip(df['dicom_id'].astype(str), df['image_path'].values)) # image path

        # label
        labels = pd.read_csv("C:/Users/gangmin/dahs/data/physionet.org/files/mimic-cxr-jpg/2.1.0/mimic-cxr-2.0.0-negbio.csv")
        labels[self.CLASSES] = labels[self.CLASSES].fillna(0) # NaN을 0으로 변환함
        labels = labels.replace(-1.0,0.0) # -1인 값이 일부 존재함

        metadata_with_labels = filtered_metadata.merge(labels[self.CLASSES+['study_id'] ], how='inner', on='study_id')
        self.filenames_to_labels = dict(zip(metadata_with_labels['dicom_id'].values, metadata_with_labels[self.CLASSES].values))


    def __getitem__(self, index):
        filename = self.filenames_loaded[index]
        
        # 이미지를 로드
        try: 
            img = Image.open(self.filenames_to_path[filename]).convert('RGB')
        except Exception as e: 
            print(f"Error loading image {filename}: {e}")
            return None, None
        
        # 라벨을 로드
        labels = torch.tensor(self.filenames_to_labels[filename]).float()

        # transform이 정의된 경우 적용함
        if self.transform is not None:
            img = self.transform(img)
        
        return img, labels


    def __len__(self):
        return len(self.filenames_loaded)
    

    def get_transforms(args): 
        normalize = transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        processed_images_train = [
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            normalize
        ]

        processed_images_test = [
            transforms.Resize(args.resize), # 별도의 전처리 위해 args 적용
            transforms.CenterCrop(args.crop),
            transforms.ToTensor(),
            normalize
        ]
        return processed_images_train, processed_images_test
    

    def get_cxr_datasets(args): 
        global filtered_metadata
        train_images, test_images = MIMICCXR.get_transforms(args)

        metadata = pd.read_csv("C:/Users/gangmin/dahs/data/physionet.org/files/mimic-cxr-jpg/2.1.0/mimic-cxr-2.0.0-metadata.csv")
        ap_pa_metadata = metadata[(metadata['ViewPosition']=="AP") | (metadata['ViewPosition']=="PA")]
        filtered_metadata = ap_pa_metadata.iloc[:1000,:]
        base_path = "C:/Users/gangmin/dahs/data/physionet.org/files/mimic-cxr-jpg/2.1.0/files"

        df, image_count = MIMICCXR.find_images(base_path, filtered_metadata)
        print(f'{image_count} images in my computer')

        train_df, test_df = train_test_split(df, test_size=0.3, random_state=0)

        train_dataset = MIMICCXR(train_df, transform=transforms.Compose(train_images), split='train', args=args)
        test_dataset = MIMICCXR(test_df, transform=transforms.Compose(test_images), split='test', args=args)

        print('CXR dataset preprocessing completed')
        return train_dataset, test_dataset


    def find_images(base_path, df):
        image_count = 0
        image_paths = []

        for _, row in tqdm(df.iterrows(), total=len(df), desc="Finding images"):
            for p_folder in range(10, 20):
                p_folder_path = os.path.join(base_path, f'p{p_folder}')
                if not os.path.exists(p_folder_path):
                    continue
                
                subject_folder_path = os.path.join(p_folder_path, f'p{row["subject_id"]}')
                if not os.path.exists(subject_folder_path):
                    continue

                study_folder_path = os.path.join(subject_folder_path, f's{row["study_id"]}')
                if not os.path.exists(study_folder_path):
                    continue

                dicom_file_path = os.path.join(study_folder_path, f"{row['dicom_id']}.jpg")
                if os.path.exists(dicom_file_path):
                    image_count += 1
                    image_paths.append(dicom_file_path)
                    break  # 이미지가 발견되면 다음 row로 이동
            else:
                image_paths.append(None)  # 이미지를 찾지 못한 경우 None을 추가

        df['image_path'] = image_paths
        return df, image_count

    # def load_images(image_paths):
    #     images = []
    #     for path in tqdm(image_paths, desc='Images loading...'):
    #         if os.path.exists(path):
    #             img = cv2.imread(path, cv2.IMREAD_ANYCOLOR)
    #             if img is not None: 
    #                 images.append((path, img))
    #             else: 
    #                 print(f"Failed to load images: {path}")
    #     return images

    # df, image_count = find_images(base_path, filtered_metadata)
    # print(f'{image_count} images in my computer')

    # images = load_images(df['image_path'].dropna())

In [6]:
class Args: 
    resize = 256
    crop = 224

args = Args()
train_dataset, test_dataset = MIMICCXR.get_cxr_datasets(args)

Finding images: 100%|██████████| 1000/1000 [00:00<00:00, 5750.14it/s]


998 images in my computer
CXR dataset preprocessing completed


1. PyTorch의 이미지 데이터 형식
- PyTorch에서 이미지는 [C, H, W] 형식으로 저장됩니다.
    -  C: 채널(Channel) 수 (예: RGB 이미지의 경우 3)
    - H: 이미지의 높이(Height)
    - W: 이미지의 너비(Width)
    
이 형식은 딥러닝 모델에서 효율적으로 처리할 수 있도록 설계되었습니다. 특히 CNN(Convolutional Neural Networks) 모델에서 이 형식이 주로 사용됩니다.

2. Matplotlib의 이미지 데이터 형식:
- 일반적으로 이미지 시각화 라이브러리(예: Matplotlib)에서는 이미지를 [H, W, C] 형식으로 처리합니다.
H: 이미지의 높이(Height)
W: 이미지의 너비(Width)
C: 채널(Channel) 수

In [None]:
img, label = train_dataset[0]
print("Image shape:", img.shape)
print("Label:", label)

plt.imshow(img.permute(1, 2, 0))
plt.title("Label: " + str(label))
plt.axis('off')
plt.show()

In [8]:
print("train dataset 개수:", len(train_dataset))
print("test dataset 개수:", len(test_dataset))

train dataset 개수: 700
test dataset 개수: 300


In [9]:
for i in range(len(train_dataset)):
    img, label = train_dataset[i]
    print(f"Sample {i}: Image shape: {img.shape}, Label: {label}")
    if i == 9:
        break

Sample 0: Image shape: torch.Size([3, 224, 224]), Label: tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0.])
Sample 1: Image shape: torch.Size([3, 224, 224]), Label: tensor([0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 1.])
Sample 2: Image shape: torch.Size([3, 224, 224]), Label: tensor([1., 0., 0., 0., 0., 0., 0., 1., 0., 1., 0., 0., 0., 0.])
Sample 3: Image shape: torch.Size([3, 224, 224]), Label: tensor([0., 1., 0., 1., 0., 0., 0., 0., 0., 0., 1., 0., 0., 1.])
Sample 4: Image shape: torch.Size([3, 224, 224]), Label: tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 1.])
Sample 5: Image shape: torch.Size([3, 224, 224]), Label: tensor([0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0.])
Sample 6: Image shape: torch.Size([3, 224, 224]), Label: tensor([0., 0., 0., 0., 0., 1., 0., 0., 0., 1., 0., 0., 1., 0.])
Sample 7: Image shape: torch.Size([3, 224, 224]), Label: tensor([0., 0., 0., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 1.])
Sample 8: Image shape: t