In [3]:
import numpy as np
import torch
import os
from torch.utils.data import Dataset
import cv2
import pandas as pd

In [4]:
MY_DEVICE = torch.device(device='cuda' if torch.cuda.is_available() else 'cpu')
PIN_MEMORY = True if MY_DEVICE == 'cuda' else False
D_TYPE = torch.float32

DATA_DIR = '../../ocelot2023_v0.1.2/'
ANN_CELL_DIR = DATA_DIR + "annotations/train/cell/"
ANN_TISS_DIR = DATA_DIR + "annotations/train/tissue/"
IMG_CELL_DIR = DATA_DIR + "images/train/cell"
IMG_TISS_DIR = DATA_DIR + "images/train/tissue"
META_PATH = DATA_DIR + 'metadata.json'

In [5]:
class SegmentationDataset(Dataset):
    def __init__(self, cellImagePaths, cellAnnotationPaths, tissImagePaths, tissAnnotationPaths, metadataAbsPath, cellTransforms = None, tissTransforms = None):
        self.cellImagePaths = cellImagePaths
        self.cellImageFileNames = os.listdir(cellImagePaths)
        self.cellAnnotationPaths = cellAnnotationPaths
        self.cellAnnFileNames = os.listdir(cellAnnotationPaths)

        self.tissImagePaths = tissImagePaths
        self.tissImageFileNames = os.listdir(tissImagePaths)
        self.tissAnnotationPaths = tissAnnotationPaths
        self.tissAnnFileNames = os.listdir(tissAnnotationPaths)

        self.metadataAbsPath = metadataAbsPath

        self.cellTransforms = cellTransforms
        self.tissTransforms = tissTransforms

    def __len__(self):
        #always make sure all dataset tissue/cell subfolders have proper size correspondence or this will not work
        return len(os.listdir(self.cellImagePaths))
    
    def __getitem__(self, idx):
        cellImageAbsPath = os.path.join(self.cellImagePaths, self.cellImageFileNames[idx])
        tissImageAbsPath = os.path.join(self.tissImagePaths, self.tissImageFileNames[idx])
        cellAnnAbsPath = os.path.join(self.cellAnnotationPaths, self.cellAnnFileNames[idx])
        tissAnnAbsPath = os.path.join(self.tissAnnotationPaths, self.tissAnnFileNames[idx])
        
        cellImage = cv2.imread(cellImageAbsPath)
        cellImage = cv2.cvtColor(cellImage, cv2.COLOR_BGR2RGB) #QUESTION: .reshape(3,1024,1024) #reshape???

        tissImage = cv2.imread(tissImageAbsPath)
        tissImage = cv2.cvtColor(tissImage, cv2.COLOR_BGR2RGB) #QUESTION: .reshape(3,1024,1024) #reshape???

        tissMask = cv2.imread(tissAnnAbsPath, 0)
        cellAnn = pd.read_csv(cellAnnAbsPath, delimiter=',').to_numpy()

        x_coord = ... #TODO: get from metadata.json in form (x_start, x_end) for corresponding sample
        y_coord = ... #TODO: get from metadata.json in form (y_start, y_end) for correspsonding sample

        #TODO: cell transforms
        #TODO: tissue transforms

        return (cellImage, cellAnn, tissImage, tissMask, x_coord, y_coord)

In [6]:
mydata = SegmentationDataset(IMG_CELL_DIR, ANN_CELL_DIR, IMG_TISS_DIR, ANN_TISS_DIR, META_PATH)
mydata[0][3]

array([[2, 2, 2, ..., 2, 2, 2],
       [2, 2, 2, ..., 2, 2, 2],
       [2, 2, 2, ..., 2, 2, 2],
       ...,
       [2, 2, 2, ..., 1, 1, 1],
       [2, 2, 2, ..., 1, 1, 1],
       [2, 2, 2, ..., 1, 1, 1]], dtype=uint8)