## One Model

rosshandler: https://pubmed.ncbi.nlm.nih.gov/37982461/ \
nowotschin: https://pubmed.ncbi.nlm.nih.gov/30959515/ \
see: https://scanpy.readthedocs.io/en/stable/tutorials/experimental/dask.html \
see: https://rapids-singlecell.readthedocs.io/en/latest/notebooks/demo_gpu-seuratv3-brain-1M.html

In [1]:
%%time
# importing our library 
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
#from scvi.models.utils import mde
import tempfile
import os
import scipy
from scipy.io import mmread
from scipy.io import mmwrite
import anndata as ad
import warnings 
import anndata as ad

import dask
import dask.array as da
#import dask.distributed as dd
import graphviz

from dask import delayed
import h5py


CPU times: user 2.41 s, sys: 636 ms, total: 3.05 s
Wall time: 29.3 s


## Dask Setup

In [2]:
## setting graphics
dask.config.set({"visualization.engine": "graphviz"});
## Problem with memory spilling just solved itself
os.environ['MALLOC_TRIM_THRESHOLD_'] = '0'

In [3]:
#cluster = dd.LocalCluster(n_workers=24, memory_limit = 8*10**9)
#client = dd.Client(cluster)

In [4]:
#client

## Data loading & setting up anndata objects

In [5]:
chunk_size = [5000, -1]

In [6]:
%%time
## reading in the raw counts via raw/X
with h5py.File("/data/hadjantalab/atlas/extAtlas/embryo_complete.h5ad", "r") as f:
    adata_rosshandler = ad.AnnData(
        obs=ad.experimental.read_elem(f["obs"]),
        var=ad.experimental.read_elem(f["var"]),
        X = ad.experimental.read_elem(f["raw/X"])    
    )
adata_rosshandler.X = adata_rosshandler.X.astype(np.float32)
adata_rosshandler.X = da.from_array(adata_rosshandler.X)
adata_rosshandler.X = da.rechunk(adata_rosshandler.X, chunk_size)

CPU times: user 1min 7s, sys: 4.24 s, total: 1min 11s
Wall time: 1min 13s


In [7]:
## adata_rosshandler object
print(adata_rosshandler)
adata_rosshandler.X

AnnData object with n_obs × n_vars = 430339 × 27669
    obs: 'cell', 'sample', 'embryo_version', 'stage', 'somite_count', 'anatomy', 'S_score', 'G2M_score', 'phase', 'louvain', 'leiden', 'celltype_PijuanSala2019', 'celltype_extended_atlas'
    var: 'highly_variable', 'means', 'dispersions', 'dispersions_norm', 'mgi_symbol'


Unnamed: 0,Array,Chunk
Shape,"(430339, 27669)","(5000, 27669)"
Dask graph,87 chunks in 3 graph layers,87 chunks in 3 graph layers
Data type,float32 scipy.sparse._csr.csr_matrix,float32 scipy.sparse._csr.csr_matrix
"Array Chunk Shape (430339, 27669) (5000, 27669) Dask graph 87 chunks in 3 graph layers Data type float32 scipy.sparse._csr.csr_matrix",27669  430339,

Unnamed: 0,Array,Chunk
Shape,"(430339, 27669)","(5000, 27669)"
Dask graph,87 chunks in 3 graph layers,87 chunks in 3 graph layers
Data type,float32 scipy.sparse._csr.csr_matrix,float32 scipy.sparse._csr.csr_matrix


In [8]:
# modulating some metadata
adata_rosshandler.var['mgi_symbol'] = adata_rosshandler.var['mgi_symbol'].astype('object')
adata_rosshandler.obs['batch'] = 'pijuan-sala'
adata_rosshandler.obs['batch'][adata_rosshandler.obs['embryo_version'] == 'Extension'] = 'rosshandler'
adata_rosshandler.var['id'] = adata_rosshandler.var_names
adata_rosshandler.var.index = adata_rosshandler.var['mgi_symbol']

A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  adata_rosshandler.obs['batch'][adata_rosshandler.obs['embryo_version'] == 'Extension'] = 'rosshandler'


In [9]:
%%time
with h5py.File("/data/hadjantalab/atlas/endoderm/sc_endoderm_all_cells.h5ad", "r") as f:
    adata_nowotschin = ad.AnnData(
        obs=ad.experimental.read_elem(f["obs"]),
        var=ad.experimental.read_elem(f["var"]),
        X = ad.experimental.read_elem(f["X"])    
    )

