# Fragment flattening ([kaggle notebook](https://www.kaggle.com/code/thenoodleninja/fragment-flattening))

Script to flatten the [vesuvius scrolls](https://www.kaggle.com/competitions/vesuvius-challenge-ink-detection).

In [None]:
import gc
from scipy.ndimage import gaussian_filter
from scipy import ndimage
from pathlib import Path
import numpy as np
import glob
import PIL.Image as Image
import matplotlib.pyplot as plt
import seaborn as sn
from tqdm.auto import tqdm 
import os

In [None]:
def flatten(arr, z_buffer=2, z_layers=7, blur_topo = True):
    """
    :param arr: numpy array with the surface_volume_data
    :param z_buffer: how much air we will leave above the papyrus
    :param z_layers: how much layers we want to keep after transsforming
    :return:
    """
    arr = arr.astype(float) / (2**16-1) # convert to float
    arr = np.flip(arr, axis=0)
    arr = gaussian_filter(arr, sigma=1)
    arr_sob = ndimage.sobel(arr, axis=0)
    arr_sob = gaussian_filter(arr_sob, sigma=1)
    topo = np.argmax(np.where(arr_sob < 0.5, 0, 1), axis=0)
    del arr_sob
    if blur_topo:
        topo=gaussian_filter(topo,sigma=1)
    arr_idx = np.indices(arr.shape)
    arr = arr[
        (arr_idx[0] + topo - z_buffer) % arr.shape[0], arr_idx[1], arr_idx[2]]
    arr = (arr[0:z_layers]*(2**16-1)).astype(np.uint16)
    return arr, topo.astype(np.uint8)

class ScanData:
    def __init__(self, baseDir, cache_folder="/kaggle/tmp"):
        self.id = hash(str(baseDir))
        self.dir = baseDir.resolve()
        maskName = str(baseDir / "mask.png")
        mask = np.array(Image.open(maskName).convert("1"))        
        self.mask = mask
        (self.h,self.w) = mask.shape

        labelName = str(baseDir / "inklabels.png" )
        try:
            self.label = np.array(Image.open(labelName).convert("1"))        
        except:
            self.label = None

        self.sliceNames = sorted( (baseDir / "surface_volume").rglob("*.tif") )
        dim = len(self.sliceNames)
        self.dim = dim
        self.cache_folder=cache_folder
    
    def get_img_uint16(self, x_slc=slice(None, None), y_slc=slice(None, None), z_slc=slice(None, None)):
        if not os.path.isfile(f"{self.cache_folder}/{self.id}_uint16.npy"):
            print("caching img")
            os.makedirs(f"{self.cache_folder}", exist_ok= True)
            
            img = np.zeros((self.dim, self.h, self.w), dtype=np.uint16)
            for idx, filename in enumerate(tqdm(self.sliceNames, leave=False)):
                fname = str(filename)
                img[idx, :, :] = np.array(Image.open(fname))
            np.save(f"{self.cache_folder}/{self.id}_uint16.npy", img)
            
        return np.load(f'{self.cache_folder}/{self.id}_uint16.npy', mmap_mode='r')[z_slc, y_slc, x_slc].copy()
    
    def get_img_uint8(self, x_slc=slice(None, None), y_slc=slice(None, None), z_slc=slice(None, None)):
        if not os.path.isfile(f"{self.cache_folder}/{self.id}_uint8.npy"):
            print("caching img")
            os.makedirs(f"{self.cache_folder}", exist_ok= True)

            img = np.zeros((self.dim, self.h, self.w), dtype=np.uint8)
            for idx, filename in enumerate(tqdm(self.sliceNames, leave=False)):
                fname = str(filename)
                img[idx, :, :] = (np.array(Image.open(fname))//256).astype(np.uint8)
            np.save(f"{self.cache_folder}/{self.id}_uint8.npy", img)

        return np.load(f'{self.cache_folder}/{self.id}_uint8.npy', mmap_mode='r')[z_slc, y_slc, x_slc].copy()
    
    def flatten(self, z_layers=7, stripe_width=500, stripe_overlay = 20):
        flattened = np.zeros((z_layers, self.h, self.w), dtype=np.uint16)
        topo = np.zeros((self.h, self.w), dtype=np.uint8)
        for i in tqdm(range(0,self.h, stripe_width)):
            y0=max(i-stripe_overlay,0)
            y1=min(i+stripe_width+stripe_overlay,self.h)

            img_stripe = self.get_img_uint16(y_slc=slice(y0, y1))
            stripe_flat, stripe_topo =flatten(img_stripe, 2, z_layers)
            if y0 == 0:
                flattened[:, i:min(i+stripe_width, self.h), :] = stripe_flat[:, 0:min(stripe_width, stripe_flat.shape[1]), :]
                topo[i:min(i+stripe_width, self.h), :] = stripe_topo[0:min(stripe_width, stripe_topo.shape[0]), :]
            else:
                flattened[:, i:min(i+stripe_width, self.h), :] = stripe_flat[:, stripe_overlay:min(stripe_width+stripe_overlay, stripe_flat.shape[1]), :] 
                topo[i:min(i+stripe_width, self.h), :] = stripe_topo[stripe_overlay:min(stripe_width+stripe_overlay, stripe_topo.shape[0]), :] 
        return flattened, topo
    
    def save_flattened(self, folder):
        os.makedirs(f"{folder}/surface_volume", exist_ok= True)
        
        flattened, topo = self.flatten()
        # saving layers
        for i in tqdm(range(flattened.shape[0])):
            layer = Image.fromarray(flattened[i])
            layer.save(f"{folder}/surface_volume/{i:02}.tif")
        layer = Image.fromarray(topo)
        layer.save(f"{folder}/topo.png")

In [None]:
# get fragment paths

output_folder = Path("/kaggle/working/flat")
base_path = Path("/kaggle/input/vesuvius-challenge-ink-detection")
train_path = base_path / "train"
train_fragments = sorted([train_path  / f.name for f in train_path.iterdir()])
test_path = base_path / "test"
test_fragments = sorted([test_path  / f.name for f in test_path.iterdir()])

allFragments = train_fragments + test_fragments

for fragment in allFragments:
    scan = ScanData(fragment)
    print(fragment.name)
    scan.save_flattened( output_folder / fragment.name)

## Comparison of segmentation approaches

In [None]:
output_folder = Path("/kaggle/working/flat")
base_path = Path("/kaggle/input/vesuvius-challenge-ink-detection")
train_path = base_path / "train"
train_fragments = sorted([train_path  / f.name for f in train_path.iterdir()])
test_path = base_path / "test"
test_fragments = sorted([test_path  / f.name for f in test_path.iterdir()])
allFragments = train_fragments + test_fragments
scan = ScanData(allFragments[3])

### Clustering

In this section several clustering algorithms are compared. All of these use the scaled intensity and z-position of pixels features.

In [None]:
slc = scan.get_img_uint8()
test = gaussian_filter(slc[:, 1000, 2400:6000], sigma=1)
test_z = np.stack((test, np.indices(test.shape)[0]), axis=2).astype(float)
test_z[:, :, 0]/=255.0*0.5
test_z[:, :, 1]/=64.0

In [None]:
plt.scatter(x=test_z[:, :, 0].reshape(-1, 1).squeeze(), y=test_z[:, :, 1].reshape(-1, 1).squeeze(), marker=".", alpha=0.005)

In [None]:
from sklearn.cluster import KMeans
N=3
kmeans = KMeans(n_clusters=N).fit(test_z.reshape((-1, 2)))
a = kmeans.predict(test_z.reshape(-1, 2)).reshape(test.shape)
plt.imshow(test[:, 1500:2000])
plt.show()
plt.imshow(a[:, 1500:2000])
plt.show()
sn.scatterplot(x=test_z[:, : ,0].flatten(), y=test_z[:, :, 1].flatten(), hue=a.flatten(), alpha=0.005)
plt.show()

In [None]:
from sklearn.cluster import AgglomerativeClustering
linkage="ward"
wrd = AgglomerativeClustering(linkage=linkage, n_clusters = N).fit(test_z[:, 1500:2000, :].reshape((-1, 2)))
#a = wrd.predict(test_z.reshape(-1, 2)).reshape(test.shape)
plt.imshow(test[:, 1500:2000])
plt.show()
plt.imshow(wrd.labels_.reshape((65, -1)))
plt.show()
sn.scatterplot(x=test_z[:, 1500:2000, 0].flatten(), y=test_z[:, 1500:2000, 1].flatten(), hue=wrd.labels_, alpha=0.05)
plt.show()

In [None]:
from sklearn.cluster import AgglomerativeClustering
linkage="average"
wrd = AgglomerativeClustering(linkage=linkage, n_clusters = N).fit(test_z[:, 1500:2000, :].reshape((-1, 2)))
#a = wrd.predict(test_z.reshape(-1, 2)).reshape(test.shape)
plt.imshow(test[:, 1500:2000])
plt.show()
plt.imshow(wrd.labels_.reshape((65, -1)))
plt.show()
sn.scatterplot(x=test_z[:, 1500:2000, 0].flatten(), y=test_z[:, 1500:2000, 1].flatten(), hue=wrd.labels_, alpha=0.05)
plt.show()

In [None]:
from sklearn.cluster import AgglomerativeClustering
linkage="complete"
wrd = AgglomerativeClustering(linkage=linkage, n_clusters = N).fit(test_z[:, 1500:2000, :].reshape((-1, 2)))
#a = wrd.predict(test_z.reshape(-1, 2)).reshape(test.shape)
plt.imshow(test[:, 1500:2000])
plt.show()
plt.imshow(wrd.labels_.reshape((65, -1)))
plt.show()
sn.scatterplot(x=test_z[:, 1500:2000, 0].flatten(), y=test_z[:, 1500:2000, 1].flatten(), hue=wrd.labels_, alpha=0.05)
plt.show()

### Edge Detection

In [None]:
arr = scan.get_img_uint16(x_slc=slice(2400, 6000), y_slc=slice(500, 1500))
arr = arr.astype(float) / (2**16-1) # convert to float
arr = np.flip(arr, axis=0)
arr_filtered = gaussian_filter(arr, sigma=1)
arr_sob = ndimage.sobel(arr_filtered, axis=0)
arr_sob = gaussian_filter(arr_sob, sigma=1)
topo = np.argmax(np.where(arr_sob < 0.5, 0, 1), axis=0)

In [None]:
fig, axs = plt.subplots(4, 1)
axs[0].imshow(arr[:, 500, 1500:2000])
axs[1].imshow(arr_sob[:, 500, 1500:2000])
axs[2].imshow(np.where(arr_sob < 0.5, 0, 1)[:, 500, 1500:2000])
axs[3].set_ylim(-64, 0)
axs[3].plot(-topo[500, 1500:2000], )

### Thresholding

In [None]:
from skimage.filters import threshold_otsu

slc = scan.get_img_uint8()
test = gaussian_filter(slc[:, 1000, 2400:6000], sigma=1)
thresh = threshold_otsu(test)

In [None]:
plt.hist(test.flatten(), bins=256, range=(0,256))
plt.axvline(thresh, color='r')
plt.show()
threshed = np.zeros_like(test)
threshed[test > thresh] = 1
plt.imshow(threshed[:, 1500:2000])