## Import Libraries

In [11]:
import tensorflow as tf
# Make sure the GPU is enabled 
assert tf.config.list_physical_devices('GPU'), 'Start the colab kernel with GPU: Runtime -> Change runtime type -> GPU'
print("Num GPUs Available: ", len(tf.config.list_physical_devices('GPU')))
import tensorflow_hub as hub # for interacting with saved models and tensorflow hub
import joblib
import gzip # for manipulating compressed files
import kipoiseq # for manipulating fasta files
from kipoiseq import Interval # same as above, really
import pyfaidx # to index our reference genome file
import pandas as pd # for manipulating dataframes
import numpy as np # for numerical computations
import matplotlib.pyplot as plt # for plotting
import matplotlib as mpl # for plotting
import seaborn as sns # for plotting
import pickle # for saving large objects
import os, sys # functions for interacting with the operating system

%matplotlib inline
%config InlineBackend.figure_format = 'retina'

Num GPUs Available:  1


## Define Paths

In [12]:
transform_path = 'gs://dm-enformer/models/enformer.finetuned.SAD.robustscaler-PCA500-robustscaler.transform.pkl'
model_path = 'https://tfhub.dev/deepmind/enformer/1'
fasta_file = '/home/s1mi/enformer-tutorial/genome.fa'
targets_txt = 'https://raw.githubusercontent.com/calico/basenji/master/manuscripts/cross2020/targets_human.txt'
df_targets = pd.read_csv(targets_txt, sep='\t')

## Define Functions

In [13]:
# @title `Enformer`, `EnformerScoreVariantsNormalized`, `EnformerScoreVariantsPCANormalized`,
SEQUENCE_LENGTH = 393216

class Enformer:

  def __init__(self, tfhub_url):
    self._model = hub.load(tfhub_url).model

  def predict_on_batch(self, inputs):
    predictions = self._model.predict_on_batch(inputs)
    return {k: v.numpy() for k, v in predictions.items()}

  @tf.function
  def contribution_input_grad(self, input_sequence,
                              target_mask, output_head='human'):
    input_sequence = input_sequence[tf.newaxis]

    target_mask_mass = tf.reduce_sum(target_mask)
    with tf.GradientTape() as tape:
      tape.watch(input_sequence)
      prediction = tf.reduce_sum(
          target_mask[tf.newaxis] *
          self._model.predict_on_batch(input_sequence)[output_head]) / target_mask_mass

    input_grad = tape.gradient(prediction, input_sequence) * input_sequence
    input_grad = tf.squeeze(input_grad, axis=0)
    return tf.reduce_sum(input_grad, axis=-1)


class EnformerScoreVariantsRaw:

  def __init__(self, tfhub_url, organism='human'):
    self._model = Enformer(tfhub_url)
    self._organism = organism

  def predict_on_batch(self, inputs):
    ref_prediction = self._model.predict_on_batch(inputs['ref'])[self._organism]
    alt_prediction = self._model.predict_on_batch(inputs['alt'])[self._organism]

    return alt_prediction.mean(axis=1) - ref_prediction.mean(axis=1)


class EnformerScoreVariantsNormalized:

  def __init__(self, tfhub_url, transform_pkl_path,
               organism='human'):
    assert organism == 'human', 'Transforms only compatible with organism=human'
    self._model = EnformerScoreVariantsRaw(tfhub_url, organism)
    with tf.io.gfile.GFile(transform_pkl_path, 'rb') as f:
      transform_pipeline = joblib.load(f)
    self._transform = transform_pipeline.steps[0][1]  # StandardScaler.

  def predict_on_batch(self, inputs):
    scores = self._model.predict_on_batch(inputs)
    return self._transform.transform(scores)


class EnformerScoreVariantsPCANormalized:

  def __init__(self, tfhub_url, transform_pkl_path,
               organism='human', num_top_features=500):
    self._model = EnformerScoreVariantsRaw(tfhub_url, organism)
    with tf.io.gfile.GFile(transform_pkl_path, 'rb') as f:
      self._transform = joblib.load(f)
    self._num_top_features = num_top_features

  def predict_on_batch(self, inputs):
    scores = self._model.predict_on_batch(inputs)
    return self._transform.transform(scores)[:, :self._num_top_features]


# TODO(avsec): Add feature description: Either PCX, or full names.


# @title `variant_centered_sequences`

