# Notebook to convert tif files to a format compatible with Micro.

In [None]:
import tifffile
import numpy as np
from ipywidgets import Dropdown, Button, HBox, VBox, Label, Output, widgets
from IPython.display import display
import json
#from ome_types.model import OME, Image, Pixels, Channel
#from ome_types.model.simple_types import PositiveFloat

In [None]:
def load_tif_with_axis_dropdown(file_path):
    """
    Loads a TIFF/OME-TIFF, reads any embedded axes metadata,
    then presents dropdowns for you to map each file-dimension
    onto a chosen standard format (either TCZYX, TZYXC, or XYZCT).
    
    Returns
    -------
    btn : Button widget. After clicking “Apply Mapping”, you can read:
      btn.image5d  → np.ndarray of shape [T,Z,Y,X,C], [T,C,Z,Y,X], or [X,Y,Z,C,T]
      btn.meta     → dict with keys 'axes','PhysicalSizeX','PhysicalSizeZ','TimeIncrement'
      btn.standard → the chosen final standard string
    """
    # 1) Read raw + metadata
    with tifffile.TiffFile(file_path) as tif:
        raw = tif.asarray()
        ome_axes = getattr(tif.series[0], "axes", None)
        desc = tif.pages[0].tags.get('ImageDescription')
        md = {}
        if desc and desc.value.strip().startswith('{'):
            try:
                md = tifffile.json.loads(desc.value)
            except Exception:
                pass

    shape = raw.shape
    ndim = raw.ndim

    # 2) Let user pick their desired output order
    standard_dropdown = Dropdown(
        options=['TCZYX'],
        value=ome_axes if ome_axes in ('TCZYX') else 'TCZYX',
        description='Final standard:'
    )

    # 3) Build mapping dropdowns for each axis in ["T","Z","Y","X","C"]
    axis_labels = [f"dim {i} ({s})" for i, s in enumerate(shape)]
    target_axes = ["T", "Z", "Y", "X", "C"]
    dropdowns = {
        ax: Dropdown(
            options=['<none>'] + axis_labels,
            description=f'{ax}:',
            value='<none>'
        )
        for ax in target_axes
    }

    # prepopulate if OME tells us
    if ome_axes and len(ome_axes) == ndim:
        for file_idx, letter in enumerate(ome_axes):
            if letter in target_axes:
                dropdowns[letter].value = axis_labels[file_idx]

    # 4) Button & output
    out = Output()
    btn = Button(description='Apply Mapping', button_style='success')

    def on_click(_):
        std = standard_dropdown.value
        mapping = []
        # gather user picks
        for ax in target_axes:
            val = dropdowns[ax].value
            if val == '<none>':
                mapping.append(None)
            else:
                idx = int(val.split()[1].strip('()'))
                mapping.append(idx)
        # validate no duplicates among non-None
        used = [m for m in mapping if m is not None]
        if len(set(used)) != len(used):
            with out:
                print("⚠️ Duplicate axes selected!")
            return

        # A) permute the raw array to match the order of std string
        permuted_axes = []
        for letter in std:
            if letter in target_axes:
                file_axis_index = target_axes.index(letter)
                mapped = mapping[file_axis_index]
                if mapped is not None:
                    permuted_axes.append(mapped)
        transposed = raw.transpose(permuted_axes) if permuted_axes else raw

        # B) reshape to full 5-tuple: for each letter in std, if mapped use its size else 1
        final_shape = []
        for letter in std:
            if letter in target_axes:
                file_axis_index = target_axes.index(letter)
                mapped = mapping[file_axis_index]
                final_shape.append(shape[mapped] if mapped is not None else 1)
        try:
            image5d = transposed.reshape(final_shape)
        except Exception as e:
            with out:
                print("❌ Reshape failed:", e)
            return

        # collect metadata
        meta = {
            'axes':           std,
            'PhysicalSizeX':  md.get('PhysicalSizeX'),
            'PhysicalSizeZ':  md.get('PhysicalSizeZ'),
            'TimeIncrement':  md.get('TimeIncrement')
        }

        # store results on the button
        btn.image5d  = image5d
        btn.meta     = meta
        btn.standard = std

        with out:
            print(f"✅ Done: shape={image5d.shape}, axes='{std}'")

    btn.on_click(on_click)

    # 5) display
    display(VBox([standard_dropdown] +
                 [dropdowns[ax] for ax in target_axes] +
                 [btn, out]))
    return btn

