In [None]:
import os
import tempfile
from pathlib import Path
import pandas as pd
import numpy as np
import pyBigWig
import copy
import xarray as xr
import tqdm

# Import our new module system and utilities.
import crandata
from crandata import CrAnDataModule, MetaCrAnDataModule, CrAnData
from crandata.chrom_io import import_bigwigs
from crandata.seq_io import add_genome_sequences_to_crandata, DNATransform

# Create temporary directories for synthetic data.
temp_dir = tempfile.TemporaryDirectory()
base_dir = Path(temp_dir.name)
beds_dir = base_dir / "beds"
bigwigs_dir = base_dir / "bigwigs"
beds_dir.mkdir(exist_ok=True)
bigwigs_dir.mkdir(exist_ok=True)

# Create a chromsizes file.
chromsizes_file = base_dir / "chrom.sizes"
with open(chromsizes_file, "w") as f:
    f.write("chr1\t1000\n")

# Create two BED files (simulate two different classes).
bed_data_A = pd.DataFrame({
    0: ["chr1", "chr1"],
    1: [100, 300],
    2: [200, 400]
})
bed_data_B = pd.DataFrame({
    0: ["chr1", "chr1"],
    1: [150, 350],
    2: [250, 450]
})
bed_file_A = beds_dir / "ClassA.bed"
bed_file_B = beds_dir / "ClassB.bed"
bed_data_A.to_csv(bed_file_A, sep="\t", header=False, index=False)
bed_data_B.to_csv(bed_file_B, sep="\t", header=False, index=False)

# Create a consensus BED file.
consensus = pd.DataFrame({
    0: ["chr1", "chr1", "chr1"],
    1: [100, 300, 350],
    2: [200, 400, 450]
})
consensus_file = base_dir / "consensus.bed"
consensus.to_csv(consensus_file, sep="\t", header=False, index=False)

# Create two bigWig files.
bigwig_file1 = bigwigs_dir / "test.bw"
bw1 = pyBigWig.open(str(bigwig_file1), "w")
bw1.addHeader([("chr1", 1000)])
bw1.addEntries(chroms=["chr1"], starts=[0], ends=[1000], values=[5.0])
bw1.close()

bigwig_file2 = bigwigs_dir / "test2.bw"
bw2 = pyBigWig.open(str(bigwig_file2), "w")
bw2.addHeader([("chr1", 1000)])
bw2.addEntries(chroms=["chr1"], starts=[0], ends=[1000], values=[4.0])
bw2.close()

# Set extraction parameters.
target_region_width = 100
backed_path = base_dir / "chrom_data.zarr"

# Create the CrAnData object from bigWig files and consensus regions.
adata = import_bigwigs(
    bigwigs_folder=str(bigwigs_dir),
    regions_file=str(consensus_file),
    backed_path=str(backed_path),
    target_region_width=target_region_width,
    chromsizes_file=str(chromsizes_file),
)

crandata.train_val_test_split(adata,strategy='chr_auto')

# Create a dummy FASTA file for a genome.
fasta_file = base_dir / "chr1.fa"
with open(fasta_file, "w") as f:
    f.write(">chr1\n")
    f.write("A" * 1000 + "\n")

# Create a Genome object.
from crandata._genome import Genome
dummy_genome = Genome(str(fasta_file), chrom_sizes=str(chromsizes_file))

# Add sequences to the CrAnData using the provided seq_io utility.
# Here we use the consensus regions as our ranges.
consensus.columns = ['chrom', 'start', 'end']
adata = add_genome_sequences_to_crandata(adata, consensus, dummy_genome)

# Write the CrAnData object to disk and then reload it to ensure sequences are out-of-memory.
adata.to_zarr(str(backed_path),mode='a')
adata_loaded = CrAnData.open_zarr(str(backed_path))
print("Loaded CrAnData:")
print(adata_loaded)

# Create two copies to simulate two datasets (e.g. two species), and add a "split" column in var metadata.
adata1 = copy.deepcopy(adata_loaded)
adata2 = copy.deepcopy(adata_loaded)
adata1["var-_-split"] = xr.DataArray(np.full(adata1.sizes["var"], "train"), dims=["var"])
adata2["var-_-split"] = xr.DataArray(np.full(adata2.sizes["var"], "train"), dims=["var"])

# Create a DNATransform instance.
transform = DNATransform(out_len=80, random_rc=True, max_shift=5)

