In [1]:
import os
import tempfile
from pathlib import Path
import pandas as pd
import numpy as np
import pyBigWig
import copy
import xarray as xr
import tqdm
import crandata


# Create temporary directories for synthetic data.
temp_dir = tempfile.TemporaryDirectory()
base_dir = Path(temp_dir.name)
beds_dir = base_dir / "beds"
bigwigs_dir = base_dir / "bigwigs"
beds_dir.mkdir(exist_ok=True)
bigwigs_dir.mkdir(exist_ok=True)

# Create a chromsizes file.
chromsizes_file = base_dir / "chrom.sizes"
with open(chromsizes_file, "w") as f:
    f.write("chr1\t1000\n")

# Create two BED files (simulate two different classes).
bed_data_A = pd.DataFrame({
    0: ["chr1", "chr1"],
    1: [100, 300],
    2: [200, 400]
})
bed_data_B = pd.DataFrame({
    0: ["chr1", "chr1"],
    1: [150, 350],
    2: [250, 450]
})
bed_file_A = beds_dir / "ClassA.bed"
bed_file_B = beds_dir / "ClassB.bed"
bed_data_A.to_csv(bed_file_A, sep="\t", header=False, index=False)
bed_data_B.to_csv(bed_file_B, sep="\t", header=False, index=False)

# Create a consensus BED file.
consensus = pd.DataFrame({
    0: ["chr1", "chr1", "chr1"],
    1: [100, 300, 350],
    2: [200, 400, 450]
})
consensus_file = base_dir / "consensus.bed"
consensus.to_csv(consensus_file, sep="\t", header=False, index=False)

# Create two bigWig files.
bigwig_file1 = bigwigs_dir / "test.bw"
bw1 = pyBigWig.open(str(bigwig_file1), "w")
bw1.addHeader([("chr1", 1000)])
bw1.addEntries(chroms=["chr1"], starts=[0], ends=[1000], values=[5.0])
bw1.close()

bigwig_file2 = bigwigs_dir / "test2.bw"
bw2 = pyBigWig.open(str(bigwig_file2), "w")
bw2.addHeader([("chr1", 1000)])
bw2.addEntries(chroms=["chr1"], starts=[0], ends=[1000], values=[4.0])
bw2.close()

# Set extraction parameters.
target_region_width = 100
backed_path = base_dir / "chrom_data.zarr"
print(backed_path)
# Create the CrAnData object from bigWig files and consensus regions.
adata = crandata.chrom_io.import_bigwigs(
    bigwigs_folder=str(bigwigs_dir),
    regions_file=str(consensus_file),
    backed_path=str(backed_path),
    target_region_width=target_region_width,
    chromsizes_file=str(chromsizes_file),
)

crandata.train_val_test_split(adata,strategy='chr_auto')

# Create a dummy FASTA file for a genome.
fasta_file = base_dir / "chr1.fa"
with open(fasta_file, "w") as f:
    f.write(">chr1\n")
    f.write("A" * 1000 + "\n")

dummy_genome = crandata.Genome(str(fasta_file), chrom_sizes=str(chromsizes_file))

# Add sequences to the CrAnData using the provided seq_io utility.
# Here we use the consensus regions as our ranges.
consensus.columns = ['chrom', 'start', 'end']
adata = crandata.seq_io.add_genome_sequences_to_crandata(adata, consensus, dummy_genome)

print(adata)
# Write the CrAnData object to disk and then reload it to ensure sequences are out-of-memory.
adata.to_zarr(str(backed_path),mode='a')
adata_loaded = crandata.CrAnData.open_zarr(str(backed_path))
print("Loaded CrAnData:")
print(adata_loaded)

# Create two copies to simulate two datasets (e.g. two species), and add a "split" column in var metadata.
adata1 = copy.deepcopy(adata_loaded)
adata2 = copy.deepcopy(adata_loaded)
adata1["var-_-split"] = xr.DataArray(np.full(adata1.sizes["var"], "train"), dims=["var"])
adata2["var-_-split"] = xr.DataArray(np.full(adata2.sizes["var"], "train"), dims=["var"])

