In [None]:
%load_ext autoreload
%autoreload 2

import numpy as np
import pandas as pd
import micron2

import h5py
import anndata
import tqdm.auto as tqdm

%matplotlib inline
from matplotlib import pyplot as plt

In [None]:
!ls -lha /home/ingn/tmp/micron2-data

## Define neighborhoods from xy-coordinates

In [None]:
data = h5py.File("/home/ingn/tmp/micron2-data/dataset.hdf5", "r")

In [None]:
channel_names = [s.decode('utf-8') for s in data['meta']['channel_names'][:]]
print(channel_names)

In [None]:
size = data['cells/DAPI'].shape[1]

In [None]:
from sklearn.neighbors import NearestNeighbors
coords = data['meta/coordinates'][:]
print(coords.shape)

NBR = NearestNeighbors(n_neighbors=5, metric='minkowski', p=2)
NBR.fit(coords)
nbr = NBR.kneighbors(return_distance=False)

print(nbr.shape)

In [None]:
def stack_neighbors(query, neighbors, src_dataset):
    """
    Args:
        query (int): the query cell, always first
        neighbors (list): list of indices for the neighbors
        dataset (HDF5 dataset): an open dataset. 
            something where `dataset[1,...]` works
    Returns:
        stack (np.ndarray): (N, H, W) stacked images same dtype as the input
    """
    stack = np.stack([src_dataset[query,...]] + [src_dataset[i] for i in neighbors], axis=0)
    return stack

In [None]:
# %%timeit
i = np.random.choice(nbr.shape[0])
s = stack_neighbors(i, nbr[i], data['cells/C1q'])
print(s.shape)

In [None]:
out_h5 = h5py.File("/home/ingn/tmp/micron2-data/setdataset.hdf5", "w")

In [None]:
n_cells = nbr.shape[0]

# sample_rate = 0.25
# n_sample = int(n_cells * sample_rate)
# indices = np.random.choice(n_cells, n_sample, replace=False)
# print(n_sample, indices.shape)

indices = np.arange(n_cells)
n_sample = n_cells

n_neighbors = 5
datasets = {}
for c in channel_names:
    d = out_h5.create_dataset(f'cells/{c}', 
                              shape=(n_sample,n_neighbors+1,size,size), 
                              maxshape=(None,n_neighbors+1,size,size),
                              dtype='uint8', 
                              chunks=(1,1,size,size), # ?
                              compression='gzip')
    datasets[c] = d

for c in tqdm.tqdm(channel_names):
    print(c)
    d = datasets[c]
    for nx, i in enumerate(indices):
        s = stack_neighbors(i, nbr[i], data[f'cells/{c}'])
        d[nx,...] = s
        
    out_h5.flush()

In [None]:
out_h5.close()

## Test the dataset

In [None]:
from micron2 import stream_dataset
dataset = stream_dataset('/home/ingn/tmp/micron2-data/setdataset.hdf5')

for k in dataset:
    print(k.shape)
    break