In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import glob
import PIL.Image as Image
from torch.utils.data import Dataset
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from tqdm import tqdm
from ipywidgets import interact, fixed

import os

In [2]:
class VesuviusSegmentationData(Dataset):
    def __init__(self, train_dir, divx=3, divy=3):
        self.train_samples = [(os.path.join(train_dir, sub), y, x) for sub in os.listdir(train_dir) for y in range(divy) for x in range(divx)]
        self.train_dir = train_dir
        self.divx,self.divy = divx, divy

    def div_xy(self, image : np.ndarray, cx : int, cy : int):
        sizey,sizex = image.shape[0],image.shape[1]
        return image[cy*sizey//self.divy:(cy+1)*sizey//self.divy , cx*sizex//self.divx:(cx+1)*sizex//self.divx]

    def __len__(self):
        return len(self.train_samples)

    def __getitem__(self, idx : int): # dataset[idx] where idx < len(dataset) returns a sample in no particular order
        key, cy, cx = self.train_samples[idx]
        dir = key

        pattern = os.path.join(dir, "surface_volume", "*.tif")
        images = []
        for filename in sorted(glob.glob(pattern)): # Load and chunk each xray slice individually for memory savings.
            image = np.array(Image.open(filename), dtype=np.float32)/65535.0
            images.append(self.div_xy(image, cx, cy))

        mask = np.array(Image.open(os.path.join(dir, "mask.png")).convert('1'))
        mask = self.div_xy(mask, cx, cy)
        label = np.array(Image.open(os.path.join(dir, "inklabels.png")))
        label = self.div_xy(label, cx, cy)

        return np.array(images), label > 0, mask > 0

In [3]:
dataset = VesuviusSegmentationData("vesuvius-challenge-ink-detection/train")