# Instantiate the MetaCrAnDataModule with the two datasets.
# Note: The batch_size is now 3, matching the number of consensus regions (var dimension).
meta_module = MetaCrAnDataModule(
    adatas=[adata1, adata2],
    batch_size=[2,2],        # adjust batch size to not exceed var length (3)
    load_keys={'sequences':'sequences','X':'X'},
    shuffle=True,
    dnatransform=transform,
    epoch_size=10
)

# Setup each underlying module for the "train" stage.
for mod in meta_module.modules:
    mod.setup(state="train")

# Retrieve the training dataloader from the meta module and iterate over a couple of batches.
meta_train_dl = meta_module.train_dataloader
print("\nIterating over a couple of training batches from MetaCrAnDataModule:")
for i, batch in enumerate(tqdm.tqdm(meta_train_dl)):
    print(batch)
    print(f"\nMeta Batch {i}:")
    for key, tensor in batch.items():
        print(f"  {key}: shape {tensor.shape}")
    if i >= 1:
        break

print("\nTemporary directory contents:")
print(os.listdir(base_dir))
temp_dir.cleanup()


In [None]:
sdfs

In [None]:
# Should the fill in _extract_values_from_bigwig actually be 0? Can we filter var where all is 0/nan without loading everything into memory?

In [None]:
import crandata
import xarray as xr
import pandas as pd
import numpy as np
import os
import crested
from tqdm import tqdm

In [None]:
genomes = {}
beds = {}
chromsizes_files = {}
bed_files = {}
species = ['mouse','human','macaque']

MAX_SHIFT = 5
WINDOW_SIZE = 2114
WINDOW_SIZE = WINDOW_SIZE #+ 2*MAX_SHIFT
OFFSET = WINDOW_SIZE // 2  # e.g., 50% overlap
N_THRESHOLD = 0.3
n_bins = WINDOW_SIZE//50


In [None]:
for s in species:
    genome_path = '/allen/programs/celltypes/workgroups/rnaseqanalysis/EvoGen/Team/Matthew/genome/onehots/'+s
    fasta_file = os.path.join(genome_path,s+'.fa')
    chrom_sizes = os.path.join(genome_path,s+'.fa.sizes')
    annotation_gtf_file = os.path.join(genome_path,s+'.annotation.gtf')
    chromsizes_files[s] = chrom_sizes
    genome = crandata.Genome(fasta_file, chrom_sizes, annotation_gtf_file)
    genome.to_memory()
    genomes[s] = genome
    OUTPUT_BED = os.path.join(genome_path, "binned_genome.bed")
    bed_files[s] = OUTPUT_BED
    # Generate bins and optionally write to disk.
    binned_df = crandata.bin_genome(genome, WINDOW_SIZE, OFFSET, n_threshold=N_THRESHOLD, output_path=OUTPUT_BED).reset_index(drop=True)
    print("Filtered bins:")
    print(binned_df)


In [None]:
# s='mouse'
# adatas = {}

# bigwigs_dir = os.path.join('/allen/programs/celltypes/workgroups/rnaseqanalysis/EvoGen/SpinalCord/manuscript/ATAC',s,'Group_bigwig')
# adatas[s] = crandata.chrom_io.import_bigwigs(
#     bigwigs_folder=bigwigs_dir,
#     regions_file='/allen/programs/celltypes/workgroups/rnaseqanalysis/EvoGen/Team/Matthew/genome/onehots/mouse/binned_genome_test.bed',
#     backed_path='/home/matthew.schmitz/Matthew/'+s+'_spc_test.zarr',
#     target_region_width=WINDOW_SIZE,
#     chromsizes_file=chromsizes_files[s],
#     target = 'raw',
#     max_stochastic_shift=5,
#     chunk_size=2048,
#     n_bins=n_bins
# )
# bed = adatas[s].get_dataframe('var').loc[:,['chrom','start','end']]
# adatas[s] = crandata.seq_io.add_genome_sequences_to_crandata(adatas[s], bed, genomes[s])
# print(adatas[s]['sequences'])
# adatas[s].to_zarr(adatas[s].encoding['source'],mode='a')


In [None]:
adatas = {}

