Skip to content

Commit

Permalink
ENH: data.insert supports primary keys (#136)
Browse files Browse the repository at this point in the history
  • Loading branch information
CalmDownKarm authored and sanand0 committed Apr 1, 2019
1 parent ce28d86 commit df764df
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 13 deletions.
3 changes: 3 additions & 0 deletions gramex/data.py
Expand Up @@ -388,6 +388,9 @@ def insert(url, meta={}, args=None, engine=None, table=None, ext=None, id=None,
rows = _pop_columns(rows, [col.name for col in cols], meta['ignored'])
if '.' in table:
kwargs['schema'], table = table.rsplit('.', 1)
# pandas does not document engine.dialect.has_table so it might change.
if not engine.dialect.has_table(engine, table) and id:
engine.execute(pd.io.sql.get_schema(rows, name=table, keys=id, con=engine))
rows.to_sql(table, engine, if_exists='append', index=False, **kwargs)
return len(rows)
else:
Expand Down
58 changes: 45 additions & 13 deletions testlib/test_data.py
Expand Up @@ -10,6 +10,7 @@
import gramex.data
import gramex.cache
import pandas as pd
import sqlalchemy as sa
from orderedattrdict import AttrDict
from nose.plugins.skip import SkipTest
from nose.tools import eq_, ok_, assert_raises
Expand All @@ -18,6 +19,11 @@
import dbutils
from . import folder, sales_file

server = AttrDict(
mysql=os.environ.get('MYSQL_SERVER', 'localhost'),
postgres=os.environ.get('POSTGRES_SERVER', 'localhost'),
)


def eqframe(actual, expected, **kwargs):
'''Same as assert_frame_equal or afe, but does not compare index'''
Expand All @@ -28,10 +34,6 @@ def eqframe(actual, expected, **kwargs):
class TestFilter(unittest.TestCase):
sales = gramex.cache.open(sales_file, 'xlsx')
db = set()
server = AttrDict(
mysql=os.environ.get('MYSQL_SERVER', 'localhost'),
postgres=os.environ.get('POSTGRES_SERVER', 'localhost'),
)

def test_get_engine(self):
check = gramex.data.get_engine
Expand Down Expand Up @@ -356,11 +358,11 @@ def check_filter_db(self, dbname, url, na_position, sum_na=True):
query='SELECT * FROM {x} WHERE {p} > 0')

def test_mysql(self):
url = dbutils.mysql_create_db(self.server.mysql, 'test_filter', sales=self.sales)
url = dbutils.mysql_create_db(server.mysql, 'test_filter', sales=self.sales)
self.check_filter_db('mysql', url, na_position='first')

def test_postgres(self):
url = dbutils.postgres_create_db(self.server.postgres, 'test_filter', **{
url = dbutils.postgres_create_db(server.postgres, 'test_filter', **{
'sales': self.sales, 'filter.sales': self.sales})
self.check_filter_db('postgres', url, na_position='last')
self.check_filter(url=url, table='filter.sales', na_position='last', sum_na=True)
Expand All @@ -372,9 +374,9 @@ def test_sqlite(self):
@classmethod
def tearDownClass(cls):
if 'mysql' in cls.db:
dbutils.mysql_drop_db(cls.server.mysql, 'test_filter')
dbutils.mysql_drop_db(server.mysql, 'test_filter')
if 'postgres' in cls.db:
dbutils.postgres_drop_db(cls.server.postgres, 'test_filter')
dbutils.postgres_drop_db(server.postgres, 'test_filter')
if 'sqlite' in cls.db:
dbutils.sqlite_drop_db('test_filter.db')

Expand All @@ -392,6 +394,7 @@ def setUpClass(cls):
'sales': ['0', -100],
# Do not add growth column
}
cls.db = set()

def test_insert_frame(self):
raise SkipTest('TODO: write insert test cases for DataFrame')
Expand Down Expand Up @@ -434,19 +437,48 @@ def test_insert_new_file(self):
afe(actual, expected, check_like=True)

def test_insert_mysql(self):
raise SkipTest('TODO: write insert test cases for MySQL')
url = dbutils.mysql_create_db(server.mysql, 'test_insert')
self.check_insert_db(url, 'mysql')

def test_insert_postgres(self):
raise SkipTest('TODO: write insert test cases for PostgreSQL')
url = dbutils.postgres_create_db(server.postgres, 'test_insert')
self.check_insert_db(url, 'postgres')

def test_insert_sqlite(self):
raise SkipTest('TODO: write insert test cases for SQLite')
url = dbutils.sqlite_create_db('test_insert.db')
self.check_insert_db(url, 'sqlite')

def check_insert_db(self, url, dbname):
self.db.add(dbname)
rows = self.insert_rows.copy()
rows['index'] = [1, 2] # create a primary key
inserted = gramex.data.insert(url, args=rows, table='test_insert', id='index')
eq_(inserted, 2)
# query table here
actual = gramex.data.filter(url, table='test_insert')
expected = pd.DataFrame(rows)
for df in [actual, expected]:
df['sales'] = df['sales'].astype(float)
afe(actual, expected, check_like=True)
# Check if it created a primary key
engine = sa.create_engine(url)
insp = sa.inspect(engine)
ok_('index' in insp.get_pk_constraint('test_insert')['constrained_columns'])
# Inserting duplicate keys raises an Exception
with assert_raises(sa.exc.IntegrityError):
gramex.data.insert(url, args=rows, table='test_insert', id='index')

@classmethod
def tearDownClass(cls):
for path in cls.tmpfiles:
if os.path.exists(path):
os.remove(path)
if 'mysql' in cls.db:
dbutils.mysql_drop_db(server.mysql, 'test_insert')
if 'postgres' in cls.db:
dbutils.postgres_drop_db(server.postgres, 'test_insert')
if 'sqlite' in cls.db:
dbutils.sqlite_drop_db('test_insert.db')


class TestEdit(unittest.TestCase):
Expand Down Expand Up @@ -693,14 +725,14 @@ def test_filtercols_with_filter_unicode_values(self):
eqframe(val, self.unique_of(expected, key))


class TestFilterColsFrame(FilterColsMixin):
class TestFilterColsFrame(unittest.TestCase, FilterColsMixin):
urls = {
'sales': {'url': FilterColsMixin.sales},
'census': {'url': FilterColsMixin.census}
}


class TestFilterColsDB(FilterColsMixin):
class TestFilterColsDB(unittest.TestCase, FilterColsMixin):
urls = {}

@classmethod
Expand Down

0 comments on commit df764df

Please sign in to comment.