# Vesuvius - Data_trn_val_tst
Manually translated from [here](https://www.kaggle.com/code/synset/vesuvius-data-trn-val-tst)


In [None]:
!pip install qunet
#!pip -q install torchinfo        # model structure

import os, gc, sys, time, datetime, math, random, copy, psutil, glob
import numpy as np,  matplotlib.pyplot as plt, pandas as pd
from pathlib import Path
import PIL.Image as Image
from   tqdm.auto import tqdm
import torch, torch.nn as nn
from torchinfo import summary

from qunet import Info, Config, Callback, Data,  MLP, Transformer, plot_histogram

In [None]:
CFG = Config(
    folder_trn  = '/kaggle/input/vesuvius-challenge-ink-detection/train/',
    folder_tst  = '/kaggle/input/vesuvius-challenge-ink-detection/test/',

    layer_min = 0,
    layer_max = 5,   # !!!!

    patch_h  = 512,  # the height and width of the patches into which the image is split
    patch_w  = 512,

    train    = True,  # the model is trained or loaded from a dataset trained before submission

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu"),
)
Image.MAX_IMAGE_PIXELS = None
info = Info()

# Class VesuviusData
Large Images (H, W) from (7606, 5249) to (14830, 9506). For 65 slices occupy 4*(65*14830*9506) = 34 GB of memory.
For model to work must split image into patches of small sizes (h, w).

For augmentation, these patches should be cut from random places and made in different sizes.
When testing the patches should completely cover the image with regular tiling.
If an integer number of patches does not fit, edge bands with overlapping patches are needed.
Apparently a better option with overlap and subsequent averaging or replacement

We must not forget to try TTA (several oaz go through different sizes, including randomly and average the cumulative mask.
Erosion removal

## Validation problems

With random splitting, validation may overlap with training patches. Let them overlap. Trust your CV.

## Uploading a new file.
For loading use the callback class.
In all modes the requested number of n_patches must be chosen so that they fit into memory. (taking into account their sizes)
In training and validation mode, after loading all the masks, all n_patches are sent to the tensors of the dataset. (i.e. Reload woks in one pass.

In test mode the maximum number of patches is determined by the image size and the patch size. Therefore, the requested number of patches n_patches may be less (to fit into memory). Then reload will select the patches formed by tiling several times.


In [None]:
class DataManagerCallback(Callback):
    def __init__(self, data, train, n_patches, rand=True,
                 patch_h = 64, patch_w = 64, patch_dh = 0, patch_dw = 0,
                 shuffle=False, batch_size=64,  whole_batch=False,
                 layer_min=0, layer_max=65):

        self.train       = train     # Training (validation) or testing
        self.n_patches   = n_patches # patch count in self.data
        self.patch_h0    = patch_h   # patch height (average value)
        self.patch_w0    = patch_w   # patch width
        self.patch_dh    = patch_dh  # deviations from patch_h0 (random deviations from average sizes)
        self.patch_dw    = patch_dw  # deviations from patch_w0
        self.rand        = rand      # random patch positions

        self.layer_min   = layer_min
        self.layer_max   = layer_max
        self.n_layers    = layer_max-layer_min  # number of fragment layers (65)
        assert self.n_layers <= 65, "Wrong number of layers"

        self.data        = data      # instance of Data
        self.period_reload = 1       # period in epochs for which we load a new fragment

        self.fast        = False     # Fast patching without mas control, otherwise only under the fragment
        self.patches_pos = None      # (N,2) tensor of N pathes pos (y,x)
        self.patch_id    = 0         # current starting patch number in patches_pos
        self.folder_id   = 0         # current id of folder from list self.folders:
        self.next_new_file = False   #The next call to reload will load the new fragment

        # In train/val mode we use CFG.folder_trn, when testing (submission) CFG.folder_tst
        self.folders = sorted(list(Path(CFG.folder_trn if train else CFG.folder_tst).glob('*')))
        info(f"VesuviusData: tarin={train}, {len(self.folders)} subfoders (scroll fragments)")

        # Wehn creating an instance of VesuviusData, the data is not loaded.
        # To do this, call the reset method. The same method is called by the trainer before the start of fit
        # Inside fit, the reload method is called periodically `period_reload`

    #---------------------------------------------------------------------------

    def load_masks(self, folder, verbose=1):
        """
        Upload fragment and ink mask files from folder
        """
        self.files = sorted( (folder / Path("surface_volume/")).glob('*.tif') )
        assert len(self.files) == 65, f"Wrong numer of files: {len(self.files)}"

        fname_mask  = folder / Path("mask.png")
        self.mask = torch.from_numpy(np.array(Image.open(fname_mask).convert('1')))
        if verbose >= 1: info(f"loaded mask: {self.mask.shape} from {folder}")
        if self.train: # In test mode there are no ink labels
            fname_ink = folder / Path("inklabels.png")
            self.ink  = torch.from_numpy(np.array(Image.open(fname_ink) .convert('1')))
            if verbose >= 1: info(f"loaded ink : {self.ink.shape}")

        # Random deviations from typical patch size
        self.patch_h = self.patch_h0 + torch.randint(-self.patch_dh, self.patch_dh+1,(1,)).item()
        self.patch_w = self.patch_w0 + torch.randint(-self.patch_dw, self.patch_dw+1,(1,)).item()

        # After this method, get_patches_pos, should be called, which generates a list of patch coordinates
        # and allocates memory for their subset. The reload method will fill this memory until the list of patches ends.
        # After that, the next fragment wil be loaded and the story repeats

    #---------------------------------------------------------------------------

    def get_patches_pos(self):
        """
        Creates (N,2) tensor with the positions of the top-left corners of the patches.
        Сreate data tensors
        It works in two modes: random patches and dense tiling from patches for training and validation.
         We use randomly cut patches from a fragment (at the edges with overlap)
        """
        (H,W), h,w, n = self.mask.shape, self.patch_h, self.patch_w, self.n_patches

        if self.rand:  # random positions
            if self.fast:
                posY = torch.randint(0, H-h, (self.n_patches,1))  # patch position top-lef corner
                posX = torch.randint(0, W-w, (self.n_patches,1))  #
                self.patches_pos = torch.hstack([posY,posX])      # список координта патчей
            else:
                self.patches_pos = torch.empty((self.n_patches,2), dtype=torch.long)
                patch_id = 0
                while patch_id < self.n_patches:
                    y = torch.randint(0, H-h, (1,)).item()     # patch position top-lef corner
                    x = torch.randint(0, W-w, (1,)).item()
                    if self.mask[y:y+h, x:x+w].sum():          # Patch coordinate list
                        self.patches_pos[patch_id, 0] = y
                        self.patches_pos[patch_id, 1] = x
                        patch_id += 1


        else:          # tiling (maybe with overlap on the right and bottom of image)
            posY = torch.IntTensor(list(range(0,H-h,h)) + ([H-h] if H % h else []) )
            posX = torch.IntTensor(list(range(0,W-w,w)) + ([W-w] if W % w else []) )
            self.patches_pos = torch.cartesian_prod(posY,posX)
            self.n_patches = min(self.n_patches, len(self.patches_pos)) # cannot exceed the number of patches `patches_pos`

        self.patch_id = 0 # Number of the first example to form examples for tensors

    #---------------------------------------------------------------------------

    def get_patches(self, verbose):
        """
        Get tensors patches.
        Return False if all positions from patches_pos are selected
        """
        # We take a subset of n patches, starting the patch with patch patch_id number
        patches_pos = self.patches_pos[self.patch_id: self.patch_id +  self.n_patches].numpy()

        # We callocate memory for dataset tensors. The number of examles is less than or equal to the number of patches in `patches_pos`
        n, h,w = len(patches_pos), self.patch_h, self.patch_w
        self.data.data = [torch.empty( n,self.n_layers, h,w ),            # patches by layers
                          torch.empty((n,1,h,w), dtype=self.mask.dtype ), # fragment mask
                          torch.empty((n,1,h,w), dtype=self.mask.dtype ), # ink mask (target)
                          torch.empty((n,2),     dtype=torch.int32 ) ]    # patch positions

        # Filing the tensors of the dataset
        # todo: eliminate cycles (???)
        self.data.data[1][:,0,:,:] = torch.vstack([self.mask[y:y+h, x:x+w].view(1,h,w) for (y,x) in patches_pos])
        if self.train:
            self.data.data[2][:,0,:,:] = torch.vstack([self.ink [y:y+h, x:x+w].view(1,h,w) for (y,x) in patches_pos])
        else:
            self.data.data[2] = self.data.data[1]
        self.data.data[3] = torch.vstack([torch.IntTensor([y,x]).view(1,2) for (y,x) in patches_pos])

        # Load images of slices one by one and split each of them into patches with coordinates from patches_pos
        for d in range(self.layer_min,  self.layer_max):
            image = torch.tensor( np.array(Image.open(self.files[d] ), dtype=np.float32) / 65535.0 )
            assert self.mask.shape == image.shape

            patches = torch.vstack([image[y:y+h, x:x+w].view(1,h,w) for (y,x) in patches_pos])
            self.data.data[0][:, d-self.layer_min, :, :] = patches

            if verbose >= 2: print(f"\rdepth:{d:2d}, {self.data.data[0].shape};  {self.data.data[0].is_contiguous()}", end ="    ")

        self.patch_id += self.n_patches
        if verbose >= 2: print(f" layers loaded, patch_id={self.patch_id}.")

        return self.patch_id >= len(self.patches_pos)  # True, if the self.patches_pos list is over

    #---------------------------------------------------------------------------

    def reload(self, train=True, epoch=0, hist=Config(), best=Config(), verbose=1):
        """
        Called in the trainer.fit() function, every period_reload epochs.
        """
        # Initial loading of masks and generating a list of patch coordinates
        # We do this only at the beginning (and not when we select the list of coordinates
        if self.next_new_file:
            self.next_new_file = False
            self.folder_id = (self.folder_id + 1) % len(self.folders)   # Fragments are numbered in a circle
            self.patch_id  = 0

            self.load_masks(self.folders[self.folder_id], verbose)      # upload fragment and ink masks
            self.get_patches_pos()                                      # create positions of patches

        # Fill the tensors of the dataset with patches with masks and fragment layers un der the patches.
        # get_patches will return True, when the entire list of coordinates has been selected
        # if this happens, next time we load the next fragment
        # Immediately impossible, because it will be necessary to first process the loaded files in the batch iterator
        if self.get_patches(verbose):                                   # all patches are selected
            self.next_new_file = True                                   # load a new file on the next call
            return True
        return False

    #---------------------------------------------------------------------------

    def on_fit_start(self, trainer, model):
        """
        Called before the trainer.fit() function starts running
        Before starting fit, load the first fragment, split it into patches,
        and form the dataset tensors from the initial subset of pataches
        """
        self.folder_id = -1
        self.next_new_file = True

        #---------------------------------------------------------------------------

    def on_train_epoch_start(self, trainer, model):
        """
        Called when epoch in fit ends.
        """
        if trainer is None or  (trainer.epoch == 1 or trainer.epoch % self.period_reload == 0):
            self.reload(verbose = 2)

# Creating validation and training datasets
We will make validation static by uploading 3 files once, splitting them into patches and collecting all the patches into one dataset.
To add one dataset to another, the data class has an add method.
The breakdown result can be seen in the visualization section

In [None]:
# Validation:
info.reset()("beg")
data_val = Data(batch_size=50)
data_tmp = Data()
callback = DataManagerCallback (data=data_tmp, train=True, n_patches=200, rand=True,
                                patch_h=CFG.patch_h, patch_w=CFG.patch_w, layer_min=CFG.layer_min, layer_max=CFG.layer_max)
callback.on_fit_start(None,0)
for i in range(3):            # 200 patches 3 times ( From each fragment)
    callback.on_train_epoch_start(None,0)
    data_val.add(callback.data)
info(f"data_val samples = {data_val.count()}  batches = {len(data_val)}")

# Training (validation callback is not needed already):
data_trn = Data(batch_size=100, shuffle=True)
callback = DataManagerCallback (data=data_trn, train=True, n_patches=1000, rand=True,
                                patch_h=CFG.patch_h, patch_w=CFG.patch_w, patch_dh=16, patch_dw=16,
                                layer_min=CFG.layer_min, layer_max=CFG.layer_max)

# We imitate a coach

In [None]:
class Trainer:
    def __init__(self, data_trn, data_val = data_val, callbacks=[]):
        self.data      = Config(trn=data_trn, val=data_val)
        self.callbacks = callbacks
        self.epoch     = 0

    def fit(self, epochs, period_reload=1):
        """ Imitate the fit method """
        for callback in self.callbacks: callback.on_fit_start(self, 0)
        for epoch in range(1, epochs+1):
            self.epoch = epoch

            info("training start")
            for callback in self.callbacks: callback.on_train_epoch_start(self, 0)
            for batch_id, batch in enumerate(self.data.trn):
                info.info(f"\r{batch_id+1}: {batch[0].shape}", pref="\r", end="   ")
            print(f" {len(data_trn)} batches")

            info("validation start")
            for batch_id, batch in enumerate(self.data.trn):
                info.info(f"\r{batch_id+1}: {batch[0].shape}", pref="\r", end="   ")
            print(f" {len(data_trn)} batches")


trainer = Trainer(data_trn, callbacks=[callback])
trainer.fit(epochs=3)

# Submission

In [None]:
def submission(patch_h, patch_w):
    result = {}
    ink_pred = None

    data_tst = Data(batch_size=100)
    callback = DataManagerCallback(data=data_tst, train=False, n_patches=1000, rand=False,
                                   patch_h=CFG.patch_h, patch_w=CFG.patch_w, patch_dh=0, patch_dw=0,
                                   layer_min=CFG.layer_min, layer_max=CFG.layer_max)
    callback.on_fit_start(None, 0)
    while True:                # on patches
        callback.on_train_epoch_start(None, 0)

        if ink_pred is None:
            ink_pred = np.ones_like(callback.mask)
            folder = callback.folders[callback.folder_id].parts[-1]
            print(f"\n*** Create new mask {ink_pred.shape}  folder:{folder}\n")

        for batch_id, batch in enumerate(data_tst):
            patches, mask, ink, pos = batch
            for (y,x) in pos:
                ink_pred[y:y+patch_h, x:x+patch_w] = 0  # Check that all pixels are covered with patches

        if callback.next_new_file:
            # Before uploading a new file, save the mask to submissions
            ink_pred[0,1]=ink_pred[1,0]=1 # для csv
            result[folder] = ""           # !
            print(f"\n*** Save submission to '{folder}'; ink_pred.shape:{ink_pred.shape}, sum:{ink_pred.sum()} == 2\n")
            ink_pred = None

            if callback.folder_id + 1 >=  len(callback.folders):
                break

    info(f"result: {result}")
    #pd.DataFrame(result).to_csv("submission.csv")  # todo

In [None]:
#---------------------------------------------------------------------------
info("beg")
#del data_trn, data_val
submission(patch_h=CFG.patch_h, patch_w=CFG.patch_w)
info(f"the End");

# Visualization

In [None]:
def plot_masks(folder=CFG.folder_trn,  subfolder="1/", w=10, h=6):
    path = Path(folder) / Path(subfolder)

    sample = Image.open(path / Path("ir.png"))
    label = torch.from_numpy(np.array(Image.open(path / Path("inklabels.png")) )).gt(0).float().to(CFG.device)
    mask = np.array(Image.open(path / Path("mask.png")).convert('1'))

    fig, (ax0, ax1) = plt.subplots(1, 2, figsize=(w, h), facecolor ='w')
    ax0.set_title(subfolder+"ir.png");         ax0.imshow(sample,      cmap='gray')
    ax1.set_title(subfolder+"inklabels.png");  ax1.imshow(label.cpu(), cmap='gray')
    ax1.set_title(subfolder+"mask.png");       ax1.imshow(mask,        cmap='gray', alpha=0.5)
    plt.show()
    print(f"label: {label.shape},  mask: {mask.shape}")

info("beg"); plot_masks(subfolder="1/"); info("end")

In [None]:
def plot_data(folder=CFG.folder_trn,  subfolder="1/", start=0, num=4,  w=12, h=6):
    """ Load the 3d x-ray scan, one slice at a time """
    path  = Path(folder) / Path(subfolder) / Path("surface_volume/")
    files = sorted( path.glob('*.tif') )
    print(f"total files: {len(files)}")
    images = [np.array(Image.open(fname), dtype=np.float32) / 65535.0 \
              for fname in tqdm(files[start:start+num])  ]

    fig, axes = plt.subplots(1, len(images), figsize=(w, h))
    for image, ax in zip(images, axes):
        ax.imshow(np.array(Image.fromarray(image).resize((image.shape[1]//20, image.shape[0]//20)), dtype=np.float32), cmap='gray')
        ax.set_xticks([]); ax.set_yticks([])
    fig.tight_layout()
    plt.show()

info("beg"); plot_data(subfolder="1/"); info("end")

In [None]:
def plot_patches(data, idx=0, n_images=5, w=12, h=4, start=0):
    data.reset()
    for patch,  mask, ink, pos in data:
        print(patch.shape, mask.shape, ink.shape)
        break

    images =  [mask[idx][0].float().numpy(), ink[idx][0].float().numpy()]
    images += [patch[idx, start+i].numpy() for i in range(n_images-3) ]
    images += [ patch[idx].mean(0).numpy()]
    fig, axes = plt.subplots(1, n_images, figsize=(w, h))
    for i, (image, ax) in enumerate(zip(images, axes)):
        ax.imshow(image, cmap='gray', vmin=0, vmax=1)
        ax.set_xticks([]); ax.set_yticks([])
        if i==0:
            ax.set_title(f"y:{pos[idx][0]}, x:{pos[idx][1]}", fontsize=8)
        #elif i > 1:
        #    ax.set_title(f"{patch[idx,start+i-2].min()} {patch[idx,start+i-2].max()} {patch[idx,start+i-2].std():.2f}")

    fig.tight_layout()
    plt.show()

for idx in range(20):
    plot_patches(data_val, idx=idx)