---
title: "Comparing different prediction matrices to ground truth"
author: "Laura Vairus"
date: "2023-08-16"
categories: [method, code]
---

When making the epigenome to query, we want it to be as accurate as possible. Some papers decribe they got their predictions by running enformer on the target interval, randomly shifted up and down, reverse complementing, and averaging them all together. To check which way is best, I made 4 different prediction matrices to evaluate by comparing against the ground truth
- (forward) normal forward prediction on the target interval
- (reverse) reverse complement prediction on the target interval
- (forrev) average of the forward and reverse complement predictions above
- (forrevshift) average of the forward and reverse complement prediction on the target interval normally, shifted up 3bps, and shifted down 3bps

## Getting ground truth

First, I had to get the ground truth data. I used the matrices from "/grand/TFXcan/imlab/data/enformer_training_data/basenji_data_h5/no_groups_popseq_revised_order/test_pop_seq.hdf5", which were used to test the training of enformer.

This hdf5 file contains 4 datasets: pop_sequence, query_regions, sequence, and target.

- query regions has 1937 lists of interval data: chromosome number, start position, and end position. shape (1937, 3)
- sequence has 1937 DNA sequeces that are 131072 bps long and encoded into a one-hot-matrix of 4 groups. shape (1947, 131072, 4)
- target has 1937 target enformer outputs of the usual (896, 5313) bin by track matrices. shape (1937, 896, 5313)

I will use the target matrices as the ground truth


There was some confusion around whether the 1937 'target' predictions and 'query_region' intervals lined up so I wrote some code to compare each 'sequence' entry to the one-hot-encoded extraction of the corresponding 'query_regions' interval from the hg38 reference genome

### code setup

In [2]:
# imports

import h5py
import numpy as np
import tensorflow as tf
import tensorflow_hub as hub
import joblib
import gzip
import kipoiseq
from kipoiseq import Interval
import pyfaidx
import pandas as pd
import seaborn as sns

2023-08-16 19:30:26.117380: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  SSE3 SSE4.1 SSE4.2 AVX AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2023-08-16 19:30:30.339908: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.8'; dlerror: libnvinfer.so.8: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /opt/nvidia/hpc_sdk/Linux_x86_64/21.9/comm_libs/nvshmeme/lib:/opt/nvidia/hpc_sdk/Linux_x86_64/21.9/comm_libs/nccl/lib:/opt/nvidia/hpc_sdk/Linux_x86_64/21.9/math_libs/lib64:/opt/nvidia/hpc_sdk/Linux_x86_64/21.9/compilers/lib:/opt/nvidia/hpc_sdk/Linux_x86_64/21.9/compilersextras/qd/lib:/opt/nvidia/hpc_sdk/Linux_x86_64/21.9/cudaextras/CUPTI/lib64:/opt/nvidia/hpc_sdk/Linux_x86_64

In [3]:
# # @title `Enformer`, `EnformerScoreVariantsNormalized`, `EnformerScoreVariantsPCANormalized`,
# SEQUENCE_LENGTH = 393216

# class Enformer:

#   def __init__(self, tfhub_url):
#     self._model = hub.load(tfhub_url).model

#   def predict_on_batch(self, inputs):
#     predictions = self._model.predict_on_batch(inputs)
#     return {k: v.numpy() for k, v in predictions.items()}

#   @tf.function
#   def contribution_input_grad(self, input_sequence,
#                               target_mask, output_head='human'):
#     input_sequence = input_sequence[tf.newaxis]

#     target_mask_mass = tf.reduce_sum(target_mask)
#     with tf.GradientTape() as tape:
#       tape.watch(input_sequence)
#       prediction = tf.reduce_sum(
#           target_mask[tf.newaxis] *
#           self._model.predict_on_batch(input_sequence)[output_head]) / target_mask_mass

#     input_grad = tape.gradient(prediction, input_sequence) * input_sequence
#     input_grad = tf.squeeze(input_grad, axis=0)
#     return tf.reduce_sum(input_grad, axis=-1)


# class EnformerScoreVariantsRaw:

#   def __init__(self, tfhub_url, organism='human'):
#     self._model = Enformer(tfhub_url)
#     self._organism = organism

#   def predict_on_batch(self, inputs):
#     ref_prediction = self._model.predict_on_batch(inputs['ref'])[self._organism]
#     alt_prediction = self._model.predict_on_batch(inputs['alt'])[self._organism]

