# Utils

## Imports

In [3]:
# import os
import pickle
import warnings
import seaborn
import pandas as pd
import autograd.numpy as np
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
from skimage.morphology import opening, closing, ball, dilation

from autograd import grad
from skimage import io
from skimage import data
from skimage import color
from skimage.morphology import opening, ball
from scipy.ndimage import fourier_shift, zoom, gaussian_filter
from scipy.sparse import csr_matrix, hstack
from scipy.stats import pearsonr
from scipy.spatial.distance import cdist
from scipy.ndimage.measurements import center_of_mass
from sklearn.decomposition import NMF, non_negative_factorization

from tqdm.notebook import tqdm

import ipywidgets as widgets
from ipywidgets import interactive, IntSlider

## Image Pre-Processing

In [4]:
#clamp image values to 0 at the low end
def high_pass(Y, pct = 10):
    Y_new = Y.copy()
    for idx in tqdm(range(Y.shape[3])):
        Ytemp = Y[:,:,:,idx]
        thresh = np.percentile(Ytemp[Ytemp > 0], pct)
        Ytemp[Ytemp < thresh] = thresh
        Y_new[:,:,:,idx] = Ytemp - thresh

    return Y_new

#normalize channels across rounds using average of high pixel values
def percentile_normalize(Y, pct_low = 95, pct_high = 99, return_sf = False):
    sf = []
    Y_new = Y.copy()
    Y_new = Y_new.astype('d')
    for idx in tqdm(range(Y.shape[3])):
        in_range = np.logical_and(Y[:,:,:,idx] > np.percentile(Y[:,:,:,idx], pct_low), Y[:,:,:,idx] < np.percentile(Y[:,:,:,idx], pct_high)) 
        if np.min(Y[in_range, idx]) <= 0:
            print('WARNING: Percentile values may be poorly scaled...')
        scale_fac = np.mean(Y[in_range, idx])
        sf.append(scale_fac)
        Y_new[:,:,:,idx] = Y[:,:,:,idx]/scale_fac
    
    Y_new = Y_new*np.median(sf)
    if return_sf:
        return Y_new, sf
    else:
        return Y_new

#use morphological opening to remove background noise
def background_opening(Y, size = 3):
    strel = ball(size)
    Y_new = Y.copy()
    for i in tqdm(range(Y.shape[3])):
        Y_new[:,:,:,i] = Y[:,:,:,i] - opening(Y[:,:,:,i], strel)
    
    return Y_new

## Visualization

In [5]:
%matplotlib inline

#image visualization utility functions

def imagesc(img, sz1 = 10, sz2 = 10):
    plt.figure(figsize=(sz1,sz2))
    plt.tick_params(axis='both', which='both', bottom=False, top=False, labelbottom=False, left=False, labelleft=False)
    max_val = np.percentile(img, 99.9) #make sure colors are scaled over whole image
    plt.imshow(img, vmin = 0, vmax = max_val)
    
def viewmask(img, sz1 = 10, sz2 = 15):
    plt.figure(figsize=(sz1,sz2))
    plt.tick_params(axis='both', which='both', bottom=False, top=False, labelbottom=False, left=False, labelleft=False)
    plt.imshow(np.sum(img, 2), vmin = 0)

def imagesc3D(vol):
    def f(idx):
        plt.figure(figsize=(10,10))
        plt.tick_params(axis='both', which='both', bottom=False, top=False, labelbottom=False, left=False, labelleft=False)
        max_val = np.percentile(vol, 99) #make sure colors are scaled over whole image
        plt.imshow(vol[:,:, idx], vmin = 0,vmax = max_val)
        plt.show()

    interactive_plot = interactive(f, idx=IntSlider(value=0, description='slice', max=vol.shape[-1]-1, min=0, continuous_update = False))
    output = interactive_plot.children[-1]
    output.layout.height = '650px'
    return interactive_plot

def imagesc4D(vol):
    def f(cidx, zidx):
        plt.figure(figsize=(10,10))
        plt.tick_params(axis='both', which='both', bottom=False, top=False, labelbottom=False, left=False, labelleft=False)
        max_val = np.percentile(vol, 99) #make sure colors are scaled over whole image
        plt.imshow(vol[:,:, zidx,cidx], vmin = 0, vmax = max_val)
        plt.show()

    interactive_plot = interactive(f, cidx=IntSlider(value=0, description='channel', max=vol.shape[-1]-1, min=0, continuous_update = False), zidx=IntSlider(value=0, description='slice', max=vol.shape[-2]-1, min=0, continuous_update = False))
    output = interactive_plot.children[-1]
    output.layout.height = '650px'
    return interactive_plot

def view_components(vol):
    def f(idx):
        plt.figure(figsize=(10,10))
        plt.tick_params(axis='both', which='both', bottom=False, top=False, labelbottom=False, left=False, labelleft=False)
        plt.imshow(np.sum(vol[:,:, :, idx], 2))
        plt.show()

    interactive_plot = interactive(f, idx=IntSlider(value=0, description='cell', max=vol.shape[-1]-1, min=0, continuous_update = False))
    output = interactive_plot.children[-1]
    output.layout.height = '650px'
    return interactive_plot

## Patching

In [6]:
#greedily initialize cell bodies of a certain size

