In [2]:
data_dir = "../../dataset/rsna-2024-lumbar-spine-degenerative-classification/"

import torch
from torch.utils.data import Dataset
from torchvision import datasets
import pandas as pd
import os
import pydicom

  from .autonotebook import tqdm as notebook_tqdm


In [3]:

class DCMImageDataset(Dataset):
    def __init__(self, series, coordinates_file, descriptions_file, train_file, img_dir, file_counts, transform=None, target_transform=None):
        self.coordinates = coordinates_file
        self.descriptions =  descriptions_file
        self.train = train_file
        self.series = series
        self.img_dir = img_dir

        merge = descriptions_file.merge(train_file, on='study_id', how='left')
        f = merge[merge['series_description'] == series]
        result = []
        for i in range(len(f)):
            study_id = f.iloc[i]['study_id']
            series_id = f.iloc[i]['series_id']
            ndf = f[(f['study_id'] == study_id) & (f['series_id'] == series_id)]
            
            expanded_dfs = []
            for j in range(1, file_counts[str(study_id)][str(series_id)]+1):
                ndf['number'] = j
                expanded_dfs.append(ndf.copy())
            
            dfs = pd.concat(expanded_dfs).reset_index(drop=True)
            result.append(dfs.copy())
        self.df = pd.concat(result).reset_index(drop=True)

        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        study_id = str(self.df.iloc[idx]['study_id'])
        series_id = str(self.df.iloc[idx]['series_id'])

        img_path = os.path.join(str(self.img_dir + 'train_images'), study_id)
        img_path = os.path.join(img_path, series_id)
        img_path = img_path + '/' + str(self.df.iloc[idx]['number']) + '.dcm'

        image = pydicom.dcmread(str(img_path)).pixel_array
        label_column = self.df.columns[3:-2]
        label = self.df.iloc[idx][label_column].tolist()
        
        if self.transform:
            image = self.transform(image)
        if self.target_transform:
            label = self.target_transform(label)

        return image, label
            

In [4]:
coordinates = pd.read_csv(data_dir + 'train_label_coordinates.csv')
descriptions = pd.read_csv(data_dir + 'train_series_descriptions.csv')
train = pd.read_csv(data_dir + 'train.csv')

file_counts = {}
study_ids = os.listdir(data_dir + 'train_images')

for study_id in study_ids:
    series_ids = os.listdir(data_dir + 'train_images/' + study_id)
    tmp = {}
    for series_id in series_ids:
        tmp[series_id] = len(os.listdir(data_dir + 'train_images/' + study_id + '/' + series_id))

    file_counts[study_id] = tmp

In [None]:
dataset = DCMImageDataset(series='Axial T2',
                          coordinates_file=coordinates,
                          descriptions_file=descriptions,
                          train_file=train,
                          img_dir=data_dir,
                          file_counts=file_counts,
                          )

In [15]:
print(len(dataset))

79979
