In [1]:
import numpy as np
import pandas as pd
from tqdm import tqdm
from PIL import Image

import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms

import os
import sys
import inspect

currentdir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe())))
parentdir = os.path.dirname(currentdir)
sys.path.insert(0, parentdir) 

In [15]:
T = transforms.Compose([
    transforms.ToTensor(),
    lambda x: x.permute(1, 2, 0).unsqueeze(1).repeat(1, 3, 1, 1),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                        std=[0.229, 0.224, 0.225])
])

In [4]:
data = np.memmap('/scratch/arihanth.srikar/train_data.bin', dtype=np.uint8, mode='r').reshape(-1, 19, 224, 224)

In [37]:
T = transforms.Compose([
    transforms.Resize((224, 224), antialias=True),
    transforms.ToTensor(),
    lambda x: x*225
])

In [None]:
print("Loading dataset")
try:
    df = pd.read_json('/home/ssd_scratch/users/arihanth.srikar/physionet.org/files/chest-imagenome/1.0.0/silver_dataset/mimic_coco_filtered.json')
except:
    df = pd.read_json('/scratch/arihanth.srikar/physionet.org/files/chest-imagenome/1.0.0/silver_dataset/mimic_coco_filtered.json')
temp_df = pd.read_csv('data/mimic_cxr_jpg/mimic-cxr-2.0.0-final.csv')
temp_df.rename(columns={'dicom_id': 'image_id'}, inplace=True)
df = df.merge(temp_df, on='image_id', how='left')
print("Dataset loaded")

In [None]:
class DumpDataset(Dataset):
    def __init__(self, df, transform=T):
        self.df = df
        self.img_loc_prefix = '/home/ssd_scratch/users/arihanth.srikar/physionet.org/files/mimic-cxr-jpg/2.0.0/files'
        self.transform = transform

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        sample_data = self.df.iloc[idx]
        
        try:
            pid = str(int(sample_data['subject_id']))
            sid = str(int(sample_data['study_id']))
            image_file_location = f'{self.img_loc_prefix}/p{pid[:2]}/p{pid}/s{sid}/{sample_data["image_id"]}.jpg'
            img = Image.open(image_file_location)
        except:
            self.img_loc_prefix = '/scratch/arihanth.srikar/physionet.org/files/mimic-cxr-jpg/2.0.0/files'
            return self.__getitem__(idx)
        
        sub_anatomies = []
        sub_anatomy_labels = []
        sub_anatomy_name = []
        for annotation in sorted(sample_data['annotations'], key=lambda k: k['category_id']):
            x, y, w, h = annotation['bbox']
            sub_anatomy = img.crop((x, y, x+w, y+h))
            sub_anatomy = self.transform(sub_anatomy).unsqueeze(0)
            sub_anatomies.append(sub_anatomy)
            sub_anatomy_name.append(annotation['id'].split('_')[-1])
            sub_anatomy_labels.append(annotation['attributes'])
            
        img = self.transform(img)
        images = torch.stack([img]+sub_anatomies)
        sub_anatomy_labels = torch.tensor(sub_anatomy_labels).float()
        global_label = (torch.sum(sub_anatomy_labels, dim=0) > 0).float()
        nine_class_labels = torch.cat((global_label, sub_anatomy_labels), dim=0)

        fourteen_class_labels = torch.from_numpy(sample_data[self.df.columns[-15:-1]].to_numpy().astype(np.float32))
        
        return {
            'id': idx,
            'images': images,
            'y': nine_class_labels,
            'anatomy_name': sub_anatomy_name,
            'y_14': fourteen_class_labels,
        }

In [None]:
dump_dataset = DumpDataset(df, transform=T)
dump_dataloader = DataLoader(dump_dataset, batch_size=16, shuffle=False, num_workers=32)

In [None]:
num_nodes = 18

In [22]:
save_file_data = '/scratch/arihanth.srikar/data.bin'
arr = np.memmap(save_file_data, dtype=np.uint8, mode='w+', shape=(len(dump_dataset), num_nodes+1, 224, 224))

save_file_labels_9 = '/scratch/arihanth.srikar/nine_labels.bin'
arr_labels_9 = np.memmap(save_file_labels_9, dtype=np.int8, mode='w+', shape=(len(dump_dataset), num_nodes+1, 9))

save_file_labels_14 = '/scratch/arihanth.srikar/fourteen_labels.bin'
arr_labels_14 = np.memmap(save_file_labels_14, dtype=np.int8, mode='w+', shape=(len(dump_dataset), 14))

In [23]:
for _, batch in enumerate(tqdm(dump_dataloader)):
    for image, idx in zip(batch['images'], batch['id']):
        arr[idx] = image.numpy().astype(np.uint8)
        arr_labels_9[idx] = batch['y'].numpy().astype(np.int8)
        arr_labels_14[idx] = batch['y_14'].numpy().astype(np.int8)

In [24]:
arr.flush()