## Preprocessing: Image slicing

In [None]:
from dotenv import load_dotenv
from PIL import Image

import numpy as np
import matplotlib.pyplot as plt
import math
import os
import glob

load_dotenv()
%matplotlib inline

In [None]:
"""
Slices image into patches of size NxNx3 according to stride
Zero pads any patches that are cut off at borders of image
Returns numpy array of image patches
"""

def slice_image(img, n, stride, channels):
    x_dim, y_dim = img.shape[0], img.shape[1]
    
    # calculate number of output cubes
    num_x = math.ceil((x_dim-n)/stride + 1)
    num_y = math.ceil((y_dim-n)/stride + 1)

    # iterate through img 
    results = np.zeros((num_x*num_y, n, n, channels), dtype=int)
    counter = 0

    for i in range(num_y):
        y = 0 if i==0 else y + stride
        for j in range(num_x):
            x = 0 if j==0 else x + stride

            cube = img[x:x+n, y:y+n]
            
            dims = cube.shape
            if channels == 1:
                cube = cube.reshape((dims[0],dims[1],1))
            results[counter,0:cube.shape[0],0:cube.shape[1]] = cube
            counter += 1
    return results

### Example usage on single image

In [None]:
n = 128
stride = 128
channels = 3

f = os.getenv('TEST_FILE') 
im = Image.open(f)
im_arr = np.array(im)[:,:,0:channels]
plt.figure(figsize=(10,10))
plt.imshow(im_arr)

In [None]:
results = slice_image(im_arr, n, stride, channels)

### Slice all images and masks, saving output

In [None]:
im_dir = os.getenv('IMAGE_DIR')
mask_dir = os.getenv('MASK_DIR')

postfix = os.getenv('DIR_LABEL_SLICED')
n = 128

im_dir_out = im_dir + postfix
mask_dir_out = mask_dir + postfix

In [None]:
im_filenames = glob.glob(im_dir+'/*.npy', recursive=True)
mask_filenames = glob.glob(mask_dir+'/*.npy', recursive=True)

# Create output directories 
if not os.path.exists(im_dir_out):
    os.makedirs(im_dir_out)
if not os.path.exists(mask_dir_out):
    os.makedirs(mask_dir_out)

for i, im_filename in enumerate(im_filenames):
    im = np.load(im_filename)
    filename = im_filename.split('/')[-1] #get image name
    mask = np.load(os.path.join(mask_dir,filename))

    im_slices = slice_image(im, n, stride, 3)
    mask_slices = slice_image(mask, n, stride, 1)

    for j, im_slice in enumerate(im_slices):
        np.save(os.path.join(im_dir_out,filename.split('.')[0]+'_'+str(j)+'.npy'), im_slice)

    for k, mask_slice in enumerate(mask_slices):
        np.save(os.path.join(mask_dir_out,filename.split('.')[0]+'_'+str(k)+'.npy'), mask_slice)
        
    print('processed image', i)