In [12]:
import numpy as np
import skimage.io
import torch
from torchvision.transforms import ToTensor, Lambda, RandomCrop
from PIL import Image
import pandas as pd
import torchvision.transforms as T
import matplotlib.pyplot as plt

from scipy.ndimage import gaussian_filter
from scipy.ndimage.measurements import center_of_mass

In [13]:
### FOR PREPROCESSING ####

In [18]:
def set_rand_seed(seed=10):
    '''
    Creates random seed for each library so that randomness is repeatable. Initialized first to set all randomness
    '''
    np.random.seed(seed)
    torch.manual_seed(seed)

def save(sample, mask, name, i, sample_name, sample_address, orig_name, cropped):
    """
    Saves file in sim_data folder with iterating number to correspond to number of samples desired as pytorch tensors
    """

    if cropped:
        for j in range (len(sample)):
            input_tensor = torch.tensor(sample[j])
            mask_tensor = torch.tensor(mask[j])
            
            file_name = f"/nsls2/users/maire1/unet/data/labeled_images/cropped_data/img{i}cropped{j}_pp.pt"
            torch.save({"input": input_tensor, "target": mask_tensor}, file_name)
            sample_name.append(f"img{i}cropped{j}_pp.pt")
            sample_address.append(file_name)

    else:
        input_tensor = torch.Tensor(sample)
        mask_tensor = torch.Tensor(mask)
        
        file_name = f"/nsls2/users/maire1/unet/data/labeled_images/real_data/img{i}_pp.pt"
        torch.save({"input": input_tensor, "target": mask_tensor}, file_name)
        sample_name.append(f"img{i}_pp.pt")
        sample_address.append(file_name)
    orig_name.append(name)

    return(sample_name, sample_address, orig_name)

def save_total_data(name, address, orig_name, cropped):
    """
    Saving titles and sample addresses into a separate csv file for use in the neural network.
    
    TODO: Save as a torch tensor in the future. 
    """

#     d = {'sample': name, 'address': address, 'original file name': orig_name[:len(address)]}
#     df = pd.DataFrame(data=d)
    print(len(name), len(address), len(orig_name))
    if cropped:
        d = {'sample': name, 'address': address}
        filename = '/nsls2/users/maire1/unet/data/labeled_images/cropped_data/cropped_img_address_pp.csv'
    else:
        d = {'sample': name, 'address': address, 'original file name': orig_name[:len(address)]}
        filename = '/nsls2/users/maire1/unet/data/labeled_images/real_data/img_address_pp.csv'
    df = pd.DataFrame(data=d)
    df.to_csv(filename, index = False)
    print("All samples completed. Data saved.")

In [19]:
def crop(orig_img, target, sample_num):
    set_rand_seed()
    pt_img = torch.tensor(orig_img)
    pt_target = torch.tensor(target)
    cropper = T.RandomCrop(size=(256, 256))
    crops = [cropper(pt_img) for j in range(sample_num)]
    set_rand_seed()
    mask_crops = [cropper(pt_target) for j in range(sample_num)]
    #plot(pt_img, crops)
    #plot(pt_target, mask_crops)
    return crops, mask_crops

def plot(orig_img, imgs, with_orig=True, row_title=None, **imshow_kwargs):
    if not isinstance(imgs[0], list):
        # Make a 2d grid even if there's just 1 row
        imgs = [imgs]

    num_rows = len(imgs)
    num_cols = len(imgs[0]) + with_orig
    fig, axs = plt.subplots(nrows=num_rows, ncols=num_cols, squeeze=False)
    for row_idx, row in enumerate(imgs):
        row = [orig_img] + row if with_orig else row
        for col_idx, img in enumerate(row):
            ax = axs[row_idx, col_idx]
            ax.imshow(np.asarray(img), **imshow_kwargs)
            ax.set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])

    if with_orig:
        axs[0, 0].set(title='Original image')
        axs[0, 0].title.set_size(8)
    if row_title is not None:
        for row_idx in range(num_rows):
            axs[row_idx, 0].set(ylabel=row_title[row_idx])

    plt.tight_layout()
    


In [20]:
def resize(x, n):
    xr = np.zeros((x.shape[0]//n,x.shape[1]//n ))
    for j in range(0, x.shape[0]-n, n):
        for i in range(0, x.shape[1]-n, n):
            xr[j//n, i//n] = np.nanmean(x[j:j+n, i:i+n])
    return xr

# I haven't checked it yet. But the idea should be to selec the nearest label
# that is not a background
def resize_target(x, n): 
    xr = np.zeros((x.shape[0]//n,x.shape[1]//n ))
    for j in range(0, x.shape[0]-n, n):
        for i in range(0, x.shape[1]-n, n):
            xr[j//n, i//n] = np.nanmax(x[j:j+n, i:i+n])
    return xr

def transform(img):
    # transformation of CHX data
    img[img>2000] = 0
    img[255:260, :] = 0
    img[805:810, :] = 0
    img[1357:1361, :] = 0
    img[1908:1912, :] = 0
    img[920:950, 1150:] = 0

    x_c, y_c = center_of_mass(img)

    # add a random dispacement if needed
    # maybe, leave it for later for data augmentation
    # x_c += np.random.randint(-100, 100)
    # y_c += np.random.randint(-100, 100

    size = 512
    x_c = np.max(int(x_c)-size//2, 0)
    y_c = np.max(int(y_c)-size//2, 0)
    #print(x_c, y_c)
    #plt.figure(dpi = 300)
    cropped_x = img[x_c:x_c+size, y_c:y_c+size]
    #plt.imshow(cropped_x , vmin = 0, vmax =5, origin = 'lower')
    
    y = resize(cropped_x, 2)
    y_smooth = gaussian_filter(y, sigma=2)
    y = gaussian_filter(y, sigma=0.3)
    y = (y - y_smooth.min())/(y_smooth.max() - y_smooth.min())
    
    return y

In [21]:
datafile = '/nsls2/users/maire1/unet/data/labeled_images/real_data/data_address.csv'
file = pd.read_csv(datafile)
partition = file['address']# IDs
labels = file['sample']# Labels
name = file['original file name']
sample_num = 10

sample_name = []
cropped_sample_name = []
sample_address = []
cropped_sample_address = []
orig_name = []

size = (256,256)

for i in range (len(partition)):
    #img = torch.load(f'/nsls2/users/maire1/unet/data/unet/real_data/{label}')
    img = np.load(f'/nsls2/users/maire1/unet/data/labeled_images/images_chx/{name[i]}.npy')
    img_mask = skimage.io.imread(f'/nsls2/users/maire1/unet/data/labeled_images/labeled_images_chx/{name[i]}_labeled.tif')
    
    img = transform(img)
    img_mask = transform(img_mask)
    
    #orig_img = img['input']
    #target = img['target']
    
    sample_name, sample_address, orig_name = save(img, img_mask, name, i, sample_name, sample_address, orig_name, False)
    crops, mask_crops = crop(img, img_mask, sample_num)
    cropped_sample_name, cropped_sample_address, orig_name = save(crops, mask_crops, name, i, cropped_sample_name, cropped_sample_address, orig_name, True)

    
save_total_data(sample_name, sample_address, orig_name, False)
save_total_data(cropped_sample_name, cropped_sample_address, orig_name, True)

  input_tensor = torch.tensor(sample[j])
  mask_tensor = torch.tensor(mask[j])


19 19 38
All samples completed. Data saved.
190 190 38
All samples completed. Data saved.
