In [None]:
import pandas as pd
import seaborn as sns
import h5py
import numpy as np

df = pd.read_hdf('/media/Backup/smlm_z_data/20240625_NUP_ifluor647/FOV1/storm_1/storm_1_MMStack_Default.ome_locs_undrift.hdf5', key='locs')

with h5py.File('/media/Backup/smlm_z_data/20240625_NUP_ifluor647/FOV1/storm_1/storm_1_MMStack_Default.ome_spots.hdf5') as f:
    spots = np.array(f['spots'])


In [None]:
from scipy.optimize import curve_fit
import numpy as np
import matplotlib.pyplot as plt
from scipy.special import erf
from skimage.transform import resize

def mean_squared_error(x, y):
    err = np.mean((x-y)**2)
    return err

def reduce_img(stack):
    return stack.max(axis=(1,2))
    
def get_lat_fwhm(image, px_size_xy, debug=False, mse_thres=0.001):
    def gaussian_2d(xy, amplitude, xo, yo, sigma_x, sigma_y, theta, offset):
        x, y = xy
        xo = float(xo)
        yo = float(yo)
        a = (np.cos(theta)**2) / (2 * sigma_x**2) + (np.sin(theta)**2) / (2 * sigma_y**2)
        b = -(np.sin(2 * theta)) / (4 * sigma_x**2) + (np.sin(2 * theta)) / (4 * sigma_y**2)
        c = (np.sin(theta)**2) / (2 * sigma_x**2) + (np.cos(theta)**2) / (2 * sigma_y**2)
        g = offset + amplitude * np.exp(- (a * ((x - xo)**2) + 2 * b * (x - xo) * (y - yo) + c * ((y - yo)**2)))
        return g.ravel()

    # Load and preprocess the image (e.g., convert to grayscale)
    # For simplicity, let's generate a simple image for demonstration
    image_size = image.shape[1]
    x = np.linspace(0, image_size - 1, image_size)
    y = np.linspace(0, image_size - 1, image_size)
    x, y = np.meshgrid(x, y)
    
    image = image / image.max()

    # Fit the Gaussian to the image data
    p0 = [1, image_size / 2, image_size / 2, 2, 2, 0, 0]  # Initial guess for parameters
    bounds = [
        (0, np.inf),
        (image_size * (1/5), image_size * (4/5)),
        (image_size * (1/5), image_size * (4/5)),
        (0, image_size/3),
        (0, image_size/3),
        (-np.inf, np.inf),
        (0, np.inf),
    ]

    try:
        popt, pcov = curve_fit(gaussian_2d, (x, y), image.ravel(), p0=p0, bounds=list(zip(*bounds)))
    except RuntimeError as e:
        popt = p0
    render = gaussian_2d((x, y), *popt).reshape(image.shape)

    error = mean_squared_error(render, image)
    # if error > mse_thres:
    #     fwhm_x, fwhm_y =  np.nan, np.nan
    # else:
    amplitude, xo, yo, sigma_x, sigma_y, theta, offset = popt
    f = 2 * np.sqrt(2 * np.log(2))
    fwhm_x = sigma_x * f * px_size_xy
    fwhm_y = sigma_y * f * px_size_xy
    fwhm_xy = np.mean([fwhm_x, fwhm_y])


    if debug:
        plt.figure(figsize=(2,2))
        print('FWHM x:', round(fwhm_x, 3), 'nm')
        print('FWHM y:', round(fwhm_y, 3), 'nm')
        print('MSE   :', '{:.2e}'.format(error))
        plt.imshow(image)
        plt.show()
        print('\n')
    return fwhm_xy, error

from tqdm import tqdm
spots_fwhm = [get_lat_fwhm(x, 106) for x in tqdm(spots)]


In [None]:
sns.scatterplot(data=df.groupby('frame').mean(), x='frame', y='z [nm]', alpha=0.01)

In [None]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers, Sequential
from tensorflow.keras import optimizers
from keras.callbacks import ReduceLROnPlateau, EarlyStopping, Callback
from tensorflow.keras.layers import Input, Dense, Flatten, Dropout
from tensorflow.keras.models import Model
from tensorflow.keras import Sequential, layers

In [None]:
args = {
    'image_size': 32,
    'architecture': 'mobilenet',
    'dense1': 32,
    'dense2': 32
}

imshape = (args['image_size'], args['image_size'], 1)
img_input = Input(shape=imshape, name='img_input')
gray_to_rgb = layers.Lambda(tf.image.grayscale_to_rgb, name='gray_to_rgb', output_shape=(args['image_size'], args['image_size'], 3))
img = gray_to_rgb(img_input)


coords_input = layers.Input((2,))
x_coords = layers.Dense(64)(coords_input)

x_coords = layers.Dense(64)(x_coords)

model_version = {
    'mobilenet': keras.applications.MobileNetV3Small,
    'mobilenet_large': keras.applications.MobileNetV3Large,
    'vgg': keras.applications.VGG19,
    'resnet': keras.applications.ResNet50V2,
    'resnet_large': keras.applications.ResNet101V2,
}[args['architecture']]

if 'vit_' in args['architecture']:
    feat_model = model_version(image_size=args['image_size'], 
                            activation='sigmoid',
                            pretrained=True,
                            include_top=False,
                            pretrained_top=False)
else:
    feat_model = model_version(input_shape=(args['image_size'], args['image_size'], 3),
                              weights='imagenet',
                              include_top=False)

x = feat_model(img)
# Add additional layers for regression prediction
x = Flatten()(x)
x = tf.concat([x, x_coords], axis=-1)
x = Dense(args['dense1'], activation='gelu')(x)
x = Dropout(0.5)(x)
if args['dense2'] != 0:
    x = Dense(args['dense2'], activation='gelu')(x)
    x = Dropout(0.5)(x)
regression_output = Dense(1, activation='tanh')(x)  # Linear activation for regression
model = Model(inputs=[img_input, coords_input], outputs=regression_output)

# aug_model = Model(inputs=img_input, outputs=img_aug_out)

In [None]:
import numpy as np
tmp = np.random.uniform(size=(5, 32, 32, 1))
coords = np.random.uniform(size=(5, 2))

model((tmp, coords))

In [None]:
import tensorflow as tf
import os
import sys
sys.path.insert(0,'../publication')

from util.util import _apply_img_norm, get_model_report, get_model_img_norm

os.environ['CUDA_VISIBLE_DEVICES'] = ''

model_path = '/home/miguel/Projects/smlm_z/autofocus/VIT_zeiss_mobilenet_fov_max/out/'
model = tf.keras.models.load_model(model_path + '/latest_vit_model3')

model_report = get_model_report(model_path)
img_norm = get_model_img_norm(model_report)


dataset = tf.data.Dataset.load('/home/miguel/Projects/smlm_z/autofocus/VIT_zeiss_mobilenet_fov_max/out/test/')
from tifffile import imread


In [None]:
import tensorflow as tf
from tensorflow import keras
import numpy as np

class RandomGaussianNoise(keras.layers.Layer):
    def __init__(self, mean_range=(-0.1, 0.1), std_range=(0.0, 0.5), perc_chance = 0.5, **kwargs):
        super(RandomGaussianNoise, self).__init__(**kwargs)
        self.mean_range = mean_range
        self.std_range = std_range
        self.perc = perc_chance

    def call(self, inputs, training=None):
        if training:
            stack_noise = []
            # Draw random mean and std values from the specified ranges
            means = tf.random.uniform(shape=(tf.shape(inputs)[0],), minval=self.mean_range[0], maxval=self.mean_range[1])
            stds = tf.random.uniform(shape=(tf.shape(inputs)[0],), minval=self.std_range[0], maxval=self.std_range[1])
            prob = tf.random.uniform(shape=(tf.shape(inputs)[0],), minval=1, maxval=0)
            for i in range(inputs.shape[0]):
                if self.perc >= prob[i]:
                    noise = tf.random.normal(shape=tf.shape(inputs)[1:], mean=means[i], stddev=stds[i])
                    noise = tf.math.floor(noise)
                else:
                    noise = tf.zeros(shape=tf.shape(inputs)[1:])
                stack_noise.append(noise)
            
            output = inputs + tf.stack(stack_noise)
            maxs = tf.math.reduce_max(output, axis=(1,2,3), keepdims=True)
            return tf.nn.relu(output / maxs)
        else:
            return inputs

    def get_config(self):
        config = super().get_config()
        config.update({
            "mean_range": self.mean_range,
            "std_range": self.std_range,
            'perc_chance': self.perc_chance
        })
        return config
    

def tamper_imgs(imgs):
    l = RandomGaussianNoise(mean_range=(0.5, 1), std_range=(0.1, 0.3), perc_chance=.5)
    imgs2 = l(imgs, training=True).numpy()
    return imgs2

def eval_imgs(s):
    # s = apply_img_norm(s).numpy()

    snrs = str(round(np.array([snr(x) for x in s]).mean(), 3))
    # print(s.min(axis=(1,2,3)))
    # print(s.max(axis=(1,2,3)))

    plt.imshow(grid_psfs(s.mean(axis=-1)).T)
    plt.title(snrs)
    plt.show()
    
for s in stacks:
    idx = np.argmax(s.max(axis=(1,2)))
    s = s[idx-100:idx+100, :, :, np.newaxis]
    s = apply_img_norm(s).numpy()
    eval_imgs(s)
    for _ in range(5):
        imgs = tamper_imgs(s)
        print(imgs.sum())
        eval_imgs(imgs)
    
    break

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import mean_absolute_error
from data.visualise import grid_psfs
from tensorflow.keras import Sequential, layers

import os

def norm_imgs(imgs, img_norm):
    imgs_xy, z = _apply_img_norm((imgs, _), None, img_norm)
    return imgs_xy[0].numpy()


zrange = 1000
n_images = 200


def snr(img):
    return img.max() / img.mean()


args = {
    'aug_gauss': 0.5,
    'aug_brightness': 0.1,
    'aug_poisson_lam': 1000,
    'seed': 42
}

def tamper_imgs(imgs):
    imgs2 = np.array(imgs)
    for i in range(imgs.shape[0]):
        noise = np.random.normal(0.01, 0.05, size=(*imgs.shape[1:3], 1))
        imgs2[i] += noise
        imgs2[imgs2<0] = 0
    return imgs2
    
def eval_images(model, img_norm, imgs, xy, zs):
    imgs = norm_imgs(imgs, img_norm, )
    print(imgs.min(), imgs.max())

    snrs = str(round(np.mean([snr(x) for x in imgs]), 3))
    fig, axs = plt.subplots(1, 3)
    z_pred = model.predict((imgs, xy)) * zrange
    err = str(round(mean_absolute_error(z_pred, zs), 3))
    plt.title(f'mae: {err}, snr: {snrs}')
    axs[0].scatter(zs, z_pred)
    axs[0].set_xlabel('True z (nm)')
    axs[0].set_ylabel('Pred z (nm)')
    axs[1].imshow(grid_psfs(imgs.mean(axis=-1)))
    axs[2].plot(imgs.max(axis=(1,2,3)))
    axs[2].plot(imgs.min(axis=(1,2,3)))
    plt.show()
    




import tensorflow as tf
dataset = tf.data.Dataset.load('/home/miguel/Projects/smlm_z/autofocus/VIT_zeiss_mobilenet_fov_max/out/test/')


for ((imgs, xy), zs) in dataset.as_numpy_iterator():
    zs = np.array(zs) * zrange
    zs = zs[:n_images]
    imgs = imgs[:n_images]
    xy = xy[:n_images]
    # eval_images(imgs, xy, zs)
    tampered_imgs = [tamper_imgs(imgs) for _ in range(5)]

    for _ in range(2):
        imgs2 = tamper_imgs(imgs)

        # eval_images(imgs2, xy, zs)
    break

import os
# os.environ['CUDA_VISIBLE_DEVICES'] = ''

import sys
sys.path.insert(0,'../publication')

import gc
from util.util import _apply_img_norm, get_model_report, get_model_img_norm



from glob import glob
models = glob('/home/miguel/Projects/smlm_z/autofocus/VIT_zeiss_mobilenet/out_standard_aug*')
for model_path in models:
    imgs_model = np.array(imgs)
    tamp_imgs_model = np.array(tampered_imgs)
    
    model = tf.keras.models.load_model(os.path.join(model_path,  'latest_vit_model3'))
    
    model_report = get_model_report(model_path)
    img_norm = get_model_img_norm(model_report)
    print(img_norm)
    eval_images(model, img_norm, imgs, xy, zs)

    for tamp_imgs in tamp_imgs_model:
        eval_images(model, img_norm, tamp_imgs, xy, zs)
    del model
    gc.collect()
    print('------------'*10)

    



In [None]:
from tifffile import imread
import h5py
import numpy as np
import pandas as pd

stacks = imread('/home/miguel/Projects/smlm_z/autofocus/VIT_zeiss_resnet2/stacks.ome.tif')

with h5py.File('/media/Backup/smlm_z_data/20240606_bacteria_Miguel_zeiss/15min_timelapse_20nm_red_beads_every_1sec_1/15min_timelapse_20nm_red_beads_every_1sec_1_MMStack_Default.ome_spots.hdf5') as f:
    timelapse_spots = np.array(f['spots'])

timelapse_locs = pd.read_hdf('/media/Backup/smlm_z_data/20240606_bacteria_Miguel_zeiss/15min_timelapse_20nm_red_beads_every_1sec_1/15min_timelapse_20nm_red_beads_every_1sec_1_MMStack_Default.ome_locs.hdf5', key='locs')

In [None]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = ''
bead_imgs = []
for s in stacks:
    peak = np.argmax(s.max(axis=(1,2)))
    bead_imgs.append(s[peak-20:peak+20])

bead_imgs = np.stack(bead_imgs)
import h5py


args = {
    'baseline': 100,
    'sensitivity': 0.45,
    'gain': 1
}

timelapse_spots = (timelapse_spots * args['gain'] / args['sensitivity']) + args['baseline']

In [None]:
from tqdm import tqdm
aug_imgs = []
train_imgs = np.concatenate(bead_imgs)
n_aug_points = train_imgs.shape[0]
idx = np.random.randint(0, train_imgs.shape[0], size=n_aug_points)
for i in tqdm(idx):
    img = train_imgs[i]
    noise_mean = np.random.uniform(0, 1.2)
    img_range = img.max() - img.min()
    noise = np.random.normal(img.max() * noise_mean, img_range * np.random.uniform(0.1, 0.7), size=img.shape)
    aug_imgs.append(img+noise)

aug_imgs = np.stack(aug_imgs)
print(aug_imgs.shape)

In [None]:
def snr(img):
    return img.max() / img.mean()

stack_snrs = np.array([snr(x) for x in np.concatenate(bead_imgs)])
aug_stack_snrs = np.array([snr(x) for x in aug_imgs])
timelapse_snrs = np.array([snr(x) for x in timelapse_spots])


In [None]:
import seaborn as sns
import matplotlib.pyplot as plt
plt.figure(figsize=(10, 3))
frame_bins = np.linspace(0, timelapse_locs['frame'].max(), 5)
for i in range(len(frame_bins)-1):
    f_min = int(frame_bins[i])
    f_max = int(frame_bins[i+1])
    idx = np.argwhere((f_min <= timelapse_locs['frame']) & (timelapse_locs['frame'] <= f_max)).squeeze()
    sns.histplot(timelapse_snrs[idx], label=f'timelapse [{f_min}, {f_max}]', stat="probability", alpha=0.5)

sns.histplot(stack_snrs, label='stack', stat="probability", alpha=0.5)
sns.histplot(aug_stack_snrs, label='stack_aug', stat="probability", alpha=0.5)
plt.legend()
plt.yscale('log')
plt.xlabel('SNR (img_max  / img_mean)')
plt.title('SNR of timelapse at various intervals of frames vs training data')

In [None]:
import sys
sys.path.insert(0,'../publication')
import tensorflow as tf

from util.util import get_model_report, get_model_img_norm, show_psf_axial, _apply_img_norm

def norm_imgs(imgs, img_norm):
    imgs_xy, z = _apply_img_norm((imgs, _), None, img_norm)
    return imgs_xy[0].numpy()
    
def snr(x):
    print(x.shape)
    if x.ndim == 2:
        return x.max() / x.mean()
    else:
        return np.mean(x.max(axis=(1,2,3)) / x.mean(axis=(1,2,3)))

def add_noise(x, mean):
    noise = np.random.normal(x.max() * mean, (x.max()-x.min()) * np.random.uniform(0.1, 0.6), size=x.shape)
    return x + noise

def plot_ax_max(x, label, img_norm):
    x = norm_imgs(x, img_norm)
    maxs = x.mean(axis=(1,2))
    show_psf_axial(x.mean(axis=-1))
    
test_stack = bead_imgs[0, :, :, :, np.newaxis]
for img_norm in ['standard']:
    plot_ax_max(test_stack, 'orig', img_norm)
    for mean in np.linspace(0, 1, 5):
        noise_img = add_noise(test_stack, mean)
        plot_ax_max(noise_img, str(round(mean, 3)), img_norm)
        
    plt.legend()
    plt.title(img_norm)
    plt.show()
    


In [None]:
import sys
sys.path.insert(0,'../publication')
import tensorflow as tf


from util.util import get_model_report, get_model_img_norm


def _apply_img_norm(img_xy, z, img_norm):
    img_xy = list(img_xy)
    imgs = img_xy[0]
    if img_norm == 'frame-mean':
        means = tf.math.reduce_mean(imgs, axis=(1,2,3), keepdims=True)
        imgs -= means
        maxs = tf.math.reduce_max(imgs, axis=(1,2,3), keepdims=True)
        imgs = tf.nn.relu(imgs / maxs)
    elif img_norm == 'frame-min':
        mins = tf.math.reduce_min(imgs, axis=(1,2,3), keepdims=True)
        imgs -= mins
        maxs = tf.math.reduce_max(imgs, axis=(1,2,3), keepdims=True)
        imgs = tf.nn.relu(imgs / maxs)
    elif img_norm == 'frame-max':
        maxs = tf.math.reduce_max(imgs, axis=(1,2,3), keepdims=True)
        imgs = imgs / maxs
    elif img_norm == 'fov-max':
        maxs = tf.math.reduce_max(imgs)
        imgs = tf.nn.relu(imgs / maxs)
    elif img_norm == 'fov-minmax':
        maxs = tf.math.reduce_max(imgs)
        mins = tf.math.reduce_min(imgs)
        imgs = (imgs - mins) / (maxs-mins)
    elif img_norm == 'standard':
        mean = tf.math.reduce_mean(imgs)
        std = tf.math.reduce_std(imgs)
        imgs = (imgs - mean) / std
    else:
        print(f'img_norm: {img_norm} not supported')
        raise NotImplementedError()
    return (imgs, img_xy[1]), z
    
def norm_imgs(imgs, img_norm):
    imgs_xy, z = _apply_img_norm((imgs, _), None, img_norm)
    return imgs_xy[0].numpy()


In [None]:
import seaborn as sns
import numpy as np

def sample(c):
    idx = np.arange(c.shape[0])
    idx = np.random.choice(idx, size=1000)
    sample = c[idx, :, :, np.newaxis]
    return sample

for norm in ['frame-min', 'frame-mean', 'frame-max', 'fov-max', 'fov-minmax', 'standard']:
    for label, c in (('beads', sample(bead_imgs)), ('timelapse', sample(timelapse_spots))):
        norm_c = norm_imgs(np.array(c), norm)
        sns.histplot(norm_c.flatten(), label=label, stat="probability")
    plt.title(norm)
    plt.legend()
    plt.show()


In [None]:
for norm in reversed(['frame-min', 'frame-mean', 'frame-max', 'fov-max', 'fov-minmax', 'standard']):
    outdir = f'./out_{norm}'
    print(f'python3 ../../publication/train_model.py -o {outdir} --architecture=mobilenet --aug-brightness=0 --aug-gauss=0 --aug-poisson-lam=0 --batch_size=4096 --dataset=20240603_Miguel_Zeiss_Stacks --dense1=256 --dense2=32 --learning_rate=0.0001 --system=zeiss --project=autofocus --zrange=1000 --norm {norm};')

In [None]:
datasets = [
    ('20230601_MQ_celltype_beads', 20, '/home/miguel/Projects/data/20230601_MQ_celltype/20230601_MQ_celltype_beads/combined/stacks.ome.tif'),
    ('20231020_20nm_beads_10um_range_10nm_step', 10, '/home/miguel/Projects/data/20231020_20nm_beads_10um_range_10nm_step/combined/stacks.ome.tif'),
    ('20231128_tubulin_miguel', 10, '/home/miguel/Projects/data/20231128_tubulin_miguel/combined/stacks.ome.tif'),
    ('20231205_miguel_mitochondria', 10, '/home/miguel/Projects/data/all_openframe_beads/20231205_miguel_mitochondria/combined/stacks.ome.tif'),
    ('20231212_miguel_openframe', 10, '/home/miguel/Projects/data/all_openframe_beads/20231212_miguel_openframe/combined/stacks.ome.tif'),
    ('20240510_Miguel_beads_20nm', 10, '/media/Backup/smlm_z_data/20240510_Miguel_beads/zstacks/20nm_red/stacks/combined/stacks.ome.tif'),
    ('20240510_Miguel_beads_100nm', 10, '/media/Backup/smlm_z_data/20240510_Miguel_beads/zstacks/100nm_tetraspeck/beads/combined/stacks.ome.tif'),
]
from tifffile import imread


In [None]:
import seaborn as sns
stacks = imread(datasets[-1][2])
print(stacks.shape)




In [None]:
import matplotlib.pyplot as plt
from skimage.transform import resize
from sklearn.metrics import euclidean_distances

def get_central_bead_idx(df):
    df['dist'] = euclidean_distances(df[['x', 'y']].to_numpy(), [[df['x'].max()/2, df['y'].max()/2]])
    ref_idx = np.argsort(df['dist'].to_numpy())[1]
    return ref_idx

