In [6]:
import sys, os

#Our project root directory
PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname("__file__"), 
                                            os.pardir))
sys.path.append(PROJECT_ROOT)

#Local packages loaded from src specifying useful constants, and our custom loader
from util.constants import DATA_PATHS, IMG_TRAIN_CELL_DIR, ANN_TRAIN_CELL_DIR, META_PATH
from PIL import Image
import numpy as np
import pandas as pd
import json
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from os.path import join, split
from os import listdir

from torchvision import transforms as trfms

In [67]:
def load_image(abspath):
    assert os.path.isfile(abspath), f"File {os.path.split(abspath)[1]} not found."
    image = Image.open(abspath)
    return image

def load_csv(abspath):
    assert os.path.isfile(abspath), f"File {split(abspath)[1]} not found."
    try:
        return pd.read_csv(abspath).to_numpy()
    except:
        return np.empty((0,3))

def load_json(abspath, obj_name):
    assert os.path.isfile(abspath), f"File {split(abspath)[1]} not found."
    with open(abspath) as jsonFile:
        jsonObject = json.load(jsonFile)
        jsonFile.close()
    x_start = jsonObject['sample_pairs'][obj_name]['cell']['x_start']
    x_end = jsonObject['sample_pairs'][obj_name]['cell']['x_end']
    y_start = jsonObject['sample_pairs'][obj_name]['cell']['y_start']
    y_end = jsonObject['sample_pairs'][obj_name]['cell']['y_end']
    return np.array([x_start, x_end]), np.array([y_start, y_end])

class MyDataLoader(Dataset):
    def __init__(self, paths, data_to_load = None, transforms=None):
        self.cell_image_path, self.cell_image_file = paths[0], listdir(paths[0])
        self.cell_ann_path, self.cell_ann_file = paths[1], listdir(paths[1])
        self.tiss_image_path, self.tiss_image_file = paths[2], listdir(paths[2])
        self.tiss_mask_path, self.tiss_mask_file = paths[3], listdir(paths[3])
        self.metadata_path = paths[4]
        self.transforms = transforms
        self.data_to_load = data_to_load
    
    def __len__(self):
        assert len(self.cell_image_file) == len(self.cell_ann_file), "Cell-cell data size mismatch. Rebalancing needed."
        assert len(self.tiss_image_file) == len(self.tiss_mask_file), "Tisue-tissue data size mismatch. Rebalancing needed."
        assert len(self.cell_image_file) == len(self.tiss_image_file), "Cell-tissue data size mismatch. Rebalancing needed."
        return len(self.cell_image_file)

    def __getitem__(self, idx):
        cell_images, cell_anns, tiss_images, tiss_masks, x_coords, y_coords = [], [], [], [], [], []

        if isinstance(idx, slice):

            #get 
            for a, b, c, d, e in zip(self.cell_image_file[idx], 
                                     self.cell_ann_file[idx], 
                                     self.tiss_image_file[idx],
                                     self.tiss_mask_file[idx],
                                     self.cell_image_file[idx]
                                     ):
                cell_image = load_image(join(self.cell_image_path, a))
                cell_image = self.transforms(cell_image) if self.transforms else cell_image
                cell_images.append(cell_image)

            #for x in self.cell_ann_file[idx]:
                cell_ann = load_csv(join(self.cell_ann_path, b))
                cell_anns.append(cell_ann)

            #for x in self.tiss_image_file[idx]:
                tiss_image = load_image(join(self.tiss_image_path, c))
                tiss_image = self.transforms(tiss_image) if self.transforms else tiss_image
                tiss_images.append(tiss_image)

            #for x in self.tiss_mask_file[idx]:
                tiss_mask = load_image(join(self.tiss_mask_path, d))
                tiss_mask = self.transforms(tiss_mask) if self.transforms else tiss_mask
                tiss_masks.append(tiss_mask)

            #for x in self.cell_image_file[idx]:
                x_coord, y_coord = load_json(self.metadata_path, os.path.splitext(e)[0])
                x_coords.append(x_coord), y_coords.append(y_coord)
        
        else:
            cell_image = load_image(join(self.cell_image_path, self.cell_image_file[idx]))
            cell_image = self.transforms(cell_image) if self.transforms else cell_image
            cell_images.append(cell_image)
            cell_ann = load_csv(join(self.cell_ann_path, self.cell_ann_file[idx]))
            cell_anns.append(cell_ann)
            
            tiss_image = load_image(join(self.tiss_image_path, self.tiss_image_file[idx]))
            tiss_image = self.transforms(tiss_image) if self.transforms else tiss_image
            tiss_images.append(tiss_image)
            tiss_mask = load_image(join(self.tiss_mask_path, self.tiss_mask_file[idx]))
            tiss_mask = self.transforms(tiss_mask) if self.transforms else tiss_mask
            tiss_masks.append(tiss_mask)
            
            x_coord, y_coord = load_json(self.metadata_path, os.path.splitext(self.cell_image_file[idx])[0])
            x_coords.append(x_coord), y_coords.append(y_coord)
        
        if self.transforms:
            cell_images = self.transforms(cell_images)
        
        if self.data_to_load == 'Tissue' or self.data_to_load == 'tissue':
            return tiss_images, tiss_masks
        
        elif self.data_to_load == 'Cell' or self.data_to_load == 'cell':
            return cell_images, cell_anns
        
        else:
            return cell_images, cell_anns, tiss_images, tiss_masks, x_coords, y_coords


transforms = trfms.Compose([trfms.ToTensor()])
x = MyDataLoader(DATA_PATHS, transforms=transforms, data_to_load='Tissue')

In [68]:
print(x[0][0])

TypeError: pic should be PIL Image or ndarray. Got <class 'list'>