# Prototype for a more generic dataloader interface

This notebook walks through the design of creating a generic single sequence dataloader, with some notes on my understanding of the kipoiseq codebase (I make no guarantees as to the correctness of the latter).

In [1]:
import kipoiseq
import kipoi
import os
import urllib
import pandas as pd
import pyranges
from collections import defaultdict

from kipoiseq.dataloaders import *
from kipoiseq.extractors import GenericMultiIntervalSeqExtractor, BaseMultiIntervalFetcher, \
    GTFMultiIntervalFetcher, BaseExtractor, FastaStringExtractor, SingleVariantMatcher, GenericSingleVariantMultiIntervalVCFSeqExtractor, \
    MultiSampleVCF

Get some data to work with

In [116]:
# make ExampleFile directory if it does not exist
if not os.path.exists("ExampleFiles"):
    os.makedirs("ExampleFiles")
    
# Download GTF
urllib.request.urlretrieve("https://zenodo.org/record/1466102/files/example_files-gencode.v24.annotation_chr22.gtf?download=1", 'ExampleFiles/chrom22.gtf')
# Download fasta
urllib.request.urlretrieve("https://zenodo.org/record/1466102/files/example_files-hg38_chr22.fa?download=1", 'ExampleFiles/chrom22.fa')
# Download bed
urllib.request.urlretrieve("https://raw.githubusercontent.com/kipoi/kipoiseq/master/tests/data/intervals_51bp.tsv", "ExampleFiles/example_bed.bed")
# Download fasta that goes along with it
urllib.request.urlretrieve("https://raw.githubusercontent.com/kipoi/kipoiseq/master/tests/data/hg38_chr22_32000000_32300000.fa", "ExampleFiles/example_bed.fasta")

('ExampleFiles/example_bed.fasta', <http.client.HTTPMessage at 0x150e6c9b40f0>)

## Introducing the GenericSingleSeqDataloader in prototype.py

The GenericSingleSeqDataloader extends SampleIterator from kipoi and is composed of three main components:

* An interval fetcher (of type BaseMultiIntervalFetcher from kipoiseq.extractors.multi_interval). This is a generator object that provides intervals defining the location (in some reference) of the sequences of interest. The way BaseMultiIntervalFetcher is designed is that it provides, for a given key (e.g. a transcript_id), a list of intervals (kipoiseq.datatypes.Interval) that correspond to that key (e.g. all exons of that transcript). As generator, it simply iterates through the keys and returns both keys and corresponding intervals. 
    * Currently, the main working implementation of BaseMultiIntervalFetcher is GTFMultiIntervalFetcher (from kipoiseq.extractors.gtf). This, despite the name, is not tied to gtf but only requires a pyranges-like pandas dataframe (so dataframe with at least the columns Chromosome, Start, End, Strand) - which can be delivered from other sources easily. The key is the dataframe index. If the index is not unique, then it will return all intervals that have the same index (I am not sure this is an intended functionality in pandas, but it seems to work well). 
    * One could also probably easily design a fetcher that gets intervals just in time, from a database or similar. This might be better when doing analyses on whole genomes
* A reference sequence source, usually a FastaExtractor, which given an interval provides the corresponding reference sequence
* A sequence transformer (e.g. OneHot), which given a string sequence provides a transformed sequence as required by a model

The operation of this class is quite straightforward: it gets keys and intervals from the fetcher, extracts the sequence using the reference sequence source, transforms it, and then returns it in a dict together with metadata.

With this, I believe it is very easy to build new dataloaders for most standard use-cases. All one needs to do is:
* Define a way to import and preprocess interval data and supply them to the Fetcher. In most cases, this will be as easy as reading in a dataframe with pyranges (from gtf, bed, tsv, ...), doing some pandas operations on it, and then calling init of the GTFMUltiIntervalFetcher (which could maybe be renamed to RangesDataFrameFetcher or something)
* Define a way to load reference sequence data (in 99% of cases this will be a fasta file supplied to a FastaSequenceExtractor)
* Define some transformations, if necessary.


