In [1]:
import os
import tempfile
from pathlib import Path
import pandas as pd
import numpy as np
import pyBigWig
import copy
import xarray as xr
import tqdm

# Import our new module system and utilities.
from crandata import CrAnDataModule, MetaCrAnDataModule, CrAnData
from crandata.chrom_io import import_bigwigs
from crandata.seq_io import add_genome_sequences_to_crandata, DNATransform

# Create temporary directories for synthetic data.
temp_dir = tempfile.TemporaryDirectory()
base_dir = Path(temp_dir.name)
beds_dir = base_dir / "beds"
bigwigs_dir = base_dir / "bigwigs"
beds_dir.mkdir(exist_ok=True)
bigwigs_dir.mkdir(exist_ok=True)

# Create a chromsizes file.
chromsizes_file = base_dir / "chrom.sizes"
with open(chromsizes_file, "w") as f:
    f.write("chr1\t1000\n")

# Create two BED files (simulate two different classes).
bed_data_A = pd.DataFrame({0: ["chr1", "chr1"],
                           1: [100, 300],
                           2: [200, 400]})
bed_data_B = pd.DataFrame({0: ["chr1", "chr1"],
                           1: [150, 350],
                           2: [250, 450]})
bed_file_A = beds_dir / "ClassA.bed"
bed_file_B = beds_dir / "ClassB.bed"
bed_data_A.to_csv(bed_file_A, sep="\t", header=False, index=False)
bed_data_B.to_csv(bed_file_B, sep="\t", header=False, index=False)

# Create a consensus BED file.
consensus = pd.DataFrame({0: ["chr1", "chr1", "chr1"],
                          1: [100, 300, 350],
                          2: [200, 400, 450]})
consensus_file = base_dir / "consensus.bed"
consensus.to_csv(consensus_file, sep="\t", header=False, index=False)

# Create two bigWig files.
bigwig_file1 = bigwigs_dir / "test.bw"
bw1 = pyBigWig.open(str(bigwig_file1), "w")
bw1.addHeader([("chr1", 1000)])
bw1.addEntries(chroms=["chr1"], starts=[0], ends=[1000], values=[5.0])
bw1.close()

bigwig_file2 = bigwigs_dir / "test2.bw"
bw2 = pyBigWig.open(str(bigwig_file2), "w")
bw2.addHeader([("chr1", 1000)])
bw2.addEntries(chroms=["chr1"], starts=[0], ends=[1000], values=[4.0])
bw2.close()

# Set extraction parameters.
target_region_width = 100
backed_path = base_dir / "chrom_data.nc"

# Create the CrAnData object from bigWig files and consensus regions.
adata = import_bigwigs(
    bigwigs_folder=str(bigwigs_dir),
    regions_file=str(consensus_file),
    backed_path=str(backed_path),
    target_region_width=target_region_width,
    chromsizes_file=str(chromsizes_file),
)

# Create a dummy FASTA file for a genome.
fasta_file = base_dir / "chr1.fa"
with open(fasta_file, "w") as f:
    f.write(">chr1\n")
    f.write("A" * 1000 + "\n")

# Create a Genome object.
from crandata._genome import Genome
dummy_genome = Genome(str(fasta_file), chrom_sizes=str(chromsizes_file))

# Add sequences to the CrAnData using the provided seq_io utility.
# Here we use the consensus regions as our ranges.
consensus.columns = ['chrom','start','end']
adata = add_genome_sequences_to_crandata(adata, consensus, dummy_genome)

# Write the CrAnData object to disk and then reload it to ensure sequences are out-of-memory.
adata.to_netcdf(str(backed_path))
adata_loaded = CrAnData.open_dataset(str(backed_path))
print("Loaded CrAnData:")
print(adata_loaded)

# Create two copies to simulate two datasets (e.g. two species), and add a "split" column in var metadata.
adata1 = copy.deepcopy(adata_loaded)
adata2 = copy.deepcopy(adata_loaded)
adata1["var/split"] = xr.DataArray(np.full(adata1.sizes["var"], "train"), dims=["var"])
adata2["var/split"] = xr.DataArray(np.full(adata2.sizes["var"], "train"), dims=["var"])

# Create a DNATransform instance.
transform = DNATransform(out_len=80, random_rc=True, max_shift=5)

# Instantiate the MetaCrAnDataModule with the two datasets.
meta_module = MetaCrAnDataModule(
    adatas=[adata1, adata2],
    batch_size=4,        # small batch size for testing
    shuffle=True,
    dnatransform=transform,
    epoch_size=10
)