def greedy_init(Y, n_components, sigma, verbose = True):
    devs = 4
    #save dimensions of Y
    d1, d2, d3, T = Y.shape

    #initialize objects
    dsigma = 2*np.array(sigma)+1
    W = np.zeros((d1, d2, d3, n_components))
    H = np.zeros((n_components,T))
    centers = np.zeros((n_components, 3))
    Y_init = Y.copy()
    Y_init = Y_init.astype('double')

    #gaussian blur image
    rho = np.array([gaussian_filter(Y_init[:,:,:, i], sigma, mode = 'constant', output = 'double', truncate = devs) for i in range(T)])
    rho = np.moveaxis(rho, 0, -1)

    #find sum image
    v = np.sum(rho, 3)

    for k in range(n_components):
        #find max location and boundaries
        ix, iy, iz = np.unravel_index(np.argmax(v), v.shape)
        xmax, xmin = min(v.shape[0]-1, ix + dsigma[0]), max(0, ix - dsigma[0])
        ymax, ymin = min(v.shape[1]-1, iy + dsigma[1]), max(0, iy - dsigma[1])
        zmax, zmin = min(v.shape[2]-1, iz + dsigma[2]), max(0, iz - dsigma[2])

        data = Y_init[xmin:xmax, ymin:ymax, zmin:zmax, :]
        dims = data.shape
        data = np.reshape(data, (dims[0]*dims[1]* dims[2], dims[3]))
        rA, C, niter = non_negative_factorization(data, n_components = 1, alpha = 0.2, max_iter = 100, tol = 1e-50)

        #save extracted components
        A = np.reshape(rA, (dims[0], dims[1],dims[2]))
        W[xmin:xmax, ymin:ymax, zmin:zmax,k] = A
        H[k, :] = C
        centers[k,:] = [ix, iy, iz]
        
        if k < (n_components-1): #if not last component
            #update Y_init
            removed_data = np.dot(rA, C)
            removed_data = np.reshape(removed_data, (dims[0], dims[1], dims[2], dims[3]))
            
            temp = Y_init[xmin:xmax, ymin:ymax, zmin:zmax, :] - removed_data
            temp[temp < 0] = 0 #remove negative values
            Y_init[xmin:xmax, ymin:ymax, zmin:zmax, :] = temp
            
            rho_xmax, rho_xmin = min(v.shape[0]-1, xmax+devs*sigma[0]), max(0, xmin-devs*sigma[0])
            rho_ymax, rho_ymin = min(v.shape[1]-1, ymax+devs*sigma[1]), max(0, ymin-devs*sigma[1])
            rho_zmax, rho_zmin = min(v.shape[2]-1, zmax+devs*sigma[2]), max(0, zmin-devs*sigma[2])
            
            inp_xmax, inp_xmin = min(v.shape[0]-1, xmax+2*devs*sigma[0]), max(0, xmin-2*devs*sigma[0])
            inp_ymax, inp_ymin = min(v.shape[1]-1, ymax+2*devs*sigma[1]), max(0, ymin-2*devs*sigma[1])
            inp_zmax, inp_zmin = min(v.shape[2]-1, zmax+2*devs*sigma[2]), max(0, zmin-2*devs*sigma[2])
            #update rho
            for i in range(T):
                rho[:,:,:,i] = gaussian_filter(Y_init[:,:,:, i], sigma, mode = 'constant', output = 'double', truncate = devs)

            #update sum image
            v = np.sum(rho, 3)
        if verbose:
            print('Found Component #' + str(k+1) + '...')
    return W, H, centers

In [7]:
def create_3D_patches(dims, patchsize, overlap):
    xdim, ydim, zdim = dims
    px, py, pz = patchsize
    ovlp_x, ovlp_y, ovlp_z = overlap
    
    itx, ity, itz = px - ovlp_x, py - ovlp_y, pz - ovlp_z
    
    xvals = [0]
    yvals = [0]
    zvals = [0]
    
    while xvals[-1]+px < xdim:
        xvals.append(xvals[-1] + itx)
        
    while yvals[-1]+py< ydim:
        yvals.append(yvals[-1] + ity)
        
    while zvals[-1]+pz < zdim:
        zvals.append(zvals[-1] + itz)
        
    n_patches = len(xvals) * len(yvals) * len(zvals)
    patches = np.zeros((2,3, n_patches))
    
    i = 0
    for xmin in xvals:
        for ymin in yvals:
            for zmin in zvals:
                xmax, ymax, zmax = min(xmin+px, xdim),min(ymin+py, ydim),min(zmin+pz, zdim)
                patches[:,:, i] = np.array([[xmin, ymin, zmin], [xmax, ymax, zmax]])
                i += 1
    
    return patches.astype('int')

In [8]:
def process_patch(Y, n_components, sigma, patch_list):
    W_init, H_init, centers = greedy_init(Y, n_components, sigmas)
    
    W_init, H_init, pl = merge_components(W_init, H_init, patch_list)
    n_components = W_init.shape[-1]
    
    d1,d2,d3,T = Y.shape
    dim = d1*d2*d3

    Y = np.reshape(Y, (dim, T))
    W_init = np.reshape(W_init, (dim, n_components))
    
    A, C, niter = non_negative_factorization(Y, W = W_init, H = H_init, n_components = n_components, 
                update_H = False, verbose = False, alpha = 0.5, max_iter = 200, tol = 1e-4, init = 'custom', l1_ratio = 1)
    A = np.reshape(A, (d1,d2,d3,n_components))
        
    return A,C, pl