for s in species:
    bigwigs_dir = os.path.join('/allen/programs/celltypes/workgroups/rnaseqanalysis/EvoGen/SpinalCord/manuscript/ATAC',s,'Group_bigwig')
    adatas[s] = crandata.chrom_io.import_bigwigs(
        bigwigs_folder=bigwigs_dir,
        regions_file=bed_files[s],
        backed_path='/home/matthew.schmitz/Matthew/'+s+'_spc_test.zarr',
        target_region_width=WINDOW_SIZE,
        chromsizes_file=chromsizes_files[s],
        target = 'raw',
        max_stochastic_shift=5,
        chunk_size=1024,
        n_bins=n_bins
    )
    bed = adatas[s].get_dataframe('var').loc[:,['chrom','start','end']]
    adatas[s] = crandata.seq_io.add_genome_sequences_to_crandata(adatas[s], bed, genomes[s])
    print(adatas[s]['sequences'])
    adatas[s].to_zarr(adatas[s].encoding['source'],mode='a')
    # adatas[s] = crandata.crandata.CrAnData.open_zarr('/home/matthew.schmitz/Matthew/'+s+'_spc_test.zarr')
    # # beware zarr issue with saving standard int (set to np.[u]int64)
    # adatas[s] = adatas[s].drop_vars('sequences')
    # adatas[s]["var-_-start"] = xr.DataArray(adatas[s].get_dataframe('var')['index'].str.split(":").str[1].str.split("-").str[0].astype('int64').to_numpy(),dims=['var'])
    # adatas[s]["var-_-end"] = xr.DataArray(adatas[s].get_dataframe('var')['index'].str.split(":").str[1].str.split("-").str[1].astype('int64').to_numpy(),dims=['var'])
    # bed = adatas[s].get_dataframe('var').loc[:,['chrom','start','end']]
    # adatas[s] = crandata.seq_io.add_genome_sequences_to_crandata(adatas[s], bed, genomes[s])
    # print(adatas[s]['sequences'])
    # adatas[s].to_zarr(adatas[s].encoding['source'],mode='a')
    adatas[s] = crandata.crandata.CrAnData.open_zarr('/home/matthew.schmitz/Matthew/'+s+'_spc_test.zarr')
    

In [None]:
s = 'mouse'

In [None]:
adatas[s]['sequences'] # is it 2114 or 2124? (is the _filter_and_adjust_chromosome_data actually working?)

In [None]:
# import numpy as np
# adatas['mouse'].uns['chunk_size'] = 512
# adatas['human'].uns['chunk_size'] = 512
# adatas['macaque'].uns['chunk_size'] = 512
# adatas['mouse'].var["chunk_index"] = np.arange(adatas['mouse'].var.shape[0]) // 512
# adatas['human'].var["chunk_index"] = np.arange(adatas['human'].var.shape[0]) // 512
# adatas['macaque'].var["chunk_index"] = np.arange(adatas['macaque'].var.shape[0]) // 512


In [None]:
for s in adatas.keys():
    crandata.train_val_test_split(
        adatas[s], strategy="region", val_size=0.1, test_size=0.1, random_state=42
    )
    adatas[s].to_zarr(adatas[s].encoding['source'],mode='a')
    


In [None]:
# for s in adatas.keys():
#     adatas[s]['sequences'] = adatas[s]['sequences'].chunk({'var':2048,'seq_len':adatas[s].dims['seq_len'],'nuc':adatas[s].dims['nuc']})
#     adatas[s]['X'] = adatas[s]['X'].chunk({'obs':adatas[s].dims['obs'],'var':2048,'seq_bins':adatas[s].dims['seq_bins']})
#     adatas[s].to_zarr(adatas[s].encoding['source'],mode='w',safe_chunks=False)

In [None]:
import importlib
# crandata = importlib.reload(crandata)
# crandata._module.MetaCrAnDataModule = importlib.reload(crandata._module.MetaCrAnDataModule)

In [None]:
adatas[s]['X'].chunksizes

In [None]:
adatas[s]['sequences'].chunksizes

In [None]:
transform = crandata.seq_io.DNATransform(out_len=WINDOW_SIZE, random_rc=True, max_shift=MAX_SHIFT)

meta_module = crandata.MetaCrAnDataModule(
    adatas=list(adatas.values()),
    batch_size=[8,8,8],
    load_keys={'X': 'y','sequences':'sequences'},
    dnatransform=transform,
    num_workers=0,
    epoch_size=1000000    # small epoch size for quick testing
)

# Setup the meta module for the "fit" stage (train/val)
meta_module.setup("train")

# Retrieve the training dataloader from the meta module and iterate over a couple of batches.
meta_train_dl = meta_module.train_dataloader

print("\nIterating over a couple of training batches from MetaAnnDataModule:")
for i, batch in enumerate(tqdm(meta_train_dl)):
    print(f"Meta Batch {i}:")
    for key, tensor in batch.items():
        print(f"  {key}: shape {tensor.shape}")
    if i == 5:
        break


