Skip to content

Commit

Permalink
Rewrite decimal parsing to eliminate buffer overflow
Browse files Browse the repository at this point in the history
This is a fix for GitHub security advisory GHSA-pm6v-h62r-rwx8.  The old code had a hardcoded
buffer of 100 bytes (and a comment asking why it was hardcoded!) and fetching a decimal greater
than 100 digits would cause a buffer overflow.

Author arturxedex128 supplied a very simple code to reproduce the error which was put into the
3 PostgreSQL unit tests as test_large_decimal.  (Thank you arturxedex128!)

Unfortunately the strategy is still that we have to parse decimals, but now Python strings /
Unicode objects are used so there is no arbitrary limit.
  • Loading branch information
mkleehammer committed Apr 11, 2023
1 parent e7fefd8 commit 6b107a2
Show file tree
Hide file tree
Showing 9 changed files with 197 additions and 135 deletions.
2 changes: 1 addition & 1 deletion src/cursor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ static bool create_name_map(Cursor* cur, SQLSMALLINT field_count, bool lower)

TRACE("Col %d: type=%s (%d) colsize=%d\n", (i+1), SqlTypeName(nDataType), (int)nDataType, (int)nColSize);

Object name(TextBufferToObject(enc, szName, cbName));
Object name(TextBufferToObject(enc, (byte*)szName, cbName));

if (!name)
goto done;
Expand Down
157 changes: 157 additions & 0 deletions src/decimal.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@

#include "pyodbc.h"
#include "wrapper.h"
#include "textenc.h"
#include "decimal.h"

static PyObject* decimal = 0;
// The Decimal constructor.

static PyObject* re_sub = 0;
static PyObject* re_compile = 0;
static PyObject* re_escape = 0;

// In Python 2.7, the 3 strings below are bytes objects. In 3.x they are Unicode objects.


static PyObject* pDecimalPoint = 0;
// A "." object which we replace pLocaleDecimal with if they are not the same.

static PyObject* pLocaleDecimal = 0;
// The decimal character used by the locale. This can be overridden by the user.
//
// In 2.7 this is a bytes object, otherwise unicode.

static PyObject* pLocaleDecimalEscaped = 0;
// A version of pLocaleDecimal escaped to be used in a regular expression. (The character
// could be something special in regular expressions.) This is zero when pLocaleDecimal is
// ".", indicating no replacement is necessary.

static PyObject* pRegExpRemove = 0;
// A regular expression that matches characters we want to remove before parsing.


bool InitializeDecimal() {
// This is called when the module is initialized and creates globals.

Object d(PyImport_ImportModule("decimal"));
decimal = PyObject_GetAttrString(d, "Decimal");
if (!decimal)
return 0;
Object re(PyImport_ImportModule("re"));
re_sub = PyObject_GetAttrString(re, "sub");
re_escape = PyObject_GetAttrString(re, "escape");
re_compile = PyObject_GetAttrString(re, "compile");

Object module(PyImport_ImportModule("locale"));
Object ldict(PyObject_CallMethod(module, "localeconv", 0));
Object point(PyDict_GetItemString(ldict, "decimal_point"));

if (!point)
return false;

#if PY_MAJOR_VERSION >= 3
pDecimalPoint = PyUnicode_FromString(".");
#else
pDecimalPoint = PyBytes_FromString(".");
#endif

if (!pDecimalPoint)
return false;

#if PY_MAJOR_VERSION >= 3
if (!SetDecimalPoint(point))
return false;
#else
// In 2.7, we only support non-Unicode right now.
if (PyBytes_Check(point))
if (!SetDecimalPoint(point))
return false;
#endif

return true;
}

PyObject* GetDecimalPoint() {
Py_INCREF(pLocaleDecimal);
return pLocaleDecimal;
}

bool SetDecimalPoint(PyObject* pNew)
{
if (PyObject_RichCompareBool(pNew, pDecimalPoint, Py_EQ) == 1)
{
// They are the same.
Py_XDECREF(pLocaleDecimal);
pLocaleDecimal = pDecimalPoint;
Py_INCREF(pLocaleDecimal);

Py_XDECREF(pLocaleDecimalEscaped);
pLocaleDecimalEscaped = 0;
}
else
{
// They are different, so we'll need a regular expression to match it so it can be
// replaced in getdata GetDataDecimal.

Py_XDECREF(pLocaleDecimal);
pLocaleDecimal = pNew;
Py_INCREF(pLocaleDecimal);

Object e(PyObject_CallFunctionObjArgs(re_escape, pNew, 0));
if (!e)
return false;

Py_XDECREF(pLocaleDecimalEscaped);
pLocaleDecimalEscaped = e.Detach();
}

#if PY_MAJOR_VERSION >= 3
Object s(PyUnicode_FromFormat("[^0-9%U-]+", pLocaleDecimal));
#else
Object s(PyBytes_FromFormat("[^0-9%s-]+", PyString_AsString(pLocaleDecimal)));
#endif
if (!s)
return false;

Object r(PyObject_CallFunctionObjArgs(re_compile, s.Get(), 0));
if (!r)
return false;

Py_XDECREF(pRegExpRemove);
pRegExpRemove = r.Detach();

return true;
}


