# Enformer human validation 

### Load  pre-trained model 

In [48]:
import tensorflow as tf
import tensorflow_hub as hub
import joblib
import gzip
import kipoiseq
from kipoiseq import Interval
import pyfaidx
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
import seaborn as sns
import os
import enformer 
from tqdm import tqdm
import importlib.util

%matplotlib inline
%config InlineBackend.figure_format = 'retina'

In [115]:
# import enformer.py as module
spec = importlib.util.spec_from_file_location("enformer", os.path.join(os.getcwd() ,"enformer.py"))
enformer = importlib.util.module_from_spec(spec)
spec.loader.exec_module(enformer)
from enformer import * 

# import utils.py as module
spec_utils = importlib.util.spec_from_file_location("enformer", os.path.join(os.getcwd() ,"utils.py"))
utils = importlib.util.module_from_spec(spec_utils)
spec.loader.exec_module(utils)
from utils import * 

### Load files

In [102]:
transform_path = 'gs://dm-enformer/models/enformer.finetuned.SAD.robustscaler-PCA500-robustscaler.transform.pkl'
model_path = 'https://tfhub.dev/deepmind/enformer/1'
datadir = "../../../../data/FED"
fasta_file = os.path.join(datadir, "hg38.fa")
human_sequences = os.path.join(datadir, "data_human_sequences.bed")

In [53]:
pyfaidx.Faidx(fasta_file)

Faidx("../../../../data/FED/hg38.fa")

In [72]:
class Enformer:

    def __init__(self, tfhub_url):
        self._model = hub.load(tfhub_url).model

    def predict_on_batch(self, inputs):
        predictions = self._model.predict_on_batch(inputs)
        return {k: v.numpy() for k, v in predictions.items()}

    @tf.function
    def contribution_input_grad(self, input_sequence,
                                  target_mask, output_head='human'):
        input_sequence = input_sequence[tf.newaxis]

        target_mask_mass = tf.reduce_sum(target_mask)
        with tf.GradientTape() as tape:
            tape.watch(input_sequence)
            prediction = tf.reduce_sum(
              target_mask[tf.newaxis] *
              self._model.predict_on_batch(input_sequence)[output_head]) / target_mask_mass

        input_grad = tape.gradient(prediction, input_sequence) * input_sequence
        input_grad = tf.squeeze(input_grad, axis=0)
        return tf.reduce_sum(input_grad, axis=-1)
    
    def __call__(self, inputs: tf.Tensor,
                   is_training: bool) -> Dict[str, tf.Tensor]:
        trunk_embedding = self.trunk(inputs, is_training=is_training)
        return {
            head: head_module(trunk_embedding, is_training=is_training)
            for head, head_module in self.heads.items()
        }





class EnformerScoreVariantsRaw:

      def __init__(self, tfhub_url, organism='human'):
        self._model = Enformer(tfhub_url)
        self._organism = organism
  
      def predict_on_batch(self, inputs):
        ref_prediction = self._model.predict_on_batch(inputs['ref'])[self._organism]
        alt_prediction = self._model.predict_on_batch(inputs['alt'])[self._organism]

        return alt_prediction.mean(axis=1) - ref_prediction.mean(axis=1)


class EnformerScoreVariantsNormalized:

    def __init__(self, tfhub_url, transform_pkl_path,
                   organism='human'):
        assert organism == 'human', 'Transforms only compatible with organism=human'
        self._model = EnformerScoreVariantsRaw(tfhub_url, organism)
        with tf.io.gfile.GFile(transform_pkl_path, 'rb') as f:
          transform_pipeline = joblib.load(f)
        self._transform = transform_pipeline.steps[0][1]  # StandardScaler.

    def predict_on_batch(self, inputs):
        scores = self._model.predict_on_batch(inputs)
        return self._transform.transform(scores)


class EnformerScoreVariantsPCANormalized:

    def __init__(self, tfhub_url, transform_pkl_path,
                   organism='human', num_top_features=500):
        self._model = EnformerScoreVariantsRaw(tfhub_url, organism)
        with tf.io.gfile.GFile(transform_pkl_path, 'rb') as f:
          self._transform = joblib.load(f)
        self._num_top_features = num_top_features

    def predict_on_batch(self, inputs):
        scores = self._model.predict_on_batch(inputs)
        return self._transform.transform(scores)[:, :self._num_top_features]
    
