In [1]:
import os
import time 
import numpy as onp
import scipy.stats as stats
from PIL import Image
# --- jax --- 
import autograd as Agrad
import autograd.numpy as np 
import scipy.optimize

In [2]:
# -- plotting --- 
import matplotlib as mpl 
import matplotlib.pyplot as plt
mpl.rcParams['text.usetex'] = True
mpl.rcParams['font.family'] = 'serif'
mpl.rcParams['axes.linewidth'] = 1.5
mpl.rcParams['axes.xmargin'] = 1
mpl.rcParams['xtick.labelsize'] = 'x-large'
mpl.rcParams['xtick.major.size'] = 5
mpl.rcParams['xtick.major.width'] = 1.5
mpl.rcParams['ytick.labelsize'] = 'x-large'
mpl.rcParams['ytick.major.size'] = 5
mpl.rcParams['ytick.major.width'] = 1.5
mpl.rcParams['legend.frameon'] = False

## gridding

In [3]:
Ngrid = 64 # 64x64 pixels 

xpix = np.linspace(0., 1., Ngrid) # default pixel gridding 
ypix = np.linspace(0., 1., Ngrid) 

_xxpix, _yypix = np.meshgrid(xpix, ypix) 
xxpix = _xxpix.flatten()
yypix = _yypix.flatten()
xypix = np.array([xxpix, yypix]).T

## PSF

In [4]:
fwhm_psf = 258.21 / 100. / float(Ngrid) # pixels
sig_psf = fwhm_psf / 2.35482004503
cov_psf = sig_psf**2 * np.identity(2)
cinv_psf = np.linalg.inv(cov_psf)

In [5]:
def psi(theta): 
    ''' measurement model (2d gaussian of width sigma PSF) written out to x,y grid
    '''
    return np.exp(-0.5 * np.array([dxy @ (cinv_psf @ dxy.T) for dxy in (xypix - theta[None,:])]))

In [None]:
theta_xgrid = np.linspace(0., 1., Ngrid) 
theta_ygrid = np.linspace(0., 1., Ngrid) 
_theta_xxgrid, _theta_yygrid = np.meshgrid(theta_xgrid, theta_ygrid) 
theta_xygrid = np.array([_theta_xxgrid.flatten(), _theta_yygrid.flatten()]).T
grid_psi = np.stack([psi(tt) for tt in theta_xygrid])  

In [None]:
def Psi(ws, thetas): 
    ''' "forward operator" i.e. forward model 
    
    Psi = int psi(theta) dmu(theta) 

    where mu is the signal parameter
    '''
    _thetas = np.atleast_2d(thetas)
    return np.sum(np.array([w * psi(tt) for (w,tt) in zip(ws, _thetas)]),0)

## ADCG

In [None]:
def ell(ws, thetas, yobs): 
    ''' loss function 
    '''
    if len(thetas.shape) == 1 and thetas.shape[0] > 2: 
        thetas = thetas.reshape((int(thetas.shape[0]/2), 2))
    return ((Psi(ws, thetas) - yobs)**2).sum() 


def gradell(ws, thetas, yobs):  
    ''' gradient of the loss fucntion 
    '''
    return (Psi(ws, thetas) - yobs)/((Psi(ws, thetas) - yobs)**2).sum() 


def lmo(v): 
    ''' step 1 of ADCG: "linear maximization oracle". This function does the following 
    optimization 
    
    argmin < psi(theta), v > 

    where for ADCG, v = the gradient of loss. For simplicity, we grid up theta to 
    theta_grid and calculate grid_psi minimize the inner product 
    '''
    ip = (grid_psi @ v) 
    return theta_xygrid[ip.argmin()] 


def coordinate_descent(thetas, yobs, lossFn, iter=35, min_drop=1e-5, **lossfn_kwargs):  
    ''' step 2 of ADCG (nonconvex optimization using block coordinate descent algorithm).
    compute weights, prune support, locally improve support
    '''
    def min_ws(): 
        # non-negative least square solver to find the weights that minimize loss 
        return scipy.optimize.nnls(np.stack([psi(tt) for tt in thetas]).T, yobs)[0]

    def min_thetas(): 
        res =  scipy.optimize.minimize(
                Agrad.value_and_grad(lambda tts: lossFn(ws, tts, yobs, **lossfn_kwargs)), thetas, 
                jac=True, method='L-BFGS-B', bounds=[(0.0, 1.0)]*2*thetas.shape[0])
        return res['x'], res['fun']

    old_f_val = np.Inf
    for i in range(iter): 
        thetas = np.atleast_2d(thetas)

        ws = min_ws() # get weights that minimize loss

        thetas, f_val = min_thetas() # keeping weights fixed, minimize loss 
    
        if len(thetas.shape) == 1 and thetas.shape[0] > 2: 
            thetas = thetas.reshape((int(thetas.shape[0]/2), 2))

        if old_f_val - f_val < min_drop: # if loss function doesn't improve by much
            break 
        old_f_val = f_val.copy()
    return ws, thetas 