## Building a gtf based TSS dataloader for Xpresso

We can use this template to design a TSS dataloader for Xpresso:

The main thing we need to write is some way to extract TSS sites from a pyranges-like dataframe

In [2]:
class TSSFinder:
    """
    Imputes approximate TSS location as 5' end of the gene annotation
    """
    
    def __init__(
        self,
        n_upstream,
        n_downstream
    ):
        self.n_upstream = n_upstream
        self.n_downstream = n_downstream
    
    def __call__(
        self,
        region_df : pd.DataFrame
    ) -> pd.DataFrame:
        region_df = region_df.query('Feature == "gene" and gene_type == "protein_coding"')
        anchor = ((region_df.Start * (region_df.Strand == "+")) 
                  + (region_df.End * (region_df.Strand == "-")))
        region_df["Start"] = (anchor + 
                (-self.n_upstream * (region_df.Strand == "+")) + 
                (-self.n_downstream * (region_df.Strand == "-")))
        region_df["End"] = (anchor + 
                (self.n_downstream * (region_df.Strand == "+")) + 
                (self.n_upstream * (region_df.Strand == "-")))
        return region_df

Once we have that, we can easily build a TSS Dataloader simply by extending our GenericSingleSeqDataloader and building the fetcher, extractor and transform:

In [3]:
from kipoiseq.dataloaders.prototype import GenericSingleSeqDataloader, IdentityTransform
class TSSDataloader(GenericSingleSeqDataloader):
    
    def __init__(
        self,
        gtf_file : str,
        fasta_file : str,
        n_upstream : int,
        n_downstream : int,
        interval_attrs = ["gene_id", "gene_type"]
    ):
        self.gtf_file = gtf_file
        self.fasta_file = fasta_file
        self.n_upstream = n_upstream
        self.n_downstream = n_downstream
        self.use_strand = True
        
        # Source interval data from gtf
        df = pyranges.read_gtf(self.gtf_file).df
        # Subset to areas of interest
        df = TSSFinder(
            self.n_upstream,
            self.n_downstream
        )(df)
        # Build the interval fetcher
        interval_source = GTFMultiIntervalFetcher(
            df, 
            keep_attrs=interval_attrs
        )
        # Source reference sequence from fasta
        reference_sequence_source = FastaStringExtractor(
            fasta_file,
            use_strand=self.use_strand
        )
        # Provide sequence transformer
        sequence_transformer = IdentityTransform()
        # Pass all to superclass
        super().__init__(
            interval_source,
            reference_sequence_source,
            sequence_transformer,
            interval_attrs
        )

Lets build it and test it

In [4]:
tss = TSSDataloader("ExampleFiles/chrom22.gtf",
                   "ExampleFiles/chrom22.fa",
                   7000,
                   3500)

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy


In [5]:
gtf_df = pyranges.read_gtf("ExampleFiles/chrom22.gtf").df

# Check that we faithfully recover the original TSS
for sample in tss:
    gene_id = sample["metadata"]["gene_id"]
    gene_strand = sample["metadata"]['ranges'].strand
    gtf_row = gtf_df.query('gene_id == @gene_id')
    if len(gtf_row) == 0:
        print(gene_id)
    if gene_strand == "+":
        implied_TSS = sample["metadata"]['ranges'].start + 7000
        assert(gtf_row.iloc[0].Start == implied_TSS)
    if gene_strand == "-":
        implied_TSS = sample["metadata"]['ranges'].end - 7000
        assert(gtf_row.iloc[0].End == implied_TSS)

## Extracting PolyA sequences

We could also extract PolyA sites:

In [6]:
class PolyAFinder:
    """
    Imputes approximate polyA location as 3' end of the transcript annotation
    """

    def __init__(
        self,
        n_upstream,
        n_downstream
    ):
        self.n_upstream = n_upstream
        self.n_downstream = n_downstream
    
    def __call__(
        self,
        region_df : pd.DataFrame
    ) -> pd.DataFrame:
        region_df = region_df.query('Feature == "transcript" and transcript_type == "protein_coding"')
        anchor = ((region_df.End * (region_df.Strand == "+")) 
                + (region_df.Start * (region_df.Strand == "-")))
        region_df["Start"] = (anchor + 
                (-self.n_upstream * (region_df.Strand == "+")) + 
                (-self.n_downstream * (region_df.Strand == "-")))
        region_df["End"] = (anchor + 
                (self.n_downstream * (region_df.Strand == "+")) + 
                (self.n_upstream * (region_df.Strand == "-")))
        return region_df

All we need to do to achieve this, is to replace the TSSFinder with a PolyAFinder:

In [7]:
class PolyADataloader(GenericSingleSeqDataloader):
    
    def __init__(
        self,
        gtf_file : str,
        fasta_file : str,
        n_upstream : int,
        n_downstream : int,
        interval_attrs = ["gene_id", "transcript_id", "transcript_type"]
    ):
        self.gtf_file = gtf_file
        self.fasta_file = fasta_file
        self.n_upstream = n_upstream
        self.n_downstream = n_downstream
        self.use_strand = True
        
        # Source interval data from gtf
        df = pyranges.read_gtf(self.gtf_file).df
        # Subset to areas of interest
        df = PolyAFinder(
            self.n_upstream,
            self.n_downstream
        )(df)
        # Build the interval fetcher
        interval_source = GTFMultiIntervalFetcher(
            df, 
            keep_attrs=interval_attrs
        )
        # Source reference sequence from fasta
        reference_sequence_source = FastaStringExtractor(
            fasta_file,
            use_strand=self.use_strand
        )
        # Provide sequence transformer
        sequence_transformer = IdentityTransform()
        # Pass all to superclass
        super().__init__(
            interval_source,
            reference_sequence_source,
            sequence_transformer,
            interval_attrs
        )

In [8]:
polyA = PolyADataloader("ExampleFiles/chrom22.gtf",
                   "ExampleFiles/chrom22.fa",
                   102,
                   103)

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy


In [9]:
next(polyA)

{'inputs': {'seq': 'TTTCTGGCATTGTCTGCCCAGCTGCTCCAAGCCAGACTGATGAAGGAGGAGTCCCCAGTGGTGAGCTGGAGGTTGGAGCCTGAAGATGGCACAGCTCTGTGATTCATCTTCTGCGGTTGTGGCAGCCACGGTGATGGAGACGGCAGCTCAACAGGAGCAATAGGAGGGTACCCATGGAGGCCAAGTGGTAGGATCCTTGGAGGGT'},
 'metadata': {'gene_id': 'ENSG00000279973.1',
  'transcript_id': 'ENST00000624155.1',
  'transcript_type': 'protein_coding',
  'ranges': GenomicRanges(chr='chr22', start=11067987, end=11068192, id=1, strand='+')}}

## Sourcing data from a bed file

Say we already have a bed file defining the areas of interest. We can easily build a dataloader for this.

In [10]:
class BEDLoader:
    """
    Class that loads a bed as pyranges-like pandas dataframe
    """
    
    def __init__(
        self,
        bed_path : str
    ):
        self.bed_path = bed_path
    
    def load_df(self) -> pd.DataFrame:
        df = pyranges.read_bed(self.bed_path).df
        if "Strand" not in df.keys():
            df["Strand"] = "*"
        return df

In [11]:
class ChrRename:
    
    def __call__(
        self,
        region_df
    ):
        region_df["Chromosome"] = region_df["Chromosome"].str.replace("^chr", "")
        return region_df

All we need to do to achieve this is to use the BEDLoader rather than reading a GTF:

In [12]:
class BedDataloader(GenericSingleSeqDataloader):
    
    def __init__(
        self,
        bed_file : str,
        fasta_file : str,
        use_strand : bool,
        interval_attrs = ["Name"]
    ):
        self.bed_file = bed_file
        self.fasta_file = fasta_file
        self.use_strand = use_strand
        
        # Source interval data from bed
        df = BEDLoader(self.bed_file).load_df()
        #df = ChrRename()(df)
        # Build the interval fetcher iterator
        interval_source = GTFMultiIntervalFetcher(
            df, 
            keep_attrs=interval_attrs
        )
        # Source reference sequence from fasta
        reference_sequence_source = FastaStringExtractor(
            fasta_file,
            use_strand=self.use_strand
        )
        # Provide sequence transformer
        sequence_transformer = IdentityTransform()
        # Pass all to superclass
        super().__init__(
            interval_source,
            reference_sequence_source,
            sequence_transformer,
            interval_attrs
        )

In [13]:
bed = BedDataloader("ExampleFiles/example_bed.bed",
                   "ExampleFiles/example_bed.fasta",
                   use_strand = False)

Test equivalence to old StringSeqDL dataloader

In [14]:
bed_old = StringSeqIntervalDl("ExampleFiles/example_bed.bed",
                   "ExampleFiles/example_bed.fasta")

seq_new = [x["inputs"]["seq"] for x in bed]
seq_old = [x["inputs"] for x in bed_old]
assert(x == y for x,y in zip(seq_new, seq_old))

## Extracting all CDS sequences

Exploiting this design, we can very easily also make a dataloader that works on spliced sequences, e.g. coding sequences

In [15]:
class CDSFinder:
    """
    Extracts CDS
    """
    
    def __call__(
        self,
        region_df : pd.DataFrame
    ) -> pd.DataFrame:
        region_df = region_df.query('Feature == "CDS" and transcript_type == "protein_coding"')
        region_df.set_index("transcript_id", inplace=True)
        return region_df

In [16]:
class CDSDataloader(GenericSingleSeqDataloader):
    
    def __init__(
        self,
        gtf_file : str,
        fasta_file : str,
        incl_chr = ["chr1"],
        interval_attrs = ["gene_id", "transcript_type"]
    ):
        self.gtf_file = gtf_file
        self.fasta_file = fasta_file
        self.use_strand = True
        
        # Source interval data from gtf
        df = (pyranges.read_gtf(self.gtf_file)
              .df
              .query('Chromosome in @incl_chr')
             )
        # Subset to areas of interest
        df = CDSFinder()(df)
        # Build the interval fetcher iterator
        interval_source = GTFMultiIntervalFetcher(
            df, 
            keep_attrs=interval_attrs
        )
        # Source reference sequence from fasta
        reference_sequence_source = FastaStringExtractor(
            fasta_file,
            use_strand=self.use_strand
        )
        # Provide sequence transformer
        sequence_transformer = IdentityTransform()
        # Pass all to superclass
        super().__init__(
            interval_source,
            reference_sequence_source,
            sequence_transformer,
            interval_attrs
        )

Get some test data (big files)

In [25]:
# ground truth
urllib.request.urlretrieve("ftp://ftp.ebi.ac.uk/pub/databases/gencode/Gencode_mouse/release_M26/gencode.vM26.pc_transcripts.fa.gz", 
                           "ExampleFiles/mouse_transcripts.fa.gz")
# gtf
urllib.request.urlretrieve("ftp://ftp.ebi.ac.uk/pub/databases/gencode/Gencode_mouse/release_M26/gencode.vM26.basic.annotation.gtf.gz", 
                           "ExampleFiles/mouse.gtf.gz")
# fasta
urllib.request.urlretrieve("ftp://ftp.ebi.ac.uk/pub/databases/gencode/Gencode_mouse/release_M26/GRCm39.primary_assembly.genome.fa.gz", 
                           "ExampleFiles/mouse.fa.gz")

('ExampleFiles/mouse.fa.gz', <email.message.Message at 0x2b21fe7615c0>)

We have to unzip the fasta, because the kipoiseq fasta extractor is **very** slow on gzipped files

In [26]:
! gunzip ExampleFiles/mouse.fa.gz

PyRanges loading of big gtf is also not exactly lightning fast, but it offers a lot of flexibility in return

In [17]:
cds = CDSDataloader("ExampleFiles/mouse.gtf.gz",
                   "ExampleFiles/mouse.fa")

