In [25]:
import os
from os.path import join, isfile, isdir
from urllib.request import urlretrieve
import zipfile
import shutil

In [26]:
from spatialdata import read_zarr

In [31]:
import zarr
import numpy as np
from tqdm import tqdm

In [48]:
import dask.array as da

In [27]:
data_dir = "data"
zip_filepath = join(data_dir, "xenium_rep1_io.spatialdata.zarr.zip")
spatialdata_filepath = join(data_dir, "xenium_rep1_io.spatialdata.zarr")

In [28]:
if not isdir(spatialdata_filepath):
    if not isfile(zip_filepath):
        os.makedirs(data_dir, exist_ok=True)
        urlretrieve('https://s3.embl.de/spatialdata/spatialdata-sandbox/xenium_rep1_io.zip', zip_filepath)
    with zipfile.ZipFile(zip_filepath,"r") as zip_ref:
        zip_ref.extractall(data_dir)
        os.rename(join(data_dir, "data.zarr"), spatialdata_filepath)
        
        # This Xenium dataset has an AnnData "raw" element.
        # Reference: https://github.com/giovp/spatialdata-sandbox/issues/55
        raw_dir = join(spatialdata_filepath, "tables", "table", "raw")
        if isdir(raw_dir):
            shutil.rmtree(raw_dir)


In [29]:
from importlib.metadata import version
version('spatialdata')

'0.5.0'

In [30]:
sdata = read_zarr(spatialdata_filepath)
sdata

version mismatch: detected: RasterFormatV02, requested: FormatV04
  compressor, fill_value = _kwargs_compat(compressor, fill_value, kwargs)
version mismatch: detected: RasterFormatV02, requested: FormatV04


SpatialData object, with associated Zarr store: /Users/mkeller/research/dbmi/vitessce/point-cloud-exploration/data/xenium_rep1_io.spatialdata.zarr
├── Images
│     ├── 'morphology_focus': DataTree[cyx] (1, 25778, 35416), (1, 12889, 17708), (1, 6444, 8854), (1, 3222, 4427), (1, 1611, 2213)
│     └── 'morphology_mip': DataTree[cyx] (1, 25778, 35416), (1, 12889, 17708), (1, 6444, 8854), (1, 3222, 4427), (1, 1611, 2213)
├── Points
│     └── 'transcripts': DataFrame with shape: (<Delayed>, 8) (3D points)
├── Shapes
│     ├── 'cell_boundaries': GeoDataFrame shape: (167780, 1) (2D shapes)
│     └── 'cell_circles': GeoDataFrame shape: (167780, 2) (2D shapes)
└── Tables
      └── 'table': AnnData (167780, 313)
with coordinate systems:
    ▸ 'global', with elements:
        morphology_focus (Images), morphology_mip (Images), transcripts (Points), cell_boundaries (Shapes), cell_circles (Shapes)

In [32]:
out_zarr_path = join(data_dir, "xenium_points.zarr")

In [33]:
out_store = zarr.open(out_zarr_path, mode="a")

In [34]:
ddf = sdata.points['transcripts']

In [35]:
ddf.head()

Unnamed: 0,x,y,z,feature_name,cell_id,overlaps_nucleus,transcript_id,qv
0,4.395842,328.666473,12.019493,SEC11C,565,0,281474976710656,18.662479
1,5.074415,236.964844,7.60851,NegControlCodeword_0502,540,0,281474976710657,18.634956
2,4.702023,322.79715,12.289083,SEC11C,562,0,281474976710658,18.662479
3,4.906601,581.42865,11.222615,DAPK3,271,0,281474976710659,20.821745
4,5.660699,720.851746,9.265523,TCIM,291,0,281474976710660,18.017488


In [36]:
# Convert feature_name to feature_index
table_name = 'table'
var_name_col = 'feature_name'
var_df = sdata.tables[table_name].var
var_index = var_df.index.values.tolist()

def try_index(gene_name):
    try:
        return var_index.index(gene_name)
    except BaseException:
        return -1
ddf['c'] = ddf[var_name_col].apply(try_index).astype('int32')

You did not provide metadata, so Dask is running your function on a small dataset to guess output types. It is possible that Dask will guess incorrectly.
To provide an explicit output types or to silence this message, please provide the `meta=` keyword, as described in the map or apply function that you are using.
  Before: .apply(func)
  After:  .apply(func, meta=('feature_name', 'category'))



In [37]:
x_min = ddf['x'].min().compute()
x_max = ddf['x'].max().compute()

y_min = ddf['y'].min().compute()
y_max = ddf['y'].max().compute()

z_min = ddf['z'].min().compute()
z_max = ddf['z'].max().compute()

In [38]:
scale_factor = 4.0
pixel_grid_shape = (np.ceil(x_max - x_min) * scale_factor, np.ceil(y_max - y_min) * scale_factor)
#pixel_grid_shape = (np.ceil(x_max - x_min) * scale_factor, np.ceil(y_max - y_min) * scale_factor, np.ceil(z_max - z_min) * scale_factor)
pixel_grid_shape

(np.float32(30100.0), np.float32(21880.0))

In [39]:
z = zarr.zeros(pixel_grid_shape)
# Let zarr automatically determine the chunk shape
# TODO: directly call zarr.normalize_chunks()
chunk_shape = z.chunks
chunk_shape

  shape = normalize_shape(shape) + dtype.shape


(471, 684)

In [40]:
pixel_grid_shape[0] % chunk_shape[0]

np.float32(427.0)

In [41]:
nice_num_chunks = (int(np.ceil(pixel_grid_shape[0] / chunk_shape[0])), int(np.ceil(pixel_grid_shape[1] / chunk_shape[1])))
#nice_num_chunks = (int(np.ceil(pixel_grid_shape[0] / chunk_shape[0])), int(np.ceil(pixel_grid_shape[1] / chunk_shape[1])), int(np.ceil(pixel_grid_shape[2] / chunk_shape[2])))
nice_num_chunks

