#### Install required libraries 

In [None]:
!pip install numpy
!python3 -m pip install --upgrade Pillow
!pip install tiffile
!pip install qlty
!pip install opencv-python
!pip install connected-components-3d
!pip install rdp 
!git clone https://github.com/phzwart/dlsia.git
!cd dlsia && pip install -e .

In [None]:
import os
import cv2
import glob
import cc3d 
import csv
import random
import rdp 
import numpy as np
import pandas as pd
import collections
from tqdm import tqdm

from PIL import Image
import matplotlib.pyplot as plt
from tifffile import imread, imwrite


import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader, Dataset

from skimage import exposure,morphology
from sklearn.model_selection import train_test_split

import qlty
from qlty import qlty2D

from dlsia.core import helpers, train_scripts, corcoef
from dlsia.core.networks import msdnet, tunet


### Helper Functions

In [None]:
def display(array1, array2):
    """
    Displays ten random images from each one of the supplied arrays.
    """
    n = 7

    indices = np.random.randint(len(array1), size=n)
    print('The indices of the images are ', indices)
    images1 = array1[indices, :]
    images2 = array2[indices, :]
    plt.figure(figsize=(50, 20))
    
    for i, (image1, image2) in enumerate(zip(images1, images2)):
        ax = plt.subplot(2, n, i + 1)
        plt.imshow(image1, vmin=0, vmax=1)
        plt.gray()
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)

        ax = plt.subplot(2, n, i + 1 + n)
        plt.imshow(image2, vmin=0, vmax=1)
        plt.gray()
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)

    plt.show()
    
def regression_metrics( preds, target):
    tmp = corcoef.cc(preds.cpu().flatten(), target.cpu().flatten() )
    return(tmp)


def segment_imgs(testloader, net):
    """ Modified for input and no ground truth"""
    torch.cuda.empty_cache()
    
    seg_imgs = []
    noisy_imgs = []
    counter = 0
    with torch.no_grad():
        for batch in tqdm(testloader):
            noisy = batch
            noisy = noisy[0]
            noisy = torch.FloatTensor(noisy)
            noisy = noisy.to(device)
            output = net(noisy)
            if counter == 0:
                seg_imgs = output.detach().cpu()
                noisy_imgs = noisy.detach().cpu()
            else:
                seg_imgs = torch.cat((seg_imgs, output.detach().cpu()), 0)
                noisy_imgs = torch.cat((noisy_imgs, noisy.detach().cpu()), 0)
                
            counter+=1
            del output
            del noisy
            torch.cuda.empty_cache()
    return seg_imgs, noisy_imgs

def create_network(model_type, params):
    # set model parameters and initialize the network
    if model_type == "SMSNet":
        net = SMSNet.random_SMS_network(**params)
        model_params = {
          "in_channels": net.in_channels,
          "out_channels": net.out_channels,
          "in_shape": net.in_shape,
          "out_shape": net.out_shape,
          "scaling_table": net.scaling_table,
          "network_graph": net.network_graph,
          "channel_count": net.channel_count,
          "convolution_kernel_size": net.convolution_kernel_size,
          "first_action": net.first_action,
          "hidden_action": net.hidden_action,
          "last_action":net.last_action,
        }
        return net, model_params
    elif model_type == "MSDNet":
        net = msdnet.MixedScaleDenseNetwork(**params)
        return net, params
    elif model_type == 'TUNet':
        net = tunet.TUNet(**params)
        return net, params
    else:
        return None, None