In [9]:
def embed_patch_results(A, full_dims, patch_min_coords, patch_max_coords):
    T = A.shape[-1]
    xmin, ymin, zmin = patch_min_coords
    xmax, ymax, zmax = patch_max_coords
    
    full = np.zeros([full_dims[0], full_dims[1], full_dims[2], T])
    full[xmin:xmax, ymin:ymax, zmin:zmax, :] = A
    
    return full

## Post-Processing

In [10]:
def remove_empty_components(A, C, patch_list):
    Aempty = np.sum(A, (0,1,2)) == 0
    Cempty = np.sum(C, 1) == 0
    remove = np.where(np.logical_or(Aempty, Cempty))[0]
    
    A = np.delete(A, remove, -1)
    C = np.delete(C, remove, 0)
    patch_list = np.delete(patch_list, remove, 0)
        
    return A, C, patch_list

In [11]:
def connected_components(adjacency, min_size = 0):
    viewed = set()
    components = []
    for seed in range(adjacency.shape[0]):
        if seed not in viewed:
            component = [seed]
            for node in component:
                viewed.add(node)
                connected = np.where(adjacency[seed, :] > 0)[0]
                
                for n in connected: 
                    if n not in viewed:
                        component.append(n)
                        viewed.add(n)
            
            components.append(component)
    
    components = [component for component in components if len(component) >= min_size]
    return components

In [12]:
def patch_mismatch(patch_list):
    mat = np.zeros((patch_list.shape[0], patch_list.shape[0]))
    for i in range(patch_list.shape[0]):
        mat[i, :] = patch_list != patch_list[i]
        
    return mat

In [13]:
#find spatial and temporal correlation between components, merge those that have high correlation in both, remove duplicate components
def merge_components(A, C, patch_list, spatial_thresh = 0.05, crosspatch_spatial_thresh = 0.05, temporal_thresh = 0.85):
    spatial_correlation = create_spatial_correlation(A)
    temporal_correlation = create_temporal_correlation(C)
    patch_mis = patch_mismatch(patch_list)
    
    spatial_overlap = (spatial_correlation > spatial_thresh)
    temporal_overlap = (temporal_correlation > temporal_thresh)
    
    to_merge = np.logical_and(spatial_overlap, temporal_overlap)
    conn_comp = connected_components(to_merge, min_size = 2)
    
    Anew = A.copy()
    Cnew = C.copy()
    for comp in tqdm(conn_comp):
        merge_idx = comp[0]
        
        inpA = A[:,:,:,comp]
        inpC = C[comp,:]
        clique_patch_list = patch_list[comp]
        
        mergeA, mergeC = merge_component_clique(inpA, inpC, clique_patch_list)
        
        Anew[:,:,:,comp] = 0
        Cnew[comp,:]
        
        Anew[:,:,:,merge_idx] = mergeA
        Cnew[comp,:] = mergeC
    
        
    Anew, Cnew, patch_list = remove_empty_components(Anew, Cnew, patch_list)
    
    return Anew, Cnew, patch_list

In [14]:
#subfunction for merging groups of components
def merge_component_clique(As, Cs, patch_list):
    Ys = []
    for idx in np.unique(patch_list):
        tempAs = As[:,:,:,patch_list == idx]
        tempCs = Cs[patch_list == idx, :]
        tempYs = [np.matmul(tempAs[:, :, :, [i]], tempCs[[i], :]) for i in range(tempAs.shape[-1])]
        tempY = np.sum(tempYs, 0)
        Ys.append(tempY)
    
    Y = Ys.pop()
    Ypos = (np.sum(Y, -1, keepdims = True) > 0).astype('int')

    while len(Ys) > 0:
        current = Ys.pop()
        Y = np.add(Y, current)
        Ypos = np.add(Ypos, (np.sum(current, -1, keepdims = True) > 0).astype('int'))
    
    Ypos[Ypos == 0] = 1
    Y = Y/Ypos
        
    d1,d2,d3, T = Y.shape
    Y = np.reshape(Y, (d1*d2*d3, T))
    W_init = np.reshape(np.mean(As, 3), (d1*d2*d3, 1))
    H_init = np.mean(Cs, 0, keepdims = True)
    
    A, C, niter = non_negative_factorization(Y, W = W_init, H = H_init, n_components = 1,
                    verbose = False, alpha = 0.5, max_iter = 100, tol = 1e-30, init = 'custom')
    
    A = np.reshape(A, ((d1,d2,d3)))
    return A, C

In [15]:
#calculate correlation graphs
def create_spatial_correlation(A, blur = False, sigma = [0.5,0.5,0.5]):
    n_cells = A.shape[-1]
    if blur:
        A = A.copy()
        for i in tqdm(range(n_cells)):
            A[:,:,:,i] = gaussian_filter(A[:,:,:,i], sigma, mode = 'constant', output = 'double')
    
    A = np.reshape(A, (-1, n_cells))

    cellsum = np.sum(A, 0, keepdims = True)
    A = A / cellsum 
    
    spatial_overlap = np.zeros((n_cells, n_cells))
    
    for i in tqdm(range(n_cells-1)):
        i_pos = A[:, i] > 0
        tempA = A[i_pos, :]
        subtract_vec = tempA[:, [i]]
        tempA = tempA[:, i+1:]
        tempA = tempA - subtract_vec
        tempA[tempA > 0] = 0
        overlap = np.sum(tempA, 0) + 1
        spatial_overlap[i, i+1:] = overlap
        spatial_overlap[i+1:, i] = overlap
            
    return spatial_overlap

