diff --git a/pyensembl/ensembl_release.py b/pyensembl/ensembl_release.py index 63cc4ad..c0c0751 100644 --- a/pyensembl/ensembl_release.py +++ b/pyensembl/ensembl_release.py @@ -88,3 +88,14 @@ def __eq__(self, other): def __hash__(self): return hash((self.release, self.species)) + + def __getstate__(self): + fields = Genome.__getstate__(self) + fields["release"] = self.release + fields["species"] = self.species + return fields + + def __setstate__(self, fields): + # Genome sets __dict__ equal to all fields, so the release and species + # fields are handled as a part of that. + Genome.__setstate__(self, fields) diff --git a/pyensembl/gene.py b/pyensembl/gene.py index f730893..1f86527 100644 --- a/pyensembl/gene.py +++ b/pyensembl/gene.py @@ -73,6 +73,16 @@ def __eq__(self, other): def __hash__(self): return hash(self.id) + def __getstate__(self): + fields = self.__dict__.copy() + # We can't pickle connections + del fields["db"] + return fields + + def __setstate__(self, fields): + self.__dict__ = fields + self.db = self.genome.db + @memoized_property def transcripts(self): """ diff --git a/pyensembl/genome.py b/pyensembl/genome.py index dad3d33..3496233 100644 --- a/pyensembl/genome.py +++ b/pyensembl/genome.py @@ -90,12 +90,17 @@ def __init__( self.reference_name = reference_name self.annotation_name = annotation_name self.annotation_version = annotation_version - self.decompress_on_download = decompress_on_download self.copy_local_files_to_cache = copy_local_files_to_cache - self.require_ensembl_ids = require_ensembl_ids + self.cache_directory_path = cache_directory_path + self._gtf_path_or_url = gtf_path_or_url + self._transcript_fasta_path_or_url = transcript_fasta_path_or_url + self._protein_fasta_path_or_url = protein_fasta_path_or_url + + self._init() + def _init(self): self.download_cache = DownloadCache( reference_name=self.reference_name, annotation_name=self.annotation_name, @@ -103,21 +108,17 @@ def __init__( decompress_on_download=self.decompress_on_download, copy_local_files_to_cache=self.copy_local_files_to_cache, install_string_function=self.install_string, - cache_directory_path=cache_directory_path) + cache_directory_path=self.cache_directory_path) self.cache_directory_path = self.download_cache.cache_directory_path - self._gtf_path_or_url = gtf_path_or_url - self.has_gtf = gtf_path_or_url is not None - - self._transcript_fasta_path_or_url = transcript_fasta_path_or_url - self.has_transcript_fasta = transcript_fasta_path_or_url is not None - - self._protein_fasta_path_or_url = protein_fasta_path_or_url - self.has_protein_fasta = protein_fasta_path_or_url is not None + self.has_gtf = self._gtf_path_or_url is not None + self.has_transcript_fasta = self._transcript_fasta_path_or_url is not None + self.has_protein_fasta = self._protein_fasta_path_or_url is not None self.logger = logging.getLogger() self.logger.setLevel(logging.INFO) self.memory_cache = MemoryCache() + self._init_lazy_fields() def _init_lazy_fields(self): @@ -1041,3 +1042,25 @@ def protein_ids(self, contig=None, strand=None): distinct=True) # drop None values return [protein_id for protein_id in protein_ids if protein_id] + + def __getstate__(self): + # Not the same as _fields(); these are useful for pickling/unpickling even if not necessary + # when checking for Genome equality. + field_list = [ + "reference_name", + "annotation_name", + "annotation_version", + "_gtf_path_or_url", + "_transcript_fasta_path_or_url", + "_protein_fasta_path_or_url", + "decompress_on_download", + "copy_local_files_to_cache", + "require_ensembl_ids", + "cache_directory_path"] + fields = self.__dict__.copy() + fields = dict([(field, value) for (field, value) in fields.items() if field in field_list]) + return fields + + def __setstate__(self, fields): + self.__dict__ = fields + self._init() diff --git a/pyensembl/species.py b/pyensembl/species.py index c791881..dbdacc4 100644 --- a/pyensembl/species.py +++ b/pyensembl/species.py @@ -58,6 +58,18 @@ def __str__(self): def __repr__(self): return str(self) + def __eq__(self, other): + return ( + other.__class__ is Species and + self.latin_name == other.latin_name and + self.synonyms == other.synonyms and + self.reference_assemblies == other.reference_assemblies) + + def __hash__(self): + return hash((self.latin_name, + tuple(self.synonyms), + frozenset(self.reference_assemblies.items()))) + _latin_names_to_species = {} _common_names_to_species = {} _reference_names_to_species = {} diff --git a/pyensembl/transcript.py b/pyensembl/transcript.py index eb7fdde..d8aeb77 100644 --- a/pyensembl/transcript.py +++ b/pyensembl/transcript.py @@ -93,6 +93,16 @@ def __hash__(self): def gene(self): return self.genome.gene_by_id(self.gene_id) + def __getstate__(self): + fields = self.__dict__.copy() + # We can't pickle connections + del fields["db"] + return fields + + def __setstate__(self, fields): + self.__dict__ = fields + self.db = self.genome.db + @memoized_property def exons(self): # need to look up exon_number alongside ID since each exon may diff --git a/test/test_pickle.py b/test/test_pickle.py new file mode 100644 index 0000000..e3d7b9f --- /dev/null +++ b/test/test_pickle.py @@ -0,0 +1,50 @@ +# Copyright (c) 2016. Mount Sinai School of Medicine +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +import pickle +from nose.tools import eq_, ok_ + +from .common import test_ensembl_releases +from .data import TP53_gene_id + +@test_ensembl_releases() +def test_gene(ensembl): + gene = ensembl.gene_by_id(TP53_gene_id) + gene_pickled = pickle.dumps(gene) + gene_new = pickle.loads(gene_pickled) + eq_(gene, gene_new) + ok_(gene.db is not None) + +@test_ensembl_releases() +def test_transcript(ensembl): + gene = ensembl.gene_by_id(TP53_gene_id) + transcript = gene.transcripts[0] + transcript_pickled = pickle.dumps(transcript) + transcript_new = pickle.loads(transcript_pickled) + eq_(transcript, transcript_new) + ok_(transcript.db is not None) + +@test_ensembl_releases() +def test_genome(ensembl): + gene = ensembl.gene_by_id(TP53_gene_id) + genome = gene.genome + genome_pickled = pickle.dumps(genome) + genome_new = pickle.loads(genome_pickled) + eq_(genome, genome_new) + ok_(genome.db is not None) + + # This Genome happens to be an EnsemblRelease; test that too. + eq_(genome.release, genome_new.release) + eq_(genome.species, genome_new.species)