In [1]:
import os
import tempfile
from pathlib import Path
import pandas as pd
import numpy as np
import pyBigWig
import copy
import xarray as xr
import tqdm
import crandata


# Create temporary directories for synthetic data.
temp_dir = tempfile.TemporaryDirectory()
base_dir = Path(temp_dir.name)
beds_dir = base_dir / "beds"
bigwigs_dir = base_dir / "bigwigs"
beds_dir.mkdir(exist_ok=True)
bigwigs_dir.mkdir(exist_ok=True)

# Create a chromsizes file.
chromsizes_file = base_dir / "chrom.sizes"
with open(chromsizes_file, "w") as f:
    f.write("chr1\t1000\n")

# Create two BED files (simulate two different classes).
bed_data_A = pd.DataFrame({
    0: ["chr1", "chr1"],
    1: [100, 300],
    2: [200, 400]
})
bed_data_B = pd.DataFrame({
    0: ["chr1", "chr1"],
    1: [150, 350],
    2: [250, 450]
})
bed_file_A = beds_dir / "ClassA.bed"
bed_file_B = beds_dir / "ClassB.bed"
bed_data_A.to_csv(bed_file_A, sep="\t", header=False, index=False)
bed_data_B.to_csv(bed_file_B, sep="\t", header=False, index=False)

# Create a consensus BED file.
consensus = pd.DataFrame({
    0: ["chr1", "chr1", "chr1"],
    1: [100, 300, 350],
    2: [200, 400, 450]
})
consensus_file = base_dir / "consensus.bed"
consensus.to_csv(consensus_file, sep="\t", header=False, index=False)

# Create two bigWig files.
bigwig_file1 = bigwigs_dir / "test.bw"
bw1 = pyBigWig.open(str(bigwig_file1), "w")
bw1.addHeader([("chr1", 1000)])
bw1.addEntries(chroms=["chr1"], starts=[0], ends=[1000], values=[5.0])
bw1.close()

bigwig_file2 = bigwigs_dir / "test2.bw"
bw2 = pyBigWig.open(str(bigwig_file2), "w")
bw2.addHeader([("chr1", 1000)])
bw2.addEntries(chroms=["chr1"], starts=[0], ends=[1000], values=[4.0])
bw2.close()

# Set extraction parameters.
target_region_width = 100
backed_path = base_dir / "chrom_data.zarr"
print(backed_path)
# Create the CrAnData object from bigWig files and consensus regions.
adata = crandata.chrom_io.import_bigwigs(
    bigwigs_folder=str(bigwigs_dir),
    regions_file=str(consensus_file),
    backed_path=str(backed_path),
    target_region_width=target_region_width,
    chromsizes_file=str(chromsizes_file),
)

crandata.train_val_test_split(adata,strategy='chr_auto')

# Create a dummy FASTA file for a genome.
fasta_file = base_dir / "chr1.fa"
with open(fasta_file, "w") as f:
    f.write(">chr1\n")
    f.write("A" * 1000 + "\n")

dummy_genome = crandata.Genome(str(fasta_file), chrom_sizes=str(chromsizes_file))

# Add sequences to the CrAnData using the provided seq_io utility.
# Here we use the consensus regions as our ranges.
consensus.columns = ['chrom', 'start', 'end']
adata = crandata.seq_io.add_genome_sequences_to_crandata(adata, consensus, dummy_genome)

print(adata)
# Write the CrAnData object to disk and then reload it to ensure sequences are out-of-memory.
adata.to_zarr(str(backed_path),mode='a')
adata_loaded = crandata.CrAnData.open_zarr(str(backed_path))
print("Loaded CrAnData:")
print(adata_loaded)

# Create two copies to simulate two datasets (e.g. two species), and add a "split" column in var metadata.
adata1 = copy.deepcopy(adata_loaded)
adata2 = copy.deepcopy(adata_loaded)
adata1["var-_-split"] = xr.DataArray(np.full(adata1.sizes["var"], "train"), dims=["var"])
adata2["var-_-split"] = xr.DataArray(np.full(adata2.sizes["var"], "train"), dims=["var"])

# Create a DNATransform instance.
transform = crandata.seq_io.DNATransform(out_len=80, random_rc=True, max_shift=5)

# Instantiate the MetaCrAnDataModule with the two datasets.
# Note: The batch_size is now 3, matching the number of consensus regions (var dimension).
meta_module = crandata.MetaCrAnDataModule(
    adatas=[adata1, adata2],
    batch_size=[2,2],        # adjust batch size to not exceed var length (3)
    load_keys={'sequences':'sequences','X':'X'},
    shuffle=True,
    dnatransform=transform,
    epoch_size=10
)

meta_module.setup('train')

# Retrieve the training dataloader from the meta module and iterate over a couple of batches.
meta_train_dl = meta_module.train_dataloader
print("\nIterating over a couple of training batches from MetaCrAnDataModule:")
for i, batch in enumerate(tqdm.tqdm(meta_train_dl)):
    print(batch)
    print(f"\nMeta Batch {i}:")
    for key, tensor in batch.items():
        print(f"  {key}: shape {tensor.shape}")
    if i >= 1:
        break

print("\nTemporary directory contents:")
print(os.listdir(base_dir))
temp_dir.cleanup()


  cls = super().__new__(mcls, name, bases, namespace, **kwargs)


/scratch/fast/168840/tmp99yl2au7/chrom_data.zarr


