In [14]:

import os
import napari
import numpy as np
from tifffile import imwrite
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader, Dataset

from tifffile import imread, imwrite
from skimage import exposure
from pyMSDtorch.core.networks import MSDNet, TUNet
from pyMSDtorch.core import helpers, train_scripts, custom_optimizers, custom_losses, corcoef

import glob
import qlty
import einops
from PIL import Image
from qlty import qlty2D
from tqdm import tqdm

### Helper functions

In [15]:
def display(array1, array2):
    """
    Displays ten random images from each one of the supplied arrays.
    """

    n = 7

    indices = np.random.randint(len(array1), size=n)
    print('The indices of the images are ', indices)
    images1 = array1[indices, :]
    images2 = array2[indices, :]

    plt.figure(figsize=(50, 20))
    for i, (image1, image2) in enumerate(zip(images1, images2)):
        ax = plt.subplot(2, n, i + 1)
        plt.imshow(image1, vmin=0, vmax=1)
        plt.gray()
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)

        ax = plt.subplot(2, n, i + 1 + n)
        plt.imshow(image2, vmin=0, vmax=1)
        plt.gray()
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)

    plt.show()
    
    
def regression_metrics( preds, target):
    tmp = corcoef.cc(preds.cpu().flatten(), target.cpu().flatten() )
    return(tmp)


def segment_imgs(testloader, net):
    """ Modified for input and no ground truth"""
    torch.cuda.empty_cache()
    
    seg_imgs = []
    noisy_imgs = []
    counter = 0
    with torch.no_grad():
        for batch in tqdm(testloader):
            noisy = batch
            noisy = noisy[0]
            noisy = torch.FloatTensor(noisy)
            noisy = noisy.to(device)#.unsqueeze(1)
            
            output = net(noisy)

            # Compute Pearson Correlation
            #tmp =  regression_metrics(output, target)
            #running_CC_test_val += tmp.item()

            if counter == 0:
                seg_imgs = output.detach().cpu()
                noisy_imgs = noisy.detach().cpu()
            else:
                seg_imgs = torch.cat((seg_imgs, output.detach().cpu()), 0)
                noisy_imgs = torch.cat((noisy_imgs, noisy.detach().cpu()), 0)
            counter+=1
            del output
            del noisy
            torch.cuda.empty_cache()
    return seg_imgs, noisy_imgs

def create_network(model_type, params):
    # set model parameters and initialize the network
    if model_type == "SMSNet":
        net = SMSNet.random_SMS_network(**params)
        model_params = {
          "in_channels": net.in_channels,
          "out_channels": net.out_channels,
          "in_shape": net.in_shape,
          "out_shape": net.out_shape,
          "scaling_table": net.scaling_table,
          "network_graph": net.network_graph,
          "channel_count": net.channel_count,
          "convolution_kernel_size": net.convolution_kernel_size,
          "first_action": net.first_action,
          "hidden_action": net.hidden_action,
          "last_action":net.last_action,
        }
        return net, model_params
    elif model_type == "MSDNet":
        net = MSDNet.MixedScaleDenseNetwork(**params)
        return net, params
    elif model_type == "MSDNet_partialconv":
        net = MSDNet.MixedScaleDenseNetwork(**params)
        return net, params
    elif model_type == "UNet":
        net = UNet.UNet(**params)
        return net, params
    elif model_type == 'TUNet':
        net = TUNet.TUNet(**params)
        return net, params
    elif model_type == 'TUNet_Eric':
        net = TUNet_Eric.TUNet(**params)
        return net, params
    else:
        return None, None

### Load data 

Data has been saved using the parse_and_save notebook, which simply recasts data as numpy arrays/tiff files.

Data was saved in the original resolution, and downsampled by factors of 2 and 4. We load images of the original resolution

### Prep data for network ingestion

We cast data as a tensor and load into pytorch Dataloader framework

In [None]:
folder = '/data/FIBSEM/nuclie_align_ROI1_membrane_1700/nuclie01'
files = len([f for f in os.listdir(folder) if f.endswith('.jpg')])
print(files)

