In [None]:
from dask.distributed import Client
client = Client()
client

In [3]:
import os
from pathlib import Path
import xml.etree.ElementTree as ET

import dask
import dask.array as da
import matplotlib.pyplot as plt
import numpy as np
import ome_types
from ome_types.model import StructuredAnnotations, XMLAnnotation
import pandas as pd
from ptufile import PtuFile
from scipy import interpolate
import tifffile

In [4]:
# directory, where the raw .ptu files are stored
ptu_dir = Path(
    r'Z:\zmbstaff\9309\Raw_Data\240531_Inhibitor_screening_1\ptu_exports2\data.sptw'
)
# directory, where the mask-TIFFs are stored
mask_dir = Path(
    r'Z:\zmbstaff\9309\Raw_Data\240531_Inhibitor_screening_1\converted_data\roiImages'
)
# directory, where the preprocessed roiExtracts will be stored
export_dir_ptus = Path(
    r'Z:\zmbstaff\9309\Raw_Data\240531_Inhibitor_screening_1\converted_data\roiExtract'
)
# path to the exported IRF xlsx file
fn_IRF_xlsx = Path(
    r'Z:\zmbstaff\9309\Raw_Data\240531_Inhibitor_screening_1\IRF\export_IRF_data.xlsx'
)
# directory, where the IRF data will be stored
export_dir_IRF = Path(
    r'Z:\zmbstaff\9309\Raw_Data\240531_Inhibitor_screening_1\converted_data\IRF'
)

## Load ptu files, extract ROIs and save as ome.tiffs

### Load metadata from ptu files

In [5]:
# find ptu files
fns_ptu = list(ptu_dir.glob("*.ptu"))
# find mask files
fns_mask = list(mask_dir.glob("*.tif"))
names_mask = [fn_mask.stem.split('_mask')[0] for fn_mask in fns_mask]

# find ptu files that have a matching mask file (ignore others)
matching_fns = []
for name_mask in names_mask:
    matching_fns_ptu = [fn_ptu for fn_ptu in fns_ptu if name_mask == fn_ptu.stem]
    if len(matching_fns_ptu) == 1:
        matching_fns.append(matching_fns_ptu[0])
    elif len(matching_fns_ptu) > 1:
        raise ValueError(f"Multiple matching names found for mask file: {name_mask}")
    else:
        raise ValueError(f"No matching name found for mask file: {name_mask}")

fns_ptu = matching_fns

### Load metadata from ptu files

In [7]:
# function to load metadata from ptu files

@dask.delayed
def lazy_load_ptu_metadata(fn):
    ptu = PtuFile(fn)
    assert ptu.dims == ('T', 'Y', 'X', 'C', 'H')

    # DETERMINE DATA WITHIN PULSE
    time = ptu.coords['H']
    # determine timepoints within pulse (1/repetition rate)
    last = np.where(time <= 1/ptu.frequency)[0][-1]
    # NOTE: logically one would take last+1 at this point, but somehow the data
    # looks more continuous this way. The last timepoint within the pulse
    # somehow has fewer photons, so we leave it out.
    t_start = time[0]
    t_end = time[last-1]
    t_step = (t_end-t_start)/(last-1)

    return (t_start, t_end, t_step), ptu.frequency, time[:last].copy()

In [None]:
output = dask.compute(*[lazy_load_ptu_metadata(fn) for fn in fns_ptu])

# check if metadata for all files is the same
assert all(e[:2]==output[0][:2] for e in output)

(t_start, t_end, t_step), frequency, micro_times = output[0]
f"repetition rate is {frequency/10**6:.3f} MHz"

### Load ptu files

In [10]:
# function to load ptu files

@dask.delayed
def lazy_load_ptu(fn):
    ptu = PtuFile(fn)
    assert ptu.dims == ('T', 'Y', 'X', 'C', 'H')

    # DETERMINE DATA WITHIN PULSE
    time = ptu.coords['H']
    # determine timepoints within pulse (1/repetition rate)
    last = np.where(time <= 1/ptu.frequency)[0][-1]
    # NOTE: logically one would take last+1 at this point, but somehow the data
    # looks more continuous this way. The last timepoint within the pulse
    # somehow has fewer photons, so we leave it out.

    # LOAD DECAY-IMAGES
    # sum all repetitions
    flim_img = ptu[...,:last].sum(axis=0).astype('uint16')
    
    # get flim-img into correct shape (make sure it matches mask!)
    data = np.transpose(flim_img, (3,2,1,0))
    data = np.flip(data, axis=3)
    
    return data[:,0]