In [18]:
cds_seq = {x["metadata"]["ranges"].id:x["inputs"]["seq"] for x in cds}

We compare to the groud truth to see if it works:

In [19]:
from Bio.Seq import Seq
from Bio.SeqRecord import SeqRecord
from Bio import SeqIO
import gzip

with gzip.open("ExampleFiles/mouse_transcripts.fa.gz", "rt") as handle:
    for record in SeqIO.parse(handle, "fasta"):
        record_id = record.id.split("|")[0] 
        if record_id in cds_seq:
            our_seq = cds_seq[record_id]
            # We extract the ground truth CDS position
            for x in record.id.split("|"):
                if x.startswith("CDS"):
                    cds_loc = x.split(":")[1].split("-")
            # For 3' end incomplete transcripts, we get the whole cds
            if str(record.seq)[int(cds_loc[1]) - 3:int(cds_loc[1])] not in ["TAA", "TGA", "TAG"]:
                true_seq = str(record.seq)[int(cds_loc[0]) - 1:int(cds_loc[1])]
            else: # For 3' end complete transcripts, we need to exclude the stop codon
                true_seq = str(record.seq)[int(cds_loc[0]) - 1:int(cds_loc[1]) - 3]
            try:
                assert(our_seq == true_seq)
            except Exception:
                print(record_id)
                print(our_seq)
                print(str(record.seq))
                print(true_seq)

# Variants

The biggest value of kipoiseq are the classes that handle extraction and insertion of variants. These are very useful and provide many useful functions, but difficult to use/maintain because:
* They are not easy to understand, at least for me, since there is a large number of objects (matchers, extractors, interval_queries etc...) and it is not super obvious which components are needed
* The documentation is quite limited
* They all are designed to work on VCF files. Especially for large applications (gnomAD etc.), other sources of variants may be preferable, such as hail, databases, etc

In [93]:
! cat ExampleFiles/chrom22.fa | sed 's/>chr/>/g' > ExampleFiles/chrom22_nochr.fa

In [None]:
# Download vcf
urllib.request.urlretrieve("http://ftp.1000genomes.ebi.ac.uk/vol1/ftp/pilot_data/paper_data_sets/a_map_of_human_variation/low_coverage/snps/CEU.low_coverage.2010_09.genotypes.vcf.gz", 'ExampleFiles/CEU.low_coverage.2010_09.genotypes.vcf.gz')

In [None]:
! tabix ExampleFiles/CEU.low_coverage.2010_09.genotypes.vcf.gz

## An alternative solution

Just as we can break down the single sequence dataloading task into the three main components of fetching intervals, getting the ref sequence and then transforming, we can also break down the variant dataloading into three main steps:
* Fetching variants belonging to a specific interval. For this we design a VariantFetcher class.
* Defining a strategy how to insert multiple variants (all at the same time, all individually, one alt-seq per chromosome for phased data etc...). This is the ugliest object in the current design, but I am not sure how to get around it. Its interface is: given a list of intervals and corresponding variants, yield a template (e.g. list of intervals with the desired variants) to build exactly one alternative sequence
* Inserting the variants into a reference sequence, respecting the strategy above. For this the existing VariantSeqExtractor works very well. For the time being, I have put a wrapper class around it, because the VariantSeqExtractor has two parameters (anchor and fixed_len) which it only gets at function call time rather than init time. If anything depends on setting these parameters dynamically, one could adapt the design to have an anchoring strategy as well, but for the time being I am not sure whether this is ever the case.

In the prototype.py, I have created a prototype for a GenericVariantDataloader, which takes these components to load variants. It is very similar in design to the GenericSingleSeqDataloader, but with the additional steps detailed above. The code is commented to explain the workflow