#     return alt_prediction.mean(axis=1) - ref_prediction.mean(axis=1)


# class EnformerScoreVariantsNormalized:

#   def __init__(self, tfhub_url, transform_pkl_path,
#                organism='human'):
#     assert organism == 'human', 'Transforms only compatible with organism=human'
#     self._model = EnformerScoreVariantsRaw(tfhub_url, organism)
#     with tf.io.gfile.GFile(transform_pkl_path, 'rb') as f:
#       transform_pipeline = joblib.load(f)
#     self._transform = transform_pipeline.steps[0][1]  # StandardScaler.

#   def predict_on_batch(self, inputs):
#     scores = self._model.predict_on_batch(inputs)
#     return self._transform.transform(scores)


# class EnformerScoreVariantsPCANormalized:

#   def __init__(self, tfhub_url, transform_pkl_path,
#                organism='human', num_top_features=500):
#     self._model = EnformerScoreVariantsRaw(tfhub_url, organism)
#     with tf.io.gfile.GFile(transform_pkl_path, 'rb') as f:
#       self._transform = joblib.load(f)
#     self._num_top_features = num_top_features

#   def predict_on_batch(self, inputs):
#     scores = self._model.predict_on_batch(inputs)
#     return self._transform.transform(scores)[:, :self._num_top_features]


# # TODO(avsec): Add feature description: Either PCX, or full names.

In [4]:
# @title `variant_centered_sequences`

class FastaStringExtractor:

    def __init__(self, fasta_file):
        self.fasta = pyfaidx.Fasta(fasta_file)
        self._chromosome_sizes = {k: len(v) for k, v in self.fasta.items()}

    def extract(self, interval: Interval, **kwargs) -> str:
        # Truncate interval if it extends beyond the chromosome lengths.
        chromosome_length = self._chromosome_sizes[interval.chrom]
        trimmed_interval = Interval(interval.chrom,
                                    max(interval.start, 0),
                                    min(interval.end, chromosome_length),
                                    )
        # pyfaidx wants a 1-based interval
        sequence = str(self.fasta.get_seq(trimmed_interval.chrom,
                                          trimmed_interval.start + 1,
                                          trimmed_interval.stop).seq).upper()
        # Fill truncated values with N's.
        pad_upstream = 'N' * max(-interval.start, 0)
        pad_downstream = 'N' * max(interval.end - chromosome_length, 0)
        return pad_upstream + sequence + pad_downstream

    def close(self):
        return self.fasta.close()


def variant_generator(vcf_file, gzipped=False):
  """Yields a kipoiseq.dataclasses.Variant for each row in VCF file."""
  def _open(file):
    return gzip.open(vcf_file, 'rt') if gzipped else open(vcf_file)

  with _open(vcf_file) as f:
    for line in f:
      if line.startswith('#'):
        continue
      chrom, pos, id, ref, alt_list = line.split('\t')[:5]
      # Split ALT alleles and return individual variants as output.
      for alt in alt_list.split(','):
        yield kipoiseq.dataclasses.Variant(chrom=chrom, pos=pos,
                                           ref=ref, alt=alt, id=id)


def one_hot_encode(sequence):
  return kipoiseq.transforms.functional.one_hot_dna(sequence).astype(np.float32)


def variant_centered_sequences(vcf_file, sequence_length, gzipped=False,
                               chr_prefix=''):
  seq_extractor = kipoiseq.extractors.VariantSeqExtractor(
    reference_sequence=FastaStringExtractor(fasta_file))

  for variant in variant_generator(vcf_file, gzipped=gzipped):
    interval = Interval(chr_prefix + variant.chrom,
                        variant.pos, variant.pos)
    interval = interval.resize(sequence_length)
    center = interval.center() - interval.start

    reference = seq_extractor.extract(interval, [], anchor=center)
    alternate = seq_extractor.extract(interval, [variant], anchor=center)

    yield {'inputs': {'ref': one_hot_encode(reference),
                      'alt': one_hot_encode(alternate)},
           'metadata': {'chrom': chr_prefix + variant.chrom,
                        'pos': variant.pos,
                        'id': variant.id,
                        'ref': variant.ref,
                        'alt': variant.alt}}

In [5]:
fasta_extractor38 = FastaStringExtractor('/lus/grand/projects/TFXcan/imlab/data/hg_sequences/hg38/Homo_sapiens_assembly38.fasta')
fasta_extractor19 = FastaStringExtractor('/lus/grand/projects/TFXcan/imlab/data/hg_sequences/hg19/raw/genome.fa')

