diff --git a/django/db/backends/oracle/base.py b/django/db/backends/oracle/base.py index 34c7b131fe90e..011ec066a9897 100644 --- a/django/db/backends/oracle/base.py +++ b/django/db/backends/oracle/base.py @@ -376,6 +376,30 @@ def __init__(self, param, cursor, strings_only=False): self.input_size = None +class VariableWrapper(object): + """ + An adapter class for cursor variables that prevents the wrapped object + from being converted into a string when used to instanciate an OracleParam. + This can be used generally for any other object that should be passed into + Cursor.execute as-is. + """ + + def __init__(self, var): + self.var = var + + def bind_parameter(self, cursor): + return self.var + + def __getattr__(self, key): + return getattr(self.var, key) + + def __setattr__(self, key, value): + if key == 'var': + self.__dict__[key] = value + else: + setattr(self.var, key, value) + + class InsertIdVar(object): """ A late-binding cursor variable that can be passed to Cursor.execute @@ -384,7 +408,7 @@ class InsertIdVar(object): """ def bind_parameter(self, cursor): - param = cursor.var(Database.NUMBER) + param = cursor.cursor.var(Database.NUMBER) cursor._insert_id_var = param return param @@ -439,7 +463,7 @@ def execute(self, query, params=None): return self.cursor.execute(query, self._param_generator(params)) except DatabaseError, e: # cx_Oracle <= 4.4.0 wrongly raises a DatabaseError for ORA-01400. - if e.args[0].code == 1400 and not isinstance(e, IntegrityError): + if hasattr(e.args[0], 'code') and e.args[0].code == 1400 and not isinstance(e, IntegrityError): e = IntegrityError(e.args[0]) raise e @@ -463,7 +487,7 @@ def executemany(self, query, params=None): [self._param_generator(p) for p in formatted]) except DatabaseError, e: # cx_Oracle <= 4.4.0 wrongly raises a DatabaseError for ORA-01400. - if e.args[0].code == 1400 and not isinstance(e, IntegrityError): + if hasattr(e.args[0], 'code') and e.args[0].code == 1400 and not isinstance(e, IntegrityError): e = IntegrityError(e.args[0]) raise e @@ -523,6 +547,12 @@ def _rowfactory(self, row): casted.append(value) return tuple(casted) + def var(self, *args): + return VariableWrapper(self.cursor.var(*args)) + + def arrayvar(self, *args): + return VariableWrapper(self.cursor.arrayvar(*args)) + def __getattr__(self, attr): if attr in self.__dict__: return self.__dict__[attr] diff --git a/tests/regressiontests/backends/tests.py b/tests/regressiontests/backends/tests.py index da3f4ab17c744..28dfd7f208ffe 100644 --- a/tests/regressiontests/backends/tests.py +++ b/tests/regressiontests/backends/tests.py @@ -22,6 +22,16 @@ def test_dbms_session(self): else: return True + def test_cursor_var(self): + # If the backend is Oracle, test that we can pass cursor variables + # as query parameters. + if settings.DATABASE_ENGINE == 'oracle': + cursor = connection.cursor() + var = cursor.var(backend.Database.STRING) + cursor.execute("BEGIN %s := 'X'; END; ", [var]) + self.assertEqual(var.getvalue(), 'X') + + class LongString(unittest.TestCase): def test_long_string(self):