## Cropping Images Pipeline
### Description
Crops image pairs from stpt2imc/data/{IMC, STPT}/ and 
### Notes
- normalize to 8 bits b/c most values are between 0-255 anyways
- in process_imc(stpt)_image(), don't convert numpy array to double - it slows things down A LOT (don't know why)
- convert tensor to double when using torch.save - MASSIVE speed ups
- make a clone of the stpt grid tensor when saving b/c torch.save will save the entire grid for some reason

In [None]:
# import packages
import os, time
import numpy as np
from skimage import io
import cv2 as cv
import torch
from torch.multiprocessing import Pool, set_start_method
from torchvision import datasets, transforms
from torchvision.utils import save_image
import matplotlib.pyplot as plt

In [None]:
# ====== functions to load an stpt and imc image ====== 

def process_stpt_image(file_name):
    img = io.imread(file_name)
    
    # normalize image (8 bits)
    norm_img = img.copy()
    cv.normalize(img, norm_img, alpha=0, beta=2**8 - 1, norm_type=cv.NORM_MINMAX)

    # Apply log transformation method
    c = (2**8 - 1) / np.log(1 + np.max(norm_img))

    log_image = c * (np.log(norm_img + 1))
    # Specify the data type so that
    # float value will be converted to int
    return torch.from_numpy(log_image)

def process_imc_image(file_name):
    # read image file
    img = cv.imread(file_name, cv.IMREAD_UNCHANGED)

    # normalize image (8 bits)
    norm_img = img.copy()
    cv.normalize(img, norm_img, alpha=0, beta=2**8 - 1, norm_type=cv.NORM_MINMAX)

    # Apply log transformation method
    c = (2**8 - 1) / np.log(1 + np.max(norm_img))

    log_image = c * (np.log(norm_img + 1))

    # Specify the data type so that
    # float value will be converted to int
    return torch.from_numpy(log_image)

def save_imc(phys_sec, grid, row, column):
    # creates directory if doesn't exist
    if not(os.path.isdir('processed_data/IMC/{0}'.format(str(phys_sec).zfill(2)))):
        os.mkdir('processed_data/IMC/{0}'.format(str(phys_sec).zfill(2)))
    torch.save(grid[row][column].double(), 'processed_data/IMC/{0}/{1}_{2}.pt'.format(str(phys_sec).zfill(2),
                                                                             str(row).zfill(2),
                                                                             str(column).zfill(2)))
    
def save_stpt(phys_sec, grid, row, column):
    # creates directory if doesn't exist
    if not(os.path.isdir('processed_data/STPT/{0}'.format(str(phys_sec).zfill(2)))):
        os.mkdir('processed_data/STPT/{0}'.format(str(phys_sec).zfill(2)))    
    torch.save(grid[row][column].clone().double(), 'processed_data/STPT/{0}/{1}_{2}.pt'.format(str(phys_sec).zfill(2),
                                                                             str(row).zfill(2),
                                                                             str(column).zfill(2)))    

In [None]:
def process_imc_images(phys_sec, grid_size=256):
    '''
    phys_sec = physical section where the image came from
    grid_size = how large each cropped image will be
    
    1. get image paths corresponding to phys_sec
    2. concatenate IMC images within folder to form a single 40-channel IMC image
    3. crop 16 pixels from each side
    4. crop images into 256x256 squares
    5. free up memory
    6. save processed tensors sequentially
    '''
    
    # ====== GET IMAGE PATHS ======
    
    imc_section_folder = os.path.join('../data/IMC/',
                                      'SECTION_{}'.format(str(phys_sec).zfill(2)))

    # get a list of all .tif images inside imc_section_folder
    imc_img_paths = [os.path.join(imc_section_folder, imc_img_path)
                     for imc_img_path in os.listdir(imc_section_folder)
                     if imc_img_path.endswith('.tif')]
    
    # ====== LOAD IMAGES ======
    with Pool(maxtasksperchild=100) as p:
        imc_imgs = list(p.imap(process_imc_image, imc_img_paths))

    imc_imgs = [torch.unsqueeze(img, 0) for img in imc_imgs] # add an extra dimesion for channel
    imc_imgs_cat = torch.cat(imc_imgs, 0) # (40, 18720, 18720)

    cropped = imc_imgs_cat[:, 16:18704, 16:18704]  # crop 16 pixels from each side (40, 18688, 18688)

    # ====== CONSTRUCT GRID ======
    
    temp = torch.split(cropped, grid_size, dim=1) # row slices; each slice has shape (40, 256, 18688)
    grid = [torch.split(curr, grid_size, dim=2) for curr in temp] # grid is 73x73; each slice has shape (40, 256, 256)

    # ====== FREE UP MEMORY ======
    
    del imc_imgs
    del imc_imgs_cat
    del cropped
    del temp
    
    # ====== SAVE PROCESSED TENSORS ======
    for i in range(len(grid)):
        for j in range(len(grid[0])):
            save_imc(phys_sec, grid, i, j)
    
    print('IMC: Done physical section:', phys_sec)

In [None]:
def process_stpt_images(phys_sec, grid_size=256):
    '''
    phys_sec = physical section where the image came from
    grid_size = how large each cropped image will be
    
    1. get image paths corresponding to phys_sec
    2. concatenate STPT images within folder to form a single 8-channel STPT image
    3. crop 16 pixels from each side
    4. crop images into 256x256 squares
    5. free up memory
    6. save processed tensors sequentially
    '''
    
    # ====== GET IMAGE PATHS ======
    
    stpt_img_paths = [os.path.join('../data/STPT/',
                                   'S{0}_Z{1}.tif'.format(str(phys_sec).zfill(3),
                                                      optical_section.zfill(2)))
                      for optical_section in ['0', '1']]  
    
    # ====== LOAD IMAGES ======
    stpt_imgs = []
    for path in stpt_img_paths:
        stpt_imgs.append(process_stpt_image(path))
#     with Pool(maxtasksperchild=100) as p:
#         stpt_imgs = list(p.imap(process_stpt_image, stpt_img_paths))
        
    stpt_imgs = [img.permute((2,0,1)) for img in stpt_imgs] # (C,H,W) tensor
    stpt_imgs_cat = torch.cat(stpt_imgs, 0) # concatenate two stpt images (8, 20800, 20800)
    del stpt_imgs

    stpt_imgs_cat = transforms.Resize(18720)(stpt_imgs_cat)  # make STPT img same size as IMC (..., 18720, 18720)
    cropped = stpt_imgs_cat[:, 16:18704, 16:18704]  # crop 16 pixels from each side (8, 18688, 18688)

    # ====== CONSTRUCT GRID ====== 
    temp = torch.split(cropped, grid_size, dim=1) # row slices; each slice has shape (8, 256, 18688)
    grid = [torch.split(curr, grid_size, dim=2) for curr in temp] # grid is 73x73; each slice has shape (8, 256, 256)

    # ====== FREE UP MEMORY ======
    del stpt_imgs_cat
    del cropped
    del temp
    
    # ====== SAVE PROCESSED TENSORS ======
    for i in range(len(grid)):
        for j in range(len(grid[0])):
            save_stpt(phys_sec, grid, i, j) 
    
    print('STPT: Done physical section:', phys_sec)

In [None]:
# main function

for i in range(2, 19): # for loop indices correspond to the physical sections that will be processed (18)
    if i == 16:
        # skip physical section 16 b/c deprecated
        continue
    # process_imc_images(i)
    process_stpt_images(i)