In [17]:
test_imgs = []

for file in range(0,files):
    img = Image.open(f'{folder}/{file:05}.jpg')
    img.load()
    img = np.array(img, dtype='float32')
    test_imgs.append(img)

test_imgs = np.array(test_imgs)
test_imgs = np.expand_dims(test_imgs, axis=1)    

In [None]:
test_imgs.shape

In [20]:
batch_size = 1
num_workers = 0    #increase to 1 or 2 with multiple GPUs

test_data = TensorDataset(torch.Tensor(test_imgs))

test_loader_params = {'batch_size': batch_size,
                     'shuffle': False,
                     'num_workers': num_workers,
                     'pin_memory':True,
                     'drop_last': False}

test_loader = DataLoader(test_data, **test_loader_params)

### Instantiate network and load trained parameters

In [None]:
maindir = '/data/FIBSEM/napari_venv/training_scripts/networks_output/tunet3_1700/tunet3'
patterndir = '/data/FIBSEM/napari_venv/training_scripts/networks_output/tunet3_1700_patterns/'
params = np.load(maindir + '/params.npy', allow_pickle=True)
params_patterns = np.load(patterndir + '/params.npy', allow_pickle=True)

#params = params[()]  # Weird trick for loading dictionaries saves as arrays
params,params_patterns = params[0],params_patterns[0]
print(type(params))
print('The following define the network parameters: ', params)
print('The following define the network parameters for patterns: ',params_patterns)

In [22]:
mem= '/data/FIBSEM/napari_venv/training_scripts/networks_output/tunet3_membrane/'
params_mem = np.load(mem + '/params.npy', allow_pickle=True)
params_mem=params_mem[0]

In [None]:
model_type = 'TUNet'
#model_type = 'MSDNet'   # Use only if loading msdnet folder 

# Initialize correct network architecture
net, model_params = create_network(model_type, params)
net_patterns, model_params_patterns = create_network(model_type, params_patterns)
net_mem, model_params_mem = create_network(model_type, params_mem)

# Load trained network parameters
net.load_state_dict(torch.load(maindir + '/net_best'))
net_mem.load_state_dict(torch.load(mem + '/net_best'))

net_patterns.load_state_dict(torch.load(patterndir + '/net_best'))

### Segment and save

Device is either cpu or cuda:0 (for a graphics card). We also get summary of the network.

In [None]:
device = helpers.get_device()
device='cuda:1'
print('Device we compute on: ', device)

net.to(device)
net_patterns.to(device)
net_mem.to(device)

In [None]:
batch_size = 1
num_workers = 0    #increase to 1 or 2 with multiple GPUs
nucleoluspath = folder+'/nucleolus/'
chromopath = folder+'/chromosome/'


for i in range(0,files,100):
    test_data = TensorDataset(torch.Tensor(test_imgs[i:i+100]))
    if os.path.isdir(nucleoluspath) is False:
        os.mkdir(nucleoluspath)
        os.mkdir(chromopath)
        
    test_loader_params = {'batch_size': batch_size,
                         'shuffle': False,
                         'num_workers': num_workers,
                         'pin_memory':True,
                         'drop_last': False}

    test_loader = DataLoader(test_data, **test_loader_params)
    output, input_imgs  = segment_imgs(test_loader, net)
    output = torch.squeeze(output, 1)
    input_imgs = torch.squeeze(input_imgs, 1)
    tunet3_output = torch.argmax(output.cpu()[:,:,:,:].data, dim=1)
    imgs,masks = input_imgs.numpy(),tunet3_output.numpy()
    idx=(masks==2)
    chromosome=np.zeros(imgs.shape)
    chromosome[idx]=imgs[idx]

    idx=(masks==3)
    nucleolus=np.zeros(imgs.shape)
    nucleolus[idx]=imgs[idx]
    
    del output
    del tunet3_output
    del input_imgs
    torch.cuda.empty_cache()
    
    
    for j in range(nucleolus.shape[0]):
        name = f'{i+j:03}.jpg'
        Image.fromarray(nucleolus[j].astype(np.uint8)).save(nucleoluspath+name)
        Image.fromarray(chromosome[j].astype(np.uint8)).save(chromopath+name)



