<a href="https://colab.research.google.com/github/jstjohn/enformer-pytorch/blob/main/evaluate_enformer_pytorch_correlation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
!git clone https://github.com/lucidrains/enformer-pytorch.git

Cloning into 'enformer-pytorch'...
remote: Enumerating objects: 643, done.[K
remote: Counting objects: 100% (132/132), done.[K
remote: Compressing objects: 100% (117/117), done.[K
remote: Total 643 (delta 28), reused 28 (delta 13), pack-reused 511[K
Receiving objects: 100% (643/643), 8.88 MiB | 3.09 MiB/s, done.
Resolving deltas: 100% (439/439), done.


In [3]:
!cd enformer-pytorch && pip install .

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Processing /content/enformer-pytorch
[33m  DEPRECATION: A future pip version will change local packages to be built in-place without first copying to a temporary directory. We recommend you use --use-feature=in-tree-build to test your packages with this new behavior before it becomes the default.
   pip 21.3 will remove support for this functionality. You can find discussion regarding this at https://github.com/pypa/pip/issues/7555.[0m
Collecting einops>=0.3
  Downloading einops-0.4.1-py3-none-any.whl (28 kB)
Collecting polars
  Downloading polars-0.13.40-cp37-abi3-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (11.9 MB)
[K     |████████████████████████████████| 11.9 MB 19.1 MB/s 
[?25hCollecting pyfaidx
  Downloading pyfaidx-0.7.0.tar.gz (102 kB)
[K     |████████████████████████████████| 102 kB 70.3 MB/s 
Collecting transformers
  Downloading transformers-4.19.2-py3-none-any.whl (4.

In [4]:
!pip install torchmetrics kipoiseq==0.5.2 BioPython --quiet > /dev/null

In [5]:
import torch
import numpy as np
import tensorflow as tf
import os 
import json
import pandas as pd
import pyfaidx
import kipoiseq
import functools
from kipoiseq import Interval

SEQUENCE_LENGTH = 196_608
BIN_SIZE = 128
TARGET_LENGTH = 896
import os
fasta_dir = "/root/data/"
!mkdir -p {fasta_dir}
human_fasta_f = 'hg38.ml.fa.gz'
mouse_fasta_f = 'mm10.ml.fa.gz'
human_fasta_gz_path = f"{fasta_dir}/{human_fasta_f}"
mouse_fasta_gz_path = f"{fasta_dir}/{mouse_fasta_f}"
human_fasta_path = human_fasta_gz_path.rstrip(".gz")
mouse_fasta_path = mouse_fasta_gz_path.rstrip(".gz")

if not os.path.isfile(human_fasta_path):
  !gsutil -m cp -n gs://basenji_barnyard/{human_fasta_f} {human_fasta_gz_path}
  !gunzip {human_fasta_gz_path}
if not os.path.isfile(mouse_fasta_path):
  !gsutil -m cp -n gs://basenji_barnyard/{mouse_fasta_f} {mouse_fasta_gz_path}
  !gunzip {mouse_fasta_gz_path}

class FastaStringExtractor:
    
    def __init__(self, fasta_file):
        self.fasta = pyfaidx.Fasta(fasta_file)
        self._chromosome_sizes = {k: len(v) for k, v in self.fasta.items()}

    def extract(self, interval: Interval, **kwargs) -> str:
        # Truncate interval if it extends beyond the chromosome lengths.
        chromosome_length = self._chromosome_sizes[interval.chrom]
        trimmed_interval = Interval(interval.chrom,
                                    max(interval.start, 0),
                                    min(interval.end, chromosome_length),
                                    )
        # pyfaidx wants a 1-based interval
        sequence = str(self.fasta.get_seq(trimmed_interval.chrom,
                                          trimmed_interval.start + 1,
                                          trimmed_interval.stop).seq).upper()
        # Fill truncated values with N's.
        pad_upstream = 'N' * max(-interval.start, 0)
        pad_downstream = 'N' * max(interval.end - chromosome_length, 0)
        return pad_upstream + sequence + pad_downstream

    def close(self):
        return self.fasta.close()


class BasenjiDataSet(torch.utils.data.IterableDataset):
  @staticmethod
  def get_organism_path(organism):
    return os.path.join('gs://basenji_barnyard/data', organism)
  @classmethod
  def get_metadata(cls, organism):
    # Keys:
    # num_targets, train_seqs, valid_seqs, test_seqs, seq_length,
    # pool_width, crop_bp, target_length
    path = os.path.join(cls.get_organism_path(organism), 'statistics.json')
    with tf.io.gfile.GFile(path, 'r') as f:
      return json.load(f)
  @staticmethod
  def one_hot_encode(sequence):
    return kipoiseq.transforms.functional.one_hot_dna(sequence).astype(np.float32)

  @classmethod
  def get_tfrecord_files(cls, organism, subset):
    # Sort the values by int(*).
    return sorted(tf.io.gfile.glob(os.path.join(
        cls.get_organism_path(organism), 'tfrecords', f'{subset}-*.tfr'
      )), key=lambda x: int(x.split('-')[-1].split('.')[0]))
  
  @property
  def num_channels(self):
    metadata = self.get_metadata(self.organism)
    return metadata['num_targets']

  @staticmethod
  def deserialize(serialized_example, metadata):
    """Deserialize bytes stored in TFRecordFile."""
    # Deserialization
    feature_map = {
        'sequence': tf.io.FixedLenFeature([], tf.string),  # Ignore this, resize our own bigger one
        'target': tf.io.FixedLenFeature([], tf.string),
    }
    example = tf.io.parse_example(serialized_example, feature_map)
    sequence = tf.io.decode_raw(example['sequence'], tf.bool)
    sequence = tf.reshape(sequence, (metadata['seq_length'], 4))
    sequence = tf.cast(sequence, tf.float32)

    target = tf.io.decode_raw(example['target'], tf.float16)
    target = tf.reshape(target,
                        (metadata['target_length'], metadata['num_targets']))
    target = tf.cast(target, tf.float32)

    return {'sequence_old': sequence,
            'target': target}

  @classmethod
  def get_dataset(cls, organism, subset, num_threads=8):
    metadata = cls.get_metadata(organism)
    dataset = tf.data.TFRecordDataset(cls.get_tfrecord_files(organism, subset),
                                      compression_type='ZLIB',
                                      num_parallel_reads=num_threads).map(
                                          functools.partial(cls.deserialize, metadata=metadata)
                                      )
    return dataset

  def __init__(self, organism:str, subset:str, seq_len:int, fasta_path:str, n_to_test:int = -1):
    assert subset in {"train", "valid", "test"}
    assert organism in {"human", "mouse"}
    self.organism = organism
    self.subset = subset
    self.base_dir = self.get_organism_path(organism)
    self.seq_len = seq_len
    self.fasta_reader = FastaStringExtractor(fasta_path)
    self.n_to_test = n_to_test
    with tf.io.gfile.GFile(f"{self.base_dir}/sequences.bed", 'r') as f:
      region_df = pd.read_csv(f, sep="\t", header=None)
      region_df.columns = ['chrom', 'start', 'end', 'subset']
      self.region_df = region_df.query('subset==@subset').reset_index(drop=True)
      
  def __iter__(self):
    worker_info = torch.utils.data.get_worker_info()
    assert worker_info is None, "Only support single process loading"
    # If num_threads > 1, the following will actually shuffle the inputs! luckily we catch this with the sequence comparison
    basenji_iterator = self.get_dataset(self.organism, self.subset, num_threads=1).as_numpy_iterator()
    for i, records in enumerate(basenji_iterator):
      loc_row = self.region_df.iloc[i]
      target_interval = Interval(loc_row['chrom'], loc_row['start'], loc_row['end'])
      sequence_one_hot = self.one_hot_encode(self.fasta_reader.extract(target_interval.resize(self.seq_len)))
      if self.n_to_test >= 0 and i < self.n_to_test:
        old_sequence_onehot = records["sequence_old"]
        if old_sequence_onehot.shape[0] > sequence_one_hot.shape[0]:
          diff = old_sequence_onehot.shape[0] - sequence_one_hot.shape[0]
          trim = diff//2
          np.testing.assert_equal(old_sequence_onehot[trim:(-trim)], sequence_one_hot)
        elif sequence_one_hot.shape[0] > old_sequence_onehot.shape[0]:
          diff = sequence_one_hot.shape[0] - old_sequence_onehot.shape[0]
          trim = diff//2
          np.testing.assert_equal(old_sequence_onehot, sequence_one_hot[trim:(-trim)])
        else:
          np.testing.assert_equal(old_sequence_onehot, sequence_one_hot)
      yield {
          "sequence": sequence_one_hot,
          "target": records["target"],
      }

Copying gs://basenji_barnyard/hg38.ml.fa.gz...
/ [1/1 files][839.8 MiB/839.8 MiB] 100% Done  58.5 MiB/s ETA 00:00:00           
Operation completed over 1 objects/839.8 MiB.                                    
Copying gs://basenji_barnyard/mm10.ml.fa.gz...
/ [1/1 files][800.8 MiB/800.8 MiB] 100% Done  65.9 MiB/s ETA 00:00:00           
Operation completed over 1 objects/800.8 MiB.                                    


In [6]:
import torch
from enformer_pytorch import Enformer

model = Enformer.from_pretrained("EleutherAI/enformer-official-rough")

Downloading:   0%|          | 0.00/464 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/959M [00:00<?, ?B/s]

In [7]:
model = model.eval().cuda()

In [14]:
from enformer_pytorch.metrics import MeanPearsonCorrCoefPerChannel

In [15]:
from tqdm import tqdm
from torchmetrics.regression.pearson import PearsonCorrCoef
def compute_correlation(model, organism:str="human", subset:str="valid", max_steps=-1):
  fasta_path = human_fasta_path if organism == "human" else mouse_fasta_path
  ds = BasenjiDataSet(organism, subset, SEQUENCE_LENGTH, fasta_path)
  total = len(ds.region_df) # number of records
  dl = torch.utils.data.DataLoader(ds, num_workers=0, batch_size=1)
  corr_coef = MeanPearsonCorrCoefPerChannel(n_channels=ds.num_channels)
  n_steps = total if max_steps <= 0 else max_steps
  for i,batch in enumerate(tqdm(dl, total=n_steps)):
    if max_steps > 0 and i >= max_steps:
      break
    batch_gpu = {k:v.to(model.device) for k,v in batch.items()}
    sequence = batch_gpu['sequence']
    target = batch_gpu['target']
    with torch.no_grad():
      pred = model(sequence)[organism]
      corr_coef(preds=pred.cpu(), target=target.cpu())
  return corr_coef.compute().mean()
compute_correlation(model, organism="human", subset="valid", max_steps=100)

100%|██████████| 100/100 [01:22<00:00,  1.21it/s]


tensor(0.6270)

In [None]:
compute_correlation(model, organism="human", subset="valid", max_steps=-1)

                not been set for this class (MeanPearsonCorrCoefPerChannel). The property determines if `update` by
                default needs access to the full metric state. If this is not the case, significant speedups can be
                achieved and we recommend setting this to `False`.
                We provide an checking function
                `from torchmetrics.utilities import check_forward_no_full_state`
                that can be used to check if the `full_state_update=True` (old and potential slower behaviour,
                default for now) or if `full_state_update=False` can be used safely.
                
100%|██████████| 2213/2213 [30:51<00:00,  1.20it/s]


tensor(0.6252)

In [None]:
compute_correlation(model, organism="human", subset="test", max_steps=-1)

                not been set for this class (MeanPearsonCorrCoefPerChannel). The property determines if `update` by
                default needs access to the full metric state. If this is not the case, significant speedups can be
                achieved and we recommend setting this to `False`.
                We provide an checking function
                `from torchmetrics.utilities import check_forward_no_full_state`
                that can be used to check if the `full_state_update=True` (old and potential slower behaviour,
                default for now) or if `full_state_update=False` can be used safely.
                
100%|██████████| 1937/1937 [20:05<00:00,  1.61it/s]


tensor(0.6503)

In [None]:
compute_correlation(model, organism="human", subset="train", max_steps=-1)

                not been set for this class (MeanPearsonCorrCoefPerChannel). The property determines if `update` by
                default needs access to the full metric state. If this is not the case, significant speedups can be
                achieved and we recommend setting this to `False`.
                We provide an checking function
                `from torchmetrics.utilities import check_forward_no_full_state`
                that can be used to check if the `full_state_update=True` (old and potential slower behaviour,
                default for now) or if `full_state_update=False` can be used safely.
                
100%|██████████| 34021/34021 [4:58:22<00:00,  1.90it/s]


tensor(0.7415)