class FastaStringExtractor:

    def __init__(self, fasta_file):
        self.fasta = pyfaidx.Fasta(fasta_file)
        self._chromosome_sizes = {k: len(v) for k, v in self.fasta.items()}
    #import pd.Interval as Interval
    def extract(self, interval: Interval, **kwargs) -> str:
        # Truncate interval if it extends beyond the chromosome lengths.
        chromosome_length = self._chromosome_sizes[interval.chrom]
        trimmed_interval = Interval(interval.chrom,
                                    max(interval.start, 0),
                                    min(interval.end, chromosome_length),
                                    )
        # pyfaidx wants a 1-based interval
        sequence = str(self.fasta.get_seq(trimmed_interval.chrom,
                                          trimmed_interval.start + 1,
                                          trimmed_interval.stop).seq).upper()
        # Fill truncated values with N's.
        pad_upstream = 'N' * max(-interval.start, 0)
        pad_downstream = 'N' * max(interval.end - chromosome_length, 0)
        return pad_upstream + sequence + pad_downstream

    def close(self):
        return self.fasta.close()


def one_hot_encode(sequence):
  return kipoiseq.transforms.functional.one_hot_dna(sequence).astype(np.float32)



# @title `plot_tracks`

def plot_tracks(tracks, interval, height=1.5):
  fig, axes = plt.subplots(len(tracks), 1, figsize=(20, height * len(tracks)), sharex=True)
  for ax, (title, y) in zip(axes, tracks.items()):
    ax.fill_between(np.linspace(interval.start, interval.end, num=len(y)), y)
    ax.set_title(title)
    sns.despine(top=True, right=True, bottom=True)
  ax.set_xlabel(str(interval))
  plt.tight_layout()

In [14]:
import Bio

from Bio.Seq import Seq
def create_rev_complement(dna_string):
    return(str(Seq(dna_string).reverse_complement()))

In [15]:
def prepare_for_quantify_prediction_per_TSS(predictions, gene, tss_df):

  '''

  Parameters:
          predicitions (A numpy array): All predictions from the track
          gene (a gene name, character): a gene
          tss_df: a list of dataframe of genes and their transcription start sites
  Returns:
          A dictionary of cage experiment predictions and a list of transcription start sites

  '''

  output = dict()
  for tdf in tss_df:
    if gene not in tdf.genes.values:
      continue
    gene_tss_list = tdf[tdf.genes == gene].txStart_Sites.apply(str).values
    gene_tss_list = [t.split(', ') for t in gene_tss_list]
    gene_tss_list = [int(item) for nestedlist in gene_tss_list for item in nestedlist]
    gene_tss_list = list(set(gene_tss_list))
  output['cage_predictions'] = predictions[:, 5110] # a numpy array
  output['gene_TSS'] = gene_tss_list # a list


  return(output) # a dictionary

def quantify_prediction_per_TSS(low_range, TSS, cage_predictions):

  '''
  Parameters:
          low_range (int): The lower interval
          TSS (list of integers): A list of TSS for a gene
          cage_predictions: A 1D numpy array or a vector of predictions from enformer corresponding to track 5110 or CAGE predictions
  Returns:
          A dictionary of gene expression predictions for each TSS for a gene
    '''
  tss_predictions = dict()
  for tss in TSS:
    bin_start = low_range + ((768 + 320) * 128)
    count = -1
    while bin_start < tss:
      bin_start = bin_start + 128
      count += 1
    if count >= len(cage_predictions)-1:
      continue
    cage_preds = cage_predictions[count - 1] + cage_predictions[count] + cage_predictions[count + 1]
    tss_predictions[tss] = cage_preds

  return(tss_predictions)

def collect_intervals(chromosomes = ["22"], gene_list=None):

  '''
    Parameters :
      chromosomes : a list of chromosome numbers; each element should be a string format
      gene_list : a list of genes; the genes should be located on those chromosomes

    Returns :
      A dictionary of genes (from gene_list) and their intervals within their respective chromosomes
  '''

  gene_intervals = {} # Collect intervals for our genes of interest

  for chrom in chromosomes:
    with open("/home/s1mi/enformer-tutorial/gene_chroms/gene_"+ chrom + ".txt", "r") as chrom_genes:
      for line in chrom_genes:
        split_line = line.strip().split("\t")
        gene_intervals[split_line[2]] = [
                                          split_line[0],
                                          int(split_line[3]),
                                          int(split_line[4])
                                        ]

  if isinstance(gene_list, list): # if the user has supplied a list of genes they are interested in
    use_genes = dict((k, gene_intervals[k]) for k in gene_list if k in gene_intervals)
    return(use_genes)
  elif isinstance(gene_list, type(None)):
    return(gene_intervals)