In [None]:
class MetadataEditor:
    """
    An ipywidgets–based metadata editor for OME-TIFF export.

    Parameters
    ----------
    meta : dict
      Must contain at least 'axes' (e.g. "TCZYX", "TZYXC" or "XYZCT"),
      and any of PhysicalSizeX/Z, TimeIncrement if available.
    image5d : np.ndarray
      Your mapped 5D array, so we can infer the C-dimension length.
    """
    def __init__(self, meta: dict, image5d: np.ndarray):
        style      = {'description_width': 'initial'}
        full_width = widgets.Layout(width='100%')

        self.meta    = meta
        self.image5d = image5d

        # check if 'axes' is provided in metadata
        if 'axes' in meta:
            self.stdandard = meta['axes']
        else:
            self.stdandard = 'TCZYX'
            
        # --- PhysicalSizeX, PhysicalSizeY, PhysicalSizeZ, TimeIncrement ---
        self.phys_x = widgets.BoundedFloatText(
            value=meta.get('PhysicalSizeX') if (meta.get('PhysicalSizeX') or 0) > 0 else 0.1,
            min=0, max=1e6,
            description='PhysicalSizeX (µm):',
            style=style, layout=full_width
        )
        self.phys_y = widgets.BoundedFloatText(
            value=meta.get('PhysicalSizeY') if (meta.get('PhysicalSizeY') or 0) > 0 else 0.1,
            min=0, max=1e6,
            description='PhysicalSizeY (µm):',
            style=style, layout=full_width
        )
        self.phys_z = widgets.BoundedFloatText(
            value=meta.get('PhysicalSizeZ') if (meta.get('PhysicalSizeZ') or 0) > 0 else 0.1,
            min=0, max=1e6,
            description='PhysicalSizeZ (µm):',
            style=style, layout=full_width
        )
        self.dt     = widgets.BoundedFloatText(
            value=meta.get('TimeIncrement') if (meta.get('TimeIncrement') or 0) > 0 else 1.0,
            min=0, max=1e6,
            description='TimeIncrement (s):',
            style=style, layout=full_width
        )
        # --- Channel Names: infer count from the C axis ---
        existing = meta.get('Channel', {}).get('Name')
        if isinstance(existing, list) and len(existing) > 0:
            default_list = existing
        else:
            axes = meta.get('axes', '')
            if 'C' in axes:
                c_idx = axes.index('C')
                c_count = image5d.shape[c_idx]
            else:
                c_count = 1
            default_list = [f"Ch{i+1}" for i in range(c_count)]

        default = ",".join(default_list)
        self.ch_edit = widgets.Text(
            value=default,
            description='Channel Names (comma separated):',
            placeholder='e.g. Ch1,Ch2,Ch3',
            style=style, layout=full_width
        )

        # --- Save button & feedback ---
        self.save_button = widgets.Button(
            description='Save Metadata',
            button_style='success',
            layout=widgets.Layout(width='180px')
        )
        self.out = widgets.Output()

        self.saved = False
        self.save_button.on_click(self._on_save)

    def _on_save(self, _):
        # parse channel names
        ch_list = [n.strip() for n in self.ch_edit.value.split(',') if n.strip()]

        self.meta = {
            'axes':               self.stdandard,
            'PhysicalSizeX':       float(self.phys_x.value),
            'PhysicalSizeY':       float(self.phys_y.value),
            'PhysicalSizeZ':       float(self.phys_z.value),
            'TimeIncrement':       float(self.dt.value),
            'TimeIncrementUnit':   's',
            'SignificantBits':     16,
            'Channel':             ch_list
        }
        self.saved = True
        # if the user save the metadata using the default values, print a warning
        
            
        with self.out:
            if (self.phys_x.value == 0.1 and self.phys_y.value == 0.1 and
                self.phys_z.value == 0.1 and self.dt.value == 1.0):
                print("⚠️ Warning: Default values used for metadata. Please adjust PhysicalSizeX/Y/Z")
                print("✅ Metadata saved.")
        

    def display(self):
        """Render the editor in the notebook."""
        form = widgets.VBox([
            self.phys_x,
            self.phys_y,
            self.phys_z,
            self.dt,
            self.ch_edit,
            widgets.HBox([self.save_button, self.out])
        ], layout=widgets.Layout(width='80%'))
        display(form)

    def get_metadata(self) -> dict:
        """Return the edited metadata (raises if not yet saved)."""
        if not self.saved:
            raise RuntimeError("Click ‘Save Metadata’ before retrieval.")
        return self.meta

