In [94]:
import os
import tempfile
from pathlib import Path
import pandas as pd
import numpy as np
import pyBigWig
import copy
import xarray as xr
import tqdm

# Import our new module system and utilities.
import crandata
from crandata import CrAnDataModule, MetaCrAnDataModule, CrAnData
from crandata.chrom_io import import_bigwigs
from crandata.seq_io import add_genome_sequences_to_crandata, DNATransform

# Create temporary directories for synthetic data.
temp_dir = tempfile.TemporaryDirectory()
base_dir = Path(temp_dir.name)
beds_dir = base_dir / "beds"
bigwigs_dir = base_dir / "bigwigs"
beds_dir.mkdir(exist_ok=True)
bigwigs_dir.mkdir(exist_ok=True)

# Create a chromsizes file.
chromsizes_file = base_dir / "chrom.sizes"
with open(chromsizes_file, "w") as f:
    f.write("chr1\t1000\n")

# Create two BED files (simulate two different classes).
bed_data_A = pd.DataFrame({
    0: ["chr1", "chr1"],
    1: [100, 300],
    2: [200, 400]
})
bed_data_B = pd.DataFrame({
    0: ["chr1", "chr1"],
    1: [150, 350],
    2: [250, 450]
})
bed_file_A = beds_dir / "ClassA.bed"
bed_file_B = beds_dir / "ClassB.bed"
bed_data_A.to_csv(bed_file_A, sep="\t", header=False, index=False)
bed_data_B.to_csv(bed_file_B, sep="\t", header=False, index=False)

# Create a consensus BED file.
consensus = pd.DataFrame({
    0: ["chr1", "chr1", "chr1"],
    1: [100, 300, 350],
    2: [200, 400, 450]
})
consensus_file = base_dir / "consensus.bed"
consensus.to_csv(consensus_file, sep="\t", header=False, index=False)

# Create two bigWig files.
bigwig_file1 = bigwigs_dir / "test.bw"
bw1 = pyBigWig.open(str(bigwig_file1), "w")
bw1.addHeader([("chr1", 1000)])
bw1.addEntries(chroms=["chr1"], starts=[0], ends=[1000], values=[5.0])
bw1.close()

bigwig_file2 = bigwigs_dir / "test2.bw"
bw2 = pyBigWig.open(str(bigwig_file2), "w")
bw2.addHeader([("chr1", 1000)])
bw2.addEntries(chroms=["chr1"], starts=[0], ends=[1000], values=[4.0])
bw2.close()

# Set extraction parameters.
target_region_width = 100
backed_path = base_dir / "chrom_data.zarr"

# Create the CrAnData object from bigWig files and consensus regions.
adata = import_bigwigs(
    bigwigs_folder=str(bigwigs_dir),
    regions_file=str(consensus_file),
    backed_path=str(backed_path),
    target_region_width=target_region_width,
    chromsizes_file=str(chromsizes_file),
)

crandata.train_val_test_split(adata,strategy='chr_auto')

# Create a dummy FASTA file for a genome.
fasta_file = base_dir / "chr1.fa"
with open(fasta_file, "w") as f:
    f.write(">chr1\n")
    f.write("A" * 1000 + "\n")

# Create a Genome object.
from crandata._genome import Genome
dummy_genome = Genome(str(fasta_file), chrom_sizes=str(chromsizes_file))

# Add sequences to the CrAnData using the provided seq_io utility.
# Here we use the consensus regions as our ranges.
consensus.columns = ['chrom', 'start', 'end']
adata = add_genome_sequences_to_crandata(adata, consensus, dummy_genome)

# Write the CrAnData object to disk and then reload it to ensure sequences are out-of-memory.
adata.to_zarr(str(backed_path),mode='a')
adata_loaded = CrAnData.open_zarr(str(backed_path))
print("Loaded CrAnData:")
print(adata_loaded)

# Create two copies to simulate two datasets (e.g. two species), and add a "split" column in var metadata.
adata1 = copy.deepcopy(adata_loaded)
adata2 = copy.deepcopy(adata_loaded)
adata1["var-_-split"] = xr.DataArray(np.full(adata1.sizes["var"], "train"), dims=["var"])
adata2["var-_-split"] = xr.DataArray(np.full(adata2.sizes["var"], "train"), dims=["var"])