## reading in the raw counts via raw/X
adata_nowotschin.X = da.from_array(adata_nowotschin.X)
adata_nowotschin.X = da.rechunk(adata_nowotschin.X, chunk_size)

CPU times: user 3.49 s, sys: 1.15 s, total: 4.64 s
Wall time: 7.2 s


In [10]:
# adata_nowotschin object
print(adata_nowotschin)
adata_nowotschin.X

AnnData object with n_obs × n_vars = 113051 × 20897
    obs: 'Cluster', 'Timepoint', 'CellType'


Unnamed: 0,Array,Chunk
Shape,"(113051, 20897)","(5000, 20897)"
Dask graph,23 chunks in 3 graph layers,23 chunks in 3 graph layers
Data type,float32 scipy.sparse._csr.csr_matrix,float32 scipy.sparse._csr.csr_matrix
"Array Chunk Shape (113051, 20897) (5000, 20897) Dask graph 23 chunks in 3 graph layers Data type float32 scipy.sparse._csr.csr_matrix",20897  113051,

Unnamed: 0,Array,Chunk
Shape,"(113051, 20897)","(5000, 20897)"
Dask graph,23 chunks in 3 graph layers,23 chunks in 3 graph layers
Data type,float32 scipy.sparse._csr.csr_matrix,float32 scipy.sparse._csr.csr_matrix


In [11]:
# modulating some metadata 
adata_nowotschin.obs['cell'] = adata_nowotschin.obs_names
adata_nowotschin.obs.index = adata_nowotschin.obs['cell']
adata_nowotschin.obs['celltype_nowotschin'] = adata_nowotschin.obs['CellType']
del adata_nowotschin.obs['CellType']
adata_nowotschin.obs['timepoint_nowotschin'] = adata_nowotschin.obs['Timepoint']
del adata_nowotschin.obs['Timepoint']
adata_nowotschin.obs['cluster_nowotschin'] = adata_nowotschin.obs['Cluster']
del adata_nowotschin.obs['Cluster']
adata_nowotschin.obs['batch'] = 'nowotschin'
adata_nowotschin.var['mgi_symbol'] = adata_nowotschin.var_names

## Concatenating with dask

In [12]:
## var.index unique? otherwise concat will not work
print('Nowotschin: ', len(adata_nowotschin.var.index) == len(set(adata_nowotschin.var.index)))
print('Rosshandler: ', len(adata_rosshandler.var.index) == len(set(adata_rosshandler.var.index)))


Nowotschin:  True
Rosshandler:  False


In [13]:
adata_rosshandler.var_names_make_unique()

In [14]:
%%time
import dask
from dask import delayed
import anndata as ad
import math

# Configure Dask to use multiple processors
dask.config.set(scheduler='threads', num_workers=10, memory_limit='230GB')
dask.config.set({'array.slicing.split_large_chunks': True})

# Define a function to perform delayed concatenation of all chunks at once
@delayed
def concatenate_all_chunks(chunks):
    return ad.concat(chunks, axis=0, join='outer', fill_value=0)

# Step 1: Determine appropriate chunk sizes
chunk_size_ross = int(math.ceil(adata_rosshandler.shape[0] / 5 / 1e4) * 1e4)
chunk_size_nowo = int(math.ceil(adata_nowotschin.shape[0] / 1e4) * 1e4)

# Step 2: Create chunks for both datasets
chunks_ross = [adata_rosshandler[i:i + chunk_size_ross] for i in range(0, len(adata_rosshandler), chunk_size_ross)]
chunks_nowo = [adata_nowotschin[i:i + chunk_size_nowo] for i in range(0, len(adata_nowotschin), chunk_size_nowo)]

# Step 3: Concatenate chunks from both datasets
# Concatenate both chunk lists at once (no need to pair them manually)
all_chunks = chunks_ross + chunks_nowo

# Step 4: Delayed final concatenation of all chunks
final_concatenated_data = concatenate_all_chunks(all_chunks)

# Compute the final concatenated anndata object
adata = dask.compute(final_concatenated_data)[0]

# Shape verification
expected_shape = (
    adata_rosshandler.shape[0] + adata_nowotschin.shape[0],
    len(set(adata_rosshandler.var_names).union(set(adata_nowotschin.var_names))),
)
print("Expected shape:", expected_shape)
print("Observed shape:", adata.shape)


