In [1]:
import kipoi_veff, kipoi
import pytest
import os
import sys
import lmdb
from tqdm import tqdm
import kipoi_veff.snv_predict as sp
import pandas as pd
import pyarrow as pa
from kipoi.readers import Reader
from kipoi_veff import analyse_model_preds
from kipoi_veff.scores import Diff, LogitRef
# from kipoi_cadd.writers import LmdbWriter
from kipoi_cadd.utils import variant_id_string
# Logit, LogitRef, LogitAlt, , DeepSEA_effect, RCScore, scoring_options
from kipoi_veff.utils.io import SyncBatchWriter, SyncPredictonsWriter

Using TensorFlow backend.


In [2]:
os.chdir("/data/ouga/home/ag_gagneur/simancas/Projects/kipoi-veff")
# pytest.main(['-k', 'test_other_writers'])

In [6]:
class LmdbWriter(SyncPredictonsWriter):
    def __init__(self, lmdb_dir, map_size=10E8):
        self.lmdb_dir = lmdb_dir
        self.map_size = map_size
    
    def __call__(self, predictions, records, line_ids=None):
        import pyarrow as pa
        
        self.env = lmdb.open(self.lmdb_dir , map_size=self.map_size, max_dbs=0, lock=False)
        with self.env.begin(write=True) as txn:
            for var_num, var in tqdm(enumerate(records), total=len(records)):
                variant_id = variant_id_string(var.CHROM, var.POS, var.REF, var.ALT)
                annotations = {}
                for p in predictions:
                    # Verify there is a prediction for this variant...
                    annotations[p] = predictions[p].iloc[var_num, :]

                buf = pa.serialize(annotations).to_buffer()
                txn.put(variant_id.encode('ascii'), buf)
  
    def close(self):
        if self.env is not None:
            self.env.close()


class LmdbReader(Reader):
    def __init__(self, lmdb_dir):
        self.lmdb_dir = lmdb_dir
        self.env = lmdb.open(self.lmdb_dir, readonly=True, lock=False)
        self.txn = self.env.begin()
        
        
    def __len__(self):
        length = self.txn.stat()['entries']
        return length
    
    def __del__(self):
        if self.env is not None:
            self.env.close()
    
    def single_iter(self):
        return iter(self.txn.cursor())
        
    close = __del__

In [4]:
INSTALL_REQ = False

def test_other_writers(tmpdir):
    if sys.version_info[0] == 2:
        pytest.skip("rbp example not supported on python 2 ")
    model_dir = "tests/models/var_seqlen_model/"
    if INSTALL_REQ:
        install_model_requirements(model_dir, "dir", and_dataloaders=True)

    model = kipoi.get_model(model_dir, source="dir")
    # The preprocessor
    Dataloader = kipoi.get_dataloader_factory(model_dir, source="dir")

    dataloader_arguments = {
        "fasta_file": "example_files/hg38_chr22.fa",
        "preproc_transformer": "dataloader_files/encodeSplines.pkl",
        "gtf_file": "example_files/gencode_v25_chr22.gtf.pkl.gz",
        "intervals_file": "example_files/variant_centered_intervals.tsv"
    }
    dataloader_arguments = {k: model_dir + v for k, v in dataloader_arguments.items()}
    vcf_path = model_dir + "example_files/variants.vcf"
    ref_out_vcf_fpath = model_dir + "example_files/variants_ref_out.vcf"

    vcf_path = kipoi_veff.ensure_tabixed_vcf(vcf_path)
    model_info = kipoi_veff.ModelInfoExtractor(model, Dataloader)

    from kipoi.writers import HDF5BatchWriter, TsvBatchWriter, MultipleBatchWriter

    h5_path = os.path.join(str(tmpdir), 'preds.h5')
    tsv_path = os.path.join(str(tmpdir), 'preds.tsv')
    lmdb_path = os.path.join(str(tmpdir), 'lmdb/')
    # writer = SyncBatchWriter(MultipleBatchWriter([HDF5BatchWriter(h5_path),
    #                                               TsvBatchWriter(tsv_path)]))
    # writer = kipoi_veff.VcfWriter(model, vcf_path, out_vcf_fpath, standardise_var_id=True)
    writer = LmdbWriter(lmdb_path)
    
    vcf_to_region = None
    with pytest.raises(Exception):
        # This has to raise an exception as the sequence length is None.
        vcf_to_region = kipoi_veff.SnvCenteredRg(model_info)
    output = sp.predict_snvs(model, Dataloader, vcf_path, dataloader_args=dataloader_arguments,
                    evaluation_function=analyse_model_preds, batch_size=32,
                    vcf_to_region=vcf_to_region,
                    evaluation_function_kwargs={'diff_types': {'diff': Diff("mean"), 'logitRef': LogitRef("max")}},
                    return_predictions=True,
                    sync_pred_writer=writer)
    return output

