In [8]:
import tifffile
import functools
import json
import os
import re
from typing import Mapping, Tuple, Union

import click
import numpy as np
from slicedimage import ImageFormat

import starfish.core.util.try_import
from starfish.experiment.builder import FetchedTile, TileFetcher, write_experiment_json
from starfish.types import Axes, Coordinates, Features, Number

ROUND_NUM = 1
FOV_NUM = 1

@functools.lru_cache(maxsize=1)
def cached_read_fn(file_path) -> np.ndarray:
    return tifffile.imread(file_path)

class SampleDataTile(FetchedTile):

    def __init__(self, 
                 file_path: str, 
                 coordinates: Mapping[Union[str, Coordinates], Union[Number, Tuple[Number, Number]]],
                 z: int) -> None:
        """
        Parameters
        ----------
        file_path : str
            location of the tile
        coordinates : Mapping[Union[str, Coordinates], Union[Number, Tuple[Number, Number]]]
            the coordinates for the selected tile, extracted from the metadata
        z : int
            the z-layer for the selected osmFISH tile
        """
        self.file_path = file_path
        self._coordinates = coordinates
        self.z = z
        
    @property
    def shape(self) -> Mapping[Axes, int]:
        raw_shape = self.tile_data().shape
        return {Axes.Y: raw_shape[0], Axes.X: raw_shape[1]}

    @property
    def coordinates(self) -> Mapping[Union[str, Coordinates], Union[Number, Tuple[Number, Number]]]:
        return self._coordinates

    def tile_data(self) -> np.ndarray:
        return tifffile.imread(self.file_path)[self.z] # slice out the correct z-plane

class SampleDataTileFetcher(TileFetcher):
    
    def __init__(self, input_dir: str) -> None:
        self.input_dir = input_dir
        basename = f"round01-ch00.tif" # assume the first file of the series is named this way
        first_file = os.path.join(input_dir, basename)
        
        raw_shape = cached_read_fn(first_file).shape
        self.num_c = raw_shape[3] if len(raw_shape) > 3 else 1
        self.num_z = raw_shape[0]
        
    @property
    def channel_map(self) -> Mapping[str, int]:
        if self.num_c >1:
            return {
                "red": 0,
                "blue": 1,
                "green": 2,
            }
        return 0
    
    def coordinate_map(self, round_: int, z: int):
        # TODO see osmFISH for example on how to dynamically generate this
        
        # dummy coordinates
        return {
            Coordinates.X: (0.0, 0.0001),
            Coordinates.Y: (0.0, 0.0001),
            Coordinates.Z: (0.0, 0.0001),
        }
    
    def get_tile(self, fov: int, r: int, ch: int, z: int) -> FetchedTile:
        basename = f"round0{r + 1}-ch0{ch}.tif"  # translate to 3d
        file_path = os.path.join(self.input_dir, basename)
        coordinates = self.coordinate_map(r, z)
        return SampleDataTile(file_path, coordinates, z)
    
#     def generate_codebook(self, output_dir: str) -> None:
#         """
#         TODO
#         Generate and save a codebook from the provided mapping of genes to DNA sequences.
#         output_dir : str
#             directory in which to save the generated codebook. Codebook is saved as "codebook.json"

#         """
#         dinucleotides_to_channels = {
#             "CT": 3,
#             "GT": 2,
#             "TT": 1,
#             "AG": 3,
#             "GG": 1,
#             "TG": 2,
#             "AC": 2,
#             "CC": 1,
#             "TC": 3,
#             "AA": 1,
#             "CA": 2,
#             "GA": 3,
#         }
        
#         with open(os.path.join(self.input_dir, "genes.csv"), "r") as f:
#             codes = [l.strip().split(",") for l in f.readlines()]  # List[(gene, dna_barcode), ...]

#         def iter_dinucleotides(sequence):
#             i = 0
#             while i + 1 < len(sequence):
#                 yield sequence[i:i + 2]
#                 i += 1

#         # construct codebook target mappings
#         code_array = []
#         for gene, dna_barcode in codes:
#             dna_barcode = dna_barcode[::-1]  # reverse barcode
#             spacetx_barcode = [
#                 {
#                     Axes.ROUND.value: r,
#                     Axes.CH.value: dinucleotides_to_channels[dinucleotide],
#                     Features.CODE_VALUE: 1
#                 } for r, dinucleotide in enumerate(iter_dinucleotides(dna_barcode))
#             ]
#             code_array.append({
#                 Features.CODEWORD: spacetx_barcode,
#                 Features.TARGET: gene
#             })

#         codebook = Codebook.from_code_array(code_array)
#         codebook.to_json(os.path.join(output_dir, "codebook.json"))

# class SampleDataDapiTileFetcher(TileFetcher):

#     def __init__(self, input_dir: str) -> None:
#         """
#         The red channel in the TIF files contains staining that labels all cells, 
#         which can serve a similar function to DAPI
#         """
#         self.input_dir = input_dir

