Skip to content
This repository has been archived by the owner on Jan 30, 2023. It is now read-only.

Issue 57 correctly check max input length #58

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions include/mprio.h
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,9 @@ void Prio_clear();
* (2) the modulus we use for modular arithmetic.
* The default configuration uses an 87-bit modulus.
*
* The value `nFields` must be in the range `0 < nFields <= max`, where `max`
* is the value returned by the function `PrioConfig_maxDataFields()` below.
*
* The `batch_id` field specifies which "batch" of aggregate statistics we are
* computing. For example, if the aggregate statistics are computed every 24
* hours, the `batch_id` might be set to an encoding of the date. The clients
Expand All @@ -93,6 +96,11 @@ PrioConfig PrioConfig_new(int nFields, PublicKey serverA, PublicKey serverB,
void PrioConfig_clear(PrioConfig cfg);
int PrioConfig_numDataFields(const_PrioConfig cfg);

/*
* Return the maximum number of data fields that the implementation supports.
*/
int PrioConfig_maxDataFields(void);

/*
* Create a PrioConfig object with no encryption keys. This routine is
* useful for testing, but PrioClient_encode() will always fail when used with
Expand Down
13 changes: 9 additions & 4 deletions prio/config.c
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,13 @@ initialize_roots(MPArray arr, const char values[], bool inverted)
return SECSuccess;
}

int
PrioConfig_maxDataFields(void)
{
const int n_roots = 1 << Generator2Order;
return (n_roots >> 1) - 1;
}

PrioConfig
PrioConfig_new(int n_fields, PublicKey server_a, PublicKey server_b,
const unsigned char* batch_id, unsigned int batch_id_len)
Expand All @@ -71,10 +78,8 @@ PrioConfig_new(int n_fields, PublicKey server_a, PublicKey server_b,
cfg->roots = NULL;
cfg->rootsInv = NULL;

if (cfg->num_data_fields >= cfg->n_roots) {
rv = SECFailure;
goto cleanup;
}
P_CHECKCB(cfg->n_roots > 1);
P_CHECKCB(cfg->num_data_fields <= PrioConfig_maxDataFields());

P_CHECKA(cfg->batch_id = malloc(batch_id_len));
strncpy((char*)cfg->batch_id, (char*)batch_id, batch_id_len);
Expand Down
35 changes: 29 additions & 6 deletions ptest/client_test.c
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ mu_test_client__new(void)
}

void
test_client_agg(int nclients)
test_client_agg(int nclients, int nfields, bool config_is_okay)
{
SECStatus rv = SECSuccess;
PublicKey pkA = NULL;
Expand All @@ -74,7 +74,12 @@ test_client_agg(int nclients)

PT_CHECKC(Keypair_new(&skA, &pkA));
PT_CHECKC(Keypair_new(&skB, &pkB));
PT_CHECKA(cfg = PrioConfig_new(133, pkA, pkB, batch_id, batch_id_len));
printf("fields: %d\n", nfields);
P_CHECKA(cfg = PrioConfig_new(nfields, pkA, pkB, batch_id, batch_id_len));
if (!config_is_okay) {
PT_CHECKCB(
(PrioConfig_new(nfields, pkA, pkB, batch_id, batch_id_len) == NULL));
}
PT_CHECKA(sA = PrioServer_new(cfg, 0, skA, seed));
PT_CHECKA(sB = PrioServer_new(cfg, 1, skB, seed));
PT_CHECKA(tA = PrioTotalShare_new());
Expand Down Expand Up @@ -118,7 +123,11 @@ test_client_agg(int nclients)
}

cleanup:
mu_check(rv == SECSuccess);
if (config_is_okay) {
mu_check(rv == SECSuccess);
} else {
mu_check(rv == SECFailure);
}
if (data_items)
free(data_items);
if (output)
Expand Down Expand Up @@ -147,17 +156,31 @@ test_client_agg(int nclients)
void
mu_test_client__agg_1(void)
{
test_client_agg(1);
test_client_agg(1, 133, true);
}

void
mu_test_client__agg_2(void)
{
test_client_agg(2);
test_client_agg(2, 133, true);
}

void
mu_test_client__agg_10(void)
{
test_client_agg(10);
test_client_agg(10, 133, true);
}

void
mu_test_client__agg_max(void)
{
int max = PrioConfig_maxDataFields();
test_client_agg(10, max, true);
}

void
mu_test_client__agg_max_bad(void)
{
int max = PrioConfig_maxDataFields();
test_client_agg(10, max + 1, false);
}