Skip to content

Commit

Permalink
Merge pull request #766 from dmitriy-serdyuk/fix-log
Browse files Browse the repository at this point in the history
Fix log unpickling
  • Loading branch information
dwf committed Jul 22, 2015
2 parents 1e0aca9 + 8c5e083 commit 49a12a3
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 20 deletions.
11 changes: 5 additions & 6 deletions blocks/log/log.py
@@ -1,5 +1,4 @@
"""The event-based main loop of Blocks."""
import sqlite3
from abc import ABCMeta
from collections import defaultdict
from numbers import Integral
Expand Down Expand Up @@ -65,13 +64,13 @@ def __init__(self, uuid=None):
})

@property
def b_uuid(self):
"""Return a buffered version of the UUID bytes.
def h_uuid(self):
"""Return a hexadecimal version of the UUID bytes.
This is necessary to store bytes in an SQLite database.
This is necessary to store ids in an SQLite database.
"""
return sqlite3.Binary(self.uuid.bytes)
return self.uuid.hex

def resume(self):
"""Resume a log by setting a new random UUID.
Expand All @@ -80,7 +79,7 @@ def resume(self):
copies the status of the old log into the new log.
"""
old_uuid = self.b_uuid
old_uuid = self.h_uuid
old_status = dict(self.status)
self.uuid = uuid4()
self.status.update(old_status)
Expand Down
28 changes: 14 additions & 14 deletions blocks/log/sqlite.py
Expand Up @@ -121,14 +121,14 @@ def __init__(self, database=None, **kwargs):
sqlite3.register_adapter(numpy.ndarray, adapt_ndarray)
with self.conn:
self.conn.execute("""CREATE TABLE IF NOT EXISTS entries (
uuid BLOB NOT NULL,
uuid TEXT NOT NULL,
time INT NOT NULL,
"key" TEXT NOT NULL,
value,
PRIMARY KEY(uuid, time, "key")
);""")
self.conn.execute("""CREATE TABLE IF NOT EXISTS status (
uuid BLOB NOT NULL,
uuid TEXT NOT NULL,
"key" text NOT NULL,
value,
PRIMARY KEY(uuid, "key")
Expand Down Expand Up @@ -167,13 +167,13 @@ def __getitem__(self, time):
def __iter__(self):
return map(itemgetter(0), self.conn.execute(
ANCESTORS_QUERY + "SELECT DISTINCT time FROM entries "
"WHERE uuid IN ancestors ORDER BY time ASC", (self.b_uuid,)
"WHERE uuid IN ancestors ORDER BY time ASC", (self.h_uuid,)
))

def __len__(self):
return self.conn.execute(
ANCESTORS_QUERY + "SELECT COUNT(DISTINCT time) FROM entries "
"WHERE uuid IN ancestors ORDER BY time ASC", (self.b_uuid,)
"WHERE uuid IN ancestors ORDER BY time ASC", (self.h_uuid,)
).fetchone()[0]


Expand All @@ -184,7 +184,7 @@ def __init__(self, log):
def __getitem__(self, key):
row = self.log.conn.execute(
"SELECT value FROM status WHERE uuid = ? AND key = ?",
(self.log.b_uuid, key)
(self.log.h_uuid, key)
).fetchone()
return _get_row(row, key)

Expand All @@ -193,25 +193,25 @@ def __setitem__(self, key, value):
with self.log.conn:
self.log.conn.execute(
"INSERT OR REPLACE INTO status VALUES (?, ?, ?)",
(self.log.b_uuid, key, value)
(self.log.h_uuid, key, value)
)

def __delitem__(self, key):
with self.log.conn:
self.log.conn.execute(
"DELETE FROM status WHERE uuid = ? AND key = ?",
(self.log.b_uuid, key)
(self.log.h_uuid, key)
)

def __len__(self):
return self.log.conn.execute(
"SELECT COUNT(*) FROM status WHERE uuid = ?",
(self.log.b_uuid,)
(self.log.h_uuid,)
).fetchone()[0]

def __iter__(self):
return map(itemgetter(0), self.log.conn.execute(
"SELECT key FROM status WHERE uuid = ?", (self.log.b_uuid,)
"SELECT key FROM status WHERE uuid = ?", (self.log.h_uuid,)
))


Expand All @@ -237,7 +237,7 @@ def __getitem__(self, key):
# JOIN statement should sort things so that the latest is returned
"JOIN ancestors ON entries.uuid = ancestors.parent "
"WHERE uuid IN ancestors AND time = ? AND key = ?",
(self.log.b_uuid, self.time, key)
(self.log.h_uuid, self.time, key)
).fetchone()
return _get_row(row, key)

Expand All @@ -246,26 +246,26 @@ def __setitem__(self, key, value):
with self.log.conn:
self.log.conn.execute(
"INSERT OR REPLACE INTO entries VALUES (?, ?, ?, ?)",
(self.log.b_uuid, self.time, key, value)
(self.log.h_uuid, self.time, key, value)
)

def __delitem__(self, key):
with self.log.conn:
self.log.conn.execute(
"DELETE FROM entries WHERE uuid = ? AND time = ? AND key = ?",
(self.log.b_uuid, self.time, key)
(self.log.h_uuid, self.time, key)
)

def __len__(self):
return self.log.conn.execute(
ANCESTORS_QUERY + "SELECT COUNT(*) FROM entries "
"WHERE uuid IN ancestors AND time = ?",
(self.log.b_uuid, self.time,)
(self.log.h_uuid, self.time,)
).fetchone()[0]

def __iter__(self):
return map(itemgetter(0), self.log.conn.execute(
ANCESTORS_QUERY + "SELECT key FROM entries "
"WHERE uuid IN ancestors AND time = ?",
(self.log.b_uuid, self.time,)
(self.log.h_uuid, self.time,)
))
12 changes: 12 additions & 0 deletions tests/test_log.py
Expand Up @@ -3,6 +3,7 @@
from numpy.testing import assert_raises

from blocks.log import TrainingLog
from blocks.serialization import load, dump


def test_training_log():
Expand All @@ -21,3 +22,14 @@ def test_training_log():

# test iteration
assert len(list(log)) == 2


def test_pickle_log():
log1 = TrainingLog()
dump(log1, "log1.pkl")
log2 = load("log1.pkl")
dump(log2, "log2.pkl")
load("log2.pkl") # loading an unresumed log works
log2.resume()
dump(log2, "log3.pkl")
load("log3.pkl") # loading a resumed log does not work

0 comments on commit 49a12a3

Please sign in to comment.