Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optional file storage for data persistence #881

Open
wants to merge 6 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
3 changes: 1 addition & 2 deletions mongomock/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from mongomock import codec_options as mongomock_codec_options
from mongomock import helpers
from mongomock import read_preferences
from mongomock import store

try:
from pymongo import ReadPreference
Expand Down Expand Up @@ -40,7 +39,7 @@ def __init__(
self.name = name
self._client = client
self._collection_accesses = {}
self._store = _store or store.DatabaseStore()
self._store = _store or getattr(self._client, '_store')[self.name]
self._read_preference = read_preference or _READ_PREFERENCE_PRIMARY
mongomock_codec_options.is_supported(codec_options)
self._codec_options = codec_options or mongomock_codec_options.CodecOptions()
Expand Down
43 changes: 37 additions & 6 deletions mongomock/store.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import os
import collections
import datetime
import functools
import weakref
import bson

import mongomock
from mongomock.thread import RWLock
Expand All @@ -9,8 +12,15 @@
class ServerStore(object):
"""Object holding the data for a whole server (many databases)."""

def __init__(self):
self._databases = {}
def __init__(self, filename=None):
self._filename = os.environ.get('MONGOMOCK_SERVERSTORE_FILE', filename)
if self._filename:
with open(self._filename, 'r', encoding='utf-8') as fh:
dct = bson.json_util.loads(fh.read())
self._databases = {k: DatabaseStore.from_dict(v) for k, v in dct.items()}
self._finalizer = weakref.finalize(self, self._to_file)
else:
self._databases = {}

def __getitem__(self, db_name):
try:
Expand All @@ -25,12 +35,19 @@ def __contains__(self, db_name):
def list_created_database_names(self):
return [name for name, db in self._databases.items() if db.is_created]

def to_dict(self):
return {k: v.to_dict() for k, v in self._databases.items()}

def _to_file(self):
with open(self._filename, 'w', encoding='utf-8') as fh:
fh.write(bson.json_util.dumps(self.to_dict()))


class DatabaseStore(object):
"""Object holding the data for a database (many collections)."""

def __init__(self):
self._collections = {}
def __init__(self, _collections=None):
self._collections = _collections or {}

def __getitem__(self, col_name):
try:
Expand Down Expand Up @@ -59,12 +76,19 @@ def rename(self, name, new_name):
def is_created(self):
return any(col.is_created for col in self._collections.values())

def to_dict(self):
return {k: v.to_dict() for k, v in self._collections.items()}

@classmethod
def from_dict(cls, dct):
return cls({k: CollectionStore.from_dict(v) for k, v in dct.items()})


class CollectionStore(object):
"""Object holding the data for a collection."""

def __init__(self, name):
self._documents = collections.OrderedDict()
def __init__(self, name, documents=None):
self._documents = documents or collections.OrderedDict()
self.indexes = {}
self._is_force_created = False
self.name = name
Expand Down Expand Up @@ -172,6 +196,13 @@ def _value_meets_expiry(self, val, expiry, ttl_now):
except TypeError:
return False

def to_dict(self):
return {'name': self.name, 'documents': list(self._documents.items())}

@classmethod
def from_dict(cls, dct):
return cls(dct['name'], collections.OrderedDict(dct['documents']))


def _get_min_datetime_from_value(val):
if not val:
Expand Down
49 changes: 49 additions & 0 deletions tests/test__persistence.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
"""test store persistence"""
import os
import unittest
from tempfile import NamedTemporaryFile
from mongomock import MongoClient
from mongomock.store import ServerStore


class ServerStorePersistenceTest(unittest.TestCase):
"""test server store persistence"""

ref_str_1 = '{"test_db": {"test_coll": {"name": "test_coll", "documents": [[{"$oid": "'
ref_str_2 = '"}, {"test": true, "_id": {"$oid": "'

def setUp(self):
with NamedTemporaryFile(mode='w', prefix='mongodb-', suffix='.json',
encoding='utf-8', delete=False) as fh:
fh.write('{}')
self.filename = fh.name
os.environ['MONGOMOCK_SERVERSTORE_FILE'] = self.filename

def tearDown(self):
os.unlink(self.filename)
del os.environ['MONGOMOCK_SERVERSTORE_FILE']

def test_kwargs_method(self):
"""test by using custom ServerStore with kwargs filename"""
store = ServerStore(filename=self.filename)
client = MongoClient(_store=store)
client.test_db.test_coll.insert_one({'test': True})
finalizer = getattr(store, '_finalizer')
assert finalizer.alive
finalizer()
with open(self.filename, 'r', encoding='utf-8') as fh:
contents = fh.read()
assert self.ref_str_1 in contents
assert self.ref_str_2 in contents

def test_environ_method(self):
"""test by using an environment variable"""
client = MongoClient()
client.test_db.test_coll.insert_one({'test': True})
finalizer = getattr(getattr(client, '_store'), '_finalizer')
assert finalizer.alive
finalizer()
with open(self.filename, 'r', encoding='utf-8') as fh:
contents = fh.read()
assert self.ref_str_1 == contents[:73]
assert self.ref_str_2 == contents[97:133]