diff --git a/org.eclipse.paho.android.service/org.eclipse.paho.android.service/src/main/java/org/eclipse/paho/android/service/DatabaseMessageStore.java b/org.eclipse.paho.android.service/org.eclipse.paho.android.service/src/main/java/org/eclipse/paho/android/service/DatabaseMessageStore.java index 84a71658..6d13174e 100755 --- a/org.eclipse.paho.android.service/org.eclipse.paho.android.service/src/main/java/org/eclipse/paho/android/service/DatabaseMessageStore.java +++ b/org.eclipse.paho.android.service/org.eclipse.paho.android.service/src/main/java/org/eclipse/paho/android/service/DatabaseMessageStore.java @@ -9,6 +9,9 @@ * http://www.eclipse.org/legal/epl-v10.html * and the Eclipse Distribution License is available at * http://www.eclipse.org/org/documents/edl-v10.php. + * + * Contributors: + * James Sutton - Removing SQL Injection vunerability (bug 467378) */ package org.eclipse.paho.android.service; @@ -203,12 +206,23 @@ public String storeArrived(String clientHandle, String topic, } private int getArrivedRowCount(String clientHandle) { - String[] cols = new String[1]; - cols[0] = "COUNT(*)"; - Cursor c = db.query(ARRIVED_MESSAGE_TABLE_NAME, cols, - MqttServiceConstants.CLIENT_HANDLE + "='" + clientHandle + "'", - null, null, null, null); - int count = 0; + int count = 0; + String[] projection = { + MqttServiceConstants.MESSAGE_ID, + }; + String selection = MqttServiceConstants.CLIENT_HANDLE; + String[] selectionArgs = new String[1]; + selectionArgs[0] = clientHandle; + Cursor c = db.query( + ARRIVED_MESSAGE_TABLE_NAME, // Table Name + projection, // The columns to return; + selection, // Columns for WHERE Clause + selectionArgs , // The values for the WHERE Cause + null, //Don't group the rows + null, // Don't filter by row groups + null // The sort order + ); + if (c.moveToFirst()) { count = c.getInt(0); } @@ -234,11 +248,15 @@ public boolean discardArrived(String clientHandle, String id) { traceHandler.traceDebug(TAG, "discardArrived{" + clientHandle + "}, {" + id + "}"); int rows; + String[] selectionArgs = new String[2]; + selectionArgs[0] = id; + selectionArgs[1] = clientHandle; + try { rows = db.delete(ARRIVED_MESSAGE_TABLE_NAME, - MqttServiceConstants.MESSAGE_ID + "='" + id + "' AND " - + MqttServiceConstants.CLIENT_HANDLE + "='" - + clientHandle + "'", null); + MqttServiceConstants.MESSAGE_ID + "=? AND " + + MqttServiceConstants.CLIENT_HANDLE + "=?", + selectionArgs); } catch (SQLException e) { traceHandler.traceException(TAG, "discardArrived", e); throw e; @@ -272,18 +290,30 @@ public Iterator getAllArrivedMessages( return new Iterator() { private Cursor c; private boolean hasNext; + private String[] selectionArgs = { + clientHandle, + }; + { db = mqttDb.getWritableDatabase(); // anonymous initialiser to start a suitable query // and position at the first row, if one exists if (clientHandle == null) { - c = db.query(ARRIVED_MESSAGE_TABLE_NAME, null, null, null, - null, null, "mtimestamp ASC"); + c = db.query(ARRIVED_MESSAGE_TABLE_NAME, + null, + null, + null, + null, + null, + "mtimestamp ASC"); } else { - c = db.query(ARRIVED_MESSAGE_TABLE_NAME, null, - MqttServiceConstants.CLIENT_HANDLE + "='" - + clientHandle + "'", null, null, null, + c = db.query(ARRIVED_MESSAGE_TABLE_NAME, + null, + MqttServiceConstants.CLIENT_HANDLE + "=?", + selectionArgs, + null, + null, "mtimestamp ASC"); } hasNext = c.moveToFirst(); @@ -352,6 +382,8 @@ protected void finalize() throws Throwable { public void clearArrivedMessages(String clientHandle) { db = mqttDb.getWritableDatabase(); + String[] selectionArgs = new String[1]; + selectionArgs[0] = clientHandle; int rows = 0; if (clientHandle == null) { @@ -362,9 +394,10 @@ public void clearArrivedMessages(String clientHandle) { traceHandler.traceDebug(TAG, "clearArrivedMessages: clearing the table of " + clientHandle + " messages"); - rows = db.delete(ARRIVED_MESSAGE_TABLE_NAME, - MqttServiceConstants.CLIENT_HANDLE + "='" + clientHandle - + "'", null); + rows = db.delete(ARRIVED_MESSAGE_TABLE_NAME, + MqttServiceConstants.CLIENT_HANDLE + "=?", + selectionArgs); + } traceHandler.traceDebug(TAG, "clearArrivedMessages: rows affected = " + rows);