class FastaStringExtractor:
    
    def __init__(self, fasta_file):
        self.fasta = pyfaidx.Fasta(fasta_file)
        self._chromosome_sizes = {k: len(v) for k, v in self.fasta.items()}

    def extract(self, interval: Interval, **kwargs) -> str:
        # Truncate interval if it extends beyond the chromosome lengths.
        chromosome_length = self._chromosome_sizes[interval.chrom]
        trimmed_interval = Interval(interval.chrom,
                                    max(interval.start, 0),
                                    min(interval.end, chromosome_length),
                                    )
        # pyfaidx wants a 1-based interval
        sequence = str(self.fasta.get_seq(trimmed_interval.chrom,
                                          trimmed_interval.start + 1,
                                          trimmed_interval.stop).seq).upper()
        # Fill truncated values with N's.
        pad_upstream = 'N' * max(-interval.start, 0)
        pad_downstream = 'N' * max(interval.end - chromosome_length, 0)
        return pad_upstream + sequence + pad_downstream

    def close(self):
        return self.fasta.close()

In [73]:
model = Enformer(model_path)

In [74]:
fasta_extractor = FastaStringExtractor(fasta_file)

### Check tracks

In [58]:
# Download targets from Basenji2 dataset 
# Cite: Kelley et al Cross-species regulatory sequence activity prediction. PLoS Comput. Biol. 16, e1008050 (2020).
targets_txt = 'https://raw.githubusercontent.com/calico/basenji/master/manuscripts/cross2020/targets_human.txt'
df_targets = pd.read_csv(targets_txt, sep='\t')
df_targets

Unnamed: 0,index,genome,identifier,file,clip,scale,sum_stat,description
0,0,0,ENCFF833POA,/home/drk/tillage/datasets/human/dnase/encode/...,32,2,mean,DNASE:cerebellum male adult (27 years) and mal...
1,1,0,ENCFF110QGM,/home/drk/tillage/datasets/human/dnase/encode/...,32,2,mean,DNASE:frontal cortex male adult (27 years) and...
2,2,0,ENCFF880MKD,/home/drk/tillage/datasets/human/dnase/encode/...,32,2,mean,DNASE:chorion
3,3,0,ENCFF463ZLQ,/home/drk/tillage/datasets/human/dnase/encode/...,32,2,mean,DNASE:Ishikawa treated with 0.02% dimethyl sul...
4,4,0,ENCFF890OGQ,/home/drk/tillage/datasets/human/dnase/encode/...,32,2,mean,DNASE:GM03348
...,...,...,...,...,...,...,...,...
5308,5308,0,CNhs14239,/home/drk/tillage/datasets/human/cage/fantom/C...,384,1,sum,CAGE:epithelioid sarcoma cell line:HS-ES-2R
5309,5309,0,CNhs14240,/home/drk/tillage/datasets/human/cage/fantom/C...,384,1,sum,CAGE:squamous cell lung carcinoma cell line:RE...
5310,5310,0,CNhs14241,/home/drk/tillage/datasets/human/cage/fantom/C...,384,1,sum,CAGE:gastric cancer cell line:GSS
5311,5311,0,CNhs14244,/home/drk/tillage/datasets/human/cage/fantom/C...,384,1,sum,CAGE:carcinoid cell line:NCI-H727


## predict human validation set 

In [59]:
target_interval = kipoiseq.Interval('chr1', 35_082_742, 35_197_430)  

# One hot encode 
# resize only takes the sequence of that interval, center it and gets the reduced sequence length 
sequence_one_hot = one_hot_encode(fasta_extractor.extract(target_interval.resize(393216)))

# Predict on that sequence
predictions = model.predict_on_batch(sequence_one_hot[np.newaxis])['human'][0]

In [60]:
predictions[:,1]

