Copyright 2021 DeepMind Technologies Limited

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

     https://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

This colab showcases training of the Enformer model published in

**"Effective gene expression prediction from sequence by integrating long-range interactions"**

Žiga Avsec, Vikram Agarwal, Daniel Visentin, Joseph R. Ledsam, Agnieszka Grabska-Barwinska, Kyle R. Taylor, Yannis Assael, John Jumper, Pushmeet Kohli, David R. Kelley


## Steps

- Setup tf.data.Dataset by directly accessing the Basenji2 data on GCS: `gs://basenji_barnyard/data`
- Train the model for a few steps, alternating training on human and mouse data batches
- Evaluate the model on human and mouse genomes

## Setup

**Start the colab kernel with GPU**: Runtime -> Change runtime type -> GPU

### Install dependencies

In [1]:
# !pip install dm-sonnet tqdm

In [2]:
# # Get enformer source code
# !wget -q https://raw.githubusercontent.com/deepmind/deepmind-research/master/enformer/attention_module.py
# !wget -q https://raw.githubusercontent.com/deepmind/deepmind-research/master/enformer/enformer.py

### Import

In [3]:
import tensorflow as tf
print(tf.__version__)
print(tf.test.is_built_with_cuda())
print(tf.test.is_gpu_available())

2024-12-10 14:13:25.145811: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2024-12-10 14:13:25.209327: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1733811205.236248   42076 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1733811205.244722   42076 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-12-10 14:13:25.314414: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instr

2.18.0
True
Instructions for updating:
Use `tf.config.list_physical_devices('GPU')` instead.
True


I0000 00:00:1733811206.678277   42076 gpu_device.cc:2022] Created device /device:GPU:0 with 22280 MB memory:  -> device: 0, name: NVIDIA GeForce RTX 4090, pci bus id: 0000:06:00.0, compute capability: 8.9


In [4]:
import torch
print(torch.__version__)
print(torch.version.cuda)
print(torch.cuda.is_available())

2.2.2
12.1
True


In [5]:
import sonnet as snt
from tqdm import tqdm
from IPython.display import clear_output
import numpy as np
import pandas as pd
import time
import os

import enformer

### Code

In [6]:
# @title `get_targets(organism)`
def get_targets(organism):
  # targets_txt = f'https://raw.githubusercontent.com/calico/basenji/master/manuscripts/cross2020/targets_{organism}.txt'
  targets_file = f'/home/shared/enformer_data/{organism}/targets.txt'
  return pd.read_csv(targets_file, sep='\t')

In [7]:
# @title `get_dataset(organism, subset, num_threads=8)`
import glob
import json
import functools


def organism_path(organism):
  # return os.path.join('gs://basenji_barnyard/data', organism)
  return os.path.join('/home/shared/enformer_data', organism) # my local path


def get_dataset(organism, subset, num_threads=8):
  metadata = get_metadata(organism)
  dataset = tf.data.TFRecordDataset(tfrecord_files(organism, subset),
                                    compression_type='ZLIB',
                                    num_parallel_reads=num_threads)
  dataset = dataset.map(functools.partial(deserialize, metadata=metadata),
                        num_parallel_calls=num_threads)
  return dataset


def get_metadata(organism):
  # Keys:
  # num_targets, train_seqs, valid_seqs, test_seqs, seq_length,
  # pool_width, crop_bp, target_length
  path = os.path.join(organism_path(organism), 'statistics.json')
  with tf.io.gfile.GFile(path, 'r') as f:
    return json.load(f)


def tfrecord_files(organism, subset):
  # Sort the values by int(*).
  return sorted(tf.io.gfile.glob(os.path.join(
      organism_path(organism), 'tfrecords', f'{subset}-*.tfr'
  )), key=lambda x: int(x.split('-')[-1].split('.')[0]))


