In [1]:
# Here we take care of paths.

from pathlib import Path
import os
print('Starting path:' + os.getcwd())
if os.getcwd()[-18:] == 'VESUVIUS_Challenge':
    pass
else:
    PATH = Path().resolve().parents[0]
    os.chdir(PATH)

# make sure you are in Paragraph_to_Tex folder
print('Current path:' + os.getcwd())

Starting path:/Users/gregory/PROJECT_ML/VESUVIUS_Challenge/jupyter notebooks
Current path:/Users/gregory/PROJECT_ML/VESUVIUS_Challenge


In [2]:
import torch
import torch.utils.data as data
from torch.utils.data import ConcatDataset, DataLoader, Dataset, ConcatDataset
import pytorch_lightning as pl
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import glob
import PIL.Image as Image
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from tqdm import tqdm
from ipywidgets import interact, fixed
from torchvision import transforms

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
transform = transforms.ToPILImage()

# Dataset Modules

### CONFIGS

In [3]:

PATH = 'kaggle/input/vesuvius-challenge/'
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# scroll_1 size = 8181, 6330
# scroll_2 size = 14830, 9506
# scroll_3 size = 7606, 5249

### Base_Dataset class 
- due to multiprocessing issues in Ipython we import it from from Data_Modules.Base_Dataset import Base_Dataset

class Base_Dataset(Dataset):

    def __init__(self,
                image_stack,
                label,
                pixels,
                 buffer,
                 z_dim,

                 ):

        self.image_stack = image_stack
        self.label = label
        self.pixels = pixels
        self.buffer = buffer
        self.z_dim = z_dim



    def __len__(self):
        return len(self.pixels)



    def __getitem__(self, index: int):
        y,x = self.pixels[index]
        subvolume = self.image_stack[:, y - self.buffer:y + self.buffer + 1, x - self.buffer:x + self.buffer + 1].view(1, self.z_dim, self.buffer * 2 + 1,self.buffer * 2 + 1)
        inklabel = self.label[y, x].view(1)
        return subvolume, inklabel
    



### Scrolls_Dataset wrapper

