Skip to content

Commit

Permalink
Add support for UUID. Fixes #177
Browse files Browse the repository at this point in the history
Accept UUID objects as parameters.  If pyodbc.native_uuid is True, return SQL_GUID columns
as UUID objects.  If False, the default for backwards compatibility, return them as Unicode
strings.
  • Loading branch information
mkleehammer committed Feb 19, 2017
1 parent 28af951 commit 2ad7a9c
Show file tree
Hide file tree
Showing 7 changed files with 176 additions and 30 deletions.
16 changes: 15 additions & 1 deletion src/dbspecific.h
@@ -1,4 +1,3 @@

#ifndef DBSPECIFIC_H
#define DBSPECIFIC_H

Expand Down Expand Up @@ -29,4 +28,19 @@ struct SQL_SS_TIME2_STRUCT
SQLUINTEGER fraction;
};

// The SQLGUID type isn't always available when compiling, so we'll make our own with a
// different name.

struct PYSQLGUID
{
// I was hoping to use uint32_t, etc., but they aren't included in a Python build. I'm not
// going to require that the compilers supply anything beyond that. There is PY_UINT32_T,
// but there is no 16-bit version. We'll stick with Microsoft's WORD and DWORD which I
// believe the ODBC headers will have to supply.
DWORD Data1;
WORD Data2;
WORD Data3;
byte Data4[8];
};

#endif // DBSPECIFIC_H
53 changes: 51 additions & 2 deletions src/getdata.cpp
Expand Up @@ -572,6 +572,33 @@ static PyObject* GetSqlServerTime(Cursor* cur, Py_ssize_t iCol)
return PyTime_FromTime(value.hour, value.minute, value.second, micros);
}

static PyObject* GetUUID(Cursor* cur, Py_ssize_t iCol)
{
// REVIEW: Since GUID is a fixed size, do we need to pass the size or cbFetched?

PYSQLGUID guid;
SQLLEN cbFetched = 0;
SQLRETURN ret;
Py_BEGIN_ALLOW_THREADS
ret = SQLGetData(cur->hstmt, (SQLUSMALLINT)(iCol+1), SQL_GUID, &guid, sizeof(guid), &cbFetched);
Py_END_ALLOW_THREADS

if (!SQL_SUCCEEDED(ret))
return RaiseErrorFromHandle("SQLGetData", cur->cnxn->hdbc, cur->hstmt);

if (cbFetched == SQL_NULL_DATA)
Py_RETURN_NONE;

#if PY_MAJOR_VERSION >= 3
const char* szFmt = "(yyy#)";
#else
const char* szFmt = "(sss#)";
#endif
Object args = Py_BuildValue(szFmt, NULL, NULL, &guid, (int)sizeof(guid));
if (!args)
return 0;
return PyObject_CallObject(uuid_type, args.Get());
}

static PyObject* GetDataTimestamp(Cursor* cur, Py_ssize_t iCol)
{
Expand Down Expand Up @@ -642,7 +669,6 @@ PyObject* PythonTypeFromSqlType(Cursor* cur, SQLSMALLINT type)
case SQL_CHAR:
case SQL_VARCHAR:
case SQL_LONGVARCHAR:
case SQL_GUID:
#if PY_MAJOR_VERSION < 3
if (cur->cnxn->str_enc.ctype == SQL_C_CHAR)
pytype = (PyObject*)&PyString_Type;
Expand All @@ -653,6 +679,24 @@ PyObject* PythonTypeFromSqlType(Cursor* cur, SQLSMALLINT type)
#endif
break;

case SQL_GUID:
if (UseNativeUUID())
{
pytype = uuid_type;
}
else
{
#if PY_MAJOR_VERSION < 3
if (cur->cnxn->str_enc.ctype == SQL_C_CHAR)
pytype = (PyObject*)&PyString_Type;
else
pytype = (PyObject*)&PyUnicode_Type;
#else
pytype = (PyObject*)&PyUnicode_Type;
#endif
}
break;

case SQL_WCHAR:
case SQL_WVARCHAR:
case SQL_WLONGVARCHAR:
Expand Down Expand Up @@ -738,10 +782,15 @@ PyObject* GetData(Cursor* cur, Py_ssize_t iCol)
case SQL_CHAR:
case SQL_VARCHAR:
case SQL_LONGVARCHAR:
case SQL_GUID:
case SQL_SS_XML:
return GetText(cur, iCol);

case SQL_GUID:
if (UseNativeUUID())
return GetUUID(cur, iCol);
return GetText(cur, iCol);
break;

case SQL_BINARY:
case SQL_VARBINARY:
case SQL_LONGVARBINARY:
Expand Down
25 changes: 25 additions & 0 deletions src/params.cpp
Expand Up @@ -421,6 +421,28 @@ static char* CreateDecimalString(long sign, PyObject* digits, long exp)
return pch;
}

