# Train neural networks with `momics` and `tensorflow`

`momics` provides several useful resources to train neural networks with `tensorflow`. This notebook demonstrates how to train a simple neural network with `momics` and `tensorflow`.

## Connect to the data

First, we need to connect to the data. We will reuse the repository from the [previous tutorial](integrating-multiomics). 

In [4]:
from momics.momics import Momics

## Creating repository
repo = Momics("yeast_CNN_data.momics")

## Check that sequence and some tracks are registered
repo.seq()
repo.tracks()

Unnamed: 0,idx,label,path
0,0,ATAC,/data/momics/S288c_atac.bw
1,1,MNase,/data/momics/S288c_mnase.bw
2,2,ATAC_rescaled,tmp1ae0oz8z
3,3,MNase_rescaled,tmpxo8t09eu


## Modify some tracks

We can first pre-process the tracks to normalize them, and save them back to the local repository.

In [5]:
import numpy as np

for track in ["ATAC", "MNase"]:
    cov = repo.tracks(track)
    q99 = np.nanpercentile(np.concatenate(list(cov.values())), 99)
    for chrom in cov.keys():
        arr = cov[chrom]
        arr = np.minimum(arr, q99)
        arr = (arr - np.nanmin(arr)) / (np.nanmax(arr) - np.nanmin(arr))
        cov[chrom] = np.nan_to_num(arr, nan=0)
    repo.ingest_track(cov, track + "_rescaled")

repo.tracks()

ValueError: Provided label 'ATAC_rescaled' already present in `tracks` table

## Define model 

We will define a simple convolutional neural network with `tensorflow` to predict the target variable `ATAC` from the feature variable `MNase`. This requires to first define `MomicsDataset` objects for a training set, a validation set, and a test set. We will use `MNase` coverage scores over tiling genomic windows (width: 2048, stride: 2048) as feature variables to predict `ATAC` coverage scores over the same tiling genomic windows, but narrowed down to the 16bp around the center of the window.

In [6]:
from momics.dataset import MomicsDataset
import momics.utils as mutils

# Fetch data from the momics repository
features = "MNase_rescaled"
features_size = 2048
target = "ATAC_rescaled"
target_size = 16
stride = 2048
batch_size = 1024

bins = repo.bins(width=features_size, stride=stride, cut_last_bin_out=True)
bins_split, bins_test = mutils.split_ranges(bins, 0.8, shuffle=False)
bins_train, bins_val = mutils.split_ranges(bins_split, 0.8, shuffle=True)

train_dataset = MomicsDataset(repo, bins_train, features, target, target_size=target_size, batch_size=batch_size)
val_dataset = MomicsDataset(repo, bins_val, features, target, target_size=target_size, batch_size=batch_size)
test_dataset = MomicsDataset(repo, bins_test, features, target, target_size=target_size, batch_size=batch_size)

train_dataset  # noqa: B018
val_dataset  # noqa: B018
test_dataset

2025-01-29 13:56:19.009933: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-01-29 13:56:19.021894: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1738155379.034796  209475 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1738155379.038236  209475 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-01-29 13:56:19.053073: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instr

<_ZipDataset element_spec=((TensorSpec(shape=(None, 2048, 1), dtype=tf.float32, name=None),), (TensorSpec(shape=(None, 16, 1), dtype=tf.float32, name=None),))>

In [7]:
from momics.chromnn import ChromNN
import tensorflow as tf
from tensorflow.keras import layers  # type: ignore

model = ChromNN(
    layers.Input(shape=(features_size, 1)),
    layers.Dense(target_size, activation="linear"),
).model

model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.001), loss="mse")
model.summary()

ValueError: The total size of the tensor must be unchanged. Received: input_shape=(16,), target_shape=(1, 1)

## Fit the model 

In [4]:
import os
from pathlib import Path
from tensorflow.keras.callbacks import CSVLogger, EarlyStopping, ModelCheckpoint, ReduceLROnPlateau  # type: ignore

