-
Notifications
You must be signed in to change notification settings - Fork 6
/
db_setup.py
189 lines (160 loc) · 6.44 KB
/
db_setup.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
import glob
import os
import networkx as nx
import pandas as pd
import sqlalchemy
import sqlalchemy.ext.asyncio
from sqlalchemy import MetaData
from sqlalchemy.schema import ForeignKeyConstraint
from sqlalchemy.ext.declarative import declarative_base
from aidb.config.config_types import python_type_to_sqlalchemy_type
from aidb.utils.constants import BLOB_TABLE_NAMES_TABLE
from sqlalchemy.sql import text
from dataclasses import dataclass
from typing import Optional
@dataclass
class ColumnInfo:
name: str
is_primary_key: bool
refers_to: Optional[tuple] # (table, column)
dtype = None
def extract_column_info(table_name, column_str) -> ColumnInfo:
pk = False
if column_str.startswith("pk_"):
pk = True
column_str = column_str[3:] # get rid of pk_ prefix
t, c = column_str.split('.')
fk = None
if t != table_name:
fk = (t, c)
return ColumnInfo(c, pk, fk)
async def create_db(db_url: str, db_name: str):
dialect = db_url.split("+")[0]
if dialect == "postgresql":
engine = sqlalchemy.ext.asyncio.create_async_engine(db_url, isolation_level='AUTOCOMMIT')
try:
async with engine.begin() as conn:
await conn.execute(text(f"CREATE DATABASE {db_name}"))
except sqlalchemy.exc.ProgrammingError:
print("Database Already exists")
elif dialect == "sqlite":
# sqlite auto creates, do nothing
pass
else:
raise NotImplementedError
return
async def drop_all_tables(conn):
metadata = MetaData(bind=conn)
# Reflect the database to get all table names
await conn.run_sync(metadata.reflect)
await conn.run_sync(metadata.drop_all)
return
async def setup_db(db_url: str, db_name: str, data_dir: str):
gt_dir = f'{data_dir}/ground_truth'
gt_csv_fnames = glob.glob(f'{gt_dir}/*.csv')
gt_csv_fnames.sort()
db_uri = f'{db_url}/{db_name}'
# Connect to db
engine = sqlalchemy.ext.asyncio.create_async_engine(db_uri)
async with engine.begin() as conn:
await drop_all_tables(conn)
metadata = MetaData(bind=conn)
# Create tables
for csv_fname in gt_csv_fnames:
base_fname = os.path.basename(csv_fname)
table_name = base_fname.split('.')[0]
df = pd.read_csv(csv_fname)
columns_info = []
fk_constraints = {}
for column in df.columns:
column_info = extract_column_info(table_name, column)
column_info.dtype = python_type_to_sqlalchemy_type(df[column].dtype)
columns_info.append(column_info)
df.rename(columns={column: column_info.name}, inplace=True)
if column_info.refers_to is not None:
fk_ref_table_name = column_info.refers_to[0]
if fk_ref_table_name not in fk_constraints:
fk_constraints[fk_ref_table_name] = {'cols': [], 'cols_refs': []}
# both tables will have same column name
fk_constraints[fk_ref_table_name]['cols'].append(column_info.name)
fk_constraints[fk_ref_table_name]['cols_refs'].append(
f"{column_info.refers_to[0]}.{column_info.refers_to[1]}")
multi_table_fk_constraints = []
for tbl, fk_cons in fk_constraints.items():
multi_table_fk_constraints.append(ForeignKeyConstraint(fk_cons['cols'], fk_cons['cols_refs']))
_ = sqlalchemy.Table(table_name, metadata, *[
sqlalchemy.Column(c_info.name, c_info.dtype, primary_key=c_info.is_primary_key) for c_info in columns_info
], *multi_table_fk_constraints)
await conn.run_sync(lambda conn: metadata.create_all(conn))
return engine
async def insert_data_in_tables(engine, data_dir: str, only_blob_data: bool):
def get_insertion_order(conn, gt_csv_files):
metadata = MetaData()
metadata.reflect(conn)
table_graph = nx.DiGraph()
for table in metadata.sorted_tables:
for fk_col in table.foreign_keys:
parent_table = str(fk_col.column).split('.')[0]
table_graph.add_edge(parent_table, table.name)
table_order = nx.topological_sort(table_graph)
ordered_csv_files = []
for table_name in table_order:
csv_file_name = f"{table_name}.csv"
for f in gt_csv_files:
if csv_file_name in f:
ordered_csv_files.append(f)
break
return ordered_csv_files
gt_dir = f'{data_dir}/ground_truth'
gt_csv_fnames = glob.glob(f'{gt_dir}/*.csv')
async with engine.begin() as conn:
gt_csv_fnames = await conn.run_sync(get_insertion_order, gt_csv_fnames, )
# Create tables
for csv_fname in gt_csv_fnames:
base_fname = os.path.basename(csv_fname)
table_name = base_fname.split('.')[0]
if only_blob_data and not table_name.startswith('blobs'):
continue
df = pd.read_csv(csv_fname)
for column in df.columns:
column_info = extract_column_info(table_name, column)
df.rename(columns={column: column_info.name}, inplace=True)
await conn.run_sync(lambda conn: df.to_sql(table_name, conn, if_exists='append', index=False))
async def clear_all_tables(engine):
def tmp(conn):
metadata = MetaData()
metadata.reflect(conn)
for table in metadata.sorted_tables:
if table.name.startswith('blobs'):
continue
conn.execute(table.delete())
async with engine.begin() as conn:
await conn.run_sync(tmp)
async def setup_config_tables(engine):
def create_blob_metadata_table(conn):
Base = declarative_base()
class BlobTables(Base):
__tablename__ = BLOB_TABLE_NAMES_TABLE
table_name = sqlalchemy.Column(sqlalchemy.String, primary_key=True)
blob_key = sqlalchemy.Column(sqlalchemy.String, primary_key=True)
Base.metadata.create_all(conn)
def get_blob_table_names_and_columns(conn):
metadata = MetaData()
metadata.reflect(conn)
table_names = [table.name for table in metadata.sorted_tables if table.name.startswith('blob')]
# Get the columns for each table
table_names_and_columns = {}
for table_name in table_names:
table = metadata.tables[table_name]
table_names_and_columns[table_name] = [column.name for column in table.columns if column.primary_key]
return table_names, table_names_and_columns
async with engine.begin() as conn:
await conn.run_sync(create_blob_metadata_table)
blob_table_names, columns = await conn.run_sync(get_blob_table_names_and_columns)
for table_name in blob_table_names:
for column in columns[table_name]:
# Insert into blob metadata table
await conn.execute(
text(f'INSERT INTO {BLOB_TABLE_NAMES_TABLE} VALUES (:table_name, :blob_key)')
.bindparams(table_name=table_name, blob_key=column)
)