In [1]:
import scanpy as sc
import numpy as np
import pandas as pd

from data_utils import *
from tqdm.notebook import tqdm

import scipy.sparse as sp

import gc

from joblib import Parallel, delayed



In [2]:
ad = sc.read_h5ad("/home/gokul/Splatter_libloc_750_11.4.h5ad")
ad.X = sp.csr_matrix(ad.X)

In [3]:
full_umi = ad.X.sum()/len(ad)

In [4]:
def _process_batch(X_batch, quality):
    X_batch = X_batch.tolil(copy=True)  # lil = easy row-wise ops!
    for i in range(X_batch.shape[0]):
        row = X_batch.rows[i]
        data = X_batch.data[i]
        if len(data) == 0:
            continue
        counts = np.array(data, dtype=np.int32)
        down = downsample_array(counts, quality)
        X_batch.data[i] = down.tolist()
    return X_batch.tocsr()

def batched_get_ad_with_quality(adata, quality, batch_size=1000, n_jobs=-1):
    """
    parallel batched downsampling! faster & fun :)
    """
    adata = adata.copy()
    X = adata.X.tocsr()
    n_cells = X.shape[0]

    batches = [
        X[i:i+batch_size].copy()
        for i in range(0, n_cells, batch_size)
    ]

    processed = Parallel(n_jobs=n_jobs)(
        delayed(_process_batch)(batch, quality) for batch in batches
    )

    X_down = sp.vstack(processed)
    
    print(X_down.sum()/X_down.shape[0])
    
    adata.X = X_down



    return adata

In [5]:
for i in range(10):
    downsampled = batched_get_ad_with_quality(ad, 0.5)
    print(downsampled.X.sum()/len(downsampled))
    suffix = 2**-(i+1)
    downsampled.write_h5ad(f"../splatter/Splatter_downsampled_Q{suffix:.3f}.h5ad")
    ad = downsampled



45558.296576
45558.296576
22778.898432
22778.898432
11389.19936
11389.19936
5694.348288
5694.348288
2846.924032
2846.924032
1423.211904
1423.211904
711.35552
711.35552
355.427904
355.427904
177.464304
177.464304
88.482264
88.482264
