In [1]:
# In[1]: imports and helpers
import os, tempfile, copy
from pathlib import Path

import pandas as pd
import numpy as np
import pyBigWig
import xarray as xr
import tqdm

import grandata
from grandata.chrom_io import grandata_from_bigwigs, add_bigwig_array

# In[2]: make temp dirs
tmp = tempfile.TemporaryDirectory()
base = Path(tmp.name)
beds = base/"beds";    beds.mkdir()
bws  = base/"bigwigs"; bws.mkdir()

# In[3]: write a tiny chrom.sizes
chromsizes = base/"chrom.sizes"
with open(chromsizes,"w") as f:
    f.write("chr1\t1000\n")

# In[4]: write consensus.bed
cons_df = pd.DataFrame({
    "chrom":["chr1","chr1","chr1"],
    "start":[100,300,350],
    "end":  [200,400,450],
})
consensus_file = base/"consensus.bed"
cons_df.to_csv(consensus_file, sep="\t", header=False, index=False)

# reload into a named‐cols DF
consensus = pd.read_csv(
    consensus_file, sep="\t", header=None,
    names=["chrom","start","end"]
)

# In[5]: build two synthetic bigWigs
for fname,val in [("one.bw", 5.0),("two.bw",4.0)]:
    path = bws/fname
    bw = pyBigWig.open(str(path),"w")
    bw.addHeader([("chr1",1000)])
    bw.addEntries(["chr1"], [0], [1000], [val])
    bw.close()

# In[6]: call new constructor
target_width = 100
out_path = base/"out.zarr"

adata = grandata_from_bigwigs(
    region_table         = consensus,
    bigwig_dir           = bws,
    backed_path          = out_path,
    array_name           = "X",        # your first track name
    obs_dim              = "obs",
    var_dim              = "var",
    seq_dim              = "bin",
    target_region_width  = target_width,
    bin_stat             = "mean",
    chunk_size           = 2,
    n_bins               = 1,
    backend              = "zarr",
    tile_size            = 2,
)

print("Created GRAnData:")
print(adata)

# In[7]: optionally add a second track "Y"
adata = add_bigwig_array(
    adata,
    region_table         = consensus,
    bigwig_dir           = bws,
    array_name           = "Y",        # new track
    obs_dim              = "obs",
    var_dim              = "var",
    seq_dim              = "bin",
    target_region_width  = target_width,
    bin_stat             = "mean",
    chunk_size           = 2,
    n_bins               = 1,
    backend              = "zarr",
    tile_size            = 2,
)
print("After adding Y:")
print(adata)

# In[8]: split, sequence‐add, meta‐module demo
grandata.train_val_test_split(adata, strategy="chr_auto")

# dummy genome
fa = base/"chr1.fa"
with open(fa,"w") as f:
    f.write(">chr1\n"+"A"*1000+"\n")
genome = grandata.Genome(str(fa), chrom_sizes=str(chromsizes))

# add sequences
adata = grandata.seq_io.add_genome_sequences_to_grandata(
    adata, consensus, genome
)

# write + reload
adata.to_zarr(str(out_path), mode="a")
adata = grandata.GRAnData.open_zarr(str(out_path))

# two copies for MetaGRAnDataModule
adata1 = copy.deepcopy(adata)
adata2 = copy.deepcopy(adata)
# add a var split column so DNATransform works
adata1["var-_-split"] = xr.DataArray(
    np.full(adata1.sizes["var"], "train"), dims=["var"]
)
adata2["var-_-split"] = xr.DataArray(
    np.full(adata2.sizes["var"], "train"), dims=["var"]
)

# DNATransform + MetaGRAnDataModule
transform = grandata.seq_io.DNATransform(out_len=80, random_rc=True, max_shift=5)
meta = grandata.GRAnDataModule(
    adatas    = [adata1,adata2],
    batch_size= 2,
    load_keys = {"sequences":"sequences","X":"X"},
    dnatransform=transform,
)
meta.setup("train")

print("\nA couple of batches from GRAnDataModule:")
for i,batch in enumerate(tqdm.tqdm(meta.train_dataloader)):
    print({k:v.shape for k,v in batch.items()})
    if i>=1: break

# In[9]: cleanup
print("\nTemp dir contents:", os.listdir(base))
tmp.cleanup()


no sparse


  cls = super().__new__(mcls, name, bases, namespace, **kwargs)


