Skip to content

Commit

Permalink
info about variants
Browse files Browse the repository at this point in the history
  • Loading branch information
Kalin Nonchev committed May 19, 2020
1 parent 2a58b40 commit 9a67c5e
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 29 deletions.
7 changes: 4 additions & 3 deletions kipoiseq/dataloaders/protein.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,20 +44,21 @@ def _extractor(self):
for transcript_id, seqs in self.protein_vcf_extractor.extract_all():
# reference sequence
ref_seq = self.transcript_extractor.get_protein_seq(transcript_id)
for alt_seq in seqs:
for (alt_seq, variant) in seqs:
yield {
'input': {
'ref_seq': ref_seq,
'alt_seq': alt_seq,
},
'metadata': self.get_metadata(transcript_id)
'metadata': self.get_metadata(transcript_id, variant)
}

def get_metadata(self, transcript_id: str):
def get_metadata(self, transcript_id: str, variant: dict):
"""
get metadata for given transcript_id
"""
row = self.metadatas.loc[transcript_id]
metadata = self.metadatas.loc[transcript_id].to_dict()
metadata['transcript_id'] = row.name
metadata['variants'] = variant
return metadata
29 changes: 21 additions & 8 deletions kipoiseq/extractors/protein.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def _get_cds_from_gtf(df):
.query("{} == 'protein_coding'".format(biotype_str))
.query("(Feature == 'CDS') | (Feature == 'CCDS')")
)
df = df[df['tag'].notna()] #grch37 have ccds without tags
df = df[df['tag'].notna()] # grch37 have ccds without tags
return df[df["tag"].str.contains("basic|CCDS")].set_index('transcript_id')

@staticmethod
Expand Down Expand Up @@ -233,7 +233,7 @@ def __init__(self, gtf_file, fasta_file, vcf_file):
# match variant with transcript_id
self.single_variant_matcher = SingleVariantMatcher(
self.vcf_file, pranges=pr_cds)

self.fasta = FastaStringExtractor(self.fasta_file)
self.multi_sample_VCF = MultiSampleVCF(self.vcf_file)
self.variant_seq_extractor = VariantSeqExtractor(self.fasta_file)
Expand All @@ -245,6 +245,18 @@ def _unstrand(intervals: List[Interval]):
"""
return [i.unstrand() for i in intervals]

@staticmethod
def _prepare_variants(variants: 'List of variants'):
variants_dict = dict()
# fill dict with variants (as dict)
for index, v in enumerate(variants):
variants_dict[index] = dict(
(key.replace('_', ''), value) for key, value in v.__dict__.items())
# if single varint, unpack dict
if len(variants_dict) == 1:
variants_dict = variants_dict[0]
return variants_dict

def extract_cds(self, cds: List[Interval], sample_id=None):
"""
Extract cds with variants in their dna sequence. It depends on the
Expand All @@ -261,9 +273,9 @@ def extract_cds(self, cds: List[Interval], sample_id=None):

iter_seqs = self.extract_query(variant_interval_queryable,
sample_id=sample_id)

for seqs in iter_seqs:
yield ProteinSeqExtractor._prepare_seq(seqs, cds[0].strand, cds[0].attrs['tag'])
# 1st seq, 2nd variant info
yield ProteinSeqExtractor._prepare_seq(seqs[0], cds[0].strand, cds[0].attrs['tag']), seqs[1]

def extract_all(self):
"""
Expand All @@ -277,7 +289,6 @@ def extract_all(self):
yield transcript_id, self.extract(transcript_id)
else:
print('No matched variants with transcript_ids.')


def extract_list(self, list_with_transcript_id: List[str]):
"""
Expand Down Expand Up @@ -328,15 +339,18 @@ def _extract_query(self, variant_interval_queryable, sample_id=None):
"""
seqs = []
flag = True
variants_info = list()
for variants, interval in variant_interval_queryable.variant_intervals:
variants = list(self._filter_snv(variants))
if len(variants) > 0:
flag = False
variants_info.extend(variants)
seqs.append(self.variant_seq_extractor.extract(
interval, variants, anchor=0))
if flag:
seqs = []
yield "".join(seqs)
variants_info = []
yield "".join(seqs), self._prepare_variants(variants_info)

def extract_query(self, variant_interval_queryable, sample_id=None):
"""
Expand Down Expand Up @@ -369,10 +383,9 @@ def extract_query(self, variant_interval_queryable, sample_id=None):
variant_interval_queryable.variant_intervals):
variants = self._filter_snv(variants)
for variant in variants:

yield [
*ref_cds_seq[:i],
self.variant_seq_extractor.extract(
interval, [variant], anchor=0),
*ref_cds_seq[(i+1):],
]
], self._prepare_variants([variant])
2 changes: 1 addition & 1 deletion tests/dataloaders/test_protein_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,4 +32,4 @@ def test_single_variant_protein_dataLoader(single_variant_protein_dataLoader):
assert type(units[2]['metadata']) == dict
assert len(units[2]) == 2
assert len(units[2]['input']) == 2
assert len(units[2]['metadata']) == 17 # number of columns
assert len(units[2]['metadata']) == 18 # number of columns
32 changes: 15 additions & 17 deletions tests/extractors/test_protein.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,15 +137,15 @@ def test_ProteinVCFSeqExtractor__unstrand():
def protein_vcf_seq(mocker):
extractor = ProteinVCFSeqExtractor(gtf_file, fasta_file, vcf_file)
extractor.extract_query = mocker.MagicMock(
return_value=iter((['ATC', 'GATG'], ['CATC', 'GAT'])))
return_value=iter(([['ATC', 'GATG'], ['Var_Mutation_Mock']], [['CATC', 'GAT'], ['Var_Mutation_Mock']])))
return extractor


def test_ProteinVCFSeqExtractor_extract_cds(protein_vcf_seq):
protein_seqs = list(protein_vcf_seq.extract_cds(intervals))

assert protein_seqs[0] == 'ID'
assert protein_seqs[1] == 'HR'
assert protein_seqs[0][0] == 'ID'
assert protein_seqs[1][0] == 'HR'

query = list(protein_vcf_seq.extract_query
.call_args[0][0].variant_intervals)
Expand All @@ -167,8 +167,8 @@ def test_ProteinVCFSeqExtractor_extract_cds(protein_vcf_seq):
def test_ProteinVCFSeqExtractor_extract(protein_vcf_seq):
transcript_id = 'enst_test2'
protein_seqs = list(protein_vcf_seq.extract(transcript_id))
assert protein_seqs[0] == 'HR'
assert protein_seqs[1] == 'ID'
assert protein_seqs[0][0] == 'HR'
assert protein_seqs[1][0] == 'ID'


@pytest.fixture
Expand All @@ -179,7 +179,7 @@ def single_seq_protein():

def test_SingleSeqProteinVCFSeqExtractor_extract(single_seq_protein, transcript_seq_extractor):
transcript_id = 'enst_test2'
seq = single_seq_protein.extract(transcript_id)
seq, info = single_seq_protein.extract(transcript_id)
txt_file = 'tests/data/Output_singleSeq_vcf_enst_test2.txt'
expected_seq = open(txt_file).readline()
assert seq == expected_seq
Expand All @@ -189,7 +189,7 @@ def test_SingleSeqProteinVCFSeqExtractor_extract(single_seq_protein, transcript_
single_seq_protein = SingleSeqProteinVCFSeqExtractor(
gtf_file, fasta_file, vcf_file)

seq = single_seq_protein.extract(transcript_id)
seq, info = single_seq_protein.extract(transcript_id)
ref_seq = transcript_seq_extractor.get_protein_seq(transcript_id)

assert len(seq) == len(ref_seq)
Expand All @@ -201,8 +201,6 @@ def test_SingleSeqProteinVCFSeqExtractor_extract(single_seq_protein, transcript_
gtf_file, fasta_file, vcf_file)
seq = list(single_seq_protein.extract_all())
assert len(seq) == 0




@pytest.fixture
Expand All @@ -224,33 +222,33 @@ def test_SingleVariantProteinVCFSeqExtractor_extract(single_variant_seq, transcr
seqs = list(single_variant_seq.extract(transcript_id))
txt_file = 'tests/data/Output_singleVar_vcf_enst_test2.txt'
expected_seq = open(txt_file).read().splitlines()
assert seqs[0] == expected_seq[0]
assert seqs[1] == expected_seq[1]
assert seqs[2] == expected_seq[2]
assert seqs[0][0] == expected_seq[0]
assert seqs[1][0] == expected_seq[1]
assert seqs[2][0] == expected_seq[2]

seqs = list(single_variant_seq.extract_all())
counter = 0
for tr_id, t_id_seqs in seqs:
t_id_seqs = list(t_id_seqs)
t_id_seqs = [seq for seq, info in list(t_id_seqs)]
counter += len(t_id_seqs)
for i, seq in enumerate(t_id_seqs):
assert seq == expected_seq[i]
assert tr_id == 'enst_test2'
assert counter == 3, 'Number of variants in vcf 3, but # of seq was: ' + \
str(counter)

transcript_id = ['enst_test2', 'enst_test1']
seqs = single_variant_seq.extract_list(transcript_id)
for tr_id, t_id_seqs in seqs:
assert tr_id in ['enst_test2', 'enst_test1'], tr_id



vcf_file = 'tests/data/singleVar_vcf_enst_test1_diff_type_of_variants.vcf.gz'
transcript_id = 'enst_test1'
single_var_protein = SingleVariantProteinVCFSeqExtractor(
gtf_file, fasta_file, vcf_file)

seqs = list(single_var_protein.extract(transcript_id))
seqs = [seq for seq, info in list(
single_var_protein.extract(transcript_id))]
ref_seq = transcript_seq_extractor.get_protein_seq(transcript_id)

assert len(seqs) == 1
Expand Down

0 comments on commit 9a67c5e

Please sign in to comment.