def create_temporal_correlation(C):
    return np.abs(np.corrcoef(C)) - np.eye(C.shape[0])

In [16]:
def get_fov_index(fov_list, idx):
    fov_label = fov_list[idx]
    start_idx = np.where(fov_list == fov_label)[0][0]
    new_idx = idx - start_idx
    return fov_label, new_idx

def get_fov_index_dict(fov_list, idx_list):
    return_dict = {}
    for idx in idx_list:
        fov_label, new_idx = get_fov_index(fov_list, idx)
        if fov_label in return_dict:
            return_dict[fov_label].append(new_idx)
        else:
            return_dict[fov_label] = [new_idx]
    
    return return_dict

## Basecalling

In [17]:
def basecall_round(Y, thresh_c, thresh_d, it):
    dim,T = Y.shape
    if not T == it:
        raise Exception('Wrong round dimension...')
        
    Y = Y.copy().astype('double')
    new = np.zeros((dim,T+1)).astype('int')
    new[:,0] = (Y[:,0] - thresh_c*Y[:,1]) > 0
    new[:,1] = (Y[:,1] - thresh_c*Y[:,0]) > 0
    new[:,2] = (Y[:,2] - thresh_d*np.max(Y[:,0:2], 1)) > 0
    
    new[:, 0] = new[:, 0] - new[:, 2]
    new[:, 1] = new[:, 1] - new[:, 2]
    
    new[new < 0] = 0
    
    new[:,3] = np.sum(new, 1) == 0
    return new

def matrix_basecall(Y, thresh_c = 1.5, thresh_d = 1.5, it = 3):
    dim,T = Y.shape
    basecalls = np.zeros((dim, int(T*(it+1)/it)), dtype = int)
    for i in tqdm(range(0, Y.shape[-1], it)):
        rd = Y[:,i:i+it]
        b_idx = int(i*(it+1)/it)
        basecalls[:, b_idx:b_idx+it+1] = basecall_round(rd, thresh_c = thresh_c, thresh_d = thresh_d, it = it)
    
    return basecalls

In [18]:
def get_A_pixels(Y, Acomp, percentile = 95):
    pos = Acomp[Acomp > 0]
    vals = np.where(Acomp > np.percentile(pos, percentile))
    
    return Y[vals[0], vals[1], vals[2], :]

In [19]:
def calc_sy(flatY):
    dim, T = flatY.shape
    n = T
    sy = flatY.sum(axis=1, keepdims=True)
    sty = np.sqrt(np.sum((flatY - sy/T)**2, 1, keepdims = True) / n)
    
    return sy, sty

def barcode_correlation(X, flatY, sy = None, sty = None):
    dim, T = flatY.shape
    n = T
    if sy is None or sty is None:
        sy = flatY.sum(axis=1, keepdims=True)
        sty = np.sqrt(np.sum((flatY - sy/T)**2, 1, keepdims = True) / n)
    
    stx = np.std(X)
    sx = np.sum(X)
    
    corr = (n*np.sum(flatY*X, 1, keepdims = True) - sx*sy) / (n**2 * stx * sty)

    corr[np.isnan(corr)] = 0
    corr[np.isinf(corr)] = 0
    
    return corr, sy, sty

def barcode_match(barcode, barcode_img, num_bases):
    matches = np.dot(barcode_img, barcode.T) == num_bases
    return np.squeeze(matches)

In [20]:
def norm_by_round(vec, bases = 3):
    vec = vec.copy()
    if(len(vec.shape) == 1):
        vec = np.reshape(vec, (1, -1))
    for i in range(0, vec.shape[-1], bases):
        vec[:, i:i+bases] = vec[:, i:i+bases] / np.max(vec[:, i:i+bases], 1, keepdims = True)
    return np.squeeze(vec)

In [21]:
def duplicate_row(row):
    chans = row.shape[0]
    val = np.zeros((4, 4*chans))
    for i in range(val.shape[0]):
        val[i, i*chans:(i+1)*chans] = row
    return val

def reshape_data_for_optimization(array):
    n_pix, n_col = array.shape
    array = np.reshape(array, (n_pix, 5, 3))
    
    array_scaler = np.max(array, -1, keepdims = True)
    array_scaler[array_scaler == 0] = 1
    
    array = array / array_scaler
    
    dup_array = np.zeros((n_pix, 5, 4, 12))
    for pix in range(n_pix):
        for r_idx in range(array.shape[1]):
            dup_array[pix, r_idx,:, :] = duplicate_row(array[pix, r_idx, :])
            
    return dup_array

def softmax_by_pixel(inp, beta_s = 250):
    ex = np.exp(beta_s*inp.astype('float128'))
    
    bc = ex / np.sum(ex, 2, keepdims = True)
    
    n_pix = bc.shape[0]
    return np.reshape(bc, (n_pix, -1)).astype('float64')