In [None]:
meta = lazy_load_ptu(fns_ptu[0]).compute()

arrays = [
    da.from_delayed(
        lazy_load_ptu(fn),
        dtype=meta.dtype,
        shape=meta.shape
    )
    for fn in fns_ptu]

flim_imgs_da = da.stack(arrays, axis=0)
flim_imgs_da

### Load masks files

In [12]:
# function to load masks

@dask.delayed
def lazy_load_mask(fn):
    mask = tifffile.imread(fn)    
    return mask

In [None]:
meta = lazy_load_mask(fns_mask[0]).compute()

arrays = [
    da.from_delayed(
        lazy_load_mask(fn),
        dtype=meta.dtype,
        shape=meta.shape
    )
    for fn in fns_mask]

masks_da = da.stack(arrays, axis=0)
masks_da

In [None]:
fig, axs = plt.subplots(1, 2, figsize=(10, 5))
axs[0].imshow(masks_da[0])
axs[1].imshow(flim_imgs_da[0].max(axis=0))

In [None]:
max_label = masks_da.max().compute()
max_label

### Extract ROIs

In [16]:
@dask.delayed
def lazy_extract_roi(flim_img, mask, max_label):
    decays = np.zeros((flim_img.shape[0], max_label), dtype=flim_img.dtype)
    for t in range(len(decays)):
        for i in range(max_label):
            decays[t, i] = flim_img[t][mask == i+1].sum()
    return np.reshape(decays, decays.shape)

In [None]:
meta = lazy_extract_roi(flim_imgs_da[0], masks_da[0], max_label).compute()

arrays = [
    da.from_delayed(
        lazy_extract_roi(flim_img_da, mask_da, max_label),
        dtype=meta.dtype,
        shape=meta.shape
    )
    for flim_img_da, mask_da in zip(flim_imgs_da, masks_da)]

roiExtracts_da = da.stack(arrays, axis=0)
roiExtracts_da

### Save as ome.tiff

In [18]:
def construct_omexml(data, flim_start, flim_end, flim_step):
    assert len(data.shape) == 3
    assert data.shape[-1] == 1

    # create ome Image object
    omexml = tifffile.OmeXml()
    omexml.addimage(
        dtype=data.dtype,
        shape=data.shape,
        storedshape=(data.shape[0], 1, 1, data.shape[1], 1, 1),
        axes='TYX',
    )
    ome_image = ome_types.from_xml(omexml.tostring()).images[0]
    ome_image.annotation_refs = [{'id': 'Annotation:0'}]

    # construct XML Annotation for FLIM metadata
    root = ET.Element('Modulo', {'namespace': 'http://www.openmicroscopy.org/Schemas/Additions/2011-09'})
    modulo_along_t = ET.SubElement(
        root,
        'ModuloAlongT',
        {
            'End': str(flim_end * 10**9),
            'Start': str(flim_start * 10**9),
            'Step': str(flim_step * 10**9),
            'Type': 'lifetime',
            'TypeDescription': '',
            'Unit': 'ns'
        }
    )
    tree = ET.ElementTree(root)
    xml_str_modulo = ET.tostring(root, encoding='unicode')
    sa = StructuredAnnotations(
        xml_annotations=[
            XMLAnnotation(
                id='Annotation:0',
                namespace='openmicroscopy.org/omero/dimension/modulo',
                value=xml_str_modulo,
            )
        ],
    )

    # combine everything into an OME object
    ome = ome_types.OME(
        images=[ome_image],
        structured_annotations=sa,
    )

    return ome.to_xml()

In [19]:
@dask.delayed
def lazy_save_roiExtract(roiExtract, export_dir_ptus, name, time_metadata):
    data = roiExtract.reshape(roiExtract.shape + (1,))
    t_start, t_end, t_step = time_metadata
    omexml = construct_omexml(data, t_start, t_end, t_step)

    save_name = export_dir_ptus / f"{name}.ome.tiff"

    with tifffile.TiffWriter(save_name, ome=False, shaped=False) as tif:
        for frame in data:
            tif.write(frame, description=omexml, contiguous=True)
            omexml = None