# Setup each underlying module for the "train" stage.
for mod in meta_module.modules:
    mod.setup(state="train")

# Retrieve the training dataloader from the meta module and iterate over a couple of batches.
meta_train_dl = meta_module.train_dataloader
print("\nIterating over a couple of training batches from MetaCrAnDataModule:")
for i, batch in enumerate(tqdm.tqdm(meta_train_dl)):
    print(f"\nMeta Batch {i}:")
    for key, tensor in batch.items():
        print(f"  {key}: shape {tensor.shape}")
    if i >= 1:
        break

print("\nTemporary directory contents:")
print(os.listdir(base_dir))
temp_dir.cleanup()


  cls = super().__new__(mcls, name, bases, namespace, **kwargs)
100%|██████████| 2/2 [00:00<00:00, 13819.78it/s]
[32m2025-03-17 21:58:11.698[0m | [1mINFO    [0m | [36mcrandata.chrom_io[0m:[36mimport_bigwigs[0m:[36m414[0m - [1mExtracting values from 2 bigWig files...[0m
2it [00:00, 101.20it/s]
  setattr(self, safe_name, self.data_vars[key])
  setattr(self, safe_name, self.data_vars[key])
  setattr(self, safe_name, self.data_vars[key])
  setattr(self, safe_name, self.data_vars[key])
  setattr(self, safe_name, self.data_vars[key])
  setattr(self, safe_name, self.data_vars[key])
  setattr(self, safe_name, self.data_vars[key])
  setattr(self, safe_name, self.data_vars[key])
  setattr(self, safe_name, self.data_vars[key])
  setattr(self, safe_name, self.data_vars[key])
  setattr(self, safe_name, self.data_vars[key])
  setattr(self, safe_name, self.data_vars[key])
  setattr(self, safe_name, self.data_vars[key])
  setattr(self, safe_name, self.data_vars[key])
  setattr(self, safe_n

Loaded CrAnData:
CrAnData object
Array names: ['sequences', 'var/index', 'var/chrom', 'var/start', 'var/end', 'var/chunk_index', 'obs/file_path', 'obs/index', 'X', 'always_convert_df']
Coordinates: ['var', 'obs', 'seq_bins']



ValueError: input sample length must be less than or equal to the dimension length, but the sample length of 4 is greater than the dimension length of 3 for var

In [None]:
import json
json.dumps({'hi':74})

In [None]:
for n in adata.array_names:
    print('Array:',n)
    print(adata[n])

In [None]:
adata.always_convert_df = ['obs','var']

In [None]:
dict(adata.dims)

In [None]:
import xbatcher

dim_dict = dict(adata.dims)
batch_size = 1  # Each batch will include one "var" slice.
dim_dict['var'] = batch_size
del dim_dict['item'] #Delete any 0 dimensional dims items
bgen = xbatcher.BatchGenerator(
    ds=adata,
    input_dims=dim_dict,
    batch_dims={'var':batch_size},
)

# Test the dataloader by iterating over the batches and printing variable shapes.
for i, batch in enumerate(bgen):
    print(f"Batch {i}:")
    for var_name, da in batch.data_vars.items():
        print(f"  {var_name}: shape {da.shape}")

#you can also index in as if it were a list (so this can be given to pytorch which can random sample batches)
bgen[1]

In [None]:
batch['X']

In [None]:
batch['X'].coords

In [None]:
ds.dims

In [None]:
dim_dict

In [None]:
batch['X'].shape

In [None]:
adata.array_names

In [None]:
adata['var/chrom']

In [None]:
adata.always_convert_df = []

In [None]:
adata_loaded['obsm/gex']

In [None]:
adata_loaded['obs/index']

In [None]:
var_df = consensus_peaks.set_index("region")
var_df["chunk_index"] = np.arange(var_df.shape[0]) // chunk_size
# do something like:
var_df = consensus_peaks.set_index("region")
var_df["chunk_index"] = np.arange(var_df.shape[0]) // chunk_size


In [None]:
adata = crandata.CrAnData(always_convert_df=["df"], global_axis_order=["obs"])

# Create two 1D arrays of length 10 and store them under hierarchical keys "df/col1" and "df/col2".
adata["df/col1"] = xr.DataArray(np.arange(10), dims=["obs"])
adata["df/col2"] = xr.DataArray(np.arange(10, 20), dims=["obs"])

# Check that array_names is populated.
print("Array names:", adata.array_names)  # Should print: ['df/col1', 'df/col2']

# Using get_dataframe() returns a grouped DataFrame.
print("\nAccess grouped DataFrame via get_dataframe('df'):")
df_group = adata.get_dataframe("df")
print(df_group)

# Standard pandas indexing.
print("\nAccess a specific element using .loc on the grouped DataFrame:")
print(df_group.loc[5, "col2"])

# Hierarchical indexing: since "df" is in always_convert_df, adata["df/col1"] returns that column.
print("\nAccess column 'col1' via hierarchical indexing:")
print(adata["df/col1"])

# Dynamic attribute access: "df/col1" is accessible as adata.df_col1.
print("\nAccess column 'col1' via dynamic attribute (underscore notation):")
print(adata.df_col1)

# To see the grouped DataFrame view, call get_dataframe.
print("\nAccess grouped DataFrame via get_dataframe('df'):")
print(adata.get_dataframe("df"))

# ----- Write to disk and read back -----
adata.to_netcdf("test_crandata.nc")
print("\nWrote CrAnData to 'test_crandata.nc'.")

# Use our class method to open the file.
adata_loaded = crandata.CrAnData.open_dataset("test_crandata.nc")
print("\nLoaded CrAnData from disk:")
print(adata_loaded)

print("\nAccess grouped DataFrame from loaded data via get_dataframe('df'):")
print(adata_loaded.get_dataframe("df"))

print("\nAccess column 'col1' from loaded data using hierarchical indexing:")
print(adata_loaded["df/col1"])

os.remove("test_crandata.nc")


In [None]:
adata_loaded.df

In [None]:
import xarray as xr
import pandas as pd
import numpy as np
import os
import json

class CrAnData(xr.Dataset):
    __slots__ = ("always_convert_df", "__dict__")
    
    def __init__(self, 
                 data_vars=None,
                 coords=None,
                 always_convert_df=None,  # list of top-level keys to be grouped into a DataFrame on access
                 **kwargs):
        """
        Create a CrAnData object as a subclass of xarray.Dataset.
        
        Parameters:
          data_vars: dictionary of data variables.
          coords: dictionary of coordinates.
          always_convert_df: list (or array) of top-level keys that should be grouped into a DataFrame on access.
          kwargs: additional data variables (merged into data_vars).
        """
        # Merge provided data_vars with kwargs.
        if data_vars is None:
            data_vars = {}
        data_vars = dict(data_vars)  # make a copy
        data_vars.update(kwargs)
        # Ensure every variable is an xr.DataArray.
        for key, var in data_vars.items():
            if not isinstance(var, xr.DataArray):
                data_vars[key] = xr.DataArray(var)
        if coords is None:
            coords = {}
            
        # Initialize the underlying xarray.Dataset.
        super().__init__(data_vars=data_vars, coords=coords)
        
        if always_convert_df is None:
            always_convert_df = []
        self.always_convert_df = always_convert_df

        # Save custom attributes (as JSON) into ds.attrs so they are stored on disk.
        self.attrs["always_convert_df"] = json.dumps(self.always_convert_df)
        
        # For keys that contain '/', add an instance attribute with "/" replaced by "_" 
        # for direct dynamic attribute access.
        for key in self.data_vars:
            if "/" in key:
                safe_name = key.replace("/", "_")
                object.__setattr__(self, safe_name, self.data_vars[key])

        if "var" in self.always_convert_df:
            grouped_var = self.get_dataframe("var")
            if grouped_var is not None:
                self.__dict__["var"] = grouped_var

    @property
    def array_names(self):
        """Return a list of the names of the data variables."""
        return list(self.data_vars.keys())
    
    def get_dataframe(self, top):
        """
        Group all data variables whose keys start with 'top/' into a pandas DataFrame.
        Assumes that each such variable is 1D and they share the same length.
        """
        cols = {}
        index = None
        for key in list(self.data_vars.keys()):
            if key.startswith(top + "/"):
                col_name = key.split("/", 1)[1]
                # Use super().__getitem__ to bypass our custom __getitem__
                da = super().__getitem__(key)
                cols[col_name] = da.values
                if top in da.coords:
                    index = da.coords[top].values
                else:
                    index = np.arange(da.shape[0])
        if cols:
            return pd.DataFrame(cols, index=index)
        else:
            return None
    
    def __getitem__(self, key):
        """
        Support hierarchical indexing.
        If key is exactly a top-level key that is in always_convert_df,
        return the grouped DataFrame.
        If key is a string containing "/" and its top-level part is in always_convert_df,
        then return the corresponding column from the grouped DataFrame.
        Otherwise, return the data variable corresponding to the full key.
        """
        if isinstance(key, str):
            if key in self.always_convert_df:
                df = self.get_dataframe(key)
                if df is None:
                    raise KeyError(f"No grouped data found for key '{key}'")
                return df
            if "/" in key:
                top, sub = key.split("/", 1)
                if top in self.always_convert_df:
                    df = self.get_dataframe(top)
                    if df is None:
                        raise KeyError(f"No grouped data found for key '{top}'")
                    return df[sub]
        return super().__getitem__(key)
    
    def __getattr__(self, attr):
        """
        Fallback attribute lookup.
        If attr is one of the keys in always_convert_df (e.g. "obs" or "var"),
        return the grouped DataFrame using get_dataframe.
        Otherwise, try to find a data variable whose key (with "/" replaced by "_") matches attr.
        """
        # If attr is one of the top-level keys to be grouped, return the full DataFrame.
        if attr in self.always_convert_df:
            df = self.get_dataframe(attr)
            if df is not None:
                return df
        # Otherwise, look for a matching data variable (e.g., "obs_somecol" for "obs/somecol").
        dv = object.__getattribute__(self, "data_vars")
        for key in dv:
            safe = key.replace("/", "_")
            if safe == attr:
                return dv[key]
        raise AttributeError(f"{type(self).__name__!r} object has no attribute {attr!r}")
    
    def __repr__(self):
        # Custom repr showing our array names, coordinate keys, and global axis order.
        rep = f"CrAnData object\nArray names: {self.array_names}\n"
        rep += f"Coordinates: {list(self.coords.keys())}\n"
        return rep
        
    def _repr_html_(self):
        return self.__repr__()
        
    @classmethod
    def open_dataset(cls, path, **kwargs):
        """
        Class method to open a NetCDF file and wrap it as a CrAnData object.
        Reads custom attributes (always_convert_df) from ds.attrs.
        """
        ds = xr.open_dataset(path, **kwargs)
        always_convert_df = json.loads(ds.attrs.get("always_convert_df", "[]"))
        return cls(data_vars=ds.data_vars, coords=ds.coords,
                   always_convert_df=always_convert_df)

    @classmethod
    def open_zarr(cls, store, **kwargs):
        """
        Class method to open a Zarr store and wrap it as a CrAnData object.
        Reads custom attributes (always_convert_df) from ds.attrs.
        """
        ds = xr.open_zarr(store, **kwargs)
        always_convert_df = json.loads(ds.attrs.get("always_convert_df", "[]"))
        return cls(data_vars=ds.data_vars, coords=ds.coords,
                   always_convert_df=always_convert_df)

# ----- Example usage -----

if __name__ == "__main__":
    # Create a new CrAnData object.
    # Indicate that keys under the top-level "df" should be grouped into a DataFrame.
    adata = CrAnData(always_convert_df=["df"], global_axis_order=["obs"])
    
    # Create two 1D arrays of length 10 and store them under hierarchical keys "df/col1" and "df/col2".
    adata["df/col1"] = xr.DataArray(np.arange(10), dims=["obs"])
    adata["df/col2"] = xr.DataArray(np.arange(10, 20), dims=["obs"])
    
    # Check that array_names is populated.
    print("Array names:", adata.array_names)  # Should print: ['df/col1', 'df/col2']
    
    # Using get_dataframe() returns a grouped DataFrame.
    print("\nAccess grouped DataFrame via get_dataframe('df'):")
    df_group = adata.get_dataframe("df")
    print(df_group)
    
    # Standard pandas indexing.
    print("\nAccess a specific element using .loc on the grouped DataFrame:")
    print(df_group.loc[5, "col2"])
    
    # Hierarchical indexing: since "df" is in always_convert_df, adata["df/col1"] returns that column.
    print("\nAccess column 'col1' via hierarchical indexing:")
    print(adata["df/col1"])
    
    # Dynamic attribute access: "df/col1" is accessible as adata.df_col1.
    print("\nAccess column 'col1' via dynamic attribute (underscore notation):")
    print(adata.df_col1)
    
    # To see the grouped DataFrame view, call get_dataframe.
    print("\nAccess grouped DataFrame via get_dataframe('df'):")
    print(adata.get_dataframe("df"))
    
    # ----- Write to disk and read back -----
    adata.to_netcdf("test_crandata.nc")
    print("\nWrote CrAnData to 'test_crandata.nc'.")
    
    # Use our class method to open the file.
    adata_loaded = CrAnData.open_dataset("test_crandata.nc")
    print("\nLoaded CrAnData from disk:")
    print(adata_loaded)

    print("\nAccess grouped DataFrame from loaded data via get_dataframe('df'):")
    print(adata_loaded.get_dataframe("df"))
    
    print("\nAccess column 'col1' from loaded data using hierarchical indexing:")
    print(adata_loaded["df/col1"])
    
    os.remove("test_crandata.nc")


In [None]:
adata_loaded.df

In [None]:
adata.obs

In [None]:
adata.var()

In [None]:
It strikes me that this is not a very good way to read from disk, as it might require you to read the whole object into memory and then try to initialize a crandata (does it?). What would it take to create a classmethod for CrAnData that just wrap the xarray open_dataset or open_zarr

In [None]:
import crandata
import xarray as xr
import pandas as pd
import numpy as np

In [None]:
X = xr.DataArray(np.arange(20).reshape(4, 5), dims=["obs", "var"])
obsm = {"embedding": xr.DataArray(np.random.rand(4, 2), dims=["obs", "other"])}
varm = {"feature": xr.DataArray(np.random.rand(5, 3), dims=["var", "other"])}
layers = {"layer1": X.copy()}
varp = {"contacts": xr.DataArray(np.random.rand(5, 5), dims=["var_0", "var_1"])}
obsp = {"adj": xr.DataArray(np.random.rand(4, 4), dims=["obs_0", "obs_1"])}
data = crandata.crandata.CrAnData(
    X, uns={"extra": "test"},
    obsm=obsm, varm=varm, layers=layers, varp=varp, obsp=obsp
)


In [None]:
data

In [None]:
adata.var['train_probs']

In [None]:
adata._propagate_missing_coordinates()

In [None]:
adata

In [None]:
for i, batch in enumerate(tqdm.tqdm(meta_train_dl.data)):
    print(f"Meta Batch {i}:")
    for key, tensor in batch.items():
        print(f"  {key}: shape {tensor.shape}")
        # print(tensor)
    # For quick testing, you can uncomment the following to break early:
    # if i == 1:
    #     break

print("Final directory contents:", os.listdir(base_dir))


In [None]:
batch['sequence'].shape

In [None]:
batch['hic'].shape

In [None]:
adata.global_axis_order

In [None]:
fff

In [None]:
import cProfile

code = '''
for i, batch in enumerate(meta_train_dl.data):
    print(f"Meta Batch {i}:")
    for key, tensor in batch.items():
        print(f"  {key}: shape {tensor.shape}")
'''

cProfile.run(code)


In [None]:
import crandata
import os
import crested
from tqdm import tqdm

In [None]:
genomes = {}
beds = {}
chromsizes_files = {}
bed_files = {}
species = ['mouse','human','macaque']

WINDOW_SIZE = 2114
OFFSET = WINDOW_SIZE // 2  # e.g., 50% overlap
N_THRESHOLD = 0.3
n_bins = WINDOW_SIZE//50

In [None]:
for s in species:
    genome_path = '/allen/programs/celltypes/workgroups/rnaseqanalysis/EvoGen/Team/Matthew/genome/onehots/'+s
    fasta_file = os.path.join(genome_path,s+'.fa')
    chrom_sizes = os.path.join(genome_path,s+'.fa.sizes')
    annotation_gtf_file = os.path.join(genome_path,s+'.annotation.gtf')
    chromsizes_files[s] = chrom_sizes
    genome = crandata.Genome(fasta_file, chrom_sizes, annotation_gtf_file)
    genomes[s] = genome
    # Set parameters for binning.
    
    # Optionally specify an output path for the BED file.
    OUTPUT_BED = os.path.join(genome_path, "binned_genome.bed")
    bed_files[s] = OUTPUT_BED
    # Generate bins and optionally write to disk.
    # binned_df = crandata.bin_genome(genome, WINDOW_SIZE, OFFSET, n_threshold=N_THRESHOLD, output_path=OUTPUT_BED)
    # print("Filtered bins:")
    # print(binned_df)

In [None]:
adatas = {}

for s in species:
    # bigwigs_dir = os.path.join('/allen/programs/celltypes/workgroups/rnaseqanalysis/EvoGen/SpinalCord/manuscript/ATAC',s,'Group_bigwig')
    # adata = crandata.chrom_io.import_bigwigs(
    #     bigwigs_folder=bigwigs_dir,
    #     regions_file=bed_files[s],
    #     backed_path='/home/matthew.schmitz/Matthew/'+s+'_spc_test.h5',
    #     target_region_width=WINDOW_SIZE,
    #     chromsizes_file=chromsizes_files[s],
    #     target = 'mean',
    #     n_bins=n_bins
    # )
    # adatas[s] = adata
    adatas[s] = crandata.crandata.CrAnData.from_h5('/home/matthew.schmitz/Matthew/'+s+'_spc_test.h5')
    

In [None]:
# import numpy as np
# adatas['mouse'].uns['chunk_size'] = 512
# adatas['human'].uns['chunk_size'] = 512
# adatas['macaque'].uns['chunk_size'] = 512
# adatas['mouse'].var["chunk_index"] = np.arange(adatas['mouse'].var.shape[0]) // 512
# adatas['human'].var["chunk_index"] = np.arange(adatas['human'].var.shape[0]) // 512
# adatas['macaque'].var["chunk_index"] = np.arange(adatas['macaque'].var.shape[0]) // 512


In [None]:
for s in adatas.keys():
    crested.pp.train_val_test_split(
        adatas[s], strategy="region", val_size=0.1, test_size=0.1, random_state=42
    )


In [None]:
meta_module = crandata._anndatamodule.MetaAnnDataModule(
    adatas=list(adatas.values()),
    genomes=list(genomes.values()),
    data_sources={'y': 'X'},
    in_memory=False,
    random_reverse_complement=True,
    max_stochastic_shift=10,
    deterministic_shift=False,
    shuffle_obs=False, obs_alignment = 'intersect',
    shuffle=True,
    batch_size=32,    # small batch size for testing
    epoch_size=1000000    # small epoch size for quick testing
)

# Setup the meta module for the "fit" stage (train/val)
meta_module.setup("fit")

# Retrieve the training dataloader from the meta module and iterate over a couple of batches.
meta_train_dl = meta_module.train_dataloader

print("\nIterating over a couple of training batches from MetaAnnDataModule:")
for i, batch in enumerate(tqdm(meta_train_dl.data)):
    print(f"Meta Batch {i}:")
    for key, tensor in batch.items():
        print(f"  {key}: shape {tensor.shape}")
    if i == 5:
        break


In [None]:
for i, batch in enumerate(tqdm(meta_train_dl.data)):
    print(f"Meta Batch {i}:")
    for key, tensor in batch.items():
        print(f"  {key}: shape {tensor.dtype}")
    if i == 5:
        break


In [None]:
import cProfile

code = '''
for i, batch in enumerate(meta_train_dl.data):
    # print(f"Meta Batch {i}:")
    # for key, tensor in batch.items():
    #     print(f"  {key}: shape {tensor.shape}")
    if i == 5:
        break
'''

out = cProfile.run(code,sort=True)


In [None]:
model_architecture = crested.tl.zoo.simple_convnet(
    seq_len=2114, num_classes=batch['y'].shape[1]
)


In [None]:
import keras
# Create your own configuration
# I recommend trying this for peak regression with a weighted cosine mse log loss function
optimizer = keras.optimizers.Adam(learning_rate=1e-5)
loss = crested.tl.losses.CosineMSELogLoss(max_weight=100, multiplier=1)
loss = crested.tl.losses.PoissonLoss()

metrics = [
    keras.metrics.MeanAbsoluteError(),
    # keras.metrics.MeanSquaredError(),
    # keras.metrics.CosineSimilarity(axis=1),
    crested.tl.metrics.PearsonCorrelation(),
    # crested.tl.metrics.ConcordanceCorrelationCoefficient(),
    # crested.tl.metrics.PearsonCorrelationLog(),
    # crested.tl.metrics.ZeroPenaltyMetric(),
]

alternative_config = crested.tl.TaskConfig(optimizer, loss, metrics)
print(alternative_config)


In [None]:
# initialize some lazy model parameters *yawn*
model_architecture(batch)

In [None]:
trainer = crested.tl.Crested(
    data=meta_module,
    model=model_architecture,
    config=alternative_config,
    project_name="mouse_biccn",  # change to your liking
    run_name="basemodel",  # change to your liking
    logger=None,  # or None, 'dvc', 'tensorboard'
    seed=7,  # For reproducibility
)
# train the model
trainer.fit(
    epochs=60,
    learning_rate_reduce_patience=3,
    early_stopping_patience=6,
)