In [22]:
def center_of_mass_norm(mask, order):
    xi, yi, zi = np.where(mask)
    center = np.array(center_of_mass(mask))
    def loss(center):
        xdist = np.sum(np.power(np.abs(xi - center[0]), order))
        ydist = np.sum(np.power(np.abs(yi - center[1]), order))
        zdist = np.sum(np.power(np.abs(zi - center[2]), order))
        return xdist + ydist + zdist
    
    gradient = grad(loss)
    
    #print("Initial distance:", loss(center))
    for i in range(1000):
        #if i % 100 == 0:
            #print("Current distance:", loss(center))
        
        g = gradient(center) 
        center -= g / (np.linalg.norm(g, 2))
        
    #print("Trained distance:", loss(center))
    return center
    

## Basecalling Optimization

In [23]:
def basecalling_loss(weights, alpha = 1, beta = 4, gamma = 0.1):
    loss = 0
    for i in range(len(inputs)):
        inpA = inputs[i][0] 
        inpB = inputs[i][1] 
        inpC = inputs[i][2]
        bA = softmax_by_pixel(np.matmul(inpA, weights.T)) 
        bB = softmax_by_pixel(np.matmul(inpB, weights.T))
        bC = softmax_by_pixel(np.matmul(inpC, weights.T))
        #variance_loss = -np.mean(np.matmul(bA, bA.T)) 
        variance_loss = -np.mean(np.matmul(bA, bC.T)) 
        discrimination_loss = np.mean(np.matmul(bA, bB.T)) 
        
        loss += (alpha*variance_loss + beta*discrimination_loss)/len(inputs)
    
    loss += gamma * np.sum(weights ** 2)
    return loss

In [24]:
def basecall_with_weights(Y, rweights):
    d1, d2, d3, T = Y.shape
    barcodeY = np.reshape(Y, (d1, d2, d3, 5, 3))
    barcodeY = np.matmul(barcodeY, rweights.T)
    barcodeY = (barcodeY / np.max(barcodeY, -1, keepdims = True) == 1).astype('int')
    barcodeY = np.reshape(barcodeY, (d1,d2,d3, -1))
    return barcodeY

# Cell Body Search

## Calculate NMF components

In [2]:
def load_fov(filepath, ext = '_affine', raw = False):
    num_channels = 3
    bases_sequenced = [1, 2, 3, 4, 6]
    
    if raw:
        file_strs = ['ch00', 'ch01', 'ch02']
    else:
        file_strs = ['ch00', 'ch01SHIFT', 'ch02SHIFT']
    filenames = []
    for ridx in bases_sequenced:
        for fs in file_strs:
            filenames.append(filepath+'richieseq_round00'+str(ridx)+'_'+fs+ext+'.tif')

    Y = np.array([zoom(io.imread(filenames[j]), (0.25, 0.25, 0.25), order = 1)  for j in tqdm(range(len(bases_sequenced)*num_channels))])
    Y = np.moveaxis(Y, 1, -1)
    Y = np.moveaxis(Y, 0, -1)
    
    return Y

fovs = ['0_0', '0_1', '1_0', '1_1', '2_0', '2_1', '3_0', '3_1', '4_0', '4_1', '5_0', '5_1', '6_0', '6_1']

Ys = {}

Ys['0_0'] = load_fov('/mp/nas2/DG/iarpa_virtual_tiles/0_0/2_color-correction/', ext = '')
Ys['0_1'] = load_fov('/mp/nas2/DG/iarpa_virtual_tiles/0_1/2_color-correction/', ext = '')
Ys['1_0'] = load_fov('/mp/nas2/DG/iarpa_virtual_tiles/1_0/4_registration/')
Ys['1_1'] = load_fov('/mp/nas2/DG/iarpa_virtual_tiles/1_1/2_color-correction/', ext = '')
Ys['2_0'] = load_fov('/mp/nas2/DG/iarpa_virtual_tiles/2_0/4_registration/')
Ys['2_1'] = load_fov('/mp/nas2/DG/iarpa_virtual_tiles/2_1/4_registration/')
Ys['3_0'] = load_fov('/mp/nas2/DG/iarpa_virtual_tiles/3_0/4_registration/')
Ys['3_1'] = load_fov('/mp/nas2/DG/iarpa_virtual_tiles/3_1/4_registration/')
Ys['4_0'] = load_fov('/mp/nas2/DG/iarpa_virtual_tiles/4_0/4_registration/')
Ys['4_1'] = load_fov('/mp/nas2/DG/iarpa_virtual_tiles/4_1/4_registration/')
Ys['5_0'] = load_fov('/mp/nas2/DG/iarpa_virtual_tiles/5_0/4_registration/')
Ys['5_1'] = load_fov('/mp/nas2/DG/iarpa_virtual_tiles/5_1/4_registration/')
Ys['6_0'] = load_fov('/mp/nas2/DG/iarpa_virtual_tiles/6_0/4_registration/')
Ys['6_1'] = load_fov('/mp/nas2/DG/iarpa_virtual_tiles/6_1/4_registration/')