In [None]:
batch_size = 1
num_workers = 0    #increase to 1 or 2 with multiple GPUs

patternpath = folder+'/pattern/'  


for i in range(0,files,100):
    test_data = TensorDataset(torch.Tensor(test_imgs[i:i+100]))
    if os.path.isdir(patternpath) is False:
        
        os.mkdir(patternpath)
        
    test_loader_params = {'batch_size': batch_size,
                         'shuffle': False,
                         'num_workers': num_workers,
                         'pin_memory':True,
                         'drop_last': False}



    test_loader = DataLoader(test_data, **test_loader_params)
    output_patterns, input_imgs  = segment_imgs(test_loader, net_patterns)
    output_patterns = torch.squeeze(output_patterns, 1)
    input_imgs = torch.squeeze(input_imgs, 1)
    tunet3_output_patterns = torch.argmax(output_patterns.cpu()[:,:,:,:].data, dim=1)
    imgs,patterns = input_imgs.numpy(),tunet3_output_patterns.numpy()

    idx=(patterns==4)
    pattern=np.zeros(patterns.shape)
    pattern[idx]=imgs[idx]

    for j in range(pattern.shape[0]):
        name = f'{i+j:03}.jpg'
        Image.fromarray(pattern[j].astype(np.uint8)).save(patternpath+name)

    del output_patterns
    del input_imgs
    torch.cuda.empty_cache()

In [None]:
batch_size = 1
num_workers = 0    #increase to 1 or 2 with multiple GPUs

mempath = folder+'/membrane/'  


for i in range(0,files,100):
    test_data = TensorDataset(torch.Tensor(test_imgs[i:i+100]))
    if os.path.isdir(mempath) is False:
        os.mkdir(mempath)
        
    test_loader_params = {'batch_size': batch_size,
                         'shuffle': False,
                         'num_workers': num_workers,
                         'pin_memory':True,
                         'drop_last': False}


    test_loader = DataLoader(test_data, **test_loader_params)
    output, input_imgs  = segment_imgs(test_loader, net_mem)
    output = torch.squeeze(output, 1)
    input_imgs = torch.squeeze(input_imgs, 1)
    tunet3_output = torch.argmax(output.cpu()[:,:,:,:].data, dim=1)
    imgs,mem = input_imgs.numpy(),tunet3_output.numpy()

    idx=(mem==2)
    mem=np.zeros(mem.shape)
    mem[idx]=imgs[idx]

    for j in range(mem.shape[0]):
        name = f'{i+j:03}.jpg'
        Image.fromarray(mem[j].astype(np.uint8)).save(mempath+name)

    del output
    del input_imgs
    torch.cuda.empty_cache()

In [56]:
folder = '/data/FIBSEM/nuclie_align_ROI2_1700/nuclie02/full_nuclie/chromosome'
files = len([f for f in os.listdir(folder) if f.endswith('.jpg')])
print(files)

1306


In [None]:
test_imgs = []
import glob
from os import walk
files = []
for file in glob.glob(f'{folder}/*.jpg'):
    files.append(file)
def last_4chars(x):
    return int(os.path.basename(x)[:-4])

files = sorted(files,key = last_4chars) 

for file in files:
    img = Image.open(file)
    img.load()
    img = np.array(img, dtype='float32')
    idx=(img!=0)
    chr=np.zeros(img.shape)
    chr[idx]=2
    test_imgs.append(chr)

In [None]:
import cc3d 
labels_out, N = cc3d.largest_k(test_imgs, k=6, connectivity=26, delta=10,return_N=True)
viewer = napari.view_image(labels_out, name='nucl1')
imwrite(f'{folder}/chromosomes_split_12.tif', labels_out)