def deserialize(serialized_example, metadata):
  """Deserialize bytes stored in TFRecordFile."""
  feature_map = {
      'sequence': tf.io.FixedLenFeature([], tf.string),
      'target': tf.io.FixedLenFeature([], tf.string),
  }
  example = tf.io.parse_example(serialized_example, feature_map)
  sequence = tf.io.decode_raw(example['sequence'], tf.bool)
  sequence = tf.reshape(sequence, (metadata['seq_length'], 4))
  sequence = tf.cast(sequence, tf.float32)

  target = tf.io.decode_raw(example['target'], tf.float16)
  target = tf.reshape(target,
                      (metadata['target_length'], metadata['num_targets']))
  target = tf.cast(target, tf.float32)

  return {'sequence': sequence,
          'target': target}


## Load dataset

In [8]:
df_targets_human = get_targets('human')
df_targets_human.head()

Unnamed: 0,index,genome,identifier,file,clip,scale,sum_stat,description
0,0,0,ENCFF833POA,/home/drk/tillage/datasets/human/dnase/encode/...,32,2,mean,DNASE:cerebellum male adult (27 years) and mal...
1,1,0,ENCFF110QGM,/home/drk/tillage/datasets/human/dnase/encode/...,32,2,mean,DNASE:frontal cortex male adult (27 years) and...
2,2,0,ENCFF880MKD,/home/drk/tillage/datasets/human/dnase/encode/...,32,2,mean,DNASE:chorion
3,3,0,ENCFF463ZLQ,/home/drk/tillage/datasets/human/dnase/encode/...,32,2,mean,DNASE:Ishikawa treated with 0.02% dimethyl sul...
4,4,0,ENCFF890OGQ,/home/drk/tillage/datasets/human/dnase/encode/...,32,2,mean,DNASE:GM03348


In [9]:
human_dataset = get_dataset('human', 'train').batch(1).repeat()
# mouse_dataset = get_dataset('mouse', 'train').batch(1).repeat()
# human_mouse_dataset = tf.data.Dataset.zip((human_dataset, mouse_dataset)).prefetch(2)

I0000 00:00:1733811211.300885   42076 gpu_device.cc:2022] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 22280 MB memory:  -> device: 0, name: NVIDIA GeForce RTX 4090, pci bus id: 0000:06:00.0, compute capability: 8.9


In [10]:
# Example input
it = iter(human_dataset)
example = next(it)
print({k: (v.shape, v.dtype) for k,v in example.items()})

{'sequence': (TensorShape([1, 131072, 4]), tf.float32), 'target': (TensorShape([1, 896, 5313]), tf.float32)}


2024-12-10 14:13:32.300581: E tensorflow/core/util/util.cc:131] oneDNN supports DT_HALF only on platforms with AVX-512. Falling back to the default Eigen-based implementation if present.
2024-12-10 14:13:32.303499: I tensorflow/core/kernels/data/tf_record_dataset_op.cc:370] TFRecordDataset `buffer_size` is unspecified, default to 262144


## Model training

In [11]:
def create_step_function(model, optimizer):

  @tf.function
  def train_step(batch, head, optimizer_clip_norm_global=0.2):
    with tf.GradientTape() as tape:
      outputs = model(batch['sequence'], is_training=True)[head]
      loss = tf.reduce_mean(
          tf.keras.losses.poisson(batch['target'], outputs))

    gradients = tape.gradient(loss, model.trainable_variables)
    optimizer.apply(gradients, model.trainable_variables)

    return loss
  return train_step

In [12]:
learning_rate = tf.Variable(0., trainable=False, name='learning_rate')
optimizer = snt.optimizers.Adam(learning_rate=learning_rate)
num_warmup_steps = 5000
target_learning_rate = 0.0005

model = enformer.Enformer(channels=1536 // 4,  # Use 4x fewer channels to train faster.
                          num_heads=8,
                          num_transformer_layers=11,
                          pooling_type='max')

train_step = create_step_function(model, optimizer)

In [13]:
# Train the model
steps_per_epoch = 20
num_epochs = 5