Expected shape: (543390, 28652)
Observed shape: (543390, 28652)
CPU times: user 411 ms, sys: 28.7 ms, total: 439 ms
Wall time: 2.3 s


In [None]:
final_concatenated_chunks.visualize()

In [None]:
final_concatenated_chunks.dask

In [18]:
adata

AnnData object with n_obs × n_vars = 543390 × 28652
    obs: 'cell', 'sample', 'embryo_version', 'stage', 'somite_count', 'anatomy', 'S_score', 'G2M_score', 'phase', 'louvain', 'leiden', 'celltype_PijuanSala2019', 'celltype_extended_atlas', 'batch', 'celltype_nowotschin', 'timepoint_nowotschin', 'cluster_nowotschin'

In [25]:
adata.X = adata.X.rechunk(chunk_size)
adata.X

Unnamed: 0,Array,Chunk
Shape,"(543390, 28652)","(5000, 28652)"
Dask graph,109 chunks in 25 graph layers,109 chunks in 25 graph layers
Data type,float32 scipy.sparse._csr.csr_matrix,float32 scipy.sparse._csr.csr_matrix
"Array Chunk Shape (543390, 28652) (5000, 28652) Dask graph 109 chunks in 25 graph layers Data type float32 scipy.sparse._csr.csr_matrix",28652  543390,

Unnamed: 0,Array,Chunk
Shape,"(543390, 28652)","(5000, 28652)"
Dask graph,109 chunks in 25 graph layers,109 chunks in 25 graph layers
Data type,float32 scipy.sparse._csr.csr_matrix,float32 scipy.sparse._csr.csr_matrix


In [26]:
adata.layers['dask'] = adata.X.copy()

In [27]:
adata.obs.index = adata.obs['cell']
del adata.obs['cell']

KeyError: 'cell'

## Back to csr matrix

In [28]:
%%time
@dask.delayed
def get_chunk(c): 
    return c

chunks = adata.X.to_delayed().ravel()
results = [da.from_delayed(get_chunk(c), shape = chunk_size, dtype = np.float32) for c in chunks]
arr = da.concatenate(results, axis = 0, allow_unknown_chunksizes = True)

CPU times: user 54.3 ms, sys: 4.78 ms, total: 59.1 ms
Wall time: 58.3 ms


In [29]:
arr

Unnamed: 0,Array,Chunk
Bytes,-2180000 B,-20000 B
Shape,"(545000, -1)","(5000, -1)"
Dask graph,109 chunks in 220 graph layers,109 chunks in 220 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes -2180000 B -20000 B Shape (545000, -1) (5000, -1) Dask graph 109 chunks in 220 graph layers Data type float32 numpy.ndarray",-1  545000,

Unnamed: 0,Array,Chunk
Bytes,-2180000 B,-20000 B
Shape,"(545000, -1)","(5000, -1)"
Dask graph,109 chunks in 220 graph layers,109 chunks in 220 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray


In [30]:
#arr.visualize()

In [31]:
%%time
adata_x = arr.compute()

  self._set_arrayXarray(i, j, x)


CPU times: user 26min 29s, sys: 2min 3s, total: 28min 32s
Wall time: 3min 49s


In [32]:
adata_x

<543390x28652 sparse matrix of type '<class 'numpy.float32'>'
	with 3350555488 stored elements in Compressed Sparse Row format>

In [33]:
adata.X = adata_x

## Saving data to disk

In [34]:
adata.obs.to_csv("/data/hadjantalab/lucas/atlas/data/dask/adata_obs.csv", index = 'cell')
adata.var.to_csv("/data/hadjantalab/lucas/atlas/data/dask/adata_var.csv")

In [35]:
%%time
mmwrite("/data/hadjantalab/lucas/atlas/data/dask/adata_X.mtx", a = adata_x)


CPU times: user 36min 58s, sys: 20.7 s, total: 37min 18s
Wall time: 4min 22s


In [36]:
def save_sparse_csr(filename, array):
    np.savez(filename, data=array.data, indices=array.indices,
             indptr=array.indptr, shape=array.shape)

In [37]:
%%time
save_sparse_csr("/data/hadjantalab/lucas/atlas/data/dask/adata_X.npz", adata_x)

CPU times: user 33.2 s, sys: 7.32 s, total: 40.5 s
Wall time: 41 s