# Create a DNATransform instance.
transform = crandata.seq_io.DNATransform(out_len=80, random_rc=True, max_shift=5)

# Instantiate the MetaCrAnDataModule with the two datasets.
# Note: The batch_size is now 3, matching the number of consensus regions (var dimension).
meta_module = crandata.MetaCrAnDataModule(
    adatas=[adata1, adata2],
    batch_size=[2,2],        # adjust batch size to not exceed var length (3)
    load_keys={'sequences':'sequences','X':'X'},
    shuffle=True,
    dnatransform=transform,
    epoch_size=10
)

meta_module.setup('train')

# Retrieve the training dataloader from the meta module and iterate over a couple of batches.
meta_train_dl = meta_module.train_dataloader
print("\nIterating over a couple of training batches from MetaCrAnDataModule:")
for i, batch in enumerate(tqdm.tqdm(meta_train_dl)):
    print(batch)
    print(f"\nMeta Batch {i}:")
    for key, tensor in batch.items():
        print(f"  {key}: shape {tensor.shape}")
    if i >= 1:
        break

print("\nTemporary directory contents:")
print(os.listdir(base_dir))
temp_dir.cleanup()


  cls = super().__new__(mcls, name, bases, namespace, **kwargs)


/scratch/fast/168840/tmp99yl2au7/chrom_data.zarr


