Skip to content

HTTPS clone URL

Subversion checkout URL

You can clone with HTTPS or Subversion.

Download ZIP
Browse files

add sqlitedb backend

  • Loading branch information...
commit 681a891cdc2d17463a8f9241d7ab6213b2036396 1 parent cfbd77b
@minrk minrk authored
Showing with 272 additions and 0 deletions.
  1. +272 −0 IPython/zmq/parallel/sqlitedb.py
View
272 IPython/zmq/parallel/sqlitedb.py
@@ -0,0 +1,272 @@
+"""A TaskRecord backend using sqlite3"""
+#-----------------------------------------------------------------------------
+# Copyright (C) 2011 The IPython Development Team
+#
+# Distributed under the terms of the BSD License. The full license is in
+# the file COPYING, distributed as part of this software.
+#-----------------------------------------------------------------------------
+
+import json
+import os
+import cPickle as pickle
+from datetime import datetime
+
+import sqlite3
+
+from IPython.utils.traitlets import CUnicode, CStr, Instance, List
+from .dictdb import BaseDB
+from .util import ISO8601
+
+#-----------------------------------------------------------------------------
+# SQLite operators, adapters, and converters
+#-----------------------------------------------------------------------------
+
+operators = {
+ '$lt' : lambda a,b: "%s < ?",
+ '$gt' : ">",
+ # null is handled weird with ==,!=
+ '$eq' : "IS",
+ '$ne' : "IS NOT",
+ '$lte': "<=",
+ '$gte': ">=",
+ '$in' : ('IS', ' OR '),
+ '$nin': ('IS NOT', ' AND '),
+ # '$all': None,
+ # '$mod': None,
+ # '$exists' : None
+}
+
+def _adapt_datetime(dt):
+ return dt.strftime(ISO8601)
+
+def _convert_datetime(ds):
+ if ds is None:
+ return ds
+ else:
+ return datetime.strptime(ds, ISO8601)
+
+def _adapt_dict(d):
+ return json.dumps(d)
+
+def _convert_dict(ds):
+ if ds is None:
+ return ds
+ else:
+ return json.loads(ds)
+
+def _adapt_bufs(bufs):
+ # this is *horrible*
+ # copy buffers into single list and pickle it:
+ if bufs and isinstance(bufs[0], (bytes, buffer)):
+ return sqlite3.Binary(pickle.dumps(map(bytes, bufs),-1))
+ elif bufs:
+ return bufs
+ else:
+ return None
+
+def _convert_bufs(bs):
+ if bs is None:
+ return []
+ else:
+ return pickle.loads(bytes(bs))
+
+#-----------------------------------------------------------------------------
+# SQLiteDB class
+#-----------------------------------------------------------------------------
+
+class SQLiteDB(BaseDB):
+ """SQLite3 TaskRecord backend."""
+
+ filename = CUnicode('tasks.db', config=True)
+ location = CUnicode('', config=True)
+ table = CUnicode("", config=True)
+
+ _db = Instance('sqlite3.Connection')
+ _keys = List(['msg_id' ,
+ 'header' ,
+ 'content',
+ 'buffers',
+ 'submitted',
+ 'client_uuid' ,
+ 'engine_uuid' ,
+ 'started',
+ 'completed',
+ 'resubmitted',
+ 'result_header' ,
+ 'result_content' ,
+ 'result_buffers' ,
+ 'queue' ,
+ 'pyin' ,
+ 'pyout',
+ 'pyerr',
+ 'stdout',
+ 'stderr',
+ ])
+
+ def __init__(self, **kwargs):
+ super(SQLiteDB, self).__init__(**kwargs)
+ if not self.table:
+ # use session, and prefix _, since starting with # is illegal
+ self.table = '_'+self.session.replace('-','_')
+ if not self.location:
+ if hasattr(self.config.Global, 'cluster_dir'):
+ self.location = self.config.Global.cluster_dir
+ else:
+ self.location = '.'
+ self._init_db()
+
+ def _defaults(self):
+ """create an empty record"""
+ d = {}
+ for key in self._keys:
+ d[key] = None
+ return d
+
+ def _init_db(self):
+ """Connect to the database and get new session number."""
+ # register adapters
+ sqlite3.register_adapter(datetime, _adapt_datetime)
+ sqlite3.register_converter('datetime', _convert_datetime)
+ sqlite3.register_adapter(dict, _adapt_dict)
+ sqlite3.register_converter('dict', _convert_dict)
+ sqlite3.register_adapter(list, _adapt_bufs)
+ sqlite3.register_converter('bufs', _convert_bufs)
+ # connect to the db
+ dbfile = os.path.join(self.location, self.filename)
+ self._db = sqlite3.connect(dbfile, detect_types=sqlite3.PARSE_DECLTYPES)
+
+ self._db.execute("""CREATE TABLE IF NOT EXISTS %s
+ (msg_id text PRIMARY KEY,
+ header dict text,
+ content dict text,
+ buffers bufs blob,
+ submitted datetime text,
+ client_uuid text,
+ engine_uuid text,
+ started datetime text,
+ completed datetime text,
+ resubmitted datetime text,
+ result_header dict text,
+ result_content dict text,
+ result_buffers bufs blob,
+ queue text,
+ pyin text,
+ pyout text,
+ pyerr text,
+ stdout text,
+ stderr text)
+ """%self.table)
+ # self._db.execute("""CREATE TABLE IF NOT EXISTS %s_buffers
+ # (msg_id text, result integer, buffer blob)
+ # """%self.table)
+ self._db.commit()
+
+ def _dict_to_list(self, d):
+ """turn a mongodb-style record dict into a list."""
+
+ return [ d[key] for key in self._keys ]
+
+ def _list_to_dict(self, line):
+ """Inverse of dict_to_list"""
+ d = self._defaults()
+ for key,value in zip(self._keys, line):
+ d[key] = value
+
+ return d
+
+ def _render_expression(self, check):
+ """Turn a mongodb-style search dict into an SQL query."""
+ expressions = []
+ args = []
+
+ skeys = set(check.keys())
+ skeys.difference_update(set(self._keys))
+ skeys.difference_update(set(['buffers', 'result_buffers']))
+ if skeys:
+ raise KeyError("Illegal testing key(s): %s"%skeys)
+
+ for name,sub_check in check.iteritems():
+ if isinstance(sub_check, dict):
+ for test,value in sub_check.iteritems():
+ try:
+ op = operators[test]
+ except KeyError:
+ raise KeyError("Unsupported operator: %r"%test)
+ if isinstance(op, tuple):
+ op, join = op
+ expr = "%s %s ?"%(name, op)
+ if isinstance(value, (tuple,list)):
+ expr = '( %s )'%( join.join([expr]*len(value)) )
+ args.extend(value)
+ else:
+ args.append(value)
+ expressions.append(expr)
+ else:
+ # it's an equality check
+ expressions.append("%s IS ?"%name)
+ args.append(sub_check)
+
+ expr = " AND ".join(expressions)
+ return expr, args
+
+ def add_record(self, msg_id, rec):
+ """Add a new Task Record, by msg_id."""
+ d = self._defaults()
+ d.update(rec)
+ d['msg_id'] = msg_id
+ line = self._dict_to_list(d)
+ tups = '(%s)'%(','.join(['?']*len(line)))
+ self._db.execute("INSERT INTO %s VALUES %s"%(self.table, tups), line)
+ self._db.commit()
+
+ def get_record(self, msg_id):
+ """Get a specific Task Record, by msg_id."""
+ cursor = self._db.execute("""SELECT * FROM %s WHERE msg_id==?"""%self.table, (msg_id,))
+ line = cursor.fetchone()
+ if line is None:
+ raise KeyError("No such msg: %r"%msg_id)
+ return self._list_to_dict(line)
+
+ def update_record(self, msg_id, rec):
+ """Update the data in an existing record."""
+ query = "UPDATE %s SET "%self.table
+ sets = []
+ keys = sorted(rec.keys())
+ values = []
+ for key in keys:
+ sets.append('%s = ?'%key)
+ values.append(rec[key])
+ query += ', '.join(sets)
+ query += ' WHERE msg_id == %r'%msg_id
+ self._db.execute(query, values)
+ self._db.commit()
+
+ def drop_record(self, msg_id):
+ """Remove a record from the DB."""
+ self._db.execute("""DELETE FROM %s WHERE mgs_id==?"""%self.table, (msg_id,))
+ self._db.commit()
+
+ def drop_matching_records(self, check):
+ """Remove a record from the DB."""
+ expr,args = self._render_expression(check)
+ query = "DELETE FROM %s WHERE %s"%(self.table, expr)
+ self._db.execute(query,args)
+ self._db.commit()
+
+ def find_records(self, check, id_only=False):
+ """Find records matching a query dict."""
+ req = 'msg_id' if id_only else '*'
+ expr,args = self._render_expression(check)
+ query = """SELECT %s FROM %s WHERE %s"""%(req, self.table, expr)
+ cursor = self._db.execute(query, args)
+ matches = cursor.fetchall()
+ if id_only:
+ return [ m[0] for m in matches ]
+ else:
+ records = {}
+ for line in matches:
+ rec = self._list_to_dict(line)
+ records[rec['msg_id']] = rec
+ return records
+
+__all__ = ['SQLiteDB']
Please sign in to comment.
Something went wrong with that request. Please try again.