In [None]:
import numpy as np
import skfmm
import os
def loadswc(filepath):
    '''
    Load swc file as a N X 7 numpy array
    '''
    swc = []
    with open(filepath) as f:
        lines = f.read().split("\n")
        for l in lines:
            if not l.startswith('#'):
                cells = l.split(' ')
                if len(cells) ==7:
                    cells = [float(c) for c in cells]
                    swc.append(cells)
    return np.array(swc)

def saveswc(filepath, swc):
    if swc.shape[1] > 7:
        swc = swc[:, :7]

    with open(filepath, 'w') as f:
        for i in range(swc.shape[0]):
            print('%d %d %.3f %.3f %.3f %.3f %d' %
                  tuple(swc[i, :].tolist()), file=f)

def loadtiff3d(filepath):
    """Load a tiff file into 3D numpy array"""

    import tifffile as tiff
    a = tiff.imread(filepath)

    stack = []
    for sample in a:
        stack.append(np.rot90(np.fliplr(np.flipud(sample))))
    out = np.dstack(stack)

    return out


def writetiff3d(filepath, block):
    import tifffile as tiff

    try:
        os.remove(filepath)
    except OSError:
        pass

    with tiff.TiffWriter(filepath, bigtiff=False) as tif:
        for z in range(block.shape[2]):
            saved_block = np.rot90(block[:, :, z])
            tif.save(saved_block.astype('uint8'), compress=0)

In [2]:
def crop(img):
    ind = np.argwhere(img>0)
    x = ind[:,0]
    y = ind[:,1]
    z = ind[:,2]
    xmin = max(x.min()-5, 0)
    xmax = min(x.max()+5, img.shape[0])
    ymin = max(y.min()-5, 0)
    ymax = max(y.min()+5, img.shape[1])
    zmin = max(z.min()-5, 0)
    zmax = max(z.min()+5, img.shape[2])
    
    return np.array([[xmin, xmax], [ymin, ymax], [zmin, zmax]])
    

In [1]:
'''
Regenerate tif file from swc and also apply distance transform.
'''
def swc2tif_dt(swc, img):
    import math

    shape = img.shape
    skimg = np.ones(shape)
    zeromask = np.ones(shape)

    # Add nodes the current swc to make sure there is
    # at least one node in each voxel on a branch
    idlist = swc[:, 0]
    extra_nodes = []
    extra_nodes_radius = []
    for i in range(swc.shape[0]):
        cnode = swc[i, 2:5]
        c_radius = swc[i, -2]
        pnode = swc[idlist == swc[i, 6], 2:5]
        if pnode.shape[0] != 0:
            p_radius = swc[idlist == swc[i, 6], -2][0]
            average_radius = int(c_radius+p_radius)/2

        dvec = pnode - cnode # [[x, y, z]]
        dvox = np.floor(np.linalg.norm(dvec)) # eculidean norm
        if dvox >= 1:
            uvec = dvec / (dvox + 1) # unit vector
            extra_nodes.extend(
                [cnode + uvec * i for i in range(1, int(dvox))])
            extra_nodes_radius.extend([average_radius for i in range(1, int(dvox))])

    # Deal with nodes in swc
    for i in range(swc.shape[0]):
        node = [math.floor(n) for n in swc[i, 2:5]]
        for j in range(3):
            if node[j] > shape[j]-1:
                node[j] = shape[j]-1
        r = int(swc[i, -2])
        skimg[node[0], node[1], node[2]] = 0
        zeromask[max(0,node[0]-r): min(node[0]+r, shape[0]), max(0,node[1]-r):min(node[1]+r, shape[1]), max(0, node[2]-r):min(node[2]+r, shape[2])] = 0

    # Deal with the extra nodes
    ex_count = 0
    for ex in extra_nodes:
        node = [math.floor(n) for n in ex[0]] # get integer x, y, z
        for j in range(3):
            if node[j] > shape[j]-1:
                node[j] = shape[j]-1
        skimg[node[0], node[1], node[2]] = 0
        r = int(extra_nodes_radius[ex_count])
        zeromask[max(0,node[0]-r): min(node[0]+r, shape[0]), max(0,node[1]-r):min(node[1]+r, shape[1]), max(0, node[2]-r):min(node[2]+r, shape[2])] = 0
        ex_count += 1

    a, dm = 6, 5
    dt = skfmm.distance(skimg, dx=1)

    dt = np.exp(a * (1 - dt / dm)) - 1
    dt[zeromask == 1] = 0
    dt = (dt/np.max(dt))*255
    return dt

In [7]:
import os
import subprocess

path_prefix = '/media/jacktang/Work/USYD/Research/Deep_Learning/GAN/pytorch-CycleGAN-and-pix2pix/datasets/datasets/fly/fly3d/Tokyo/'

for sub in os.listdir(path_prefix):
    for f in os.listdir(path_prefix+sub):
        new_name = '1' + f.split('.')[0] + '.tif'
        subprocess.call(['mv', path_prefix+sub+'/'+f, path_prefix+sub+'/'+new_name])

In [14]:
import random 

ind = [x for x in range(1,43,1)]
for x in range(111,117,1):
    ind.append(x)
print(ind)
path_prefix = '/media/jacktang/Work/USYD/Research/Deep_Learning/GAN/pytorch-CycleGAN-and-pix2pix/datasets/datasets/fly/fly3d/'
tokyo_prefix = path_prefix + 'Tokyo/'
janelia_prefix = path_prefix + 'Janelia/'

random.shuffle(ind)
print('train: ', ind[:40])
print('val: ', ind[41:])

for i in ind[:40]:
    if int(i) > 100:
        subprocess.call(['cp',tokyo_prefix+'tokyo_gt/'+str(i)+'_gt.tif',path_prefix+'testA/'+str(i)+'_gt.tif'])
        

# for i in ind[:38]:
#     subprocess.call(['cp', path_prefix+str(i)+'_2d.tif', out_prefix+'trainB/'+str(i)+'_2d.tif'])
#     subprocess.call(['cp', path_prefix+str(i)+'_gt_2d.tif', out_prefix+'trainA/'+str(i)+'_gt_2d.tif'])

# for i in ind[38:]:
#     subprocess.call(['cp', path_prefix+str(i)+'_2d.tif', out_prefix+'testA/'+str(i)+'_2d.tif'])
#     subprocess.call(['cp', path_prefix+str(i)+'_gt_2d.tif', out_prefix+'testB/'+str(i)+'_gt_2d.tif'])


[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 111, 112, 113, 114, 115, 116]
train:  [35, 40, 9, 8, 13, 15, 17, 38, 28, 21, 3, 18, 29, 26, 10, 14, 113, 23, 33, 1, 20, 37, 42, 114, 2, 6, 115, 7, 11, 16, 22, 34, 32, 41, 5, 116, 30, 111, 112, 36]
val:  [4, 39, 19, 31, 12, 24, 25]


In [None]:
train:  [35, 40, 9, 8, 13, 15, 17, 38, 28, 21, 3, 18, 29, 26, 10, 14, 113, 23, 33, 1, 20, 37, 42, 114, 2, 6, 115, 7, 11, 16, 22, 34, 32, 41, 5, 116, 30, 111, 112, 36]
val:  [4, 39, 19, 31, 12, 24, 25]