def save_stack(imgx, imgy, imgz, d_type = None, return_stacks=True):
    imgx_stack = []
    imgy_stack = []
    imgz_stack = []

    for j in tqdm(range(len(imgx))):
        ix = Image.open(imgx[j])
        iy = Image.open(imgy[j])
        iz = Image.open(imgz[j])
        
        ix.load()
        iy.load()
        iz.load()

        if d_type == None:
            ix = np.array(ix)
            iy = np.array(iy)
            iz = np.array(iz)
        else:
            ix = np.array(ix, dtype=d_type)
            iy = np.array(iy, dtype=d_type)
            iz = np.array(iz, dtype=d_type)

        imgx_stack.append(ix)
        imgy_stack.append(iy)
        imgz_stack.append(iz)

    imgx_stack = np.array(imgx_stack)
    imgy_stack = np.array(imgy_stack)
    imgz_stack = np.array(imgz_stack)
        
    if return_stacks == True:
        return imgx_stack, imgy_stack, imgz_stack

In [None]:
# Path to images to be segmented. Requires jpeg images 
source ='/data/FIBSEM/testing_scripts/sourceFile/images' 
# Path to trained Tunet, should have params.npy and a net file
maindir= '/data/FIBSEM/testing_scripts/sourceFile/Results/tunet3'

In [None]:
files = [f for f in os.listdir(source) if f.endswith('.jpg')]
print('Number of files to segment: ', len(files))
files.sort()

In [None]:
test_imgs = []
for file in files:
    img = Image.open(f'{source}/{file}')
    img.load()
    img = np.array(img, dtype='float32')
    # Uncomment this if images are RGB 
    #img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
    test_imgs.append(img)
test_imgs = np.expand_dims(np.array(test_imgs), axis=1)

### The following qlty and pre processing code is same as train.ipynb. any changes made while training the network should also be made here. 

In [None]:
quilt = qlty2D.NCYXQuilt(X=test_imgs.shape[3],
                         Y=test_imgs.shape[2],
                         window=(256,256),
                         step=(64,64),
                         border=(10,10),
                         border_weight=0)

def imageSplit(quilt,test_imgs):
    dicedImgs = []
    labeled_imgs = torch.Tensor(test_imgs)
    labeled_imgs = quilt.unstitch(labeled_imgs)
    print("x shape: ",test_imgs.shape)
    print("x_bits shape:", labeled_imgs.shape)
    
    for i in range(len(labeled_imgs)):
        bilateral = cv2.bilateralFilter(labeled_imgs[i][0].numpy(),5,50,10)
        clahe = cv2.createCLAHE(clipLimit=3)
        bilateral= bilateral.astype(np.uint16)
        final = clahe.apply(bilateral)
        x = exposure.equalize_hist(final)
        dicedImgs.append(x.astype(np.float32))
        #dicedImgs.append(final.astype(np.float32))
    return np.expand_dims(np.array(dicedImgs), axis=1)

In [None]:
params = np.load(maindir + '/params.npy', allow_pickle=True)
params = params[0]
print('The following define the network parameters: ', params)

model_type = 'TUNet'
#model_type = 'MSDNet'  

net, model_params = create_network(model_type, params)
net.load_state_dict(torch.load(maindir + '/net'))

In [None]:
device = helpers.get_device()
device='cuda:1'
print('Device we compute on: ', device)
print('Number of parameters: ', helpers.count_parameters(net))
net.to(device)

In [None]:
# folder to save the segmented images and masks 
folder = '/data/FIBSEM/testing_scripts/sourceFile/outputs'
# mapper maps the label and feature. eg 1:"filament" represents label 1 in training which is marked as filament 
# this generates folder filament with  segmented images. {1:"ribosomes",2:"tube",3:"mem"}
output_mapper = {1:"filament"}

# Number of files to process in one batch. reduce or increase this based on available memory. 
file_batch = 5 

In [None]:
out_masks = None 

for k,v in output_mapper.items():
    if os.path.isdir(f'{folder}/{v}/segments') is False: os.mkdir(f'{folder}/{v}/segments')
    