(64, 32)

In [42]:
nice_pixel_grid_shape = (nice_num_chunks[0] * chunk_shape[0], nice_num_chunks[1] * chunk_shape[1])
#nice_pixel_grid_shape = (nice_num_chunks[0] * chunk_shape[0], nice_num_chunks[1] * chunk_shape[1], nice_num_chunks[2] * chunk_shape[2])
nice_pixel_grid_shape

(30144, 21888)

In [43]:
z_x = zarr.zeros(nice_pixel_grid_shape, chunks=chunk_shape)
z_y = zarr.zeros(nice_pixel_grid_shape, chunks=chunk_shape)
#z_z = zarr.zeros(nice_pixel_grid_shape, chunks=chunk_shape)
z_c = zarr.zeros(nice_pixel_grid_shape, chunks=chunk_shape)

In [47]:
# Save the zeros arrays to disk.
out_store['/x'] = z_x
out_store['/y'] = z_y
#out_store['/z'] = z_z
out_store['/c'] = z_c

In [44]:
ddf['x'] = ((ddf['x'] - x_min) / (x_max - x_min)) * pixel_grid_shape[0]
ddf['y'] = ((ddf['y'] - y_min) / (y_max - y_min)) * pixel_grid_shape[1]
#ddf['z'] = ((ddf['z'] - y_min) / (z_max - z_min)) * pixel_grid_shape[2]

In [45]:
ddf.head()

Unnamed: 0,x,y,z,feature_name,cell_id,overlaps_nucleus,transcript_id,qv,c
0,25.078289,1297.217773,12.019493,SEC11C,565,0,281474976710656,18.662479,253
1,27.792721,930.350403,7.60851,NegControlCodeword_0502,540,0,281474976710657,18.634956,-1
2,26.303076,1273.736572,12.289083,SEC11C,562,0,281474976710658,18.662479,253
3,27.121431,2308.434082,11.222615,DAPK3,271,0,281474976710659,20.821745,93
4,30.137976,2866.21875,9.265523,TCIM,291,0,281474976710660,18.017488,284


In [46]:
def pad_with_zeros(chunk_vals, chunk_len, val_shape, val_len):
    out_vals = np.zeros((val_len,))
    out_vals[0:chunk_len] = chunk_vals
    return out_vals.reshape(val_shape) # TODO: more sophisticated rounding

In [49]:
df = ddf.compute()

In [77]:
def create_fill_func(dim_name):
    def fill_chunk_for_dim(arr_in, block_info=None):
        if block_info is not None:
            #print(block_info[0])
            x_offset = block_info[0]["array-location"][0][0]
            y_offset = block_info[0]["array-location"][1][0]
            x_chunk_max = block_info[0]["array-location"][0][1]
            y_chunk_max = block_info[0]["array-location"][1][1]
    
            chunk_df = df.loc[(df['x'] >= x_offset) & (df['x'] < x_chunk_max) & (df['y'] >= y_offset) & (df['y'] < y_chunk_max)]
            
            chunk_x_vals = chunk_df['x'].values
            chunk_y_vals = chunk_df['y'].values
            #chunk_z_vals = chunk_df['z'].values
            chunk_c_vals = chunk_df['c'].values
    
            chunk_len = chunk_df.shape[0]
    
            # All x/y values here should be less than the chunk shape width/height.
            val_shape = (x_chunk_max - x_offset, y_chunk_max - y_offset)
            val_len = val_shape[0] * val_shape[1]
    
            if chunk_len > val_len:
                raise ValueError("values do not fit in chunk. try increasing scale_factor")
            
            if chunk_len > 0:
                x_out = pad_with_zeros(chunk_x_vals, chunk_len, val_shape, val_len) # TODO: convert to output dtype using .astype
                y_out = pad_with_zeros(chunk_y_vals, chunk_len, val_shape, val_len)
                #z_out = pad_with_zeros(chunk_z_vals, chunk_len, val_shape, val_len)
                c_out = pad_with_zeros(chunk_c_vals, chunk_len, val_shape, val_len)

                if dim_name == "x":
                    return x_out
                elif dim_name == "y":
                    return y_out
                elif dim_name == "z":
                    raise ValueError("notyetimplemented: z_out")
                    #return z_out
                elif dim_name == "c":
                    return c_out
                raise ValueError("unknown dim_name")
        return arr_in
    return fill_chunk_for_dim

In [78]:
da_x = da.from_zarr(url=out_zarr_path, component="/x")
da_y = da.from_zarr(url=out_zarr_path, component="/y")
#da_z = da.from_zarr(url=out_zarr_path, component="/z")
da_c = da.from_zarr(url=out_zarr_path, component="/c")

  compressor, fill_value = _kwargs_compat(compressor, fill_value, kwargs)


In [87]:
out_da_x = da_x.map_blocks(create_fill_func("x")) # TODO: specify output dtype
out_da_y = da_y.map_blocks(create_fill_func("y"))
#out_da_z = da_z.map_blocks(create_fill_func("z"))
out_da_c = da_c.map_blocks(create_fill_func("c"))

In [88]:
out_da_x.to_zarr(url=out_zarr_path, component="/x", overwrite=True, compute=True)
out_da_y.to_zarr(url=out_zarr_path, component="/y", overwrite=True, compute=True)
out_da_c.to_zarr(url=out_zarr_path, component="/c", overwrite=True, compute=True)

In [89]:
# Check

In [None]:
# Save in OME-Zarr
# Define scale/translate transformations that reflect the normalization operations performed.