I think this template is easier to use/extend than the currently existing dataloaders. The currently existing code has a large hierarchy of specialized coupled objects (for example, there is a SingleVariantUTRDataLoader, which has a GenericSingleVariantMultiIntervalVCFSeqExtractor, which is BaseMultiIntervalVCFSeqExtractor (which in turn is a GenericMultiIntervalSeqExtractor, which in turn is a BaseMultiIntervalSeqExtractor) and it has a interval_fetcher and a variant seq extractor + a mixin that determines the insertion strategy and a "matcher" that as far as I can see is never used (see appendix)). Its a bit of a matryoshka doll that for me at least, took a considerable time to understand, and I think most new users will be scared away by this. It also leads to a proliferation of kind of useless classes. For example if I want to get variants from a hail table, I have to create a BaseMultiIntervalHailSeqExtractor and then use mixins to create the GenericSingleVariantMultiIntervalHailSeqExtractor and the GenericMultiVariantMultiIntervalHailSeqExtractor and so on.

I believe the design I propose, by decoupling all of these things, makes it in my opinion a lot easier to design new dataloaders in a modular fashion (i.e. if I want tyo load from hail, I jsut change the fetcher). It also makes it easier to write a documentation for new users to understand how to build a dataloader, changing just the particular components they need.

In [2]:
class CDSFinder:
    """
    Extracts CDS
    """
    
    def __call__(
        self,
        region_df : pd.DataFrame
    ) -> pd.DataFrame:
        region_df = region_df.query('Feature == "CDS" and transcript_type == "protein_coding"')
        region_df.set_index("transcript_id", inplace=True)
        return region_df

class ChrRename:
    
    def __call__(
        self,
        region_df
    ):
        region_df["Chromosome"] = region_df["Chromosome"].str.replace("^chr", "")
        return region_df

In [3]:
from kipoiseq.dataloaders.prototype import GenericVariantDataloader, VCFVariantFetcher, SingleVariantStrategy,\
    IdentityTransform, VariantSequenceExtractor
class CDSVariantDataloader(GenericVariantDataloader):
    
    def __init__(
        self,
        gtf_file : str,
        fasta_file : str,
        vcf_file : str,
        interval_attrs = ["gene_id", "transcript_type"]
    ):
        self.gtf_file = gtf_file
        self.fasta_file = fasta_file
        self.vcf_file = vcf_file
        self.use_strand = True
        
        # Source interval data from gtf
        df = pyranges.read_gtf(self.gtf_file).df
        # Subset to areas of interest
        df = CDSFinder()(df)
        df = ChrRename()(df)
        # Build the interval fetcher
        interval_source = GTFMultiIntervalFetcher(
            df, 
            keep_attrs=interval_attrs
        )
        # Source reference sequence from fasta
        reference_sequence_source = FastaStringExtractor(
            fasta_file,
            use_strand=self.use_strand
        )
        # Source variants from vcf
        variant_source = VCFVariantFetcher(
            self.vcf_file
        )
        # Build variant sequence extractor
        # I am not super sure what changing the anchor achieves 
        # The GenericSingleVariantMultiIntervalVCFSeqExtractor hardcodes it to 0
        # So I set it by default to 0 too.
        variant_sequence_extractor = VariantSequenceExtractor(
            reference_sequence_source,
            anchor = 0,
            fixed_len = False
        )
        # Provide variants individually
        variant_insertion_strategy = SingleVariantStrategy()
        # Provide sequence transformer
        sequence_transformer = IdentityTransform()
        # Pass all to superclass
        super().__init__(
            interval_source,
            variant_source,
            variant_insertion_strategy,
            reference_sequence_source,
            variant_sequence_extractor,
            sequence_transformer,
            interval_attrs
        )

## Test

In [4]:
cdsvar = CDSVariantDataloader(
    "ExampleFiles/chrom22.gtf",
    "ExampleFiles/chrom22_nochr.fa",
    "ExampleFiles/CEU.low_coverage.2010_09.genotypes.vcf.gz"
)

# Results when using the old extractor
var_dict_new = defaultdict(dict)
for item in cdsvar:
    transcript_id = item["metadata"]["ranges"].name
    var_dict_new[transcript_id]["ref"] = item["inputs"]["ref_seq"]
    var_dict_new[transcript_id]["alts"] = var_dict_new[transcript_id].get("alts", []) + [item["inputs"]["alt_seq"]]