array([2.50549219e-03, 1.58413667e-02, 5.63042536e-02, 3.39204110e-02,
       4.74021807e-02, 8.28567967e-02, 6.69879243e-02, 2.55072415e-02,
       2.35070698e-02, 8.42880532e-02, 1.09183863e-01, 1.03348963e-01,
       1.13832407e-01, 1.35337815e-01, 1.30810678e-01, 1.26478016e-01,
       1.30831897e-01, 1.08674467e-01, 1.68840867e-02, 1.57137564e-03,
       6.97181970e-02, 1.64800078e-01, 1.79596424e-01, 1.39135689e-01,
       1.06212147e-01, 1.00273758e-01, 7.68458992e-02, 5.08518554e-02,
       2.55599953e-02, 8.27139467e-02, 7.20480606e-02, 2.39708293e-02,
       8.59023130e-04, 9.85492952e-03, 6.36985246e-03, 3.91527824e-02,
       5.10640815e-02, 1.69637613e-03, 1.57112014e-02, 3.00747268e-02,
       1.82608829e-03, 2.42718067e-02, 1.20249167e-02, 2.82154116e-03,
       3.24195391e-03, 6.53163269e-02, 1.17052682e-01, 4.80120666e-02,
       5.21726627e-03, 3.82351503e-02, 9.86452103e-02, 1.02631025e-01,
       1.13904156e-01, 6.09156340e-02, 5.77645972e-02, 7.97166303e-03,
      

In [75]:
predictions.shape

(896, 5313)

In [101]:
df = pd.read_csv(human_sequences, memory_map=True, header=None, index_col=False, delimiter="\t")

# keep only validation intervals 
validation_intervals= df[df[3]=="valid"]
validation_intervals = validation_intervals.head()


# create list with interval
interval_list = list()
validation_intervals.apply(lambda row : interval_list.append(kipoiseq.Interval(row[0],row[1], row[2])), axis = 1)
interval_list



NameError: name 'human_sequences' is not defined

In [100]:
validation_intervals

NameError: name 'validation_intervals' is not defined

In [90]:
my_interval = interval_list[1]

In [91]:
# resize only takes the sequence of that interval, center it and gets the reduced sequence length 
sequence_one_hot = one_hot_encode(fasta_extractor.extract(my_interval.resize(SEQUENCE_LENGTH)))
# Predict on that sequence
predictions = model.predict_on_batch(sequence_one_hot[np.newaxis])['human'][0]

In [93]:
predictions.shape

(896, 5313)

In [65]:
# %load utils
# @title `get_dataset(organism, subset, num_threads=8)`
import glob
import json
import os
import functools


def organism_path(organism):
    return os.path.join('gs://basenji_barnyard/data', organism)

#
def get_dataset(organism, subset, num_threads=8):
    
    metadata = get_metadata(organism)
    
    dataset = tf.data.TFRecordDataset(tfrecord_files(organism, subset),
                                        compression_type='ZLIB',
                                        num_parallel_reads=num_threads)
    dataset = dataset.map(functools.partial(deserialize, metadata=metadata),
                            num_parallel_calls=num_threads)
    return dataset


def get_metadata(organism):
  # Keys:
  # num_targets, train_seqs, valid_seqs, test_seqs, seq_length,
  # pool_width, crop_bp, target_length
    path = os.path.join(organism_path(organism), 'statistics.json')
    with tf.io.gfile.GFile(path, 'r') as f:
        return json.load(f)


def tfrecord_files(organism, subset):
  # Sort the values by int(*).
  return sorted(tf.io.gfile.glob(os.path.join(
      organism_path(organism), 'tfrecords', f'{subset}-*.tfr'
  )), key=lambda x: int(x.split('-')[-1].split('.')[0]))


def deserialize(serialized_example, metadata):
    """Deserialize bytes stored in TFRecordFile."""
    feature_map = {
          'sequence': tf.io.FixedLenFeature([], tf.string),
          'target': tf.io.FixedLenFeature([], tf.string),
    }
    example = tf.io.parse_example(serialized_example, feature_map)
    sequence = tf.io.decode_raw(example['sequence'], tf.bool)
    sequence = tf.reshape(sequence, (metadata['seq_length'], 4))
    sequence = tf.cast(sequence, tf.float32)

    target = tf.io.decode_raw(example['target'], tf.float16)
    target = tf.reshape(target,
                          (metadata['target_length'], metadata['num_targets']))
    target = tf.cast(target, tf.float32)

    return {'sequence': sequence,
              'target': target}


### compute score (how well predicted)

In [76]:
# Get real values 
# Get the 896 long string

human_dataset = get_dataset('human', 'valid').batch(1).repeat()


# Get predicted values 




In [99]:
ds_size = sum(1 for _ in human_dataset)

KeyboardInterrupt: 

In [77]:
def _reduced_shape(shape, axis):
    if axis is None:
        return tf.TensorShape([])
    return tf.TensorShape([d for i, d in enumerate(shape) if i not in axis])


class CorrelationStats(tf.keras.metrics.Metric):
    """Contains shared code for PearsonR and R2."""

    def __init__(self, reduce_axis=None, name='pearsonr'):
        """Pearson correlation coefficient.

        Args:
          reduce_axis: Specifies over which axis to compute the correlation (say
            (0, 1). If not specified, it will compute the correlation across the
            whole tensor.
          name: Metric name.
        """
        super(CorrelationStats, self).__init__(name=name)
        self._reduce_axis = reduce_axis
        self._shape = None  # Specified in _initialize.

    def _initialize(self, input_shape):
        # Remaining dimensions after reducing over self._reduce_axis.
        self._shape = _reduced_shape(input_shape, self._reduce_axis)

        weight_kwargs = dict(shape=self._shape, initializer='zeros')
        self._count = self.add_weight(name='count', **weight_kwargs)
        self._product_sum = self.add_weight(name='product_sum', **weight_kwargs)
        self._true_sum = self.add_weight(name='true_sum', **weight_kwargs)
        self._true_squared_sum = self.add_weight(name='true_squared_sum',
                                                 **weight_kwargs)
        self._pred_sum = self.add_weight(name='pred_sum', **weight_kwargs)
        self._pred_squared_sum = self.add_weight(name='pred_squared_sum',
                                                 **weight_kwargs)

    def update_state(self, y_true, y_pred, sample_weight=None):
        """Update the metric state.

        Args:
          y_true: Multi-dimensional float tensor [batch, ...] containing the ground
            truth values.
          y_pred: float tensor with the same shape as y_true containing predicted
            values.
          sample_weight: 1D tensor aligned with y_true batch dimension specifying
            the weight of individual observations.
        """
        if self._shape is None:
          # Explicit initialization check.
          self._initialize(y_true.shape)
        y_true.shape.assert_is_compatible_with(y_pred.shape)
        y_true = tf.cast(y_true, 'float32')
        y_pred = tf.cast(y_pred, 'float32')

        self._product_sum.assign_add(
            tf.reduce_sum(y_true * y_pred, axis=self._reduce_axis))

        self._true_sum.assign_add(
            tf.reduce_sum(y_true, axis=self._reduce_axis))

        self._true_squared_sum.assign_add(
            tf.reduce_sum(tf.math.square(y_true), axis=self._reduce_axis))

        self._pred_sum.assign_add(
            tf.reduce_sum(y_pred, axis=self._reduce_axis))

        self._pred_squared_sum.assign_add(
            tf.reduce_sum(tf.math.square(y_pred), axis=self._reduce_axis))

        self._count.assign_add(
            tf.reduce_sum(tf.ones_like(y_true), axis=self._reduce_axis))

    def result(self):
        raise NotImplementedError('Must be implemented in subclasses.')

    def reset_states(self):
        if self._shape is not None:
            tf.keras.backend.batch_set_value([(v, np.zeros(self._shape))
                                        for v in self.variables])


class PearsonR(CorrelationStats):
    """Pearson correlation coefficient.

          Computed as:
      ((x - x_avg) * (y - y_avg) / sqrt(Var[x] * Var[y])
      """

    def __init__(self, reduce_axis=(0,), name='pearsonr'):
        """Pearson correlation coefficient.

        Args:
          reduce_axis: Specifies over which axis to compute the correlation.
          name: Metric name.
        """
        super(PearsonR, self).__init__(reduce_axis=reduce_axis,
                                       name=name)

    def result(self):
        true_mean = self._true_sum / self._count
        pred_mean = self._pred_sum / self._count

        covariance = (self._product_sum
                      - true_mean * self._pred_sum
                      - pred_mean * self._true_sum
                      + self._count * true_mean * pred_mean)

        true_var = self._true_squared_sum - self._count * tf.math.square(true_mean)
        pred_var = self._pred_squared_sum - self._count * tf.math.square(pred_mean)
        tp_var = tf.math.sqrt(true_var) * tf.math.sqrt(pred_var)
        correlation = covariance / tp_var

        return correlation


class R2(CorrelationStats):
    """R-squared  (fraction of explained variance)."""

    def __init__(self, reduce_axis=None, name='R2'):
        """R-squared metric.

        Args:
          reduce_axis: Specifies over which axis to compute the correlation.
          name: Metric name.
        """
        super(R2, self).__init__(reduce_axis=reduce_axis,
                                 name=name)

    def result(self):
        true_mean = self._true_sum / self._count
        total = self._true_squared_sum - self._count * tf.math.square(true_mean)
        residuals = (self._pred_squared_sum - 2 * self._product_sum
                     + self._true_squared_sum)

        return tf.ones_like(residuals) - residuals / total


class MetricDict:
    def __init__(self, metrics):
        self._metrics = metrics

    def update_state(self, y_true, y_pred):
        for k, metric in self._metrics.items():
            metric.update_state(y_true, y_pred)

    def result(self):
        return {k: metric.result() for k, metric in self._metrics.items()}

In [78]:
def evaluate_model(model, dataset, head, max_steps=None):
    
    metric = MetricDict({'PearsonR': PearsonR(reduce_axis=(0,1))})
    print("Metric dictionary created")
    
    @tf.function
    def predict(x):    
        return model(x, is_training=False)[head]
    
    print("")
    for i, batch in tqdm(enumerate(dataset)):
        if max_steps is not None and i > max_steps:
            break
        #metric.update_state(batch['target'], predict(batch['sequence']))
        metric.update_state(batch['target'], batch['target'])

    return metric.result()

In [None]:
def evaluate_model_all_sequences(model, dataset, head, max_steps=None):
    
    metric = MetricDict({'PearsonR': PearsonR(reduce_axis=(0,1))})
    print("Metric dictionary created")
    
    @tf.function
    def predict(x):    
        return model(x, is_training=False)[head]
    print(" Predict funciton loaded")
    
    for i, batch in tqdm(enumerate(dataset)):
        if max_steps is not None and i > max_steps:
            break
        #metric.update_state(batch['target'], predict(batch['sequence']))
        metric.update_state(batch['target'], batch['target'])

    return metric.result()

In [None]:
# Evaluate model on first ten 
# Right now it evaluates the whole model and 
metrics_human = evaluate_model(model,
                               dataset=get_dataset('human', 'valid').batch(2).prefetch(2),
                               head='human',
                               max_steps=10)
print('')
print({k: v.numpy().mean() for k, v in metrics_human.items()})

In [85]:
test_dataset = get_dataset('human', 'valid').batch(1).prefetch(2)

In [90]:
human_dataset = get_dataset('human', 'valid')

In [96]:
dataset = human_dataset

In [105]:
for i, batch in tqdm(enumerate(test_dataset)):
    print(i)
    print(batch)
    batch_one = batch
    if i > 0: 
        break
        

1it [00:06,  6.27s/it]

0
{'sequence': <tf.Tensor: shape=(1, 131072, 4), dtype=float32, numpy=
array([[[1., 0., 0., 0.],
        [0., 0., 0., 1.],
        [0., 0., 1., 0.],
        ...,
        [0., 0., 0., 1.],
        [0., 0., 1., 0.],
        [1., 0., 0., 0.]]], dtype=float32)>, 'target': <tf.Tensor: shape=(1, 896, 5313), dtype=float32, numpy=
array([[[0.09924316, 0.0927124 , 0.01834106, ..., 0.        ,
         0.        , 0.        ],
        [0.11126709, 0.1685791 , 0.03396606, ..., 0.        ,
         0.984375  , 0.        ],
        [0.14318848, 0.23217773, 0.01850891, ..., 0.        ,
         0.        , 0.        ],
        ...,
        [0.00662994, 0.01672363, 0.00756454, ..., 0.01852417,
         0.11566162, 0.        ],
        [0.00411224, 0.00155735, 0.        , ..., 0.        ,
         0.        , 0.        ],
        [0.06958008, 0.03845215, 0.04312134, ..., 0.        ,
         0.        , 0.        ]]], dtype=float32)>}
1
{'sequence': <tf.Tensor: shape=(1, 131072, 4), dtype=float32, num




In [114]:
batch_one["sequence"].numpy()

array([[[0., 1., 0., 0.],
        [1., 0., 0., 0.],
        [0., 1., 0., 0.],
        ...,
        [0., 1., 0., 0.],
        [0., 1., 0., 0.],
        [0., 0., 0., 1.]]], dtype=float32)