In [4]:
from Data_Modules.Base_Dataset import Base_Dataset
class Scrolls_Dataset(pl.LightningDataModule):

    def __init__(self,
                 buffer = 15,
                 z_start = 27,
                 z_dim = 10,
                 validation_rect = (1100, 3500, 700, 950),
                shared_height = 8000,
                 downsampling =None,
                 scroll_fragments = [1,2,3],
                 stage = 'train',
                 shuffle=True,
                 batch_size=8,
                 num_workers =4 ,
                 on_gpu= False,


                 ):

        self.buffer = buffer
        self.z_start = z_start
        self.z_dim = z_dim
        self.validation_rect = validation_rect
        self.shared_height = shared_height
        self.downsampling = downsampling
        self.scroll_fragments = scroll_fragments
        self.stage = stage
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.on_gpu = on_gpu


    def prepare_data(self, *args, **kwargs):
        if self.stage.lower() == 'train':


            z_slices = [[] for _ in range(len(self.scroll_fragments))]
            labels =  [[] for _ in range(len(self.scroll_fragments))]
            masks = [[] for _ in range(len(self.scroll_fragments))]

            for i in self.scroll_fragments:
                # get z_slices .tiffs paths
                z_slices[i-1] += sorted(glob.glob(f"{PATH}/{'train'}/{i}/surface_volume/*.tif"))[self.z_start:self.z_start + self.z_dim]
                # get labels
                labels[i-1] = self.load_labels('train', i)
                # get masks
                masks[i-1] = self.load_mask('train', i)

            # get images of z-slices and convert them to tensors
            images = [[] for _ in range(len(self.scroll_fragments))]
            for i in range(len(self.scroll_fragments)):
                images[i] = self.load_slices(z_slices[i])

            # concat images, labels and masks of different scrolls
            images_tensors = torch.cat([image for image in images], axis=-1)
            label_tensors =  torch.cat([label for label in labels], axis=-1)
            mask_tensors =  np.concatenate([mask for mask in masks], axis=-1)
            del images
            del z_slices
            del labels
            del masks

            # obtain train_pixesl and val_pixels
            train_pixels , val_pixels = self.split_train_val(mask_tensors)
            self.mask = mask_tensors
            #del mask_tensors
            
            self.data_train = Base_Dataset(image_stack=images_tensors, label=label_tensors,  pixels=train_pixels, buffer=self.buffer, z_dim=self.z_dim )
            self.data_val = Base_Dataset(image_stack=images_tensors, label=label_tensors,  pixels=val_pixels,  buffer=self.buffer, z_dim=self.z_dim)

            del images_tensors
            del label_tensors
            del train_pixels
            del val_pixels



        # TODO: finish the same for test, note paths are different
        elif self.stage.lower() == 'test':

            # get z_slices paths
            z_slices = [[], []]
            for i, l in enumerate(['a','b']):
                z_slices[i] = sorted(glob.glob(f"{PATH}/{'test'}/{l}/surface_volume/*.tif"))[self.z_star:self.z_star + self.z_dim]



    def train_dataloader(self, *args, **kwargs) -> DataLoader:
        """
        construct a dataloader for training data
        data is shuffled !
        :param args:
        :param kwargs:
        :return:
        """
        return DataLoader(
            self.data_train,
            shuffle=True,
            batch_size=self.batch_size,
            num_workers=self.num_workers,
            pin_memory=self.on_gpu,
            #collate_fn=self.collate_function,
        )

    def val_dataloader(self, *args, **kwargs):
        """

        :param args:
        :param kwargs:
        :return:
        """
        return DataLoader(
            self.data_val,
            shuffle=False,
            batch_size=self.batch_size,
            num_workers=self.num_workers,
            pin_memory=self.on_gpu,
            #collate_fn=self.collate_function
        )

    def test_dataloader(self, *args, **kwargs):
        """

        :param args:
        :param kwargs:
        :return:
        """
        return DataLoader(
            self.data_test,
            shuffle=False,
            batch_size=self.batch_size,
            num_workers=self.num_workers,
            pin_memory=self.on_gpu,
            collate_fn=self.collate_function,
        )






    # image_stack = torch.stack([torch.from_numpy(image) for image in images], dim=0).to(DEVICE)
    def load_slices(self, z_slices_fnames):
        images = []
        for z, filename in tqdm(enumerate(z_slices_fnames)):
            img = Image.open(filename)
            img = self.resize(img)
            z_slice = np.array(img, dtype="float32")/65535.0
            images.append(z_slice)
        return torch.stack([torch.from_numpy(image) for image in images], dim=0).to(DEVICE)



    def load_mask(self, split, index):
        img = Image.open(f"{PATH}/{split}/{index}/mask.png").convert('1')
        img = self.resize(img)
        return np.array(img)



    def load_labels(self, split, index):
        img = Image.open(f"{PATH}/{split}/{index}/inklabels.png")
        img = self.resize(img)
        return torch.from_numpy(np.array(img)).gt(0).float().to(DEVICE)


    def resize(self, img):
        current_width, current_height = img.size
        aspect_ratio = current_width / current_height
        new_width = int(self.shared_height * aspect_ratio)
        new_size = (new_width, self.shared_height)
        img = img.resize(new_size)
        return img



    def split_train_val(self,mask):
        rect = self.validation_rect
        not_border = np.zeros(mask.shape, dtype=bool)
        not_border[self.buffer:mask.shape[0] - self.buffer, self.buffer:mask.shape[1] - self.buffer] = True
        arr_mask = np.array(mask) * not_border
        inside_rect = np.zeros(mask.shape, dtype=bool) * arr_mask
        inside_rect[rect[1]:rect[1] + rect[3] + 1, rect[0]:rect[0] + rect[2] + 1] = True
        outside_rect = np.ones(mask.shape, dtype=bool) * arr_mask
        outside_rect[rect[1]:rect[1] + rect[3] + 1, rect[0]:rect[0] + rect[2] + 1] = False
        pixels_inside_rect = np.argwhere(inside_rect)
        pixels_outside_rect = np.argwhere(outside_rect)
        return pixels_outside_rect, pixels_inside_rect


In [5]:
# Initiating Dataset with parameters

# buffer =   -- x,y patchsize for training
# z_start =  --  Offset of slices in the z direction
# z_dim =    -- Number of slices in the z direction. Max value is (64 - z_start)
# validation_rect =  -- rectangle removed for validation set
# shared_height = -- Height to resize all scrolls
# scroll_fragments = -- scrolls to be used 

dataset = Scrolls_Dataset(
                buffer = 15,
                 z_start = 27,
                 z_dim = 10,
                 validation_rect = (1100, 3500, 700, 950),
                shared_height = 8000,
                 downsampling =None,
                 scroll_fragments = [1,2,3],
                 stage = 'train',
                 shuffle=True,
                 batch_size=8,
                 num_workers =4 ,
                 on_gpu= False,
                          
                         )



In [6]:
# prepeare data, by processng images and loading dataloader

dataset.prepare_data()

10it [00:03,  2.74it/s]
10it [00:06,  1.48it/s]
10it [00:03,  3.18it/s]


### Dataloaders

In [10]:
dataloader = iter(dataset.train_dataloader())
for i in range(2):
    # Get image and label from train data -- change number for different ones
    #print(next(dataloader))
    subvolume, inklabel = next(dataloader)
    print('subvolume shape:',subvolume.shape)
    print('inklabel shape:',inklabel.shape)
    

subvolume shape: torch.Size([8, 1, 10, 31, 31])
inklabel shape: torch.Size([8, 1])
subvolume shape: torch.Size([8, 1, 10, 31, 31])
inklabel shape: torch.Size([8, 1])
