Skip to content

Commit

Permalink
Merge branch 'performance'
Browse files Browse the repository at this point in the history
  • Loading branch information
James Casbon committed Jun 12, 2012
2 parents ccecafd + 563caad commit ad6dd99
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 61 deletions.
116 changes: 56 additions & 60 deletions vcf/parser.py
Expand Up @@ -115,6 +115,9 @@ def read_meta(self, meta_string):


class _Call(object):

__slots__ = ['site', 'sample', 'data', 'gt_nums', 'called']

""" A genotype call, a cell entry in a VCF file"""
def __init__(self, site, sample, data):
#: The ``_Record`` for this ``_Call``
Expand Down Expand Up @@ -211,8 +214,8 @@ class _Record(object):
The list of genotype calls is in the ``samples`` property.
"""
def __init__(self, CHROM, POS, ID, REF, ALT, QUAL, FILTER, INFO, FORMAT, sample_indexes, samples=None,
gt_bases = None, gt_types = None, gt_phases = None):
def __init__(self, CHROM, POS, ID, REF, ALT, QUAL, FILTER, INFO, FORMAT,
sample_indexes, samples=None):
self.CHROM = CHROM
self.POS = POS
self.ID = ID
Expand All @@ -232,11 +235,6 @@ def __init__(self, CHROM, POS, ID, REF, ALT, QUAL, FILTER, INFO, FORMAT, sample_
#: list of ``_Calls`` for each sample ordered as in source VCF
self.samples = samples
self._sample_indexes = sample_indexes
# lists of the pre-computed base-wise genotypes ('A/G'), types (0,1,2)
# and phases for each sample.
self.gt_bases = gt_bases
self.gt_types = gt_types
self.gt_phases = gt_phases

def __eq__(self, other):
""" _Records are equal if they describe the same variant (same position, alleles) """
Expand Down Expand Up @@ -555,6 +553,7 @@ def __init__(self, fsock=None, filename=None, compressed=False, prepend_chr=Fals
self._tabix = None
self._prepend_chr = prepend_chr
self._parse_metainfo()
self._format_cache = {}

def __iter__(self):
return self
Expand Down Expand Up @@ -647,14 +646,8 @@ def _parse_info(self, info_str):

return retdict

def _parse_samples(self, samples, samp_fmt, site):
'''Parse a sample entry according to the format specified in the FORMAT
column.'''
samp_data = []# OrderedDict()
gt_bases = []# A/A, A|G, G/G, etc.
gt_types = []# 0, 1, 2, etc.
gt_phases = []# T, F, T, etc.

def _parse_sample_format(self, samp_fmt):
""" Parse the format of the calls in this _Record """
samp_fmt = samp_fmt.split(':')

samp_fmt_types = []
Expand All @@ -672,60 +665,65 @@ def _parse_samples(self, samples, samp_fmt, site):
entry_type = 'String'
samp_fmt_types.append(entry_type)
samp_fmt_nums.append(entry_num)
return samp_fmt, samp_fmt_types, samp_fmt_nums

for name, sample in itertools.izip(self.samples, samples):
sampdict = self._parse_sample(sample, samp_fmt, samp_fmt_types, samp_fmt_nums)
call = _Call(site, name, sampdict)
samp_data.append(call)

bases = call.gt_bases
type = call.gt_type
phase = call.phased
gt_bases.append(bases) if bases is not None else './.'
gt_types.append(type) if type is not None else -1
gt_phases.append(phase) if phase is not None else False
def _parse_samples(self, samples, samp_fmt, site):
'''Parse a sample entry according to the format specified in the FORMAT
column.'''

return _SampleInfo(samp_data, gt_bases, gt_types, gt_phases)
# check whether we already know how to parse this format
if samp_fmt in self._format_cache:
samp_fmt, samp_fmt_types, samp_fmt_nums = \
self._format_cache[samp_fmt]
else:
sf, samp_fmt_types, samp_fmt_nums = self._parse_sample_format(samp_fmt)
self._format_cache[samp_fmt] = (sf, samp_fmt_types, samp_fmt_nums)
samp_fmt = sf

def _parse_sample(self, sample, samp_fmt, samp_fmt_types, samp_fmt_nums):
sampdict = dict([(x, None) for x in samp_fmt])
samp_data = []

for fmt, entry_type, entry_num, vals in itertools.izip(
samp_fmt, samp_fmt_types, samp_fmt_nums, sample.split(':')):
for name, sample in itertools.izip(self.samples, samples):

# short circuit the most common
if vals == '.' or vals == './.':
sampdict[fmt] = None
continue
# parse the data for this sample
sampdict = dict([(x, None) for x in samp_fmt])

# we don't need to split single entries
if entry_num == 1 or ',' not in vals:
for fmt, entry_type, entry_num, vals in itertools.izip(
samp_fmt, samp_fmt_types, samp_fmt_nums, sample.split(':')):

if entry_type == 'Integer':
sampdict[fmt] = int(vals)
elif entry_type == 'Float':
sampdict[fmt] = float(vals)
else:
sampdict[fmt] = vals
# short circuit the most common
if vals == '.' or vals == './.':
sampdict[fmt] = None
continue

if entry_num != 1:
sampdict[fmt] = (sampdict[fmt])
# we don't need to split single entries
if entry_num == 1 or ',' not in vals:

continue
if entry_type == 'Integer':
sampdict[fmt] = int(vals)
elif entry_type == 'Float':
sampdict[fmt] = float(vals)
else:
sampdict[fmt] = vals

if entry_num != 1:
sampdict[fmt] = (sampdict[fmt])

vals = vals.split(',')
continue

if entry_type == 'Integer':
sampdict[fmt] = self._map(int, vals)
elif entry_type == 'Float' or entry_type == 'Numeric':
sampdict[fmt] = self._map(float, vals)
else:
sampdict[fmt] = vals
vals = vals.split(',')

if entry_type == 'Integer':
sampdict[fmt] = self._map(int, vals)
elif entry_type == 'Float' or entry_type == 'Numeric':
sampdict[fmt] = self._map(float, vals)
else:
sampdict[fmt] = vals

return sampdict
# create a call object
call = _Call(site, name, sampdict)
samp_data.append(call)

return samp_data

def next(self):
'''Return the next record in the file.'''
Expand Down Expand Up @@ -762,14 +760,12 @@ def next(self):
except IndexError:
fmt = None

record = _Record(chrom, pos, ID, ref, alt, qual, filt, info, fmt, self._sample_indexes)
record = _Record(chrom, pos, ID, ref, alt, qual, filt,
info, fmt, self._sample_indexes)

if fmt is not None:
sample_info = self._parse_samples(row[9:], fmt, record)
record.samples = sample_info.samples
record.gt_bases = sample_info.gt_bases
record.gt_types = sample_info.gt_types
record.gt_phases = sample_info.gt_phases
samples = self._parse_samples(row[9:], fmt, record)
record.samples = samples

return record

Expand Down
2 changes: 1 addition & 1 deletion vcf/test/prof.py
Expand Up @@ -5,7 +5,7 @@
import sys

def parse_1kg():
for line in vcf.Reader(filename='test/1kg.vcf.gz'):
for line in vcf.Reader(filename='vcf/test/1kg.vcf.gz'):
pass

if len(sys.argv) == 1:
Expand Down

0 comments on commit ad6dd99

Please sign in to comment.