This is an experiment to see if we can make training go faster
withouth loading the whole dataset into RAM.

The question is, if we pick the channels that we care about and store the nuclei as a single dataset like:

```
hdf5:
    /cells/ # shape N x H x W x C
```

Does that fix the loading bottleneck?


In [None]:
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))

In [None]:
%load_ext autoreload
%autoreload 2
import h5py
import numpy as np
from micron2 import hdf5_info

import tqdm.auto as tqdm

In [None]:
h5in = '/storage/codex/datasets_v1/bladder_merged_v4.hdf5'
hdf5_info(h5in)
# with h5py.File(h5in, "r") as f:
#     print(f.keys())
#     print(f.keys())

In [None]:
with h5py.File(h5in, 'r') as f:
    channels = [x.decode('utf-8') for x in f['meta/channel_names'][:]]
    n_cells = f['meta/Cell_IDs'].shape[0]
    shape = f[f'cells/DAPI'].shape[1]
print(channels)
print(n_cells)
print(shape)

In [None]:
h5out = '/home/ingn/tmp/micron2-data/bladder_merged_v4.cells.hdf5'
# h5out = '/home/ingn/tmp/micron2-data/210113_Breast_Cassette11_reg1_nosubtract.mergedCells.hdf5'
# h5out = '/home/ingn/tmp/micron2-data/210122_Breast_Cassette7_reg2.hdf5.cells.hdf5'

In [None]:
with h5py.File(h5out, "w", rdcc_nbytes=shape*shape*len(channels)*200000) as fout, h5py.File(h5in, "r") as fin:
    d = fout.create_dataset('/images/cells', shape=(n_cells, shape, shape, len(channels)),
                            chunks=(1, shape, shape, 1),
                            dtype='uint8',
                            compression='gzip')
    print('finished creating dataset')
#     with h5py.File(h5in, "r") as fin:
    with tqdm.trange(n_cells) as pbar:
        # pbar.set_description(f'channel: {ch} ({i}/{len(channels)})')
        for j in pbar:
            img = np.zeros((1,shape,shape,len(channels)))
            for i,ch in enumerate(channels):  
                img[0,:,:,i:i+1] = np.expand_dims(fin[f'cells/{ch}'][j:j+1,...], axis=-1)
                
#                 d[j:j+1,:,:,i:i+i] = np.expand_dims(np.expand_dims(fin[f'cells/{ch}'][j,...], axis=0), axis=-1)
            d[j:j+1,:,:,:] = img
            #fout.flush()

    for k in fin['meta'].keys():
        _ = fout.create_dataset(f'meta/{k}', data=fin[f'meta/{k}'][:])


In [None]:
hdf5_info(h5out)