Skip to content

Commit

Permalink
Simplified the tests for the parallel pybgen
Browse files Browse the repository at this point in the history
  • Loading branch information
lemieuxl committed Jul 19, 2017
1 parent fd6d18a commit 0df07d1
Showing 1 changed file with 2 additions and 212 deletions.
214 changes: 2 additions & 212 deletions pybgen/tests/test_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,27 +24,20 @@


import os
import shutil
import unittest
from tempfile import mkdtemp

import numpy as np
from pkg_resources import resource_filename

from ..parallel import ParallelPyBGEN
from ..pybgen import HAS_ZSTD
from .truths import truths
from .test_pybgen import ReaderTests


__all__ = ["parallel_reader_tests"]


class ParallelReaderTests(unittest.TestCase):

@classmethod
def setUpClass(cls):
# Creating a temporary directory
cls.tmp_dir = mkdtemp(prefix="pybgen_test_")
class ParallelReaderTests(ReaderTests):

def setUp(self):
# Getting the truth for this file
Expand All @@ -54,209 +47,6 @@ def setUp(self):
bgen_fn = resource_filename(__name__, self.bgen_filename)
self.bgen = ParallelPyBGEN(bgen_fn)

@classmethod
def tearDownClass(cls):
# Cleaning the temporary directory
shutil.rmtree(cls.tmp_dir)

def tearDown(self):
# Closing the object
self.bgen.close()

def _compare_variant(self, expected, observed):
"""Compare two variants."""
self.assertEqual(expected.name, observed.name)
self.assertEqual(expected.chrom, observed.chrom)
self.assertEqual(expected.pos, observed.pos)
self.assertEqual(expected.a1, observed.a1)
self.assertEqual(expected.a2, observed.a2)

def test_repr(self):
"""Tests the __repr__ representation."""
self.assertEqual(
"PyBGEN({:,d} samples; {:,d} variants)".format(
self.truths["nb_samples"], self.truths["nb_variants"],
),
str(self.bgen),
)

def test_nb_samples(self):
"""Tests the number of samples."""
self.assertEqual(self.truths["nb_samples"], self.bgen.nb_samples)

def test_nb_variants(self):
"""Tests the number of variants."""
self.assertEqual(self.truths["nb_variants"], self.bgen.nb_variants)

def test_samples(self):
"""Tests the samples attribute."""
if self.truths["samples"] is None:
self.assertTrue(self.bgen.samples is None)
else:
self.assertEqual(self.truths["samples"], self.bgen.samples)

def test_get_first_variant(self):
"""Tests getting the first variant of the file."""
# The variant to retrieve
name = "RSID_2"

# Getting the results (there should be only one
r = self.bgen.get_variant(name)
self.assertEqual(1, len(r))
variant, dosage = r.pop()

# Checking the variant
self._compare_variant(
self.truths["variants"][name]["variant"],
variant,
)

# Checking the dosage
np.testing.assert_array_almost_equal(
self.truths["variants"][name]["dosage"], dosage,
)

def test_get_middle_variant(self):
"""Tests getting a variant in the middle of the file."""
# The variant to retrieve
name = "RSID_148"

# Getting the results (there should be only one
r = self.bgen.get_variant(name)
self.assertEqual(1, len(r))
variant, dosage = r.pop()

# Checking the variant
self._compare_variant(
self.truths["variants"][name]["variant"],
variant,
)

# Checking the dosage
np.testing.assert_array_almost_equal(
self.truths["variants"][name]["dosage"], dosage,
)

def test_get_last_variant(self):
"""Tests getting the last variant of the file."""
# The variant to retrieve
name = "RSID_200"

# Getting the results (there should be only one
r = self.bgen.get_variant(name)
self.assertEqual(1, len(r))
variant, dosage = r.pop()

# Checking the variant
self._compare_variant(
self.truths["variants"][name]["variant"],
variant,
)

# Checking the dosage
np.testing.assert_array_almost_equal(
self.truths["variants"][name]["dosage"], dosage,
)

def test_get_missing_variant(self):
"""Tests getting a variant which is absent from the BGEN file."""
with self.assertRaises(ValueError) as cm:
self.bgen.get_variant("UNKOWN_VARIANT_NAME")
self.assertEqual(
"UNKOWN_VARIANT_NAME: name not found",
str(cm.exception),
)

def test_iter_all_variants(self):
"""Tests the iteration of all variants."""
seen_variants = set()
for variant, dosage in self.bgen.iter_variants():
# The name of the variant
name = variant.name
seen_variants.add(name)

# Comparing the variant
self._compare_variant(
self.truths["variants"][name]["variant"],
variant,
)

# Comparing the dosage
np.testing.assert_array_almost_equal(
self.truths["variants"][name]["dosage"], dosage,
)

# Checking if we checked all variants
self.assertEqual(seen_variants, self.truths["variant_set"])

def test_as_iterator(self):
"""Tests the module as iterator."""
seen_variants = set()
for variant, dosage in self.bgen:
# The name of the variant
name = variant.name
seen_variants.add(name)

# Comparing the variant
self._compare_variant(
self.truths["variants"][name]["variant"],
variant,
)

# Comparing the dosage
np.testing.assert_array_almost_equal(
self.truths["variants"][name]["dosage"], dosage,
)

# Checking if we checked all variants
self.assertEqual(seen_variants, self.truths["variant_set"])

def test_iter_variant_info(self):
"""Tests the iteration of all variants' information."""
seen_variants = set()
for variant in self.bgen.iter_variant_info():
# The name of the variant
name = variant.name
seen_variants.add(name)

# Comparing the variant
self._compare_variant(
self.truths["variants"][name]["variant"],
variant,
)

# Checking if we checked all variants
self.assertEqual(seen_variants, self.truths["variant_set"])

def test_iter_variants_in_region(self):
"""Tests the iteration of all variants in a genomic region."""
seen_variants = set()
iterator = self.bgen.iter_variants_in_region("01", 67000, 70999)
for variant, dosage in iterator:
# The name of the variant
name = variant.name
seen_variants.add(name)

# Comparing the variant
self._compare_variant(
self.truths["variants"][name]["variant"],
variant,
)

# Comparing the dosage
np.testing.assert_array_almost_equal(
self.truths["variants"][name]["dosage"], dosage,
)

# Checking if we checked all variants
expected = set()
for name in self.truths["variant_set"]:
variant = self.truths["variants"][name]["variant"]
if variant.chrom == "01":
if variant.pos >= 67000 and variant.pos <= 70999:
expected.add(name)
self.assertEqual(seen_variants, expected)


class Test32bits(ParallelReaderTests):
bgen_filename = os.path.join("data", "example.32bits.bgen")
Expand Down

0 comments on commit 0df07d1

Please sign in to comment.