100%|██████████| 2/2 [00:00<00:00, 4362.25it/s]
[32m2025-04-14 10:36:11.599[0m | [1mINFO    [0m | [36mcrandata.chrom_io[0m:[36mimport_bigwigs[0m:[36m330[0m - [1mExtracting values from 2 bigWig files...[0m
  return cls(**configuration_parsed)
  result = await AsyncArray._create_v3(
  return cls(**configuration_parsed)
  result = await AsyncArray._create_v3(
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  result = await AsyncArray._create_v3(
  return cls(**configuration_parsed)
  result = await AsyncArray._create_v3(
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
Processing consensus peak memor

<xarray.CrAnData> Size: 7kB
Dimensions:            (obs: 2, var: 3, seq_bins: 100, seq_len: 100, nuc: 4)
Coordinates:
  * obs                (obs) object 16B 'test' 'test2'
  * var                (var) object 24B 'chr1:100-200' ... 'chr1:350-450'
  * seq_bins           (seq_bins) int64 800B 0 1 2 3 4 5 6 ... 94 95 96 97 98 99
Dimensions without coordinates: seq_len, nuc
Data variables:
    obs-_-index        (obs) object 16B dask.array<chunksize=(2,), meta=np.ndarray>
    var-_-end          (var) int64 24B dask.array<chunksize=(3,), meta=np.ndarray>
    obs-_-file_path    (obs) object 16B dask.array<chunksize=(2,), meta=np.ndarray>
    var-_-chunk_index  (var) int64 24B dask.array<chunksize=(3,), meta=np.ndarray>
    var-_-chrom        (var) object 24B dask.array<chunksize=(3,), meta=np.ndarray>
    var-_-index        (var) object 24B dask.array<chunksize=(3,), meta=np.ndarray>
    var-_-start        (var) int64 24B dask.array<chunksize=(3,), meta=np.ndarray>
    X                  (ob

0it [00:00, ?it/s]

{'sequences': array([[[1, 0, 0, 0],
        [1, 0, 0, 0],
        [1, 0, 0, 0],
        ...,
        [1, 0, 0, 0],
        [1, 0, 0, 0],
        [1, 0, 0, 0]],

       [[1, 0, 0, 0],
        [1, 0, 0, 0],
        [1, 0, 0, 0],
        ...,
        [1, 0, 0, 0],
        [1, 0, 0, 0],
        [1, 0, 0, 0]],

       [[1, 0, 0, 0],
        [1, 0, 0, 0],
        [1, 0, 0, 0],
        ...,
        [1, 0, 0, 0],
        [1, 0, 0, 0],
        [1, 0, 0, 0]],

       [[1, 0, 0, 0],
        [1, 0, 0, 0],
        [1, 0, 0, 0],
        ...,
        [1, 0, 0, 0],
        [1, 0, 0, 0],
        [1, 0, 0, 0]]], dtype=uint8), 'X': array([[[6.90964373e-310, 6.90964373e-310, 4.68488233e-310,
         4.68488233e-310, 6.90943458e-310, 6.90943452e-310,
         6.90943482e-310, 4.68488233e-310, 6.90943501e-310,
         6.90960911e-310, 6.90943501e-310, 0.00000000e+000,
         6.90943617e-310, 6.90960911e-310, 6.90943617e-310,
         6.90943469e-310, 6.90960911e-310, 6.90943469e-310,
         6.90943469

1it [00:00, 18.01it/s]



Meta Batch 0:
  sequences: shape (4, 100, 4)
  X: shape (2, 4, 100)
{'sequences': array([[[1, 0, 0, 0],
        [1, 0, 0, 0],
        [1, 0, 0, 0],
        ...,
        [1, 0, 0, 0],
        [1, 0, 0, 0],
        [1, 0, 0, 0]],

       [[1, 0, 0, 0],
        [1, 0, 0, 0],
        [1, 0, 0, 0],
        ...,
        [1, 0, 0, 0],
        [1, 0, 0, 0],
        [1, 0, 0, 0]],

       [[1, 0, 0, 0],
        [1, 0, 0, 0],
        [1, 0, 0, 0],
        ...,
        [1, 0, 0, 0],
        [1, 0, 0, 0],
        [1, 0, 0, 0]],

       [[1, 0, 0, 0],
        [1, 0, 0, 0],
        [1, 0, 0, 0],
        ...,
        [1, 0, 0, 0],
        [1, 0, 0, 0],
        [1, 0, 0, 0]]], dtype=uint8), 'X': array([[[6.90964373e-310, 6.90964373e-310, 4.68488233e-310,
         4.68488233e-310, 6.90943458e-310, 6.90943452e-310,
         6.90943482e-310, 4.68488233e-310, 6.90943501e-310,
         6.90960911e-310, 6.90943501e-310, 0.00000000e+000,
         6.90943617e-310, 6.90960911e-310, 6.90943617e-310,
         




In [2]:
sdfs

NameError: name 'sdfs' is not defined

In [None]:
# Should the fill in _extract_values_from_bigwig actually be 0? Can we filter var where all is 0/nan without loading everything into memory?

In [5]:
import crandata
import xarray as xr
import pandas as pd
import numpy as np
import os
import crested
from tqdm import tqdm
import importlib

In [2]:
genomes = {}
beds = {}
chromsizes_files = {}
bed_files = {}
species = ['human','macaque','mouse']
species_codes = {'human':0,'macaque':1,'mouse':2}

MAX_SHIFT = 5
WINDOW_SIZE = 2114
WINDOW_SIZE = WINDOW_SIZE #+ 2*MAX_SHIFT
OFFSET = WINDOW_SIZE // 2  # e.g., 50% overlap
N_THRESHOLD = 0.3
n_bins = WINDOW_SIZE//50


In [3]:
for s in species:
    genome_path = '/allen/programs/celltypes/workgroups/rnaseqanalysis/EvoGen/Team/Matthew/genome/onehots/'+s
    fasta_file = os.path.join(genome_path,s+'.fa')
    chrom_sizes = os.path.join(genome_path,s+'.fa.sizes')
    annotation_gtf_file = os.path.join(genome_path,s+'.annotation.gtf')
    chromsizes_files[s] = chrom_sizes
    genome = crandata.Genome(fasta_file, chrom_sizes, annotation_gtf_file)
    genome.to_memory()
    genomes[s] = genome
    OUTPUT_BED = os.path.join(genome_path, "binned_genome.bed")
    bed_files[s] = OUTPUT_BED
    # Generate bins and optionally write to disk.
    binned_df = crandata.bin_genome(genome, WINDOW_SIZE, OFFSET, n_threshold=N_THRESHOLD, output_path=OUTPUT_BED).reset_index(drop=True)
    print("Filtered bins:")
    print(binned_df)


2025-04-14T14:50:52.539431-0700 INFO Genome sequences loaded into memory.


Calculating N content: 100%|██████████| 2932321/2932321 [04:31<00:00, 10781.34it/s]


Filtered bins:
              chrom  start    end   prop_n
0              chr1   9514  11628  0.23026
1              chr1  10571  12685  0.00000
2              chr1  11628  13742  0.00000
3              chr1  12685  14799  0.00000
4              chr1  13742  15856  0.00000
...             ...    ...    ...      ...
2786513  KI270518.1      1   2115  0.00000
2786514  KI270530.1      1   2115  0.00000
2786515  KI270304.1      1   2115  0.00000
2786516  KI270418.1      1   2115  0.00000
2786517  KI270424.1      1   2115  0.00000

[2786518 rows x 4 columns]
2025-04-14T14:56:16.284754-0700 INFO Genome sequences loaded into memory.


Calculating N content: 100%|██████████| 2806701/2806701 [04:09<00:00, 11236.23it/s]


Filtered bins:
               chrom  start    end  prop_n
0        NC_041754.1      1   2115     0.0
1        NC_041754.1   1058   3172     0.0
2        NC_041754.1   2115   4229     0.0
3        NC_041754.1   3172   5286     0.0
4        NC_041754.1   4229   6343     0.0
...              ...    ...    ...     ...
2773980  NC_005943.1   9514  11628     0.0
2773981  NC_005943.1  10571  12685     0.0
2773982  NC_005943.1  11628  13742     0.0
2773983  NC_005943.1  12685  14799     0.0
2773984  NC_005943.1  13742  15856     0.0

[2773985 rows x 4 columns]
2025-04-14T15:01:11.719357-0700 INFO Genome sequences loaded into memory.


Calculating N content: 100%|██████████| 2583507/2583507 [03:44<00:00, 11513.40it/s]


Filtered bins:
              chrom    start      end    prop_n
0              chr1  2999767  3001881  0.110638
1              chr1  3002938  3005052  0.085579
2              chr1  3003995  3006109  0.000000
3              chr1  3005052  3007166  0.000000
4              chr1  3006109  3008223  0.000000
...             ...      ...      ...       ...
2509462  JH584292.1     8457    10571  0.000000
2509463  JH584292.1     9514    11628  0.000000
2509464  JH584292.1    10571    12685  0.000000
2509465  JH584292.1    11628    13742  0.000000
2509466  JH584292.1    12685    14799  0.000000

[2509467 rows x 4 columns]


In [17]:
crandata.chrom_io = importlib.reload(crandata.chrom_io)

In [18]:
adatas = {}

for s in species[-1:]:
    print(s)
    bigwigs_dir = os.path.join('/allen/programs/celltypes/workgroups/rnaseqanalysis/EvoGen/SpinalCord/manuscript/ATAC',s,'Group_bigwig')
    adatas[s] = crandata.chrom_io.import_bigwigs(
        bigwigs_folder=bigwigs_dir,
        regions_file=bed_files[s],
        backed_path='/home/matthew.schmitz/Matthew/'+s+'_spc_test.zarr',
        target_region_width=WINDOW_SIZE,
        chromsizes_file=chromsizes_files[s],
        target = 'raw',
        max_stochastic_shift=5,
        tile_size=5000,
        chunk_size=512,
        n_bins=n_bins
    )
    bed = adatas[s].get_dataframe('var').loc[:,['chrom','start','end']]
    adatas[s] = crandata.seq_io.add_genome_sequences_to_crandata(adatas[s], bed, genomes[s])
    print(adatas[s]['sequences'])
    adatas[s]['var-_-species'] = xr.DataArray(np.repeat(species_codes[s],adatas[s].sizes['var']),dims='var').chunk({'var':adatas[s].attrs['chunk_size']})
    # adatas[s].to_icehunk(mode='a',commit_name='add_genome_seqs')
    adatas[s].to_zarr(adatas[s].encoding['source'],mode='a')
    # adatas[s] = crandata.crandata.CrAnData.open_zarr('/home/matthew.schmitz/Matthew/'+s+'_spc_test.zarr')
    

mouse


100%|██████████| 49/49 [00:00<00:00, 549.31it/s]


2025-04-14T18:50:28.812808-0700 INFO Extracting values from 48 bigWig files...


  return cls(**configuration_parsed)
  meta = AsyncArray._create_metadata_v3(
  return cls(**configuration_parsed)
  meta = AsyncArray._create_metadata_v3(
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  for chrom, start, end in [line.split("\t")[:3] for line in lines]
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  meta = AsyncArray._create_metadata_v3(
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
Processing memory tiles: 100%|██████████| 502/502 [33:14<00:00,  3.97s/it]
  raise ValueError("Unsupported backend. Use 'zarr' or 'icechunk'.")

<xarray.DataArray 'sequences' (var: 2508122, seq_len: 2124, nuc: 4)> Size: 21GB
dask.array<xarray-<this-array>, shape=(2508122, 2124, 4), dtype=uint8, chunksize=(512, 2124, 4), chunktype=numpy.ndarray>
Coordinates:
  * var      (var) object 20MB 'chr1:2999767-3001881' ... 'JH584292.1:12685-1...
Dimensions without coordinates: seq_len, nuc


  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)


In [None]:
# #Alternate workflow to directly write icechunks, but this is ~5x slower (better to write pure zarr3 then convert the the whole store at once)
# adatas = {}

# for s in species:
#     print(s)
#     bigwigs_dir = os.path.join('/allen/programs/celltypes/workgroups/rnaseqanalysis/EvoGen/SpinalCord/manuscript/ATAC',s,'Group_bigwig')
#     adatas[s] = crandata.chrom_io.import_bigwigs(
#         bigwigs_folder=bigwigs_dir,
#         regions_file=bed_files[s],
#         backed_path='/home/matthew.schmitz/Matthew/'+s+'_spc_test.icechunk',
#         target_region_width=WINDOW_SIZE,
#         chromsizes_file=chromsizes_files[s],
#         target = 'raw',
#         max_stochastic_shift=5,
#         chunk_size=512,
#         backend='icechunk',
#         n_bins=n_bins
#     )
#     bed = adatas[s].get_dataframe('var').loc[:,['chrom','start','end']]
#     adatas[s] = crandata.seq_io.add_genome_sequences_to_crandata(adatas[s], bed, genomes[s])
#     print(adatas[s]['sequences'])
#     adatas[s]['var-_-species'] = xr.DataArray(np.repeat(species_codes[s],adatas[s].sizes['var']),dims='var').chunk({'var':adatas[s].attrs['chunk_size']})
#     adatas[s].to_icechunk(mode='a',commit_name='add_genome_seqs')
#     # adatas[s].to_zarr(adatas[s].encoding['source'],mode='a')
#     adatas[s] = crandata.crandata.CrAnData.open_icechunk('/home/matthew.schmitz/Matthew/'+s+'_spc_test.icechunk')
    

human


100%|██████████| 49/49 [00:00<00:00, 500.38it/s]


2025-04-14T12:02:42.163578-0700 INFO Extracting values from 49 bigWig files...


  return cls(**configuration_parsed)
  result = await AsyncArray._create_v3(
  return cls(**configuration_parsed)
  result = await AsyncArray._create_v3(
Processing memory tiles:   0%|          | 2/556 [03:11<14:10:06, 92.07s/it] 

In [None]:
for s in adatas.keys():
    crandata.train_val_test_split(
        adatas[s], strategy="region", val_size=0.1, test_size=0.1, random_state=42
    )
    adatas[s].to_zarr(adatas[s].encoding['source'],mode='a')
    # adatas[s].to_icechunk(mode='a',commit_name='train_val_test_split') #If you're using icechunk store
    


  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  _add(result[val], k)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  meta = AsyncArray._create_metadata_v3(
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)


In [None]:
print(adatas[s])
print(adatas[s]['X'])
print(adatas[s]['sequences'])

In [28]:
# crandata.utils = importlib.reload(crandata.utils)
for s in tqdm(species):
    crandata.utils.convert_zarr_store_to_icechunk('/home/matthew.schmitz/Matthew/'+s+'_spc_test.zarr', 
                                                  '/home/matthew.schmitz/Matthew/'+s+'_spc_test.icechunk')
    adatas[s] = crandata.CrAnData.open_icechunk('/home/matthew.schmitz/Matthew/'+s+'_spc_test.icechunk',
                                                cache_config={'num_bytes_chunks':int(8e9)})#Cache 8Gb

  0%|          | 0/3 [00:00<?, ?it/s]

Conversion complete. Commit ID: SRWYSK94ST47318A5TV0


  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
 33%|███▎      | 1/3 [04:50<09:40, 290.02s/it]

Conversion complete. Commit ID: 8Y62G1KR1WQW3WJVJFW0


  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  path = os.fspath(arg)


Conversion complete. Commit ID: GZJ42CCZ6A0YNS45XP60


  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
  return cls(**configuration_parsed)
100%|██████████| 3/3 [14:21<00:00, 287.07s/it]


In [None]:
transform = crandata.seq_io.DNATransform(out_len=WINDOW_SIZE, random_rc=True, max_shift=MAX_SHIFT)

meta_module = crandata.MetaCrAnDataModule(
    adatas=list(adatas.values()),
    batch_size=[16,16,16],
    load_keys={'X': 'y','sequences':'sequence','var-_-species':'species'},
    dnatransform=transform,
    join='inner',
    num_workers=0,
    epoch_size=1000000
)

# Setup the meta module for the "fit" stage (train/val)
meta_module.setup("train")

# Retrieve the training dataloader from the meta module and iterate over a couple of batches.
meta_train_dl = meta_module.train_dataloader

print("\nIterating over a couple of training batches from MetaAnnDataModule:")
for i, batch in enumerate(tqdm(meta_train_dl)):
    print(f"Meta Batch {i}:")
    for key, tensor in batch.items():
        print(f"  {key}: shape {tensor.shape}")
    if i == 5:
        break


In [None]:
import cProfile

code = '''
for i, batch in enumerate(tqdm(meta_train_dl)):
    print(f"Meta Batch {i}:")
    for key, tensor in batch.items():
        print(f"  {key}: shape {tensor.shape}")
    if i == 5:
        break
'''

out = cProfile.run(code,sort=True)


In [None]:
meta_module.load()
meta_train_dl = meta_module.train_dataloader


In [None]:
code = '''
for i, batch in enumerate(tqdm(meta_train_dl)):
    print(f"Meta Batch {i}:")
    for key, tensor in batch.items():
        print(f"  {key}: shape {tensor.shape}")
    if i == 50:
        break
'''

out = cProfile.run(code,sort=True)


In [None]:
model_architecture = crested.tl.zoo.simple_convnet(
    seq_len=2114, num_classes=batch['y'].shape[1]
)


In [None]:
import keras
# Create your own configuration
# I recommend trying this for peak regression with a weighted cosine mse log loss function
optimizer = keras.optimizers.Adam(learning_rate=1e-5)
loss = crested.tl.losses.CosineMSELogLoss(max_weight=100, multiplier=1)
loss = crested.tl.losses.PoissonLoss()

metrics = [
    keras.metrics.MeanAbsoluteError(),
    # keras.metrics.MeanSquaredError(),
    # keras.metrics.CosineSimilarity(axis=1),
    crested.tl.metrics.PearsonCorrelation(),
    # crested.tl.metrics.ConcordanceCorrelationCoefficient(),
    # crested.tl.metrics.PearsonCorrelationLog(),
    # crested.tl.metrics.ZeroPenaltyMetric(),
]

alternative_config = crested.tl.TaskConfig(optimizer, loss, metrics)
print(alternative_config)


In [None]:
batch['sequence'].shape

In [None]:
# initialize some lazy model parameters *yawn*
model_architecture(batch['sequence'].float().mean(0).unsqueeze(0))

In [None]:
trainer = crested.tl.Crested(
    data=meta_module,
    model=model_architecture,
    config=alternative_config,
    project_name="mouse_biccn",  # change to your liking
    run_name="basemodel",  # change to your liking
    logger=None,  # or None, 'dvc', 'tensorboard'
    seed=7,  # For reproducibility
)
# train the model
trainer.fit(
    epochs=60,
    learning_rate_reduce_patience=3,
    early_stopping_patience=6,
)