os.makedirs(".chromnn", exist_ok=True)
callbacks_list = [
    CSVLogger(Path(".chromnn", "epoch_data.csv")),
    ModelCheckpoint(filepath=Path(".chromnn", "Checkpoint.keras").name, monitor="val_correlate", save_best_only=True),
    EarlyStopping(monitor="val_loss", patience=10, min_delta=1e-5, restore_best_weights=True),
    ReduceLROnPlateau(monitor="val_loss", factor=0.1, patience=6 // 2, min_lr=0.1 * 0.001),
]
model.fit(
    train_dataset,
    validation_data=val_dataset,
    epochs=20,
    callbacks=callbacks_list,
    steps_per_epoch=len(bins_train) // batch_size,
)

Epoch 1/20


I0000 00:00:1729497492.875502 2553417 service.cc:146] XLA service 0x7fe100009520 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1729497492.875515 2553417 service.cc:154]   StreamExecutor device (0): Quadro P400, Compute Capability 6.1
2024-10-21 09:58:12.942499: I tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc:268] disabling MLIR crash reproducer, set env var `MLIR_CRASH_REPRODUCER_DIRECTORY` to enable.
2024-10-21 09:58:13.279690: I external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:531] Loaded cuDNN version 8907
2024-10-21 09:58:18.083385: W external/local_tsl/tsl/framework/bfc_allocator.cc:291] Allocator (GPU_0_bfc) ran out of memory trying to allocate 1.70GiB with freed_by_count=0. The caller indicates that this is not a failure, but this may mean that there could be performance gains if more memory were available.
2024-10-21 09:58:21.314785: E external/local_xla/xla/service/slow_operation_alarm.cc:65] Trying alg

[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 453ms/step - loss: 2.2423

2024-10-21 09:58:34.701244: W external/local_tsl/tsl/framework/bfc_allocator.cc:291] Allocator (GPU_0_bfc) ran out of memory trying to allocate 3.90GiB with freed_by_count=0. The caller indicates that this is not a failure, but this may mean that there could be performance gains if more memory were available.
2024-10-21 09:58:36.423690: W external/local_tsl/tsl/framework/bfc_allocator.cc:291] Allocator (GPU_0_bfc) ran out of memory trying to allocate 1.58GiB with freed_by_count=0. The caller indicates that this is not a failure, but this may mean that there could be performance gains if more memory were available.


[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m27s[0m 3s/step - loss: 2.2155 - val_loss: 0.0260 - learning_rate: 0.0010
Epoch 2/20


2024-10-21 09:58:36.827083: I tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
	 [[{{node IteratorGetNext}}]]
2024-10-21 09:58:36.827107: I tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
	 [[{{node IteratorGetNext}}]]
	 [[IteratorGetNext/_4]]
2024-10-21 09:58:36.827115: I tensorflow/core/framework/local_rendezvous.cc:423] Local rendezvous recv item cancelled. Key hash: 5039837812499872075
2024-10-21 09:58:36.827128: I tensorflow/core/framework/local_rendezvous.cc:423] Local rendezvous recv item cancelled. Key hash: 18235436505003952707
  self.gen.throw(typ, value, traceback)
  self._save_model(epoch=epoch, batch=None, logs=logs)
2024-10-21 09:58:41.179100: W external/local_tsl/tsl/framework/bfc_allocator.cc:291] Allocator (GPU_0_bfc) ran out of memory trying to allocate 1.20GiB with freed_by_count=0. The caller indicates that this is not 