In [None]:
mapper = load_tif_with_axis_dropdown("/Users/nzlab-la/Downloads/suntag-kdm5b-ke1-ki0_04_nopb_diff1_sigma1/tSUN-KDM5B.tif")


#mapper = load_tif_with_axis_dropdown("/Users/nzlab-la/Library/CloudStorage/OneDrive-TheUniversityofColoradoDenver/General - Zhao (NZ) Lab/Microscope/Luis Aguilera/20250710_pGG001.lif")


In [None]:
img5d, meta, std = mapper.image5d, mapper.meta, mapper.standard


In [None]:
# In this widget, you can edit the metadata before saving it to an OME-TIFF.
# The metadata will be used to write the OME-TIFF file.
# Please ensure that the metadata is correct before saving. 
# Because the pixel sizes are not stored in the TIFF, you must set them manually.
# The default values are 1.0 µm for X/Y and 1.0 µm for Z, and 1.0 s for time.
# The channel names are inferred
editor = MetadataEditor(meta, mapper.image5d)
editor.display()

In [None]:
# Then pull the final metadata:
final_meta = editor.get_metadata()
# print the metadata as JSON
print("Final metadata JSON:", json.dumps(final_meta, indent=2))

In [None]:
# 2) Build the metadata dict exactly as tifffile wants:
meta = {
    'axes':               std,                               # e.g. 'XYZCT'
    'PhysicalSizeX':      final_meta['PhysicalSizeX'],       # µm
    'PhysicalSizeY':      final_meta['PhysicalSizeY'],       # µm
    'PhysicalSizeZ':      final_meta['PhysicalSizeZ'],       # µm
    'TimeIncrement':      final_meta['TimeIncrement'],       # seconds
    'TimeIncrementUnit':  final_meta['TimeIncrementUnit'],   # e.g. 's'
    'SignificantBits':    final_meta['SignificantBits'],     # e.g. 16
    #'Channel':            final_meta['Channel'],
    'Channel':            {'Name': final_meta['Channel']},                # list of strings
}

# 3) Write it out:
tifffile.imwrite(
    "exported.ome.tif",
    data=img5d.astype(np.uint16),
    imagej=False,        # set True if you need ImageJ compatibility
    metadata=meta,
    ome=True,           # set True to write OME-TIFF format
)


In [None]:
# check the written file:
with tifffile.TiffFile("exported.ome.tif") as tif:
    print("Written OME-TIFF metadata:")
    # print with more formatting and spacing
    print("Written OME-TIFF metadata (formatted):")
    print(json.dumps(tif.pages[0].tags['ImageDescription'].value, indent=2, ensure_ascii=False))
    print("Shape of the written image:", tif.asarray().shape)
    print("Axes in the written image:", tif.series[0].axes)