def adcg(yobs, lossFn, gradlossFn, local_update, max_iters, **lossfn_kwargs): 
    ''' Alternative Descent Conditional Gradient 
    '''
    thetas, ws = np.zeros(0), np.zeros(0) 
    output = np.zeros(len(xypix)) 

    history = [] 
    for i in range(max_iters): 
        residual = output - yobs
        loss = lossFn(ws, thetas, yobs, **lossfn_kwargs) 
        print('  iter=%i, loss=%f' % (i, loss)) 
        history.append((loss, ws, thetas))
    
        # get gradient of loss function 
        grad = gradlossFn(ws, thetas, yobs, **lossfn_kwargs) 
        # compute new support
        theta = lmo(grad)
        # update signal parameters  
        if i == 0: _thetas = np.append(thetas, theta)
        else: _thetas = np.append(np.atleast_2d(thetas), np.atleast_2d(theta), axis=0)

        ws, thetas = local_update(_thetas, yobs, lossFn, **lossfn_kwargs)

        # calculate output 
        output = Psi(ws, thetas)
        
        if (i > 2) and (history[-2][0] - history[-1][0] < 1.):  
            return loss, ws, thetas

    return loss, ws, thetas

## example 0: fake data

In [None]:
np.random.seed(0)
def obs2d(N_source=5, sig_noise=5.):  
    ''' generate and write out 2d observations on a xypix. Takes positions 
    and "intensities" (weights) and convolves them with a PSF and adds noise 
    '''
    thetas = np.array([np.random.uniform(0, 1, N_source), np.random.uniform(0, 1, N_source)]).T # x_true, y_true positions 
    weights = np.repeat(100., N_source) #np.random.rand(N_source)*2 # weights --- in SMLM intensities 
    return thetas, weights, Psi(weights, thetas) + sig_noise * np.random.randn(len(xypix)) 

eg0_thetas, eg0_weights, eg0_data = obs2d()
eg0_true_data = Psi(eg0_weights, eg0_thetas)

In [None]:
loss, ws, thetas  = adcg(eg0_data, ell, gradell, coordinate_descent, 30)
output_adcg = Psi(ws, thetas) 

In [None]:
# plot data 
fig = plt.figure(figsize=(15,5))
sub = fig.add_subplot(141)
sub.imshow(eg0_data.reshape(_xxpix.shape))
sub.set_title(r'$y_{\rm obs}$', fontsize=20) 

sub = fig.add_subplot(142)
sub.imshow(eg0_true_data.reshape(_xxpix.shape))
sub.set_title(r'$y_{\rm true}$', fontsize=20) 

sub = fig.add_subplot(143) 
sub.imshow(output_adcg.reshape(_xxpix.shape))
sub.set_title(r'$y_{\rm adcg}$', fontsize=20) 

sub = fig.add_subplot(144) 
sub.imshow((eg0_data - output_adcg).reshape(_xxpix.shape))
sub.set_title(r'$y_{\rm true} - y_{\rm adcg}$', fontsize=20) 

## example 1: BTLS

In [None]:
dir_tub = '/Users/ChangHoon/data/locahbay/smlm/bundled_tubes_long_seq/'

In [None]:
def read_frame(i): 
    f_frame = os.path.join(dir_tub, 'sequence', '%s.tif' % str(i).zfill(5))

    im = Image.open(f_frame)
    imarr = np.array(im)
    
    # back ground subtraction 
    noise_level = np.median(stats.sigmaclip(imarr.flatten(), high=3.)[0])
    return imarr - noise_level

def read_fluorophorses(i): 
    f_fluor = os.path.join(dir_tub, 'fluorophores', 'frames', '%s.csv' % str(i).zfill(5)) 
    try: 
        x, y, z, I = np.loadtxt(f_fluor, delimiter=',', skiprows=1, unpack=True, usecols=[2,3,4,5]) # positions in nm
    except ValueError: 
        # no fluorophores
        x = np.array([])
        y = np.array([])
        z = np.array([])
        
    # convert to pixels
    return np.array([x/100., y/100., z/100.])

In [None]:
for i in range(1,101): 
    print('Frame %i' % i)
    frame = read_frame(i)
    truth = read_fluorophorses(i)
    
    loss, ws, thetas = adcg(frame.flatten(), ell, gradell, coordinate_descent, 5)
    output_adcg = Psi(ws, thetas) 
    
    _im = Image.fromarray(output_adcg.reshape(_xxpix.shape))
    
    # writeout 
    _im.save(os.path.join('/Users/ChangHoon/data/locahbay/smlm/bundled_tubes_long_seq/adcg', 
                          '%s.tiff' % str(i).zfill(5)))
    
    # plot data 
    fig = plt.figure(figsize=(15,5))
    sub = fig.add_subplot(131)
    sub.imshow(frame)
    sub.set_title(r'$y_{\rm obs}$', fontsize=20) 

    sub = fig.add_subplot(132) 
    sub.imshow(output_adcg.reshape(_xxpix.shape))
    sub.scatter(truth[0,:], truth[1,:], s=200, facecolors='none', edgecolor='r', linewidths=4)   
    sub.set_title(r'$y_{\rm adcg}$', fontsize=20) 

    sub = fig.add_subplot(133) 
    sub.imshow(frame - (output_adcg).reshape(_xxpix.shape))
    sub.set_title(r'$y_{\rm true} - y_{\rm adcg}$', fontsize=20) 
    fig.savefig(os.path.join('/Users/ChangHoon/data/locahbay/smlm/bundled_tubes_long_seq/adcg', 
                          '%s.png' % str(i).zfill(5)))
    plt.close()