def run_predictions(gene_intervals, tss_dataframe, individuals_list=None):
  '''
  Parameters :
    gene_intervals : the results from calling `collect_intervals`
    tss_dataframe : a list of the TSSs dataframes i.e. the TSS for the genes in the chromosomes
    individuals_list : a list of individuals on which we want to make predictions; defaults to None

  Returns :
    A list of predictions; the first element is the predictions around the TSS for each gene. The second is the prediction across CAGE tracks
  '''

  gene_output = dict()
  gene_predictions = dict()

  for gene in gene_intervals.keys():
    gene_interval = gene_intervals[gene]
    target_interval = kipoiseq.Interval("chr" + gene_interval[0],
                                        gene_interval[1],
                                        gene_interval[2]) # creates an interval to select the right sequences
    target_fa = fasta_extractor.extract(target_interval.resize(SEQUENCE_LENGTH))  # extracts the fasta sequences, and resizes such that it is compatible with the sequence_length
    window_coords = target_interval.resize(SEQUENCE_LENGTH) # we also need information about the start and end locations after resizing
    try:
      cur_gene_vars = pd.read_csv("/home/s1mi/enformer-tutorial/individual_beds/chr" + gene_interval[0] + "/chr" + gene_interval[0] + "_"+ gene + ".bed", sep="\t", header=0) # read in the appropriate bed file for the gene
    except:
      continue
    individual_results = dict()
    individual_prediction = dict()

    if isinstance(individuals_list, list) or isinstance(individuals_list, type(np.empty([1, 1]))):
      use_individuals = individuals_list
    elif isinstance(individuals_list, type(None)):
      use_individuals = cur_gene_vars.columns[4:]

    for individual in use_individuals:
      print('Currently on gene {}, and predicting on individual {}...'.format(gene, individual))
      # two haplotypes per individual
      haplo_1 = list(target_fa[:])
      haplo_2 = list(target_fa[:])

      ref_mismatch_count = 0
      for i,row in cur_gene_vars.iterrows():

        geno = row[individual].split("|")
        if (row["POS"]-window_coords.start-1) >= len(haplo_2):
          continue
        if (row["POS"]-window_coords.start-1) < 0:
          continue
        if geno[0] == "1":
          haplo_1[row["POS"]-window_coords.start-1] = row["ALT"]
        if geno[1] == "1":
          haplo_2[row["POS"]-window_coords.start-1] = row["ALT"]

      # predict on the individual's two haplotypes
      prediction_1 = model.predict_on_batch(one_hot_encode("".join(haplo_1))[np.newaxis])['human'][0]
      prediction_2 = model.predict_on_batch(one_hot_encode("".join(haplo_2))[np.newaxis])['human'][0]

      temp_predictions = [prediction_1[:, 5110], prediction_2[:, 5110]] # CAGE predictions we are interested in
      individual_prediction[individual] = temp_predictions

      # Calculate TSS CAGE expression which correspond to column 5110 of the predictions above
      temp_list = list()

      pred_prepared_1 = prepare_for_quantify_prediction_per_TSS(predictions=prediction_1, gene=gene, tss_df=tss_dataframe)
      tss_predictions_1 = quantify_prediction_per_TSS(low_range = window_coords.start, TSS=pred_prepared_1['gene_TSS'], cage_predictions=pred_prepared_1['cage_predictions'])

      pred_prepared_2 = prepare_for_quantify_prediction_per_TSS(predictions=prediction_2, gene=gene, tss_df=tss_dataframe)
      tss_predictions_2 = quantify_prediction_per_TSS(low_range = window_coords.start, TSS=pred_prepared_2['gene_TSS'], cage_predictions=pred_prepared_2['cage_predictions'])

      temp_list.append(tss_predictions_1)
      temp_list.append(tss_predictions_2) # results here are a dictionary for each TSS for each haplotype

      individual_results[individual] = temp_list # save for the individual

    gene_output[gene] = individual_results
    gene_predictions[gene] = individual_prediction

  return([gene_output, gene_predictions])