[1m1/3[0m [32m━━━━━━[0m[37m━━━━━━━━━━━━━━[0m [1m30s[0m 15s/step - loss: 1.7935

2024-10-21 09:58:52.193159: I tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
	 [[{{node IteratorGetNext}}]]
	 [[IteratorGetNext/_4]]
2024-10-21 09:58:52.193190: I tensorflow/core/framework/local_rendezvous.cc:423] Local rendezvous recv item cancelled. Key hash: 5039837812499872075
2024-10-21 09:58:52.380865: I tensorflow/core/framework/local_rendezvous.cc:423] Local rendezvous recv item cancelled. Key hash: 5039837812499872075
  self.gen.throw(typ, value, traceback)
2024-10-21 09:58:52.380889: I tensorflow/core/framework/local_rendezvous.cc:423] Local rendezvous recv item cancelled. Key hash: 18235436505003952707


[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m16s[0m 117ms/step - loss: 1.7935 - val_loss: 0.0258 - learning_rate: 0.0010
Epoch 3/20


  self._save_model(epoch=epoch, batch=None, logs=logs)


[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 546ms/step - loss: 1.5613 - val_loss: 0.0255 - learning_rate: 0.0010
Epoch 4/20


2024-10-21 09:58:54.181994: I tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
	 [[{{node IteratorGetNext}}]]
	 [[IteratorGetNext/_4]]
  self.gen.throw(typ, value, traceback)
2024-10-21 09:58:54.182015: I tensorflow/core/framework/local_rendezvous.cc:423] Local rendezvous recv item cancelled. Key hash: 5039837812499872075
2024-10-21 09:58:54.182023: I tensorflow/core/framework/local_rendezvous.cc:423] Local rendezvous recv item cancelled. Key hash: 18235436505003952707
  self._save_model(epoch=epoch, batch=None, logs=logs)


[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 117ms/step - loss: 1.2968 - val_loss: 0.0256 - learning_rate: 0.0010
Epoch 5/20


2024-10-21 09:58:54.749458: I tensorflow/core/framework/local_rendezvous.cc:423] Local rendezvous recv item cancelled. Key hash: 5039837812499872075
  self.gen.throw(typ, value, traceback)
2024-10-21 09:58:54.749481: I tensorflow/core/framework/local_rendezvous.cc:423] Local rendezvous recv item cancelled. Key hash: 18235436505003952707
  self._save_model(epoch=epoch, batch=None, logs=logs)


[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 487ms/step - loss: 1.1518 - val_loss: 0.0258 - learning_rate: 0.0010
Epoch 6/20


  self.gen.throw(typ, value, traceback)
  self._save_model(epoch=epoch, batch=None, logs=logs)


[1m1/3[0m [32m━━━━━━[0m[37m━━━━━━━━━━━━━━[0m [1m0s[0m 450ms/step - loss: 0.9550

2024-10-21 09:58:56.943969: I tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
	 [[{{node IteratorGetNext}}]]
	 [[IteratorGetNext/_4]]
2024-10-21 09:58:57.124772: I tensorflow/core/framework/local_rendezvous.cc:423] Local rendezvous recv item cancelled. Key hash: 5039837812499872075
2024-10-21 09:58:57.124795: I tensorflow/core/framework/local_rendezvous.cc:423] Local rendezvous recv item cancelled. Key hash: 18235436505003952707
  self.gen.throw(typ, value, traceback)


[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 111ms/step - loss: 0.9550 - val_loss: 0.0258 - learning_rate: 0.0010
Epoch 7/20


  self._save_model(epoch=epoch, batch=None, logs=logs)


[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 495ms/step - loss: 0.9325 - val_loss: 0.0266 - learning_rate: 1.0000e-04
Epoch 8/20


2024-10-21 09:58:58.840624: I tensorflow/core/framework/local_rendezvous.cc:423] Local rendezvous recv item cancelled. Key hash: 5039837812499872075
2024-10-21 09:58:58.840647: I tensorflow/core/framework/local_rendezvous.cc:423] Local rendezvous recv item cancelled. Key hash: 18235436505003952707
  self.gen.throw(typ, value, traceback)
  self._save_model(epoch=epoch, batch=None, logs=logs)


[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 117ms/step - loss: 0.8977 - val_loss: 0.0268 - learning_rate: 1.0000e-04
Epoch 9/20


2024-10-21 09:58:59.506671: I tensorflow/core/framework/local_rendezvous.cc:423] Local rendezvous recv item cancelled. Key hash: 5039837812499872075
2024-10-21 09:58:59.506712: I tensorflow/core/framework/local_rendezvous.cc:423] Local rendezvous recv item cancelled. Key hash: 18235436505003952707
  self.gen.throw(typ, value, traceback)
  self._save_model(epoch=epoch, batch=None, logs=logs)


[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 552ms/step - loss: 0.8974 - val_loss: 0.0275 - learning_rate: 1.0000e-04
Epoch 10/20


2024-10-21 09:59:01.361780: I tensorflow/core/framework/local_rendezvous.cc:423] Local rendezvous recv item cancelled. Key hash: 18235436505003952707
  self.gen.throw(typ, value, traceback)
  self._save_model(epoch=epoch, batch=None, logs=logs)


[1m1/3[0m [32m━━━━━━[0m[37m━━━━━━━━━━━━━━[0m [1m0s[0m 321ms/step - loss: 0.9267

2024-10-21 09:59:01.724832: I tensorflow/core/framework/local_rendezvous.cc:423] Local rendezvous recv item cancelled. Key hash: 5039837812499872075
2024-10-21 09:59:01.724851: I tensorflow/core/framework/local_rendezvous.cc:423] Local rendezvous recv item cancelled. Key hash: 18235436505003952707
2024-10-21 09:59:01.913726: I tensorflow/core/framework/local_rendezvous.cc:423] Local rendezvous recv item cancelled. Key hash: 5039837812499872075
2024-10-21 09:59:01.913752: I tensorflow/core/framework/local_rendezvous.cc:423] Local rendezvous recv item cancelled. Key hash: 18235436505003952707
  self.gen.throw(typ, value, traceback)


[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 115ms/step - loss: 0.9267 - val_loss: 0.0277 - learning_rate: 1.0000e-04
Epoch 11/20


  self._save_model(epoch=epoch, batch=None, logs=logs)


[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 496ms/step - loss: 0.8864 - val_loss: 0.0283 - learning_rate: 1.0000e-04
Epoch 12/20


2024-10-21 09:59:03.627357: I tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
	 [[{{node IteratorGetNext}}]]
	 [[IteratorGetNext/_4]]
2024-10-21 09:59:03.627374: I tensorflow/core/framework/local_rendezvous.cc:423] Local rendezvous recv item cancelled. Key hash: 5039837812499872075
2024-10-21 09:59:03.627383: I tensorflow/core/framework/local_rendezvous.cc:423] Local rendezvous recv item cancelled. Key hash: 18235436505003952707
  self.gen.throw(typ, value, traceback)
  self._save_model(epoch=epoch, batch=None, logs=logs)


[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 111ms/step - loss: 0.8577 - val_loss: 0.0286 - learning_rate: 1.0000e-04
Epoch 13/20


2024-10-21 09:59:04.299576: I tensorflow/core/framework/local_rendezvous.cc:423] Local rendezvous recv item cancelled. Key hash: 5039837812499872075
  self.gen.throw(typ, value, traceback)
2024-10-21 09:59:04.299600: I tensorflow/core/framework/local_rendezvous.cc:423] Local rendezvous recv item cancelled. Key hash: 18235436505003952707
  self._save_model(epoch=epoch, batch=None, logs=logs)


[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 560ms/step - loss: 0.8691 - val_loss: 0.0293 - learning_rate: 1.0000e-04


2024-10-21 09:59:06.019402: I tensorflow/core/framework/local_rendezvous.cc:423] Local rendezvous recv item cancelled. Key hash: 5039837812499872075
  self.gen.throw(typ, value, traceback)
2024-10-21 09:59:06.019431: I tensorflow/core/framework/local_rendezvous.cc:423] Local rendezvous recv item cancelled. Key hash: 18235436505003952707
  self._save_model(epoch=epoch, batch=None, logs=logs)


<keras.src.callbacks.history.History at 0x7fe3d9412110>

## Evaluate and save model 

In [5]:
# Evaluate the model
model.evaluate(test_dataset)

# Save the model
model.save("chromnn_model.keras")

      1/Unknown [1m3s[0m 3s/step - loss: 0.0323

2024-10-21 09:59:58.373818: W external/local_tsl/tsl/framework/bfc_allocator.cc:291] Allocator (GPU_0_bfc) ran out of memory trying to allocate 4.21GiB with freed_by_count=0. The caller indicates that this is not a failure, but this may mean that there could be performance gains if more memory were available.


[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m7s[0m 4s/step - loss: 0.0270


2024-10-21 09:59:59.142341: I tensorflow/core/framework/local_rendezvous.cc:423] Local rendezvous recv item cancelled. Key hash: 5039837812499872075
2024-10-21 09:59:59.142361: I tensorflow/core/framework/local_rendezvous.cc:423] Local rendezvous recv item cancelled. Key hash: 18235436505003952707
  self.gen.throw(typ, value, traceback)


## Use the model to predict ATAC-seq coverage

In [9]:
from momics.momicsquery import MomicsQuery

## Define 2048-bp-wide genomic windows, with a stride of 16 bp, and extract MNase data from it.
bb = repo.bins(width=features_size, stride=16, cut_last_bin_out=True)["XVI"]
dat = MomicsQuery(repo, bb).query_tracks(tracks=["MNase_rescaled"])
dat = np.array(list(dat.coverage["MNase_rescaled"].values()))
dat = np.nan_to_num(dat, nan=0)

## Run predictions
predictions = model.predict(dat)

## Export predictions as a bigwig
midpoints = (bb.Start + bb.End) // 2
bb.Start = (midpoints - target_size / 2).apply(int)
bb.End = (midpoints + target_size / 2).apply(int)
bb.to_bed("predictions_bins.bed")
mutils.pyranges_to_bw(bb, predictions, "predictions_cov.bw")

momics :: INFO :: 2024-10-21 10:02:21,179 :: Query completed in 0.2572s.


[1m1848/1848[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m8s[0m 3ms/step