### checking values

In [3]:
with h5py.File(f'/grand/TFXcan/imlab/data/enformer_training_data/basenji_data_h5/no_groups_popseq_revised_order/test_pop_seq.hdf5') as f:
    intervals = f['query_regions'][()]

In [5]:
intervals[200]

array([       0, 39808995, 39940067])

In [38]:
def check_int_seqs(intervals, i):
    for interval in intervals:

        if i % 20 == 0:
            print(f'now on index {i}, interval {interval}')

        chr = interval[0]
        start = interval[1]
        end = interval[2]

        if chr == 0:
            chr = 'X'

        with h5py.File(f'/grand/TFXcan/imlab/data/enformer_training_data/basenji_data_h5/no_groups_popseq_revised_order/test_pop_seq.hdf5') as f:
            target_seq = f['sequence'][()][i,:,:]
        
        target_interval = kipoiseq.Interval(f'chr{chr}', start, end)
        seq38 = one_hot_encode(fasta_extractor38.extract(target_interval))
        
        if np.array_equal(target_seq, seq38) == False:
            print(f'not equal at index {i}, interval {interval}')
        
        i += 1


In [None]:
check_int_seqs(intervals[541:], 541)

now on index 1160, interval [        0 125746319 125877391]
now on index 1180, interval [        0 122535027 122666099]
now on index 1200, interval [      14 77070902 77201974]
now on index 1220, interval [      14 58376595 58507667]
now on index 1240, interval [        3 189435603 189566675]
now on index 1260, interval [      15 23391563 23522635]
now on index 1280, interval [      12 11085476 11216548]
now on index 1300, interval [      14 96797410 96928482]
now on index 1320, interval [      14 58605973 58737045]
now on index 1340, interval [       14 105169707 105300779]
now on index 1360, interval [      14 36585685 36716757]
now on index 1380, interval [      14 87507601 87638673]
now on index 1400, interval [      10 37784915 37915987]
now on index 1420, interval [       14 100926214 101057286]
now on index 1440, interval [      14 45416738 45547810]
now on index 1460, interval [      11 32207325 32338397]
now on index 1480, interval [       14 104825640 104956712]
now on index 

Every sequence except for the ones on indices 317, 1159, and 541 matched up with the correspomding extracted sequence. the 317 and 1159 sequences gave a KeyError that suggested there might be an "M" in the extracted region, and the 514th sequence simply did not match it's correspomdong extraction, both of which should be looked into further.

In [23]:
# "M" error output:

# 'M' error at index 317, interval ('chr1', 248654924, 248785996)
target_interval = kipoiseq.Interval('chr1', 248654924, 248785996)
extract = fasta_extractor38.extract(target_interval)
seq38 = one_hot_encode(extract)

KeyError: 'M'

In [None]:
# 'M' error at index 1159, interval [12 132224362 132355434]

In [None]:
# not equal at 541

## Comparing matrices to ground truth

for my first test, I chose the 1000th index from the basenji dataset. The interval was "chr17_20693681_20824753". So I put that interval normally, shifted up 3 bps, and shifted down 3bps, and selected the option for computing the reverse complement in the enformer pipeline. After getting the prediction results, I read them all in and calculated the averages I wanted.

In [None]:
# read in predictions

with h5py.File("/grand/TFXcan/imlab/users/lvairus/reftile_project/predictions_folder/chr17_reference_avg6_shifted_regions/predictions_2023-08-15/chr17_predictions/chr17_reference_avg6/haplotype0/chr17_20693681_20824753_predictions.h5") as f:
    forward = f['chr17_20693681_20824753'][()]

with h5py.File("/grand/TFXcan/imlab/users/lvairus/reftile_project/predictions_folder/chr17_reference_avg6_shifted_regions/predictions_2023-08-15/chr17_predictions/chr17_reference_avg6/haplotype0_rc/chr17_20693681_20824753_predictions.h5") as f:
    reverse = f['chr17_20693681_20824753'][()]

forrev = (forward + reverse) / 2

with h5py.File("/grand/TFXcan/imlab/users/lvairus/reftile_project/predictions_folder/chr17_reference_avg6_shifted_regions/predictions_2023-08-15/chr17_predictions/chr17_reference_avg6/haplotype0/chr17_20693678_20824750_predictions.h5") as f:
    shift_down = f['chr17_20693678_20824750'][()]