# Create a DNATransform instance.
transform = DNATransform(out_len=80, random_rc=True, max_shift=5)

# Instantiate the MetaCrAnDataModule with the two datasets.
# Note: The batch_size is now 3, matching the number of consensus regions (var dimension).
meta_module = MetaCrAnDataModule(
    adatas=[adata1, adata2],
    batch_size=[2,2],        # adjust batch size to not exceed var length (3)
    load_keys={'sequences':'sequences','X':'X'},
    shuffle=True,
    dnatransform=transform,
    epoch_size=10
)

# Setup each underlying module for the "train" stage.
for mod in meta_module.modules:
    mod.setup(state="train")

# Retrieve the training dataloader from the meta module and iterate over a couple of batches.
meta_train_dl = meta_module.train_dataloader
print("\nIterating over a couple of training batches from MetaCrAnDataModule:")
for i, batch in enumerate(tqdm.tqdm(meta_train_dl)):
    print(batch)
    print(f"\nMeta Batch {i}:")
    for key, tensor in batch.items():
        print(f"  {key}: shape {tensor.shape}")
    if i >= 1:
        break

print("\nTemporary directory contents:")
print(os.listdir(base_dir))
temp_dir.cleanup()


100%|██████████| 2/2 [00:00<00:00, 13231.24it/s]

