Permalink
Fetching contributors…
Cannot retrieve contributors at this time
260 lines (202 sloc) 9.17 KB
# python2, 3 compatibility
from __future__ import absolute_import, division, print_function
import six
import os
import sys
import inspect
from builtins import str, open, range, dict
import pickle
import numpy as np
import pandas as pd
import pybedtools
from pybedtools import BedTool
from sklearn.preprocessing import FunctionTransformer
from genomelake.extractors import BaseExtractor, FastaExtractor, one_hot_encode_sequence, NUM_SEQ_CHARS
from pysam import FastaFile
from concise.preprocessing.splines import encodeSplines
from concise.utils.position import extract_landmarks, ALL_LANDMARKS
from gtfparse import read_gtf
from kipoi.metadata import GenomicRanges
import linecache
from kipoi.data import Dataset
import warnings
filename = inspect.getframeinfo(inspect.currentframe()).filename
DATALOADER_DIR = os.path.dirname(os.path.abspath(filename))
def sign_log_func(x):
return np.sign(x) * np.log10(np.abs(x) + 1)
def sign_log_func_inverse(x):
return np.sign(x) * (np.power(10, np.abs(x)) - 1)
class BedToolLinecache(BedTool):
"""Fast BedTool accessor by Ziga Avsec
Normal BedTools loops through the whole file to get the
line of interest. Hence the access it o(n)
"""
def __getitem__(self, idx):
line = linecache.getline(self.fn, idx + 1)
return pybedtools.create_interval_from_list(line.strip().split("\t"))
class DistanceTransformer:
"""Transforms the raw distances to the appropriate modeling form
"""
def __init__(self, pos_features, pipeline_obj_path):
"""
Args:
pos_features: list of positional features to use
pipeline_obj_path: path to the serialized pipeline obj_path
"""
self.pos_features = pos_features
self.pipeline_obj_path = pipeline_obj_path
# deserialize the pickle file
with open(self.pipeline_obj_path, "rb") as f:
pipeline_obj = pickle.load(f)
self.POS_FEATURES = pipeline_obj[0]
self.minmax_scaler = pipeline_obj[1]
self.imp = pipeline_obj[2]
self.funct_transform = FunctionTransformer(func=sign_log_func,
inverse_func=sign_log_func_inverse)
# for simplicity, assume all current pos_features are the
# same as from before
assert self.POS_FEATURES == self.pos_features
def transform(self, x):
# impute missing values and rescale the distances
xnew = self.minmax_scaler.transform(self.funct_transform.transform(self.imp.transform(x)))
# convert distances to spline bases
dist = {"dist_" + k: encodeSplines(xnew[:, i, np.newaxis], start=0, end=1, warn=False)
for i, k in enumerate(self.POS_FEATURES)}
return dist
class DistToClosestLandmarkExtractor(BaseExtractor):
"""Extract distances to the closest genomic landmark
# Arguments
gtf_file: Genomic annotation file path (say gencode gtf)
landmarks: List of landmarks to extract. See `concise.utils.position.extract_landmarks`
use_strand: Take into account the strand of the intervals
"""
multiprocessing_safe = True
def __init__(self, gtf_file, landmarks=ALL_LANDMARKS, use_strand=True, **kwargs):
super(DistToClosestLandmarkExtractor, self).__init__(gtf_file, **kwargs)
self._gtf_file = gtf_file
self.landmarks = extract_landmarks(gtf_file, landmarks=landmarks)
self.columns = landmarks # column names. Reqired for concating distances into array
self.use_strand = use_strand
# set index to chromosome and strand - faster access
self.landmarks = {k: v.set_index(["seqname", "strand"])
for k, v in six.iteritems(self.landmarks)}
def _extract(self, intervals, out, **kwargs):
def find_closest(ldm, interval, use_strand=True):
"""Uses
"""
# subset the positions to the appropriate strand
# and extract the positions
ldm_positions = ldm.loc[interval.chrom]
if use_strand and interval.strand != ".":
ldm_positions = ldm_positions.loc[interval.strand]
ldm_positions = ldm_positions.position.values
int_midpoint = (interval.end + interval.start) // 2
dist = (ldm_positions - 1) - int_midpoint # -1 for 0, 1 indexed positions
if use_strand and interval.strand == "-":
dist = - dist
return dist[np.argmin(np.abs(dist))]
out[:] = np.array([[find_closest(self.landmarks[ldm_name], interval, self.use_strand)
for ldm_name in self.columns]
for interval in intervals], dtype=float)
return out
def _get_output_shape(self, num_intervals, width):
return (num_intervals, len(self.columns))
class TxtDataset(Dataset):
def __init__(self, path):
with open(path, "r") as f:
self.lines = f.readlines()
def __len__(self):
return len(self.lines)
def __getitem__(self, idx):
return int(self.lines[idx].strip())
# --------------------------------------------
class SeqDistDataset(Dataset):
"""
Args:
intervals_file: file path; tsv file
Assumes bed-like `chrom start end id score strand` format.
fasta_file: file path; Genome sequence
gtf_file: file path; Genome annotation GTF file.
filter_protein_coding: Considering genomic landmarks only for protein coding genes
preproc_transformer: file path; tranformer used for pre-processing.
target_file: file path; path to the targets
batch_size: int
"""
SEQ_WIDTH = 101
def __init__(self, intervals_file, fasta_file, gtf_file,
target_file=None,
filter_protein_coding=True,
position_transformer_file=None,
use_linecache=True):
if sys.version_info[0] != 3:
warnings.warn("Only Python 3 is supported. You are using Python {0}".format(sys.version_info[0]))
self.gtf = read_gtf(gtf_file)
self.filter_protein_coding = filter_protein_coding
if self.filter_protein_coding:
if "gene_type" in self.gtf:
self.gtf = self.gtf[self.gtf["gene_type"] == "protein_coding"]
elif "gene_biotype" in self.gtf:
self.gtf = self.gtf[self.gtf["gene_biotype"] == "protein_coding"]
else:
warnings.warn("Gtf doesn't have the field 'gene_type' or 'gene_biotype'. Considering genomic landmarks" +
"of all genes not just protein_coding.")
if not np.any(self.gtf.seqname.str.contains("chr")):
self.gtf["seqname"] = "chr" + self.gtf["seqname"]
# intervals
if use_linecache:
self.bt = BedToolLinecache(intervals_file)
else:
self.bt = BedTool(intervals_file)
# extractors
self.fasta_file = fasta_file
self.seq_extractor = None
self.dist_extractor = None
# here the DATALOADER_DIR contains the path to the current directory
if position_transformer_file is None:
raise ValueError("position_transformer_file needs to be specified")
self.dist_transformer = DistanceTransformer(ALL_LANDMARKS, position_transformer_file)
# target
if target_file:
self.target_dataset = TxtDataset(target_file)
assert len(self.target_dataset) == len(self.bt)
else:
self.target_dataset = None
def __len__(self):
return len(self.bt)
def __getitem__(self, idx):
if self.seq_extractor is None:
self.seq_extractor = FastaExtractor(self.fasta_file)
self.dist_extractor = DistToClosestLandmarkExtractor(gtf_file=self.gtf,
landmarks=ALL_LANDMARKS)
interval = self.bt[idx]
if interval.stop - interval.start != self.SEQ_WIDTH:
raise ValueError("Expected the interval to be {0} wide. Recieved stop - start = {1}".
format(self.SEQ_WIDTH, interval.stop - interval.start))
out = {}
out['inputs'] = {}
# input - sequence
out['inputs']['seq'] = np.squeeze(self.seq_extractor([interval]), axis=0)
# input - distance
dist_dict = self.dist_transformer.transform(self.dist_extractor([interval]))
dist_dict = {k: np.squeeze(v, axis=0) for k, v in dist_dict.items()} # squeeze the batch axis
out['inputs'] = {**out['inputs'], **dist_dict}
# targets
if self.target_dataset is not None:
out["targets"] = np.array([self.target_dataset[idx]])
# metadata
out['metadata'] = {}
out['metadata']['ranges'] = GenomicRanges.from_interval(interval)
return out
def test_dataset():
"""Runs tests on the function
"""
# File paths
intervals_file = "example_files/intervals.bed"
target_file = "example_files/targets.tsv"
gtf_file = "example_files/gencode.v24.annotation_chr22.gtf"
fasta_file = "example_files/hg38_chr22.fa"
ds = SeqDistDataset(intervals_file, fasta_file, gtf_file, target_file)
ds[0]
ds[10]
it = ds.batch_iter(32)
next(it)