In [62]:
import zarr
import numpy as np
import spatialdata as sd
from tqdm import tqdm

In [2]:
# Download merfish spatialdata example from https://spatialdata.scverse.org/en/latest/tutorials/notebooks/notebooks/examples/technology_merfish.html

In [3]:
merfish_zarr_path = "./data/data.zarr"

In [65]:
out_zarr_path = "./data/out.zarr"

In [4]:
merfish_sdata = sd.read_zarr(merfish_zarr_path)
merfish_sdata

SpatialData object with:
├── Images
│     └── 'rasterized': SpatialImage[cyx] (1, 522, 575)
├── Points
│     └── 'single_molecule': DataFrame with shape: (3714642, 3) (2D points)
├── Shapes
│     ├── 'anatomical': GeoDataFrame shape: (6, 1) (2D shapes)
│     └── 'cells': GeoDataFrame shape: (2399, 2) (2D shapes)
└── Table
      └── AnnData object with n_obs × n_vars = 2399 × 268
    obs: 'cell_id', 'region'
    uns: 'spatialdata_attrs': AnnData (2399, 268)
with coordinate systems:
▸ 'global', with elements:
        rasterized (Images), single_molecule (Points), anatomical (Shapes), cells (Shapes)

In [5]:
merfish_sdata.table.uns['spatialdata_attrs']

{'instance_key': 'cell_id', 'region': 'cells', 'region_key': 'region'}

In [6]:
merfish_sdata.table.obs

Unnamed: 0,cell_id,region
0,0,cells
1,1,cells
2,2,cells
3,3,cells
4,4,cells
...,...,...
2394,2394,cells
2395,2395,cells
2396,2396,cells
2397,2397,cells


In [89]:
out_store = zarr.open(out_zarr_path, mode="a")

In [7]:
merfish_sdata.table.obs['region']

0       cells
1       cells
2       cells
3       cells
4       cells
        ...  
2394    cells
2395    cells
2396    cells
2397    cells
2398    cells
Name: region, Length: 2399, dtype: category
Categories (1, object): ['cells']

In [8]:
df = merfish_sdata.points['single_molecule']

In [9]:
x_min = df['x'].min().compute()
x_max = df['x'].max().compute()

y_min = df['y'].min().compute()
y_max = df['y'].max().compute()

In [10]:
c_cats = df['cell_type'].unique().compute().tolist()
c_cat_to_index = dict([(c_cat, c_cats.index(c_cat)) for c_cat in c_cats])
c_cat_to_index

{'outside_VISp': 0,
 'VISp_wm': 1,
 'VISp_VI': 2,
 'VISp_V': 3,
 'VISp': 4,
 'VISp_IV': 5,
 'VISp_II/III': 6,
 'VISp_I': 7}

In [11]:
x_min, x_max, y_min, y_max

(1154.3634, 3171.979, 4548.483, 6565.997)

In [77]:
scale_factor = 4

pixel_grid_shape = (np.ceil(x_max - x_min) * scale_factor, np.ceil(y_max - y_min) * scale_factor)
pixel_grid_shape

(8072.0, 8072.0)

In [78]:
z = zarr.zeros(pixel_grid_shape)
# Let zarr automatically determine the chunk shape
# TODO: directly call zarr.normalize_chunks()
chunk_shape = z.chunks

In [79]:
pixel_grid_shape[0] % chunk_shape[0]

497.0

In [80]:
nice_num_chunks = (int(np.ceil(pixel_grid_shape[0] / chunk_shape[0])), int(np.ceil(pixel_grid_shape[1] / chunk_shape[1])))

In [81]:
nice_pixel_grid_shape = (nice_num_chunks[0] * chunk_shape[0], nice_num_chunks[1] * chunk_shape[1])

In [82]:
nice_pixel_grid_shape

(8080, 8080)