data_it = iter(human_mouse_dataset)
global_step = 0
for epoch_i in range(num_epochs):
  for i in tqdm(range(steps_per_epoch)):
    global_step += 1

    if global_step > 1:
      learning_rate_frac = tf.math.minimum(
          1.0, global_step / tf.math.maximum(1.0, num_warmup_steps))
      learning_rate.assign(target_learning_rate * learning_rate_frac)

    batch_human, batch_mouse = next(data_it)

    loss_human = train_step(batch=batch_human, head='human')
    loss_mouse = train_step(batch=batch_mouse, head='mouse')

  # End of epoch.
  print('')
  print('loss_human', loss_human.numpy(),
        'loss_mouse', loss_mouse.numpy(),
        'learning_rate', optimizer.learning_rate.numpy()
        )

NameError: name 'human_mouse_dataset' is not defined

## Evaluate

In [14]:
# @title `PearsonR` and `R2` metrics

def _reduced_shape(shape, axis):
  if axis is None:
    return tf.TensorShape([])
  return tf.TensorShape([d for i, d in enumerate(shape) if i not in axis])


class CorrelationStats(tf.keras.metrics.Metric):
  """Contains shared code for PearsonR and R2."""

  def __init__(self, reduce_axis=None, name='pearsonr'):
    """Pearson correlation coefficient.

    Args:
      reduce_axis: Specifies over which axis to compute the correlation (say
        (0, 1). If not specified, it will compute the correlation across the
        whole tensor.
      name: Metric name.
    """
    super(CorrelationStats, self).__init__(name=name)
    self._reduce_axis = reduce_axis
    self._shape = None  # Specified in _initialize.

  def _initialize(self, input_shape):
    # Remaining dimensions after reducing over self._reduce_axis.
    self._shape = _reduced_shape(input_shape, self._reduce_axis)

    weight_kwargs = dict(shape=self._shape, initializer='zeros')
    self._count = self.add_weight(name='count', **weight_kwargs)
    self._product_sum = self.add_weight(name='product_sum', **weight_kwargs)
    self._true_sum = self.add_weight(name='true_sum', **weight_kwargs)
    self._true_squared_sum = self.add_weight(name='true_squared_sum',
                                             **weight_kwargs)
    self._pred_sum = self.add_weight(name='pred_sum', **weight_kwargs)
    self._pred_squared_sum = self.add_weight(name='pred_squared_sum',
                                             **weight_kwargs)

  def update_state(self, y_true, y_pred, sample_weight=None):
    """Update the metric state.

    Args:
      y_true: Multi-dimensional float tensor [batch, ...] containing the ground
        truth values.
      y_pred: float tensor with the same shape as y_true containing predicted
        values.
      sample_weight: 1D tensor aligned with y_true batch dimension specifying
        the weight of individual observations.
    """
    if self._shape is None:
      # Explicit initialization check.
      self._initialize(y_true.shape)
    y_true.shape.assert_is_compatible_with(y_pred.shape)
    y_true = tf.cast(y_true, 'float32')
    y_pred = tf.cast(y_pred, 'float32')

    self._product_sum.assign_add(
        tf.reduce_sum(y_true * y_pred, axis=self._reduce_axis))

    self._true_sum.assign_add(
        tf.reduce_sum(y_true, axis=self._reduce_axis))

    self._true_squared_sum.assign_add(
        tf.reduce_sum(tf.math.square(y_true), axis=self._reduce_axis))

    self._pred_sum.assign_add(
        tf.reduce_sum(y_pred, axis=self._reduce_axis))

    self._pred_squared_sum.assign_add(
        tf.reduce_sum(tf.math.square(y_pred), axis=self._reduce_axis))

    self._count.assign_add(
        tf.reduce_sum(tf.ones_like(y_true), axis=self._reduce_axis))

  def result(self):
    raise NotImplementedError('Must be implemented in subclasses.')

  def reset_states(self):
    if self._shape is not None:
      tf.keras.backend.batch_set_value([(v, np.zeros(self._shape))
                                        for v in self.variables])


