Permalink
Cannot retrieve contributors at this time
# python2, 3 compatibility | |
from __future__ import absolute_import, division, print_function | |
import six | |
import os | |
import sys | |
import inspect | |
from builtins import str, open, range, dict | |
import pickle | |
import numpy as np | |
import pandas as pd | |
import pybedtools | |
from pybedtools import BedTool | |
from sklearn.preprocessing import FunctionTransformer | |
from genomelake.extractors import BaseExtractor, FastaExtractor, one_hot_encode_sequence, NUM_SEQ_CHARS | |
from pysam import FastaFile | |
from concise.preprocessing.splines import encodeSplines | |
from concise.utils.position import extract_landmarks, ALL_LANDMARKS | |
from gtfparse import read_gtf | |
from kipoi.metadata import GenomicRanges | |
import linecache | |
from kipoi.data import Dataset | |
import warnings | |
filename = inspect.getframeinfo(inspect.currentframe()).filename | |
DATALOADER_DIR = os.path.dirname(os.path.abspath(filename)) | |
def sign_log_func(x): | |
return np.sign(x) * np.log10(np.abs(x) + 1) | |
def sign_log_func_inverse(x): | |
return np.sign(x) * (np.power(10, np.abs(x)) - 1) | |
class BedToolLinecache(BedTool): | |
"""Fast BedTool accessor by Ziga Avsec | |
Normal BedTools loops through the whole file to get the | |
line of interest. Hence the access it o(n) | |
""" | |
def __getitem__(self, idx): | |
line = linecache.getline(self.fn, idx + 1) | |
return pybedtools.create_interval_from_list(line.strip().split("\t")) | |
class DistanceTransformer: | |
"""Transforms the raw distances to the appropriate modeling form | |
""" | |
def __init__(self, pos_features, pipeline_obj_path): | |
""" | |
Args: | |
pos_features: list of positional features to use | |
pipeline_obj_path: path to the serialized pipeline obj_path | |
""" | |
self.pos_features = pos_features | |
self.pipeline_obj_path = pipeline_obj_path | |
# deserialize the pickle file | |
with open(self.pipeline_obj_path, "rb") as f: | |
pipeline_obj = pickle.load(f) | |
self.POS_FEATURES = pipeline_obj[0] | |
self.minmax_scaler = pipeline_obj[1] | |
self.imp = pipeline_obj[2] | |
self.funct_transform = FunctionTransformer(func=sign_log_func, | |
inverse_func=sign_log_func_inverse) | |
# for simplicity, assume all current pos_features are the | |
# same as from before | |
assert self.POS_FEATURES == self.pos_features | |
def transform(self, x): | |
# impute missing values and rescale the distances | |
xnew = self.minmax_scaler.transform(self.funct_transform.transform(self.imp.transform(x))) | |
# convert distances to spline bases | |
dist = {"dist_" + k: encodeSplines(xnew[:, i, np.newaxis], start=0, end=1, warn=False) | |
for i, k in enumerate(self.POS_FEATURES)} | |
return dist | |
class DistToClosestLandmarkExtractor(BaseExtractor): | |
"""Extract distances to the closest genomic landmark | |
# Arguments | |
gtf_file: Genomic annotation file path (say gencode gtf) | |
landmarks: List of landmarks to extract. See `concise.utils.position.extract_landmarks` | |
use_strand: Take into account the strand of the intervals | |
""" | |
multiprocessing_safe = True | |
def __init__(self, gtf_file, landmarks=ALL_LANDMARKS, use_strand=True, **kwargs): | |
super(DistToClosestLandmarkExtractor, self).__init__(gtf_file, **kwargs) | |
self._gtf_file = gtf_file | |
self.landmarks = extract_landmarks(gtf_file, landmarks=landmarks) | |
self.columns = landmarks # column names. Reqired for concating distances into array | |
self.use_strand = use_strand | |
# set index to chromosome and strand - faster access | |
self.landmarks = {k: v.set_index(["seqname", "strand"]) | |
for k, v in six.iteritems(self.landmarks)} | |
def _extract(self, intervals, out, **kwargs): | |
def find_closest(ldm, interval, use_strand=True): | |
"""Uses | |
""" | |
# subset the positions to the appropriate strand | |
# and extract the positions | |
ldm_positions = ldm.loc[interval.chrom] | |
if use_strand and interval.strand != ".": | |
ldm_positions = ldm_positions.loc[interval.strand] | |
ldm_positions = ldm_positions.position.values | |
int_midpoint = (interval.end + interval.start) // 2 | |
dist = (ldm_positions - 1) - int_midpoint # -1 for 0, 1 indexed positions | |
if use_strand and interval.strand == "-": | |
dist = - dist | |
return dist[np.argmin(np.abs(dist))] | |
out[:] = np.array([[find_closest(self.landmarks[ldm_name], interval, self.use_strand) | |
for ldm_name in self.columns] | |
for interval in intervals], dtype=float) | |
return out | |
def _get_output_shape(self, num_intervals, width): | |
return (num_intervals, len(self.columns)) | |
class TxtDataset(Dataset): | |
def __init__(self, path): | |
with open(path, "r") as f: | |
self.lines = f.readlines() | |
def __len__(self): | |
return len(self.lines) | |
def __getitem__(self, idx): | |
return int(self.lines[idx].strip()) | |
# -------------------------------------------- | |
class SeqDistDataset(Dataset): | |
""" | |
Args: | |
intervals_file: file path; tsv file | |
Assumes bed-like `chrom start end id score strand` format. | |
fasta_file: file path; Genome sequence | |
gtf_file: file path; Genome annotation GTF file. | |
filter_protein_coding: Considering genomic landmarks only for protein coding genes | |
preproc_transformer: file path; tranformer used for pre-processing. | |
target_file: file path; path to the targets | |
batch_size: int | |
""" | |
SEQ_WIDTH = 101 | |
def __init__(self, intervals_file, fasta_file, gtf_file, | |
target_file=None, | |
filter_protein_coding=True, | |
position_transformer_file=None, | |
use_linecache=True): | |
if sys.version_info[0] != 3: | |
warnings.warn("Only Python 3 is supported. You are using Python {0}".format(sys.version_info[0])) | |
self.gtf = read_gtf(gtf_file) | |
self.filter_protein_coding = filter_protein_coding | |
if self.filter_protein_coding: | |
if "gene_type" in self.gtf: | |
self.gtf = self.gtf[self.gtf["gene_type"] == "protein_coding"] | |
elif "gene_biotype" in self.gtf: | |
self.gtf = self.gtf[self.gtf["gene_biotype"] == "protein_coding"] | |
else: | |
warnings.warn("Gtf doesn't have the field 'gene_type' or 'gene_biotype'. Considering genomic landmarks" + | |
"of all genes not just protein_coding.") | |
if not np.any(self.gtf.seqname.str.contains("chr")): | |
self.gtf["seqname"] = "chr" + self.gtf["seqname"] | |
# intervals | |
if use_linecache: | |
self.bt = BedToolLinecache(intervals_file) | |
else: | |
self.bt = BedTool(intervals_file) | |
# extractors | |
self.fasta_file = fasta_file | |
self.seq_extractor = None | |
self.dist_extractor = None | |
# here the DATALOADER_DIR contains the path to the current directory | |
if position_transformer_file is None: | |
raise ValueError("position_transformer_file needs to be specified") | |
self.dist_transformer = DistanceTransformer(ALL_LANDMARKS, position_transformer_file) | |
# target | |
if target_file: | |
self.target_dataset = TxtDataset(target_file) | |
assert len(self.target_dataset) == len(self.bt) | |
else: | |
self.target_dataset = None | |
def __len__(self): | |
return len(self.bt) | |
def __getitem__(self, idx): | |
if self.seq_extractor is None: | |
self.seq_extractor = FastaExtractor(self.fasta_file) | |
self.dist_extractor = DistToClosestLandmarkExtractor(gtf_file=self.gtf, | |
landmarks=ALL_LANDMARKS) | |
interval = self.bt[idx] | |
if interval.stop - interval.start != self.SEQ_WIDTH: | |
raise ValueError("Expected the interval to be {0} wide. Recieved stop - start = {1}". | |
format(self.SEQ_WIDTH, interval.stop - interval.start)) | |
out = {} | |
out['inputs'] = {} | |
# input - sequence | |
out['inputs']['seq'] = np.squeeze(self.seq_extractor([interval]), axis=0) | |
# input - distance | |
dist_dict = self.dist_transformer.transform(self.dist_extractor([interval])) | |
dist_dict = {k: np.squeeze(v, axis=0) for k, v in dist_dict.items()} # squeeze the batch axis | |
out['inputs'] = {**out['inputs'], **dist_dict} | |
# targets | |
if self.target_dataset is not None: | |
out["targets"] = np.array([self.target_dataset[idx]]) | |
# metadata | |
out['metadata'] = {} | |
out['metadata']['ranges'] = GenomicRanges.from_interval(interval) | |
return out | |
def test_dataset(): | |
"""Runs tests on the function | |
""" | |
# File paths | |
intervals_file = "example_files/intervals.bed" | |
target_file = "example_files/targets.tsv" | |
gtf_file = "example_files/gencode.v24.annotation_chr22.gtf" | |
fasta_file = "example_files/hg38_chr22.fa" | |
ds = SeqDistDataset(intervals_file, fasta_file, gtf_file, target_file) | |
ds[0] | |
ds[10] | |
it = ds.batch_iter(32) | |
next(it) |