/
basedb.py
164 lines (133 loc) · 5.76 KB
/
basedb.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
# vim: set et sw=4 ts=4 sts=4 ff=unix fenc=utf8:
# Author: Binux<i@binux.com>
# http://binux.me
# Created on 2012-08-30 17:43:49
from __future__ import unicode_literals, division, absolute_import
import logging
logger = logging.getLogger('database.basedb')
from six import itervalues
from pyspider.libs import utils
class BaseDB:
'''
BaseDB
dbcur should be overwirte
'''
__tablename__ = None
placeholder = '%s'
maxlimit = -1
@staticmethod
def escape(string):
return '`%s`' % string
@property
def dbcur(self):
raise NotImplementedError
def _execute(self, sql_query, values=[]):
dbcur = self.dbcur
dbcur.execute(sql_query, values)
return dbcur
def _select(self, tablename=None, what="*", where="", where_values=[], offset=0, limit=None):
tablename = self.escape(tablename or self.__tablename__)
if isinstance(what, list) or isinstance(what, tuple) or what is None:
what = ','.join(self.escape(f) for f in what) if what else '*'
sql_query = "SELECT %s FROM %s" % (what, tablename)
if where:
sql_query += " WHERE %s" % where
if limit:
sql_query += " LIMIT %d, %d" % (offset, limit)
elif offset:
sql_query += " LIMIT %d, %d" % (offset, self.maxlimit)
logger.debug("<sql: %s>", sql_query)
for row in self._execute(sql_query, where_values):
yield row
def _select2dic(self, tablename=None, what="*", where="", where_values=[],
order=None, offset=0, limit=None):
tablename = self.escape(tablename or self.__tablename__)
if isinstance(what, list) or isinstance(what, tuple) or what is None:
what = ','.join(self.escape(f) for f in what) if what else '*'
sql_query = "SELECT %s FROM %s" % (what, tablename)
if where:
sql_query += " WHERE %s" % where
if order:
sql_query += ' ORDER BY %s' % order
if limit:
sql_query += " LIMIT %d, %d" % (offset, limit)
elif offset:
sql_query += " LIMIT %d, %d" % (offset, self.maxlimit)
logger.debug("<sql: %s>", sql_query)
dbcur = self._execute(sql_query, where_values)
# f[0] may return bytes type
# https://github.com/mysql/mysql-connector-python/pull/37
fields = [utils.text(f[0]) for f in dbcur.description]
for row in dbcur:
yield dict(zip(fields, row))
def _replace(self, tablename=None, **values):
tablename = self.escape(tablename or self.__tablename__)
if values:
_keys = ", ".join(self.escape(k) for k in values)
_values = ", ".join([self.placeholder, ] * len(values))
sql_query = "REPLACE INTO %s (%s) VALUES (%s)" % (tablename, _keys, _values)
else:
sql_query = "REPLACE INTO %s DEFAULT VALUES" % tablename
logger.debug("<sql: %s>", sql_query)
if values:
dbcur = self._execute(sql_query, list(itervalues(values)))
else:
dbcur = self._execute(sql_query)
return dbcur.lastrowid
def _insert(self, tablename=None, **values):
tablename = self.escape(tablename or self.__tablename__)
if values:
_keys = ", ".join((self.escape(k) for k in values))
_values = ", ".join([self.placeholder, ] * len(values))
sql_query = "INSERT INTO %s (%s) VALUES (%s)" % (tablename, _keys, _values)
else:
sql_query = "INSERT INTO %s DEFAULT VALUES" % tablename
logger.debug("<sql: %s>", sql_query)
if values:
dbcur = self._execute(sql_query, list(itervalues(values)))
else:
dbcur = self._execute(sql_query)
return dbcur.lastrowid
def _update(self, tablename=None, where="1=0", where_values=[], **values):
tablename = self.escape(tablename or self.__tablename__)
_key_values = ", ".join([
"%s = %s" % (self.escape(k), self.placeholder) for k in values
])
sql_query = "UPDATE %s SET %s WHERE %s" % (tablename, _key_values, where)
logger.debug("<sql: %s>", sql_query)
return self._execute(sql_query, list(itervalues(values)) + list(where_values))
def _delete(self, tablename=None, where="1=0", where_values=[]):
tablename = self.escape(tablename or self.__tablename__)
sql_query = "DELETE FROM %s" % tablename
if where:
sql_query += " WHERE %s" % where
logger.debug("<sql: %s>", sql_query)
return self._execute(sql_query, where_values)
if __name__ == "__main__":
import sqlite3
class DB(BaseDB):
__tablename__ = "test"
placeholder = "?"
def __init__(self):
self.conn = sqlite3.connect(":memory:")
cursor = self.conn.cursor()
cursor.execute(
'''CREATE TABLE `%s` (id INTEGER PRIMARY KEY AUTOINCREMENT, name, age)'''
% self.__tablename__
)
@property
def dbcur(self):
return self.conn.cursor()
db = DB()
assert db._insert(db.__tablename__, name="binux", age=23) == 1
assert db._select(db.__tablename__, "name, age").next() == ("binux", 23)
assert db._select2dic(db.__tablename__, "name, age").next()["name"] == "binux"
assert db._select2dic(db.__tablename__, "name, age").next()["age"] == 23
db._replace(db.__tablename__, id=1, age=24)
assert db._select(db.__tablename__, "name, age").next() == (None, 24)
db._update(db.__tablename__, "id = 1", age=16)
assert db._select(db.__tablename__, "name, age").next() == (None, 16)
db._delete(db.__tablename__, "id = 1")
assert [row for row in db._select(db.__tablename__)] == []