stack_size = 400
def subsample_stack(s):
    inten = s.max(axis=(1,2))
    peak = np.argmax(inten)
    half_stack = int(stack_size//2)
    return s[peak-half_stack:peak+half_stack]

def resize_side_profile(img):
    return resize(img, (img.shape[0]/4, 30))
        
def gen_xyz_profiles(s):

    s = subsample_stack(s)

    print('Stack shape', s.shape)
    fig, axs = plt.subplots(1, 3, layout='constrained', figsize=(10, 4))
    
    # xy view
    img = s.sum(axis=(0))
    axs[0].imshow(img)
    axs[0].set_title('XY')
    
    # zy
    img = s.sum(axis=1)
    axs[1].imshow(resize_side_profile(img))
    axs[1].set_title('ZY')

    # zx
    img = s.sum(axis=2)
    axs[2].imshow(resize_side_profile(img))
    axs[2].set_title('ZX')
    plt.show()

for stack in [
    '/home/miguel/Projects/data/all_openframe_beads/20231205_miguel_mitochondria/combined/stacks.ome.tif',
    '/media/Backup/smlm_z_data/20240510_Miguel_beads/zstacks/20nm_red/stacks/combined/stacks.ome.tif'
    ]:
    print(stack)
    locs = stack.replace('stacks.ome.tif', 'locs.hdf')
    locs = pd.read_hdf(locs, key='locs')
    idx = get_central_bead_idx(locs)
    gen_xyz_profiles(imread(stack)[idx])



In [None]:
s = stacks.max(axis=(2,3))
print(s.shape)
for x in s[0:50]:
    plt.plot(norm_zero_one(x))
plt.show()

In [None]:
import matplotlib.pyplot as plt
from tqdm import tqdm
def get_fwhm(s):
    profile = s.max(axis=(1,2))
    profile -= min(profile)
    max_val = max(profile)
    half_max = max(profile) / 2

    above_hm = np.argwhere(profile>half_max)
    first, last = above_hm[0], above_hm[-1]
    fwhm_px = (last - first)
    # plt.plot(np.arange(profile.shape[0]) * 10, profile)
    # plt.show()
    # print(first*10, last*10, fwhm)
    return fwhm_px
    
    

data_fwhms = []
for _, zstep, fpath in datasets:
    ds = imread(fpath)
    fwhms = []
    for s in tqdm(ds):
        fwhms.append(get_fwhm(s) * zstep)
    fwhms = np.array(fwhms).squeeze()
    data_fwhms.append(fwhms)
    

In [None]:
print(data_fwhms[0].shape)
plt.boxplot(data_fwhms, labels=[x[0] for x in datasets])
plt.title('FWHM measured on cut-out bead stacks from various datasets')
plt.ylabel('FWHM (z) in nm')
plt.xticks(rotation=90)

In [None]:
from tifffile import imread
import pandas as pd
from sklearn.metrics import euclidean_distances
import os
import numpy as np
from keras.metrics import mean_squared_error
from tqdm import trange
os.environ['CUDA_VISIBLE_DEVICES'] = ''
import tensorflow as tf

stacks = imread('/media/Backup/smlm_z_data/20240510_Miguel_beads/zstacks/20nm_red/stacks/combined/stacks.ome.tif')
locs = pd.read_hdf('/media/Backup/smlm_z_data/20240510_Miguel_beads/zstacks/20nm_red/stacks/combined/locs.hdf', key='locs')

def norm_zero_one(s):
    max_s = s.max()
    min_s = s.min()
    return (s - min_s) / (max_s - min_s)


def realign_beads(psfs, df, z_step, i=0):
    from sklearn.metrics import euclidean_distances
    ref_idx = get_central_bead_idx(df, i)
    ref_offset = df.iloc[ref_idx]['offset']

    ref_psf = norm_zero_one(psfs[ref_idx])
    ref_tf = tf.convert_to_tensor(ref_psf)
    rolls = []
    for idx in trange(df.shape[0]):
        if idx == ref_idx:
            roll = 0
        else:
            psf2 = norm_zero_one(psfs[idx])
            roll = -tf_find_optimal_roll(ref_tf, psf2)
        rolls.append(roll)
    rolls = np.array(rolls)
    df['offset'] = rolls * z_step
    df['offset'] += ref_offset
    return psfs, df



def tf_eval_roll(ref_psf, psf, roll):
    return tf.reduce_mean(mean_squared_error(ref_psf, tf.roll(psf, roll, axis=0)))
    
def tf_find_optimal_roll(ref_tf, img):
    
    img_tf = tf.convert_to_tensor(img)

    roll_range = ref_tf.shape[0]//4
    rolls = np.arange(-roll_range, roll_range).astype(int)
    errors = tf.map_fn(lambda roll: tf_eval_roll(ref_tf, img_tf, roll), rolls, dtype=tf.float64)
    # idx = 0
    # for roll in tqdm(rolls):
    #     error = tf.eval_roll(ref_tf, img_tf, roll)
    #     print(i, error)
    #     errors[idx] = error
    #     idx += 1

    best_roll = rolls[tf.argmin(errors).numpy()]
    # Prefer small backwards roll to large forwards roll
    if abs(best_roll - img.shape[0]) < best_roll:
        best_roll = best_roll - img.shape[0]

    return best_roll

def get_central_bead_idx(df, i):
    df['dist'] = euclidean_distances(df[['x', 'y']].to_numpy(), [[df['x'].max()/2, df['y'].max()/2]])
    ref_idx = np.argsort(df['dist'].to_numpy())[i]
    print(ref_idx)
    return ref_idx

locs_sets = [realign_beads(stacks, locs.copy(deep=True), 10, i)[1] for i in range(5)]


In [None]:
locs_sets2 = [realign_beads(stacks, locs.copy(deep=True), 10, i)[1] for i in range(5, 10)]

In [None]:
locs_sets[0]['offset']

In [None]:
offsets = np.array([locs_set['offset'].to_numpy() for locs_set in locs_sets+locs_sets2])

In [None]:
offsets.shape

In [None]:
import matplotlib.pyplot as plt

for o in offsets:
    o2 = o[0:10]
    o2 -= o2[0]
    plt.scatter(np.arange(o2.shape[0]), o2, alpha=0.1)
# plt.show()

mean_offsets = np.mean(offsets, axis=0)
mean_offsets -= mean_offsets[0]
plt.scatter(np.arange(mean_offsets.shape[0]), mean_offsets, marker='+')
plt.show()

In [None]:
mean_offsets = np.mean(offsets, axis=0)

In [None]:
locs['offset'] = mean_offsets

In [None]:
locs.to_hdf('/media/Backup/smlm_z_data/20240510_Miguel_beads/zstacks/20nm_red/stacks/combined/locs.hdf', key='locs')

In [None]:
from keras.applications import VGG19, MobileNetV3Small

MobileNetV3Small(input_shape=(64,64,3), weights='imagenet',
                                  include_top=False)


In [None]:
import os
import numpy as np

os.environ['CUDA_VISIBLE_DEVICES']=''
from tensorflow.keras import Sequential, layers
import tensorflow as tf

image_size = 64
imshape = (image_size, image_size)
img_preprocessing = Sequential([
    layers.Resizing(*imshape),
    layers.Lambda(tf.image.grayscale_to_rgb)
])


imgs = tf.convert_to_tensor(np.random.uniform(0, 1, size=(3, 15, 15, 1)))

imgs2 = img_preprocessing(imgs)
print(imgs2.shape)



In [None]:
import os
import numpy as np

os.environ['CUDA_VISIBLE_DEVICES']=''
from tifffile import imread

stacks = imread('/home/miguel/Projects/smlm_z/publication/VIT_openframe/stacks.ome.tif')
stack = np.concatenate(stacks[0:1, 60:120]).astype(float)

In [None]:
import sys
sys.path.append('/home/miguel/Projects/smlm_z/publication/')
from util import util

stack_c = stack[:, :, :, np.newaxis]

stack_norm = util._apply_img_norm((stack_c, None), None, {'norm': 'frame-min'})[0][0].numpy()

In [None]:
print(stack_norm.min(axis=(1,2,3)))
print(stack_norm.max(axis=(1,2,3)))

args = {
    'aug_gauss': 0,
    'aug_brightness': 0.2,
    'aug_poisson_lam': 0,
    'seed': 42
}



extra_aug = Sequential([], name='extra_aug')

if args['aug_gauss']:
    extra_aug.add(layers.GaussianNoise(stddev=args['aug_gauss'], seed=args['seed']))

if args['aug_brightness']:
    extra_aug.add(layers.RandomBrightness(args['aug_brightness'], value_range=[0, 1], seed=args['seed']))
    
# layers.RandomTranslation(1/imshape[0], 1/imshape[0], seed=args['seed']),
if args['aug_poisson_lam']:
    extra_aug.add(RandomPoissonNoise(imshape, 1, args['aug_poisson_lam'], seed=args['seed']))

stack_aug = extra_aug(stack_norm).numpy()
print(stack_aug.min(axis=(1,2,3)))
print(stack_aug.max(axis=(1,2,3)))

In [None]:
def _apply_img_norm(imgs):
    print(imgs.dtype)
    mins = tf.math.reduce_min(imgs, axis=(1,2), keepdims=True)
    imgs -= mins
    maxs = tf.math.reduce_max(imgs, axis=(1,2), keepdims=True)
    imgs = tf.nn.relu(imgs / maxs)
    return imgs

data = data.map(_apply_img_norm)
data = data.batch(64)

In [None]:
import matplotlib.pyplot as plt

for batch in data.as_numpy_iterator():
    imgs = batch.squeeze()

for i in range(norm_stacks.shape[1]):
    plt.figure(figsize=(1, 1), dpi=60)
    plt.imshow(norm_stacks[0, i])
    plt.show()
    plt.figure(figsize=(1, 1), dpi=60)
    plt.imshow(imgs[i])
    plt.show()
    print('---------')


print(d.shape)
plt.plot(d.max(axis=(1,2)))
plt.show()
    

In [None]:
import numpy as np

psfs = np.random.uniform(0, 1, size=(100, 10, 15, 15))

i = 0
psf_mins = psfs[i].min(axis=(1,2), keepdims=True)
print(psf_mins.shape)
psfs[i] -= psf_mins
print(psf_mins.shape)
psf_sums = psfs[i].sum(axis=(1,2), keepdims=True)
psfs[i] /= psf_sums
print(psfs[i].sum(axis=(1,2)))

In [None]:
##### import os
os.environ['CUDA_VISIBLE_DEVICES'] = ''

import tensorflow as tf
import joblib

run_dir = '/home/miguel/Projects/smlm_z/publication/VIT_Zeiss_gauss_aug5/out_roll_alignment/'
train_data = tf.data.Dataset.load(run_dir + 'train')

for (imgs, xy), z in train_data.as_numpy_iterator():
    break

img = imgs[500:502]



from tensorflow.keras import layers, Sequential

class RandomPoissonNoise(layers.Layer):
    def __init__(self, shape, lam_min, lam_max, rescale=65336, seed=42):
        super(RandomPoissonNoise, self).__init__()
        tf.random.set_seed(seed)

        self.shape = shape
        self.lam_min = lam_min
        self.lam_max = lam_max
        self.rescale = rescale

    def call(self, input, training=False):
        if training==False:
            return input
        lam = tf.random.uniform((1,), self.lam_min, self.lam_max)[0]
        noise = tf.random.poisson(self.shape, lam, dtype=tf.float32) / self.rescale
        return input + noise


extra_aug = Sequential([
    layers.GaussianNoise(stddev=0.005, seed=42),
    layers.RandomTranslation(0.1, 0.1, seed=42),
    layers.RandomBrightness(0.4, value_range=[0, 1], seed=42),
    RandomPoissonNoise(img.shape, 100, 10000)
], name='extra_aug')

print(img.min(), img.max())
plt.figure(figsize=(2, 2), dpi=60)
plt.imshow(img[0].mean(axis=-1))
plt.show()


def norm_zero_one(x):
    return (x-x.min()) / (x.max()-x.min())


for _ in range(5):
    aug_img = extra_aug(img, training=True).numpy()[0]
    print(aug_img.min(), aug_img.max())
    plt.figure(figsize=(2, 2), dpi=60)
    plt.imshow(aug_img.mean(axis=-1))
    plt.show()


In [None]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = ''

import tensorflow as tf
import joblib

run_dir = '/home/miguel/Projects/smlm_z/publication/VIT_Zeiss_gauss_aug10/out_roll_alignment/'
train_data = tf.data.Dataset.load(run_dir + 'train')


    
import numpy as np
train_img_mins = []
train_img_maxs = []
train_img_means = []
for (imgs, xy), z in train_data.as_numpy_iterator():
    train_img_mins.append(imgs.min(axis=(1,2,3)))
    train_img_maxs.append(imgs.max(axis=(1,2,3)))
    train_img_means.append(imgs.mean(axis=(1,2,3)))

train_img_mins = np.concatenate(train_img_mins)
train_img_maxs = np.concatenate(train_img_maxs)
train_img_means = np.concatenate(train_img_means)
    




In [None]:
import h5py 

datagen = joblib.load(run_dir + 'datagen.gz')

exp_data = '/media/Data/smlm_z_data/20231121_nup_miguel_zeiss/FOV1/storm_1/old_locs/storm_1_MMStack_Default.ome_spots.hdf5'
with h5py.File(exp_data, 'r') as f:
    spots = np.array(f['spots']).astype(float)

BASELINE = 100
SENSITIVITY = 1
GAIN = 1

spots = (spots * GAIN / SENSITIVITY) + BASELINE

spots = datagen.standardize(spots)
spots_mins = spots.min(axis=(1,2))
spots_maxs = spots.max(axis=(1,2))
spots_mean = spots.mean(axis=(1,2))

In [None]:
import matplotlib.pyplot as plt
plt.boxplot((spots_mins, train_img_mins))
plt.show()
plt.boxplot((spots_maxs, train_img_maxs))
plt.show()
plt.boxplot((spots_mean, train_img_means))
plt.show()

In [None]:
from sklearn.metrics import mean_absolute_error
import matplotlib.pyplot as plt
from data.visualise import grid_psfs
from tensorflow.keras import layers, Sequential
import numpy as np
    
SEED = 42
n = 200
for (imgs, xy), z in test_data.as_numpy_iterator():
    imgs = imgs[:n]
    xy = xy[:n]
    z = z[:n]
    plt.imshow(grid_psfs(imgs.mean(axis=-1)).T)
    plt.show()
    z_pred = model.predict((imgs, xy), verbose=False).squeeze()
    mae = mean_absolute_error(z_pred, z)

    plt.title(str(round(mae, 2)))
    plt.scatter(z, z_pred, marker='x')
    plt.show()


    
    aug = Sequential([
        layers.RandomBrightness(0.5, value_range=[0, 1], seed=SEED)
    ])

    aug_imgs = aug(imgs).numpy()

    plt.imshow(grid_psfs(aug_imgs.mean(axis=-1)).T)
    plt.show()
    z_pred = model.predict((aug_imgs, xy), verbose=False).squeeze()

    idx = np.argwhere(np.sum(aug_imgs, axis=(1,2,3))!=0).squeeze()
    z = z[idx]
    z_pred = z_pred[idx]
    
    mae = mean_absolute_error(z_pred, z)

    plt.title(str(round(mae, 2)))
    plt.scatter(z, z_pred, marker='x')
    plt.show()
    

    
    break

In [None]:
import pandas as pd
import seaborn as sns
import numpy as np
from scipy.stats import norm
import matplotlib.pyplot as plt

df = pd.read_hdf('/home/miguel/Projects/smlm_z/publication/VIT_openframe_newaug5/out_roll_alignment/out_nup/locs_3d.hdf5', key='locs')
picked_locs = pd.read_hdf('/home/miguel/Projects/data/20230601_MQ_celltype/nup/fov2/storm_1/storm_1_MMStack_Default.ome_locs_undrifted_picked_4.hdf5', key='locs')
df = df.merge(picked_locs, on=['x', 'y', 'photons', 'bg', 'lpx', 'lpy', 'net_gradient', 'iterations', 'frame', 'likelihood', 'sx', 'sy'])
df['clusterID'] = df['group']
for i in range(100):
    
    data = df[df['group']==i][['z']].to_numpy()
    sns.histplot(data=data, bins=40, stat='density')

    from sklearn.mixture import GaussianMixture
    
    cov_type = 'full'
    gm = GaussianMixture(n_components=2, n_init=20, covariance_type=cov_type).fit(data)
    bic = gm.bic(data)
    
    labels = gm.predict(data).squeeze()
    
    weights = gm.weights_
    
    # sns.histplot(data=gm_df, x='pred', hue='cluster_id', stat='density', alpha=0.2, bins=20)
    
    # create necessary things to plot
    x_axis = np.linspace(data.min(), data.max(), 50)
    ys = []
    sub_df2 = pd.DataFrame.from_dict({'x': x_axis})
    for i in range(0, gm.n_components):
        if cov_type == 'tied':
            cov = gm.covariances_.squeeze()
        elif cov_type == 'full' or cov_type == None:
            cov = gm.covariances_[i][0][0]
        elif cov_type == 'spherical':
            cov = gm.covariances_[i]
        elif cov_type == 'diag':
            cov_type = gm.covariances_[i]
    
        sub_df2[f'y_{i}'] = norm.pdf(x_axis, float(gm.means_[i][0]), np.sqrt(cov))*gm.weights_[i]
        sns.lineplot(data=sub_df2, x='x', y=f'y_{i}')

    diff = gm.means_.max() - gm.means_.min()
    if 40 < diff and diff < 60:
        plt.show()
    else:
        plt.close()


In [None]:
# Attention map for VIT model
import matplotlib.pyplot as plt
from vit_keras import visualize
import tensorflow as tf
import numpy as np
import h5py
import joblib

dataset = '/home/miguel/Projects/smlm_z/publication/VIT_openframe_newaug3/out_roll_alignment/test'
model = '/home/miguel/Projects/smlm_z/publication/VIT_openframe_newaug3/out_roll_alignment/latest_vit_model3'
model = tf.keras.models.load_model(model)
model = model.layers[2]
train_data = tf.data.Dataset.load(dataset)

def norm_zero_one(x):
    return (x-x.min()) / (x.max()-x.min())
    
for (imgs, xy), z in train_data.as_numpy_iterator():
    for i in range(30, 50):
        # for i in range(imgs.shape[0]):
        #     plt.figure(figsize=(1, 1))
        #     plt.title(str(i))
        #     plt.imshow(norm_zero_one(imgs[i]))
        #     plt.show()
        image = imgs[i]
        print(image.min(), image.max())
        attention_map = visualize.attention_map(model=model, image=norm_zero_one(image)*255)
        print(attention_map.min(), attention_map.max())
        fig, (ax1, ax2) = plt.subplots(ncols=2)
        ax1.axis('off')
        ax2.axis('off')
        ax1.set_title('Original')
        ax2.set_title('Attention Map')
        _ = ax1.imshow(norm_zero_one(image))
        _ = ax2.imshow(norm_zero_one(attention_map))
        plt.show()
    break

In [None]:
# Compare pixel vals between training data and exp data
import tensorflow as tf
import numpy as np
import h5py
import joblib

dataset = '/home/miguel/Projects/smlm_z/publication/VIT_openframe_no_imagenet2/out_roll_alignment/train'
locs = '/home/miguel/Projects/data/20230601_MQ_celltype/nup/fov2/storm_1/storm_1_MMStack_Default.ome_spots.hdf5'
datagen = joblib.load('/home/miguel/Projects/smlm_z/publication/VIT_openframe_no_imagenet2/out_roll_alignment/datagen.gz')

train_data = tf.data.Dataset.load(dataset)

psfs = []
i = 0
for (psf, xy), z in train_data.as_numpy_iterator():
    psfs.append(psf)
    i += 1
    if i == 10:
        break
    
psfs = np.concatenate(psfs)

with h5py.File(locs, 'r') as f:
    nup_spots = np.array(f['spots']).astype(np.uint16)
nup_spots = datagen.standardize(nup_spots.astype(float))

import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns

def snr(x):
    return x.max() / np.median(x)

for op in [np.max, np.min, np.mean, snr]:
    df = pd.DataFrame.from_dict({'val': [op(x) for x in psfs]})
    df['ds'] = 'beads'
    df2 = pd.DataFrame.from_dict({'val': [op(x) for x in nup_spots]})
    df2['ds'] = 'locs' 
    df = pd.concat((df, df2))
    fig = plt.figure(figsize=(5, 3))

    plt.title(str(op.__name__))
    sns.boxplot(data=df, x='ds', y='val')
    plt.show()

In [None]:
print(psfs.mean(), np.std(psfs))

In [None]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = ''
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from tifffile import imread
import h5py


locs = '/home/miguel/Projects/smlm_z/publication/simul/fd-loco-simul/nup_spots.hdf5'
with h5py.File(locs, 'r') as f:
    simul_spots = np.array(f['spots']).astype(np.uint16)

locs = '/home/miguel/Projects/data/fd-loco/roi_startpos_810_790_split.ome_spots.hdf5'
with h5py.File(locs, 'r') as f:
    real_spots = np.array(f['spots']).astype(np.uint16)

print(simul_spots.min(), simul_spots.max(), simul_spots.mean())
print(real_spots.min(), real_spots.max(), real_spots.mean())



In [None]:
from data.visualise import show_psf_axial
keep_beads = []

for i in range(len(beads)):
    max_val = np.argmax(beads[i].max(axis=(1,2)))
    keep_beads.append(beads[i][max_val-100:max_val+100])

keep_beads = np.concatenate(keep_beads)


In [None]:
keep_beads = keep_beads[:, :, :, np.newaxis].astype(float)
spots = spots[:, :, :, np.newaxis].astype(float)


In [None]:
keep_beads.max()

In [None]:
def snr(x):
    return x.max() / np.median(x)


for op in [np.max, np.min, np.mean, snr]:
    
    df = pd.DataFrame.from_dict({'val': [op(x) for x in keep_beads]})
    df['ds'] = 'beads'
    df4 = pd.DataFrame.from_dict({'val': [op(x) for x in spots]})
    df4['ds'] = 'locs' 
    df = pd.concat((df, df4))
    fig = plt.figure(figsize=(5, 3))

    plt.title(str(op.__name__))
    sns.boxplot(data=df, x='ds', y='val')
    plt.show()


In [None]:
def norm(x, maxval=1):
    return ((x - x.min()) / (x.max() - x.min())) * maxval


keep_beads = np.stack([norm(x) for x in keep_beads])
spots = np.stack([norm(x) for x in spots])

print(keep_beads.max(), spots.max())
print(keep_beads.min(), spots.min())
print(keep_beads.mean(), spots.mean())

In [None]:

fake_images = np.random.uniform(0, 255, size=(4, 255, 255, 3))

def norm_zero_one(imgs):
    mins = imgs.min(axis=(1,2,3), keepdims=True)
    maxs = imgs.max(axis=(1,2,3), keepdims=True)
    return (imgs-mins) / (maxs-mins)

norm_imgs = norm_zero_one(fake_images)
print(norm_imgs.min(axis=(1,2,3)))
print(norm_imgs.max(axis=(1,2,3)))

In [None]:
from tensorflow.keras import Sequential, layers

aug = Sequential([
   layers.RandomBrightness([-0.2, 0.2], value_range=(0, 1)),
   layers.GaussianNoise(0.5),
   layers.RandomContrast(0.2)
])
# new_aug = aug(keep_beads.copy()).numpy()

# for i in [0, 25, 50, 600]:
#     plt.figure(figsize=(1,1))
#     img1 = keep_beads[i]
#     print(img1.min(), img1.max(), img1.dtype)
#     img2 = aug(img1).numpy()
#     print(img2.min(), img2.max(), img2.dtype)
#     img = np.concatenate((norm(img1), norm(img2)), axis=1)
#     plt.imshow(img)
#     plt.show()



In [None]:
from tensorflow.keras import Sequential, layers

aug_pipeline = Sequential([
    layers.GaussianNoise(stddev=0.001*keep_beads.max()),
    # layers.RandomTranslation(MAX_TRANSLATION_PX/img_size, MAX_TRANSLATION_PX/img_size, seed=args['seed']),
    layers.RandomBrightness(0.01, value_range=[0, keep_beads.max()]),
])

old_aug = aug_pipeline(keep_beads.copy()).numpy()[:, :, : np.newaxis]
old_aug = np.concatenate((old_aug, keep_beads))

# aug = Sequential([
#    layers.RandomBrightness([-0.4, 0], value_range=(0, 1)),
#    layers.GaussianNoise(0.5),
#    layers.RandomContrast(0.5)
# ])

new_aug = (aug(keep_beads.copy()).numpy())[:, :, : :, np.newaxis]
new_aug = np.concatenate((new_aug, keep_beads))




In [None]:
def snr(x):
    return x.max() / np.median(x)


for op in [np.max, np.min, np.mean, snr]:
    
    df = pd.DataFrame.from_dict({'val': [op(x) for x in keep_beads]})
    df['ds'] = 'beads'
    df2 = pd.DataFrame.from_dict({'val': [op(x) for x in old_aug]})
    df2['ds'] = 'old_aug'
    df3 = pd.DataFrame.from_dict({'val': [op(x) for x in new_aug]})
    df3['ds'] = 'new_aug'
    df4 = pd.DataFrame.from_dict({'val': [op(x) for x in spots]})
    df4['ds'] = 'locs' 
    df = pd.concat((df, df2, df3, df4))
    fig = plt.figure(figsize=(5, 3))

    plt.title(str(op.__name__))
    sns.boxplot(data=df, x='ds', y='val')
    plt.show()


In [None]:
from tensorflow.keras.preprocessing.image import ImageDataGenerator
datagen = ImageDataGenerator(
    rescale=1.0/65336.0,
    samplewise_center=False,
    samplewise_std_normalization=False,
    featurewise_center=True,
    featurewise_std_normalization=True,
    horizontal_flip=False)

datagen.fit(keep_beads)
std_beads = datagen.standardize(keep_beads.copy())
std_spots = datagen.standardize(spots.copy())

print(old_aug.shape, keep_beads.shape)
old_aug_data = np.concatenate((old_aug, keep_beads))
datagen = ImageDataGenerator(
    rescale=1.0/65336.0,
    samplewise_center=False,
    samplewise_std_normalization=False,
    featurewise_center=True,
    featurewise_std_normalization=True,
    horizontal_flip=False)
datagen.fit(old_aug_data)
old_aug_data_norm = datagen.standardize(old_aug_data.copy())
old_aug_spots = datagen.standardize(spots.copy())

new_aug_data = np.concatenate((new_aug, keep_beads))
datagen = ImageDataGenerator(
    rescale=1.0/65336.0,
    samplewise_center=False,
    samplewise_std_normalization=False,
    featurewise_center=True,
    featurewise_std_normalization=True,
    horizontal_flip=False)
datagen.fit(new_aug_data)
new_aug_data_norm = datagen.standardize(new_aug_data.copy())
new_aug_spots = datagen.standardize(spots.copy())


In [None]:
def snr(x):
    return x.max() / np.median(x)

for op in [np.max, np.min, np.mean, snr]:
    df = pd.DataFrame.from_dict({'val': [op(x) for x in old_aug_data_norm]})
    df['ds'] = 'beads'
    df2 = pd.DataFrame.from_dict({'val': [op(x) for x in old_aug_spots]})
    df2['ds'] = 'locs' 
    df3 = pd.DataFrame.from_dict({'val': [op(x) for x in new_aug_data_norm]})
    df3['ds'] = 'beads_2'
    df4 = pd.DataFrame.from_dict({'val': [op(x) for x in new_aug_spots]})
    df4['ds'] = 'locs_2' 
    df = pd.concat((df, df2, df3, df4))
    fig = plt.figure(figsize=(5, 3))

    plt.title(str(op.__name__))
    sns.boxplot(data=df, x='ds', y='val')
    plt.show()


In [None]:
snrs_beads = [max(x) for x in keep_beads]
snrs_locs = [max(x) for x in spots]

plt.hist(snrs_beads, label='beads')
plt.hist(snrs_locs, label='locs')
plt.legend()


snrs_beads = [max(x) for x in keep_beads]
snrs_locs = [max(x) for x in spots]

plt.hist(snrs_beads, label='beads')
plt.hist(snrs_locs, label='locs')
plt.legend()

In [None]:
import os
os.environ['CUDA_VISIBLE_DEVICES']=''
import tensorflow as tf

ds = tf.data.Dataset.load('/home/miguel/Projects/smlm_z/publication/VIT_Zeiss_gauss_aug/out_roll_alignment/train')

imgs = []
for (img, xy), z in ds.as_numpy_iterator():
    imgs.append(img)



In [None]:
imgs = np.concatenate(imgs)
print(imgs.shape)

In [None]:
print(spots.min(), spots.max(), spots.mean())
print(imgs.min(), imgs.max(), imgs.mean())

In [None]:
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

# new_nup = pd.read_csv('/home/miguel/Projects/smlm_z/publication/VIT_openframe_newaug/out_roll_alignment/out_nup/nup_renders3/nup_report.csv')
# old_nup = pd.read_csv('/home/miguel/Projects/smlm_z/publication/VIT_openframe/out_roll_alignment/out_nup/nup_renders3/nup_report.csv')


new_nup = pd.read_csv('/home/miguel/Projects/smlm_z/publication/VIT_Zeiss_gauss_aug/out_roll_alignment/out_nup/nup_renders3/nup_report.csv')
old_nup = pd.read_csv('/home/miguel/Projects/smlm_z/publication/VIT_Zeiss/out_roll_alignment/out_nup/nup_renders3/nup_report.csv')

new_nup['dataset'] = 'new'
old_nup['dataset'] = 'old'

df = pd.concat((old_nup, new_nup))

sns.boxplot(data=df, x='dataset', y='seperation')
plt.legend()
plt.show()

def count_valid(x):
    return len(np.argwhere((x>40) & (x<60)))

print(count_valid(new_nup['seperation']))
print(count_valid(old_nup['seperation']))

# new_locs = pd.read_hdf('/home/miguel/Projects/smlm_z/publication/VIT_Zeiss_gauss_aug/out_roll_alignment/out_nup/locs_3d.hdf5', key='locs')
# old_locs = pd.read_hdf('/home/miguel/Projects/smlm_z/publication/VIT_Zeiss/out_roll_alignment/out_nup/locs_3d.hdf5', key='locs')


In [None]:
sns.histplot(new_locs['z [nm]'])
sns.histplot(old_locs['z [nm]'])


In [None]:

import sys, os



# # TODO remove this
if not os.environ.get('CUDA_VISIBLE_DEVICES'):
    os.environ['CUDA_VISIBLE_DEVICES']='0'

import joblib
import json
import shutil
import argparse
import pandas as pd
import h5py
import numpy as np
import seaborn as sns
import matplotlib
import matplotlib.pyplot as plt
from tensorflow.keras.layers import Resizing, Lambda
from tensorflow.keras import Sequential
import tensorflow as tf
from picasso import io

from data.visualise import grid_psfs



N_GPUS = max(1, len(tf.config.experimental.list_physical_devices("GPU")))


VERSION = '0.1'



# Picasso localisation parameters
BASELINE = 100
SENSITIVITY = 1
GAIN = 1


DEFAULT_LOCS = None
DEFAULT_SPOTS = None
DEFAULT_PIXEL_SIZE = None
PICKED = None
XLIM, YLIM = None, None

# NUP FD-LOCO
# DEFAULT_LOCS = '/home/miguel/Projects/data/fd-loco/roi_startpos_810_790_split.ome_locs.hdf5'
# DEFAULT_SPOTS = '/home/miguel/Projects/data/fd-loco/roi_startpos_810_790_split.ome_spots.hdf5'
# PICKED = '/home/miguel/Projects/data/fd-loco/roi_startpos_810_790_split.ome_locs_picked.hdf5'
# DEFAULT_PIXEL_SIZE = 110
# XLIM, YLIM = None, None



# NUP OPENFRAME
# DEFAULT_LOCS = '/home/miguel/Projects/data/20230601_MQ_celltype/nup/fov2/storm_1/storm_1_MMStack_Default.ome_locs_undrifted.hdf5'
# DEFAULT_SPOTS = '/home/miguel/Projects/data/20230601_MQ_celltype/nup/fov2/storm_1/storm_1_MMStack_Default.ome_spots.hdf5'
# PICKED = '/home/miguel/Projects/data/20230601_MQ_celltype/nup/fov2/storm_1/storm_1_MMStack_Default.ome_locs_undrifted_picked_4.hdf5'
# DEFAULT_PIXEL_SIZE = 86
# XLIM, YLIM = None, None



# Zeiss
# DEFAULT_LOCS = '/media/Data/smlm_z_data/20231121_nup_miguel_zeiss/FOV1/storm_1/storm_1_MMStack_Default.ome_locs_undrifted.hdf5'
# DEFAULT_SPOTS = '/media/Data/smlm_z_data/20231121_nup_miguel_zeiss/FOV1/storm_1/storm_1_MMStack_Default.ome_spots.hdf5'
# PICKED = '/media/Data/smlm_z_data/20231121_nup_miguel_zeiss/FOV1/storm_1/storm_1_MMStack_Default.ome_locs_picked.hdf5'
# DEFAULT_PIXEL_SIZE = 106
# XLIM, YLIM = None, None


# Unused below

# # Mitochondria (older)
# DEFAULT_LOCS = '/home/miguel/Projects/data/20231205_miguel_mitochondria/mitochondria/FOV2/storm_1/storm_1_MMStack_Default.ome_locs_undrifted.hdf5'
# DEFAULT_SPOTS = '/home/miguel/Projects/data/20231205_miguel_mitochondria/mitochondria/FOV2/storm_1/storm_1_MMStack_Default.ome_spots.hdf5'
# DEFAULT_PIXEL_SIZE = 86
# XLIM, YLIM = None, None

# Mitochondria (newer) (still not clearly working)
# DEFAULT_LOCS = '/media/Data/smlm_z_data/20231212_miguel_openframe/mitochondria/FOV2/storm_1/storm_1_MMStack_Default.ome_locs_undrift.hdf5'
# DEFAULT_SPOTS = '/media/Data/smlm_z_data/20231212_miguel_openframe/mitochondria/FOV2/storm_1/storm_1_MMStack_Default.ome_spots.hdf5'
# PICKED = None
# DEFAULT_PIXEL_SIZE = 86
# XLIM = 400, 600
# YLIM = 700, 1000



# Tubulin
# DEFAULT_LOCS = '/media/Data/smlm_z_data/20231212_miguel_openframe/tubulin/FOV1/storm_1/storm_1_MMStack_Default.ome_locs_undrifted.hdf5'
# DEFAULT_SPOTS = '/media/Data/smlm_z_data/20231212_miguel_openframe/tubulin/FOV1/storm_1/storm_1_MMStack_Default.ome_spots.hdf5'
# PICKED = None
# DEFAULT_PIXEL_SIZE = 86
# XLIM = 200, 800
# YLIM = 500, 1000

def write_arg_log(args):
    outfile = os.path.join(args['outdir'], 'config.json')
    with open(outfile, 'w') as fp:
        json_dumps_str = json.dumps(args, indent=4)
        print(json_dumps_str, file=fp)


def save_copy_script(outdir):
    outpath = os.path.join(outdir, 'localise_exp_sample.py.bak')
    shutil.copy(os.path.abspath(__file__), outpath)


def gen_2d_plot(locs, outdir):
    print('Gen 2d plot')
    sns.scatterplot(data=locs, x='x', y='y', marker='.', alpha=0.1)
    plt.axis('equal')
    plt.gca().invert_yaxis()
    plt.savefig(os.path.join(outdir, '2d_scatterplot.png'))
    plt.close()


def gen_example_spots(spots, outdir):
    print('Gen example splots')
    plt.rcParams['figure.figsize'] = [10, 10]
    plt.imshow(grid_psfs(spots[0:100]))
    plt.savefig(os.path.join(outdir, 'example_spots.png'))
    plt.close()


def apply_normalisation(locs, spots, args):
    print('Applying pre-processing')
    scaler = joblib.load(args['coords_scaler'])
    datagen = joblib.load(args['datagen'])

    coords = scaler.transform(locs[['x', 'y']].to_numpy())
    spots = datagen.standardize(spots.astype(np.float32))[:, :, :, np.newaxis]

    return coords, spots



def pred_z(model, spots, coords):

    spots = spots.astype(np.float32)
    print('Predicting z locs')

    # exp_spots = tf.data.Dataset.from_generator(
    #     generator=lambda: iter(spots),
    #     output_signature=tf.TensorSpec(shape=spots.shape[1:], dtype=tf.float32)
    # )
    exp_spots = tf.data.Dataset.from_tensor_slices(spots)
    exp_coords = tf.data.Dataset.from_tensor_slices(coords)

    exp_X = tf.data.Dataset.zip((exp_spots, exp_coords))

    fake_z = np.zeros((coords.shape[0],))
    exp_z = tf.data.Dataset.from_tensor_slices(fake_z)

    exp_data = tf.data.Dataset.zip((exp_X, exp_z))

    image_size = 64
    imshape = (image_size, image_size)
    img_preprocessing = Sequential([
        Resizing(*imshape),
        Lambda(tf.image.grayscale_to_rgb)
    ])

    def apply_rescaling(x, y):
        x = [x[0], x[1]]
        x[0] = img_preprocessing(x[0])
        return tuple(x), y

    BATCH_SIZE = 2048
    exp_data = exp_data.map(apply_rescaling, num_parallel_calls=tf.data.AUTOTUNE).batch(BATCH_SIZE)

    pred_z = model.predict(exp_data, batch_size=BATCH_SIZE, workers=4)

    

    # sns.histplot(pred_z)
    # plt.show()
    # plt.savefig(os.path.join(outdir, 'z_histplot.png'))
    # plt.close()
    return pred_z

def write_locs(locs, z_coords, args):
    locs['z [nm]'] = z_coords
    locs['z'] = locs['z [nm]']
    # locs['z'] = z_coords / args['pixel_size']
    locs['x [nm]'] = locs['x'] * args['pixel_size']
    locs['y [nm]'] = locs['y'] * args['pixel_size']

    locs_path = os.path.join(args['outdir'], 'locs_3d.hdf5')
    with h5py.File(locs_path, "w") as locs_file:
        locs_file.create_dataset("locs", data=locs.to_records())

    yaml_file = args['locs'].replace('.hdf5', '.yaml')
    if os.path.exists(yaml_file):
        dest_yaml = locs_path.replace('.hdf5', '.yaml')
        shutil.copy(yaml_file, dest_yaml)
    else:
        dest_yaml = None
        print('Could not write yaml file (original from 2D localisation not found)')
    print('Wrote results to:')
    print(f'\t- {os.path.abspath(locs_path)}')
    if dest_yaml:
        print(f'\t- {os.path.abspath(dest_yaml)}')


def write_report_data(args):
    report_data = {
        'code_version': VERSION
    }
    report_data.update(args)
    with open(os.path.join(args['outdir'], 'report.json'), 'w') as fp:
        json_dumps_str = json.dumps(report_data, indent=4)
        print(json_dumps_str, file=fp)


def extract_fov(spots, locs):
    print(locs.shape)
    idx = np.argwhere((XLIM[0]<locs['x']) & (XLIM[1]>locs['x']) & (YLIM[0]<locs['y']) & (YLIM[1]>locs['y'])).squeeze()
    spots = spots[idx]
    locs = locs.iloc[idx]
    return spots, locs

def tmp_filter_locs(new_locs, spots, args):
    old_locs = pd.read_hdf(args['picked_locs'], key='locs')

    idx = np.argwhere(new_locs['x'].isin(old_locs['x'])).squeeze()
    new_locs = new_locs.iloc[idx]
    spots = spots[idx]
    return new_locs, spots



args = {
    'model': '/home/miguel/Projects/smlm_z/publication/VIT_Zeiss/out_roll_alignment/latest_vit_model3',
    'locs': '/media/Data/smlm_z_data/20231121_nup_miguel_zeiss/FOV1/storm_1/storm_1_MMStack_Default.ome_locs_undrifted.hdf5',
    'spots': '/media/Data/smlm_z_data/20231121_nup_miguel_zeiss/FOV1/storm_1/storm_1_MMStack_Default.ome_spots.hdf5',
    'coords_scaler': '/home/miguel/Projects/smlm_z/publication/VIT_Zeiss/out_roll_alignment/scaler.save',
    'datagen': '/home/miguel/Projects/smlm_z/publication/VIT_Zeiss/out_roll_alignment/datagen.gz',
    'picked_locs': '/media/Data/smlm_z_data/20231121_nup_miguel_zeiss/FOV1/storm_1/storm_1_MMStack_Default.ome_locs_picked.hdf5'
}

mirrored_strategy = tf.distribute.MirroredStrategy()

with mirrored_strategy.scope():
    model = tf.keras.models.load_model(args['model'])

locs, info = io.load_locs(args['locs'])
locs = pd.DataFrame.from_records(locs)

with h5py.File(args['spots'], 'r') as f:
    spots = np.array(f['spots']).astype(np.uint16)

spots = (spots * GAIN / SENSITIVITY) + BASELINE



# # TODO remove temp subset of locs
if args['picked_locs']:
    locs, spots = tmp_filter_locs(locs, spots, args)

assert locs.shape[0] == spots.shape[0]
print(locs.shape)
if XLIM or YLIM:
    spots, locs = extract_fov(spots, locs)

# gen_2d_plot(locs, args['outdir'])
# gen_example_spots(spots, args['outdir'])
coords, spots = apply_normalisation(locs, spots, args)
print(coords.shape)

z_coords = pred_z(model, spots, coords)



In [None]:
z_coords2 = pred_z(model, spots, np.zeros(coords.shape))
plt.scatter(z_coords, z_coords2)
plt.show()
print(np.abs(z_coords-z_coords2).mean())

In [None]:
im1 = grid_psfs(spots[0:100].mean(axis=-1))
im2 = im1+np.random.normal(0, 0.05, size=im1.shape)
plt.imshow(im1)
plt.show()

plt.imshow(im2)
plt.show()

plt.imshow(abs(im2-im1))
plt.show()

In [None]:
from argparse import ArgumentParser
import os
from glob import glob
from natsort import natsorted
import subprocess
import h5py
import pandas as pd
pd.options.mode.chained_assignment = None

import numpy as np
from sklearn.metrics import euclidean_distances
from tifffile import imread, imwrite
from skimage.feature import match_template
from skimage.filters import butterworth
from skimage.filters import gaussian
import matplotlib.pyplot as plt
from tqdm import tqdm, trange
from scipy.optimize import curve_fit
from scipy.special import erf
from sklearn.metrics import mean_squared_error
from scipy.interpolate import UnivariateSpline
from scipy.ndimage import gaussian_filter
import json
import shutil
from data.visualise import grid_psfs
import seaborn as sns

def norm_zero_one(s):
    max_s = s.max()
    min_s = s.min()
    return (s - min_s) / (max_s - min_s)

def validate_args(args):
    args['bead_stacks'] = [b for b in args['bead_stacks'] if 'ignored' not in b]
    n_stacks = len(args['bead_stacks'])
    print(f"Found {n_stacks} bead stacks")
    if n_stacks == 0:
        quit(1)
    for f in natsorted(args['bead_stacks']):
        print(f'\t - {f}')


def test_picasso_exec():
    res = subprocess.run(['picasso', '-h'], capture_output=True, text=True)
    if res.returncode != 0:
        print(res.stdout)
        print(res.stderr)
        print('\n')
        raise EnvironmentError('Picasso not found/working (see above)')

def transform_args(args):
    fnames = glob(f"{args['bead_stacks']}/**/*.tif", recursive=True)

    args['outpath'] = args['bead_stacks']
    fnames = [os.path.abspath(f) for f in fnames if '_slice.ome.tif' not in f and os.path.basename(f) != 'stacks.ome.tif']
    args['bead_stacks'] = fnames

    args['gaussian_blur'] = list(map(int, args['gaussian_blur'].split(',')))
    return args

def get_or_create_slice(bead_stack, slice_path):
    if not os.path.exists(slice_path):
        im_slice = bead_stack[bead_stack.shape[0]//2]
        # plt.imshow(im_slice)
        # plt.show()
        imwrite(slice_path, im_slice.astype(np.uint16))
    return slice_path

def get_or_create_locs(slice_path, args):
    spots_path = slice_path.replace('.ome.tif', '.ome_spots.hdf5')
    locs_path = slice_path.replace('.ome.tif', '.ome_locs.hdf5')

    if not os.path.exists(spots_path) or not os.path.exists(locs_path) or args['regen']:
        cmd = ['picasso', 'localize', slice_path, '-b', args['box_size_length'], '-g', args['gradient'], '-px', args['pixel_size']]
        print(f'Running {" ".join(list(map(str, cmd)))}')
        for extra_arg in ['qe', 'sensitivity', 'gain', 'baseline', 'fit-method']:
            if extra_arg in args and args[extra_arg]:
                cmd.extend([f'-{extra_arg}', args[extra_arg]])
        cmd = ' '.join(list(map(str, cmd)))
        tqdm.write('Running picasso...', end='')
        res = subprocess.run(cmd, capture_output=True, shell=True, text=True)
        if res.returncode != 0:
            print('Picasso error occured')
            print(res.stdout)
            print(res.stderr)
            return
        tqdm.write('finished!')

    with h5py.File(spots_path) as f:
        spots = np.array(f['spots'])

    locs = pd.read_hdf(locs_path, key='locs')
    locs['fname'] = '___'.join(slice_path.split('/')[-3:])
    print(f'Found {locs.shape[0]} beads')
    return locs, spots

def remove_colocal_beads(locs, spots, args):
    tqdm.write('Removing overlapping beads...')
    coords = locs[['x', 'y']].to_numpy()
    dists = euclidean_distances(coords, coords)
    np.fill_diagonal(dists, np.inf)
    min_dists = dists.min(axis=1)
    
    error_margin = 0.8
    min_seperation = (np.sqrt(2)  * args['box_size_length']) * error_margin
    idx = np.argwhere(min_dists > min_seperation).squeeze()
    locs = locs.iloc[idx]
    spots = spots[idx]

    return locs, spots


def extract_training_stacks(spots, bead_stack, args) -> np.array:
    spot_size = args['box_size_length']
    frame_idx = bead_stack.shape[0]//2
    frame = bead_stack[frame_idx]
    stacks = []
    for spot in spots:
        res = match_template(frame, spot)
        i, j = np.unravel_index(np.argmax(res), res.shape)
        stack = bead_stack[:, i:i+spot_size, j:j+spot_size]
        stacks.append(stack)
    return np.array(stacks)

def snr(psf):
    return psf.max() / np.median(psf)


def has_fwhm(psf, args):
    psf = butterworth(psf, cutoff_frequency_ratio=0.2, high_pass=False)
    y = np.max(gaussian(psf), axis=(1,2))
    max_val = np.max(y)
    min_val = np.min(y)
    half_max = min_val + ((max_val-min_val) / 2)
    crossCount = np.sum((y[:-1]>half_max) != (y[1:]>half_max))
    # if args['debug'] and crossCount < 2:
    #     plt.plot(y, label='raw')
    #     plt.plot([0, len(y)], [half_max, half_max])
    #     plt.show()
    # if not (crossCount >= 2):
    #     fig = plt.figure(layout="constrained", figsize=(20, 15), dpi=64)
    #     gs = plt.GridSpec(1, 2, figure=fig)
    #     ax1 = fig.add_subplot(gs[0, 0])
    #     ax2 = fig.add_subplot(gs[0, 1])

    #     ax1.imshow(grid_psfs(psf, cols=20))
    #     ax2.plot(psf)
    #     plt.show()
    return crossCount >= 2


def filter_mse_zprofile(psf, args, i):
    z_step = args['zstep']

    # Define the skewed Gaussian function
    def skewed_gaussian(x, A, x0, sigma, alpha, offset):
        """
        A: Amplitude
        x0: Center
        sigma: Standard Deviation
        alpha: Skewness parameter
        offset: Vertical offset
        """
        return A * np.exp(-(x - x0)**2 / (2 * sigma**2)) * (1 + erf(alpha * (x - x0))) + offset


    # Fit the skewed Gaussian to the data
    x_data = np.arange(psf.shape[0]) * z_step
    y_data = psf.max(axis=(1,2))
    y_data = norm_zero_one(y_data)
    initial_guess = [1, psf.shape[0] * z_step / 2, psf.shape[0] * z_step/4, 0.0, np.median(y_data)]

    bounds = [
        (0.6, 1.2),
        (psf.shape[0] * z_step/8, psf.shape[0] * z_step),
        (psf.shape[0] * z_step/20, psf.shape[0] * z_step/4),
        (-np.inf, np.inf),
        (y_data.min(), y_data.max())
    ]
    try:
        params, _ = curve_fit(skewed_gaussian, x_data, y_data, p0=initial_guess, bounds=list(zip(*bounds)))
    except RuntimeError:
        print('Failed to find Z fit')
        params = initial_guess

    y_fit = skewed_gaussian(x_data, *params)

    mse = (y_fit - y_data) ** 2
    avg_mse = np.mean(mse)
    max_mse = np.max(mse)

    permitted_avg_mse = 0.02
    permitted_max_mse = 0.1
    # if (avg_mse < permitted_avg_mse and max_mse < permitted_max_mse):
    #     fig = plt.figure(layout="constrained", figsize=(10, 8), dpi=64)
    #     gs = plt.GridSpec(1, 2, figure=fig)
    #     ax1 = fig.add_subplot(gs[0, 0])
    #     ax2 = fig.add_subplot(gs[0, 1])

    #     ax1.imshow(grid_psfs(psf, cols=20))
    #     print(avg_mse, max_mse)
    #     ax2.plot(x_data, y_data)
    #     ax2.plot(x_data, y_fit)
    #     plt.show()
    return avg_mse < permitted_avg_mse and max_mse < permitted_max_mse
        

def get_sharpness(array):
    gy, gx = np.gradient(array)
    gnorm = np.sqrt(gx**2 + gy**2)
    sharpness = np.average(gnorm)
    return sharpness


def reduce_img(psf):
    return np.stack([get_sharpness(x) for x in psf])


def est_bead_offsets(psfs, locs, args):
    UPSCALE_RATIO = 10

    def denoise(img):
        
        sigmas = np.array(args['gaussian_blur'])
        return gaussian_filter(img.copy(), sigma=sigmas)

    def find_peak(psf):
        if psf.ndim == 4:
            psf = psf.mean(axis=-1)
        x = np.arange(psf.shape[0]) * args['zstep']
        psf = denoise(psf)
        
        inten = norm_zero_one(reduce_img(psf))

        cs = UnivariateSpline(x, inten, k=3, s=0.2)

        x_ups = np.linspace(0, psf.shape[0], len(x) * UPSCALE_RATIO) * args['zstep']

        peak_xups = x_ups[np.argmax(cs(x_ups))] 

        return peak_xups
    offsets = np.array(map(find_peak, psfs))

    locs['offset'] = offsets


def filter_mse_xy(stack, max_mse, i):
    # Define a 2D Gaussian function
    def gaussian_2d(xy, amplitude, xo, yo, sigma_x, sigma_y, theta, offset):
        x, y = xy
        xo = float(xo)
        yo = float(yo)
        a = (np.cos(theta)**2) / (2 * sigma_x**2) + (np.sin(theta)**2) / (2 * sigma_y**2)
        b = -(np.sin(2 * theta)) / (4 * sigma_x**2) + (np.sin(2 * theta)) / (4 * sigma_y**2)
        c = (np.sin(theta)**2) / (2 * sigma_x**2) + (np.cos(theta)**2) / (2 * sigma_y**2)
        g = offset + amplitude * np.exp(- (a * ((x - xo)**2) + 2 * b * (x - xo) * (y - yo) + c * ((y - yo)**2)))
        return g.ravel()

    sharp = reduce_img(stack)
    idx = np.argmax(sharp)
    image = stack[idx]


    # Load and preprocess the image (e.g., convert to grayscale)
    # For simplicity, let's generate a simple image for demonstration
    image_size = image.shape[1]
    x = np.linspace(0, image_size - 1, image_size)
    y = np.linspace(0, image_size - 1, image_size)
    x, y = np.meshgrid(x, y)
    
    image = image / image.max()

    # Fit the Gaussian to the image data
    p0 = [1, image_size / 2, image_size / 2, 2, 2, 0, 0]  # Initial guess for parameters
    bounds = [
        (0, np.inf),
        (image_size * (1/5), image_size * (4/5)),
        (image_size * (1/5), image_size * (4/5)),
        (0, image_size/3),
        (0, image_size/3),
        (-np.inf, np.inf),
        (0, np.inf),
    ]

    try:
        popt, pcov = curve_fit(gaussian_2d, (x, y), image.ravel(), p0=p0, bounds=list(zip(*bounds)))
    except RuntimeError:
        print('XY fit failed')
        popt = p0
    render = gaussian_2d((x, y), *popt).reshape(image.shape)

    error = mean_squared_error(render, image)

    # if error > max_mse:
    #     print(i)
    #     # Visualize the original image and the fitted Gaussian
    #     plt.plot(sharp)
    #     plt.show()
    #     plt.figure(figsize=(5, 3))
    #     plt.subplot(1, 2, 1)
    #     plt.imshow(image)
    #     plt.title('Original Image')
        
    #     plt.subplot(1, 2, 2)
    #     plt.imshow(render)
    #     plt.title('Fitted Gaussian')
        
    #     plt.tight_layout()
    #     plt.show()
    #     print(error, max_mse, error <= max_mse)

    return error <= max_mse


def filter_beads(spots, locs, stacks, args, rejected_outpath):
    print('Removing poorly imaged beads...', end='')
    # Filter by SNR threshold

    for i in range(stacks.shape[0]):
        psf = stacks[i]
        plt.title(str(i))
        plt.imshow(grid_psfs(psf))
        plt.show()
    mse_xy = np.array([filter_mse_xy(psf, 10000, i) for i, psf in enumerate(stacks)])
    snrs = np.array([snr(psf) > args['min_snr'] for psf in stacks])
    fwhms = np.array([has_fwhm(psf, args) for psf in stacks])
    mse_z = np.array([filter_mse_zprofile(psf, args, i) for i, psf in enumerate(stacks)])
    # TODO re-enable
    # mse_xy[:] = True
    
    # snrs[:] = True
    # mse_filters[:] = True

    for i in range(stacks.shape[0]):
        psf = stacks[i]
        print(filter_mse_xy(psf, 10000, i))
        print(snr(psf) > args['min_snr'])
        print(has_fwhm(psf, args))
        print(filter_mse_zprofile(psf, args, i))
        plt.imshow(grid_psfs(psf))
        plt.show()


    idx = np.argwhere(snrs & fwhms & mse_z & mse_xy).squeeze()
    reasons = [''] * len(snrs)
    for i in range(spots.shape[0]):
        if not fwhms[i]:
            reasons[i] += 'fwhm'
        if not snrs[i]:
            reasons[i] += f',snrs({round(snr(stacks[i]), 3)})' 
        if not mse_z[i]:
            reasons[i] += ',mse_z'
        if not mse_xy[i]:
            reasons[i] += ',mse_xy'

    locs['rejected'] = reasons


    est_bead_offsets(stacks, locs, args)

    if args['debug']:
        rejected_idx = np.argwhere(np.invert(snrs & fwhms & mse_z & mse_xy))[:, 0]
        print('\n', 'Rejected: ', rejected_idx)

        if len(rejected_idx):
            print('Writing rejected figures...')

            write_stack_figures(stacks[rejected_idx], locs.iloc[rejected_idx], rejected_outpath)

    spots = spots[idx]
    locs = locs.iloc[idx]
    stacks = stacks[idx]

    print('finished!')

    return spots, locs, stacks


def write_combined_data(stacks, locs, args):

    outpath = os.path.join(args['outpath'], 'combined')
    os.makedirs(outpath, exist_ok=True)

    locs_outpath = os.path.join(outpath, 'locs.hdf')
    stacks_outpath = os.path.join(outpath, 'stacks.ome.tif')

    imwrite(stacks_outpath, stacks)
    locs.to_hdf(locs_outpath, key='locs')

    stacks_config = {
        'zstep': args['zstep'],
        'gen_args': args
    }
    
    stacks_config_outpath = os.path.join(outpath, 'stacks_config.json')
    with open(stacks_config_outpath, 'w') as fp:
        json_dumps_str = json.dumps(stacks_config, indent=4)
        print(json_dumps_str, file=fp)

    figpath = os.path.join(outpath, 'offsets.png')
    sns.scatterplot(data=locs, x='x', y='y', hue='offset')
    plt.savefig(figpath)
    plt.close()

    print('Saved results to:')
    print(f'\t{locs_outpath}')
    print(f'\t{stacks_outpath}')
    print(f'\t{stacks_config_outpath}')
    print(f'Total beads: {locs.shape[0]}')


def write_stack_figure(i, stacks, locs, outpath, fname):
    stack = stacks[i]
    loc = locs.iloc[i].to_dict()

    fig = plt.figure(layout="constrained", figsize=(20, 15), dpi=64)
    gs = plt.GridSpec(2, 3, figure=fig)
    ax1 = fig.add_subplot(gs[0:, 0])
    ax2 = fig.add_subplot(gs[0, 1:])
    ax3 = fig.add_subplot(gs[1, 1:])


    fig.suptitle(f'Bead: {i}')

    ax1.imshow(grid_psfs(stack, cols=20))
    ax1.set_title('Ordered by frame')

    intensity = stack.max(axis=(1,2))
    min_val = min(intensity)
    max_val = max(intensity)
    frame_zpos = (np.arange(len(intensity)) * args['zstep']) - loc['offset']
    ax2.plot(frame_zpos, intensity)
    ax2.vlines(0, min_val, max_val, colors='orange')
    ax2.set_title('Max normalised pixel intensity over z')
    ax2.set_xlabel('z (nm)')
    ax2.set_ylabel('pixel intensity')    


    for k, v in loc.items():
        if isinstance(v, float):
            loc[k] = round(v, 5)
    text = json.dumps(loc, indent=4)
    ax3.axis((0, 10, 0, 10))
    ax3.text(0,0, text, fontsize=18, wrap=True)
    outfpath = os.path.join(outpath, f'{fname}_bead_{i}.png')
    plt.savefig(outfpath)
    plt.close()
    print(f'Wrote {outfpath}')


from multiprocessing import Pool
from itertools import repeat


def write_stack_figures(stacks, locs, outpath):
    fname = set(locs['fname']).pop().replace('.ome.tif', '')
    os.makedirs(outpath, exist_ok=True)

    idx = np.arange(stacks.shape[0])
    with Pool(8) as pool:
        res = pool.starmap(write_stack_figure, zip(idx, repeat(stacks), repeat(locs), repeat(outpath), repeat(fname)))

# def filter_by_tmp_locs(locs, spots):
#     original_locs = pd.read_hdf('/home/miguel/Projects/smlm_z/publication/original_locs.hdf', key='locs')
#     x_coords = set(original_locs['x'])
#     idx = np.argwhere([x in x_coords for x in locs['x']]).squeeze()
#     locs = locs.iloc[idx]
#     spots = spots[idx]
#     print(locs.shape)
#     return locs, spots



stackss = None
def main(args):
    all_stacks = []
    all_spots = []
    all_locs = []

    found_beads = 0
    retained_beads = 0

    rejected_outpath = os.path.join(args['outpath'], 'combined', 'rejected')
    shutil.rmtree(rejected_outpath, ignore_errors=True)

    if args['debug']:
        os.makedirs(rejected_outpath, exist_ok=True)

    for bead_stack_path in tqdm(natsorted(args['bead_stacks'])):
        if 'stack__2_' not in bead_stack_path:
            continue
        tqdm.write(f'Preparing {os.path.basename(bead_stack_path)}')

        bead_stack = imread(bead_stack_path)
        slice_path = bead_stack_path.replace('.ome', '_slice.ome')
        slice_path = get_or_create_slice(bead_stack, slice_path)

        raw_locs, spots = get_or_create_locs(slice_path, args)
        
        # raw_locs, spots = filter_by_tmp_locs(raw_locs, spots)
        found_beads += raw_locs.shape[0]
        locs, spots = remove_colocal_beads(raw_locs, spots, args)
        perc_removed = round(100*(1-(locs.shape[0]/raw_locs.shape[0])), 2)
        print(f'Removed {perc_removed}% due to co-location')

        stacks = extract_training_stacks(spots, bead_stack, args)
        stackss = stacks
        raise EnvironmentError
        spots, locs, stacks = filter_beads(spots, locs, stacks, args, rejected_outpath)
        retained_beads += locs.shape[0]
        tqdm.write(f'Retained {stacks.shape[0]} beads')
        all_stacks.append(stacks)
        all_spots.append(spots)
        all_locs.append(locs)

    print(f'Found {found_beads} total beads')
    min_stack_length = min(list(map(lambda s: s.shape[1], all_stacks)))
    stacks = [s[:, :min_stack_length] for s in all_stacks]
    locs = pd.concat(all_locs)
    stacks = np.concatenate(stacks)

    # original_locs = pd.read_hdf('/home/miguel/Projects/smlm_z/publication/original_locs.hdf', key='locs')
    # x_coords = set(original_locs['x'])
    # print(len(set(locs['x'])), len(set(original_locs['x'])))
    # print(len(set(locs['x']).intersection(set(original_locs['x']))))
    # locs = pd.concat((locs, locs))
    # stacks = np.concatenate((stacks, stacks))

    print(locs.shape)
    print(stacks.shape)
    print(f'Kept {locs.shape[0]} total beads')

    write_combined_data(stacks, locs, args)

    if args['debug']:
        outpath = os.path.join(args['outpath'], 'combined', 'debug')
        write_stack_figures(stacks, locs, outpath)
        

def parse_args():
    parser = ArgumentParser(description='')
    parser.add_argument('bead_stacks', help='Path to TIFF bead stacks / directory containing bead stacks.')
    parser.add_argument('-z', '--zstep', help='Pixel size (nm)', default=10, type=int)
    parser.add_argument('-px', '--pixel_size', help='Pixel size (nm)', default=86, type=int)
    parser.add_argument('-g', '--gradient', help='Min. net gradient', default=1000, type=int)
    parser.add_argument('-b', '--box-size-length', help='Box size', default=15, type=int)
    parser.add_argument('-qe', '--qe', help='Quantum efficiency', type=float)
    parser.add_argument('-s', '--sensitivity', help='Sensitivity', type=float)
    parser.add_argument('-ga', '--gain', help='Gain', type=float)
    parser.add_argument('-bl', '--baseline', help='Baseline', type=int)
    parser.add_argument('-a', '--fit-method', help='Fit method', choices=['mle', 'lq', 'avg'])
    parser.add_argument('--regen', action='store_true')
    parser.add_argument('-snr', '--min-snr', type=float, default=2.0)
    parser.add_argument('-gb', '--gaussian-blur', default='3,2,2', help='Gaussian pixel-blur in Z/Y/X for bead offset estimation')
    parser.add_argument('--debug', action='store_true')

    args = vars(parser.parse_args())
    return args



In [None]:
args = {
    'zstep': 10,
    'pixel_size': 86,
    'bead_stacks': '/media/Data/smlm_z_data/20231121_nup_miguel_zeiss/stacks/',
    'gaussian_blur': '3,2,2',
    'debug': False,
    'regen': False,
    'box_size_length': 15
}

args = transform_args(args)
validate_args(args)
main(args)

In [None]:
bead_stack_path = '/media/Data/smlm_z_data/20231121_nup_miguel_zeiss/stacks/stack__2/stack__2_MMStack_Default.ome.tif'

bead_stack = imread(bead_stack_path)
slice_path = bead_stack_path.replace('.ome', '_slice.ome')
slice_path = get_or_create_slice(bead_stack, slice_path)

raw_locs, spots = get_or_create_locs(slice_path, args)

# raw_locs, spots = filter_by_tmp_locs(raw_locs, spots)
locs, spots = remove_colocal_beads(raw_locs, spots, args)
perc_removed = round(100*(1-(locs.shape[0]/raw_locs.shape[0])), 2)
print(f'Removed {perc_removed}% due to co-location')

stacks = extract_training_stacks(spots, bead_stack, args)

In [None]:
# mse_xy = np.array([filter_mse_xy(psf, 10000, i) for i, psf in enumerate(stacks)])
snrs = np.array([snr(psf) > args['min_snr'] for psf in stacks])
fwhms = np.array([has_fwhm(psf, args) for psf in stacks])
mse_z = np.array([filter_mse_zprofile(psf, args, i) for i, psf in enumerate(stacks)])

In [None]:

for i in range(stacks.shape[0]):
    plt.title(str(i))
    plt.imshow(grid_psfs(stacks[i]))
    plt.show()

In [None]:
import pandas as pd
import h5py

df = pd.read_csv('~/Projects/smlm_z/publication/comparisons/inspr/reconstruction.csv')
df['x'] = df['x [nm]'] / 86
df['y'] = df['y [nm]'] / 86
df['z'] = df['z [nm]'] / 86

df['photons'] = 1000
df['sx'] = 1
df['sy'] = 1
df['bg'] = 0
df['lpx'] = 0.1
df['lpy'] = 0.1
df['frame'] = df['Frame-ID']

locs_path = '/home/miguel/Projects/smlm_z/publication/comparisons/inspr/reconstruction.hdf5'
with h5py.File(locs_path, "w") as locs_file:
    locs_file.create_dataset("locs", data=df.to_records())
        
print(list(df))

In [None]:
plt.scatter(df['x [nm]'], df['y [nm]'], marker='.')



In [None]:
import seaborn as sns
sns.histplot(df['z [nm]'])

In [None]:
import mat73
import scipy
import pandas as pd
data_dict = mat73.loadmat('/home/miguel/Projects/smlm_z/publication/comparisons/spline/nup_11_sml.mat')


In [None]:
data_dict['saveloc']['file']['info']['cam_pixelsize_um']

In [None]:
import pandas as pd
import shutil
df = pd.DataFrame.from_records(data_dict['saveloc']['loc'])

df2 = pd.read_hdf('/home/miguel/Projects/smlm_z/publication/VIT_openframe/out_roll_alignment/out_nup/locs_3d.hdf5', key='locs')

PIXEL_SIZE = 86

map_cols = [
    ('xnm', 'x [nm]'),
    ('ynm', 'y [nm]'),
    ('znm', 'z [nm]'),
    ('xpix', 'x'),
    ('ypix', 'y'),
    ('PSFxnm', 'sx'),
    ('PSFynm', 'sy'),
    ('phot', 'photons'),
    ('locprecnm', 'lpx'),
    ('locprecnm', 'lpy'),
    ('znm', 'z'),
    ('zerr', 'lpz'),
]

for src, trgt in map_cols:
    df[trgt] = df[src].copy()

src_cols = list(set([x[0] for x in map_cols]))
for col in src_cols:
    del df[col]

rescale_cols = ['lpx', 'lpy', 'sx', 'sy', 'lpz']
for col in rescale_cols:
    df[col] = df[col] / PIXEL_SIZE

df = df[[c for c in list(df) if c in list(df2)] + ['lpz']]
df['frame'] = df['frame'].astype(int)
df.reset_index(drop=True, inplace=True)
print(list(df))
import h5py

locs_path = '/home/miguel/Projects/smlm_z/publication/comparisons/spline/locs_3d.hdf5'
with h5py.File(locs_path, "w") as locs_file:
    locs_file.create_dataset("locs", data=df.to_records())

# shutil.copy('/home/miguel/Projects/smlm_z/publication/VIT_openframe/out_roll_alignment/out_nup/locs_3d.yaml', locs_path.replace('hdf5', 'yaml'))
# frame=frame
# xnm=x
# ynm=y
# phot=photons
# PSFxnm=sx
# PSFynm=sy
# bg=bg
# locprecnm=lpx
# znm=z
# cam_pixelsize_um=100
# factor=100    1

In [None]:
import pandas as pd
import seaborn as sns
from sklearn.metrics import euclidean_distances
df_new = pd.read_hdf('/home/miguel/Projects/smlm_z/publication/comparisons/spline/locs_3d_undrift.hdf5', key='locs')
df_orig = pd.read_hdf('/home/miguel/Projects/data/20230601_MQ_celltype/nup/fov2/storm_1/storm_1_MMStack_Default.ome_locs_undrifted_picked_4.hdf5', key='locs')
print(df_new.shape)
sns.scatterplot(data=df_new, x='x', y='y', alpha=0.01)
sns.scatterplot(data=df_orig, x='x', y='y', alpha=0.01)

In [None]:
import numpy as np
from tqdm import trange
import tensorflow as tf

def _euclidean_distances(coords1, coords2, reduce_min=None):
    """
    Compute the Euclidean distances between two sets of 2D coordinates.

    Args:
        coords1 (tf.Tensor): A tensor of shape (N, 2) containing N 2D coordinates.
        coords2 (tf.Tensor): A tensor of shape (M, 2) containing M 2D coordinates.

    Returns:
        tf.Tensor: A tensor of shape (N, M) containing the Euclidean distances
                   between each pair of coordinates from coords1 and coords2.
    """
    # Ensure inputs are TensorFlow tensors
    coords1 = tf.convert_to_tensor(coords1, dtype=tf.float32)
    coords2 = tf.convert_to_tensor(coords2, dtype=tf.float32)

    t1 = tf.reshape(coords1, (1, *coords1.shape))
    t2 = tf.reshape(coords2, (coords2.shape[0],1,coords2.shape[1]))
    result = tf.norm(t1-t2, ord='euclidean', axis=2,)
    if reduce_min is not None:
        result = tf.math.reduce_min(result, axis=reduce_min, keepdims=False)
    res = result.numpy()
    del result
    return res
    
def batched_euclidean_distance(coords1, coords2, batch_size, reduce_min=None):
    min_dists = np.zeros((coords1.shape[0],))
    from tqdm import trange
    for i in trange(0, coords1.shape[0]-1, batch_size):
        start, end = i, i+batch_size
        _coords1 = coords1[start:end]
        _min_dists = _euclidean_distances(_coords1, coords2, reduce_min)
        min_dists[start:end] = _min_dists
    return min_dists

xy_new = df_new[['x', 'y']].to_numpy()
xy_orig = df_orig[['x', 'y']].to_numpy()

# # Adjust batch size to fit in GPU memory
BATCH_SIZE = 2**14

# print(xy_new.shape)
# print(BATCH_SIZE)

min_dists = batched_euclidean_distance(xy_new, xy_orig, BATCH_SIZE, 0)

print(min_dists.min(), min_dists.mean(), min_dists.max())


In [None]:
from sklearn.metrics import euclidean_distances
import matplotlib.pyplot as plt
cutoff = (120/2) / 86
sub_df = df_new[min_dists<cutoff]

sub_df['lpz'] /= 10
# sns.scatterplot(data=df_orig, x='x', y='y', alpha=0.01)
# plt.show()
# sns.scatterplot(data=sub_df, x='x', y='y', alpha=0.01)
# plt.show()

nearest_neighbour = np.zeros((sub_df.shape[0]), dtype=object)
xy_new = sub_df[['x', 'y']].to_numpy()

df_orig_groups = df_orig.groupby('group').mean()

df_orig_groups.reset_index(inplace=True, drop=False)

group_centers = df_orig_groups[['x', 'y']].to_numpy()

print(xy_new.shape, group_centers.shape)
dists = euclidean_distances(xy_new, group_centers)


dists_idx = np.argmin(dists, axis=1)
sub_df['clusterID'] = df_orig_groups['group'][dists_idx].to_numpy().astype(int)


In [None]:
import h5py
if 'index' in list(sub_df):
    del sub_df['index']
sub_df['group'] = sub_df['clusterID']
locs_path = '/home/miguel/Projects/smlm_z/publication/comparisons/spline/locs_3d_undrift.hdf5'
with h5py.File(locs_path, "w") as locs_file:
    locs_file.create_dataset("locs", data=sub_df.to_records())

In [None]:
!ls /home/miguel/Projects/smlm_z/publication/spline_comparison/nup_renders3/*/nup_88_gaussian.png

In [None]:
import matplotlib.pyplot as plt
import os
import glob
from natsort import natsorted
import matplotlib.image as mpimg

dir1 = '/home/miguel/Projects/smlm_z/publication/comparisons/spline/nup_renders3/*/'
dir2 = '/home/miguel/Projects/smlm_z/publication/VIT_openframe/out*/out_nup/nup_renders3/*/'

imgs = [set(list(map(os.path.basename, glob.glob(f'{dirname}/*.png')))) for dirname in (dir1, dir2)]
imnames = imgs[0].intersection(imgs[1])
print(len(imnames))

for imname in natsorted(imnames):
    impath = glob.glob(dir1+imname)[0]
    img1 = mpimg.imread(impath)

    impath2 = glob.glob(dir2+imname)[0]
    if 'good' not in impath2:
        continue
    print(imname)

    img2 = mpimg.imread(impath2)
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(30, 15))
    # Display the first image
    ax1.imshow(img1)
    ax1.set_title('MLE')
    ax1.axis('off')
    
    # Display the second image
    ax2.imshow(img2)
    ax2.set_title('Vision transformer')
    ax2.axis('off')
    
    # Adjust the spacing between subplots
    plt.subplots_adjust(wspace=0)
    
    # Display the figure
    plt.show()
        
        

In [None]:
for c in sorted(set(sub_df['group'])):
    tmp_df = sub_df[sub_df['group']==c]
    print(tmp_df)


In [None]:
import pandas as pd

new_df = '

In [None]:
import matplotlib.pyplot as plt
from picasso import io
from picasso.render import render
import pandas as pd
import numpy as np

import seaborn as sns
from sklearn.neighbors import KernelDensity
from scipy.signal import find_peaks
from mpl_toolkits.axes_grid1.anchored_artists import AnchoredSizeBar
import matplotlib.font_manager as fm
fontprops = fm.FontProperties(size=18)

min_sigma = 0
max_sigma = 3

z_min = 400
z_max = 600
min_log_likelihood = -100
# min_kde = np.log(0.007)
min_kde = 0.05

cmap_min_z = -600
cmap_max_z = -300
BLUR = 'gaussian'
color_by_depth = False

MIN_BLUR=0.001

records = []

def filter_locs(l):
    n_points = l.shape[0]
    print(f'From {n_points} points')

    l = l[(min_sigma < l['sx']) & (l['sx'] < max_sigma)]
    l = l[(min_sigma < l['sy']) & (l['sy'] < max_sigma)]
    # print(f'{n_points-l.shape[0]} removed by sx/sy')


    X = l[['z']]
    kde = KernelDensity(kernel='gaussian', bandwidth=0.5).fit(X)
    l['kde'] = kde.score_samples(X)

    # l = l[l['z [nm]'] > z_min]
    # l = l[l['z [nm]'] < z_max]
    # sns.scatterplot(data=l, x='z', y='kde')
    # plt.show()
    
    l = l[np.power(10, l['kde']) > min_kde]
    # print(f'{n_points-l.shape[0]} removed by kde')

    # l = l[l['likelihood']>min_log_likelihood]
    
    n_points2 = l.shape[0]
    # print(f'Removed {n_points-n_points2} pts')
    # print(f'{n_points2} remaining')
    print(f'N points: {n_points2}')

    return l


plt.rcParams['figure.figsize'] = [18, 6]


def apply_cmap_img(img, cmap_min_coord, cmap_max_coord, img_min_coord, img_max_coord, cmap='gist_rainbow', brightness_factor = 20):
    img = img.squeeze()
    
    cmap_zrange = cmap_max_coord - cmap_min_coord
    
    def map_z_to_cbar(z_val):
        return (z_val - cmap_min_coord) / cmap_zrange
        
    min_coord_color = map_z_to_cbar(img_min_coord)
    max_coord_color = map_z_to_cbar(img_max_coord)
    
    cmap = plt.get_cmap('gist_rainbow')
    
    gradient = np.repeat(np.linspace(min_coord_color, max_coord_color, img.shape[1])[np.newaxis, :], img.shape[0], 0)
    
    base = cmap(gradient)
    img = img[:, :, np.newaxis]
    cmap_img = img * base
    # cmap_img /= 2
    # Black background
    cmap_img = (cmap_img / cmap_img.max()) * 255
    cmap_img *= brightness_factor

    cmap_img[:, :, 3] = 255 
    
    cmap_img = cmap_img.astype(int)

    return cmap_img
    
def color_histplot(barplot, cmap_min_z, cmap_max_z):
    from matplotlib.colors import rgb2hex
    cmap = plt.get_cmap('gist_rainbow')
    
    bar_centres = [bar._x0 + bar._width/2 for bar in barplot.patches]
    bar_centres = np.array(list(map(lambda x: (x-cmap_min_z) / (cmap_max_z-cmap_min_z), bar_centres)))
    rgb_colors = cmap(bar_centres)
    hex_colors = [rgb2hex(x) for x in rgb_colors]
    
    for bar, hex_color in zip(barplot.patches, hex_colors):
        bar.set_facecolor(hex_color)
        

def center_view(locs, zrange=200):
    zs = locs['z [nm]']
    bin_width = 25
    hist, bins = np.histogram(zs, bins=np.arange(zs.min(), zs.max(), bin_width))
    try:
        max_bin_idx = np.argmax(hist)
        bin_val = bins[max_bin_idx] + (bin_width // 2)
    except ValueError:
        bin_val = np.mean(zs)

    locs = locs[(bin_val-zrange <=locs['z [nm]']) & (locs['z [nm]'] <= bin_val+zrange)]

    return locs

def get_viewport(locs, axes, margin=1):
    mins = np.array([locs[ax].min()-margin for ax in axes])
    maxs = np.array([locs[ax].max()+margin for ax in axes])
    # mins[:] = min(mins)
    # maxs[:] = max(maxs)
    return np.array([mins, maxs])

def disable_axis_ticks():
    plt.xticks([])
    plt.yticks([])

def get_extent(viewport, pixel_size):
    mins, maxs = viewport
    return np.array([mins[1], maxs[1], mins[0], maxs[0]]) * pixel_size


def render_locs(locs, args, ang_xyz=(0,0,0), barsize=None, ax=None):
    
    locs = locs.copy()
    locs['lpz'] = np.mean(locs[['lpx', 'lpy']].to_numpy()) / 2
    locs['sz'] = np.mean(locs[['sx', 'sy']].to_numpy()) / 3
    # locs['lpx'] = 0.1
    # locs['sx'] = 0.1
    # locs['lpy'] = 0.1
    # locs['sy'] = 0.1
    disable_axis_ticks()
    locs['x [nm]'] -= locs['x [nm]'].mean()
    locs['y [nm]'] -= locs['y [nm]'].mean()
    locs['z [nm]'] -= locs['z [nm]'].mean()
    locs['x'] -= locs['x'].mean()
    locs['y'] -= locs['y'].mean()
    locs['z'] -= locs['z'].mean()

    viewport = get_viewport(locs, ('y', 'x'))

    _, img = render(locs.to_records(), blur_method=args['blur_method'], viewport=viewport, min_blur_width=args['min_blur'], ang=ang_xyz, oversampling=args['oversample'])
    if ang_xyz == (0, 0, 0):
        plt.xlabel('x [nm]')
        plt.ylabel('y [nm]')
    elif ang_xyz == (np.pi/2, 0, 0):
        plt.xlabel('z [nm]')
        plt.ylabel('x [nm]')
        img = img.T
        viewport = np.fliplr(viewport)

    elif ang_xyz == (0, np.pi/2, 0):
        plt.xlabel('z [nm]')
        plt.ylabel('y [nm]')
    else:
        print('Axis labels uncertain due to rotation angle')

    extent = get_extent(viewport, args['pixel_size'])
    if ax is None:
        ax = plt.gca()
    ax.set_aspect('equal', 'box')
    img_plot = plt.imshow(img, extent=extent)
    plt.colorbar(img_plot)

    if barsize is not None:
        scalebar = AnchoredSizeBar(ax.transData,
                            barsize, f'{barsize} nm', 'lower center', 
                            pad=0.1,
                            color='white',
                            frameon=False,
                            size_vertical=1,
                            fontproperties=fontprops)
        ax.add_artist(scalebar)

def write_nup_plots(locs, args, good_dir, other_dir):
    for cid in set(locs['clusterID']):
        # if not cid in [1, 6, 18, 19, 21, 22]:
        if not cid in [6]:
            continue
        print('Cluster ID', cid)

        df = locs[locs['clusterID']==cid]
        df = filter_locs(df)

        if df.shape[0] == 0:
            continue
        df = center_view(df)

        try:
            del df['index']
        except ValueError:
            pass

        if df.shape[0] < 5:
            print('No remaining localisations, continuing...')
            continue

        fig = plt.figure()
        gs = fig.add_gridspec(1, 4)
        plt.subplots_adjust(wspace=0.3, hspace=0)
        
        ax1 = fig.add_subplot(gs[0, 0])
        render_locs(df, args, (0,0,0), barsize=110, ax=ax1)

        ax2 = fig.add_subplot(gs[0, 1])
        render_locs(df, args, (np.pi/2,0,0), barsize=50, ax=ax2)

        ax3 = fig.add_subplot(gs[0, 2])
        render_locs(df, args, (0, np.pi/2,0), barsize=50, ax=ax3)

        ax4 = fig.add_subplot(gs[0, 3])
        
        histplot = sns.histplot(data=df, x='z [nm]', ax=ax4, stat='density', legend=False)
        if color_by_depth:
            color_histplot(histplot, cmap_min_z, cmap_max_z)
        sns.kdeplot(data=df, x='z [nm]', ax=ax4, bw_adjust=0.5, color='black', bw_method='silverman')

        x = ax4.lines[0].get_xdata()
        y = ax4.lines[0].get_ydata()
        peaks, _ = find_peaks(y)

        sorted_peaks = sorted(peaks, key=lambda peak_index: y[peak_index], reverse=True)
        peak_vals = y[peaks]
        if len(peak_vals) == 1:
            n_peaks = 1
        else:
            n_peaks = 2
            
        sorted_peaks = sorted_peaks[:n_peaks]

        peak_x = x[sorted_peaks]
        peak_y = y[sorted_peaks]
        for x, y in zip(peak_x, peak_y):
            ax4.vlines(x, 0, y, label=str(round(x)), color='black')

        sep = abs(max(peak_x) - min(peak_x))

        septxt = 'Sep: '+ str(round(sep))+ 'nm'

        records.append({
            'id': cid,
            'seperation': sep,
        })

        margin=10
        if 50-margin <= sep and sep <= 50+margin:
            cluster_outdir = good_dir
        else:
            cluster_outdir = other_dir
        plt.suptitle(f'Nup ID: {cid}, N points: {df.shape[0]}, {septxt}')
        plt.show()

def load_and_filter_locs(args):
    locs, info = io.load_locs(args['locs'])
    locs = pd.DataFrame.from_records(locs)
    try:
        assert info[1]['Pixelsize'] == args['pixel_size']
    except AssertionError:
        print('Pixel size mismatch', info[1]['Pixelsize'],  args['pixel_size'])
        quit(1)

    if args['picked_locs']:
        picked_locs, old_info = io.load_locs(args['picked_locs'])
        picked_locs = pd.DataFrame.from_records(picked_locs)
        locs = locs.merge(picked_locs, on=['x', 'y', 'photons', 'bg', 'lpx', 'lpy', 'net_gradient', 'iterations', 'frame', 'likelihood', 'sx', 'sy'])
    locs['clusterID'] = locs['group']
    locs['z'] = locs['z [nm]'] / args['pixel_size']
    return locs

locs = '/home/miguel/Projects/smlm_z/publication/VIT_openframe/out_roll_alignment/out_nup/locs_3d.hdf5'
picked_locs = '/home/miguel/Projects/data/20230601_MQ_celltype/nup/fov2/storm_1/storm_1_MMStack_Default.ome_locs_undrifted_picked_4.hdf5'

args = {
    'locs': locs,
    'picked_locs': picked_locs,
    'pixel_size': 86,
    'blur_method': 'gaussian',
    'min_blur': 0.001,
    'oversample': 20,
}

locs = load_and_filter_locs(args)

write_nup_plots(locs, args, None, None)

In [None]:
from scipy.stats import gaussian_kde

plt.rcParams['figure.figsize'] = [15, 5]

for cid in [6, 18, 19, 21, 22]:
    cid_locs = locs[locs['clusterID']==cid]
    df2 = filter_locs(cid_locs).copy()
    del df2['index']
    
    kde = gaussian_kde(df2['z [nm]'].to_numpy())
    kde.set_bandwidth(bw_method='silverman')
    kde.set_bandwidth(kde.factor * 0.75)
    
    zvals = np.linspace(df2['z [nm]'].min()-25, df2['z [nm]'].max()+25, 50)
    
    score = kde(zvals)
    zvals = zvals.squeeze()
    
    peaks, _ = find_peaks(score)
    
    sorted_peaks = sorted(peaks, key=lambda peak_index: zvals[peak_index], reverse=False)
    peak_vals = zvals[peaks]
    
    if len(peak_vals) == 1:
        n_peaks = 1
    else:
        n_peaks = 2
        
    sorted_peaks = sorted_peaks[-2:]
    z_peaks = zvals[sorted_peaks]

    seperation = np.diff(z_peaks)[0]

    z_between_peaks = np.linspace(min(z_peaks), max(z_peaks), 50)
    scores = kde(z_between_peaks)
    
    density_cutoff = min(scores) * 1.05
    print('Cutoff', density_cutoff)
    
    df2['density'] = kde(df2['z [nm]'].to_numpy())
    
    df3 = df2[df2['density']>=density_cutoff]
    
    
    for _df, title in zip([df2, df3], ['raw', 'w/ min density']):
        fig = plt.figure()
        plt.axis('off')
        plt.subplots_adjust(wspace=0.3, hspace=0)
        plt.title(title)
        gs = fig.add_gridspec(1, 3)
        ax1 = fig.add_subplot(gs[0, 0])
        render_locs(_df, args, (np.pi/2,0,0), barsize=50, ax=ax1)
        ax2 = fig.add_subplot(gs[0, 1])
        render_locs(_df, args, (0,np.pi/2,0), barsize=50, ax=ax2)

        ax3 = fig.add_subplot(gs[0,2])
        histplot = sns.histplot(data=_df, x='z [nm]', stat='density', legend=False)
        if title == 'raw':
            # Plot KDE
            ax3.plot(zvals, score)
            # PLOT cutoff line
            x = [_df['z [nm]'].min(), _df['z [nm]'].max()]
            y = [density_cutoff, density_cutoff]
            ax3.plot(x, y, 'r--')

            # Plot peaks
            for peak in z_peaks:
                x = [peak, peak]
                y = [0, kde(peak).squeeze()]
                ax3.plot(x, y, 'r--')
                ax3.set_title(f'Sep: {round(seperation, 2)}')
        plt.show()

In [None]:
from seaborn._statistics import KDE

plt.rcParams['figure.figsize'] = [3, 3]

locs = locs[locs['clusterID']==6]
df2 = filter_locs(locs).copy()
del df2['index']
fig = plt.figure()
ax2 = plt.gca()
histplot = sns.histplot(data=df2, x='z [nm]', ax=ax2, stat='density', legend=False)
sns.kdeplot(data=df2, x='z [nm]', ax=ax2, bw_adjust=0.5, color='black', bw_method='silverman')

kde = KDE(bw_method='silverman', bw_adjust=0.5).fit(df2[['z [nm]']])

zvals = np.linspace(df2['z [nm]'].min()-25, df2['z [nm]'].max()+25, 50).reshape(-1, 1)

score = kde.score_samples(zvals)
score = np.exp(score)
print(score)
zvals = zvals.squeeze()
plt.plot(zvals, score, c='red')


plt.show()


# peaks, _ = find_peaks(y)

# sorted_peaks = sorted(peaks, key=lambda peak_index: y[peak_index], reverse=True)
# peak_vals = y[peaks]
# if len(peak_vals) == 1:
#     n_peaks = 1
# else:
#     n_peaks = 2
    
# sorted_peaks = sorted_peaks[:n_peaks]


# fig = plt.figure()
# ax = plt.gca()
# render_locs(df2, args, (np.pi/2,0,0), barsize=50, ax=ax)



In [None]:
from sklearn.metrics import euclidean_distances
import pandas as pd
import numpy as np


def norm_zero_one(s):
    min_val = s.min()
    max_val = s.max()
    return (s - min_val) / (max_val - min_val)


import tensorflow as tf
from keras.metrics import mean_squared_error
from tqdm import tqdm, trange
from tifffile import imread

UPSCALE_RATIO = 1

def norm_sum_imgs(psf):
    psf_sums = psf.sum(axis=(1,2))
    psf = psf / psf_sums[:, np.newaxis, np.newaxis]
    return psf


def tf_eval_roll(ref_psf, psf, roll):
    return tf.reduce_mean(mean_squared_error(ref_psf, tf.roll(psf, roll, axis=0)))
    
def tf_find_optimal_roll(ref_tf, img, upscale_ratio=UPSCALE_RATIO):
    
    img_tf = tf.convert_to_tensor(img)

    roll_range = ref_tf.shape[0]//4
    rolls = np.arange(-roll_range, roll_range).astype(int)
    errors = tf.map_fn(lambda roll: tf_eval_roll(ref_tf, img_tf, roll), rolls, dtype=tf.float64)
    # idx = 0
    # for roll in tqdm(rolls):
    #     error = tf.eval_roll(ref_tf, img_tf, roll)
    #     print(i, error)
    #     errors[idx] = error
    #     idx += 1
    best_roll = rolls[tf.argmin(errors).numpy()]
    # Prefer small backwards roll to large forwards roll
    if abs(best_roll - img.shape[0]) < best_roll:
        best_roll = best_roll - img.shape[0]
    return best_roll/upscale_ratio



def realign_beads(psfs, df, z_step, i):
    from sklearn.metrics import euclidean_distances
    df['dist'] = euclidean_distances(df[['x', 'y']].to_numpy(), [[df['x'].max()/2, df['y'].max()/2]])
    ref_idx = np.argsort(df['dist'].to_numpy())[i]
    print(ref_idx)
    ref_offset = df.iloc[ref_idx]['offset']

    ref_psf = norm_zero_one(psfs[ref_idx])
    ref_tf = tf.convert_to_tensor(ref_psf)
    rolls = []
    for idx in trange(df.shape[0]):
        if idx == ref_idx:
            roll = 0
        else:
            psf2 = norm_zero_one(psfs[idx])
            roll = -tf_find_optimal_roll(ref_tf, psf2)
        print(roll)
        rolls.append(roll)
    rolls = np.array(rolls)
    return rolls * z_step


psfs = imread('/home/miguel/Projects/smlm_z/publication/VIT_openframe/stacks.ome.tif')
locs = pd.read_hdf('/home/miguel/Projects/smlm_z/publication/VIT_openframe/locs.hdf', key='locs')
z_step = 10

psfs = psfs[::5]
locs = locs.iloc[::5]

for i in range(5):
    locs[f'offset_{i}'] = realign_beads(psfs, locs, z_step, i)



In [None]:
empty_psf = np.zeros(psfs.shape[1:])
from data.visualise import show_psf_axial
for p, offset in zip(psfs, locs['offset_0'].to_numpy()):
    p = norm_zero_one(p)
    rolled = np.roll(p, shift=-int(offset//z_step), axis=0)
    empty_psf += rolled

ref_psf = norm_zero_one(empty_psf)
    


In [None]:
print(locs['offset_avg_ref'].mean())

In [None]:
np.argmax(ref_psf.max(axis=(1,2))) * z_step

In [None]:
def realign_beads2(psfs, df, z_step, ref_psf):
    ref_psf = norm_zero_one(ref_psf)
    ref_tf = tf.convert_to_tensor(ref_psf)
    rolls = []
    for idx in trange(df.shape[0]):
        if idx == ref_idx:
            roll = 0
        else:
            psf2 = norm_zero_one(psfs[idx])
            roll = -tf_find_optimal_roll(ref_tf, psf2)
        rolls.append(roll)
    rolls = np.array(rolls)
    return rolls * z_step

locs['offset_avg_ref'] = realign_beads2(psfs, locs, 10, ref_psf)

In [None]:
for c in list(locs):
    if 'offset' in c:
        plt.scatter(x, (locs[c]-locs[c].mean())-locs['offset_avg_ref'], label=c)
        
# plt.scatter(x, locs['offset_avg_ref'], label='avg')
# for i in range(5):
#     plt.scatter(x, locs[f'offset_{i}'], label=str(i))
plt.legend()


In [None]:
import matplotlib.pyplot as plt
x = np.arange(locs.shape[0])

ds = []
for i in range(5):
    y = abs(locs[f'offset_{i}'] - locs['offset_0'].to_numpy())
    plt.scatter(x, y)
    print(np.mean(y))
plt.show()

In [None]:
import pandas as pd

old_df = pd.read_hdf('/home/miguel/Projects/smlm_z/publication/VIT_013/locs.hdf', key='locs')
new_df = pd.read_hdf('/media/Data/smlm_z_data/20231121_nup_miguel_zeiss/stacks/combined/locs.hdf', key='locs')

print(old_df.shape, new_df.shape)

In [None]:
from tifffile import imread

d = imread('/home/miguel/Projects/smlm_z/publication/tmp/tmp.tif')
print(d.shape)

In [None]:
import numpy as np
def snr(p):
    return p.max() / p.mean()

snrs = np.array([snr(p) for p in d])
min_snr = 2.0
idx = np.argwhere(snrs>min_snr).squeeze()
print(len(idx))

from data.visualise import show_psf_axial

In [None]:
from scipy.optimize import curve_fit
from sklearn.metrics import mean_squared_error
import matplotlib.pyplot as plt

def get_sharpness(array):
    gy, gx = np.gradient(array)
    gnorm = np.sqrt(gx**2 + gy**2)
    sharpness = np.average(gnorm)
    return sharpness


def reduce_img(psf):
    return np.stack([get_sharpness(x) for x in psf])


def filter_mse_xy(stack, max_mse, i):
    # Define a 2D Gaussian function
    def gaussian_2d(xy, amplitude, xo, yo, sigma_x, sigma_y, theta, offset):
        x, y = xy
        xo = float(xo)
        yo = float(yo)
        a = (np.cos(theta)**2) / (2 * sigma_x**2) + (np.sin(theta)**2) / (2 * sigma_y**2)
        b = -(np.sin(2 * theta)) / (4 * sigma_x**2) + (np.sin(2 * theta)) / (4 * sigma_y**2)
        c = (np.sin(theta)**2) / (2 * sigma_x**2) + (np.cos(theta)**2) / (2 * sigma_y**2)
        g = offset + amplitude * np.exp(- (a * ((x - xo)**2) + 2 * b * (x - xo) * (y - yo) + c * ((y - yo)**2)))
        return g.ravel()

    sharp = reduce_img(stack)
    idx = np.argmax(sharp)
    image = stack[idx]


    # Load and preprocess the image (e.g., convert to grayscale)
    # For simplicity, let's generate a simple image for demonstration
    image_size = image.shape[1]
    x = np.linspace(0, image_size - 1, image_size)
    y = np.linspace(0, image_size - 1, image_size)
    x, y = np.meshgrid(x, y)
    
    # Fit the Gaussian to the image data
    p0 = [1, image_size / 2, image_size / 2, 2, 2, 0, 0]  # Initial guess for parameters
    bounds = [
        (0, np.inf),
        (image_size * (1/5), image_size * (4/5)),
        (image_size * (1/5), image_size * (4/5)),
        (0, image_size/3),
        (0, image_size/3),
        (-np.inf, np.inf),
        (0, np.inf),
    ]

    image = image / image.max()

    try:
        popt, pcov = curve_fit(gaussian_2d, (x, y), image.ravel(), p0=p0, bounds=list(zip(*bounds)))
    except RuntimeError:
        print('XY fit failed')
        popt = p0
    render = gaussian_2d((x, y), *popt).reshape(image.shape)

    render0 = gaussian_2d((x, y), *p0).reshape(image.shape)
    error = mean_squared_error(render, image)

    if error > max_mse:
        d = 'bad'
    else:
        d = 'good'

    
    # show_psf_axial(stack, f'{i} {d} ' + '{:.3E}'.format(error), 30)
    # Visualize the original image and the fitted Gaussian
    print(p0)
    print(popt)
    plt.figure(figsize=(5, 3))
    plt.title('Error' +  '{:.8E}'.format(error))

    plt.subplot(1, 3, 1)
    plt.imshow(image)
    plt.title('Original Image')

    plt.subplot(1, 3, 2)
    plt.imshow(render0)

    plt.title('Initial Gaussian')

    plt.subplot(1, 3, 3)
    plt.imshow(render)
    plt.title('Fitted Gaussian')
    
    plt.tight_layout()
    plt.show()
    # plt.savefig(f'./{d}/{i}.png')
    # plt.close()
    print(error, max_mse, error <= max_mse)

    return error <= max_mse

# problem_idx = 
for i in [1548]:
    filter_mse_xy(d[i], 0.005, i)
    

In [None]:
from tensorflow import keras
import os
import tensorflow as tf


model_dir = '/home/miguel/Projects/smlm_z/publication/VIT_035/out_roll_alignment'

args = {
    'outdir': model_dir
}

model = keras.models.load_model(os.path.join(args['outdir'], './latest_vit_model3'))
test_data = tf.data.Dataset.load(os.path.join(args['outdir'], 'test'))
pred_zs = model.predict(test_data, batch_size=4096)


In [None]:
import numpy as np
def get_z_coordinates(dataset):
    zs = []
    xys = []
    for (_, xy), z in dataset.as_numpy_iterator():
        zs.append(z)
        xys.append(xy)

    return np.concatenate(xys).squeeze(), np.concatenate(zs).squeeze()

xys, zs = get_z_coordinates(test_data)

In [None]:
pred_zs = pred_zs.squeeze()

In [None]:
from scipy import optimize as opt
from sklearn.metrics import root_mean_squared_error

def bestfit_error(z_true, z_pred):
    def linfit(x, c):
        return x + c

    x = z_true
    y = z_pred
    popt, _ = opt.curve_fit(linfit, x, y, p0=[0])

    x = np.linspace(z_true.min(), z_true.max(), len(y))
    y_fit = linfit(x, popt[0])
    error = root_mean_squared_error(y_fit, y)
    return error, popt[0], y_fit, abs(y_fit-y)

def remove_constant_error(xys, zs, pred_zs):
    coords2 = ['_'.join(x.astype(str)) for x in xys]
    all_errors = []
    for num, c in enumerate(set(coords2)):
        idx = [i for i, val in enumerate(coords2) if val==c]
        _, _, _, errors = bestfit_error(zs[idx], pred_zs[idx])
        all_errors.append(errors)
    all_errors = np.concatenate(all_errors)
    return all_errors

corrected_errors = remove_constant_error(xys, zs, pred_zs)
errors = abs(pred_zs-zs)

In [None]:
import pandas as pd
ries_data = pd.read_csv('/home/miguel/Projects/smlm_z/publication/ries_comparison_data.csv')
cols = list(ries_data)
ries_deeploc = ries_data[[c for c in cols if ('DeepLoc' in c) or ('z(nm)' in c)]].dropna().set_index('z(nm)')
crlb_deeploc = ries_data[[c for c in cols if ('CRLB' in c) or ('z(nm)' in c)]].dropna().set_index('z(nm)')

In [None]:
import seaborn as sns
import matplotlib.pyplot as plt
pred_zs = pred_zs.squeeze()
zs = zs.squeeze()

sns.regplot(x=pred_zs, y=corrected_errors, scatter=True, ci=95, order=5, x_bins=np.arange(-1000, 1000, 50), label='Our method')
sns.lineplot(data=ries_deeploc)
plt.show()


sns.regplot(x=pred_zs, y=corrected_errors, scatter=True, ci=95, order=5, x_bins=np.arange(-1000, 1000, 50), label='Our method')
sns.lineplot(data=crlb_deeploc)
plt.show()


In [None]:
stacks = imread('/media/Data/smlm_z_data/20231121_nup_miguel_zeiss/stacks/tmp.tif')
print(stacks.shape)

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from tifffile import imread
from scipy.optimize import curve_fit
from sklearn.metrics import mean_squared_error

def get_sharpness(array):
    gy, gx = np.gradient(array)
    gnorm = np.sqrt(gx**2 + gy**2)
    sharpness = np.average(gnorm)
    return sharpness


def reduce_img(psf):
    return np.stack([get_sharpness(x) for x in psf])
    
def filter_mse_xy(stack, max_mse, i):
    # Define a 2D Gaussian function
    def gaussian_2d(xy, amplitude, xo, yo, sigma_x, sigma_y, theta, offset):
        x, y = xy
        xo = float(xo)
        yo = float(yo)
        a = (np.cos(theta)**2) / (2 * sigma_x**2) + (np.sin(theta)**2) / (2 * sigma_y**2)
        b = -(np.sin(2 * theta)) / (4 * sigma_x**2) + (np.sin(2 * theta)) / (4 * sigma_y**2)
        c = (np.sin(theta)**2) / (2 * sigma_x**2) + (np.cos(theta)**2) / (2 * sigma_y**2)
        g = offset + amplitude * np.exp(- (a * ((x - xo)**2) + 2 * b * (x - xo) * (y - yo) + c * ((y - yo)**2)))
        return g.ravel()

    sharp = reduce_img(stack)
    idx = np.argmax(sharp)
    image = stack[idx]

    # Load and preprocess the image (e.g., convert to grayscale)
    # For simplicity, let's generate a simple image for demonstration
    image_size = image.shape[1]
    x = np.linspace(0, image_size - 1, image_size)
    y = np.linspace(0, image_size - 1, image_size)
    x, y = np.meshgrid(x, y)



    # Fit the Gaussian to the image data
    p0 = [1, image_size / 2, image_size / 2, 5, 5, 0, 0]  # Initial guess for parameters
    bounds = [
        (-np.inf, np.inf),
        (image_size * (2/5), image_size * (3/5)),
        (image_size * (2/5), image_size * (3/5)),
        (0, np.inf),
        (0, np.inf),
        (-np.inf, np.inf),
        (0, np.inf),
    ]

    fit_failed = False
    try:
        popt, pcov = curve_fit(gaussian_2d, (x, y), image.ravel(), p0=p0, bounds=list(zip(*bounds)))
    except RuntimeError:
        print('XY fit failed')
        fit_failed = True
        popt = p0

    render = gaussian_2d((x, y), *popt).reshape(image.shape)

    error = mean_squared_error(render, image)

    # Visualize the original image and the fitted Gaussian
    if error >= max_mse or fit_failed:
        print(i)

        for limits, val in zip(bounds, popt):
            print(limits, val)
    
        plt.plot(sharp)
        plt.show()
        plt.figure(figsize=(5, 3))
        plt.subplot(1, 2, 1)
        plt.imshow(image)
        plt.title('Original Image')
        
        plt.subplot(1, 2, 2)
        plt.imshow(render)
        plt.title('Fitted Gaussian')
        
        plt.tight_layout()
        plt.show()
        print(error, max_mse, error <= max_mse)
        show_psf_axial(stack, '', 30)


    return error <= max_mse
from data.visualise import show_psf_axial
for i in range(stacks.shape[0]):
    print(i)
    filter_mse_xy(stacks[i], 10000, i)

In [None]:
# Plane alignment for beads


from numpy.linalg import lstsq
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

from sklearn import linear_model
def realign_beads(psfs, df):
        
    X_data = df[['x', 'y']].to_numpy()
    Y_data = df['offset'].to_numpy()
    
    reg = linear_model.LinearRegression().fit(X_data, Y_data)
    
    print("coefficients of equation of plane, (a1, a2): ", reg.coef_)
    
    print("value of intercept, c:", reg.intercept_)
    
    z_fit = reg.predict(X_data)
    
    error = abs(z_fit-Y_data)
    perc_cutoff = np.percentile(error, 95)
    
    idx = np.argwhere(error<=perc_cutoff).squeeze()
    X_data = X_data[idx]
    Y_data = Y_data[idx]
    
    psfs = psfs[idx]
    df = df.iloc[idx]
    
    reg = linear_model.LinearRegression().fit(X_data, Y_data)
    print("coefficients of equation of plane, (a1, a2): ", reg.coef_)
    
    print("value of intercept, c:", reg.intercept_)
    
    z_fit = reg.predict(X_data)
    df['offset'] = z_fit
    
    return psfs, df

df = pd.read_hdf('/home/miguel/Projects/smlm_z/publication/VIT_031_redo2/locs.hdf', key='locs')

realign_beads(np.zeros(df.shape), df)
plt.hist(error)
plt.show()

sns.scatterplot(data=df, x='x', y='y', hue='offset')
plt.show()
sns.scatterplot(data=df.iloc[idx], x='x', y='y', hue=z_fit)
plt.show()





In [None]:
import glob
import os
for n2 in glob.glob('/home/miguel/Projects/smlm_z/publication/VIT_03*/out*'):
    n = n2.replace('/home/miguel/Projects/smlm_z/publication/', '')
    if 'subset' in n or '035' in n:
        continue
    dirname, outdir = n.split('/')

    nup_path = os.path.join(n2, 'out_nup', 'nup_renders2')
    model_path = os.path.join(n2, 'latest_vit_model3')
    if not os.path.exists(nup_path) and os.path.exists(model_path):
        print(f'cd /home/miguel/Projects/smlm_z/publication/VIT_031_redo/{outdir} && python3 ../../localise_exp_sample.py -mo . -o out_nup && cd out_nup && python3 ../../../render_nup.py && cd /home/miguel/Projects/smlm_z/publication/VIT_031_redo;')



In [None]:
import glob
import json
import os
import pandas as pd
import numpy as np

records = []

reports = glob.glob('/home/miguel/Projects/smlm_z/publication/VIT_03*/out*/out_nup/nup_renders2/*.csv')

def within_n(seps, sep):
    return sum((50-sep<=seps) & (seps<=50+sep))

for r in reports:
    if 'plane_alignment' in r and not 'plane_alignment_3' in r:
        continue
    outdir = '/'.join(r.split('/')[0:8])
    name = '_'.join(outdir.split('/')[-2:])
    training_report = os.path.join(outdir, 'results', 'report.json')
    with open(training_report) as f:
        training_report_data = json.load(f)

    nup_report = pd.read_csv(r)

    bead_report = os.path.join(outdir, 'results', 'results.csv')
    bead_report_data = pd.read_csv(bead_report)
    bead_report_data = bead_report_data[bead_report_data['dataset']=='test']
    
    args = training_report_data['args']
    gen_args = args['gen_args']
    records.append({
        'name': name,
        'train_mae': training_report_data['train_mae'],
        'val_mae': training_report_data['val_mae'],
        'test_mae': training_report_data['test_mae'],
        'aug_ratio': args['aug_ratio'],
        'brightness': args['brightness'],
        'gauss': args['gauss'],
        'min_snr': gen_args['min_snr'],
        'mean_seperation': np.mean(nup_report['seperation']),
        'stdev_seperation': np.std(nup_report['seperation']),
        'sep_count_40-60': within_n(nup_report['seperation'], 10),
        'sep_count_45-55': within_n(nup_report['seperation'], 5),
        'mean_bead_offset': np.mean(np.abs(bead_report_data['offset']))
    })
pd.set_option("display.precision", 3)
df = pd.DataFrame.from_records(records)
df.sort_values(['sep_count_45-55'], ascending=False, inplace=True)
print(df.shape)
print(df)

In [None]:
import seaborn as sns
import matplotlib.pyplot as plt

sns.regplot(data=df, x='mean_bead_offset', y='test_mae')
plt.show()

for c in ['train_mae', 'val_mae', 'test_mae', 'min_snr', 'aug_ratio', 'gauss', 'brightness', 'stdev_seperation', 'mean_seperation', 'mean_bead_offset']:
    sns.scatterplot(data=df, x=c, y='sep_count_40-60')
    plt.show()

In [None]:
import tensorflow as tf
from tifffile import imread
import numpy as np
import pandas as pd

args = {
    'stacks': '/home/miguel/Projects/smlm_z/publication/VIT_031_redo_subset/stacks.ome.tif',
    'locs': '/home/miguel/Projects/smlm_z/publication/VIT_031_redo_subset/locs.hdf',
    'zstep': 10,
    'zrange': 1000
}
psfs = imread(args['stacks'])[:, :, :, :, np.newaxis]
locs = pd.read_hdf(args['locs'], key='locs')
locs['idx'] = np.arange(locs.shape[0])
# idx = (xlim[0] < all_locs['x']) & (all_locs['x'] < xlim[1]) & (ylim[0] < all_locs['y']) & (all_locs['y'] < ylim[1])
# locs = all_locs[idx]
# psfs = all_psfs[locs['idx']]

def filter_zrange(X, zs):
    psfs = X
    valid_ids = np.argwhere(abs(zs.squeeze()) < args['zrange']).squeeze()
    return psfs[valid_ids], zs[valid_ids]
        
ys = []
for offset in locs['offset']:
    zs = ((np.arange(psfs.shape[1])) * args['zstep']) - offset
    ys.append(zs)

ys = np.array(ys)

psfs = np.concatenate(psfs)
ys = np.concatenate(ys)

psfs, ys = filter_zrange(psfs, ys)

print(psfs.shape, ys.shape)

print(ys.min(), ys.max())

train = psfs

In [None]:
import h5py
import pandas as pd
spots_path = '/home/miguel/Projects/data/20230601_MQ_celltype/nup/fov2/storm_1/storm_1_MMStack_Default.ome_spots.hdf5'
with h5py.File(spots_path, "r") as f:
    test = np.array(f['spots'])



In [None]:
# sensitivity = 0.45, quantum efficiency = 0.9

# em gain (1) baseline (100)
# new_spots = (spots - baseline) * sensitivity / (gain)

baseline = 100
sensitivity = 0.45
gain = 1

test2 = (raw_test * gain / sensitivity) + baseline

test2 = test2.astype(np.uint16)

In [None]:
raw_train = imread('/home/miguel/Projects/data/all_openframe_beads/20231205_miguel_mitochondria/stack__10_MMStack_Default.ome.tif')


In [None]:
raw_test = imread('/home/miguel/Projects/data/20230601_MQ_celltype/nup/fov2/storm_1_MMStack_Default_2.ome.tif')

In [None]:
train.min(), train.max(), train.dtype

In [None]:
test.min(), test.max(), test.dtype

In [None]:
raw_train.min(), raw_train.max(), raw_train.dtype

In [None]:
raw_test.min(), raw_test.max(), raw_train.dtype

In [None]:
test2.min(), test2.max(), test2.dtype

In [None]:
# import keras 

# from keras import layers
# from keras.layers.pre

# class NoClipRandomContrast(layers.RandomContrast):
#     def __init__(self, factor, seed=None, clip=False, **kwargs):
#         super().__init__(factor, seed, **kwargs)
#         self.clip = clip

#     def call(self, inputs, training=True):
#         # inputs = self.backend.cast(inputs, self.compute_dtype)
#         if training:
#             seed_generator = self._get_seed_generator(self.backend._backend)
#             factor = self.backend.random.uniform(
#                 shape=(),
#                 minval=1.0 - self.lower,
#                 maxval=1.0 + self.upper,
#                 seed=seed_generator,
#                 dtype=self.compute_dtype,
#             )

#             outputs = self._adjust_constrast(inputs, factor)
#             if self.clip:
#                 outputs = self.backend.numpy.clip(outputs, 0, 255)
#             self.backend.numpy.reshape(outputs, self.backend.shape(inputs))
#             return outputs
#         else:
#             return inputs


In [None]:
from keras import Sequential, layers
MAX_GAUSS_NOISE = 0.001
MAX_TRANSLATION_PX = 0
BRIGHTNESS = 0.2
aug_pipeline = Sequential([
    # layers.GaussianNoise(stddev=MAX_GAUSS_NOISE*X_train[0].max(), seed=args['seed']),
    # layers.RandomTranslation(MAX_TRANSLATION_PX/img_size, MAX_TRANSLATION_PX/img_size, seed=args['seed']),
    layers.RandomBrightness(BRIGHTNESS, value_range=[0, psfs.max()], seed=42),
    NoClipRandomContrast(0.2, seed=42),
    layers.RandomContrast(0.2, seed=42)
])

new_ds = aug_pipeline(train.astype(float), training=True).numpy()
print(new_ds.shape)

In [None]:
for ds in [train, test, new_ds]:
    print(ds.min(), ds.std(), ds.mean(), ds.max())


In [None]:
import pandas as pd
import matplotlib.pyplot as plt

df = pd.read_hdf('/home/miguel/Projects/data/20231212_miguel_openframe/tubulin/FOV1/storm_1/storm_1_MMStack_Default.ome_locs_undrifted.hdf5', key='locs')


sns.scatterplot(data=df, x='x', y='y', alpha=0.1)
plt.xlim((400, 600))
plt.ylim((700, 1000))
plt.gca().invert_yaxis()




plt.show()

In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
old_locs = '/home/miguel/Projects/data/20230601_MQ_celltype/nup/fov2/storm_1/storm_1_MMStack_Default.ome_locs_undrifted_picked_4.hdf5'
old_locs = pd.read_hdf(old_locs, key='locs')
for c in set(old_locs['group']):
    if c != 10:
        continue
    sub_df = old_locs[old_locs['group']==c]
    print(sub_df['x'].max()-sub_df['x'].min())
    # sns.scatterplot(data=sub_df, x='x', y='y')
    # plt.show()

In [None]:
new_locs = pd.read_hdf('/home/miguel/Projects/smlm_z/publication/VIT_031_redo/out6/out_nup/locs_3d.hdf5', key='locs')
new_locs = new_locs[new_locs['x'].isin(old_locs['x'])]

locs = new_locs.merge(old_locs, on=['x', 'y', 'photons', 'bg', 'lpx', 'lpy', 'net_gradient', 'iterations', 'frame', 'likelihood', 'sx', 'sy'])
locs['clusterID'] = locs['group']
print(list(locs))
sns.scatterplot(data=old_locs[old_locs['group']==10], x='x', y='y')
plt.show()
sns.scatterplot(data=locs[locs['group']==10], x='x', y='y')
plt.show()

In [None]:
import matplotlib.pyplot as plt
import numpy as np
from pyotf.otf import HanserPSF, apply_aberration, apply_named_aberration
from pyotf.phaseretrieval import retrieve_phase
from pyotf.utils import prep_data_for_PR, center_data
import sys
sys.path.append('/home/miguel/Projects/uni/phd/smlm_z')
from data.estimate_offset import get_peak_sharpness
from data.visualise import grid_psfs, show_psf_axial

def mse(y, y_pred):
    return np.mean((y - y_pred) ** 2).sum()

psf = X[0]
model_kwargs = dict(
    wl=0.665,
    na=0.9,
    ni=1.34,
    res=0.09,
    zres=0.01,
    size=psf.shape[1],
    zsize=psf.shape[0],
    vec_corr="none",
    condition="none",
)
print(model_kwargs)
psf = prep_data_for_PR(psf, multiplier=1.0)

# Retrieve phase for experimental PSF
PR_result = retrieve_phase(psf, model_kwargs, 100, 1e-4, 1e-4)

PR_result.fit_to_zernikes(16)
PR_result.plot()
PR_result.zd_result.plot()
PR_result.zd_result.plot_named_coefs()
PR_result.plot_convergence()

# Simulate HanserPSF with parameters

result_psf = PR_result.generate_zd_psf(sphase=slice(None))

# this part is very kludgy
PR_result.model.PSFi = psf

PR_result.model.PSFi = result_psf

psf = psf / psf.max()
result_psf = result_psf / result_psf.max()

print(mse(psf, result_psf))

plt.show()
print('Experimental')
show_psf_axial(psf / psf.max(), '', 15)
print('Simulated')
show_psf_axial(result_psf / result_psf.max(), '', 15)


In [None]:
fake_img = np.zeros()

In [None]:
help(PR_result.model.OTFa)

In [None]:
import numpy as np
import seaborn as sns

xy = np.random.uniform(-1, 1, size=(50, 2))
vals = np.random.uniform(0, 500, size=(50, 1))


In [None]:
import pyotf
import pyotf.zernike
pyotf.zernike.name2noll

In [None]:
import numpy as np
from sklearn.metrics import euclidean_distances
import seaborn as sns
import matplotlib.pyplot as plt

s = np.linspace(-1, 1, 25)
x, y = np.meshgrid(s, s)
x = x.flatten()
y = y.flatten()
xy = np.stack((x,y)).T

dists = euclidean_distances([[0, 0]], xy).squeeze()
dists = np.power(dists, 3)
dists /= dists.max()
sns.scatterplot(x=x, y=y, hue=dists)
plt.show()

In [None]:
import sys
sys.path.append('/home/miguel/Projects/uni/phd/smlm_z')
from data.visualise import grid_psfs, show_psf_axial
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from sklearn.metrics import euclidean_distances
import seaborn as sns
model_kwargs = dict(
    wl=647,
    na=1.3,
    ni=1.51,
    res=90,
    zres=50,
    size=32,
    zsize=100,
    vec_corr="none",
    condition="none",
)
from pyotf.otf import HanserPSF, apply_aberration, apply_named_aberration

def get_ab(a1, a2):
    base_args = np.array([0, 0, 0, 0, 0, a1, a2])
    return base_args
    
def gen_fake_psf(model_kwargs, ab):
    model = HanserPSF(**model_kwargs)
    # model = apply_named_aberration(model, 'oblique astigmatism', 2)
    model = apply_aberration(model, np.zeros(ab.shape), ab)

    psf = model.PSFi
    psf = psf.astype(float)
    # psf = psf / psf.max()
    return psf

def add_noise(psf):
    return psf + np.random.normal(0, 1e-2, size=psf.shape)

coords = []
psfs = []

n_points = 50

lim = 1
xy = np.random.uniform(-1, 1, size=(n_points, 2))
# xy = np.stack([np.linspace(-lim, lim, n_points), np.linspace(-lim, lim, n_points)]).T
center = [[0, 0]]
dists = euclidean_distances(xy, center).squeeze()
# dists = 2**dists
dists = np.power(dists, 3)
dists /= dists.max()

a1s = []
a2s = []

for i in range(n_points):
    x, y = xy[i]
    dist = dists[i]

    a1 = (1 if x>0 else -1) * dist
    a2 = (1 if y>0 else -1) * dist
    a1s.append(a1)
    a2s.append(a2)
    ab = get_ab(a1, a2)
    psf = gen_fake_psf(model_kwargs, ab*2)
#     psf = add_noise(psf)
    coords.append([x, y])
    psfs.append(psf)
#     plt.imshow(grid_psfs(psf[::7]).squeeze())
#     plt.show()

psfs = np.array(psfs)
df = pd.DataFrame(coords, columns=['x', 'y'])
sns.scatterplot(x=xy[:, 0], y=xy[:, 1], hue=a1s)
plt.show()
sns.scatterplot(x=xy[:, 0], y=xy[:, 1], hue=a2s)
plt.show()

In [None]:
print(psfs.shape)
psfs.reshape((psfs.shape[0], -1)).shape

In [None]:
from sklearn.decomposition import PCA
import seaborn as sns
pca = PCA().fit(psfs.reshape((psfs.shape[0], -1)))

d = pca.transform(psfs.reshape((psfs.shape[0], -1)))
sns.scatterplot(x=d[:, 0], y=d[:, 1], hue=dists)
plt.show()

In [None]:
# randomly roll psfs
import seaborn as sns

z_step = 50

# subsample psfs
# subsample = 5
# psfs = psfs[:, ::subsample, :, :]
# z_step *= subsample

def roll_psf(psf, roll):
    rolled_psf = np.roll(psfs[i], shift=roll, axis=0)
#     if roll < 0:
#         rolled_psf[roll:] = 0
#     else:
#         rolled_psf[:roll] = 0
    return rolled_psf
#     show_psf_axial(psfs[i], '', 7)
#     show_psf_axial(rolled_psf, roll, 7)

rolls = []
for i in range(psfs.shape[0]):
    roll = np.random.randint(-5, 5)
    psfs[i] = roll_psf(psfs[i], roll)
    rolls.append(roll)
df['roll'] = np.array(rolls) * z_step
# df['roll'] *= 10



sns.scatterplot(data=df, x='x', y='y', hue='roll')
plt.show()
for i in range(psfs.shape[0]):
    show_psf_axial(psfs[i], i, 7)

In [None]:
d = pca.transform(psfs.reshape((psfs.shape[0], -1)))
sns.scatterplot(x=d[:, 0], y=d[:, 1], hue=dists)
plt.show()

In [None]:
# # Load real data
# import pandas as pd
# from tifffile import imread
# import matplotlib.pyplot as plt
# df = pd.read_csv('/home/miguel/Projects/uni/phd/smlm_z/smlm-z/data/05_model_input/coords.csv')
# psfs = imread('/home/miguel/Projects/uni/phd/smlm_z/smlm-z/data/05_model_input/spots.tif')
# df['roll'] = 0
# assert psfs.shape[0] == df.shape[0]

In [None]:
import numpy as np
from skimage.exposure import match_histograms
from sklearn.metrics.pairwise import euclidean_distances
from scipy.sparse import csr_matrix
from scipy.sparse.csgraph import minimum_spanning_tree
from multiprocessing import Pool
from itertools import product
from functools import partial
import tqdm
from data.estimate_offset import get_peak_sharpness
from data.visualise import show_psf_axial
from keras.losses import MeanSquaredError
from multiprocessing.spawn import prepare
import numpy as np
from scipy.interpolate import UnivariateSpline
from tqdm import trange
import tensorflow as tf

mse = MeanSquaredError()
DEBUG = False
UPSCALE_RATIO = 10
def norm_zero_one(s):
    max_s = s.max()
    min_s = s.min()
    return (s - min_s) / (max_s - min_s)


def pad_and_fit_spline(coords, psf, z, z_ups):
    x, y = coords
    zs = psf[:, x, y]
    cs = UnivariateSpline(z, zs, k=1, s=1e-3)
    if False:
        plt.scatter(z, zs, label='raw')
        plt.plot(z_ups, cs(z_ups), label='smooth')
        plt.legend()
        plt.show()
    return x, y, cs(z_ups)
    
def upsample_psf(psf, ratio=UPSCALE_RATIO):
    pad_width = 10
    z = np.arange(-pad_width, psf.shape[0] + pad_width)
    z_ups = np.arange(0, psf.shape[0], 1/ratio)
    upsampled_psf = np.zeros((z_ups.shape[0], *psf.shape[1:]))
    
    psf = np.pad(psf, ((pad_width, pad_width), (0, 0), (0, 0)), mode='edge')
    xys = list(product(np.arange(psf.shape[1]), np.arange(psf.shape[2])))
    func = partial(pad_and_fit_spline, psf=psf, z=z, z_ups=z_ups)
    res = list(map(func, xys))
    # with Pool(8) as p:
    #     res = list(p.imap(func, xys))
    for x, y, z_col in res:
        upsampled_psf[:, x, y] = z_col

    return upsampled_psf


def plot_correction(target, img, psf_corrected, errors):
    if True:
        plt.plot(target.max(axis=(1,2)), label='target')
        plt.plot(img.max(axis=(1,2)),  label='original')
        plt.plot(psf_corrected.max(axis=(1,2)), label='corrected', )

        plt.legend()
        plt.show()

        
mse = MeanSquaredError(reduction='sum')
def loss_func(true_m, pred_m):
    m = tf.math.abs(true_m-pred_m)
    m = tf.math.square(m*pred_m)
    return tf.math.reduce_mean(m)

def tf_find_optimal_roll(target, img, upscale_ratio=UPSCALE_RATIO):
    ref_tf = tf.convert_to_tensor(target)
    img_tf = tf.convert_to_tensor(img)
    errors = []

    for i in range(img.shape[0]):
#         error = loss_func(ref_tf, img_tf)
        error = mse(ref_tf, img_tf)
        errors.append(error)
        img_tf = tf.roll(img_tf, 1, axis=0)

    best_i = tf.argmin(errors).numpy()
    # Prefer small backwards roll to large forwards roll
    if abs(best_i - img.shape[0]) < best_i:
        best_i = best_i - img.shape[0]

    psf_corrected = np.roll(img, int(best_i), axis=0)
    plot_correction(target, img, psf_corrected, errors)

    return best_i/upscale_ratio

def prepare_psf(psf):
#     psf = gaussian(psf, sigma=1)
    psf = psf.copy()
    psf = np.square(psf)
    psf = norm_zero_one(psf)
    psf = upsample_psf(psf)
#     psf = mask_img_stack(psf, 12)
    return psf


def align_psfs(psf, psf2):
    psf = prepare_psf(psf)
    psf2 = prepare_psf(psf2)
    psf = match_histograms(psf, psf2)
    offset = tf_find_optimal_roll(psf, psf2)
    return offset * z_step

def find_seed_psf(df):
    # Seed PSF - most centered PSF in FOV
    center = df[['x', 'y']].mean(axis=0).to_numpy()
    coords = df[['x', 'y']].to_numpy()
    dists = euclidean_distances([center], coords).squeeze()
    first_point = np.argmin(dists)
    print(first_point)
    return first_point

def get_or_prepare_psf(prepped_psfs, raw_psfs, idx):
    if idx not in prepped_psfs:
        prepped_psfs[idx] = prepare_psf(raw_psfs[idx])
    return prepped_psfs[idx]

xys = df[['x', 'y']].to_numpy()

errors = []
def classic_align_psfs(psfs, df):
    print(f'Aligning {psfs.shape} psfs...')

    seed_psf = find_seed_psf(df)
    ref_psf = prepare_psf(psfs[seed_psf])
    offsets = np.zeros((psfs.shape[0]))

    ref_0 = get_peak_sharpness(psfs[seed_psf])

    for i in trange(0, psfs.shape[0]):
        if i == seed_psf:
            offsets[i] = 0
            errors.append(0)
            continue
        psf = psfs[i]
        psf = prepare_psf(psf)
#         psf = match_histograms(psf, ref_psf)
        offset = tf_find_optimal_roll(ref_psf, psf) * z_step
        offsets[i] = offset
        correct_dist = df['roll'][seed_psf] - df['roll'][i]
        euc_dist = euclidean_distances([xys[i]], [xys[seed_psf]]).squeeze()
        errors.append(correct_dist)
        print(f"{seed_psf} -> {i}, {offset}, {correct_dist}")
        if DEBUG:
            offset_psf = np.roll(psf, shift=-int(offset), axis=0)
            imgs = np.concatenate((ref_psf, offset_psf), axis=2)
            show_psf_axial(imgs, subsample_n=30)
            
#         plt.imshow(grid_psfs(psf[::5]).T)
#         plt.show()
#         plt.imshow(grid_psfs(ref_psf[::5]).T)
#         plt.show()

#     offsets -= ref_0

    return offsets

classic_offsets = classic_align_psfs(psfs, df)

In [None]:

import numpy as np
import scipy.optimize as opt
import skimage.filters as filters

def gaussian(x, amplitude, mean, stddev):
    return amplitude * np.exp(-(x - mean) ** 2 / (2 * stddev ** 2))

def measure_psf_fwhm(psf):
    # Normalize the PSF to range [0, 1]
    psf_norm = (psf - np.min(psf)) / (np.max(psf) - np.min(psf))
    
    # Find the center of the PSF using the maximum intensity
    center = np.unravel_index(np.argmax(psf_norm), psf_norm.shape)
    # Extract a 1D slice of the PSF along the z-axis passing through the center
    z_slice = psf_norm[:, center[0]]
    
    # Estimate the initial parameters of the Gaussian fit
    amplitude = np.max(z_slice) - np.min(z_slice)
    mean = center[0]
    stddev = 2
    
    # Fit the Gaussian to the 1D slice using least squares optimization
    try:
        popt, _ = opt.curve_fit(gaussian, np.arange(z_slice.size), z_slice, p0=[amplitude, mean, stddev])
    except RuntimeError:
        return np.inf
    # Compute the FWHM of the Gaussian fit
    fwhm = 2 * np.sqrt(2 * np.log(2)) * popt[2]
    
    return fwhm

def determine_best_focus_slice(psf):
    # Measure the FWHM of the PSF for each z-slice
    fwhm_values = []
    for i in range(psf.shape[0]):
        fwhm = measure_psf_fwhm(psf[i])
        fwhm_values.append(fwhm)
    
    # Find the index of the z-slice with the minimum FWHM value
    best_slice_idx = np.argmin(fwhm_values)
    plt.rcParams['figure.figsize'] = [2, 2]
    print(f'best slice: {best_slice_idx}')
    slices = psf[[best_slice_idx-5, best_slice_idx, best_slice_idx], :, :]
    show_psf_axial(slices, '', 1)
    return best_slice_idx

def fwhm_offsets(psfs):
    idxs = np.array([determine_best_focus_slice(psf) for psf in psfs]).astype(float)
    idxs -= np.mean(idxs)
    return idxs

# measure_psf_fwhm(psfs[0][0])
new_offsets = fwhm_offsets(psfs)

In [None]:
plt.title('Alignment error (frames, 50nm z-step)')
sns.scatterplot(data=df, x='x', y='y', hue=abs(df['roll']-new_offsets))
plt.show()
plt.scatter(df['roll'], abs(df['roll']+classic_offsets))
plt.show()

In [None]:
from sklearn.metrics.pairwise import euclidean_distances
from sklearn.utils.graph_shortest_path import graph_shortest_path
import networkx as nx

def get_center_point(df):
    center_point = [[0, 0]]
    df['dists'] = euclidean_distances(df[['x', 'y']], center_point)
    idx = df['dists'].idxmin()
    return df.iloc[idx]

dists = euclidean_distances(df[['x', 'y']])

G = nx.from_numpy_matrix(dists)
G = nx.minimum_spanning_tree(G)
center_point = get_center_point(df)

from itertools import combinations
for src, target in G.edges:
    G[src][target]['weight'] = dists[src, target]

nx.draw(G, pos=df[['x', 'y']].values, with_labels=True, node_size=100, node_color='lightgreen')

In [None]:
from sklearn.metrics.pairwise import euclidean_distances
from sklearn.utils.graph_shortest_path import graph_shortest_path
from tqdm import trange
import networkx as nx

UPSCALE_RATIO = 10

def pad_and_fit_spline(coords, psf, z, z_ups):
    x, y = coords
    zs = psf[:, x, y]
    cs = UnivariateSpline(z, zs, k=1, s=1e-6)
    if False:
        plt.scatter(z, zs, label='raw')
        plt.plot(z_ups, cs(z_ups), label='smooth')
        plt.legend()
        plt.show()
    return x, y, cs(z_ups)
    
def upsample_psf(psf, ratio=UPSCALE_RATIO):
    pad_width = 10
    z = np.arange(-pad_width, psf.shape[0] + pad_width)
    z_ups = np.arange(0, psf.shape[0], 1/ratio)
    upsampled_psf = np.zeros((z_ups.shape[0], *psf.shape[1:]))
    
    psf = np.pad(psf, ((pad_width, pad_width), (0, 0), (0, 0)), mode='edge')
    xys = list(product(np.arange(psf.shape[1]), np.arange(psf.shape[2])))
    func = partial(pad_and_fit_spline, psf=psf, z=z, z_ups=z_ups)
    res = list(map(func, xys))
    # with Pool(8) as p:
    #     res = list(p.imap(func, xys))
    for x, y, z_col in res:
        upsampled_psf[:, x, y] = z_col

    return upsampled_psf


def plot_correction(target, img, psf_corrected, errors):
    if True:
        plt.plot(target.max(axis=(1,2)), label='target')
        plt.plot(img.max(axis=(1,2)),  label='original')
        plt.plot(psf_corrected.max(axis=(1,2)), label='corrected', )

        plt.legend()
        plt.show()

        
mse = MeanSquaredError(reduction='sum')
def loss_func(true_m, pred_m):
    m = tf.math.abs(true_m-pred_m)
    m = tf.math.square(m*true_m)
    return tf.math.reduce_mean(m)

def tf_find_optimal_roll(target, img, upscale_ratio=UPSCALE_RATIO):
    ref_tf = tf.convert_to_tensor(target)
    img_tf = tf.convert_to_tensor(img)
    errors = []

    for i in range(img.shape[0]):
#         error = loss_func(ref_tf, img_tf)
        error = mse(ref_tf, img_tf)
        errors.append(error)
        img_tf = tf.roll(img_tf, 1, axis=0)

    best_i = tf.argmin(errors).numpy()
    # Prefer small backwards roll to large forwards roll
    if abs(best_i - img.shape[0]) < best_i:
        best_i = best_i - img.shape[0]

    psf_corrected = np.roll(img, int(best_i), axis=0)
    plot_correction(target, img, psf_corrected, errors)

    return best_i/upscale_ratio


def prepare_psf(psf):
#     psf = gaussian(psf, sigma=1)
    psf = psf.copy()
    psf = psf * psf
    psf = norm_zero_one(psf.copy())
    psf = upsample_psf(psf)
    # psf = mask_img_stack(psf, 12)
    return psf


def align_psfs(psf, psf2):
    psf = prepare_psf(psf)
    psf2 = prepare_psf(psf2)
    psf = match_histograms(psf, psf2)
    offset = tf_find_optimal_roll(psf, psf2)
    return offset * z_step

offsets = np.zeros((df.shape[0], df.shape[0]))
offsets[:] = None

src_node = 0
target_node = center_point.name
all_offsets = []
def get_path_offset(G, src_node, target_node):
    spath = nx.shortest_path(G, src_node, target_node)
    if not np.isnan(offsets[src, target]):
        cumul = offsets[src, target]
    else:
        cumul = 0
        for i in range(0, len(spath)-1):
            a, b = spath[i], spath[i+1]
            if not np.isnan(offsets[a, b]):
                offset = offsets[a, b]
            else:
                offset = align_psfs(psfs[a], psfs[b])
                offsets[a, b] = offset
                offsets[b, a] = -offset
                diff = (df['roll'][a] - df['roll'][b])
                print(f'{a} -> {b}: {offsets[a, b]}, {diff}')
            
            cumul += offset
        offsets[src_node, target_node] = cumul
        offsets[target_node, src_node] = -cumul
    all_offsets.append(cumul)
    
for i in trange(0, df.shape[0]):
    if i == target_node:
        all_offsets.append(0)
        continue
    get_path_offset(G, i, target_node)
all_offsets = np.array(all_offsets)

# def roll_psf(psf, roll):
#     rolled_psf = np.roll(psf, shift=roll, axis=0)
# #     if roll < 0:
# #         rolled_psf[roll:] = 0
# #     else:
# #         rolled_psf[:roll] = 0
#     return rolled_psf
# rolls = df['roll']
# print(rolls[5], rolls[6])
# psf = psfs[5]
# rolled_psf = roll_psf(psf, 10)
# print(align_psfs(psf, rolled_psf))



In [None]:
print(classic_offsets.shape)
print(all_offsets.shape)

In [None]:
import seaborn as sns
from sklearn.metrics import mean_absolute_error

plt.scatter(df['roll'], -classic_offsets-150, label='classic')
plt.scatter(df['roll'], all_offsets-150, label='new')
plt.plot(df['roll'], df['roll'], c='red', label='1')
plt.legend()
plt.show()

print(round(mean_absolute_error(df['roll'], -classic_offsets-150), 5))
print(round(mean_absolute_error(df['roll'], all_offsets-150), 5))


In [None]:
co = -classic_offsets*50
sns.scatterplot(data=df, x='x', y='y', hue=co-co.min())
plt.title('Offsets [nm]')
plt.xlabel('x [nm]')
plt.ylabel('y [nm]')
plt.show()

In [None]:
coords = df[['x', 'y']].to_numpy()
target_node_coords = coords[69]
dists = euclidean_distances(coords, [[0, 0]])
plt.scatter(df['x'], co-co.min())
plt.ylabel('offset [nm]')
plt.xlabel('x [nm]')
plt.show()
plt.scatter(df['y'], co-co.min())
plt.ylabel('offset [nm]')
plt.xlabel('y [nm]')
plt.show()

In [None]:
import networkx as nx
G = nx.Graph()
for path in tri.simplices:
    nx.add_path(G, path)

from itertools import combinations
for src, target in G.edges:
    G[src][target]['weight'] = dists[src, target]

edge_weights = []
for edge in G.edges():
    src, target = edge
    edge_weights.append(offsets[src-1, target-1])
    
nx.draw(G, pos=df[['x', 'y']].values, edge_color=edge_weights, width=edge_weights, with_labels=True, node_size=100, node_color='lightgreen')

In [None]:
z_step = 50
psf_z = np.arange(0, psfs.shape[1]) * z_step

psf2_z = (np.arange(0, psfs.shape[1]) * z_step) + overall_align

all_z = np.concatenate((psf_z, psf2_z))
all_psfs = np.concatenate((psfs[src_node], psfs[target_node]))
idx = np.argsort(all_z)
all_psfs = all_psfs[idx]

def norm_zero_one(psf):
    return (psf - psf.min()) / (psf.max() - psf.min())

all_psfs = np.stack([norm_zero_one(p) for p in all_psfs])

from data.visualise import grid_psfs

plt.imshow(grid_psfs(norm_zero_one(psfs[src_node])))
plt.show()


plt.imshow(grid_psfs(norm_zero_one(psfs[target_node])))
plt.show()

plt.imshow(grid_psfs(all_psfs))
plt.show()

In [None]:
import pyotf
from tifffile import imread, imwrite

psf = imread('/home/miguel/Projects/uni/phd/smlm_z/test/psfs/20220506_Miguel_beads_zeiss_training_3_beads.tif')[0]

from pyotf.otf import HanserPSF
from pyotf.phaseretrieval import retrieve_phase
from pyotf.zernike import zernike

model_kwargs = dict(
    wl=0.647,
    na=1.3,
    ni=1.33,
    res=0.106,
    size=psf.shape[1],
    zsize=psf.shape[0],
    zres=0.01,
    vec_corr="none",
    condition="none",
)


In [None]:
print(psf.shape)

In [None]:
PR_result = retrieve_phase(
    psf, model_kwargs, max_iters=200, pupil_tol=0, mse_tol=0, phase_only=False
)


In [None]:
PR_result.__dict__.keys()

In [None]:
PR_result.plot()
PR_result.plot_convergence()
PR_result.fit_to_zernikes(64)
PR_result.zd_result.plot()
PR_result.zd_result.plot_named_coefs()
PR_result.zd_result.plot_coefs()

from pyotf.otf import apply_aberration
result_psf = HanserPSF(**model_kwargs)
result_psf = apply_aberration(result_psf, PR_result.zd_result.mcoefs, PR_result.zd_result.pcoefs)

In [None]:
import numpy as np
from data.visualise import show_psf_axial
psfs = np.concatenate((psf, result_psf.PSFi), axis=2)
show_psf_axial(psfs)

In [None]:
%load_ext autoreload
%autoreload 2
from data.datasets import TrainingPicassoDataset
from config.datasets import dataset_configs
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np
from sklearn.metrics.pairwise import euclidean_distances
from scipy.sparse import csr_matrix
from scipy.sparse.csgraph import minimum_spanning_tree

cfg = dataset_configs['paired_bead_stacks']['training_1']

locs = cfg['bpath'] / cfg['locs']
df = pd.read_hdf(locs, 'locs')
center = df[['x', 'y']].mean(axis=0).to_numpy()
coords = df[['x', 'y']].to_numpy()
dists = euclidean_distances([center], coords).squeeze()
df['dist_to_center'] = dists

first_point = np.argmin(dists)
df['source'] = False
df.loc[first_point, 'source'] = True

dists = euclidean_distances(coords, coords)
m = csr_matrix(dists)
tree = minimum_spanning_tree(m).toarray()
tree[tree>0] = 1
edges = np.where(tree>0)

abs_offsets = np.zeros((df.shape[0]))
abs_offsets[first_point] = 10

tree += tree.T
print(tree)
print(first_point)
alignable_psfs = set(np.argwhere(tree[first_point, :] > 0).squeeze())
while len(alignable_psfs):
    unaligned_psf = alignable_psfs.pop()
    known_offsets = set(np.argwhere(abs_offsets>0).flatten())
    connected_points = set(np.argwhere(tree[unaligned_psf, :] > 0).flatten())
    
    target_psf = known_offsets.intersection(connected_points).pop()
    abs_offsets[unaligned_psf] += abs_offsets[target_psf] + 1
    alignable_psfs = alignable_psfs.union(set(np.argwhere(tree[:, unaligned_psf]).flatten()))
    alignable_psfs = alignable_psfs.difference(np.argwhere(abs_offsets>0).flatten())

print(abs_offsets)

edges = np.where(tree>0)
for i in range(edges[0].shape[0]):
    src, target = edges[0][i], edges[1][i]
    x = [coords[src][0], coords[target][0]]
    y = [coords[src][1], coords[target][1]]
    plt.plot(x, y, color='0')
sns.scatterplot(data=df, x='x', y='y', hue=abs_offsets.astype(int))
plt.show()

In [None]:
%load_ext autoreload
%autoreload 2
from pyotf.otf import HanserPSF, apply_aberration, apply_named_aberration
import matplotlib.pyplot as plt
import numpy as np

plt.rcParams['figure.figsize'] = [5, 5]

from data.visualise import show_psf_axial, grid_psfs
from data.align_psfs import align_psfs, tf_find_optimal_roll, mask_img_stack, norm_zero_one

align_psfs.debug = False

kwargs = dict(
    wl=647,
    na=1.3,
    ni=1.51,
    res=106,
    zres=50,
    size=32,
    zsize=200,
    vec_corr="none",
    condition="none",
)
psf = HanserPSF(**kwargs)
psf = apply_aberration(psf, np.array([0, 0, 0, 0, 0]), np.array([0, 0, 0, 0, 1]))

blank_psf = psf.PSFi

blank_psf = norm_zero_one(blank_psf) * 255

from experiments.noise.noise_psf import EMCCD
import matplotlib.pyplot as plt
import numpy as np

e = EMCCD(noise_background=125)
e.add_noise(np.random.randint(0, 255, size=(32,32)))

psf = blank_psf[blank_psf.shape[0]//2]
plt.imshow(psf)
plt.show()
plt.imshow(e.add_noise(psf))
plt.show()

In [None]:
%load_ext autoreload
%autoreload 2
from experiments.noise.noise_psf import generate_noisy_psf
import numpy as np
import matplotlib.pyplot as plt

plt.rcParams['figure.figsize'] = [5, 5]

offset = 5
rolled_psf = np.roll(blank_psf, offset, axis=0)

rolled_psf = generate_noisy_psf(rolled_psf)
from skimage.exposure import match_histograms

psfs = np.stack((blank_psf, rolled_psf))
plt.plot(psfs[0].max(axis=(1,2)), label='target')
plt.plot(psfs[1].max(axis=(1,2)), label='psf')
plt.legend()
plt.show()

res = tf_find_optimal_roll(rolled_psf, blank_psf, 1)

offsets = [200, 200-res]

print('offsets', offsets)
assert res == -offset

imgs = np.concatenate(psfs)
zs = []
labels = ['target', 'psf']

for i, o in enumerate(offsets):
    vals = (np.arange(0, psfs[0].shape[0])*10) - (o*10)
    print(vals[0:5])
    plt.plot(vals, np.max(psfs[i], axis=(1,2)), label=labels[i])
plt.legend()
plt.show()



In [None]:
from data.estimate_offset import get_peak_sharpness


fake_stacks = []
offsets = []
for offset in [0, 5, 10]:
    offsets.append(offset)
    rolled_psf = np.roll(blank_psf, offset, axis=0)
#     rolled_psf = generate_noisy_psf(rolled_psf)
    fake_stacks.append(rolled_psf)

fake_stacks = np.array(fake_stacks)
fake_stacks[fake_stacks<0] = 0

pos0 = (get_peak_sharpness(fake_stacks[0]) * 50) * 2
print(pos0)


In [None]:
corr_offsets = align_psfs(fake_stacks)
plt.rcParams['figure.figsize'] = [5, 5]

print(corr_offsets)
voxel_size = 50
zs = []
for offset, psf in zip(corr_offsets, fake_stacks):
    z = ((np.arange(0, psf.shape[0]) - offset)  * voxel_size) -  pos0
    print(z.min(), z.max())
    zs.append(z)
    plt.plot(z, psf.max(axis=(1,2)))
plt.show()

    




In [None]:
from data.datasets import stack_offset_to_z
zs = []
for psf, offset in zip(fake_stacks, corr_offsets):
    z = stack_offset_to_z(offset, psf, 10)
    print(z.min(), z.max())
    zs.append(z)

In [None]:
# Inspect training dataset
from data.visualise import grid_psfs
plt.rcParams['figure.figsize'] = [50, 50]

zs = np.concatenate(zs)
imgs = np.concatenate(fake_stacks)
idx = np.argsort(zs)
imgs = imgs[idx]

plt.imshow(grid_psfs(imgs.squeeze()))
plt.show()





In [None]:
import h5py

def read_spots(dirpath):
    f = h5py.File(dirpath, 'r')
    spots = np.array(f['spots'])[:, :, :, np.newaxis]
    f.close()
    return spots

dirpath = '/home/miguel/Projects/uni/data/smlm_3d/picasso_sim/grid_pairs/grid_pairs_spots.hdf5'
test_spots = read_spots(dirpath)

dirpath = '/home/miguel/Projects/uni/data/smlm_3d/picasso_test/training/NPC-A647-3D-BEADS/0021_spots.hdf5'

train_spots = read_spots(dirpath)

print(train_spots.shape)
print(test_spots.shape)
for spots in [train_spots, test_spots]:
    print(spots.min(), spots.mean(), spots.std(), spots.max())

In [None]:
import matplotlib.pyplot as plt
from skimage.exposure import match_histograms

plt.rcParams['figure.figsize'] = [3, 3]
train_spot = train_spots[0]

mean_img = np.mean(train_spots, axis=0)

plt.imshow(train_spot)
plt.show()
plt.imshow(mean_img)
plt.show()
plt.imshow(match_histograms(train_spot, mean_img))
plt.show()

print(train_spot.min(), train_spot.max())

print(mean_img.min(), mean_img.max())

print(match_histograms(train_spot, mean_img).min(), match_histograms(train_spot, mean_img).max())

In [None]:
import numpy as np
from skimage.exposure import equalize_hist
plt.rcParams['figure.figsize'] = [5,5]

mean_img = np.mean(train_spots, axis=0)

for i in range(0, 10):
    show_imgs(train_spots[i], match_histograms(train_spots[i], mean_img))

In [None]:
%load_ext autoreload
%autoreload 2
from data.datasets import TrainingPicassoDataset
from config.datasets import dataset_configs

cfg = dataset_configs['paired_bead_stacks']['training_1']
print(cfg)

ds = TrainingPicassoDataset(cfg)

In [None]:
import numpy as np
import matplotlib.pyplot as plt

radius = 107
n_elements = 8
z_depths = [0, 200]
offset = (0, 0)

angles = [i*((np.pi*2)/n_elements) for i in range(n_elements)]
x = np.array([np.cos(a)*radius for a in angles])
y = np.array([np.sin(a)*radius for a in angles])

x += offset[0]
y += offset[0]

xs = [round(n, 3) for n in np.concatenate((x, x))]
ys = [round(n, 3) for n in np.concatenate((y, y))]
zs = [round(n, 3) for n in (sorted(z_depths*n_elements))]

print(len(xs), len(ys), len(zs))

plt.scatter(xs, ys)
plt.show()

print(xs)
print(ys)
print(zs)

In [None]:
dataset ='picasso_test'
BOUND = 31

from model.model import load_trained_model


model = load_trained_model(dataset)


In [None]:
model.summary()

In [None]:
'''
--------
|    x |   <- 50nm deeper than other
|      |
| x    |
--------

5 structures, frame 16 px,
structureX: 1000,3000
structureY: 1000,3000

structure3D: 0,50
ExchangeLabels:1,1
'''
dirpath = '/home/miguel/Projects/uni/data/smlm_3d/picasso_sim/50nm/50nm'

# dirpath = '/home/miguel/Projects/uni/data/smlm_3d/picasso_sim/grid_single/grid_single'
# dirpath = '/home/miguel/Projects/uni/data/smlm_3d/picasso_sim/grid_pairs/grid_pairs'

In [None]:
import yaml

yaml_file = f'{dirpath}.yaml'

class SafeLoaderIgnoreUnknown(yaml.SafeLoader):
    def ignore_unknown(self, node):
        return None 

SafeLoaderIgnoreUnknown.add_constructor(None, SafeLoaderIgnoreUnknown.ignore_unknown)

with open(yaml_file, "r") as stream:
    root = yaml.load(stream, Loader=SafeLoaderIgnoreUnknown)
with open(yaml_file, "w") as stream:
    yaml.dump(root, stream)

In [None]:
%load_ext autoreload
%autoreload 2

import h5py
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

plt.rcParams['figure.figsize'] = [5,5]

def norm_zero_one(img):
    img_max = img.max()
    img_min = img.min()
    return (img - img_min) / (img_max - img_min)

def norm_one_one(img):
    return (2 * norm_zero_one(img)) - 1 

locs = f'{dirpath}_locs.hdf5'
spots = f'{dirpath}_spots.hdf5'
print(locs)
print(spots)
df = pd.read_hdf(locs, 'locs')

f = h5py.File(spots, 'r')
spots = np.array(f['spots'])[:, :, :, np.newaxis]

df['id'] = np.arange(0, df.shape[0])
print(spots.shape)
print(df.shape, spots.shape)
f.close()

from data.visualise import grid_psfs
# plt.imshow(grid_psfs(spots.squeeze()))
# plt.show()

assert df.shape[0] == spots.shape[0]



from data.datasets import norm_dataset_from_config, standardise, load_ref_img_and_norm

print(spots.shape)
print(spots.min(), spots.max())
# spots = standardise(spots)
# spots = load_ref_img_and_norm(spots)
for s in spots[0:5]:
    print(s.min(), s.max())
    
spots = np.stack([norm_one_one(img) for img in spots])

print(spots.min(), spots.max())



In [None]:
plt.rcParams['figure.figsize'] = [20, 20]
plt.imshow(grid_psfs(spots[0:50].squeeze()).squeeze())
plt.show()

mean_img = spots[1]

spots_matched_hist = np.stack([match_histograms(img, mean_img) for img in spots])
plt.rcParams['figure.figsize'] = [5, 5]
plt.imshow(mean_img)
plt.show()
plt.rcParams['figure.figsize'] = [20, 20]
plt.imshow(grid_psfs(spots_matched_hist[0:50].squeeze()).squeeze())
plt.show()


In [None]:
# # Rescale to [-1, 1]
# spots = np.stack([norm_one_one(img) for img in spots])


plt.gca().invert_yaxis()

df['index'] = np.arange(0, df.shape[0])
print(spots.shape)
plt.rcParams['figure.figsize'] = [5, 5]
sns.scatterplot(data=df, x='x', y='y', marker='+')
plt.show()


coords = np.zeros((spots.shape[0], 2))

pred = model.predict((spots_matched_hist, coords)).squeeze()
print(np.std(pred))
df['pred'] = pred
print(pred.min(), pred.max())

plt.rcParams['figure.figsize'] = [5, 5]
sns.histplot(pred, bins=50)
plt.xlabel('Z position (nm)')
plt.show()


sub_imgs = spots[0:100]
sub_preds = pred[0:100]
plt.rcParams['figure.figsize'] = [100, 100]
from data.visualise import grid_psfs
print(np.sort(sub_preds.squeeze()))

plt.imshow(grid_psfs(sub_imgs[np.argsort(sub_preds.squeeze())].squeeze()))
plt.show()

In [None]:
print(np.std(pred))

In [None]:
%load_ext autoreload
%autoreload 2

from config.datasets import dataset_configs

from data.datasets import TrainingDataSet, ExperimentalDataSet, GenericDataSet, MultiTrainingDataset, TrainingPicassoDataset

dataset = 'picasso_test'
version = ''
cfg = dataset_configs[dataset]['training']
print(cfg)

train_dataset = TrainingPicassoDataset(cfg)

In [None]:
train_dataset.data['train'][0][0].max()

In [None]:
plt.rcParams['figure.figsize'] = [5, 5]

test_imgs = spots.flatten()

plt.hist(test_imgs, label='exp', alpha=0.5)

for k in train_dataset.data.keys():
    imgs = train_dataset.data[k][0][0].flatten()
    print(imgs.min(), imgs.max())
    plt.hist(imgs.flatten(), label=k, alpha=0.5)

plt.legend()
plt.yscale('log')
plt.show()

In [None]:
mean_pixel_vals = np.mean(spots, axis=(1,2))
plt.rcParams['figure.figsize'] = [5, 5]
sns.kdeplot(mean_pixel_vals, pred)
plt.xlabel('Mean pixel value')
plt.ylabel('Pred')
plt.show()

In [None]:
plt.rcParams['figure.figsize'] = [5, 5]
# df['emitter'] = df['x'] < 60
df['pred'] = pred.squeeze()
# df['pred2'] = df['pred']>-250
df['snr'] = [np.max(s)/np.median(s) for s in spots]
df['error'] = df['pred']+130
sns.scatterplot(data=df, x='snr', y='pred')
plt.show()
sns.scatterplot(data=df, x='snr', y='error')
plt.show()
# sns.histplot(daata=df, x='pred', hue='emitter')
# plt.show()

# sns.scatterplot(data=df, x='x', y='y', hue='pred2')
# plt.show()



In [None]:
"""
Checks influence of polar coords on Z localisation
"""

# from itertools import product

# plt.rcParams['figure.figsize'] = [5, 5]

# spot = spots[0:1][:, :, :, np.newaxis]
# thetas = np.linspace(0, 1, 20)
# rhos = np.linspace(0, 1, 20)

# coords = np.array(list(product(thetas, rhos))).squeeze()
# spot = np.repeat(spot, coords.shape[0], axis=0)

# preds = model.predict((spot, coords))

# tmp_df = pd.DataFrame.from_dict({'theta': coords[:, 0], 'rho': coords[:, 1], 'z': preds.squeeze()})
# sns.scatterplot(data=tmp_df, x='theta', y='z')
# plt.show()
# sns.scatterplot(data=tmp_df, x='rho', y='z')
# plt.show()
# sns.scatterplot(data=tmp_df, x='theta', y='rho', hue='z')
# plt.show()

In [None]:
# %load_ext autoreload
# %autoreload 2
# from config.datasets import dataset_configs
# from data.datasets import StormDataset, ExperimentalDataSet
# import seaborn as sns
# import matplotlib.pyplot as plt

# dataset ='picasso_test'

# cfg = dataset_configs[dataset]['nucleopore']
# ds = StormDataset(cfg, normalize_psf=True, lazy=True, apply_clustering=False)
# # ds.neighbour_radius = 15
# # ds.max_off_frames = 1000
# ds.csv_data = ds.csv_data[(ds.csv_data['x [nm]'] > 8250) 
#                           & (ds.csv_data['x [nm]'] < 8425) 
#                           & (ds.csv_data['y [nm]'] > 9425) 
#                           & (ds.csv_data['y [nm]'] < 9575) 
#                          ]

# plt.rcParams['figure.figsize'] = [5, 5]
# sns.scatterplot(data=ds.csv_data, x='x [nm]', y='y [nm]', marker='.')
# plt.show()


# ds.prepare_data()

# from scipy.ndimage import median_filter
# ds.data[0] = np.stack([median_filter(d, size=2) for d in ds.data[0]])

# df = ds.csv_data
# print(df.shape)
# print(ds.data[0].shape)

In [None]:
from sklearn.metrics.pairwise import euclidean_distances
from sklearn.cluster import DBSCAN

plt.rcParams['figure.figsize'] = [5,5]

coords = df[['x', 'y']].to_numpy() * 100
distance_matrix = euclidean_distances(coords, coords)
eps = 15
min_count = 50

distance_matrix = (distance_matrix < eps).astype(int).sum(axis=0)
idx = np.argwhere(distance_matrix > min_count).squeeze()
sub_coords = coords[idx]
print(coords.shape)
print(sub_coords.shape)

cluster_ids = DBSCAN(eps=eps, min_samples=min_count).fit_predict(sub_coords)
sns.scatterplot(x=sub_coords[:, 0], y=sub_coords[:, 1], hue=cluster_ids.astype(str))
plt.axis('equal')
plt.show()

sub_df = df.iloc[idx]
sub_df['cluster_id'] = cluster_ids.squeeze().astype(str)
sub_spots = spots[idx]


In [None]:
sns.histplot(model.predict((spots, np.zeros((spots.shape[0], 2)))).squeeze())
plt.show()

In [None]:
from sklearn.cluster import DBSCAN, OPTICS, KMeans
from sklearn.mixture import GaussianMixture
from scipy.stats import norm
from data.visualise import grid_psfs

plt.rcParams['figure.figsize'] = [20, 5]
# df['cluster_id'] = OPTICS().fit_predict(df[['x', 'y']].to_numpy()).squeeze().astype(str)
plt.axis('equal')
plt.title('Units: pixels (90nm each)')
sns.scatterplot(data=sub_df, x='x', y='y', marker='.', hue='cluster_id', legend=False)
plt.show()

n_components = 2

gm_mean_diffs = {}
mean_img_diffs = {}


def gm_min_bic(data, imgs):
    gm_df = pd.DataFrame({'pred': data.squeeze()}, index=np.arange(0, data.squeeze().shape[0]))

    best_gm = None
    min_bic = np.inf
    bics = []
    cov_type = 'tied'
    stdevs = []
    
    fig, axes = plt.subplots(1, 6)
    for n in range(1, 7):
        gm = GaussianMixture(n_components=n, n_init=20, covariance_type=cov_type).fit(data)
        bic = gm.bic(data)
        
        bics.append(round(bic, 3))
        stdevs.append(round(np.std(gm.weights_), 3))
        if bic < min_bic:
            min_bic = bic
            best_gm = gm
        
        ax = axes[n-1]
        labels = gm.predict(data).squeeze()

        gm_df['cluster_id'] = labels.astype(str)

        weights = gm.weights_

        sns.histplot(data=gm_df, x='pred', hue='cluster_id', stat='density', alpha=0.2, bins=20, ax=ax)

        # create necessary things to plot
        x_axis = np.linspace(data.min(), data.max(), 50)
        ys = []
        sub_df2 = pd.DataFrame.from_dict({'x': x_axis})
        for i in range(0, best_gm.n_components):
            if cov_type == 'tied':
                cov = gm.covariances_.squeeze()
            elif cov_type == 'full' or cov_type == None:
                cov = gm.covariances_[i][0][0]
            elif cov_type == 'spherical':
                cov = gm.covariances_[i]
            elif cov_type == 'diag':
                cov_type = gm.covariances_[i]

            sub_df2[f'y_{i}'] = norm.pdf(x_axis, float(gm.means_[i][0]), np.sqrt(cov))*gm.weights_[i]
            sns.lineplot(data=sub_df2, x='x', y=f'y_{i}', ax=ax)
    plt.show()        
        
    print(bics)
    print(stdevs)

    return best_gm.means_[:, 0]

def apply_gm(data, imgs, cid):
    data = data.reshape(-1, 1)
    gm_df = pd.DataFrame({'pred': list(data)})
    
    return gm_min_bic(data, imgs)
    

all_coords = []
for cid in set(sub_df['cluster_id']):
    if cid == '-1':
        continue
    idx = np.argwhere(sub_df['cluster_id'].to_numpy()==cid).squeeze()
    imgs = sub_spots[idx][:, :, :, np.newaxis]
    coords = np.zeros((imgs.shape[0], 2))
    preds = model.predict((imgs, coords)).squeeze()
    preds -= preds.min()
    preds += 0.00000001
    preds = np.sqrt(preds)
    z_pos = apply_gm(preds, imgs, cid)
    x, y = sub_df.iloc[idx][['x', 'y']].mean(axis=0).to_numpy() * 106
    coords = [[x, y, z, int(cid)] for z in z_pos]
    all_coords.extend(coords)

all_coords = np.array(all_coords)
res = pd.DataFrame.from_dict({
    k: all_coords[:, i] for k, i in zip(['x', 'y', 'z', 'cluster_id'], [0, 1, 2, 3])
})


In [None]:
d = gm_mean_diffs
tops = [np.max(v) for k, v in d.items()]
bottoms = [np.min(v) for k, v in d.items()]
print(np.mean(tops))
print(np.mean(bottoms))
print(np.mean(tops) - np.mean(bottoms))

In [None]:
plt.rcParams['figure.figsize'] = [15, 5]
res['cluster_id'] = res['cluster_id'].astype(int)
fig = plt.figure()
fig.tight_layout() 

plt.subplots_adjust(wspace = 0.4) 
ax = fig.add_subplot(121, projection='3d')

ax.scatter(res['x'], res['y'], res['z'], marker='o', c=res['cluster_id'])

ax.set_xlabel('X')
ax.set_ylabel('Y')
ax.set_zlabel('Z')

ax2 = fig.add_subplot(122)

ax2.scatter(res['cluster_id'], res['z'], c=res['cluster_id'])
plt.xlabel('Cluster ID')
plt.title('All units in nm')
plt.show()

In [None]:
print(spots.shape)
print(coords.shape)
all_coords = np.zeros((spots.shape[0], 2))
all_spots = spots[:, :, :, np.newaxis]

pred = model.predict((all_spots, all_coords))
df['z'] = pred.squeeze()
# for c in list(df):
#     sns.scatterplot(data=df, x=c, y='z')
#     plt.show()
sns.scatterplot(data=df.iloc[0:1000], x='frame', y='z')
plt.show()
sns.histplot(data=df, x='z')
plt.show()

In [None]:

best_img = spots[0]

recs = []

imgs = []
coords = []

for _ in range(1000):
    noise_loc = np.random.uniform(0, 0.8)
    noise_scale = np.random.uniform(0, 0.7)
    new_img = np.random.normal(loc=noise_loc, scale=noise_scale, size=best_img.shape)
    new_img += best_img
    new_img = norm_zero_one(new_img)
    
    new_img = norm_zero_one(new_img)
    
    new_img = np.array([new_img])
    imgs.append(new_img)
    coords.append(np.zeros((1, 2)))
    
    recs.append({
        'img_mean': np.mean(new_img),
        'img_median': np.median(new_img),
        'noise_loc': noise_loc,
        'noise_scale': noise_scale,
    })

df = pd.DataFrame.from_dict(recs)

imgs = np.concatenate(imgs)
print(imgs.shape)
coords = np.concatenate(coords)
print(coords.shape)
pred = model.predict((imgs, coords)).squeeze()
err = pred
df['err'] = err
df['pred'] = pred
for col in list(df):
    sns.scatterplot(data=df, x=col, y='err')
    plt.show()
sns.scatterplot(data=df, x='noise_loc', y='noise_scale', hue='err')
plt.show()