In [83]:
z_x = zarr.zeros(nice_pixel_grid_shape, chunks=chunk_shape)
z_y = zarr.zeros(nice_pixel_grid_shape, chunks=chunk_shape)
z_c = zarr.zeros(nice_pixel_grid_shape, chunks=chunk_shape)

In [84]:
df['x'] = ((df['x'] - x_min) / (x_max - x_min)) * pixel_grid_shape[0]
df['y'] = ((df['y'] - y_min) / (y_max - y_min)) * pixel_grid_shape[1]

In [85]:
df.head()

Unnamed: 0,x,y,cell_type,c
0,-2628.446533,5367.57959,outside_VISp,0
1,-1855.569946,5419.486328,outside_VISp,0
2,-2414.878662,5986.996094,outside_VISp,0
3,-2273.074707,6602.478027,outside_VISp,0
4,-3369.111572,6684.087402,outside_VISp,0


In [86]:
unique_c = df['cell_type'].unique().compute()
c_str_to_i = dict([(c_str, c_i) for c_i, c_str in enumerate(unique_c.values) ])
c_str_to_i

{'outside_VISp': 0,
 'VISp_wm': 1,
 'VISp_VI': 2,
 'VISp_V': 3,
 'VISp': 4,
 'VISp_IV': 5,
 'VISp_II/III': 6,
 'VISp_I': 7}

In [87]:
df['c'] = df['cell_type'].replace(c_str_to_i)

In [88]:
def pad_with_zeros(chunk_vals, chunk_len, val_shape, val_len):
    out_vals = np.zeros((val_len,))
    out_vals[0:chunk_len] = chunk_vals
    return out_vals.reshape(val_shape).astype(int) # TODO: more sophisticated rounding

In [91]:
# TODO: move this code after the for loop
out_store['/x'] = z_x
out_store['/y'] = z_y
out_store['/c'] = z_c

In [90]:
# TODO: use dask to do this in parallel?
for x_chunk_i in tqdm(range(nice_num_chunks[0])):
    for y_chunk_i in range(nice_num_chunks[1]):
        x_offset = x_chunk_i * chunk_shape[0]
        y_offset = y_chunk_i * chunk_shape[1]
        x_chunk_max = (x_chunk_i + 1) * chunk_shape[0]
        y_chunk_max = (y_chunk_i + 1) * chunk_shape[0]

        chunk_df = df.loc[(df['x'] >= x_offset) & (df['x'] < x_chunk_max) & (df['y'] >= y_offset) & (df['y'] < y_chunk_max)]
        
        chunk_x_vals = chunk_df['x'].sub(x_offset).values.compute()
        chunk_y_vals = chunk_df['y'].sub(y_offset).values.compute()
        chunk_c_vals = chunk_df['c'].values.compute()

        chunk_len = chunk_df.shape[0].compute()

        # All x/y values here should be less than the chunk shape width/height.

        val_shape = (chunk_shape[0], np.ceil(chunk_len / chunk_shape[0]).astype(int))
        val_len = val_shape[0] * val_shape[1]

        if chunk_len > val_len:
            raise ValueError("values do not fit in chunk")
        
        if chunk_len > 0:
            z_x[x_offset:x_offset+val_shape[0], y_offset:y_offset+val_shape[1]] = pad_with_zeros(chunk_x_vals, chunk_len, val_shape, val_len)
            z_y[x_offset:x_offset+val_shape[0], y_offset:y_offset+val_shape[1]] = pad_with_zeros(chunk_y_vals, chunk_len, val_shape, val_len)
            z_c[x_offset:x_offset+val_shape[0], y_offset:y_offset+val_shape[1]] = pad_with_zeros(chunk_c_vals, chunk_len, val_shape, val_len)

100%|███████████████████████████████████████████████████████████████████████████████████████| 16/16 [01:54<00:00,  7.13s/it]


In [48]:
z_c

<zarr.core.Array (8072, 8072) float64>