class PearsonR(CorrelationStats):
  """Pearson correlation coefficient.

  Computed as:
  ((x - x_avg) * (y - y_avg) / sqrt(Var[x] * Var[y])
  """

  def __init__(self, reduce_axis=(0,), name='pearsonr'):
    """Pearson correlation coefficient.

    Args:
      reduce_axis: Specifies over which axis to compute the correlation.
      name: Metric name.
    """
    super(PearsonR, self).__init__(reduce_axis=reduce_axis,
                                   name=name)

  def result(self):
    true_mean = self._true_sum / self._count
    pred_mean = self._pred_sum / self._count

    covariance = (self._product_sum
                  - true_mean * self._pred_sum
                  - pred_mean * self._true_sum
                  + self._count * true_mean * pred_mean)

    true_var = self._true_squared_sum - self._count * tf.math.square(true_mean)
    pred_var = self._pred_squared_sum - self._count * tf.math.square(pred_mean)
    tp_var = tf.math.sqrt(true_var) * tf.math.sqrt(pred_var)
    correlation = covariance / tp_var

    return correlation


class R2(CorrelationStats):
  """R-squared  (fraction of explained variance)."""

  def __init__(self, reduce_axis=None, name='R2'):
    """R-squared metric.

    Args:
      reduce_axis: Specifies over which axis to compute the correlation.
      name: Metric name.
    """
    super(R2, self).__init__(reduce_axis=reduce_axis,
                             name=name)

  def result(self):
    true_mean = self._true_sum / self._count
    total = self._true_squared_sum - self._count * tf.math.square(true_mean)
    residuals = (self._pred_squared_sum - 2 * self._product_sum
                 + self._true_squared_sum)

    return tf.ones_like(residuals) - residuals / total


class MetricDict:
  def __init__(self, metrics):
    self._metrics = metrics

  def update_state(self, y_true, y_pred):
    for k, metric in self._metrics.items():
      metric.update_state(y_true, y_pred)

  def result(self):
    return {k: metric.result() for k, metric in self._metrics.items()}

In [15]:
def evaluate_model(model, dataset, head, max_steps=None):
  metric = MetricDict({'PearsonR': PearsonR(reduce_axis=(0,1))})
  @tf.function
  def predict(x):
    return model(x, is_training=False)[head]

  for i, batch in tqdm(enumerate(dataset)):
    if max_steps is not None and i > max_steps:
      break
    metric.update_state(batch['target'], predict(batch['sequence']))

  return metric.result()

In [16]:
model = enformer.Enformer(channels=1536 // 4,  # Use 4x fewer channels to train faster.
                          num_heads=8,
                          num_transformer_layers=11,
                          pooling_type='max')


In [17]:
metrics_human = evaluate_model(model,
                               dataset=get_dataset('human', 'valid').batch(1).prefetch(2),
                               head='human',
                               max_steps=100)
print('')
print({k: v.numpy().mean() for k, v in metrics_human.items()})

0it [00:00, ?it/s]

0it [00:00, ?it/s]


TypeError: in user code:

    File "/tmp/ipykernel_42076/3450058308.py", line 5, in predict  *
        return model(x, is_training=False)[head]
    File "/home/hxcai/anaconda3/envs/torch/lib/python3.10/site-packages/sonnet/src/utils.py", line 85, in _decorate_unbound_method  *
        return decorator_fn(bound_method, self, args, kwargs)
    File "/home/hxcai/anaconda3/envs/torch/lib/python3.10/site-packages/sonnet/src/base.py", line 262, in wrap_with_name_scope  *
        return method(*args, **kwargs)
    File "/home/hxcai/cell_type_specific_CRE/MPRA_predict/notebooks/pretrained_models/enformer.py", line 175, in __call__  *
        trunk_embedding = self.trunk(inputs, is_training=is_training)
    File "/home/hxcai/anaconda3/envs/torch/lib/python3.10/site-packages/sonnet/src/utils.py", line 85, in _decorate_unbound_method  *
        return decorator_fn(bound_method, self, args, kwargs)
    File "/home/hxcai/anaconda3/envs/torch/lib/python3.10/site-packages/sonnet/src/base.py", line 262, in wrap_with_name_scope  *
        return method(*args, **kwargs)
    File "/home/hxcai/cell_type_specific_CRE/MPRA_predict/notebooks/pretrained_models/enformer.py", line 231, in __call__  *
        outputs = mod(outputs, **kwargs)
    File "/home/hxcai/anaconda3/envs/torch/lib/python3.10/site-packages/sonnet/src/utils.py", line 85, in _decorate_unbound_method  *
        return decorator_fn(bound_method, self, args, kwargs)
    File "/home/hxcai/anaconda3/envs/torch/lib/python3.10/site-packages/sonnet/src/base.py", line 262, in wrap_with_name_scope  *
        return method(*args, **kwargs)

    TypeError: outer_factory.<locals>.inner_factory.<locals>.tf____call__() missing 1 required positional argument: 'is_training'


