Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Kipoified xpresso dataloader #96

Merged
merged 8 commits into from
Jun 22, 2021
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
124 changes: 74 additions & 50 deletions kipoiseq/dataloaders/sequence.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,27 @@
import pandas as pd
import numpy as np
from copy import deepcopy
import pyranges as pr
from kipoi.metadata import GenomicRanges
from kipoi.data import Dataset, kipoi_dataloader
from kipoi_conda.dependencies import Dependencies
from kipoiseq.transforms import ReorderedOneHot
from kipoi.specs import Author
from kipoi_utils.utils import default_kwargs
from kipoiseq.extractors import FastaStringExtractor
from kipoiseq.transforms.functional import resize_interval
from kipoiseq.transforms.functional import resize_interval, one_hot_dna
from kipoiseq.utils import to_scalar, parse_dtype
from kipoiseq.dataclasses import Interval

# general dependencies
# bioconda::genomelake', TODO - add genomelake again once it gets released with pyfaidx to bioconda
deps = Dependencies(conda=['bioconda::pybedtools', 'bioconda::pyfaidx', 'numpy', 'pandas'],
deps = Dependencies(conda=['bioconda::pybedtools', 'bioconda::pyfaidx', 'bioconda::pyranges', 'numpy', 'pandas'],
pip=['kipoiseq'])
package_authors = [Author(name='Ziga Avsec', github='avsecz'),
Author(name='Roman Kreuzhuber', github='krrome')]
# Add Alex here?

# Object exported on import *
__all__ = ['SeqIntervalDl', 'StringSeqIntervalDl', 'BedDataset']
__all__ = ['SeqIntervalDl', 'StringSeqIntervalDl', 'BedDataset', 'AnchoredGTFDl']


class BedDataset(object):
Expand Down Expand Up @@ -383,9 +383,70 @@ def get_output_schema(cls):
output_schema.targets = None

return output_schema


@kipoi_dataloader(override={"dependencies": deps, 'info.authors': Author(name='Alex Karollus', github='Karollus')})
class AnchoredGTFDl(Dataset):

"""
info:
doc: >
Dataloader for a combination of fasta and gtf files. The dataloader extracts fixed length regions
around anchor points. Anchor points are extracted from the gtf based on the anchor parameter.
The sequences corresponding to the region are then extracted from the fasta file and optionally
trnasformed using a function given by the transform parameter.
args:
gtf_file:
doc: Path to a gtf file (str)
example:
url: https://zenodo.org/record/1466102/files/example_files-gencode.v24.annotation_chr22.gtf
md5: c0d1bf7738f6a307b425e4890621e7d9
fasta_file:
doc: Reference genome FASTA file path (str)
example:
url: https://zenodo.org/record/1466102/files/example_files-hg38_chr22.fa
md5: b0f5cdd4f75186f8a4d2e23378c57b5b
num_upstream:
doc: Number of nt by which interval is extended upstream of the anchor point
num_downstream:
doc: Number of nt by which interval is extended downstream of the anchor point
gtf_filter:
doc: >
Allows to filter the gtf before extracting the anchor points. Can be str, callable
or None. If str, it is interpreted as argument to pandas .query(). If callable,
it is interpreted as function that filters a pandas dataframe and returns the
filtered df.
anchor:
doc: >
Defines the anchor points. Can be str or callable. If it is a callable, it is
treated as function that takes a pandas dataframe and returns a modified version
of the dataframe where each row represents one anchor point, the position of
which is stored in the column called anchor_pos. If it is a string, a predefined function
is loaded. Currently available are tss (anchor is the start of a gene), start_codon
(anchor is the start of the start_codon), stop_codon (anchor is the position right after
the stop_codon), polya (anchor is the position right after the end of a gene).
transform:
doc: Callable (or None) to transform the extracted sequence (e.g. one-hot)
interval_attrs:
doc: Metadata to extract from the gtf, e.g. ["gene_id", "Strand"]
use_strand:
doc: True or False
output_schema:
inputs:
name: seq
shape: (None, 4)
special_type: DNAStringSeq
doc: exon sequence with flanking intronic sequence
associated_metadata: ranges
metadata:
gene_id:
type: str
doc: gene id
Strand:
type: str
doc: Strand
ranges:
type: GenomicRanges
doc: ranges that the sequences were extracted
"""
_function_mapping = {
"tss": lambda x: AnchoredGTFDl.anchor_to_feature_start(x, "gene", use_strand=True),
"start_codon": lambda x: AnchoredGTFDl.anchor_to_feature_start(x, "start_codon", use_strand=True),
Expand All @@ -394,51 +455,14 @@ class AnchoredGTFDl(Dataset):
}

def __init__(self, gtf_file, fasta_file,
num_upstream, num_downstream,
gtf_filter, anchor,
transform,
interval_attrs,
num_upstream=7000, num_downstream=3500,
haimasree marked this conversation as resolved.
Show resolved Hide resolved
gtf_filter='gene_type == "protein_coding"',
anchor='tss',
transform=one_hot_dna,
interval_attrs=["gene_id", "Strand"],
use_strand=True):
"""
This dataloader allows to extract fixed length sequences
Around pre defined anchor points.

info:
doc: >
Dataloader for a combination of fasta and gtf files. The dataloader extracts fixed length regions
around anchor points. Anchor points are extracted from the gtf based on the anchor parameter.
The sequences corresponding to the region are then extracted from the fasta file and optionally
trnasformed using a function given by the transform parameter.
args:
gtf_file:
doc: Path to a gtf file (str)
fasta_file:
doc: Reference genome FASTA file path (str)
num_upstream:
doc: Number of nt by which interval is extended upstream of the anchor point
num_downstream:
doc: Number of nt by which interval is extended downstream of the anchor point
gtf_filter:
doc: >
Allows to filter the gtf before extracting the anchor points. Can be str, callable
or None. If str, it is interpreted as argument to pandas .query(). If callable,
it is interpreted as function that filters a pandas dataframe and returns the
filtered df.
anchor:
doc: >
Defines the anchor points. Can be str or callable. If it is a callable, it is
treated as function that takes a pandas dataframe and returns a modified version
of the dataframe where each row represents one anchor point, the position of
which is stored in the column called anchor_pos. If it is a string, a predefined function
is loaded. Currently available are tss (anchor is the start of a gene), start_codon
(anchor is the start of the start_codon), stop_codon (anchor is the position right after
the stop_codon), polya (anchor is the position right after the end of a gene).
transform:
doc: Callable (or None) to transform the extracted sequence (e.g. one-hot)
interval_attrs:
doc: Metadata to extract from the gtf, e.g. ["gene_id", "Strand"]
"""

import pyranges as pr

# Read and filter gtf
gtf = pr.read_gtf(gtf_file).df
if gtf_filter:
Expand Down
3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@
# "h5py",
"gffutils",
"kipoi-utils>=0.1.1",
"kipoi-conda>=0.1.0"
"kipoi-conda>=0.1.0",
"pyranges"
]

test_requirements = [
Expand Down