no sparse


  adata.to_zarr(str(backed_path), mode="w")
  return cls(**configuration_parsed)
  meta = AsyncArray._create_metadata_v3(
  return cls(**configuration_parsed)
  meta = AsyncArray._create_metadata_v3(
  return cls(**configuration_parsed)
  meta = AsyncArray._create_metadata_v3(
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  meta = AsyncArray._create_metadata_v3(
  return cls(**configuration_parsed)
  meta = AsyncArray._create_metadata_v3(
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
writing X: 100%|██████████| 2/2 [00:00<00:00, 392.39it/s]
  return cls(**configuratio

Created GRAnData:
<xarray.GRAnData> Size: 208B
Dimensions:       (var: 3, bin: 1, obs: 2)
Coordinates:
  * var           (var) object 24B 'chr1:100-200' 'chr1:300-400' 'chr1:350-450'
  * bin           (bin) int64 8B 0
  * obs           (obs) object 16B 'one' 'two'
Data variables:
    var-_-region  (var) object 24B dask.array<chunksize=(3,), meta=np.ndarray>
    var-_-start   (var) int64 24B dask.array<chunksize=(2,), meta=np.ndarray>
    var-_-end     (var) int64 24B dask.array<chunksize=(2,), meta=np.ndarray>
    var-_-index   (var) object 24B dask.array<chunksize=(3,), meta=np.ndarray>
    X             (obs, var, bin) float32 24B dask.array<chunksize=(2, 2, 1), meta=np.ndarray>
    obs-_-index   (obs) object 16B dask.array<chunksize=(2,), meta=np.ndarray>
    var-_-chrom   (var) object 24B dask.array<chunksize=(3,), meta=np.ndarray>
Attributes:
    chunk_size:  2


writing Y: 100%|██████████| 2/2 [00:00<00:00, 303.62it/s]
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)


After adding Y:
<xarray.GRAnData> Size: 232B
Dimensions:       (var: 3, bin: 1, obs: 2)
Coordinates:
  * var           (var) object 24B 'chr1:100-200' 'chr1:300-400' 'chr1:350-450'
  * bin           (bin) int64 8B 0
  * obs           (obs) object 16B 'one' 'two'
Data variables:
    var-_-start   (var) int64 24B dask.array<chunksize=(2,), meta=np.ndarray>
    X             (obs, var, bin) float32 24B dask.array<chunksize=(2, 2, 1), meta=np.ndarray>
    var-_-region  (var) object 24B dask.array<chunksize=(3,), meta=np.ndarray>
    obs-_-index   (obs) object 16B dask.array<chunksize=(2,), meta=np.ndarray>
    var-_-chrom   (var) object 24B dask.array<chunksize=(3,), meta=np.ndarray>
    var-_-index   (var) object 24B dask.array<chunksize=(3,), meta=np.ndarray>
    var-_-end     (var) int64 24B dask.array<chunksize=(2,), meta=np.ndarray>
    Y             (obs, var, bin) float32 24B dask.array<chunksize=(2, 2, 1), meta=np.ndarray>
Attributes:
    chunk_size:  2


100%|██████████| 3/3 [00:00<00:00, 1953.26it/s]
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  meta = AsyncArray._create_metadata_v3(
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)



A couple of batches from GRAnDataModule:


1it [00:00, 12.43it/s]

{'sequences': (2, 80, 4), 'X': (2, 2, 1)}
{'sequences': (2, 80, 4), 'X': (2, 2, 1)}

Temp dir contents: ['beds', 'bigwigs', 'chrom.sizes', 'consensus.bed', 'out.zarr', 'chr1.fa', 'chr1.fa.fai']





In [2]:
adata.attrs

{'chunk_size': 2,
 'genome_name': 'chr1',
 'genome_fasta': '/scratch/fast/438413/tmpus_l2ytn/chr1.fa',
 'genome_chrom_sizes': '{"chr1": 1000}'}

In [3]:
consensus

Unnamed: 0,chrom,start,end
0,chr1,100,200
1,chr1,300,400
2,chr1,350,450


In [4]:
adata.get_dataframe('var')

Unnamed: 0,start,region,chrom,end,split
chr1:100-200,0,chr1:100-200,chr1,0,train
chr1:300-400,0,chr1:300-400,chr1,0,train
chr1:350-450,0,chr1:350-450,chr1,0,train


In [5]:
sdfs

NameError: name 'sdfs' is not defined

In [None]:
# Should the fill in _extract_values_from_bigwig actually be 0? Can we filter var where all is 0/nan without loading everything into memory?

In [7]:
import grandata
import xarray as xr
import pandas as pd
import numpy as np
import os
from tqdm import tqdm
import importlib

In [8]:
genomes = {}
bed_dfs = {}
chromsizes_files = {}
bed_files = {}
species = ['human','macaque','mouse']
species_codes = {'human':0,'macaque':1,'mouse':2}

MAX_SHIFT = 5
WINDOW_SIZE = 2114
WINDOW_SIZE = WINDOW_SIZE #+ 2*MAX_SHIFT
OFFSET = WINDOW_SIZE // 2  # e.g., 50% overlap
N_THRESHOLD = 0.3
n_bins = WINDOW_SIZE//50


In [9]:
for s in species:
    genome_path = '/allen/programs/celltypes/workgroups/rnaseqanalysis/EvoGen/Team/Matthew/genome/onehots/'+s
    fasta_file = os.path.join(genome_path,s+'.fa')
    chrom_sizes = os.path.join(genome_path,s+'.fa.sizes')
    annotation_gtf_file = os.path.join(genome_path,s+'.annotation.gtf')
    chromsizes_files[s] = chrom_sizes
    genome = grandata.Genome(fasta_file, chrom_sizes, annotation_gtf_file)
    genome.to_memory()
    genomes[s] = genome
    OUTPUT_BED = os.path.join(genome_path, "binned_genome.bed")
    # Generate bins and optionally write to disk.
    binned_df = grandata.bin_genome(genome, WINDOW_SIZE+ 2*MAX_SHIFT, OFFSET, n_threshold=N_THRESHOLD, output_path=OUTPUT_BED).reset_index(drop=True)
    bed_dfs[s] = binned_df
    print("Filtered bins:")
    print(binned_df)


KeyboardInterrupt: 

In [10]:
bed_dfs = {}
for s in species:
    print(s)
    genome_path = '/allen/programs/celltypes/workgroups/rnaseqanalysis/EvoGen/Team/Matthew/genome/onehots/'+s
    fasta_file = os.path.join(genome_path,s+'.fa')
    chrom_sizes = os.path.join(genome_path,s+'.fa.sizes')
    annotation_gtf_file = os.path.join(genome_path,s+'.annotation.gtf')
    chromsizes_files[s] = chrom_sizes
    genome = grandata.Genome(fasta_file, chrom_sizes, annotation_gtf_file)
    genome.to_memory()
    genomes[s] = genome
    OUTPUT_BED = os.path.join(genome_path, "binned_genome.bed")
    bed_dfs[s] = pd.read_csv(OUTPUT_BED,sep='\t',header=None)
    bed_dfs[s].columns = ['chrom','start','end']

human


[32m2025-04-25 23:30:56.816[0m | [1mINFO    [0m | [36mgrandata._genome[0m:[36mto_memory[0m:[36m142[0m - [1mGenome sequences loaded into memory.[0m


macaque


[32m2025-04-25 23:31:08.985[0m | [1mINFO    [0m | [36mgrandata._genome[0m:[36mto_memory[0m:[36m142[0m - [1mGenome sequences loaded into memory.[0m


mouse


[32m2025-04-25 23:31:20.286[0m | [1mINFO    [0m | [36mgrandata._genome[0m:[36mto_memory[0m:[36m142[0m - [1mGenome sequences loaded into memory.[0m


In [11]:
# grandata.grandata = importlib.reload(grandata.grandata)
# grandata.chrom_io = importlib.reload(grandata.chrom_io)

In [None]:
adatas = {}

for s in species:
    print(s)
    bigwig_dir = os.path.join('/allen/programs/celltypes/workgroups/rnaseqanalysis/EvoGen/SpinalCord/manuscript/ATAC',s,'Group_bigwig')
    adatas[s] = grandata.chrom_io.grandata_from_bigwigs(
        bigwig_dir=bigwig_dir,
        region_table=bed_dfs[s],
        backed_path='/home/matthew.schmitz/Matthew/data/test_grandata/'+s+'_spc_test.zarr',
        array_name           = "X",        # your first track name
        obs_dim              = "obs",
        var_dim              = "var",
        seq_dim              = "seq_bins",
        target_region_width=WINDOW_SIZE,
        bin_stat = 'mean',
        tile_size=5000,
        chunk_size=256,
        n_bins=n_bins
    )
    bed = adatas[s].get_dataframe('var').loc[:,['chrom','start','end']] #good to test that this works
    adatas[s] = grandata.seq_io.add_genome_sequences_to_grandata(adatas[s], bed, genomes[s])
    print(adatas[s]['sequences'])
    adatas[s]['var-_-species'] = xr.DataArray(np.repeat(species_codes[s],adatas[s].sizes['var']),dims='var').chunk({'var':adatas[s].attrs['chunk_size']})
    adatas[s].to_zarr(adatas[s].encoding['source'],mode='a')
    adatas[s] = grandata.grandata.GRAnData.open_zarr('/home/matthew.schmitz/Matthew/data/test_grandata/'+s+'_spc_test.zarr')




human


  adata.to_zarr(str(backed_path), mode="w")
  return cls(**configuration_parsed)
  meta = AsyncArray._create_metadata_v3(
  return cls(**configuration_parsed)
  meta = AsyncArray._create_metadata_v3(
  return cls(**configuration_parsed)
  meta = AsyncArray._create_metadata_v3(
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  meta = AsyncArray._create_metadata_v3(
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
writing X: 100%|██████████| 556/556 [15:27<00:00,  1.67s/it]
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
100%|██████████| 2779992/2779992 [04:51<00:00, 9537.48it/s] 


<xarray.DataArray 'sequences' (var: 2779992, seq_len: 2124, nuc: 4)> Size: 24GB
dask.array<xarray-<this-array>, shape=(2779992, 2124, 4), dtype=uint8, chunksize=(256, 2124, 4), chunktype=numpy.ndarray>
Coordinates:
  * var      (var) object 22MB 'chr1:9514-11638' ... 'KI270713.1:38053-40177'
Dimensions without coordinates: seq_len, nuc


  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)


macaque


  adata.to_zarr(str(backed_path), mode="w")
  return cls(**configuration_parsed)
  meta = AsyncArray._create_metadata_v3(
  return cls(**configuration_parsed)
  meta = AsyncArray._create_metadata_v3(
  return cls(**configuration_parsed)
  meta = AsyncArray._create_metadata_v3(
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  meta = AsyncArray._create_metadata_v3(
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
writing X: 100%|██████████| 539/539 [21:25<00:00,  2.39s/it]
  return cls(**configuration_parsed)
  return cls(**configuration

<xarray.DataArray 'sequences' (var: 2692512, seq_len: 2124, nuc: 4)> Size: 23GB
dask.array<xarray-<this-array>, shape=(2692512, 2124, 4), dtype=uint8, chunksize=(256, 2124, 4), chunktype=numpy.ndarray>
Coordinates:
  * var      (var) object 22MB 'NC_041754.1:1-2125' ... 'NC_005943.1:13742-15...
Dimensions without coordinates: seq_len, nuc


  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)


mouse


  adata.to_zarr(str(backed_path), mode="w")
  return cls(**configuration_parsed)
  meta = AsyncArray._create_metadata_v3(
  return cls(**configuration_parsed)
  meta = AsyncArray._create_metadata_v3(
  return cls(**configuration_parsed)
  meta = AsyncArray._create_metadata_v3(
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  meta = AsyncArray._create_metadata_v3(
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
writing X: 100%|██████████| 502/502 [16:47<00:00,  2.01s/it]
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
100%|██████████| 2508124/250

<xarray.DataArray 'sequences' (var: 2508124, seq_len: 2124, nuc: 4)> Size: 21GB
dask.array<xarray-<this-array>, shape=(2508124, 2124, 4), dtype=uint8, chunksize=(256, 2124, 4), chunktype=numpy.ndarray>
Coordinates:
  * var      (var) object 20MB 'chr1:2999767-3001891' ... 'JH584292.1:12685-1...
Dimensions without coordinates: seq_len, nuc


  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)


In [None]:
#you can clip the number of bins using GRAnData.isel along the seq_bins dimension (to only focus on the center of the region, etc)

In [None]:
# #Alternate workflow to directly write icechunks, but this is ~5x slower (better to write pure zarr3 then convert the the whole store at once)
# adatas = {}

# for s in species:
#     print(s)
#     bigwigs_dir = os.path.join('/allen/programs/celltypes/workgroups/rnaseqanalysis/EvoGen/SpinalCord/manuscript/ATAC',s,'Group_bigwig')
#     adatas[s] = grandata.chrom_io.import_bigwigs(
#         bgu=bigwigs_dir,
#         regions_file=bed_files[s],
#         backed_path='/home/matthew.schmitz/Matthew/data/test_grandata/'+s+'_spc_test.icechunk',
#         target_region_width=WINDOW_SIZE,
#         chromsizes_file=chromsizes_files[s],
#         target = 'raw',
#         max_stochastic_shift=5,
#         chunk_size=512,
#         backend='icechunk',
#         n_bins=n_bins
#     )
#     bed = adatas[s].get_dataframe('var').loc[:,['chrom','start','end']]
#     adatas[s] = grandata.seq_io.add_genome_sequences_to_grandata(adatas[s], bed, genomes[s])
#     print(adatas[s]['sequences'])
#     adatas[s]['var-_-species'] = xr.DataArray(np.repeat(species_codes[s],adatas[s].sizes['var']),dims='var').chunk({'var':adatas[s].attrs['chunk_size']})
#     adatas[s].to_icechunk(mode='a',commit_name='add_genome_seqs')
#     # adatas[s].to_zarr(adatas[s].encoding['source'],mode='a')
#     adatas[s] = grandata.grandata.GRAnData.open_icechunk('/home/matthew.schmitz/Matthew/data/test_grandata/'+s+'_spc_test.icechunk')
    

In [None]:
for s in adatas.keys():
    grandata.train_val_test_split(
        adatas[s], strategy="region", val_size=0.1, test_size=0.1, random_state=42
    )
    adatas[s].to_zarr(adatas[s].encoding['source'],mode='a')
    # adatas[s].to_icechunk(mode='a',commit_name='train_val_test_split') #If you're using icechunk store

In [None]:
print(adatas[s])
print(adatas[s]['X'])
adatas[s]['sequences']

In [None]:
adatas = {}

for s in tqdm(species):
    # adatas[s] = grandata.GRAnData.open_icechunk('/home/matthew.schmitz/Matthew/data/test_grandata/'+s+'_spc_test.icechunk',
    #                                             cache_config={'num_bytes_chunks':int(8e9)})#Cache 8Gb
    adatas[s] = grandata.grandata.GRAnData.open_zarr('/home/matthew.schmitz/Matthew/data/test_grandata/'+s+'_spc_test.zarr')

In [None]:
grandata._module = importlib.reload(grandata._module)

In [None]:
transform = grandata.seq_io.DNATransform(out_len=WINDOW_SIZE, random_rc=True, max_shift=MAX_SHIFT)

meta_module = grandata.GRAnDataModule(
    adatas=list(adatas.values()),
    batch_size=48,
    load_keys={'X': 'y','sequences':'sequence','var-_-species':'species'},
    dnatransform=transform,
    shuffle_dims=['obs'],
    join='inner',
    num_workers=0
)

# Setup the meta module for the "fit" stage (train/val)
meta_module.setup("train")

# Retrieve the training dataloader from the meta module and iterate over a couple of batches.
meta_train_dl = meta_module.train_dataloader

print("\nIterating over a couple of training batches from GRAnDataModule:")
for i, batch in enumerate(tqdm(meta_train_dl)):
    print(f"Meta Batch {i}:")
    for key, tensor in batch.items():
        print(f"  {key}: shape {tensor.shape}")
    if i == 5:
        break


In [None]:
import cProfile
import pstats

# Run your code and write the profile data to a file.
cProfile.run("""
for i, batch in tqdm(enumerate(meta_train_dl)):
    print(f"Meta Batch {i}:")
    for key, tensor in batch.items():
        print(f"  {key}: shape {tensor.shape}")
    if i == 10:
        break
""", "profile_output.prof")

# Load the profile data from the file using pstats.
p = pstats.Stats("profile_output.prof")
p.strip_dirs().sort_stats("cumtime").print_stats(50)
# p.strip_dirs().sort_stats('cumtime').print_stats('grandata')


In [None]:
for s in tqdm(species):
    adatas[s].unify_convert_chunks('/home/matthew.schmitz/Matthew/data/test_grandata/'+s+'_spc_test.icechunk')

In [None]:
adatas = {}
for s in tqdm(species):
    adatas[s] = grandata.GRAnData.open_icechunk('/home/matthew.schmitz/Matthew/data/test_grandata/'+s+'_spc_test.icechunk',
                                                cache_config={'num_bytes_chunks':int(8e9),'num_chunk_refs':5})#Cache 8Gb
    # adatas[s] = grandata.grandata.GRAnData.open_zarr('/home/matthew.schmitz/Matthew/data/test_grandata/'+s+'_spc_test.zarr')

In [None]:
transform = grandata.seq_io.DNATransform(out_len=WINDOW_SIZE, random_rc=True, max_shift=MAX_SHIFT)

meta_module = grandata.GRAnDataModule(
    adatas=list(adatas.values()),
    batch_size=48,
    load_keys={'X': 'y','sequences':'sequence','var-_-species':'species'},
    dnatransform=transform,
    shuffle_dims=['obs'],
    join='inner',
    num_workers=0
)

# Setup the meta module for the "fit" stage (train/val)
meta_module.setup("train")

# Retrieve the training dataloader from the meta module and iterate over a couple of batches.
meta_train_dl = meta_module.train_dataloader

print("\nIterating over a couple of training batches from GRAnDataModule:")
for i, batch in enumerate(tqdm(meta_train_dl)):
    print(f"Meta Batch {i}:")
    for key, tensor in batch.items():
        print(f"  {key}: shape {tensor.shape}")
    if i == 5:
        break


In [None]:
import cProfile

code = '''
for i, batch in enumerate(tqdm(meta_train_dl)):
    print(f"Meta Batch {i}:")
    for key, tensor in batch.items():
        print(f"  {key}: shape {tensor.shape}")
    if i == 5:
        break
'''

out = cProfile.run(code,sort=True)


In [None]:
meta_module.load()
meta_train_dl = meta_module.train_dataloader


In [None]:
code = '''
for i, batch in enumerate(tqdm(meta_train_dl)):
    print(f"Meta Batch {i}:")
    for key, tensor in batch.items():
        print(f"  {key}: shape {tensor.shape}")
    if i == 50:
        break
'''

out = cProfile.run(code,sort=True)


In [None]:
model_architecture = crested.tl.zoo.simple_convnet(
    seq_len=2114, num_classes=batch['y'].shape[1]
)


In [None]:
import keras
# Create your own configuration
# I recommend trying this for peak regression with a weighted cosine mse log loss function
optimizer = keras.optimizers.Adam(learning_rate=1e-5)
loss = crested.tl.losses.CosineMSELogLoss(max_weight=100, multiplier=1)
loss = crested.tl.losses.PoissonLoss()

metrics = [
    keras.metrics.MeanAbsoluteError(),
    # keras.metrics.MeanSquaredError(),
    # keras.metrics.CosineSimilarity(axis=1),
    crested.tl.metrics.PearsonCorrelation(),
    # crested.tl.metrics.ConcordanceCorrelationCoefficient(),
    # crested.tl.metrics.PearsonCorrelationLog(),
    # crested.tl.metrics.ZeroPenaltyMetric(),
]

alternative_config = crested.tl.TaskConfig(optimizer, loss, metrics)
print(alternative_config)


In [None]:
batch['sequence'].shape

In [None]:
# initialize some lazy model parameters *yawn*
model_architecture(batch['sequence'].float().mean(0).unsqueeze(0))

In [None]:
trainer = crested.tl.Crested(
    data=meta_module,
    model=model_architecture,
    config=alternative_config,
    project_name="mouse_biccn",  # change to your liking
    run_name="basemodel",  # change to your liking
    logger=None,  # or None, 'dvc', 'tensorboard'
    seed=7,  # For reproducibility
)
# train the model
trainer.fit(
    epochs=60,
    learning_rate_reduce_patience=3,
    early_stopping_patience=6,
)


In [None]:
import icechunk as ic
store_path = '/home/matthew.schmitz/Matthew/data/test_grandata/mouse_spc_test.icechunk'
storage_config = ic.local_filesystem_storage(store_path)
config = ic.RepositoryConfig.default()
config.caching = ic.CachingConfig(num_bytes_chunks=int(8e9))
repo = ic.Repository.open(storage_config, config)
session = repo.readonly_session("main")
ds = xr.open_zarr(session.store, consolidated=False)


In [None]:
for k in ds.keys():
    print(k)
    print(ds[k].chunks)


In [None]:
start_idx = np.random.randint(0, ds.dims['var'] - 100)
print("Random access start index:", start_idx)

t0 = time.time()
subset1 = ds.isel(var=slice(start_idx, start_idx + 100))['X'].values
print("First access time: {:.4f} sec".format(time.time() - t0))

t0 = time.time()
subset2 = ds.isel(var=slice(start_idx, start_idx + 100))['X'].values
print("Second access time: {:.4f} sec".format(time.time() - t0))

print("Done.")


In [None]:
start_idx = np.random.randint(0, ds.dims['var'] - 100)
print("Random access start index:", start_idx)

t0 = time.time()
subset1 = np.array(ds.isel(var=slice(start_idx, start_idx + 100))['X'])
print("First access time: {:.4f} sec".format(time.time() - t0))

t0 = time.time()
subset2 = np.array(ds.isel(var=slice(start_idx, start_idx + 100))['X'].values)
print("Second access time: {:.4f} sec".format(time.time() - t0))

print("Done.")


In [None]:
%%time
np.array(ds.isel({'var':np.arange(start,start+1000)})['X'])
print('done')

In [None]:
%%time
np.array(ds.isel({'var':np.arange(start,start+1000)})['X'])
print('done')

In [None]:
ds.isel({'var':np.arange(start,start+1000)})['X']

In [None]:
import tempfile
import time
import numpy as np
import xarray as xr
import icechunk as ic
from icechunk.xarray import to_icechunk

# Create a temporary directory for the Icechunk store
with tempfile.TemporaryDirectory() as tmpdir:
    store_path = f"{tmpdir}/example.icechunk"
    
    # Set up local storage and repository with a caching configuration (1 MB in this demo)
    storage_config = ic.local_filesystem_storage(store_path)
    config = ic.RepositoryConfig.default()
    config.caching = ic.CachingConfig(num_bytes_chunks=1024 * 1024)
    repo = ic.Repository.create(storage_config, config)
    
    # Create a simple xarray dataset with a 'var' dimension and chunk it along 'var'
    # In this example, the dataset 'X' has shape (var=1000, y=20) and chunks of size 100 along 'var'
    data = np.random.rand(100000, 20)
    ds = xr.Dataset({'X': (('var', 'y'), data)})
    ds = ds.chunk({'var': 100, 'y': 20})
    
    # Write the dataset to the Icechunk store using a writable session
    session = repo.writable_session("main")
    to_icechunk(ds, session)
    commit_hash = session.commit("initial commit")
    print("Committed with hash:", commit_hash)
    
    # Read the dataset back from the store using a read-only session
    session = repo.readonly_session("main")
    ds2 = xr.open_zarr(session.store, consolidated=False)
    print("Dataset dimensions:", ds2.dims)
    
    # Test the caching behavior by timing two consecutive random accesses along 'var'
    var_dim = ds2.dims['var']
    # Ensure we have 100 contiguous indices available (avoid overflow)
    start_idx = np.random.randint(0, var_dim - 100)
    print("Random access start index:", start_idx)
    
    t0 = time.time()
    subset1 = ds2.isel(var=slice(start_idx, start_idx + 100))['X'].values
    print("First access time: {:.4f} sec".format(time.time() - t0))
    
    t0 = time.time()
    subset2 = ds2.isel(var=slice(start_idx, start_idx + 100))['X'].values
    print("Second access time: {:.4f} sec".format(time.time() - t0))
    
    print("Done.")


In [None]:
1024 * 1024

In [None]:
import h5py
import numpy as np
import pandas as pd
import xarray as xr
from scipy.sparse import csr_matrix
import sparse
from pathlib import Path
from typing import Union, Literal, List
from grandata import GRAnData

def read_h5ad_selective_to_grandata(
    filename: Union[str, Path],
    mode: Literal["r", "r+"] = "r",
    selected_fields: List[str] = None,
) -> GRAnData:
    """
    Read just the specified top‐level AnnData fields (e.g. "X","obs","var","layers", etc.)
    from an .h5ad file via h5py, reconstruct sparse/categorical if needed,
    and return a GRAnData (xarray.Dataset).  This version unpacks obs/var
    into -_- columns so we never pass a DataFrame into GRAnData.__init__.
    """
    selected_fields = selected_fields or ["X", "obs", "var"]

    # ————— Helpers (same as before) ——————————————————————————————————

    def h5_tree(g):
        out = {}
        for k, v in g.items():
            if isinstance(v, h5py.Group):
                out[k] = h5_tree(v)
            else:
                try: out[k] = len(v)
                except TypeError: out[k] = "scalar"
        return out

    def dict_to_ete3_tree(d, parent=None):
        from ete3 import Tree
        if parent is None: parent = Tree(name="root")
        for k, v in d.items():
            c = parent.add_child(name=k)
            if isinstance(v, dict):
                dict_to_ete3_tree(v, c)
        return parent

    def ete3_tree_to_dict(t):
        def helper(n):
            if n.is_leaf(): return n.name
            return {c.name: helper(c) for c in n.get_children()}
        return {c.name: helper(c) for c in t.get_children()}

    def prune_tree(tree, keep_keys):
        t = dict_to_ete3_tree(tree)
        keep = set()
        for key in keep_keys:
            for node in t.search_nodes(name=key):
                keep.update(node.iter_ancestors())
                keep.update(node.iter_descendants())
                keep.add(node)
        for n in t.traverse("postorder"):
            if n not in keep and n.up:
                n.detach()
        return ete3_tree_to_dict(t)

    def read_h5_to_dict(group, subtree):
        def helper(grp, sub):
            out = {}
            for k, v in sub.items():
                if isinstance(v, dict):
                    out[k] = helper(grp[k], v) if k in grp else None
                else:
                    if k in grp and isinstance(grp[k], h5py.Dataset):
                        ds = grp[k]
                        if ds.shape == ():
                            out[k] = ds[()]
                        else:
                            arr = ds[...]
                            if arr.dtype.kind == "S":
                                arr = arr.astype(str)
                            out[k] = arr
                    else:
                        out[k] = None
            return out
        return helper(group, subtree)

    def convert_to_dataframe(d: dict) -> pd.DataFrame:
        # infer length
        length = next((len(v) for v in d.values() if not isinstance(v, dict)), None)
        if length is None:
            raise ValueError("Cannot infer obs/var length")
        cols = {}
        for k, v in d.items():
            if isinstance(v, dict) and {"categories","codes"} <= set(v):
                codes = np.asarray(v["codes"], int)
                cats  = [c.decode() if isinstance(c, bytes) else c for c in v["categories"]]
                if len(codes)==length:
                    cols[k] = pd.Categorical.from_codes(codes, cats)
            elif isinstance(v, dict) and {"data","indices","indptr"} <= set(v):
                shape = tuple(v.get("shape",(length, max(v["indices"])+1)))
                cols[k] = csr_matrix((v["data"], v["indices"], v["indptr"]), shape=shape)
            elif not isinstance(v, dict):
                arr = np.asarray(v)
                if arr.ndim==1 and arr.shape[0]==length:
                    if arr.dtype.kind=="S":
                        arr = arr.astype(str)
                    cols[k] = arr
        return pd.DataFrame(cols)

    # ————— Read HDF5 and prune ——————————————————————————————————

    with h5py.File(filename, mode) as f:
        full_tree = h5_tree(f)
        pruned    = prune_tree(full_tree, selected_fields)
        raw       = read_h5_to_dict(f, pruned)

    data_vars = {}
    coords     = {}

    # — obs: unpack into coords + obs-_-col ——————————————————————————————
    if "obs" in raw:
        od = raw["obs"]
        idx = od.pop("_index", None)
        obs_df = convert_to_dataframe(od)
        if idx is not None:
            obs_df.index = [str(x) for x in idx]
        coords["obs"] = obs_df.index.to_numpy()

        # now unpack columns
        for col in obs_df.columns:
            data_vars[f"obs-_-{col}"] = xr.DataArray(
                obs_df[col].values,
                dims=("obs",),
                coords={"obs": coords["obs"]}
            )
        # also store index
        data_vars["obs-_-index"] = xr.DataArray(coords["obs"], dims=("obs",))

    # — var: same pattern ——————————————————————————————————————————————
    if "var" in raw:
        vd = raw["var"]
        idx = vd.pop("_index", None)
        var_df = convert_to_dataframe(vd)
        if idx is not None:
            var_df.index = [str(x) for x in idx]
        coords["var"] = var_df.index.to_numpy()

        for col in var_df.columns:
            data_vars[f"var-_-{col}"] = xr.DataArray(
                var_df[col].values,
                dims=("var",),
                coords={"var": coords["var"]}
            )
        data_vars["var-_-index"] = xr.DataArray(coords["var"], dims=("var",))

    # — X matrix ——————————————————————————————————————————————————
    if "X" in raw:
        xraw = raw["X"]
        print(xraw)
        if isinstance(xraw, dict) and {"data","indices","indptr"} <= set(xraw):
            csr_mat = csr_matrix((xraw["data"], xraw["indices"], xraw["indptr"]))
                                  #shape=tuple(xraw["shape"]))
            arr = sparse.COO.from_scipy_sparse(csr_mat)
        else:
            arr = np.asarray(xraw)
        data_vars["X"] = xr.DataArray(arr, dims=("obs","var"), coords=coords)

    # — layers/obsm/varm/obsp ——————————————————————————————————————————
    for grp in ("layers","obsm","varm","obsp"):
        if grp in raw:
            for name, val in raw[grp].items():
                if val is None:
                    continue
                if isinstance(val, dict) and {"data","indices","indptr"} <= set(val):
                    csr_mat = csr_matrix((val["data"], val["indices"], val["indptr"]))
                                          #shape=tuple(val.get("shape",arr.shape)))
                    arr = sparse.COO.from_scipy_sparse(csr_mat)
                else:
                    arr = np.asarray(val)

                if grp=="layers":
                    dims, c = ("obs","var"), coords
                elif grp=="obsm":
                    d2 = f"obsm_{name}"
                    dims, c = ("obs",d2), {"obs":coords["obs"],d2:np.arange(arr.shape[1])}
                elif grp=="varm":
                    d2 = f"varm_{name}"
                    dims, c = ("var",d2), {"var":coords["var"],d2:np.arange(arr.shape[1])}
                else:  # obsp
                    d2 = f"obsp_{name}"
                    dims, c = ("obs",d2), {"obs":coords["obs"],d2:coords["obs"]}

                data_vars[f"{grp}-_-{name}"] = xr.DataArray(arr, dims=dims, coords=c)

    # ——— Finally, build and return GRAnData ——————————————————————
    return GRAnData(data_vars=data_vars, coords=coords)


In [None]:
ds=read_h5ad_selective_to_grandata("/allen/programs/celltypes/workgroups/rnaseqanalysis/EvoGen/Team/Matthew/data/testgenesets/siletti300k_highly_variable.h5ad",selected_fields=["X","obs","var","UMIs"])

In [None]:
ds

In [None]:
import numpy as np
import xarray as xr
from itertools import product
from typing import Union, List, Tuple, Dict

def group_aggr_xr(
    ds: xr.Dataset,
    array_name: str,
    categories: Union[str, List[str]],
    agg_func=np.mean,
    normalize: bool = False,
) -> Tuple[np.ndarray, Dict[str, List[str]]]:
    """
    Group–aggregate an xarray.Dataset along 'obs' by one or more categorical
    obs columns, using xarray.groupby on the specified data array.

    Parameters
    ----------
    ds
        An xarray.Dataset (e.g. GRAnData) containing:
          - a DataArray `ds[array_name]` with dims ("obs","var") or similar,
          - one or more obs columns named "obs-_-<cat>".
    array_name
        Name of the DataArray in `ds` to aggregate (e.g. "X", "layers-_-counts", "obsp-_-contacts").
    categories
        Single category name or list of names (the <cat> in "obs-_-<cat>").
    agg_func
        Aggregation function (e.g. np.mean, np.median, np.std).
    normalize
        If True, each observation is normalized by its row‑sum before grouping.

    Returns
    -------
    result : np.ndarray
        Aggregated values, shape (*category_sizes, num_vars).
    category_orders : dict
        Maps each category name → list of its observed levels (in first‑appearance order).
    """
    # — normalize categories list —
    if isinstance(categories, str):
        categories = [categories]
    if not categories:
        raise ValueError("Must supply at least one category name")

    # — pick the DataArray and its dims —
    da = ds[array_name]
    obs_dim, var_dim = da.dims[:2]
    n_vars = da.sizes[var_dim]

    # — collect category arrays & orders —
    category_orders: Dict[str, List[str]] = {}
    cat_arrs: List[np.ndarray] = []
    for cat in categories:
        arr = ds[f"obs-_-{cat}"].values.astype(str)
        # preserve first‑appearance order
        seen = dict.fromkeys(arr.tolist())
        category_orders[cat] = list(seen.keys())
        cat_arrs.append(arr)

    # — build grouping coordinate —
    if len(categories) == 1:
        group_dim = categories[0]
        grouping = xr.DataArray(cat_arrs[0], dims=obs_dim, coords={obs_dim: ds.coords[obs_dim]})
    else:
        sep = "__"
        combo = cat_arrs[0]
        for arr in cat_arrs[1:]:
            combo = np.char.add(np.char.add(combo, sep), arr)
        group_dim = sep.join(categories)
        grouping = xr.DataArray(combo, dims=obs_dim, coords={obs_dim: ds.coords[obs_dim]})

    da = da.assign_coords(**{group_dim: grouping})

    # — optional normalize each row by its sum —
    if normalize:
        da = da / da.sum(dim=var_dim, keepdims=True)

    # — groupby & reduce —
    grouped = da.groupby(group_dim).reduce(agg_func, dim=obs_dim)

    # — extract the raw data, densifying if needed —
    raw = grouped.data
    if hasattr(raw, "todense"):
        arr = raw.todense()
    elif hasattr(raw, "toarray"):
        arr = raw.toarray()
    else:
        arr = np.asarray(raw)

    # — reorder and reshape into (*category_sizes, n_vars) —
    if len(categories) == 1:
        cats = category_orders[categories[0]]
        # ensure our output follows the same order
        idx = [cats.index(v) for v in grouped[ group_dim ].values.astype(str)]
        result = arr[idx, :]
    else:
        lists = [category_orders[c] for c in categories]
        combos = list(product(*lists))
        combo_strs = [sep.join(c) for c in combos]
        idx = [combo_strs.index(v) for v in grouped[group_dim].values.astype(str)]
        reshaped = arr[idx, :]
        sizes = [len(category_orders[c]) for c in categories]
        result = reshaped.reshape(*sizes, n_vars)

    return result, category_orders


In [None]:
group_aggr_grandata_xr(ds,category_column_names=['dataset'])