PyObject* DecimalFromText(const TextEnc& enc, const byte* pb, Py_ssize_t cb)
{
// Creates a Decimal object from a text buffer.

// The Decimal constructor requires the decimal point to be '.', so we need to convert the
// locale's decimal to it. We also need to remove non-decimal characters such as thousands
// separators and currency symbols.
//
// Remember that the thousands separate will often be '.', so have to do this carefully.
// We'll create a regular expression with 0-9 and whatever the thousands separator is.

Object text(TextBufferToObject(enc, pb, cb));
if (!text)
return 0;

Object cleaned = PyObject_CallMethod(pRegExpRemove, "sub", "sO", "", text.Get());
if (!cleaned)
return 0;

if (pLocaleDecimalEscaped)
{
Object c2(PyObject_CallFunctionObjArgs(re_sub, pLocaleDecimalEscaped, pDecimalPoint, 0));
if (!c2)
return 0;
cleaned.Attach(c2.Detach());
}

PyObject* result = PyObject_CallFunctionObjArgs(decimal, cleaned.Get(), 0);
return result;
}
7 changes: 7 additions & 0 deletions src/decimal.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
#pragma once

bool InitializeDecimal();
PyObject* GetDecimalPoint();
bool SetDecimalPoint(PyObject* pNew);

PyObject* DecimalFromText(const TextEnc& enc, const byte* pb, Py_ssize_t cb);
82 changes: 4 additions & 78 deletions src/getdata.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include "connection.h"
#include "errors.h"
#include "dbspecific.h"
#include "decimal.h"
#include <time.h>
#include <datetime.h>

Expand Down Expand Up @@ -381,86 +382,11 @@ static PyObject* GetDataDecimal(Cursor* cur, Py_ssize_t iCol)
Py_RETURN_NONE;
}

Object result(TextBufferToObject(enc, pbData, cbData));
Object result(DecimalFromText(enc, pbData, cbData));

pyodbc_free(pbData);

if (!result)
return 0;

// Remove non-digits and convert the databases decimal to a '.' (required by decimal ctor).
//
// We are assuming that the decimal point and digits fit within the size of ODBCCHAR.

// If Unicode, convert to UTF-8 and copy the digits and punctuation out. Since these are
// all ASCII characters, we can ignore any multiple-byte characters. Fortunately, if a
// character is multi-byte all bytes will have the high bit set.

char* pch;
Py_ssize_t cch;

#if PY_MAJOR_VERSION >= 3
if (PyUnicode_Check(result))
{
pch = (char*)PyUnicode_AsUTF8AndSize(result, &cch);
}
else
{
int n = PyBytes_AsStringAndSize(result, &pch, &cch);
if (n < 0)
pch = 0;
}
#else
Object encoded;
if (PyUnicode_Check(result))
{
encoded = PyUnicode_AsUTF8String(result);
if (!encoded)
return 0;
result = encoded.Detach();
}
int n = PyString_AsStringAndSize(result, &pch, &cch);
if (n < 0)
pch = 0;
#endif

if (!pch)
return 0;

// TODO: Why is this limited to 100? Also, can we perform a check on the original and use
// it as-is?
char ascii[100];
size_t asciilen = 0;

const char* pchMax = pch + cch;
while (pch < pchMax)
{
if ((*pch & 0x80) == 0)
{
if (*pch == chDecimal)
{
// Must force it to use '.' since the Decimal class doesn't pay attention to the locale.
ascii[asciilen++] = '.';
}
else if ((*pch >= '0' && *pch <= '9') || *pch == '-')
{
ascii[asciilen++] = (char)(*pch);
}
}
pch++;
}

ascii[asciilen] = 0;

Object str(PyString_FromStringAndSize(ascii, (Py_ssize_t)asciilen));
if (!str)
return 0;
PyObject* decimal_type = GetClassForThread("decimal", "Decimal");
if (!decimal_type)
return 0;
PyObject* decimal = PyObject_CallFunction(decimal_type, "O", str.Get());
Py_DECREF(decimal_type);
return decimal;
return result.Detach();
}