In [None]:
import cProfile

code = '''
for i, batch in enumerate(tqdm(meta_train_dl)):
    print(f"Meta Batch {i}:")
    for key, tensor in batch.items():
        print(f"  {key}: shape {tensor.shape}")
    if i == 3:
        break
'''

out = cProfile.run(code,sort=True)


In [None]:
import numpy as np
import xarray as xr
import xbatcher

# Create an xarray Dataset with 5 variables of various shapes and dimensions
ds = xr.Dataset({
    'var1': (('time', 'lat', 'lon'), np.random.rand(20, 10, 15)),      # e.g. climate data
    'var2': (('time', 'lat', 'lon'), np.random.rand(20, 10, 15)),                        # e.g. 2D image-like array
    'var3': (('time', 'lat'), np.random.rand(20, 10)),
    },
    coords = {'time':list(range(20)),'lat':list(range(10)),'lon':list(range(15))}
)

bgen1 = xbatcher.BatchGenerator(ds=ds[['var1','var2','var3']], input_dims=dict(ds.dims))
print(f'bgen1 has {len(bgen1)} batches')
print("First batch from var1:")
print(bgen1[0])

ds2 = xr.Dataset({
    'var1': (('time', 'lat', 'lon'), np.random.rand(20, 8, 15)),      # e.g. climate data
    'var2': (('time', 'lat', 'lon'), np.random.rand(20, 8, 15)),                        # e.g. 2D image-like array
    'var3': (('time', 'lat'), np.random.rand(20, 8)), 
    },
    coords = {'time':list(range(20)),'lat':list(range(8)),'lon':list(range(15))}
)
bgen2 = xbatcher.BatchGenerator(ds=ds2[['var1','var2','var3']], input_dims=dict(ds2.dims))
print(f'bgen2 has {len(bgen2)} batches')
print("First batch from var1:")
print(bgen2[0])


In [None]:
 xr.concat([bgen1[0],bgen2[0]],dim='time',join='inner')

In [None]:
from torchdata.nodes import Mapper, MultiNodeWeightedSampler, IterableWrapper, Loader, BaseNode,ParallelMapper
import collections
concat_axis = 'time'
join_param = 'inner'
def combine_samples(x):
    return xr.concat([next(i) for i in x],dim=concat_axis,join=join_param)

datasets = IterableWrapper([bgen1, bgen2])

multi_node_sampler = ParallelMapper(datasets, map_fn=combine_samples, num_workers=3, method="thread")

# Since nodes are iterators, they need to be manually .reset() between epochs.
# We can wrap the root node in Loader to convert it to a more conventional Iterable.
loader = Loader(multi_node_sampler)


In [None]:
import xarray as xr
import xbatcher
from torchdata.nodes import BaseNode, IterableWrapper, Loader, ParallelMapper

def combine_round_robin(*batches, concat_dim='time', join='inner'):
    # Concatenate batches from each generator along the given dimension.
    return xr.concat(batches, dim=concat_dim, join=join)

class RoundRobinNode(BaseNode):
    def __init__(self, nodes, combine_fn, concat_dim='time', join='inner'):
        super().__init__()
        self.nodes = nodes
        self.combine_fn = combine_fn
        self.concat_dim = concat_dim
        self.join = join

    def reset(self, initial_state=None):
        super().reset(initial_state)
        for node in self.nodes:
            node.reset(initial_state)

    def get_state(self):
        return {i: node.get_state() for i, node in enumerate(self.nodes)}

    def _get_next_batch(self, node):
        # Attempt to fetch the next batch; if exhausted, reset and try again.
        try:
            return next(node)
        except StopIteration:
            node.reset()
            return next(node)

    def next(self):
        # Use ParallelMapper to apply _get_next_batch to each node concurrently.
        mapper = ParallelMapper(
            source=IterableWrapper(self.nodes),
            map_fn=self._get_next_batch,
            num_workers=1,
            method="thread"
        )
        batches = list(mapper)
        return self.combine_fn(*batches, concat_dim=self.concat_dim, join=self.join)

# Example usage:
# Assume bgen1 and bgen2 are your xbatcher BatchGenerators.
node1 = IterableWrapper(bgen1)
node2 = IterableWrapper(bgen2)

# Create the round-robin node to concurrently retrieve a batch from each generator.
round_robin_node = RoundRobinNode(
    [node1, node2],
    combine_round_robin,
    concat_dim='time',
    join='inner'
)