with h5py.File("/grand/TFXcan/imlab/users/lvairus/reftile_project/predictions_folder/chr17_reference_avg6_shifted_regions/predictions_2023-08-15/chr17_predictions/chr17_reference_avg6/haplotype0_rc/chr17_20693678_20824750_predictions.h5") as f:
    shift_down_rc = f['chr17_20693678_20824750'][()]

with h5py.File("/grand/TFXcan/imlab/users/lvairus/reftile_project/predictions_folder/chr17_reference_avg6_shifted_regions/predictions_2023-08-15/chr17_predictions/chr17_reference_avg6/haplotype0/chr17_20693684_20824756_predictions.h5") as f:
    shift_up = f['chr17_20693684_20824756'][()]

with h5py.File("/grand/TFXcan/imlab/users/lvairus/reftile_project/predictions_folder/chr17_reference_avg6_shifted_regions/predictions_2023-08-15/chr17_predictions/chr17_reference_avg6/haplotype0_rc/chr17_20693684_20824756_predictions.h5") as f:
    shift_up_rc = f['chr17_20693684_20824756'][()]

forrevshift = (forward + reverse + shift_down + shift_down_rc + shift_up + shift_up_rc) / 6

with h5py.File("/grand/TFXcan/imlab/data/enformer_training_data/basenji_data_h5/no_groups_popseq_revised_order/test_pop_seq.hdf5") as f:
    truth = f['target'][1000]

In [None]:
# Comparison functions

def get_diffmat(mat1, mat2):
    
    diffmat = mat1 - mat2
    abs_diffmat = np.abs(diffmat)

    colwise_maxes1 = np.max(mat1, axis=0)
    colwise_maxes2 = np.max(mat2, axis=0)

    colwise_maxes_maxes = np.maximum(colwise_maxes1, colwise_maxes2)

    relmax3_diffmat = diffmat / colwise_maxes_maxes
    relmax3_diffmat = np.abs(relmax3_diffmat)

    return relmax3_diffmat


def get_summary(arr):
    summary = {
        "mean": np.mean(arr),
        "median": np.median(arr),
        "minimum": np.min(arr),
        "maximum": np.max(arr),
        "q1": np.percentile(arr, 25),
        "q3": np.percentile(arr, 75),
    }
    return summary


def plot_hist(arr, bin_num, xlab='Value', ylab='Frequency', title='Histogram'):
    plt.hist(arr, bins=bin_num)
    plt.title(title)
    plt.xlabel(xlab)
    plt.ylabel(ylab)
    plt.show()


In [None]:
# getting diffmats

diff_forward = get_diffmat(forward, truth)
diff_reverse = get_diffmat(reverse, truth)
diff_forrev = get_diffmat(forrev, truth)
diff_forrevshift = get_diffmat(forrevshift, truth)

In [None]:
# getting summaries

get_summary(diff_forward), get_summary(diff_reverse), get_summary(diff_forrev), get_summary(diff_forrevshift)

the normal forward matrix had the smallest median and mean of differences between it and the ground truth matrix

I also compared matrices by their column correlations

In [None]:

corvec_forward = np.empty(5313)
corvec_reverse = np.empty(5313)
corvec_forrev = np.empty(5313)
corvec_forrevshift = np.empty(5313)

for col in range(5313):
    col_correlation = np.corrcoef(forward[:, col], truth[:, col])[0, 1]
    col_correlation = np.corrcoef(reverse[:, col], truth[:, col])[0, 1]
    col_correlation = np.corrcoef(forrev[:, col], truth[:, col])[0, 1]
    col_correlation = np.corrcoef(forrevshift[:, col], truth[:, col])[0, 1]

    corvec_forward[col] = col_correlation
    corvec_reverse [col] = col_correlation
    corvec_forrev[col] = col_correlation
    corvec_forrevshift[col] = col_correlation


In [None]:
plot_hist(corvec_forward, 1000, title="forward")
get_summary(corvec_forward)

In [None]:
plot_hist(corvec_reverse, 1000, title="reverse")
get_summary(corvec_reverse)

In [None]:
plot_hist(corvec_forrev, 1000, title="forrev")
get_summary(corvec_forrev)

In [None]:
plot_hist(corvec_forrevshift, 1000, title="forrevshift")
get_summary(corvec_forrevshift)

Again, the normal forward matrix was the closest to the ground truth with the highest column-wise correlation median and mean

Conclusion: using just the normal forward prediction of an interval will give you the closest results to the ground truth, so that's what we will be using for our predictions