Permalink
Browse files

full sqlite

git-svn-id: svn://forre.st/undns@1223 470744a7-cac9-478e-843e-5ec1b25c69e8
  • Loading branch information...
forrest
forrest committed Apr 14, 2011
1 parent 2d4d183 commit c68ce7cb8eb1389cb20ceb080efd293dc85abf6b
Showing with 359 additions and 182 deletions.
  1. +60 −13 db.py
  2. +12 −9 main.py
  3. +78 −0 manage.py
  4. +26 −41 packet.py
  5. +120 −113 server.py
  6. +63 −6 util.py
View
73 db.py
@@ -1,7 +1,48 @@
-import bsddb
import random
-import atexit
-import weakref
+import sqlite3
+
+class SQLiteDict(object):
+ def __init__(self, db, table):
+ self._db = db
+ self._table = table
+
+ self._db.execute('CREATE TABLE IF NOT EXISTS %s(key BLOB PRIMARY KEY NOT NULL, value BLOB NOT NULL)' % (self._table,))
+
+ def __iter__(self):
+ for row in self._db.execute("SELECT key FROM %s" % (self._table,)):
+ yield str(row[0])
+ def iterkeys(self):
+ return iter(self)
+ def keys(self):
+ return list(self)
+
+ def itervalues(self):
+ for row in self._db.execute("SELECT value FROM %s" % (self._table,)):
+ yield str(row[0])
+ def values(self):
+ return list(self.itervalues)
+
+ def iteritems(self):
+ for row in self._db.execute("SELECT key, value FROM %s" % (self._table,)):
+ yield (str(row[0]), str(row[1]))
+ def items(self):
+ return list(self.iteritems())
+
+ def __setitem__(self, key, value):
+ if self._db.execute("SELECT key FROM %s where key=?" % (self._table,), (buffer(key),)).fetchone() is None:
+ self._db.execute('INSERT INTO %s(key, value) VALUES (?, ?)' % (self._table,), (buffer(key), buffer(value)))
+ else:
+ self._db.execute('UPDATE %s SET value=? WHERE key=?' % (self._table,), (buffer(value), buffer(key)))
+
+ def __getitem__(self, key):
+ row = self._db.execute('SELECT value FROM %s WHERE key=?' % (self._table,), (buffer(key),)).fetchone()
+ if row is None:
+ raise KeyError(key)
+ else:
+ return str(row[0])
+
+ def __delitem__(self, key):
+ self._db.execute("DELETE FROM %s WHERE key=?" % (self._table,), (buffer(key),))
class CachingDictWrapper(object):
def __init__(self, inner, cache_size=10000):
@@ -21,7 +62,7 @@ def __getitem__(self, key):
try:
return self._cache[key]
except KeyError:
- print "cache failed for", key
+ print "cache failed for", repr(key)
value = self._inner[key]
self._add_to_cache(key, value)
return value
@@ -108,6 +149,8 @@ def __getitem__(self, key):
return self._decode(key, self._inner[key])
def __setitem__(self, key, value):
self._inner[key] = self._encode(key, value)
+ def __delitem__(self, key):
+ del self._inner[key]
def __contains__(self, key):
return key in self._inner
def __iter__(self):
@@ -118,13 +161,17 @@ def iteritems(self):
for k, v in self._inner.iteritems():
yield k, self._decode(k, v)
-def try_sync(db_weakref):
- db = db_weakref()
- if db is None:
- return
- db.sync()
-def safe_open_db(filename):
- db = bsddb.hashopen(filename)
- atexit.register(try_sync, weakref.ref(db))
- return db
+class JSONValueWrapper(ValueDictWrapper):
+ def _encode(self, addr, content):
+ return json.dumps(content)
+ def _decode(self, addr, data):
+ return json.loads(data)
+
+import cPickle
+
+class PickleValueWrapper(ValueDictWrapper):
+ def _encode(self, addr, content):
+ return cPickle.dumps(content) #, cPickle.HIGHEST_PROTOCOL)
+ def _decode(self, addr, data):
+ return cPickle.loads(data)
View
21 main.py
@@ -3,6 +3,7 @@
import argparse
import random
import sys
+import sqlite3
import twisted.names.common, twisted.names.client, twisted.names.dns, twisted.names.server, twisted.names.error, twisted.names.authority
del twisted
@@ -28,8 +29,8 @@
help="run a TCP+UDP recursive dns server on PORT; you likely do want this - this is for clients",
type=int, action="append", default=[], dest="recursive_dns_ports")
parser.add_argument("-d", "--dht-port", metavar="PORT",
- help="use UDP port PORT to connect to other DHT nodes and listen for connections (if not specified a random high port is chosen)",
- type=int, action="store", default=random.randrange(49152, 65536), dest="dht_port")
+ help="use UDP port PORT to connect to other DHT nodes and listen for connections (default: last used or random if never used)",
+ type=int, action="store", default=None, dest="dht_port")
parser.add_argument("-n", "--node", metavar="ADDR:PORT",
help="connect to existing DHT node at ADDR listening on UDP port PORT",
action="append", default=[], dest="dht_nodes")
@@ -46,22 +47,17 @@
rng = Random.new().read
-print name, "on port", args.dht_port
-
def parse(x):
if ':' not in x:
return ('127.0.0.1', int(x))
ip, port = x.split(':')
return ip, int(port)
knownNodes = map(parse, args.dht_nodes)
-n = server.UnDNSNode(udpPort=args.dht_port, db_prefix=args.config, rng=rng)
+n = server.UnDNSNode(udpPort=args.dht_port, config_db=sqlite3.connect(args.config, isolation_level=None), rng=rng)
n.joinNetwork(knownNodes)
-rpc_factory = protocol.ServerFactory()
-rpc_factory.protocol = server.RPCProtocol
-for port in args.rpc_ports:
- reactor.listenTCP(port, rpc_factory)
+print name, "on port", n.port
resolver = server.UnDNSResolver(n)
@@ -75,4 +71,11 @@ def parse(x):
reactor.listenTCP(port, recursive_dns)
reactor.listenUDP(port, names.dns.DNSDatagramProtocol(recursive_dns))
+rpc_factory = protocol.ServerFactory()
+rpc_factory.protocol = server.RPCProtocol
+rpc_factory.node = n
+rpc_factory.rng = rng
+for port in args.rpc_ports:
+ reactor.listenTCP(port, rpc_factory, interface="127.0.0.1")
+
reactor.run()
View
@@ -0,0 +1,78 @@
+import subprocess
+import os
+import argparse
+import random
+import sys
+import json
+
+from twisted.internet import reactor, protocol, endpoints, defer
+from twisted.protocols import basic
+
+try:
+ __version__ = subprocess.Popen(["svnversion", os.path.dirname(sys.argv[0])], stdout=subprocess.PIPE).stdout.read().strip()
+except IOError:
+ __version__ = "unknown"
+
+name = "UnDNS manager (version %s)" % (__version__,)
+
+parser = argparse.ArgumentParser(description=name)
+parser.add_argument('--version', action='version', version=__version__)
+parser.add_argument("-c", "--connect", metavar="ADDR:PORT",
+ help="connect to server running on ADDR:PORT (default: 127.0.0.1:4000)",
+ type=str, action="store", default="4000", dest="connect")
+
+#config_default = os.path.join(os.path.expanduser('~'), '.undns')
+#parser.add_argument("-c", "--config", metavar="PATH",
+# help="use configuration database at PATH (default: %s)" % (config_default,),
+# action="store", default=config_default, dest="config")
+
+parser.add_argument(metavar="COMMAND",
+ help="command to run",
+ type=str, action="store", dest="command")
+parser.add_argument(metavar="ARGUMENT",
+ help="argument to command",
+ type=str, action="store", nargs="*", dest="arguments")
+
+args = parser.parse_args()
+
+class RPCClient(basic.LineOnlyReceiver):
+ def __init__(self):
+ self.deferreds = []
+ self.queue = []
+
+ def send_queries(self):
+ while self.queue:
+ self.sendLine(self.queue.pop())
+
+ def connectionMade(self):
+ self.send_queries()
+
+ def query(self, command, arguments):
+ self.queue.append(json.dumps([command, arguments]))
+ self.send_queries()
+
+ df = defer.Deferred()
+ self.deferreds.append(df)
+ return df
+
+ def lineReceived(self, line):
+ self.deferreds.pop(0).callback(json.loads(line))
+
+def parse(x):
+ if ':' not in x:
+ return ('127.0.0.1', int(x))
+ ip, port = x.split(':')
+ return ip, int(port)
+
+host, port = parse(args.connect)
+
+def got_response(resp):
+ print repr(resp)
+
+arguments = [open(a[1:]).read() if a.startswith('@') else a for a in args.arguments]
+
+f = protocol.ClientFactory()
+f.protocol = RPCClient
+endpoints.TCP4ClientEndpoint(reactor, host, port).connect(f).addCallback(lambda rpcclient: rpcclient.query(args.command, arguments)).addCallback(got_response).addBoth(lambda _: reactor.stop())
+
+reactor.run()
View
@@ -3,13 +3,13 @@
import zlib
from Crypto.PublicKey import RSA
-from twisted.names import authority
+from twisted.names import authority, common
import util
class BindStringAuthority(authority.BindAuthority):
def __init__(self, contents, origin):
- names.common.ResolverBase.__init__(self)
+ common.ResolverBase.__init__(self)
self.origin = origin
lines = contents.splitlines(True)
lines = self.stripComments(lines)
@@ -22,7 +22,7 @@ class DomainKey(object):
@classmethod
def generate(cls, rng):
- return cls(RSA.generate(1024, rng))
+ return cls(RSA.generate(4096, rng))
@classmethod
def from_binary(cls, x):
@@ -37,8 +37,11 @@ def __init__(self, private_key):
def to_binary(self):
return zlib.compress(json.dumps(util.key_to_tuple(self._private_key)))
- def get_address(self):
- return util.key_to_address(self._private_key.publickey())
+ def get_name(self):
+ return util.key_to_name(self._private_key.publickey())
+
+ def get_name_hash(self):
+ return util.key_to_name_hash(self._private_key.publickey())
def encode(self, record, rng):
return DomainPacket(self._private_key.publickey(), record, self._private_key.sign(record.get_hash(), rng))
@@ -55,27 +58,23 @@ def __init__(self, public_key, record, signature):
if public_key.has_private():
raise ValueError("key not public")
assert isinstance(record, DomainRecord)
- assert isinstance(signature, tuple)
+ signature = tuple(signature)
self._public_key = public_key
self._record = record
self._signature = signature
-
- self._address = util.key_to_address(self._public_key)
- self._address_hash = util.hash_address_hash(self._address).digest()
- self._zone = None
def to_binary(self):
return zlib.compress(json.dumps(dict(public_key=util.key_to_tuple(self._public_key), record=self._record.to_obj(), signature=self._signature)))
def verify_signature(self):
- return public_key.verify(self._record.get_hash(), signature)
+ return self._public_key.verify(self._record.get_hash(), self._signature)
- def get_address(self):
- return self._address
+ def get_name(self):
+ return util.key_to_name(self._public_key)
- def get_address_hash(self):
- return self._address_hash
+ def get_name_hash(self):
+ return util.key_to_name_hash(self._public_key)
def get_record(self):
return self._record
@@ -84,23 +83,21 @@ class DomainRecord(object):
"Information about a domain"
@classmethod
- def from_obj(cls, (zone_file, start_time, end_time)):
- return cls(zone_file, start_time, end_time)
+ def from_obj(cls, (zone_file, end_time)):
+ return cls(zone_file, end_time)
- def __init__(self, zone_file, start_time, end_time):
+ def __init__(self, zone_file, end_time):
assert isinstance(zone_file, unicode)
- assert isinstance(start_time, (int, long))
assert isinstance(end_time, (int, long))
self._zone_file = zone_file
- self._start_file = start_time
self._end_time = end_time
def to_obj(self):
- return (self._zone_file, self._start_time, self._end_time)
+ return (self._zone_file, self._end_time)
def to_binary(self):
- return json.dumps(dict(zone_file=self._zone_file, start_time=self._start_time, end_time=self._end_time))
+ return json.dumps(dict(zone_file=self._zone_file, end_time=self._end_time))
def get_zone_file(self):
return self._zone_file
@@ -109,33 +106,21 @@ def get_zone(self, address):
assert not address.endswith('.')
return BindStringAuthority(self._zone_file.encode('utf8'), address + '.')
- def get_start_time(self):
- return self._start_time
-
def get_end_time(self):
return self._end_time
def get_hash(self):
- return util.hash_sign(self.to_binary()).digest()
+ b = self.to_binary()
+ return util.ripemd160(b).digest() + hashlib.sha512(b).digest()
if __name__ == '__main__':
from Crypto import Random
rng = Random.new().read
- h = "hello, world!"
-
- a = MyIdentity.generate(rng)
+ a = DomainKey.generate(rng)
- print repr(a.to_binary())
- print repr(a.get_id())
- print repr(a.to_binary_public())
+ print a.get_name()
+ print util.name_hash_to_name(a.get_name_hash())
- print
- d = a.sign(h, rng)
- print repr(d)
-
- b = TheirIdentity.from_binary(a.to_binary_public())
-
- print repr(b.get_id())
- print b.verify(h, d)
-
+ print util.name_to_name_hash(a.get_name()).encode('hex')
+ print a.get_name_hash().encode('hex')
Oops, something went wrong.

0 comments on commit c68ce7c

Please sign in to comment.