def collect_target_intervals(gene_intervals):

  '''
  Returns a dictionary of Interval objects (from kipoiseq) for each gene corresponding to the locations of the gene
  '''

  target_intervals_dict = dict()

  for gene in gene_intervals.keys():
    gene_interval = gene_intervals[gene]
    target_interval = kipoiseq.Interval("chr" + gene_interval[0],
                                        gene_interval[1],
                                        gene_interval[2])
    target_intervals_dict[gene] = target_interval

  return(target_intervals_dict)

def prepare_for_plot_tracks(gene, individual, all_predictions, chromosome=['22']):

  '''
  This returns a dictionary of gene tracks and gene intervals, prepared for the function plot_tracks.

  Parameters:
    - gene
    - individual
    - all_predictions
  '''

  haplo_predictions = all_predictions[gene][individual]
  gene_tracks = {gene + ' | ' + individual + ' | haplotype 1': np.log10(1 + haplo_predictions[0]),
                gene + ' | ' + individual + ' | haplotype 2': np.log10(1 + haplo_predictions[1])}

  gene_intervals = collect_intervals(chromosomes=chromosome, gene_list=[gene])
  gene_intervals = collect_target_intervals(gene_intervals)

  output = dict()
  output['gene_tracks'] = gene_tracks
  output['gene_intervals'] = gene_intervals[gene]

  return(output)

def check_individuals(path_to_bed_file, list_of_individuals):

  '''
  Checks if an individual is missing in bed variation files.
  These individuals should be removed prior to training
  '''

  myfile = open(path_to_bed_file, 'r')
  myline = myfile.readline()
  bed_names = myline.split('\t')[4:]
  myfile.close()

  if set(list_of_individuals).issubset(set(bed_names)) == False:
    missing = list(set(list_of_individuals).difference(bed_names))
    print('This (or these) individual(s) is/are not present: {}'.format(missing))
  else:
    missing = []
    print('All individuals are present in the bed file.')

  return(missing)

In [16]:
def geno_to_seq(gene, individual):
      # two haplotypes per individual
  haplo_1 = list(target_fa[:])
  haplo_2 = list(target_fa[:])

  ref_mismatch_count = 0
  for i,row in cur_gene_vars.iterrows():

    geno = row[individual].split("|")
    if (row["POS"]-window_coords.start-1) >= len(haplo_2):
      continue
    if (row["POS"]-window_coords.start-1) < 0:
      continue
    if geno[0] == "1":
      haplo_1[row["POS"]-window_coords.start-1] = row["ALT"]
    if geno[1] == "1":
      haplo_2[row["POS"]-window_coords.start-1] = row["ALT"]
  return haplo_1, haplo_2

      # predict on the individual's two haplotypes
    

## Prepare input data

We want to predict epigenome around ERAP2 TSS on chromosome 5.

In [17]:
chr5_tss = pd.read_table('/home/s1mi/enformer-tutorial/tss_by_chr/chr5_tss_by_gene.txt', sep='\t')
erap2_variations = pd.read_table('/home/s1mi/enformer-tutorial/individual_beds/chr5/chr5_ERAP2.bed', sep='\t')
geuvadis_gene_expression = pd.read_table('https://uchicago.box.com/shared/static/5vwc7pjw9qmtv7298c4rc7bcuicoyemt.gz', sep='\t',
                                         dtype={'gene_id': str, 'gene_name':str, 'TargetID':str, 'Chr':str})
geuvadis_gene_expression.head(5)


Unnamed: 0,gene_id,gene_name,TargetID,Chr,Coord,HG00096,HG00097,HG00099,HG00100,HG00101,...,NA20810,NA20811,NA20812,NA20813,NA20814,NA20815,NA20816,NA20819,NA20826,NA20828
0,ENSG00000223972.4,DDX11L1,ENSG00000223972.4,1,11869,0.320818,0.344202,0.354225,0.478064,-0.102815,...,1.008605,0.384489,0.581284,0.513981,0.667449,0.35089,0.186103,-0.037976,0.405439,0.199143
1,ENSG00000227232.3,WASH7P,ENSG00000227232.3,1,29806,33.714457,20.185174,18.095407,24.100871,29.018719,...,30.980194,34.086207,39.678442,29.643513,27.12042,29.121624,31.117198,32.047074,22.798959,23.563874
2,ENSG00000243485.1,MIR1302-11,ENSG00000243485.1,1,29554,0.240408,0.157456,0.218806,0.320878,0.067833,...,0.06594,0.228784,0.140642,0.283905,0.273821,0.286311,0.32406,0.049574,0.255288,0.15744
3,ENSG00000238009.2,RP11-34P13.7,ENSG00000238009.2,1,133566,0.328272,0.327932,0.090064,0.420443,0.220269,...,0.274071,0.384179,0.533693,0.307221,0.307367,0.400278,0.612321,0.666633,0.281138,1.346129
4,ENSG00000239945.1,RP11-34P13.8,ENSG00000239945.1,1,91105,0.332171,-0.032164,0.017323,0.424677,0.214025,...,0.347323,0.346744,0.07358,0.400396,0.470517,0.069749,0.299353,0.090019,0.282554,-0.15717