#stitch all vfovs together
full_stitch = np.concatenate([np.concatenate([Ys['0_0'], Ys['1_0'], Ys['2_0'], Ys['3_0'], Ys['4_0'], Ys['5_0'], Ys['6_0']], 0), np.concatenate([Ys['0_1'], Ys['1_1'], Ys['2_1'], Ys['3_1'], Ys['4_1'], Ys['5_1'], Ys['6_1']], 0)], 1)

NameError: name 'np' is not defined

In [7]:
full_stitch = high_pass(full_stitch)
full_stitch, sf = percentile_normalize(full_stitch, return_sf = True)

pickle.dump(sf, open('full_stitch_normalization_weights.p', 'wb'))

HBox(children=(FloatProgress(value=0.0, max=15.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=15.0), HTML(value='')))




In [None]:
#preprocess full stitch
full_stitch = high_pass(full_stitch)
full_stitch = percentile_normalize(full_stitch)

#preprocess each individual FOV
for fov in fovs:
    try:
        temp = high_pass(Ys[fov])
        temp = percentile_normalize(temp)
        Ys[fov] = temp
    except:
        print('High pass/normalization failed for FOV ' + fov)

#make flat copy of full_stitch
df1,df2,df3,fT = full_stitch.shape
fdim = df1*df2*df3

flat_full_stitch = np.reshape(full_stitch, (fdim, fT))

In [None]:
#specify how many cells initialized for each FOV
num_cells_fov = [0, 0, 5, 0, 15, 15, 15, 15, 15, 15, 15, 15, 0, 0]

for fov_idx in range(len(fovs)):
    fov_file_label = fovs[fov_idx]
    print(fov_file_label)
    num_cells = num_cells_fov[fov_idx]
    if num_cells == 0:
        print('Skipping FOV ' + str(fov_file_label) + '...')
    else:
        print('Running FOV ' + str(fov_file_label) + '...')
        Y = Ys[fov_file_label]
        Y = high_pass(Y)
        Y = percentile_normalize(Y)

        d1,d2,d3,T = Y.shape
        dim = d1*d2*d3

        flatY = np.reshape(Y, (dim, T))

        sigmas = [12,12,8]

        patches = create_3D_patches([d1,d2,d3], [150, 150, 100], [30, 30, 30])
        num_patches = patches.shape[-1]

        A = np.zeros([d1,d2,d3, num_cells*num_patches])
        C = np.zeros([num_cells*num_patches, T])
        patch_list = np.zeros([num_cells*num_patches])

        with warnings.catch_warnings():
            # ignore all caught warnings
            warnings.filterwarnings("ignore")

            for i in tqdm(range(num_patches)):
                idx = i*num_cells
                patch_coords = patches[:,:,i]
                x1,y1,z1,x2,y2,z2 = np.ravel(patch_coords)
                Ypatch = Y[x1:x2, y1:y2, z1:z2, :]
                pl = np.array([i]*num_cells)
                Apatch, Cpatch, pl = process_patch(Ypatch, num_cells, sigmas, pl)

                merged_num = Cpatch.shape[0]
                print('Reshaping patch results...')
                patch_list[idx:idx+merged_num] = pl
                C[idx:idx+merged_num] = Cpatch
                A[:,:,:,idx:idx+merged_num] = embed_patch_results(Apatch, [d1,d2,d3], [x1,y1,z1], [x2, y2, z2])

        A, C, patch_list = remove_empty_components(A, C, patch_list)
        A, C, patch_list = merge_components(A, C, patch_list)

        pickle.dump(A, open(fov_file_label+"reprocess_complete_A.p", "wb" ))
        pickle.dump(C, open(fov_file_label+"reprocess_complete_C.p", "wb" ))

## Post-Process NMF components

### Load components

In [88]:
As = {}
Cs = {}

for fov in tqdm(fovs):
    try:
        As[fov] = pickle.load(open(fov+"_A.p", "rb" ))
        Cs[fov] = pickle.load(open(fov+"_C.p", "rb" ))
    except:
        print('No components loaded for ' + fov)
        

HBox(children=(FloatProgress(value=0.0, max=14.0), HTML(value='')))

No components loaded for 0_1
No components loaded for 1_1
No components loaded for 6_0
No components loaded for 6_1



In [95]:
#label each component with its source fov, store info in fov_list
fov_list = []
for fov in fovs:
    if fov in As:
        fov_list = fov_list + [fov] * As[fov].shape[-1]
        
fov_list = np.array(fov_list)

#store padding needed to convert fov coords -> full stitch coords
padding = {}
full_padding = {}
ffv1, ffv2, ffv3 = 1870, 1946, 400
fv1, fv2, fv3 = 468, 486, 100
for fov in fovs:
    full_padding[fov] = np.array((int(fov[0])*ffv1, int(fov[2])*ffv2, 0))
    padding[fov] = np.array((int(fov[0])*fv1, int(fov[2])*fv2, 0))

### Calculate statistics on NMF components

In [81]:
nmf_distances = []
nmf_coms = []
nmf_df = []