100%|██████████| 2/2 [00:00<00:00, 4362.25it/s]
[32m2025-04-14 10:36:11.599[0m | [1mINFO    [0m | [36mcrandata.chrom_io[0m:[36mimport_bigwigs[0m:[36m330[0m - [1mExtracting values from 2 bigWig files...[0m
  return cls(**configuration_parsed)
  result = await AsyncArray._create_v3(
  return cls(**configuration_parsed)
  result = await AsyncArray._create_v3(
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  result = await AsyncArray._create_v3(
  return cls(**configuration_parsed)
  result = await AsyncArray._create_v3(
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
Processing consensus peak memor

<xarray.CrAnData> Size: 7kB
Dimensions:            (obs: 2, var: 3, seq_bins: 100, seq_len: 100, nuc: 4)
Coordinates:
  * obs                (obs) object 16B 'test' 'test2'
  * var                (var) object 24B 'chr1:100-200' ... 'chr1:350-450'
  * seq_bins           (seq_bins) int64 800B 0 1 2 3 4 5 6 ... 94 95 96 97 98 99
Dimensions without coordinates: seq_len, nuc
Data variables:
    obs-_-index        (obs) object 16B dask.array<chunksize=(2,), meta=np.ndarray>
    var-_-end          (var) int64 24B dask.array<chunksize=(3,), meta=np.ndarray>
    obs-_-file_path    (obs) object 16B dask.array<chunksize=(2,), meta=np.ndarray>
    var-_-chunk_index  (var) int64 24B dask.array<chunksize=(3,), meta=np.ndarray>
    var-_-chrom        (var) object 24B dask.array<chunksize=(3,), meta=np.ndarray>
    var-_-index        (var) object 24B dask.array<chunksize=(3,), meta=np.ndarray>
    var-_-start        (var) int64 24B dask.array<chunksize=(3,), meta=np.ndarray>
    X                  (ob

0it [00:00, ?it/s]

{'sequences': array([[[1, 0, 0, 0],
        [1, 0, 0, 0],
        [1, 0, 0, 0],
        ...,
        [1, 0, 0, 0],
        [1, 0, 0, 0],
        [1, 0, 0, 0]],

       [[1, 0, 0, 0],
        [1, 0, 0, 0],
        [1, 0, 0, 0],
        ...,
        [1, 0, 0, 0],
        [1, 0, 0, 0],
        [1, 0, 0, 0]],

       [[1, 0, 0, 0],
        [1, 0, 0, 0],
        [1, 0, 0, 0],
        ...,
        [1, 0, 0, 0],
        [1, 0, 0, 0],
        [1, 0, 0, 0]],

       [[1, 0, 0, 0],
        [1, 0, 0, 0],
        [1, 0, 0, 0],
        ...,
        [1, 0, 0, 0],
        [1, 0, 0, 0],
        [1, 0, 0, 0]]], dtype=uint8), 'X': array([[[6.90964373e-310, 6.90964373e-310, 4.68488233e-310,
         4.68488233e-310, 6.90943458e-310, 6.90943452e-310,
         6.90943482e-310, 4.68488233e-310, 6.90943501e-310,
         6.90960911e-310, 6.90943501e-310, 0.00000000e+000,
         6.90943617e-310, 6.90960911e-310, 6.90943617e-310,
         6.90943469e-310, 6.90960911e-310, 6.90943469e-310,
         6.90943469

1it [00:00, 18.01it/s]



Meta Batch 0:
  sequences: shape (4, 100, 4)
  X: shape (2, 4, 100)
{'sequences': array([[[1, 0, 0, 0],
        [1, 0, 0, 0],
        [1, 0, 0, 0],
        ...,
        [1, 0, 0, 0],
        [1, 0, 0, 0],
        [1, 0, 0, 0]],

       [[1, 0, 0, 0],
        [1, 0, 0, 0],
        [1, 0, 0, 0],
        ...,
        [1, 0, 0, 0],
        [1, 0, 0, 0],
        [1, 0, 0, 0]],

       [[1, 0, 0, 0],
        [1, 0, 0, 0],
        [1, 0, 0, 0],
        ...,
        [1, 0, 0, 0],
        [1, 0, 0, 0],
        [1, 0, 0, 0]],

       [[1, 0, 0, 0],
        [1, 0, 0, 0],
        [1, 0, 0, 0],
        ...,
        [1, 0, 0, 0],
        [1, 0, 0, 0],
        [1, 0, 0, 0]]], dtype=uint8), 'X': array([[[6.90964373e-310, 6.90964373e-310, 4.68488233e-310,
         4.68488233e-310, 6.90943458e-310, 6.90943452e-310,
         6.90943482e-310, 4.68488233e-310, 6.90943501e-310,
         6.90960911e-310, 6.90943501e-310, 0.00000000e+000,
         6.90943617e-310, 6.90960911e-310, 6.90943617e-310,
         




In [2]:
sdfs

NameError: name 'sdfs' is not defined

In [None]:
# Should the fill in _extract_values_from_bigwig actually be 0? Can we filter var where all is 0/nan without loading everything into memory?

In [1]:
import crandata
import xarray as xr
import pandas as pd
import numpy as np
import os
import crested
from tqdm import tqdm
import importlib

  cls = super().__new__(mcls, name, bases, namespace, **kwargs)


In [2]:
genomes = {}
beds = {}
chromsizes_files = {}
bed_files = {}
species = ['human','macaque','mouse']
species_codes = {'human':0,'macaque':1,'mouse':2}

MAX_SHIFT = 5
WINDOW_SIZE = 2114
WINDOW_SIZE = WINDOW_SIZE #+ 2*MAX_SHIFT
OFFSET = WINDOW_SIZE // 2  # e.g., 50% overlap
N_THRESHOLD = 0.3
n_bins = WINDOW_SIZE//50


In [3]:
for s in species:
    genome_path = '/allen/programs/celltypes/workgroups/rnaseqanalysis/EvoGen/Team/Matthew/genome/onehots/'+s
    fasta_file = os.path.join(genome_path,s+'.fa')
    chrom_sizes = os.path.join(genome_path,s+'.fa.sizes')
    annotation_gtf_file = os.path.join(genome_path,s+'.annotation.gtf')
    chromsizes_files[s] = chrom_sizes
    genome = crandata.Genome(fasta_file, chrom_sizes, annotation_gtf_file)
    genome.to_memory()
    genomes[s] = genome
    OUTPUT_BED = os.path.join(genome_path, "binned_genome.bed")
    bed_files[s] = OUTPUT_BED
    # Generate bins and optionally write to disk.
    binned_df = crandata.bin_genome(genome, WINDOW_SIZE, OFFSET, n_threshold=N_THRESHOLD, output_path=OUTPUT_BED).reset_index(drop=True)
    print("Filtered bins:")
    print(binned_df)


2025-04-15T00:05:12.810741-0700 INFO Genome sequences loaded into memory.


Calculating N content: 100%|██████████| 2932321/2932321 [04:46<00:00, 10250.21it/s]


Filtered bins:
              chrom  start    end   prop_n
0              chr1   9514  11628  0.23026
1              chr1  10571  12685  0.00000
2              chr1  11628  13742  0.00000
3              chr1  12685  14799  0.00000
4              chr1  13742  15856  0.00000
...             ...    ...    ...      ...
2786513  KI270518.1      1   2115  0.00000
2786514  KI270530.1      1   2115  0.00000
2786515  KI270304.1      1   2115  0.00000
2786516  KI270418.1      1   2115  0.00000
2786517  KI270424.1      1   2115  0.00000

[2786518 rows x 4 columns]
2025-04-15T00:10:51.943113-0700 INFO Genome sequences loaded into memory.


Calculating N content: 100%|██████████| 2806701/2806701 [04:31<00:00, 10348.06it/s]


Filtered bins:
               chrom  start    end  prop_n
0        NC_041754.1      1   2115     0.0
1        NC_041754.1   1058   3172     0.0
2        NC_041754.1   2115   4229     0.0
3        NC_041754.1   3172   5286     0.0
4        NC_041754.1   4229   6343     0.0
...              ...    ...    ...     ...
2773980  NC_005943.1   9514  11628     0.0
2773981  NC_005943.1  10571  12685     0.0
2773982  NC_005943.1  11628  13742     0.0
2773983  NC_005943.1  12685  14799     0.0
2773984  NC_005943.1  13742  15856     0.0

[2773985 rows x 4 columns]
2025-04-15T00:16:11.924751-0700 INFO Genome sequences loaded into memory.


Calculating N content: 100%|██████████| 2583507/2583507 [04:09<00:00, 10350.88it/s]


Filtered bins:
              chrom    start      end    prop_n
0              chr1  2999767  3001881  0.110638
1              chr1  3002938  3005052  0.085579
2              chr1  3003995  3006109  0.000000
3              chr1  3005052  3007166  0.000000
4              chr1  3006109  3008223  0.000000
...             ...      ...      ...       ...
2509462  JH584292.1     8457    10571  0.000000
2509463  JH584292.1     9514    11628  0.000000
2509464  JH584292.1    10571    12685  0.000000
2509465  JH584292.1    11628    13742  0.000000
2509466  JH584292.1    12685    14799  0.000000

[2509467 rows x 4 columns]


In [4]:
crandata.crandata = importlib.reload(crandata.crandata)
crandata.chrom_io = importlib.reload(crandata.chrom_io)

  cls = super().__new__(mcls, name, bases, namespace, **kwargs)


In [None]:
adatas = {}

for s in species:
    print(s)
    bigwigs_dir = os.path.join('/allen/programs/celltypes/workgroups/rnaseqanalysis/EvoGen/SpinalCord/manuscript/ATAC',s,'Group_bigwig')
    adatas[s] = crandata.chrom_io.import_bigwigs(
        bigwigs_folder=bigwigs_dir,
        regions_file=bed_files[s],
        backed_path='/home/matthew.schmitz/Matthew/data/test_crandata/'+s+'_spc_test.zarr',
        target_region_width=WINDOW_SIZE,
        chromsizes_file=chromsizes_files[s],
        target = 'raw',
        max_stochastic_shift=5,
        tile_size=5000,
        chunk_size=512,
        n_bins=n_bins
    )
    bed = adatas[s].get_dataframe('var').loc[:,['chrom','start','end']]
    adatas[s] = crandata.seq_io.add_genome_sequences_to_crandata(adatas[s], bed, genomes[s])
    print(adatas[s]['sequences'])
    adatas[s]['var-_-species'] = xr.DataArray(np.repeat(species_codes[s],adatas[s].sizes['var']),dims='var').chunk({'var':adatas[s].attrs['chunk_size']})
    adatas[s].to_zarr(adatas[s].encoding['source'],mode='a')
    adatas[s] = crandata.crandata.CrAnData.open_zarr('/home/matthew.schmitz/Matthew/data/test_crandata/'+s+'_spc_test.zarr')
    

human


100%|██████████| 49/49 [00:01<00:00, 26.20it/s]


2025-04-15T00:21:17.125226-0700 INFO Extracting values from 49 bigWig files...


  adata.to_zarr(str(backed_path),mode='w')
  return cls(**configuration_parsed)
  meta = AsyncArray._create_metadata_v3(
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  meta = AsyncArray._create_metadata_v3(
  return cls(**configuration_parsed)
  meta = AsyncArray._create_metadata_v3(
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
Processing memory tiles: 100%|██████████| 556/556 [36:48<00:00,  3.97s/it]
  adata['X'] = adata['X'].chunk({'obs':adata.dims['obs'],'var':chunk_size,'seq_bins':adata.dims['seq_bins']}) #enforce the same as

<xarray.DataArray 'sequences' (var: 2779994, seq_len: 2124, nuc: 4)> Size: 24GB
dask.array<xarray-<this-array>, shape=(2779994, 2124, 4), dtype=uint8, chunksize=(512, 2124, 4), chunktype=numpy.ndarray>
Coordinates:
  * var      (var) object 22MB 'chr1:9514-11628' ... 'KI270713.1:38053-40167'
Dimensions without coordinates: seq_len, nuc


  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)


macaque


100%|██████████| 50/50 [00:02<00:00, 24.94it/s]


2025-04-15T01:29:49.946587-0700 INFO Extracting values from 50 bigWig files...


  adata.to_zarr(str(backed_path),mode='w')
  return cls(**configuration_parsed)
  meta = AsyncArray._create_metadata_v3(
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  meta = AsyncArray._create_metadata_v3(
  return cls(**configuration_parsed)
  meta = AsyncArray._create_metadata_v3(
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
Processing memory tiles: 100%|██████████| 539/539 [45:24<00:00,  5.05s/it]
  adata['X'] = adata['X'].chunk({'obs':adata.dims['obs'],'var':chunk_size,'seq_bins':adata.dims['seq_bins']}) #enforce the same as before
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_

<xarray.DataArray 'sequences' (var: 2692517, seq_len: 2124, nuc: 4)> Size: 23GB
dask.array<xarray-<this-array>, shape=(2692517, 2124, 4), dtype=uint8, chunksize=(512, 2124, 4), chunktype=numpy.ndarray>
Coordinates:
  * var      (var) object 22MB 'NC_041754.1:1-2115' ... 'NC_005943.1:13742-15...
Dimensions without coordinates: seq_len, nuc


  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)


mouse


100%|██████████| 49/49 [00:01<00:00, 29.92it/s]


2025-04-15T02:47:20.994900-0700 INFO Extracting values from 48 bigWig files...


  adata.to_zarr(str(backed_path),mode='w')
  return cls(**configuration_parsed)
  meta = AsyncArray._create_metadata_v3(
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  meta = AsyncArray._create_metadata_v3(
  return cls(**configuration_parsed)
  meta = AsyncArray._create_metadata_v3(
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
Processing memory tiles: 100%|██████████| 502/502 [34:34<00:00,  4.13s/it]
  adata['X'] = adata['X'].chunk({'obs':adata.dims['obs'],'var':chunk_size,'seq_bins':adata.dims['seq_bins']}) #enforce the same as

In [None]:
# #Alternate workflow to directly write icechunks, but this is ~5x slower (better to write pure zarr3 then convert the the whole store at once)
# adatas = {}

# for s in species:
#     print(s)
#     bigwigs_dir = os.path.join('/allen/programs/celltypes/workgroups/rnaseqanalysis/EvoGen/SpinalCord/manuscript/ATAC',s,'Group_bigwig')
#     adatas[s] = crandata.chrom_io.import_bigwigs(
#         bigwigs_folder=bigwigs_dir,
#         regions_file=bed_files[s],
#         backed_path='/home/matthew.schmitz/Matthew/data/test_crandata/'+s+'_spc_test.icechunk',
#         target_region_width=WINDOW_SIZE,
#         chromsizes_file=chromsizes_files[s],
#         target = 'raw',
#         max_stochastic_shift=5,
#         chunk_size=512,
#         backend='icechunk',
#         n_bins=n_bins
#     )
#     bed = adatas[s].get_dataframe('var').loc[:,['chrom','start','end']]
#     adatas[s] = crandata.seq_io.add_genome_sequences_to_crandata(adatas[s], bed, genomes[s])
#     print(adatas[s]['sequences'])
#     adatas[s]['var-_-species'] = xr.DataArray(np.repeat(species_codes[s],adatas[s].sizes['var']),dims='var').chunk({'var':adatas[s].attrs['chunk_size']})
#     adatas[s].to_icechunk(mode='a',commit_name='add_genome_seqs')
#     # adatas[s].to_zarr(adatas[s].encoding['source'],mode='a')
#     adatas[s] = crandata.crandata.CrAnData.open_icechunk('/home/matthew.schmitz/Matthew/data/test_crandata/'+s+'_spc_test.icechunk')
    

human


100%|██████████| 49/49 [00:00<00:00, 500.38it/s]


2025-04-14T12:02:42.163578-0700 INFO Extracting values from 49 bigWig files...


  return cls(**configuration_parsed)
  result = await AsyncArray._create_v3(
  return cls(**configuration_parsed)
  result = await AsyncArray._create_v3(
Processing memory tiles:   0%|          | 2/556 [03:11<14:10:06, 92.07s/it] 

In [None]:
for s in adatas.keys():
    crandata.train_val_test_split(
        adatas[s], strategy="region", val_size=0.1, test_size=0.1, random_state=42
    )
    adatas[s].to_zarr(adatas[s].encoding['source'],mode='a')
    # adatas[s].to_icechunk(mode='a',commit_name='train_val_test_split') #If you're using icechunk store

In [None]:
print(adatas[s])
print(adatas[s]['X'])
adatas[s]['sequences']

In [22]:
adatas = {}

for s in tqdm(species):
    # adatas[s] = crandata.CrAnData.open_icechunk('/home/matthew.schmitz/Matthew/data/test_crandata/'+s+'_spc_test.icechunk',
    #                                             cache_config={'num_bytes_chunks':int(8e9)})#Cache 8Gb
    adatas[s] = crandata.crandata.CrAnData.open_zarr('/home/matthew.schmitz/Matthew/data/test_crandata/'+s+'_spc_test.zarr')

  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
100%|██████████| 3/3 [00:12<00:00,  4.02s/it]


In [9]:
crandata._module = importlib.reload(crandata._module)

AttributeError: module 'crandata' has no attribute '_module'

In [23]:
transform = crandata.seq_io.DNATransform(out_len=WINDOW_SIZE, random_rc=True, max_shift=MAX_SHIFT)

meta_module = crandata.CrAnDataModule(
    adatas=list(adatas.values()),
    batch_size=48,
    load_keys={'X': 'y','sequences':'sequence','var-_-species':'species'},
    dnatransform=transform,
    shuffle_dims=['obs'],
    join='inner',
    num_workers=0
)

# Setup the meta module for the "fit" stage (train/val)
meta_module.setup("train")

# Retrieve the training dataloader from the meta module and iterate over a couple of batches.
meta_train_dl = meta_module.train_dataloader

print("\nIterating over a couple of training batches from CrAnDataModule:")
for i, batch in enumerate(tqdm(meta_train_dl)):
    print(f"Meta Batch {i}:")
    for key, tensor in batch.items():
        print(f"  {key}: shape {tensor.shape}")
    if i == 5:
        break





Iterating over a couple of training batches from MetaAnnDataModule:


1it [00:00,  5.89it/s]

Meta Batch 0:
  y: shape (47, 48, 42)
  sequence: shape (48, 2114, 4)
  species: shape (48,)


2it [00:00,  6.50it/s]

Meta Batch 1:
  y: shape (47, 48, 42)
  sequence: shape (48, 2114, 4)
  species: shape (48,)


3it [00:00,  7.00it/s]

Meta Batch 2:
  y: shape (47, 48, 42)
  sequence: shape (48, 2114, 4)
  species: shape (48,)


4it [00:00,  7.34it/s]

Meta Batch 3:
  y: shape (47, 48, 42)
  sequence: shape (48, 2114, 4)
  species: shape (48,)


5it [00:00,  7.48it/s]

Meta Batch 4:
  y: shape (47, 48, 42)
  sequence: shape (48, 2114, 4)
  species: shape (48,)


5it [00:00,  6.07it/s]

Meta Batch 5:
  y: shape (47, 48, 42)
  sequence: shape (48, 2114, 4)
  species: shape (48,)





In [24]:
import cProfile
import pstats

# Run your code and write the profile data to a file.
cProfile.run("""
for i, batch in tqdm(enumerate(meta_train_dl)):
    print(f"Meta Batch {i}:")
    for key, tensor in batch.items():
        print(f"  {key}: shape {tensor.shape}")
    if i == 10:
        break
""", "profile_output.prof")

# Load the profile data from the file using pstats.
p = pstats.Stats("profile_output.prof")
p.strip_dirs().sort_stats("cumtime").print_stats(50)
# p.strip_dirs().sort_stats('cumtime').print_stats('crandata')


1it [00:00,  4.77it/s]

Meta Batch 0:
  y: shape (47, 48, 42)
  sequence: shape (48, 2114, 4)
  species: shape (48,)


3it [00:00,  4.92it/s]

Meta Batch 1:
  y: shape (47, 48, 42)
  sequence: shape (48, 2114, 4)
  species: shape (48,)
Meta Batch 2:
  y: shape (47, 48, 42)
  sequence: shape (48, 2114, 4)
  species: shape (48,)


5it [00:01,  5.04it/s]

Meta Batch 3:
  y: shape (47, 48, 42)
  sequence: shape (48, 2114, 4)
  species: shape (48,)
Meta Batch 4:
  y: shape (47, 48, 42)
  sequence: shape (48, 2114, 4)
  species: shape (48,)


7it [00:01,  5.02it/s]

Meta Batch 5:
  y: shape (47, 48, 42)
  sequence: shape (48, 2114, 4)
  species: shape (48,)
Meta Batch 6:
  y: shape (47, 48, 42)
  sequence: shape (48, 2114, 4)
  species: shape (48,)


8it [00:01,  4.97it/s]

Meta Batch 7:
  y: shape (47, 48, 42)
  sequence: shape (48, 2114, 4)
  species: shape (48,)
Meta Batch 8:
  y: shape (47, 48, 42)
  sequence: shape (48, 2114, 4)
  species: shape (48,)


10it [00:02,  4.97it/s]

Meta Batch 9:
  y: shape (47, 48, 42)
  sequence: shape (48, 2114, 4)
  species: shape (48,)


10it [00:02,  4.50it/s]

Meta Batch 10:
  y: shape (47, 48, 42)
  sequence: shape (48, 2114, 4)
  species: shape (48,)
Thu Apr 17 22:26:58 2025    profile_output.prof

         3303356 function calls (3277978 primitive calls) in 2.362 seconds

   Ordered by: cumulative time
   List reduced from 1463 to 50 due to restriction <50>

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
     11/1    0.000    0.000    2.006    2.006 dataset.py:873(load)
     11/1    0.000    0.000    2.006    2.006 daskmanager.py:80(compute)
     11/1    0.000    0.000    2.006    2.006 base.py:600(compute)
1559/1541    0.004    0.000    1.467    0.001 {method 'run' of '_contextvars.Context' objects}
   618/41    0.006    0.000    1.227    0.030 base_events.py:1909(_run_once)
       12    0.000    0.000    0.972    0.081 _module.py:182(<lambda>)
       12    0.001    0.000    0.922    0.077 _module.py:48(__call__)
       12    0.001    0.000    0.887    0.074 seq_io.py:321(apply_rc)
    12/11    0.000    0.000    




<pstats.Stats at 0x7fa9a3b7ffb0>

In [10]:
for s in tqdm(species):
    adatas[s].unify_convert_chunks('/home/matthew.schmitz/Matthew/data/test_crandata/'+s+'_spc_test.icechunk')

  return cls(**configuration_parsed)
  meta = AsyncArray._create_metadata_v3(
  return cls(**configuration_parsed)
  meta = AsyncArray._create_metadata_v3(
  return cls(**configuration_parsed)
  meta = AsyncArray._create_metadata_v3(
  return cls(**configuration_parsed)
  meta = AsyncArray._create_metadata_v3(
  return cls(**configuration_parsed)
  meta = AsyncArray._create_metadata_v3(
  return cls(**configuration_parsed)
  meta = AsyncArray._create_metadata_v3(
  return cls(**configuration_parsed)
  meta = AsyncArray._create_metadata_v3(
  return cls(**configuration_parsed)
  meta = AsyncArray._create_metadata_v3(
  return cls(**configuration_parsed)
  meta = AsyncArray._create_metadata_v3(
100%|██████████| 3/3 [36:53<00:00, 737.94s/it]


In [11]:
adatas = {}
for s in tqdm(species):
    adatas[s] = crandata.CrAnData.open_icechunk('/home/matthew.schmitz/Matthew/data/test_crandata/'+s+'_spc_test.icechunk',
                                                cache_config={'num_bytes_chunks':int(8e9),'num_chunk_refs':5})#Cache 8Gb
    # adatas[s] = crandata.crandata.CrAnData.open_zarr('/home/matthew.schmitz/Matthew/data/test_crandata/'+s+'_spc_test.zarr')

  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
100%|██████████| 3/3 [47:48<00:00, 956.28s/it]


In [15]:
transform = crandata.seq_io.DNATransform(out_len=WINDOW_SIZE, random_rc=True, max_shift=MAX_SHIFT)

meta_module = crandata.CrAnDataModule(
    adatas=list(adatas.values()),
    batch_size=48,
    load_keys={'X': 'y','sequences':'sequence','var-_-species':'species'},
    dnatransform=transform,
    shuffle_dims=['obs'],
    join='inner',
    num_workers=0
)

# Setup the meta module for the "fit" stage (train/val)
meta_module.setup("train")

# Retrieve the training dataloader from the meta module and iterate over a couple of batches.
meta_train_dl = meta_module.train_dataloader

print("\nIterating over a couple of training batches from CrAnDataModule:")
for i, batch in enumerate(tqdm(meta_train_dl)):
    print(f"Meta Batch {i}:")
    for key, tensor in batch.items():
        print(f"  {key}: shape {tensor.shape}")
    if i == 5:
        break





Iterating over a couple of training batches from MetaAnnDataModule:


1it [00:00,  7.56it/s]

Meta Batch 0:
  y: shape (47, 48, 42)
  sequence: shape (48, 2114, 4)
  species: shape (48,)


2it [00:00,  8.04it/s]

Meta Batch 1:
  y: shape (47, 48, 42)
  sequence: shape (48, 2114, 4)
  species: shape (48,)
Meta Batch 2:
  y: shape (47, 48, 42)
  sequence: shape (48, 2114, 4)
  species: shape (48,)


4it [00:00,  8.96it/s]

Meta Batch 3:
  y: shape (47, 48, 42)
  sequence: shape (48, 2114, 4)
  species: shape (48,)


5it [00:10,  2.20s/it]

Meta Batch 4:
  y: shape (47, 48, 42)
  sequence: shape (48, 2114, 4)
  species: shape (48,)
Meta Batch 5:
  y: shape (47, 48, 42)
  sequence: shape (48, 2114, 4)
  species: shape (48,)





In [21]:
import cProfile

code = '''
for i, batch in enumerate(tqdm(meta_train_dl)):
    print(f"Meta Batch {i}:")
    for key, tensor in batch.items():
        print(f"  {key}: shape {tensor.shape}")
    if i == 5:
        break
'''

out = cProfile.run(code,sort=True)


1it [00:00,  3.32it/s]

Meta Batch 0:
  y: shape (47, 48, 42)
  sequence: shape (48, 2114, 4)
  species: shape (48,)


2it [00:00,  4.04it/s]

Meta Batch 1:
  y: shape (47, 48, 42)
  sequence: shape (48, 2114, 4)
  species: shape (48,)


3it [00:00,  4.32it/s]

Meta Batch 2:
  y: shape (47, 48, 42)
  sequence: shape (48, 2114, 4)
  species: shape (48,)


4it [00:00,  4.38it/s]

Meta Batch 3:
  y: shape (47, 48, 42)
  sequence: shape (48, 2114, 4)
  species: shape (48,)


5it [00:01,  4.50it/s]

Meta Batch 4:
  y: shape (47, 48, 42)
  sequence: shape (48, 2114, 4)
  species: shape (48,)


5it [00:01,  3.60it/s]

Meta Batch 5:
  y: shape (47, 48, 42)
  sequence: shape (48, 2114, 4)
  species: shape (48,)
         1715553 function calls (1705109 primitive calls) in 1.462 seconds

   Ordered by: internal time

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
   818592    0.129    0.000    0.134    0.000 _collections_abc.py:868(__iter__)
  1171/13    0.095    0.000    0.000    0.000 {method 'acquire' of '_thread.lock' objects}
     1662    0.091    0.000    0.166    0.000 {method 'update' of 'set' objects}
       54    0.090    0.002    0.102    0.002 blockwise.py:625(get_output_keys)
       84    0.054    0.001    0.112    0.001 {method 'intersection' of 'set' objects}
   176/28    0.048    0.000    0.319    0.011 threading.py:323(wait)
     6426    0.040    0.000    0.131    0.000 <frozen importlib._bootstrap_external>:1593(find_spec)
    22/14    0.037    0.002    0.020    0.001 cpu.py:188(as_numpy_array_wrapper)
     6426    0.034    0.000    0.034    0.000 {built-in met




In [22]:
meta_module.load()
meta_train_dl = meta_module.train_dataloader



KeyboardInterrupt



In [None]:
code = '''
for i, batch in enumerate(tqdm(meta_train_dl)):
    print(f"Meta Batch {i}:")
    for key, tensor in batch.items():
        print(f"  {key}: shape {tensor.shape}")
    if i == 50:
        break
'''

out = cProfile.run(code,sort=True)


In [None]:
model_architecture = crested.tl.zoo.simple_convnet(
    seq_len=2114, num_classes=batch['y'].shape[1]
)


In [None]:
import keras
# Create your own configuration
# I recommend trying this for peak regression with a weighted cosine mse log loss function
optimizer = keras.optimizers.Adam(learning_rate=1e-5)
loss = crested.tl.losses.CosineMSELogLoss(max_weight=100, multiplier=1)
loss = crested.tl.losses.PoissonLoss()

metrics = [
    keras.metrics.MeanAbsoluteError(),
    # keras.metrics.MeanSquaredError(),
    # keras.metrics.CosineSimilarity(axis=1),
    crested.tl.metrics.PearsonCorrelation(),
    # crested.tl.metrics.ConcordanceCorrelationCoefficient(),
    # crested.tl.metrics.PearsonCorrelationLog(),
    # crested.tl.metrics.ZeroPenaltyMetric(),
]

alternative_config = crested.tl.TaskConfig(optimizer, loss, metrics)
print(alternative_config)


In [None]:
batch['sequence'].shape

In [None]:
# initialize some lazy model parameters *yawn*
model_architecture(batch['sequence'].float().mean(0).unsqueeze(0))

In [None]:
trainer = crested.tl.Crested(
    data=meta_module,
    model=model_architecture,
    config=alternative_config,
    project_name="mouse_biccn",  # change to your liking
    run_name="basemodel",  # change to your liking
    logger=None,  # or None, 'dvc', 'tensorboard'
    seed=7,  # For reproducibility
)
# train the model
trainer.fit(
    epochs=60,
    learning_rate_reduce_patience=3,
    early_stopping_patience=6,
)


In [95]:
import icechunk as ic
store_path = '/home/matthew.schmitz/Matthew/data/test_crandata/mouse_spc_test.icechunk'
storage_config = ic.local_filesystem_storage(store_path)
config = ic.RepositoryConfig.default()
config.caching = ic.CachingConfig(num_bytes_chunks=int(8e9))
repo = ic.Repository.open(storage_config, config)
session = repo.readonly_session("main")
ds = xr.open_zarr(session.store, consolidated=False)


  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)


In [59]:
for k in ds.keys():
    print(k)
    print(ds[k].chunks)


X
((48,), (512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 512,

In [99]:
start_idx = np.random.randint(0, ds.dims['var'] - 100)
print("Random access start index:", start_idx)

t0 = time.time()
subset1 = ds.isel(var=slice(start_idx, start_idx + 100))['X'].values
print("First access time: {:.4f} sec".format(time.time() - t0))

t0 = time.time()
subset2 = ds.isel(var=slice(start_idx, start_idx + 100))['X'].values
print("Second access time: {:.4f} sec".format(time.time() - t0))

print("Done.")


Random access start index: 1822385
First access time: 0.0130 sec
Second access time: 0.0102 sec
Done.


  start_idx = np.random.randint(0, ds.dims['var'] - 100)


In [94]:
start_idx = np.random.randint(0, ds.dims['var'] - 100)
print("Random access start index:", start_idx)

t0 = time.time()
subset1 = np.array(ds.isel(var=slice(start_idx, start_idx + 100))['X'])
print("First access time: {:.4f} sec".format(time.time() - t0))

t0 = time.time()
subset2 = np.array(ds.isel(var=slice(start_idx, start_idx + 100))['X'].values)
print("Second access time: {:.4f} sec".format(time.time() - t0))

print("Done.")


Random access start index: 686658
First access time: 0.0125 sec
Second access time: 0.0112 sec
Done.


  start_idx = np.random.randint(0, ds.dims['var'] - 100)


In [61]:
%%time
np.array(ds.isel({'var':np.arange(start,start+1000)})['X'])
print('done')



done
CPU times: user 323 ms, sys: 123 ms, total: 446 ms
Wall time: 442 ms


In [62]:
%%time
np.array(ds.isel({'var':np.arange(start,start+1000)})['X'])
print('done')



done
CPU times: user 316 ms, sys: 114 ms, total: 430 ms
Wall time: 423 ms


In [63]:
ds.isel({'var':np.arange(start,start+1000)})['X']

  ds.isel({'var':np.arange(start,start+1000)})['X']


Unnamed: 0,Array,Chunk
Bytes,15.38 MiB,7.86 MiB
Shape,"(48, 1000, 42)","(48, 511, 42)"
Dask graph,2 chunks in 3 graph layers,2 chunks in 3 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray
"Array Chunk Bytes 15.38 MiB 7.86 MiB Shape (48, 1000, 42) (48, 511, 42) Dask graph 2 chunks in 3 graph layers Data type float64 numpy.ndarray",42  1000  48,

Unnamed: 0,Array,Chunk
Bytes,15.38 MiB,7.86 MiB
Shape,"(48, 1000, 42)","(48, 511, 42)"
Dask graph,2 chunks in 3 graph layers,2 chunks in 3 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray


In [115]:
import tempfile
import time
import numpy as np
import xarray as xr
import icechunk as ic
from icechunk.xarray import to_icechunk

# Create a temporary directory for the Icechunk store
with tempfile.TemporaryDirectory() as tmpdir:
    store_path = f"{tmpdir}/example.icechunk"
    
    # Set up local storage and repository with a caching configuration (1 MB in this demo)
    storage_config = ic.local_filesystem_storage(store_path)
    config = ic.RepositoryConfig.default()
    config.caching = ic.CachingConfig(num_bytes_chunks=1024 * 1024)
    repo = ic.Repository.create(storage_config, config)
    
    # Create a simple xarray dataset with a 'var' dimension and chunk it along 'var'
    # In this example, the dataset 'X' has shape (var=1000, y=20) and chunks of size 100 along 'var'
    data = np.random.rand(100000, 20)
    ds = xr.Dataset({'X': (('var', 'y'), data)})
    ds = ds.chunk({'var': 100, 'y': 20})
    
    # Write the dataset to the Icechunk store using a writable session
    session = repo.writable_session("main")
    to_icechunk(ds, session)
    commit_hash = session.commit("initial commit")
    print("Committed with hash:", commit_hash)
    
    # Read the dataset back from the store using a read-only session
    session = repo.readonly_session("main")
    ds2 = xr.open_zarr(session.store, consolidated=False)
    print("Dataset dimensions:", ds2.dims)
    
    # Test the caching behavior by timing two consecutive random accesses along 'var'
    var_dim = ds2.dims['var']
    # Ensure we have 100 contiguous indices available (avoid overflow)
    start_idx = np.random.randint(0, var_dim - 100)
    print("Random access start index:", start_idx)
    
    t0 = time.time()
    subset1 = ds2.isel(var=slice(start_idx, start_idx + 100))['X'].values
    print("First access time: {:.4f} sec".format(time.time() - t0))
    
    t0 = time.time()
    subset2 = ds2.isel(var=slice(start_idx, start_idx + 100))['X'].values
    print("Second access time: {:.4f} sec".format(time.time() - t0))
    
    print("Done.")


Committed with hash: V6J86ASCJMYC72RA9NR0
Random access start index: 49587
First access time: 0.0063 sec
Second access time: 0.0047 sec
Done.


  var_dim = ds2.dims['var']


In [66]:
1024 * 1024

1048576

In [48]:
import h5py
import numpy as np
import pandas as pd
import xarray as xr
from scipy.sparse import csr_matrix
import sparse
from pathlib import Path
from typing import Union, Literal, List
from crandata import CrAnData

def read_h5ad_selective_to_crandata(
    filename: Union[str, Path],
    mode: Literal["r", "r+"] = "r",
    selected_fields: List[str] = None,
) -> CrAnData:
    """
    Read just the specified top‐level AnnData fields (e.g. "X","obs","var","layers", etc.)
    from an .h5ad file via h5py, reconstruct sparse/categorical if needed,
    and return a CrAnData (xarray.Dataset).  This version unpacks obs/var
    into -_- columns so we never pass a DataFrame into CrAnData.__init__.
    """
    selected_fields = selected_fields or ["X", "obs", "var"]

    # ————— Helpers (same as before) ——————————————————————————————————

    def h5_tree(g):
        out = {}
        for k, v in g.items():
            if isinstance(v, h5py.Group):
                out[k] = h5_tree(v)
            else:
                try: out[k] = len(v)
                except TypeError: out[k] = "scalar"
        return out

    def dict_to_ete3_tree(d, parent=None):
        from ete3 import Tree
        if parent is None: parent = Tree(name="root")
        for k, v in d.items():
            c = parent.add_child(name=k)
            if isinstance(v, dict):
                dict_to_ete3_tree(v, c)
        return parent

    def ete3_tree_to_dict(t):
        def helper(n):
            if n.is_leaf(): return n.name
            return {c.name: helper(c) for c in n.get_children()}
        return {c.name: helper(c) for c in t.get_children()}

    def prune_tree(tree, keep_keys):
        t = dict_to_ete3_tree(tree)
        keep = set()
        for key in keep_keys:
            for node in t.search_nodes(name=key):
                keep.update(node.iter_ancestors())
                keep.update(node.iter_descendants())
                keep.add(node)
        for n in t.traverse("postorder"):
            if n not in keep and n.up:
                n.detach()
        return ete3_tree_to_dict(t)

    def read_h5_to_dict(group, subtree):
        def helper(grp, sub):
            out = {}
            for k, v in sub.items():
                if isinstance(v, dict):
                    out[k] = helper(grp[k], v) if k in grp else None
                else:
                    if k in grp and isinstance(grp[k], h5py.Dataset):
                        ds = grp[k]
                        if ds.shape == ():
                            out[k] = ds[()]
                        else:
                            arr = ds[...]
                            if arr.dtype.kind == "S":
                                arr = arr.astype(str)
                            out[k] = arr
                    else:
                        out[k] = None
            return out
        return helper(group, subtree)

    def convert_to_dataframe(d: dict) -> pd.DataFrame:
        # infer length
        length = next((len(v) for v in d.values() if not isinstance(v, dict)), None)
        if length is None:
            raise ValueError("Cannot infer obs/var length")
        cols = {}
        for k, v in d.items():
            if isinstance(v, dict) and {"categories","codes"} <= set(v):
                codes = np.asarray(v["codes"], int)
                cats  = [c.decode() if isinstance(c, bytes) else c for c in v["categories"]]
                if len(codes)==length:
                    cols[k] = pd.Categorical.from_codes(codes, cats)
            elif isinstance(v, dict) and {"data","indices","indptr"} <= set(v):
                shape = tuple(v.get("shape",(length, max(v["indices"])+1)))
                cols[k] = csr_matrix((v["data"], v["indices"], v["indptr"]), shape=shape)
            elif not isinstance(v, dict):
                arr = np.asarray(v)
                if arr.ndim==1 and arr.shape[0]==length:
                    if arr.dtype.kind=="S":
                        arr = arr.astype(str)
                    cols[k] = arr
        return pd.DataFrame(cols)

    # ————— Read HDF5 and prune ——————————————————————————————————

    with h5py.File(filename, mode) as f:
        full_tree = h5_tree(f)
        pruned    = prune_tree(full_tree, selected_fields)
        raw       = read_h5_to_dict(f, pruned)

    data_vars = {}
    coords     = {}

    # — obs: unpack into coords + obs-_-col ——————————————————————————————
    if "obs" in raw:
        od = raw["obs"]
        idx = od.pop("_index", None)
        obs_df = convert_to_dataframe(od)
        if idx is not None:
            obs_df.index = [str(x) for x in idx]
        coords["obs"] = obs_df.index.to_numpy()

        # now unpack columns
        for col in obs_df.columns:
            data_vars[f"obs-_-{col}"] = xr.DataArray(
                obs_df[col].values,
                dims=("obs",),
                coords={"obs": coords["obs"]}
            )
        # also store index
        data_vars["obs-_-index"] = xr.DataArray(coords["obs"], dims=("obs",))

    # — var: same pattern ——————————————————————————————————————————————
    if "var" in raw:
        vd = raw["var"]
        idx = vd.pop("_index", None)
        var_df = convert_to_dataframe(vd)
        if idx is not None:
            var_df.index = [str(x) for x in idx]
        coords["var"] = var_df.index.to_numpy()

        for col in var_df.columns:
            data_vars[f"var-_-{col}"] = xr.DataArray(
                var_df[col].values,
                dims=("var",),
                coords={"var": coords["var"]}
            )
        data_vars["var-_-index"] = xr.DataArray(coords["var"], dims=("var",))

    # — X matrix ——————————————————————————————————————————————————
    if "X" in raw:
        xraw = raw["X"]
        print(xraw)
        if isinstance(xraw, dict) and {"data","indices","indptr"} <= set(xraw):
            csr_mat = csr_matrix((xraw["data"], xraw["indices"], xraw["indptr"]))
                                  #shape=tuple(xraw["shape"]))
            arr = sparse.COO.from_scipy_sparse(csr_mat)
        else:
            arr = np.asarray(xraw)
        data_vars["X"] = xr.DataArray(arr, dims=("obs","var"), coords=coords)

    # — layers/obsm/varm/obsp ——————————————————————————————————————————
    for grp in ("layers","obsm","varm","obsp"):
        if grp in raw:
            for name, val in raw[grp].items():
                if val is None:
                    continue
                if isinstance(val, dict) and {"data","indices","indptr"} <= set(val):
                    csr_mat = csr_matrix((val["data"], val["indices"], val["indptr"]))
                                          #shape=tuple(val.get("shape",arr.shape)))
                    arr = sparse.COO.from_scipy_sparse(csr_mat)
                else:
                    arr = np.asarray(val)

                if grp=="layers":
                    dims, c = ("obs","var"), coords
                elif grp=="obsm":
                    d2 = f"obsm_{name}"
                    dims, c = ("obs",d2), {"obs":coords["obs"],d2:np.arange(arr.shape[1])}
                elif grp=="varm":
                    d2 = f"varm_{name}"
                    dims, c = ("var",d2), {"var":coords["var"],d2:np.arange(arr.shape[1])}
                else:  # obsp
                    d2 = f"obsp_{name}"
                    dims, c = ("obs",d2), {"obs":coords["obs"],d2:coords["obs"]}

                data_vars[f"{grp}-_-{name}"] = xr.DataArray(arr, dims=dims, coords=c)

    # ——— Finally, build and return CrAnData ——————————————————————
    return CrAnData(data_vars=data_vars, coords=coords)


In [46]:
ds=read_h5ad_selective_to_crandata("/allen/programs/celltypes/workgroups/rnaseqanalysis/EvoGen/Team/Matthew/data/testgenesets/siletti300k_highly_variable.h5ad",selected_fields=["X","obs","var","UMIs"])

{'data': array([1, 1, 1, ..., 1, 1, 1], dtype=int16), 'indices': array([8712, 5191, 8660, ...,  357, 1115, 1541], dtype=int32), 'indptr': array([        0,       667,      1325, ..., 484405706, 484406692,
       484407236], dtype=int32)}


In [49]:
ds

0,1
Format,coo
Data Type,int16
Shape,"(299721, 9751)"
nnz,484407236
Density,0.16574647184333144
Read-only,True
Size,4.5G
Storage ratio,0.83


In [53]:
import numpy as np
import xarray as xr
from itertools import product
from typing import Union, List, Tuple, Dict

def group_aggr_xr(
    ds: xr.Dataset,
    array_name: str,
    categories: Union[str, List[str]],
    agg_func=np.mean,
    normalize: bool = False,
) -> Tuple[np.ndarray, Dict[str, List[str]]]:
    """
    Group–aggregate an xarray.Dataset along 'obs' by one or more categorical
    obs columns, using xarray.groupby on the specified data array.

    Parameters
    ----------
    ds
        An xarray.Dataset (e.g. CrAnData) containing:
          - a DataArray `ds[array_name]` with dims ("obs","var") or similar,
          - one or more obs columns named "obs-_-<cat>".
    array_name
        Name of the DataArray in `ds` to aggregate (e.g. "X", "layers-_-counts", "obsp-_-contacts").
    categories
        Single category name or list of names (the <cat> in "obs-_-<cat>").
    agg_func
        Aggregation function (e.g. np.mean, np.median, np.std).
    normalize
        If True, each observation is normalized by its row‑sum before grouping.

    Returns
    -------
    result : np.ndarray
        Aggregated values, shape (*category_sizes, num_vars).
    category_orders : dict
        Maps each category name → list of its observed levels (in first‑appearance order).
    """
    # — normalize categories list —
    if isinstance(categories, str):
        categories = [categories]
    if not categories:
        raise ValueError("Must supply at least one category name")

    # — pick the DataArray and its dims —
    da = ds[array_name]
    obs_dim, var_dim = da.dims[:2]
    n_vars = da.sizes[var_dim]

    # — collect category arrays & orders —
    category_orders: Dict[str, List[str]] = {}
    cat_arrs: List[np.ndarray] = []
    for cat in categories:
        arr = ds[f"obs-_-{cat}"].values.astype(str)
        # preserve first‑appearance order
        seen = dict.fromkeys(arr.tolist())
        category_orders[cat] = list(seen.keys())
        cat_arrs.append(arr)

    # — build grouping coordinate —
    if len(categories) == 1:
        group_dim = categories[0]
        grouping = xr.DataArray(cat_arrs[0], dims=obs_dim, coords={obs_dim: ds.coords[obs_dim]})
    else:
        sep = "__"
        combo = cat_arrs[0]
        for arr in cat_arrs[1:]:
            combo = np.char.add(np.char.add(combo, sep), arr)
        group_dim = sep.join(categories)
        grouping = xr.DataArray(combo, dims=obs_dim, coords={obs_dim: ds.coords[obs_dim]})

    da = da.assign_coords(**{group_dim: grouping})

    # — optional normalize each row by its sum —
    if normalize:
        da = da / da.sum(dim=var_dim, keepdims=True)

    # — groupby & reduce —
    grouped = da.groupby(group_dim).reduce(agg_func, dim=obs_dim)

    # — extract the raw data, densifying if needed —
    raw = grouped.data
    if hasattr(raw, "todense"):
        arr = raw.todense()
    elif hasattr(raw, "toarray"):
        arr = raw.toarray()
    else:
        arr = np.asarray(raw)

    # — reorder and reshape into (*category_sizes, n_vars) —
    if len(categories) == 1:
        cats = category_orders[categories[0]]
        # ensure our output follows the same order
        idx = [cats.index(v) for v in grouped[ group_dim ].values.astype(str)]
        result = arr[idx, :]
    else:
        lists = [category_orders[c] for c in categories]
        combos = list(product(*lists))
        combo_strs = [sep.join(c) for c in combos]
        idx = [combo_strs.index(v) for v in grouped[group_dim].values.astype(str)]
        reshaped = arr[idx, :]
        sizes = [len(category_orders[c]) for c in categories]
        result = reshaped.reshape(*sizes, n_vars)

    return result, category_orders


In [None]:
group_aggr_crandata_xr(ds,category_column_names=['dataset'])