static bool GetUUIDInfo(Cursor* cur, Py_ssize_t index, PyObject* param, ParamInfo& info)
{
info.ValueType = SQL_C_GUID;
info.ParameterType = SQL_GUID;
info.ColumnSize = 16;

info.allocated = true;
info.ParameterValuePtr = pyodbc_malloc(sizeof(SQLGUID));
if (!info.ParameterValuePtr)
{
PyErr_NoMemory();
return false;
}

// Do we need to use "bytes" on a big endian machine?
Object b(PyObject_GetAttrString(param, "bytes_le"));
if (!b)
return false;
memcpy(info.ParameterValuePtr, PyBytes_AS_STRING(b.Get()), sizeof(SQLGUID));
return true;
}

static bool GetDecimalInfo(Cursor* cur, Py_ssize_t index, PyObject* param, ParamInfo& info)
{
// The NUMERIC structure never works right with SQL Server and probably a lot of other drivers. We'll bind as a
Expand Down Expand Up @@ -590,6 +612,9 @@ static bool GetParameterInfo(Cursor* cur, Py_ssize_t index, PyObject* param, Par
if (PyDecimal_Check(param))
return GetDecimalInfo(cur, index, param, info);

if (uuid_type && PyObject_IsInstance(param, uuid_type))
return GetUUIDInfo(cur, index, param, info);

#if PY_VERSION_HEX >= 0x02060000
if (PyByteArray_Check(param))
return GetByteArrayInfo(cur, index, param, info);
Expand Down
47 changes: 35 additions & 12 deletions src/pyodbcmodule.cpp
@@ -1,4 +1,3 @@

// Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated
// documentation files (the "Software"), to deal in the Software without restriction, including without limitation the
// rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to
Expand Down Expand Up @@ -128,11 +127,22 @@ static ExcInfo aExcInfos[] = {


PyObject* decimal_type;
PyObject* uuid_type;

bool UseNativeUUID()
{
PyObject* o = PyObject_GetAttrString(pModule, "native_uuid");
// If this fails for some reason, we'll assume false and allow the exception to pop up later.
bool b = o && PyObject_IsTrue(o);
Py_XDECREF(o);
return b;
}

HENV henv = SQL_NULL_HANDLE;

Py_UNICODE chDecimal = '.';


// Initialize the global decimal character and thousands separator character, used when parsing decimal
// objects.
//
Expand Down Expand Up @@ -182,25 +192,37 @@ static bool import_types()
if (!Params_init())
return false;

PyObject* decimalmod = PyImport_ImportModule("cdecimal");
if (!decimalmod)
Object mod(PyImport_ImportModule("cdecimal"));
if (!mod)
{
// Clear the error from the failed import of cdecimal.
PyErr_Clear();
decimalmod = PyImport_ImportModule("decimal");
if (!decimalmod) {
mod.Attach(PyImport_ImportModule("decimal"));
if (!mod)
{
PyErr_SetString(PyExc_RuntimeError, "Unable to import cdecimal or decimal");
return false;
}
}

decimal_type = PyObject_GetAttrString(decimalmod, "Decimal");
Py_DECREF(decimalmod);

if (decimal_type == 0)
Object dec(PyObject_GetAttrString(mod, "Decimal"));
if (!dec)
{
PyErr_SetString(PyExc_RuntimeError, "Unable to import decimal.Decimal.");
return false;
}

mod = PyImport_ImportModule("uuid");
if (!mod)
return false;

Object uuid(PyObject_GetAttrString(mod, "UUID"));
if (!uuid)
return false;

return decimal_type != 0;
decimal_type = dec.Detach();
uuid_type = uuid.Detach();

return true;
}


Expand Down Expand Up @@ -1031,6 +1053,8 @@ initpyodbc(void)
Py_INCREF(Py_True);
PyModule_AddObject(module, "lowercase", Py_False);
Py_INCREF(Py_False);
PyModule_AddObject(module, "native_uuid", Py_False);
Py_INCREF(Py_False);

PyModule_AddObject(module, "Connection", (PyObject*)&ConnectionType);
Py_INCREF((PyObject*)&ConnectionType);
Expand Down Expand Up @@ -1156,4 +1180,3 @@ static PyObject* MakeConnectionString(PyObject* existing, PyObject* parts)

return result;
}

4 changes: 4 additions & 0 deletions src/pyodbcmodule.h
Expand Up @@ -32,6 +32,7 @@ extern PyObject* NotSupportedError;
extern PyObject* null_binary;

extern PyObject* decimal_type;
extern PyObject* uuid_type;

inline bool PyDecimal_Check(PyObject* p)
{
Expand All @@ -53,4 +54,7 @@ inline bool lowercase()

extern Py_UNICODE chDecimal;

bool UseNativeUUID();
// Returns True if pyodbc.native_uuid is true, meaning uuid.UUID objects should be returned.

#endif // _PYPGMODULE_H
30 changes: 23 additions & 7 deletions tests2/sqlservertests.py
Expand Up @@ -27,7 +27,7 @@
2008: DRIVER={SQL Server Native Client 10.0}
"""

import sys, os, re
import sys, os, re, uuid
import unittest
from decimal import Decimal
from datetime import datetime, date, time
Expand Down Expand Up @@ -162,12 +162,28 @@ def test_noscan(self):
self.cursor.noscan = True
self.assertEqual(self.cursor.noscan, True)

def test_guid(self):
self.cursor.execute("create table t1(g1 uniqueidentifier)")
self.cursor.execute("insert into t1 values (newid())")
v = self.cursor.execute("select * from t1").fetchone()[0]
self.assertEqual(type(v), unicode)
self.assertEqual(len(v), 36)
def test_nonnative_uuid(self):
# The default is False meaning we should return a string. Note that
# SQL Server seems to always return uppercase.
value = uuid.uuid4()
self.cursor.execute("create table t1(n uniqueidentifier)")
self.cursor.execute("insert into t1 values (?)", value)

pyodbc.native_uuid = False
result = self.cursor.execute("select n from t1").fetchval()
self.assertEqual(type(result), unicode)
self.assertEqual(result, unicode(value).upper())

def test_native_uuid(self):
# When true, we should return a uuid.UUID object.
value = uuid.uuid4()
self.cursor.execute("create table t1(n uniqueidentifier)")
self.cursor.execute("insert into t1 values (?)", value)

pyodbc.native_uuid = True
result = self.cursor.execute("select n from t1").fetchval()
self.assertIsInstance(result, uuid.UUID)
self.assertEqual(value, result)

def test_nextset(self):
self.cursor.execute("create table t1(i int)")
Expand Down
31 changes: 23 additions & 8 deletions tests3/sqlservertests.py
Expand Up @@ -27,7 +27,7 @@
2008: DRIVER={SQL Server Native Client 10.0}
"""

import sys, os, re
import sys, os, re, uuid
import unittest
from decimal import Decimal
from datetime import datetime, date, time
Expand Down Expand Up @@ -157,12 +157,28 @@ def test_noscan(self):
self.cursor.noscan = True
self.assertEqual(self.cursor.noscan, True)

def test_guid(self):
self.cursor.execute("create table t1(g1 uniqueidentifier)")
self.cursor.execute("insert into t1 values (newid())")
v = self.cursor.execute("select * from t1").fetchone()[0]
self.assertEqual(type(v), str)
self.assertEqual(len(v), 36)
def test_nonnative_uuid(self):
# The default is False meaning we should return a string. Note that
# SQL Server seems to always return uppercase.
value = uuid.uuid4()
self.cursor.execute("create table t1(n uniqueidentifier)")
self.cursor.execute("insert into t1 values (?)", value)

pyodbc.native_uuid = False
result = self.cursor.execute("select n from t1").fetchval()
self.assertEqual(type(result), str)
self.assertEqual(result, str(value).upper())

def test_native_uuid(self):
# When true, we should return a uuid.UUID object.
value = uuid.uuid4()
self.cursor.execute("create table t1(n uniqueidentifier)")
self.cursor.execute("insert into t1 values (?)", value)

pyodbc.native_uuid = True
result = self.cursor.execute("select n from t1").fetchval()
self.assertIsInstance(result, uuid.UUID)
self.assertEqual(value, result)

def test_nextset(self):
self.cursor.execute("create table t1(i int)")
Expand Down Expand Up @@ -633,7 +649,6 @@ def test_negative_float(self):
result = self.cursor.execute("select n from t1").fetchone()[0]
self.assertEqual(value, result)


#
# stored procedures
#
Expand Down

0 comments on commit 2ad7a9c

Please sign in to comment.