for fov in tqdm(fovs):
    if fov in As:
        A = As[fov]
        for cidx in tqdm(range(A.shape[-1])):
            comp = A[:,:,:,cidx]
            
            #take highest intensity pixels in NMF A image, find center of mass of those pixels
            high = comp > np.percentile(comp, 99.99)
            comr = center_of_mass_norm(high, 0.5)
            x,y,z = comr
            xpos, ypos, zpos = np.where(high)
            
            #find distance from high intensity pixels to center of mass
            distance = []
            for idx in range(xpos.shape[-1]):
                dist= (xpos[idx]-x)**2 + (ypos[idx]-y)**2 + (zpos[idx]-z)**2
                distance.append(dist ** 0.5)
            nmf_distances.append(distance)
            
            #check the density of high pixel mask by checking size increase after dilation
            df = np.sum(dilation(high, ball(1))) / np.sum(high)
            nmf_df.append(df)

            com = comr + padding[fov]
            nmf_coms.append(com)
            
nmf_median_distances = np.array([np.median(distances) for distances in nmf_distances])

nmf_coms = np.array(nmf_coms)
nmf_df = np.array(nmf_df)

HBox(children=(FloatProgress(value=0.0, max=14.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=53.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=50.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=121.0), HTML(value='')))

  lambda ans, x, y : unbroadcast_f(x, lambda g: g * y * x ** anp.where(y, y - 1, 1.)),
  lambda ans, x : lambda g: g * replace_zero(anp.conj(x), 0.) / replace_zero(ans, 1.))





HBox(children=(FloatProgress(value=0.0, max=104.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=135.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=119.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=126.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=132.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=122.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=112.0), HTML(value='')))





In [None]:
#mark components with a low density or high median distance to center of mass (potentially multiple cell bodies)
prefilter = np.logical_and(nmf_median_distances < 50, nmf_df < 5)

#mark components in different patches that have similar center of mass and temporal signal
#(suggests cell is on fov border, was initialized twice)
fullC = np.concatenate([Cs[fov] for fov in fovs if fov in Cs], 0)
tc = create_temporal_correlation(fullC)
pm = patch_mismatch(fov_list)
comd = cdist(nmf_coms, nmf_coms)

merge_across_patches = connected_components(np.logical_and(np.logical_and(comd < 90, pm), tc > 0.85), 2)

remove_indices = []
for cc in tqdm(merge_across_patches):
    cc_sizes = [np.sum(As[get_fov_index(fov_list, elm)[0]][:,:,:,get_fov_index(fov_list, elm)[1]]) for elm in cc]
    [remove_indices.append(index) for index in cc if index != cc[np.argmax(cc_sizes)]]
    
prefilter[remove_indices] = False

prefilter = np.where(prefilter)[0]

In [87]:
#remove marked components
keep_dict = get_fov_index_dict(fov_list,prefilter)
fov_list = fov_list[prefilter]

for fov in fovs:
    if fov in As:
        As[fov] = As[fov][:,:,:, keep_dict[fov]]
        Cs[fov] = Cs[fov][keep_dict[fov], :]
        
nmf_coms = nmf_coms[prefilter, :]
nmf_df = nmf_df[prefilter, :]
fullC = np.concatenate([Cs[fov] for fov in fovs if fov in Cs], 0)

KeyboardInterrupt: 

### Train Basecalling Weights

In [325]:
#make input data for training
num_pixels = 1000 # how many top pixel values used to calculate basecalling loss
inputs = []
cross_comp_vals = []

#use all remaining NMF components to calculate maximally discriminative basecalling weights
for fov in fovs:
    if fov in As:
        A = As[fov]
        Y = Ys[fov]
        C = Cs[fov]
        
        #pick random x pixels from top 0.01 percent highest values of each component, assign those pixels to that components
        #also pool values from each component into cross_comp_vals
        selected_components = range(A.shape[-1])
        for fidx in tqdm(selected_components):
            input_vals = []
            mask = A[:,:,:, fidx] > np.percentile(A[:,:,:, fidx], 99.99)
            masked = Y[mask, :]

            input_vals.append(reshape_data_for_optimization(masked[np.random.choice(masked.shape[0], size=num_pixels, replace = False), :]))    
            inputs.append(input_vals)

            cross_comp_vals.append(masked[np.random.choice(masked.shape[0], size=num_pixels, replace = False), :])


cross_comp_vals = np.concatenate(cross_comp_vals, 0)

#assign each component x pixels drawn randomly from the pool of pixels used by ALL components
idx = 0
for fov in fovs:
    if fov in As:
        A = As[fov]
        Y = Ys[fov]
        C = Cs[fov]
        selected_components = range(A.shape[-1])
        
        for fidx in tqdm(selected_components):
            inputs[idx].append(reshape_data_for_optimization(cross_comp_vals[np.random.choice(cross_comp_vals.shape[0], size=num_pixels, replace = False), :]))
            inputs[idx].append(reshape_data_for_optimization(C[[fidx], :]))
            idx+=1
            
#result is input data with length 3
#1st item: 15-long signal vector for 1000 top pixels belonging to each NMF component
#2nd item: 15-long signal vector for 1000 pixels belonging to random components
#3rd item: 15-long vector = NMF C value for that component

HBox(children=(FloatProgress(value=0.0, max=35.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=37.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=60.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=89.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=92.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=82.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=65.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=48.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=19.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=35.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=37.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=69.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=60.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=89.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=92.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=82.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=65.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=48.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=19.0), HTML(value='')))