In [19]:
gene_intervals = collect_intervals(chromosomes=['5'], gene_list=['ERAP2'])


In [20]:
model = Enformer(model_path) # here we load the model architecture.

fasta_extractor = FastaStringExtractor(fasta_file) # we define a class called fasta_extractor to help us extra raw sequence data

## Run Predictions

We'll pick one individual at random.

In [21]:
rand_individual = np.random.choice(a=geuvadis_gene_expression.columns[6:-1], replace=False) # individuals we are interested in
rand_individual

'NA19160'

In [22]:
gene = 'ERAP2'
gene_interval = gene_intervals[gene]
target_interval = kipoiseq.Interval("chr" + gene_interval[0],
                                        gene_interval[1],
                                        gene_interval[2])
target_fa = fasta_extractor.extract(target_interval.resize(SEQUENCE_LENGTH))
window_coords = target_interval.resize(SEQUENCE_LENGTH)
cur_gene_vars = pd.read_csv("/home/s1mi/enformer-tutorial/individual_beds/chr" + gene_interval[0] + "/chr" + gene_interval[0] + "_"+ gene + ".bed", sep="\t", header=0) # read in the appropriate bed file for the gene


In [23]:
haplo_1, haplo_2 = geno_to_seq('ERAP2', rand_individual)

haplo_1_enc = one_hot_encode("".join(haplo_1))[np.newaxis]
haplo_2_enc = one_hot_encode("".join(haplo_2))[np.newaxis]
average_enc = np.add(haplo_1_enc, haplo_2_enc) / 2

In [24]:
prediction_1 = model.predict_on_batch(haplo_1_enc)['human'][0]
prediction_2 = model.predict_on_batch(haplo_2_enc)['human'][0]

pre_average = model.predict_on_batch(average_enc)['human'][0]
post_average = (prediction_1 + prediction_2) / 2


In [25]:
print(pre_average)
print(post_average)

[[0.09901776 0.09164002 0.0699944  ... 0.00303012 0.00990325 0.01237618]
 [0.05784911 0.05089986 0.03969494 ... 0.00734875 0.02510322 0.0280255 ]
 [0.00580059 0.00483631 0.00361973 ... 0.0036824  0.00902823 0.01077321]
 ...
 [0.00023449 0.00016057 0.00016778 ... 0.00083343 0.0026698  0.00314449]
 [0.00304364 0.00254643 0.00265122 ... 0.00092347 0.00226813 0.00265463]
 [0.05069465 0.03909001 0.03585194 ... 0.00824723 0.03236783 0.0296576 ]]
[[0.09910323 0.09164958 0.07005598 ... 0.00303381 0.00984134 0.01238221]
 [0.05783322 0.05087575 0.03971903 ... 0.00735789 0.02497963 0.0280641 ]
 [0.00579469 0.00483093 0.00361714 ... 0.00366731 0.00894659 0.01073787]
 ...
 [0.00024071 0.00016465 0.00017168 ... 0.0007882  0.00254766 0.00304421]
 [0.0030893  0.00256867 0.00267111 ... 0.00086807 0.0020461  0.00242939]
 [0.05074473 0.03918105 0.03594471 ... 0.00819363 0.03179348 0.02920263]]


## Comparing across tracks

In [30]:
res = []
for i in range(5313):
    pre_track = pre_average[:, i]
    post_track = post_average[:, i]
    corr = np.corrcoef(pre_track, post_track)[0][1]
    res.append(corr)


The results from both methods are nearly identical.

In [31]:
print(min(res), max(res))

0.9910764472374001 0.999995405486603
