Skip to content

Commit

Permalink
Fixes for Oracle MERGE implementation (DM-21748)
Browse files Browse the repository at this point in the history
The _Merge class has to inherit Executable so that Connection can
execute it, the order of base classes for _Merge is important. Call to
bindparam() in generated MERGE needs a `type_` argument to correctly
interpret parameter data using all registered type decorators.
  • Loading branch information
andy-slac committed Oct 14, 2019
1 parent 0429a71 commit 7da68d1
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 2 deletions.
5 changes: 3 additions & 2 deletions python/lsst/daf/butler/registries/oracleRegistry.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,15 @@
from sqlalchemy import create_engine
from sqlalchemy.ext import compiler
from sqlalchemy.sql import ClauseElement, and_, bindparam, select
from sqlalchemy.sql.expression import Executable

from lsst.daf.butler.core.config import Config
from lsst.daf.butler.core.registryConfig import RegistryConfig

from .sqlRegistry import SqlRegistry, SqlRegistryConfig


class _Merge(ClauseElement):
class _Merge(Executable, ClauseElement):
def __init__(self, table, onConflict):
self.table = table
self.onConflict = onConflict
Expand All @@ -48,7 +49,7 @@ def _merge(merge, compiler, **kw):
pkColumns = [col.name for col in table.primary_key]
nonPkColumns = [col for col in allColumns if col not in pkColumns]

selectColumns = [bindparam(col).label(col) for col in allColumns]
selectColumns = [bindparam(col.name, type_=col.type).label(col.name) for col in table.columns]
selectClause = select(selectColumns)

tableAlias = table.alias("t")
Expand Down
9 changes: 9 additions & 0 deletions tests/test_sqlRegistry.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from sqlalchemy import Table, Column, Integer
from sqlalchemy.schema import MetaData
from sqlalchemy.exc import IntegrityError
from sqlalchemy.sql.expression import Executable

import lsst.sphgeom

Expand Down Expand Up @@ -728,6 +729,8 @@ def testInsertConflictSqlite(self):
expect = 'INSERT INTO insert_conflict_test (pk, value, "uniqVal") VALUES (?, ?, ?)' \
' ON CONFLICT (pk) DO NOTHING'
clause = InsertOnConflict(table, onConflict="ignore")
self.assertIsInstance(clause, Executable)
self.assertTrue(clause.supports_execution)
query = clause.compile(dialect=sqlite.dialect())
self.assertEqual(str(query), expect)

Expand All @@ -751,6 +754,8 @@ def testInsertConflictPG(self):
expect = 'INSERT INTO insert_conflict_test (pk, value, \"uniqVal\")' \
' VALUES (%(pk)s, %(value)s, %(uniqVal)s) ON CONFLICT (pk) DO NOTHING'
clause = PostgreSqlRegistry._makeInsertWithConflictImpl(table, onConflict="ignore")
self.assertIsInstance(clause, Executable)
self.assertTrue(clause.supports_execution)
query = clause.compile(dialect=postgresql.dialect())
self.assertEqual(str(query), expect)

Expand All @@ -777,6 +782,8 @@ def testInsertConflictOracle(self):
'ON (t.pk = d.pk)\n' \
'WHEN NOT MATCHED THEN INSERT (pk, value, "uniqVal") VALUES (d.pk, d.value, d."uniqVal")'
clause = _Merge(table, onConflict="ignore")
self.assertIsInstance(clause, Executable)
self.assertTrue(clause.supports_execution)
query = clause.compile(dialect=oracle.dialect())
self.assertEqual(str(query), expect)

Expand All @@ -786,6 +793,8 @@ def testInsertConflictOracle(self):
'WHEN MATCHED THEN UPDATE SET t.value = d.value, t."uniqVal" = d."uniqVal"\n' \
'WHEN NOT MATCHED THEN INSERT (pk, value, "uniqVal") VALUES (d.pk, d.value, d."uniqVal")'
clause = _Merge(table, onConflict="replace")
self.assertIsInstance(clause, Executable)
self.assertTrue(clause.supports_execution)
query = clause.compile(dialect=oracle.dialect())
self.assertEqual(str(query), expect)

Expand Down

0 comments on commit 7da68d1

Please sign in to comment.