In [332]:
#train weights that maximize basecalling agreement within the component while 
#maximizing basecalling difference between the component and randomly drawn pixels
weights = np.array([[1, 0, 0, 0, 1, 0, 0, 0, 1, 1, 1, 0]]).astype('d')
matrix_gradient = grad(basecalling_loss)
learning_rates = [0.01, 0.001, 0.0001]
training_steps = [500, 500, 500]
print("Initial loss:", basecalling_loss(weights))
for train_idx in range(len(learning_rates)):
    for i in tqdm(range(training_steps[train_idx])):
        if i % 10 == 0:
            print("Current loss:", basecalling_loss(weights))
        weights -= matrix_gradient(weights) * learning_rates[train_idx]

print("Trained loss:", basecalling_loss(weights))

Initial loss: 8.522170937175297


HBox(children=(FloatProgress(value=0.0, max=500.0), HTML(value='')))

Current loss: 8.522170937175297
Current loss: 2.7997774957012953
Current loss: 2.7776894066682765
Current loss: 2.758026535462781
Current loss: 2.739238573726101
Current loss: 2.7215074862661353
Current loss: 2.704826877473797
Current loss: 2.688989622417308
Current loss: 2.673917326982906
Current loss: 2.6644719520610685
Current loss: 2.6490747475164906
Current loss: 2.6355120237912764
Current loss: 2.622719284111408
Current loss: 2.610459778147777
Current loss: 2.598625857005801
Current loss: 2.5870546856178356
Current loss: 2.5755079251034507
Current loss: 2.563894225364197
Current loss: 2.552718711262135
Current loss: 2.5424163473199677
Current loss: 2.5325399538561397
Current loss: 2.5231516275312327
Current loss: 2.5172584329845913
Current loss: 2.513222586433723
Current loss: 2.502602429066288
Current loss: 2.497713330486314
Current loss: 2.4902563306221053
Current loss: 2.4818159286631327
Current loss: 2.4741786905193104
Current loss: 2.4650044751601836
Current loss: 2.45529112

HBox(children=(FloatProgress(value=0.0, max=500.0), HTML(value='')))

Current loss: 2.3679608445568117
Current loss: 2.3525622514774036
Current loss: 2.3515700759751605
Current loss: 2.3510813700117827
Current loss: 2.350646590221187
Current loss: 2.3502252580815814
Current loss: 2.349806835144144
Current loss: 2.349387947844261
Current loss: 2.3489674553317976
Current loss: 2.348544889888356
Current loss: 2.3481199563282966
Current loss: 2.347692352535708
Current loss: 2.347261678562711
Current loss: 2.346827361877793
Current loss: 2.3463885724612283
Current loss: 2.345944111040012
Current loss: 2.3454922500084416
Current loss: 2.345030494757455
Current loss: 2.344555211847452
Current loss: 2.344061034555688
Current loss: 2.3435398974738506
Current loss: 2.342979457181779
Current loss: 2.3423604846993626
Current loss: 2.3416523109982808
Current loss: 2.3408031850492574
Current loss: 2.339714629942582
Current loss: 2.3381959674018384
Current loss: 2.336133498641753
Current loss: 2.3339683310275126
Current loss: 2.3321032403811888
Current loss: 2.33058145

HBox(children=(FloatProgress(value=0.0, max=500.0), HTML(value='')))

Current loss: 2.32071266876076
Current loss: 2.320680748426688
Current loss: 2.3206488437452366
Current loss: 2.320616954464105
Current loss: 2.3205850803351273
Current loss: 2.320553221114115
Current loss: 2.3205213765606687
Current loss: 2.320489546438018
Current loss: 2.3204577305128655
Current loss: 2.320425928555204
Current loss: 2.3203941403382022
Current loss: 2.3203623656380277
Current loss: 2.3203306042337073
Current loss: 2.3202988559069864
Current loss: 2.3202671204422036
Current loss: 2.320235397626135
Current loss: 2.3202036872478717
Current loss: 2.320171989098709
Current loss: 2.3201403029719923
Current loss: 2.320108628663031
Current loss: 2.3200769659689415
Current loss: 2.320045314688571
Current loss: 2.3200136746223485
Current loss: 2.319982045572207
Current loss: 2.319950427341458
Current loss: 2.3199188197346787
Current loss: 2.3198872225576443
Current loss: 2.3198556356171736
Current loss: 2.3198240587210903
Current loss: 2.3197924916780726
Current loss: 2.3197609

## Filter Duplicate Barcodes

In [None]:
# not done

## Save Cell Barcodes / Weights

In [None]:
#save relevant info for exseq integration
run_name = 'test1'
pickle.dump(weights, open(run_name + "_weights.p", "wb" ))
pickle.dump(nmf_coms, open(run_name + "_centerofmass.p", "wb" ))
pickle.dump(nmf_df, open(run_name + "_dilationfactor.p", "wb" ))
pickle.dump(fov_list, open(run_name + "_fovlist.p", "wb" ))

In [None]:
pickle.dump(fullC, open(run_name + "_cvec.p", "wb" ))

b1, bT = fullC.shape
barcodeC = np.reshape(fullC, (b1, 5, 3))
barcodeC = np.matmul(barcodeC, rweights.T)
barcodeC = (barcodeC / np.max(barcodeC, -1, keepdims = True) == 1).astype('int')
barcodeC = np.reshape(barcodeC, (b1,-1))

pickle.dump(barcodeC, open(run_name + "_barcodes.p", "wb" ))