# Wrap it in a Loader for a dataloader-like interface.
loader = Loader(round_robin_node)
print('made loader')
# Iterate over the loader to obtain mixed batches.
for mixed_batch in loader:
    print(mixed_batch)


In [None]:
import xarray as xr
import xbatcher
from torchdata.nodes import BaseNode, IterableWrapper, Loader

# Custom node that sequentially yields batches from a list of nodes.
class SequentialNode(BaseNode):
    def __init__(self, nodes):
        super().__init__()
        self.nodes = nodes
        self.current = 0

    def reset(self, initial_state=None):
        super().reset(initial_state)
        for node in self.nodes:
            node.reset(initial_state)
        self.current = 0

    def get_state(self):
        return {
            "current": self.current,
            "states": [node.get_state() for node in self.nodes]
        }

    def next(self):
        # Loop until we find a node with data or we run out of nodes.
        while self.current < len(self.nodes):
            try:
                # Attempt to fetch the next batch from the current node.
                return next(self.nodes[self.current])
            except StopIteration:
                # If exhausted, move to the next node.
                self.current += 1
        # If all nodes are exhausted, signal end-of-iteration.
        raise StopIteration

# Example: Create two xbatcher BatchGenerators over different datasets.
ds1 = xr.Dataset({
    'var1': (('time', 'lat', 'lon'), xr.DataArray(100+np.random.rand(20, 10, 15)).data)
})
ds2 = xr.Dataset({
    'var1': (('time', 'lat', 'lon'), xr.DataArray(np.random.rand(14, 10, 15)).data)
})

# Create BatchGenerators for each dataset.
bgen1 = xbatcher.BatchGenerator(ds=ds1, input_dims={'time': 5, 'lat': 10, 'lon': 15})
bgen2 = xbatcher.BatchGenerator(ds=ds2, input_dims={'time': 5, 'lat': 10, 'lon': 15})

# Wrap each BatchGenerator in an IterableWrapper to convert them to torchdata nodes.
node1 = IterableWrapper(bgen1)
node2 = IterableWrapper(bgen2)

# Create a SequentialNode that will iterate through node1 then node2.
seq_node = SequentialNode([node1, node2])

# Wrap the custom sequential node in a Loader for dataloader-like iteration.
loader = Loader(seq_node)

# Iterate over the loader to retrieve batches sequentially.
for batch in loader:
    print(batch)


In [None]:
a

In [None]:
for i in loader:
    print(i)

In [None]:
for i, batch in enumerate(tqdm(meta_train_dl.data)):
    print(f"Meta Batch {i}:")
    for key, tensor in batch.items():
        print(f"  {key}: shape {tensor.dtype}")
    if i == 5:
        break


In [None]:
import cProfile

code = '''
for i, batch in enumerate(meta_train_dl.data):
    # print(f"Meta Batch {i}:")
    # for key, tensor in batch.items():
    #     print(f"  {key}: shape {tensor.shape}")
    if i == 5:
        break
'''

out = cProfile.run(code,sort=True)


In [None]:
model_architecture = crested.tl.zoo.simple_convnet(
    seq_len=2114, num_classes=batch['y'].shape[1]
)


In [None]:
import keras
# Create your own configuration
# I recommend trying this for peak regression with a weighted cosine mse log loss function
optimizer = keras.optimizers.Adam(learning_rate=1e-5)
loss = crested.tl.losses.CosineMSELogLoss(max_weight=100, multiplier=1)
loss = crested.tl.losses.PoissonLoss()

metrics = [
    keras.metrics.MeanAbsoluteError(),
    # keras.metrics.MeanSquaredError(),
    # keras.metrics.CosineSimilarity(axis=1),
    crested.tl.metrics.PearsonCorrelation(),
    # crested.tl.metrics.ConcordanceCorrelationCoefficient(),
    # crested.tl.metrics.PearsonCorrelationLog(),
    # crested.tl.metrics.ZeroPenaltyMetric(),
]

alternative_config = crested.tl.TaskConfig(optimizer, loss, metrics)
print(alternative_config)


In [None]:
# initialize some lazy model parameters *yawn*
model_architecture(batch)

In [None]:
trainer = crested.tl.Crested(
    data=meta_module,
    model=model_architecture,
    config=alternative_config,
    project_name="mouse_biccn",  # change to your liking
    run_name="basemodel",  # change to your liking
    logger=None,  # or None, 'dvc', 'tensorboard'
    seed=7,  # For reproducibility
)
# train the model
trainer.fit(
    epochs=60,
    learning_rate_reduce_patience=3,
    early_stopping_patience=6,
)
