Skip to content

HTTPS clone URL

Subversion checkout URL

You can clone with
or
.
Download ZIP
Newer
Older
100644 157 lines (128 sloc) 4.616 kb
fd281cf @coleifer Adding a little helper to introspect postgresql databases and generate
authored
1 from optparse import OptionParser
2 from psycopg2 import OperationalError
3 import re
4 import sys
5
6 from peewee import *
7
8 # thanks, django
9 reverse_mapping = {
10 16: 'BooleanField',
11 20: 'IntegerField',
12 21: 'IntegerField',
13 23: 'IntegerField',
14 25: 'TextField',
15 700: 'FloatField',
16 701: 'FloatField',
17 1043: 'CharField',
18 1114: 'DateTimeField',
19 1184: 'DateTimeField',
20 1700: 'DecimalField',
21 }
22
23 def get_conn(database, **connect):
24 db = PostgresqlDatabase(database, **connect)
25 try:
26 db.connect()
27 except OperationalError:
28 err('error connecting to %s' % database)
29 raise
30 return db
31
32 def get_columns(conn, table):
33 curs = conn.execute('select * from %s limit 1' % table)
34 return dict((c.name, reverse_mapping.get(c.type_code, 'UnknownFieldType')) for c in curs.description)
35
36 def get_foreign_keys(conn, table):
37 framing = '''
38 SELECT
39 kcu.column_name, ccu.table_name, ccu.column_name
40 FROM information_schema.table_constraints AS tc
41 JOIN information_schema.key_column_usage AS kcu
42 ON tc.constraint_name = kcu.constraint_name
43 JOIN information_schema.constraint_column_usage AS ccu
44 ON ccu.constraint_name = tc.constraint_name
45 WHERE
46 tc.constraint_type = 'FOREIGN KEY' AND
47 tc.table_name = %s
48 '''
49 fks = []
50 for row in conn.execute(framing, (table,)):
51 fks.append(row)
52 return fks
53
54 frame = '''from peewee import *
55
56 database = PostgresqlDatabase('%s', **%s)
57
58 class UnknownFieldType(object):
59 pass
60
61 class BaseModel(Model):
62 class Meta:
63 database = database
64 '''
65
66 def introspect(database, **connect):
67 conn = get_conn(database, **connect)
68 tables = conn.get_tables()
69
70 models = {}
71 table_to_model = {}
efb13e3 @coleifer Speed up foreign key lookups by caching them
authored
72 table_fks = {}
fd281cf @coleifer Adding a little helper to introspect postgresql databases and generate
authored
73
74 # first pass, just raw column names and peewee type
75 for table in tables:
76 models[table] = get_columns(conn, table)
77 table_to_model[table] = tn(table)
efb13e3 @coleifer Speed up foreign key lookups by caching them
authored
78 table_fks[table] = get_foreign_keys(conn, table)
fd281cf @coleifer Adding a little helper to introspect postgresql databases and generate
authored
79
80 # second pass, convert foreign keys, assign primary keys, and mark
81 # explicit column names where they don't match the "pythonic" ones
82 col_meta = {}
83 for table in tables:
84 col_meta[table] = {}
efb13e3 @coleifer Speed up foreign key lookups by caching them
authored
85 for column, rel_table, rel_pk in table_fks[table]:
fd281cf @coleifer Adding a little helper to introspect postgresql databases and generate
authored
86 models[table][column] = 'ForeignKeyField'
87 models[rel_table][rel_pk] = 'PrimaryKeyField'
88 col_meta[table][column] = {'to': table_to_model[rel_table]}
89
90 for column in models[table]:
91 col_meta[table].setdefault(column, {})
92 if column != cn(column):
93 col_meta[table][column]['db_column'] = "'%s'" % column
94
95 # write generated code to standard out
96 print frame % (database, repr(connect))
97
98 # print the models
99 def print_model(model, seen):
efb13e3 @coleifer Speed up foreign key lookups by caching them
authored
100 for _, rel_table, _ in table_fks[model]:
fd281cf @coleifer Adding a little helper to introspect postgresql databases and generate
authored
101 if rel_table not in seen:
102 seen.add(rel_table)
103 print_model(rel_table, seen)
104
105 ttm = table_to_model[model]
106 print 'class %s(BaseModel):' % ttm
107 cols = models[model]
108 for column, field_class in ds(cols):
109 if column == 'id' and field_class in ('IntegerField', 'PrimaryKeyField'):
110 continue
111
112 field_params = ', '.join([
113 '%s=%s' % (k, v) for k, v in col_meta[model][column].items()
114 ])
115 print ' %s = %s(%s)' % (cn(column), field_class, field_params)
116 print
117
118 print ' class Meta:'
119 print ' db_table = \'%s\'' % model
120 print
121 seen.add(model)
122
123 seen = set()
124 for model, cols in ds(models):
125 if model not in seen:
126 print_model(model, seen)
127
128 # misc
129 tn = lambda t: t.title().replace('_', '')
130 cn = lambda c: re.sub('_id$', '', c.lower())
131 ds = lambda d: sorted(d.items(), key=lambda t:t[0])
132
133 def err(msg):
134 print '\033[91m%s\033[0m' % msg
135
136
137 if __name__ == '__main__':
138 parser = OptionParser(usage='usage: %prog [options] database_name')
139 ao = parser.add_option
140 ao('-H', '--host', dest='host')
141 ao('-p', '--port', dest='port', type='int')
142 ao('-u', '--user', dest='user')
143 ao('-P', '--password', dest='password')
144
145 options, args = parser.parse_args()
146 ops = ('host', 'port', 'user', 'password')
147 connect = dict((o, getattr(options, o)) for o in ops if getattr(options, o))
148
149 if len(args) < 1:
150 print 'error: missing required parameter "database"'
151 parser.print_help()
152 sys.exit(1)
153
154 database = args[-1]
155
156 introspect(database, **connect)
Something went wrong with that request. Please try again.