In [8]:
tmpdir = "/tmp/kipoi-veff/"
out, res = test_other_writers(tmpdir)

  0%|          | 0/1 [00:00<?, ?it/s]INFO:2019-01-01 17:24:54,478:genomelake] Running landmark extractors..
  ("strand", gtf.strand)])
INFO:2019-01-01 17:24:54,494:genomelake] Done!


pp_line              0
varpos_rel          49
ref                  A
alt                  T
start         21541541
end           21541641
id                   0
do_mutate         True
strand               .
Name: 0, dtype: object
Sequence:
TACCTATTTGGGTTTTCACTAGTAAGCAGTTGGTTTGTAAGCAGTTGGTAATTTTAGTTTGTCTGGGTTTCAGCCATGAATATTCTATTGTAAACTTAATT[0m
pp_line              1
varpos_rel          49
ref                  C
alt                  C
start         21541903
end           21542003
id                   1
do_mutate         True
strand               .
Name: 1, dtype: object
Sequence:
GTAGATACGGGGTTTCAACATGTTGCCCAGGCTGGTCTTGAATTCCTGTCCTCAAGCGATCCACTTGCCTCGCCTCCCAAAGTGCTGAGATTACAAGTATG[0m
pp_line              2
varpos_rel          49
ref                  T
alt                  G
start         30630171
end           30630271
id                   2
do_mutate         True
strand               .
Name: 2, dtype: object
Sequence:
GCCCTCAGACTCCCTTCACCCCAAGGTGTGCCATCCTCTCCATTCCACCTAGGCCTGTCCAGGCCTCG


  0%|          | 0/6 [00:00<?, ?it/s][A
100%|██████████| 1/1 [00:00<00:00,  4.05it/s]


In [98]:
reader = LmdbReader(tmpdir + "lmdb")

In [99]:
it = reader.single_iter()

In [101]:
b = next(it)

In [109]:
pa.deserialize(b[1])

{'diff': rbp_prb    0.0
 Name: 1, dtype: float32, 'logitRef': rbp_prb   NaN
 Name: 1, dtype: float32}

In [108]:
type(b[0])
str(b[0], encoding="ascii")

"22:21541952:C:['C']"

In [68]:
for k in res:
    print(res[k].iloc[1,:])

rbp_prb    0.0
Name: 1, dtype: float32
rbp_prb   NaN
Name: 1, dtype: float32


In [54]:
from cyvcf2 import VCF
from kipoi_cadd.utils import variant_id_string

In [38]:
model_dir = "tests/models/var_seqlen_model/"
vcf_path = model_dir + "example_files/variants.vcf"
var_it = VCF(vcf_path)
var = next(var_it)
var

Variant(chr22:21541590 A/T)

In [43]:
print(var.__repr__())

Variant(chr22:21541590 A/T)


In [57]:
variant_id_string(var.CHROM, var.POS, var.REF, var.ALT)

"22:21541590:A:['T']"

In [53]:
var.CHROM.split('chr')[1]

'22'