In [None]:
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import os
from oiffile import imread
import random
import glob
import torch
import sys
sys.path.insert(0, '..')
from other import dognet
from other import functions
import skimage
from skimage.transform import resize
import tifffile as tiff
import torch
from tqdm import tnrange

In [None]:
# load networks
networks = []
for i in range(0,10):
    netName = "../networks/multipleNetworks-flocculusA/net%s" % (i)
    net = torch.load(netName)
    networks.append(net)
# load and clip pre-computed medians of slices
medianTotal= np.double(np.load("../../datasets/flocculusA/normalizationMatrices/medianTotal.npy"))
medianclipped0 = np.maximum(medianTotal[0],np.max(medianTotal[0])*0.6)
medianclipped1 = np.maximum(medianTotal[1],np.max(medianTotal[1])*0.6)
medianclipped2 = np.maximum(medianTotal[2],np.max(medianTotal[2])*0.6)

In [None]:
# generate predicted synapse maps for all images
for i in tnrange(1,204):
    if i<10:
        n = '000%s' % (i)
    elif i <100:
        n = '00%s' % (i)
    else:
        n = '0%s' % (i)
    # load image
    fdir = r'E:\pcp2cre_syptom_568_mglur1_1to200_647_1to250_vgat_1to200_488_1to250\FV10__20190507_224650_flocculusA/Track%s' % (n)
    im_name = 'Image%s_01.oib' % (n)
    image = imread(os.path.join(fdir,im_name))
    
    predictedZImages = []
    # evaluate each z-slice
    for z_slice in range(0,16):
        # normalize, shrink, and cut the image
        image_save = np.double(image[:,z_slice,:,:].transpose(2,1,0))  
        image_save[:,:,0] = image_save[:,:,0]/(medianclipped0[z_slice]*20)
        image_save[:,:,1] = image_save[:,:,1]/(medianclipped1[z_slice]*25)
        image_save[:,:,2] = image_save[:,:,2]/(medianclipped2[z_slice]*25)
        image_save = resize(image_save,(400,400), order=1, preserve_range=True)
        h=400
        w=400
        tol=30
        UL=(image_save[0:int(h/2+tol),0:int(w/2+tol),:]).transpose(2,1,0)       
        UR=(image_save[0:int(h/2+tol),int(w/2-tol):w,:]).transpose(2,1,0)
        LL=(image_save[int(h/2-tol):h,0:int(w/2+tol),:]).transpose(2,1,0)
        LR=(image_save[int(h/2-tol):h,int(w/2-tol):w,:]).transpose(2,1,0)
        
        smallPictures=[UL,UR,LL,LR]
        predictedImages=[]
        
        # generate prediction maps
        for j in range(0,4):
            inimg = np.transpose(resize(np.transpose(smallPictures[j],(2,1,0)),(230,230)),(2,0,1))
            outimages = []
            for netnum in networks:
                outname = functions.inference(netnum.to(torch.device("cuda")),inimg,get_inter=False,device=torch.device("cuda"))
                outimages.append(outname)
            # get average prediction map
            outimg = np.median(np.asarray(outimages), axis=0)
            newImg = resize(outimg,(2,460,460), order=1, preserve_range=True)
            pMap = newImg[:1, :, :]
            predictedImages.append(pMap)
        # take maximum of overlaps of smaller images
        Ntol=60 
        top = np.maximum(predictedImages[0][:, 0:h+Ntol, w-Ntol:w+Ntol], predictedImages[1][:, 0:h+Ntol, 0:2*Ntol])
        left = np.maximum(predictedImages[0][:, h-Ntol:h+Ntol, 0:w+Ntol], predictedImages[2][:, 0:2*Ntol, 0:w+Ntol])
        bottom = np.maximum(predictedImages[2][:, 0:h+Ntol, w-Ntol:w+Ntol], predictedImages[3][:, 0:h+Ntol, 0:2*Ntol])
        right = np.maximum(predictedImages[1][:, h-Ntol:h+Ntol, 0:w+Ntol], predictedImages[3][:, 0:2*Ntol, 0:w+Ntol])
        middle1 = np.maximum(top[:, h-Ntol:h+Ntol, 0:2*Ntol], bottom[:, 0:2*Ntol, 0:2*Ntol])
        middle2 = np.maximum(right[:, 0:2*Ntol, 0:2*Ntol], left[:, 0:2*Ntol, w-Ntol:w+Ntol])
        middle = np.maximum(middle1, middle2)
        #stitch four prediction maps
        topLeft = predictedImages[0][:, 0:h-Ntol, 0:w-Ntol]
        topRight = predictedImages[1][:, 0:h-Ntol, 2*Ntol:]
        bottomLeft = predictedImages[2][:, 2*Ntol:, 0:w-Ntol]
        bottomRight = predictedImages[3][:, 2*Ntol:, 2*Ntol:]
        # generate final prediction map for entire image
        finalImage = np.zeros((1,800,800))
        finalImage[:, :h-Ntol, :w-Ntol] = topLeft
        finalImage[:, :h+Ntol, w-Ntol:w+Ntol] = top
        finalImage[:, :h-Ntol, w+Ntol:] = topRight
        finalImage[:, h-Ntol:h+Ntol, :w+Ntol] = left
        finalImage[:, h-Ntol:h+Ntol, w-Ntol:] = right
        finalImage[:, h+Ntol:, :w-Ntol] = bottomLeft
        finalImage[:, h-Ntol:, w-Ntol:w+Ntol] = bottom
        finalImage[:, h+Ntol:, w+Ntol:] = bottomRight
        finalImage[:, h-Ntol:h+Ntol, w-Ntol:w+Ntol] = middle 
        finalImage = (4000*finalImage).astype(np.uint16)        
        finalImage = np.flip(np.flip(finalImage.transpose(0,2,1)))
        
        # save all final prediction images in one z-stack
        predictedZImages.append(finalImage)
    
    stack = np.stack(predictedZImages, axis=1)
    # append prediction maps to image
    x = np.vstack((image, stack))
    x = x.astype(np.uint16)
        
    fileName = r'C:\Users\CGuo\Desktop\flocculusA\%s.tif' % (im_name[:12])
    # save as a tiff file
    ###tiff.imsave(fileName, x)