In [5]:
# get variants using old extractor as ground truth
df = pyranges.read_gtf('ExampleFiles/chrom22.gtf').df
df = CDSFinder()(df)
df = ChrRename()(df)
interval_source = GTFMultiIntervalFetcher(
    df, 
    keep_attrs=["gene_id"]
)
variant_matcher = SingleVariantMatcher(
    "ExampleFiles/CEU.low_coverage.2010_09.genotypes.vcf.gz",
    pranges=pyranges.PyRanges(
        interval_source.df.reset_index()
    )
)
reference_sequence_source = FastaStringExtractor(
    "ExampleFiles/chrom22_nochr.fa",
    use_strand=True
)
multi_sample_VCF = MultiSampleVCF("ExampleFiles/CEU.low_coverage.2010_09.genotypes.vcf.gz")
extractor = GenericSingleVariantMultiIntervalVCFSeqExtractor(
            interval_fetcher=interval_source,
            reference_seq_extractor=reference_sequence_source,
            variant_matcher=variant_matcher,
            multi_sample_VCF=multi_sample_VCF,
)

# Results when using the old extractor
var_dict = defaultdict(dict)
for transcript_id, (ref_seq, alt_seqs) in extractor.items():
    alt_seq_list = [alt_seq[0] for alt_seq in alt_seqs]
    if len(alt_seq_list) > 0:
        var_dict[transcript_id]["ref"] = ref_seq
        var_dict[transcript_id]["alts"] = alt_seq_list

In [6]:
assert([sorted(var_dict_new.keys()) == sorted(var_dict.keys())])

In [7]:
for transcript_id in var_dict_new.keys():
    assert(var_dict_new[transcript_id]["ref"] == var_dict[transcript_id]["ref"])
    assert(sorted(var_dict_new[transcript_id]["alts"]) == sorted(var_dict[transcript_id]["alts"]))

# Drawbacks and Remaining Issues