for i in range(200,220,file_batch):
    imgs = test_imgs[i:i+file_batch]
    dicedtestImgs = imageSplit(quilt, imgs)
    
    batch_size = file_batch
    num_workers = 0    #increase to 1 or 2 with multiple GPUs
    test_data = TensorDataset(torch.Tensor(dicedtestImgs))
    test_loader_params = {'batch_size': batch_size,
                     'shuffle': False,
                     'num_workers': num_workers,
                     'pin_memory':True,
                     'drop_last': False}
    test_loader = DataLoader(test_data, **test_loader_params)  
    
    output, input_imgs  = segment_imgs(test_loader, net)
    stitched_output = quilt.stitch(torch.tensor(output))
    o = torch.squeeze(stitched_output[0], 1)
    tunet3_output = torch.argmax(o.cpu()[:,:,:,:].data, dim=1)
    
    masks=tunet3_output.numpy()
    imgs= np.squeeze(imgs,1)
    
    out_masks=masks if out_masks is None else np.vstack((out_masks,masks))
    
    for k,v in output_mapper.items():
        idx=(masks==k)
        structures=np.zeros(imgs.shape)
        structures[idx]=imgs[idx]
        out_path = f'{folder}/{v}/segments/'
        
        for j in range(structures.shape[0]):
            name = f'{i+j:03}.jpg'
            print(out_path+name)
            Image.fromarray(structures[j].astype(np.uint8)).save(out_path+name)
        
    del output
    del tunet3_output
    del input_imgs
    torch.cuda.empty_cache()
    
imwrite(folder+'/masks.tif', out_masks)

### Clean image segments 
This code cleans up small objects from the segmented images in the folder which are generated because of false positives. 

In [None]:
object_size = 100 # remove objects smaller than this size. 

def clean_stack(img_stack, minim):
        cleaned = np.copy(img_stack)
        cleaned_index = (cleaned!=0)
        for j in tqdm(range(len(cleaned))):
            img = cleaned_index[j,:] 
            img = morphology.remove_small_objects(img, minim, connectivity=1)
            target_img = cleaned[j,:,:]
            cleaned[j,:,:] = np.multiply(target_img, img)
        return cleaned

for k,v in output_mapper.items():
    path = f'{folder}/{v}/segments'
        
    files = []
    for file in glob.glob(path+"/*.jpg"):files.append(file)
    files = sorted(files)
    imgs= []
    for j in range(len(files)):
        img = Image.open(files[j])
        img.load()
        img = np.array(img, dtype='float32')
        imgs.append(img)
    imwrite(f'{folder}/{v}/{v}.tiff', clean_stack(imgs, object_size))

### Generate Co-ordinates for subtomo averaging  

In [None]:
def simplify_points(arr,z):
    
    final_coord = []
    
    count = collections.defaultdict(list)
    rows,cols = np.nonzero(arr)
    
    for r,c in zip(list(rows),list(cols)):
        pixel = arr[r][c]
        count[pixel].append([r,c])
        
    for pixel, coord in count.items():
        
        simplied = rdp.rdp_iter(np.array(coord),epsilon=0.5)
        simplied = [(x[0],x[1],z,pixel) for x in simplied]
        final_coord += simplied 
    
    return final_coord

for k,v in output_mapper.items():
    file =f'{folder}/{v}/{v}.tiff'
    
    imgs =imread(file)
    imgs[imgs!=0]=1
    labels_out= cc3d.connected_components(imgs, connectivity=6)
    total,rows,cols = labels_out.shape

    with open(f'{folder}/{v}/coordinates.csv','a',encoding='UTF8',newline='') as f:
        writer = csv.writer(f)
        for images in range(total):
            for val in simplify_points(labels_out[images],images):
                writer.writerow(val)

### Simplify co-ordinates 
Generate only one co-ordinate per feature(generates one co-ordinate for each filament or ribosome)

In [None]:
df = pd.read_csv(f'{folder}/{v}/coordinates.csv',names=["x","y","z","pixel"])
grouped = df.groupby('pixel')
random_points = grouped.apply(lambda x: x.iloc[np.random.randint(0,len(x))])
random_points.to_csv(f'{folder}/{v}/simplified_coord.csv',index=False)