Skip to content

Commit

Permalink
Merge pull request #96 from kipoi/add-xpresso-dataloader
Browse files Browse the repository at this point in the history
Kipoified xpresso dataloader
  • Loading branch information
haimasree committed Jun 22, 2021
2 parents 3a6732b + dc07e0c commit e985c6e
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 50 deletions.
124 changes: 75 additions & 49 deletions kipoiseq/dataloaders/sequence.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,27 @@
import pandas as pd
import numpy as np
from copy import deepcopy
import pyranges as pr
from kipoi.metadata import GenomicRanges
from kipoi.data import Dataset, kipoi_dataloader
from kipoi_conda.dependencies import Dependencies
from kipoiseq.transforms import ReorderedOneHot
from kipoi.specs import Author
from kipoi_utils.utils import default_kwargs
from kipoiseq.extractors import FastaStringExtractor
from kipoiseq.transforms.functional import resize_interval
from kipoiseq.transforms.functional import resize_interval, one_hot_dna
from kipoiseq.utils import to_scalar, parse_dtype
from kipoiseq.dataclasses import Interval

# general dependencies
# bioconda::genomelake', TODO - add genomelake again once it gets released with pyfaidx to bioconda
deps = Dependencies(conda=['bioconda::pybedtools', 'bioconda::pyfaidx', 'numpy', 'pandas'],
deps = Dependencies(conda=['bioconda::pybedtools', 'bioconda::pyfaidx', 'bioconda::pyranges', 'numpy', 'pandas'],
pip=['kipoiseq'])
package_authors = [Author(name='Ziga Avsec', github='avsecz'),
Author(name='Roman Kreuzhuber', github='krrome')]
# Add Alex here?

# Object exported on import *
__all__ = ['SeqIntervalDl', 'StringSeqIntervalDl', 'BedDataset']
__all__ = ['SeqIntervalDl', 'StringSeqIntervalDl', 'BedDataset', 'AnchoredGTFDl']


class BedDataset(object):
Expand Down Expand Up @@ -383,9 +383,72 @@ def get_output_schema(cls):
output_schema.targets = None

return output_schema


@kipoi_dataloader(override={"dependencies": deps, 'info.authors': Author(name='Alex Karollus', github='Karollus')})
class AnchoredGTFDl(Dataset):

"""
info:
doc: >
Dataloader for a combination of fasta and gtf files. The dataloader extracts fixed length regions
around anchor points. Anchor points are extracted from the gtf based on the anchor parameter.
The sequences corresponding to the region are then extracted from the fasta file and optionally
trnasformed using a function given by the transform parameter.
args:
gtf_file:
doc: Path to a gtf file (str)
example:
url: https://zenodo.org/record/1466102/files/example_files-gencode.v24.annotation_chr22.gtf
md5: c0d1bf7738f6a307b425e4890621e7d9
fasta_file:
doc: Reference genome FASTA file path (str)
example:
url: https://zenodo.org/record/1466102/files/example_files-hg38_chr22.fa
md5: b0f5cdd4f75186f8a4d2e23378c57b5b
num_upstream:
doc: Number of nt by which interval is extended upstream of the anchor point
example: 7000
num_downstream:
doc: Number of nt by which interval is extended downstream of the anchor point
example: 3500
gtf_filter:
doc: >
Allows to filter the gtf before extracting the anchor points. Can be str, callable
or None. If str, it is interpreted as argument to pandas .query(). If callable,
it is interpreted as function that filters a pandas dataframe and returns the
filtered df.
anchor:
doc: >
Defines the anchor points. Can be str or callable. If it is a callable, it is
treated as function that takes a pandas dataframe and returns a modified version
of the dataframe where each row represents one anchor point, the position of
which is stored in the column called anchor_pos. If it is a string, a predefined function
is loaded. Currently available are tss (anchor is the start of a gene), start_codon
(anchor is the start of the start_codon), stop_codon (anchor is the position right after
the stop_codon), polya (anchor is the position right after the end of a gene).
transform:
doc: Callable (or None) to transform the extracted sequence (e.g. one-hot)
interval_attrs:
doc: Metadata to extract from the gtf, e.g. ["gene_id", "Strand"]
use_strand:
doc: True or False
output_schema:
inputs:
name: seq
shape: (None, 4)
special_type: DNAStringSeq
doc: exon sequence with flanking intronic sequence
associated_metadata: ranges
metadata:
gene_id:
type: str
doc: gene id
Strand:
type: str
doc: Strand
ranges:
type: GenomicRanges
doc: ranges that the sequences were extracted
"""
_function_mapping = {
"tss": lambda x: AnchoredGTFDl.anchor_to_feature_start(x, "gene", use_strand=True),
"start_codon": lambda x: AnchoredGTFDl.anchor_to_feature_start(x, "start_codon", use_strand=True),
Expand All @@ -395,50 +458,13 @@ class AnchoredGTFDl(Dataset):

def __init__(self, gtf_file, fasta_file,
num_upstream, num_downstream,
gtf_filter, anchor,
transform,
interval_attrs,
gtf_filter='gene_type == "protein_coding"',
anchor='tss',
transform=one_hot_dna,
interval_attrs=["gene_id", "Strand"],
use_strand=True):
"""
This dataloader allows to extract fixed length sequences
Around pre defined anchor points.
info:
doc: >
Dataloader for a combination of fasta and gtf files. The dataloader extracts fixed length regions
around anchor points. Anchor points are extracted from the gtf based on the anchor parameter.
The sequences corresponding to the region are then extracted from the fasta file and optionally
trnasformed using a function given by the transform parameter.
args:
gtf_file:
doc: Path to a gtf file (str)
fasta_file:
doc: Reference genome FASTA file path (str)
num_upstream:
doc: Number of nt by which interval is extended upstream of the anchor point
num_downstream:
doc: Number of nt by which interval is extended downstream of the anchor point
gtf_filter:
doc: >
Allows to filter the gtf before extracting the anchor points. Can be str, callable
or None. If str, it is interpreted as argument to pandas .query(). If callable,
it is interpreted as function that filters a pandas dataframe and returns the
filtered df.
anchor:
doc: >
Defines the anchor points. Can be str or callable. If it is a callable, it is
treated as function that takes a pandas dataframe and returns a modified version
of the dataframe where each row represents one anchor point, the position of
which is stored in the column called anchor_pos. If it is a string, a predefined function
is loaded. Currently available are tss (anchor is the start of a gene), start_codon
(anchor is the start of the start_codon), stop_codon (anchor is the position right after
the stop_codon), polya (anchor is the position right after the end of a gene).
transform:
doc: Callable (or None) to transform the extracted sequence (e.g. one-hot)
interval_attrs:
doc: Metadata to extract from the gtf, e.g. ["gene_id", "Strand"]
"""

import pyranges as pr

# Read and filter gtf
gtf = pr.read_gtf(gtf_file).df
if gtf_filter:
Expand Down
3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@
# "h5py",
"gffutils",
"kipoi-utils>=0.1.1",
"kipoi-conda>=0.1.0"
"kipoi-conda>=0.1.0",
"pyranges"
]

test_requirements = [
Expand Down

0 comments on commit e985c6e

Please sign in to comment.