2025-03-24T20:55:42.212571-0700 INFO Extracting values from 2 bigWig files...



  return cls(**configuration_parsed)
  result = await AsyncArray._create_v3(
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  result = await AsyncArray._create_v3(
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
100%|██████████| 2/2 [00:00<00:00, 42.15it/s]
  adata['X'] = adata['X'].chunk({'obs':adata.dims['obs'],'var':chunk_size,'seq_bins':adata.dims['seq_bins']}) #enforce the same as before
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_par

Loaded CrAnData:
CrAnData object
Array names: ['var-_-start', 'var-_-end', 'var-_-chrom', 'obs-_-file_path', 'var-_-index', 'obs-_-index', 'var-_-chunk_index', 'X', 'sequences', 'var-_-split']
Coordinates: ['var', 'obs', 'seq_bins']



  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  dim_dict[self.batch_dim] = self.batch_size
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  dim_dict[self.batch_dim] = self.batch_size



Iterating over a couple of training batches from MetaCrAnDataModule:


1it [00:00, 26.36it/s]

{'sequences': tensor([[[1, 0, 0, 0],
         [1, 0, 0, 0],
         [1, 0, 0, 0],
         ...,
         [1, 0, 0, 0],
         [1, 0, 0, 0],
         [1, 0, 0, 0]],

        [[1, 0, 0, 0],
         [1, 0, 0, 0],
         [1, 0, 0, 0],
         ...,
         [1, 0, 0, 0],
         [1, 0, 0, 0],
         [1, 0, 0, 0]],

        [[1, 0, 0, 0],
         [1, 0, 0, 0],
         [1, 0, 0, 0],
         ...,
         [1, 0, 0, 0],
         [1, 0, 0, 0],
         [1, 0, 0, 0]],

        [[1, 0, 0, 0],
         [1, 0, 0, 0],
         [1, 0, 0, 0],
         ...,
         [1, 0, 0, 0],
         [1, 0, 0, 0],
         [1, 0, 0, 0]]], dtype=torch.uint8), 'X': tensor([[[6.9079e-310, 4.6558e-310,  0.0000e+00,  0.0000e+00, 6.9072e-310,
          6.9072e-310, 6.9072e-310, 6.9072e-310, 6.9072e-310, 6.9072e-310,
          6.9072e-310, 6.9072e-310, 6.9072e-310, 6.9072e-310, 6.9072e-310,
          6.9072e-310, 6.9072e-310, 6.9072e-310, 6.9072e-310, 6.9072e-310,
          6.9072e-310, 6.9072e-310, 6.9072e-3




In [None]:
sdfs

In [None]:
# Should the fill in _extract_values_from_bigwig actually be 0? Can we filter var where all is 0/nan without loading everything into memory?

In [1]:
import crandata
import xarray as xr
import pandas as pd
import numpy as np
import os
import crested
from tqdm import tqdm

  cls = super().__new__(mcls, name, bases, namespace, **kwargs)


In [75]:
genomes = {}
beds = {}
chromsizes_files = {}
bed_files = {}
species = ['human','macaque','mouse']
species_codes = {'human':0,'macaque':1,'mouse':2}

MAX_SHIFT = 5
WINDOW_SIZE = 2114
WINDOW_SIZE = WINDOW_SIZE #+ 2*MAX_SHIFT
OFFSET = WINDOW_SIZE // 2  # e.g., 50% overlap
N_THRESHOLD = 0.3
n_bins = WINDOW_SIZE//50


In [9]:
for s in species:
    genome_path = '/allen/programs/celltypes/workgroups/rnaseqanalysis/EvoGen/Team/Matthew/genome/onehots/'+s
    fasta_file = os.path.join(genome_path,s+'.fa')
    chrom_sizes = os.path.join(genome_path,s+'.fa.sizes')
    annotation_gtf_file = os.path.join(genome_path,s+'.annotation.gtf')
    chromsizes_files[s] = chrom_sizes
    genome = crandata.Genome(fasta_file, chrom_sizes, annotation_gtf_file)
    genome.to_memory()
    genomes[s] = genome
    OUTPUT_BED = os.path.join(genome_path, "binned_genome.bed")
    bed_files[s] = OUTPUT_BED
    # Generate bins and optionally write to disk.
    binned_df = crandata.bin_genome(genome, WINDOW_SIZE, OFFSET, n_threshold=N_THRESHOLD, output_path=OUTPUT_BED).reset_index(drop=True)
    print("Filtered bins:")
    print(binned_df)


2025-03-24T14:12:06.933762-0700 INFO Genome sequences loaded into memory.


Calculating N content: 100%|██████████| 2583507/2583507 [02:08<00:00, 20092.55it/s]


Filtered bins:
              chrom    start      end    prop_n
0              chr1  2999767  3001881  0.110638
1              chr1  3002938  3005052  0.085579
2              chr1  3003995  3006109  0.000000
3              chr1  3005052  3007166  0.000000
4              chr1  3006109  3008223  0.000000
...             ...      ...      ...       ...
2509462  JH584292.1     8457    10571  0.000000
2509463  JH584292.1     9514    11628  0.000000
2509464  JH584292.1    10571    12685  0.000000
2509465  JH584292.1    11628    13742  0.000000
2509466  JH584292.1    12685    14799  0.000000

[2509467 rows x 4 columns]
2025-03-24T14:14:50.616433-0700 INFO Genome sequences loaded into memory.


Calculating N content: 100%|██████████| 2932321/2932321 [02:27<00:00, 19817.85it/s]


Filtered bins:
              chrom  start    end   prop_n
0              chr1   9514  11628  0.23026
1              chr1  10571  12685  0.00000
2              chr1  11628  13742  0.00000
3              chr1  12685  14799  0.00000
4              chr1  13742  15856  0.00000
...             ...    ...    ...      ...
2786513  KI270518.1      1   2115  0.00000
2786514  KI270530.1      1   2115  0.00000
2786515  KI270304.1      1   2115  0.00000
2786516  KI270418.1      1   2115  0.00000
2786517  KI270424.1      1   2115  0.00000

[2786518 rows x 4 columns]
2025-03-24T14:17:52.924853-0700 INFO Genome sequences loaded into memory.


Calculating N content: 100%|██████████| 2806701/2806701 [02:21<00:00, 19821.34it/s]


Filtered bins:
               chrom  start    end  prop_n
0        NC_041754.1      1   2115     0.0
1        NC_041754.1   1058   3172     0.0
2        NC_041754.1   2115   4229     0.0
3        NC_041754.1   3172   5286     0.0
4        NC_041754.1   4229   6343     0.0
...              ...    ...    ...     ...
2773980  NC_005943.1   9514  11628     0.0
2773981  NC_005943.1  10571  12685     0.0
2773982  NC_005943.1  11628  13742     0.0
2773983  NC_005943.1  12685  14799     0.0
2773984  NC_005943.1  13742  15856     0.0

[2773985 rows x 4 columns]


In [4]:
# s='mouse'
# adatas = {}

# bigwigs_dir = os.path.join('/allen/programs/celltypes/workgroups/rnaseqanalysis/EvoGen/SpinalCord/manuscript/ATAC',s,'Group_bigwig')
# adatas[s] = crandata.chrom_io.import_bigwigs(
#     bigwigs_folder=bigwigs_dir,
#     regions_file='/allen/programs/celltypes/workgroups/rnaseqanalysis/EvoGen/Team/Matthew/genome/onehots/mouse/binned_genome_test.bed',
#     backed_path='/home/matthew.schmitz/Matthew/'+s+'_spc_test.zarr',
#     target_region_width=WINDOW_SIZE,
#     chromsizes_file=chromsizes_files[s],
#     target = 'raw',
#     max_stochastic_shift=5,
#     chunk_size=2048,
#     n_bins=n_bins
# )
# bed = adatas[s].get_dataframe('var').loc[:,['chrom','start','end']]
# adatas[s] = crandata.seq_io.add_genome_sequences_to_crandata(adatas[s], bed, genomes[s])
# print(adatas[s]['sequences'])
# adatas[s].to_zarr(adatas[s].encoding['source'],mode='a')


In [77]:
adatas = {}

for s in species:
    print(s)
    bigwigs_dir = os.path.join('/allen/programs/celltypes/workgroups/rnaseqanalysis/EvoGen/SpinalCord/manuscript/ATAC',s,'Group_bigwig')
    adatas[s] = crandata.chrom_io.import_bigwigs(
        bigwigs_folder=bigwigs_dir,
        regions_file=bed_files[s],
        backed_path='/home/matthew.schmitz/Matthew/'+s+'_spc_test.zarr',
        target_region_width=WINDOW_SIZE,
        chromsizes_file=chromsizes_files[s],
        target = 'raw',
        max_stochastic_shift=5,
        chunk_size=256,
        n_bins=n_bins
    )
    bed = adatas[s].get_dataframe('var').loc[:,['chrom','start','end']]
    adatas[s] = crandata.seq_io.add_genome_sequences_to_crandata(adatas[s], bed, genomes[s])
    print(adatas[s]['sequences'])
    adatas[s]['var-_-species'] = xr.DataArray(np.repeat(species_codes[s],adatas[s].sizes['var']),dims='var').chunk({'var':adatas[s].attrs['chunk_size']})
    adatas[s].to_zarr(adatas[s].encoding['source'],mode='a')
    adatas[s] = crandata.crandata.CrAnData.open_zarr('/home/matthew.schmitz/Matthew/'+s+'_spc_test.zarr')
    

  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
 

In [92]:
adatas[s].sizes

Frozen({'var': 2508122, 'seq_bins': 42, 'obs': 49, 'seq_len': 2124, 'nuc': 4})

In [60]:
# import numpy as np
# adatas['mouse'].uns['chunk_size'] = 512
# adatas['human'].uns['chunk_size'] = 512
# adatas['macaque'].uns['chunk_size'] = 512
# adatas['mouse'].var["chunk_index"] = np.arange(adatas['mouse'].var.shape[0]) // 512
# adatas['human'].var["chunk_index"] = np.arange(adatas['human'].var.shape[0]) // 512
# adatas['macaque'].var["chunk_index"] = np.arange(adatas['macaque'].var.shape[0]) // 512


In [79]:
for s in adatas.keys():
    crandata.train_val_test_split(
        adatas[s], strategy="region", val_size=0.1, test_size=0.1, random_state=42
    )
    adatas[s].to_zarr(adatas[s].encoding['source'],mode='a')
    


  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
 

In [62]:
# for s in adatas.keys():
#     adatas[s]['sequences'] = adatas[s]['sequences'].chunk({'var':2048,'seq_len':adatas[s].dims['seq_len'],'nuc':adatas[s].dims['nuc']})
#     adatas[s]['X'] = adatas[s]['X'].chunk({'obs':adatas[s].dims['obs'],'var':2048,'seq_bins':adatas[s].dims['seq_bins']})
#     adatas[s].to_zarr(adatas[s].encoding['source'],mode='w',safe_chunks=False)

In [80]:
import importlib
crandata = importlib.reload(crandata)
# crandata._module.MetaCrAnDataModule = importlib.reload(crandata._module.MetaCrAnDataModule)

In [81]:
adatas[s]['X']

Unnamed: 0,Array,Chunk
Bytes,38.46 GiB,4.02 MiB
Shape,"(49, 2508122, 42)","(49, 256, 42)"
Dask graph,9798 chunks in 2 graph layers,9798 chunks in 2 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray
"Array Chunk Bytes 38.46 GiB 4.02 MiB Shape (49, 2508122, 42) (49, 256, 42) Dask graph 9798 chunks in 2 graph layers Data type float64 numpy.ndarray",42  2508122  49,

Unnamed: 0,Array,Chunk
Bytes,38.46 GiB,4.02 MiB
Shape,"(49, 2508122, 42)","(49, 256, 42)"
Dask graph,9798 chunks in 2 graph layers,9798 chunks in 2 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray


In [82]:
adatas[s]['sequences']

Unnamed: 0,Array,Chunk
Bytes,19.85 GiB,2.07 MiB
Shape,"(2508122, 2124, 4)","(256, 2124, 4)"
Dask graph,9798 chunks in 2 graph layers,9798 chunks in 2 graph layers
Data type,uint8 numpy.ndarray,uint8 numpy.ndarray
"Array Chunk Bytes 19.85 GiB 2.07 MiB Shape (2508122, 2124, 4) (256, 2124, 4) Dask graph 9798 chunks in 2 graph layers Data type uint8 numpy.ndarray",4  2124  2508122,

Unnamed: 0,Array,Chunk
Bytes,19.85 GiB,2.07 MiB
Shape,"(2508122, 2124, 4)","(256, 2124, 4)"
Dask graph,9798 chunks in 2 graph layers,9798 chunks in 2 graph layers
Data type,uint8 numpy.ndarray,uint8 numpy.ndarray


In [93]:
transform = crandata.seq_io.DNATransform(out_len=WINDOW_SIZE, random_rc=True, max_shift=MAX_SHIFT)

meta_module = crandata.MetaCrAnDataModule(
    adatas=list(adatas.values()),
    batch_size=[16,16,16],
    load_keys={'X': 'y','sequences':'sequence','var-_-species':'species'},
    dnatransform=transform,
    join='inner',
    num_workers=0,
    epoch_size=1000000    # small epoch size for quick testing
)

# Setup the meta module for the "fit" stage (train/val)
meta_module.setup("train")

# Retrieve the training dataloader from the meta module and iterate over a couple of batches.
meta_train_dl = meta_module.train_dataloader

print("\nIterating over a couple of training batches from MetaAnnDataModule:")
for i, batch in enumerate(tqdm(meta_train_dl)):
    print(f"Meta Batch {i}:")
    for key, tensor in batch.items():
        print(f"  {key}: shape {tensor.shape}")
    if i == 5:
        break


  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  dim_dict[self.batch_dim] = self.batch_size
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  dim_dict[self.batch_dim] = self.batch_size
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  dim_dict[self.batch_dim] = self.batch_size



Iterating over a couple of training batches from MetaAnnDataModule:


1it [00:00,  2.32it/s]

Meta Batch 0:
  y: shape torch.Size([47, 48, 42])
  sequence: shape torch.Size([48, 2124, 4])
  species: shape torch.Size([48])


2it [00:00,  2.83it/s]

Meta Batch 1:
  y: shape torch.Size([47, 48, 42])
  sequence: shape torch.Size([48, 2124, 4])
  species: shape torch.Size([48])


3it [00:01,  3.04it/s]

Meta Batch 2:
  y: shape torch.Size([47, 48, 42])
  sequence: shape torch.Size([48, 2124, 4])
  species: shape torch.Size([48])


4it [00:01,  3.18it/s]

Meta Batch 3:
  y: shape torch.Size([47, 48, 42])
  sequence: shape torch.Size([48, 2124, 4])
  species: shape torch.Size([48])


5it [00:01,  3.17it/s]

Meta Batch 4:
  y: shape torch.Size([47, 48, 42])
  sequence: shape torch.Size([48, 2124, 4])
  species: shape torch.Size([48])


5it [00:01,  2.59it/s]

Meta Batch 5:
  y: shape torch.Size([47, 48, 42])
  sequence: shape torch.Size([48, 2124, 4])
  species: shape torch.Size([48])





In [91]:
batch

{'y': tensor([[[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]],
 
         [[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]],
 
         [[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]],
 
         ...,
 
         [[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ...,

In [90]:
import cProfile

code = '''
for i, batch in enumerate(tqdm(meta_train_dl)):
    print(f"Meta Batch {i}:")
    for key, tensor in batch.items():
        print(f"  {key}: shape {tensor.shape}")
    if i == 3:
        break
'''

out = cProfile.run(code,sort=True)


1it [00:00,  2.30it/s]

Meta Batch 0:
  y: shape torch.Size([51, 48, 42])
  sequence: shape torch.Size([48, 2124, 4])
  species: shape torch.Size([48])


2it [00:00,  2.35it/s]

Meta Batch 1:
  y: shape torch.Size([51, 48, 42])
  sequence: shape torch.Size([48, 2124, 4])
  species: shape torch.Size([48])


3it [00:01,  2.34it/s]

Meta Batch 2:
  y: shape torch.Size([51, 48, 42])
  sequence: shape torch.Size([48, 2124, 4])
  species: shape torch.Size([48])


3it [00:01,  1.75it/s]

Meta Batch 3:
  y: shape torch.Size([51, 48, 42])
  sequence: shape torch.Size([48, 2124, 4])
  species: shape torch.Size([48])
         3687612 function calls (3676112 primitive calls) in 1.717 seconds

   Ordered by: internal time

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
  3259436    0.426    0.000    0.435    0.000 _collections_abc.py:868(__iter__)
     1432    0.289    0.000    0.539    0.000 {method 'update' of 'set' objects}
      108    0.221    0.002    0.244    0.002 blockwise.py:625(get_output_keys)
      148    0.182    0.001    0.357    0.002 {method 'intersection' of 'set' objects}
       12    0.073    0.006    0.945    0.079 optimization.py:37(optimize)
    40/25    0.050    0.001    0.009    0.000 cpu.py:188(as_numpy_array_wrapper)
       12    0.019    0.002    0.815    0.068 highlevelgraph.py:707(cull)
    72/57    0.017    0.000    0.009    0.000 {built-in method _io.open}
     1908    0.015    0.000    0.024    0.000 ipkernel.py:775(_




In [38]:
model_architecture = crested.tl.zoo.simple_convnet(
    seq_len=2114, num_classes=batch['y'].shape[1]
)


In [39]:
import keras
# Create your own configuration
# I recommend trying this for peak regression with a weighted cosine mse log loss function
optimizer = keras.optimizers.Adam(learning_rate=1e-5)
loss = crested.tl.losses.CosineMSELogLoss(max_weight=100, multiplier=1)
loss = crested.tl.losses.PoissonLoss()

metrics = [
    keras.metrics.MeanAbsoluteError(),
    # keras.metrics.MeanSquaredError(),
    # keras.metrics.CosineSimilarity(axis=1),
    crested.tl.metrics.PearsonCorrelation(),
    # crested.tl.metrics.ConcordanceCorrelationCoefficient(),
    # crested.tl.metrics.PearsonCorrelationLog(),
    # crested.tl.metrics.ZeroPenaltyMetric(),
]

alternative_config = crested.tl.TaskConfig(optimizer, loss, metrics)
print(alternative_config)


TaskConfig(optimizer=<keras.src.backend.torch.optimizers.torch_adam.Adam object at 0x7f251de7b5f0>, loss=<crested.tl.losses._poisson.PoissonLoss object at 0x7f251de7b380>, metrics=[<MeanAbsoluteError name=mean_absolute_error>, <PearsonCorrelation name=pearson_correlation>])


In [53]:
batch['sequence'].shape

torch.Size([16, 2124, 4])

In [50]:
# initialize some lazy model parameters *yawn*
model_architecture(batch['sequence'].float().mean(0).unsqueeze(0))

ValueError: Input 0 of layer "functional_2" is incompatible with the layer: expected shape=(None, 2114, 4), found shape=(1, 2124, 4)

In [49]:
trainer = crested.tl.Crested(
    data=meta_module,
    model=model_architecture,
    config=alternative_config,
    project_name="mouse_biccn",  # change to your liking
    run_name="basemodel",  # change to your liking
    logger=None,  # or None, 'dvc', 'tensorboard'
    seed=7,  # For reproducibility
)
# train the model
trainer.fit(
    epochs=60,
    learning_rate_reduce_patience=3,
    early_stopping_patience=6,
)




None


AttributeError: 'MetaCrAnDataModule' object has no attribute 'train_dataset'