#     def get_tile(self, fov: int, r: int, ch: int, z: int) -> FetchedTile:
#         basename = f"round0{r + 1}-ch03.tif"
#         file_path = os.path.join(self.input_dir, basename)
#         return StarMapTile(file_path, z)

    
def cli(input_dir, output_dir) -> None:
    """
    TODO
    osmFISH example does not use auxillary images - could try, but need dimensions
    """
    abs_output_dir = os.path.expanduser(output_dir)
    abs_input_dir = os.path.expanduser(input_dir)
    os.makedirs(abs_output_dir, exist_ok=True)

    primary_tile_fetcher = SampleDataTileFetcher(abs_input_dir)
    primary_image_dimensions: Mapping[Union[str, Axes], int] = {
        Axes.ROUND: ROUND_NUM,
        Axes.CH: primary_tile_fetcher.num_c,
        Axes.ZPLANE: primary_tile_fetcher.num_z,
    }

    write_experiment_json(
        path=output_dir,
        fov_count=FOV_NUM,
        tile_format=ImageFormat.TIFF,
        primary_image_dimensions=primary_image_dimensions,
        primary_tile_fetcher=primary_tile_fetcher,
        aux_name_to_dimensions={},
        dimension_order=(Axes.ROUND, Axes.CH, Axes.ZPLANE)
    )

    #primary_tile_fetcher.generate_codebook(abs_output_dir)
    
if __name__ == "__main__":
    cli('sample_data/raw2', 'sample_data/formatted2')

  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s

  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s

  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s

  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s

  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)


In [None]:
#smFISH example does not have cell segmentation step
# only filter, spot detection, and decode

#Segmentation used in in-situ sequencing (ISS) example
# 1) detect 2) decode 3) segment cell 4) assign decoded spots to cells

In [2]:
from typing import Optional, Tuple
from IPython import get_ipython

import starfish
import starfish.data
from starfish import FieldOfView, IntensityTable

ipython = get_ipython()
ipython.magic("gui qt5")

In [15]:
experiment = starfish.Experiment.from_json("sample_data/formatted2/experiment.json")
for key in experiment.keys():
    print(key)

fov_000


In [None]:
from starfish.types import Axes
fov = experiment['fov_000']

In [27]:
primary_images = fov.get_image(FieldOfView.PRIMARY_IMAGES)
raw_image_viewer = starfish.display(image.sel({Axes.CH: 0, Axes.ROUND: 0}))
raw_image_viewer

100%|██████████| 169/169 [00:12<00:00, 13.63it/s]


<napari.components._viewer.model.Viewer at 0x1956edf98>

In [4]:
# bandpass filter to remove cellular background and camera noise
bandpass = starfish.image.Filter.Bandpass(lshort=.5, llong=7, threshold=0.0)

# gaussian blur to smooth z-axis
glp = starfish.image.Filter.GaussianLowPass(
    sigma=(1, 0, 0),
    is_volume=True
)

# pre-filter clip to remove low-intensity background signal
clip1 = starfish.image.Filter.Clip(p_min=50, p_max=100)

# post-filter clip to eliminate all but the highest-intensity peaks
clip2 = starfish.image.Filter.Clip(p_min=99, p_max=100, is_volume=True)

In [5]:
# spot detection
tlmpf = starfish.spots.DetectSpots.TrackpyLocalMaxPeakFinder(
    spot_diameter=5,  # must be odd integer
    min_mass=0.02,
    max_size=2,  # this is max radius
    separation=7,
    noise_size=0.65,  # this is not used because preprocess is False
    preprocess=False,
    percentile=10,  # this is irrelevant when min_mass, spot_diameter, and max_size are set properly
    verbose=True,
    is_volume=True,
)

In [28]:
def processing_pipeline(
    experiment: starfish.Experiment,
    fov_name: str,
    n_processes: Optional[int]=None
) -> Tuple[starfish.ImageStack, starfish.IntensityTable]:
    """Process a single field of view of an experiment

    Parameters
    ----------
    experiment : starfish.Experiment
        starfish experiment containing fields of view to analyze
    fov_name : str
        name of the field of view to process
    n_processes : int

    Returns
    -------
    starfish.IntensityTable :
        decoded IntensityTable containing spots matched to the genes they are hybridized against
    """

    all_intensities = list()
    #codebook = experiment.codebook

    print("Loading images...")

    for image_number, primary_image in enumerate(experiment[fov_name].get_images(FieldOfView.PRIMARY_IMAGES)):
        print(f"Filtering image {image_number}...")
        filter_kwargs = dict(
            in_place=True,
            verbose=True,
            n_processes=n_processes
        )
        clip1.run(primary_image, **filter_kwargs)
        bandpass.run(primary_image, **filter_kwargs)
        glp.run(primary_image, **filter_kwargs)
        clip2.run(primary_image, **filter_kwargs)

#         print("Calling spots...")
#         spot_attributes = tlmpf.run(primary_image)
#         all_intensities.append(spot_attributes)

#     spot_attributes = IntensityTable.concatenate_intensity_tables(all_intensities)

#     print("Decoding spots...")
#     decoded = codebook.decode_per_round_max(spot_attributes)
#     decoded = decoded[decoded["total_intensity"] > .025]

    return primary_image

#experiment = starfish.data.allen_smFISH(use_test_data=True)
#image, intensities = processing_pipeline(experiment, fov_name='fov_001')
processed_image = processing_pipeline(experiment, fov_name='fov_000')

  0%|          | 0/169 [00:00<?, ?it/s]

Loading images...


100%|██████████| 169/169 [00:49<00:00,  3.42it/s]
0it [00:00, ?it/s]

Filtering image 0...


169it [00:04, 33.95it/s]
169it [00:23,  7.15it/s]
1it [00:00, 15.20it/s]
1it [00:00, 16.00it/s]


In [29]:
# uncomment the below line to visualize the output with the spot calls.
processed_image_viewer = starfish.display(processed_image)
processed_image_viewer

<napari.components._viewer.model.Viewer at 0x1957037f0>

In [None]:
'''
1) compare spot calls between Neurotator and Starfish
2) compare spot calls with or without filter/different filters
'''