static PyObject* GetDataBit(Cursor* cur, Py_ssize_t iCol)
Expand Down Expand Up @@ -875,4 +801,4 @@ PyObject* GetData(Cursor* cur, Py_ssize_t iCol)

return RaiseErrorV("HY106", ProgrammingError, "ODBC SQL type %d is not yet supported. column-index=%zd type=%d",
(int)pinfo->sql_type, iCol, (int)pinfo->sql_type);
}
}
61 changes: 15 additions & 46 deletions src/pyodbcmodule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include "cnxninfo.h"
#include "params.h"
#include "dbspecific.h"
#include "decimal.h"
#include <datetime.h>

#include <time.h>
Expand Down Expand Up @@ -152,9 +153,6 @@ bool UseNativeUUID()

HENV henv = SQL_NULL_HANDLE;

Py_UNICODE chDecimal = '.';


PyObject* GetClassForThread(const char* szModule, const char* szClass)
{
// Returns the given class, specific to the current thread's interpreter. For performance
Expand Down Expand Up @@ -249,36 +247,6 @@ bool IsInstanceForThread(PyObject* param, const char* szModule, const char* szCl
}


// Initialize the global decimal character and thousands separator character, used when parsing decimal
// objects.
//
static void init_locale_info()
{
Object module(PyImport_ImportModule("locale"));
if (!module)
{
PyErr_Clear();
return;
}

Object ldict(PyObject_CallMethod(module, "localeconv", 0));
if (!ldict)
{
PyErr_Clear();
return;
}

PyObject* value = PyDict_GetItemString(ldict, "decimal_point");
if (value)
{
if (PyBytes_Check(value) && PyBytes_Size(value) == 1)
chDecimal = (Py_UNICODE)PyBytes_AS_STRING(value)[0];
if (PyUnicode_Check(value) && PyUnicode_GET_SIZE(value) == 1)
chDecimal = PyUnicode_AS_UNICODE(value)[0];
}
}


static bool import_types()
{
// Note: We can only import types from C extensions since they are shared among all
Expand All @@ -300,6 +268,8 @@ static bool import_types()
GetData_init();
if (!Params_init())
return false;
if (!InitializeDecimal())
return false;

return true;
}
Expand Down Expand Up @@ -708,24 +678,25 @@ static PyObject* mod_timestampfromticks(PyObject* self, PyObject* args)
static PyObject* mod_setdecimalsep(PyObject* self, PyObject* args)
{
UNUSED(self);
if (!PyString_Check(PyTuple_GET_ITEM(args, 0)) && !PyUnicode_Check(PyTuple_GET_ITEM(args, 0)))
return PyErr_Format(PyExc_TypeError, "argument 1 must be a string or unicode object");

PyObject* value = PyUnicode_FromObject(PyTuple_GetItem(args, 0));
if (value)
{
if (PyBytes_Check(value) && PyBytes_Size(value) == 1)
chDecimal = (Py_UNICODE)PyBytes_AS_STRING(value)[0];
if (PyUnicode_Check(value) && PyUnicode_GET_SIZE(value) == 1)
chDecimal = PyUnicode_AS_UNICODE(value)[0];
}
#if PY_MAJOR_VERSION >= 3
const char* type = "U";
#else
const char* type = "S";
#endif

PyObject* p;
if (!PyArg_ParseTuple(args, type, &p))
return 0;
if (!SetDecimalPoint(p))
return 0;
Py_RETURN_NONE;
}

static PyObject* mod_getdecimalsep(PyObject* self)
{
UNUSED(self);
return PyUnicode_FromUnicode(&chDecimal, 1);
return GetDecimalPoint();
}

static char connect_doc[] =
Expand Down Expand Up @@ -1245,8 +1216,6 @@ initpyodbc(void)
if (!module || !import_types() || !CreateExceptions())
return MODRETURN(0);

init_locale_info();

const char* szVersion = TOSTRING(PYODBC_VERSION);
PyModule_AddStringConstant(module, "version", (char*)szVersion);

Expand Down
2 changes: 0 additions & 2 deletions src/pyodbcmodule.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,6 @@ inline bool lowercase()
return PyObject_GetAttrString(pModule, "lowercase") == Py_True;
}

extern Py_UNICODE chDecimal;

bool UseNativeUUID();
// Returns True if pyodbc.native_uuid is true, meaning uuid.UUID objects should be returned.

Expand Down
Loading

0 comments on commit 6b107a2

Please sign in to comment.