In [None]:
# metrics_mouse = evaluate_model(model,
#                                dataset=get_dataset('mouse', 'valid').batch(1).prefetch(2),
#                                head='mouse',
#                                max_steps=100)
# print('')
# print({k: v.numpy().mean() for k, v in metrics_mouse.items()})

101it [00:21,  6.54it/s]


{'PearsonR': 0.005183698}


# Restore Checkpoint

Note: For the TF-Hub Enformer model, the required input sequence length is 393,216 which actually gets cropped within the model to 196,608. The open source module does not internally crop the sequence. Therefore, the code below crops the central `196,608 bp` of the longer sequence to reproduce the output of the TF hub from the reloaded checkpoint.

In [None]:
np.random.seed(42)
EXTENDED_SEQ_LENGTH = 393_216
SEQ_LENGTH = 196_608
inputs = np.array(np.random.random((1, EXTENDED_SEQ_LENGTH, 4)), dtype=np.float32)
inputs_cropped = enformer.TargetLengthCrop1D(SEQ_LENGTH)(inputs)

In [None]:
checkpoint_gs_path = 'gs://dm-enformer/models/enformer/sonnet_weights/*'
checkpoint_path = '/tmp/enformer_checkpoint'

In [None]:
!mkdir /tmp/enformer_checkpoint

mkdir: cannot create directory ‘/tmp/enformer_checkpoint’: File exists


In [None]:
# Copy checkpoints from GCS to temporary directory.
# This will take a while as the checkpoint is ~ 1GB.
for file_path in tf.io.gfile.glob(checkpoint_gs_path):
  print(file_path)
  file_name = os.path.basename(file_path)
  tf.io.gfile.copy(file_path, f'{checkpoint_path}/{file_name}', overwrite=True)

gs://dm-enformer/models/enformer/sonnet_weights/checkpoint
gs://dm-enformer/models/enformer/sonnet_weights/enformer-fine-tuned-human-1.data-00000-of-00001
gs://dm-enformer/models/enformer/sonnet_weights/enformer-fine-tuned-human-1.index


In [None]:
!ls -lh /tmp/enformer_checkpoint

total 959M
-rw-r--r-- 1 root root  111 May 25 10:58 checkpoint
-rw-r--r-- 1 root root 959M May 25 10:59 enformer-fine-tuned-human-1.data-00000-of-00001
-rw-r--r-- 1 root root 5.7K May 25 10:59 enformer-fine-tuned-human-1.index


In [None]:
enformer_model = enformer.Enformer()

In [None]:
checkpoint = tf.train.Checkpoint(module=enformer_model)

In [None]:
latest = tf.train.latest_checkpoint(checkpoint_path)
print(latest)
status = checkpoint.restore(latest)

/tmp/enformer_checkpoint/enformer-fine-tuned-human-1


In [None]:
# Using `is_training=False` to match TF-hub predict_on_batch function.
restored_predictions = enformer_model(inputs_cropped, is_training=False)

In [None]:
import tensorflow_hub as hub
enformer_tf_hub_model = hub.load("https://tfhub.dev/deepmind/enformer/1").model

In [None]:
hub_predictions = enformer_tf_hub_model.predict_on_batch(inputs)

In [None]:
np.allclose(hub_predictions['human'], restored_predictions['human'], atol=1e-5)

True

In [None]:
# Can run with 'is_training=True' but note that this will
# change the predictions as the batch statistics will be updated
# and the outputs will likley not match the TF-hub model.
# enformer(inputs_cropped, is_training=True)