Specific issues of this new design:
* The variant insertion strategy is a bit intimidating as it works with List[Tuple[Interval, List[Variant]]]. One could, of course, make a datatype for this. The main reason I  didnt do it is (a) I couldnt figure out how to name it (except for ListOfIntervalVariantsTuples, which isn't really much better) and (b) because lists are nice to work with and I did not want to write all the boilerplate to make an object implement the list interface
* I dont really know what the anchor parameter of the VariantSeqExtractor really does, but if it actually makes a difference, this is a bit of a leaky abstraction since one needs to be quite familiar with class internals to set it correctly. 

More general issues:
* Generalizing to multi-input or, worse, multimodal models (e.g. ones that use sequence + experimental tracks) is not entirely straightforward
* There are complex variant events (e.g. a variant deleting a stop, causing a CDS to extend (on the spliced mRNA!) until the next stop is found), which are really hard to handle in a generic way

# Appendix

An example of the somewhat confusing nature of the current variant objects: The GenericSingleVariantMultiIntervalVCFSeqExtractor (I love OOP names) wants a VariantMatcher, but as far as I can see, never uses it in any capacity

In [6]:
df = pyranges.read_gtf('ExampleFiles/chrom22.gtf').df
df = CDSFinder()(df)
df = ChrRename()(df)
interval_source = GTFMultiIntervalFetcher(
    df, 
    keep_attrs=["gene_id"]
)
variant_matcher = SingleVariantMatcher(
    "ExampleFiles/CEU.low_coverage.2010_09.genotypes.vcf.gz",
    pranges=pyranges.PyRanges(
        interval_source.df.reset_index()
    )
)
reference_sequence_source = FastaStringExtractor(
    "ExampleFiles/chrom22_nochr.fa",
    use_strand=True
)
multi_sample_VCF = MultiSampleVCF("ExampleFiles/CEU.low_coverage.2010_09.genotypes.vcf.gz")
extractor = GenericSingleVariantMultiIntervalVCFSeqExtractor(
            interval_fetcher=interval_source,
            reference_seq_extractor=reference_sequence_source,
            variant_matcher=variant_matcher,
            multi_sample_VCF=multi_sample_VCF,
)

In [7]:
# Results when using the matcher
var_dict = defaultdict(list)
for transcript_id, (ref_seq, alt_seqs) in extractor.items():
    alt_seq_list = [alt_seq for alt_seq in alt_seqs]
    for alt_seq in alt_seq_list: 
        alt_seq[1]["source"] = None # because cyvcf variant objects are incomparable
    if len(alt_seq_list) > 0:
        var_dict[transcript_id].append((ref_seq, alt_seq_list))

In [128]:
extractor = GenericSingleVariantMultiIntervalVCFSeqExtractor(
            interval_fetcher=interval_source,
            reference_seq_extractor=reference_sequence_source,
            variant_matcher=None,
            multi_sample_VCF=multi_sample_VCF,
)

In [129]:
# Results when matcher is none
var_dict_nomatcher = defaultdict(list)
for transcript_id, (ref_seq, alt_seqs) in extractor.items():
    alt_seq_list = [alt_seq for alt_seq in alt_seqs]
    for alt_seq in alt_seq_list: 
        alt_seq[1]["source"] = None # because cyvcf variant objects are incomparable
    if len(alt_seq_list) > 0:
        var_dict_nomatcher[transcript_id].append((ref_seq, alt_seq_list))

In [130]:
assert(var_dict.keys() == var_dict_nomatcher.keys())

In [132]:
for key in var_dict.keys():
    try:
        assert(var_dict[key] == var_dict_nomatcher[key])
    except Exception:
        a = var_dict[key]
        b = var_dict_nomatcher[key]
        break

In [9]:
x = var_dict[next(iter(var_dict))]

In [14]:
x[0][1]

[('ATGGTGGCTGAGGCTGGTTCAATGCCGGCTGCCTCCTCTGTGAAGAAGCCATTTGGTCTCAGAAGCAAGATGGGCAAGTGGTGCCGCCACTGCTTCGCCTGGTGCAGGGGGAGCGGCAAGAGCAACGTGGGCACTTCTGGAGACCACGACGATTCTGCTATGAAGACACTCAGGAGCAAGATGGGCAAGTGGTGCTGCCACTGCTTCCCCTGGTGCAGGGGGAGCGGCAAGAGCAACGTGGGCACTTCTGGAGACCACGACGATTCTGCTATGAAGACACTCAGGAGCAAGATGGGCAAGTGGTGCTGCCACTGCTTCCCCTGCTGCAGGGGGAGCGGCAAGAGCAACGTGGGCACTTCTGGAGACCACGACGACTCTGCTATGAAGACACTCAGGAGCAAGATGGGCAACTGGTGCTGCCACTGCTTCCCCTGCTGCAGGGGGAGCGGCAAGAACAAAGTGGGCCCTTGGGGAGACTACGACGACAGCGCTTTCATGGAGCCGAGGTACCACGTCCGTCGAGAAGATCTGGACAAGCTCCACAGAGCTGCCTGGTGGGGTAAAGTCCCCAGAAAGGATCTCATCGTCATGCTCAAGGACACTGACATGAACAAGAAGGACAAGCAAAAGAGGACTGCTCTACATCTGGCCTCTGCCAATGGAAATTCAGAAGTAGTAAAACTCCTGCTGGACAGACGATGTCAACTTAATATCCTTGACAACAAAAAGAGGACAGCTCTGACAAAGGCCGTACAATGCCAGGAAGATGAATGTGCGTTAATGTTGCTGGAACATGGCACTGATCCGAATATTCCAGATGAGTATGGAAATACCGCTCTACACTATGCTATCTACAATGAAGATAAATTAATGGCCAAAGCACTGCTCTTATACGGTGCTGATATCGAATCAAAAAACAAGCATGGCCTCACACCACTGTTACTTGGTGTACATGAGCAAAAACAGCAAGTGGTGAAATTTTTAATCAAGAAAAAAG