In [20]:
# make export directory:
os.makedirs(export_dir_ptus, exist_ok=True)

# run the saving
names = [fn.stem for fn in fns_ptu]
_ = dask.compute(
    *[
        lazy_save_roiExtract(
            roiExtract_da,
            export_dir_ptus,
            name,
            (t_start, t_end, t_step)
        )
        for roiExtract_da, name in zip(roiExtracts_da, names)
    ]
)

## Combine exported IRFs

In [22]:
# LOAD IRF DATA
# read IRFs
df = pd.read_excel(fn_IRF_xlsx, header=[0, 1])

# extract IRF columns
columns = df.columns.get_level_values(0).unique()
columns = columns[columns.str.contains('IRF')]

# extract IRF data
irf_data = []
for column in columns:
    time = df[column].iloc[:, 0].dropna().to_numpy()
    irf = df[column].iloc[:, 1].dropna().to_numpy()
    irf_data.append(np.array([time, irf]))

In [None]:
# plot raw IRFs
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 4))

for data in irf_data:
    time, irf = data
    irf_normalized = irf
    ax1.plot(time, irf_normalized)
    ax2.semilogy(time, irf_normalized)

ax1.set_xlabel('Time (ns)')
ax1.set_ylabel('IRF (a.u.)')
ax1.set_title('Exported IRFs (linear)')

ax2.set_xlabel('Time (ns)')
ax2.set_ylabel('IRF (a.u.)')
ax2.set_title('Exported IRFs (log)')

plt.tight_layout()

In [24]:
# COMBINE IRFs
# interpolate IRFs to match FLIM data
irf_data_interp = []
IRF_time = micro_times * 10**9 # time from FLIM data in ns
for data in irf_data:
    interp = interpolate.CubicSpline(data[0], data[1], extrapolate=0)
    data_interp = np.nan_to_num(interp(IRF_time), nan=0)
    irf_data_interp.append(np.array([IRF_time, data_interp]))

# normalize IRFs
irf_data_norm = []
irf_data_interp_norm = []
for data, data_interp in zip(irf_data, irf_data_interp):
    #norm = data_interp[1].max()  # normalize to max
    norm = data_interp[1].sum()  # normalize to sum
    irf_data_norm.append(np.array([data[0], data[1] / norm]))
    irf_data_interp_norm.append(np.array([data_interp[0], data_interp[1] / norm]))

# take mean of normalized IRFs
irf_data_interp_norm_mean = np.mean(np.array(irf_data_interp_norm), axis=0)

In [None]:
# plot result
xlims = (
    min([min(data[0]) for data in irf_data]),
    max([max(data[0]) for data in irf_data])
)
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 4))

for data, data_interp in zip(irf_data_norm, irf_data_interp_norm):
    ax1.plot(data[0], data[1], 'o')
    ax1.plot(data_interp[0], data_interp[1])

    ax2.semilogy(data[0], data[1], 'o')
    ax2.plot(data_interp[0], data_interp[1])

ax1.plot(irf_data_interp_norm_mean[0], irf_data_interp_norm_mean[1], 'k--', label='Mean')
ax2.plot(irf_data_interp_norm_mean[0], irf_data_interp_norm_mean[1], 'k--')

ax1.set_xlabel('Time (ns)')
ax1.set_ylabel('IRF (a.u.)')
ax1.set_title('Normalized IRFs & mean (linear)')
ax1.legend()

ax2.set_xlabel('Time (ns)')
ax2.set_ylabel('IRF (a.u.)')
ax2.set_title('Normalized IRFs & mean (log)')

ax1.set_xlim(xlims[0], xlims[1])
ax2.set_xlim(xlims[0], xlims[1])
plt.tight_layout()

In [26]:
# make export directory:
os.makedirs(export_dir_IRF, exist_ok=True)

# export IRF as csv
export_time = (irf_data_interp_norm_mean[0] * 10**3).astype('int')
export_data = irf_data_interp_norm_mean[1] / irf_data_interp_norm_mean[1].max()
irf_df = pd.DataFrame({'t': export_time, 'irf_ch1': export_data})
irf_df.to_csv(export_dir_IRF / 'irf_est.csv', index=False, float_format='%.6f')