diff --git a/emlearn.ino b/emlearn.ino index efaa6e2..bc9a81d 100644 --- a/emlearn.ino +++ b/emlearn.ino @@ -4,19 +4,83 @@ #include "digits.h" -const int32_t n_features = 64; -const int32_t bytes_per_number = 50; // 32bit integer, plus separator. But better to have too much -const int32_t buffer_length = n_features*bytes_per_number; -char receive_buffer[buffer_length] = {0,}; -int32_t receive_idx = 0; +void send_reply(int32_t request, int32_t time_taken, + int32_t prediction, int32_t n_repetitions) +{ + Serial.print(request); + Serial.print(";"); + Serial.print(time_taken); + Serial.print(";"); + Serial.print(prediction); + //Serial.print(";"); + //Serial.print(n_repetitions); + //for (int i=0; i 0) { @@ -30,67 +94,10 @@ void loop() { } if (ch == '\n') { - int32_t request = -3; - int32_t n_repetitions = -1; - const int non_value_fields = 2; - - // Parse the values to use for prediction - int field_no = 0; - char seps[] = ",;"; - char *token = strtok(receive_buffer, seps); - while (token != NULL) - { - long value; - sscanf(token, "%ld", &value); - if (field_no == 0) { - request = value; - } else if (field_no == 1) { - n_repetitions = value; - } else { - values[field_no-non_value_fields] = value; - } - field_no++; - token = strtok(NULL, seps); - } - - if (field_no-non_value_fields != n_features) { - Serial.print("Error, wrong number of features: "); Serial.println(field_no-non_value_fields); - } - + parse_predict_reply(); + receive_idx = 0; memset(receive_buffer, buffer_length, 0); - - // Do predictions - volatile int32_t sum = 0; // avoid profiler folding the loop - const long pre = micros(); - int32_t prediction = -999; - for (int32_t i=0; i - static void classify(const int32_t *values, int length, int row) {{ + static void classify(const float *values, int length, int row) {{ const int32_t class = {func}; printf("%d,%d\\n", row, class); }} diff --git a/emlearn/eml_bayes.h b/emlearn/eml_bayes.h index f1be66f..bde3a57 100644 --- a/emlearn/eml_bayes.h +++ b/emlearn/eml_bayes.h @@ -86,10 +86,13 @@ eml_bayes_logpdf(eml_q16_t x, eml_q16_t mean, eml_q16_t std, eml_q16_t stdlog2) int32_t -eml_bayes_predict(EmlBayesModel *model, const eml_q16_t values[], int32_t values_length) { +eml_bayes_predict(EmlBayesModel *model, const float values[], int32_t values_length) { //printf("predict(%d), classes=%d features=%d\n", // values_length, model->n_classes, model->n_features); + EML_PRECONDITION(model, EmlUninitialized); + EML_PRECONDITION(values, EmlUninitialized); + const int MAX_CLASSES = 10; eml_q16_t class_probabilities[MAX_CLASSES]; @@ -99,7 +102,7 @@ eml_bayes_predict(EmlBayesModel *model, const eml_q16_t values[], int32_t values for (int value_idx = 0; value_idxn_features + value_idx; EmlBayesSummary summary = model->summaries[summary_idx]; - const eml_q16_t val = values[value_idx]; + const eml_q16_t val = EML_Q16_FROMFLOAT(values[value_idx]); const eml_q16_t plog = eml_bayes_logpdf(val, summary.mean, summary.std, summary.stdlog2); class_p += plog; diff --git a/emlearn/eml_net.h b/emlearn/eml_net.h index e5d1ff5..8bb5330 100644 --- a/emlearn/eml_net.h +++ b/emlearn/eml_net.h @@ -315,37 +315,3 @@ eml_net_predict(EmlNet *model, const float *features, int32_t features_length) { return _class; } - - - -#if 0 -EmlError -eml_net_parse_csv_line(const char *buffer, float *values, int32_t values_length) { - - int field_no = 0; - const char seps[] = ",;"; - char *token = strtok(buffer, seps); - - while (token != NULL) - { - long value; - sscanf(token, "%ld", &value); - - if (field_no >= values_length) { - return EmlNetSizeMismatch; - } - - values[field_no] = value; - field_no++; - token = strtok(NULL, seps); - } - - return EmlNetOk; -} - -#endif - - - - - diff --git a/emlearn/eml_test.h b/emlearn/eml_test.h index 2ad069c..ae98934 100644 --- a/emlearn/eml_test.h +++ b/emlearn/eml_test.h @@ -6,60 +6,63 @@ #include #include +#include "eml_common.h" -typedef void (*EmlCsvCallback)(const int32_t *values, int length, int row); +typedef void (*EmlCsvCallback)(const float *values, int length, int row); -void -eml_test_read_csv(FILE *fp, EmlCsvCallback row_callback) { - char buffer[1024]; - int32_t values[256]; - int row_no = 0; - int value_no = 0; +// Return number of values parsed, or -EmlError +int32_t +eml_test_parse_csv_line(char *buffer, float *values, int32_t values_length, + int32_t *values_read_out) +{ + EML_PRECONDITION(buffer, EmlUninitialized); + EML_PRECONDITION(values, EmlUninitialized); - while(fgets(buffer, sizeof buffer, fp)) + int field_no = 0; + const char seps[] = ",;"; + char *token = strtok(buffer, seps); + + while (token != NULL) { - char seps[] = ",;"; - char *token = strtok(buffer, seps); - while (token != NULL) - { - long value; - sscanf(token, "%ld", &value); - values[value_no++] = value; - token = strtok (NULL, seps); + float value; + sscanf(token, "%f", &value); + + if (field_no >= values_length) { + return EmlSizeMismatch; } - row_callback(values, value_no, row_no); - value_no = 0; - row_no++; + + values[field_no] = value; + field_no++; + token = strtok(NULL, seps); } + + if (values_read_out) { + *values_read_out = field_no; + } + + return EmlOk; } -// FIXME: Remove in favor of float/int32 support -#include "eml_fixedpoint.h" -typedef void (*EmBayesCsvCallback)(const eml_q16_t values[], int length, int row); -void -eml_bayes_test_read_csv(FILE *fp, EmBayesCsvCallback row_callback) { - char buffer[1024]; - eml_q16_t values[256]; +EmlError +eml_test_read_csv(FILE *fp, EmlCsvCallback row_callback) { + const int32_t buffer_length = 1024; + char buffer[buffer_length]; + const int32_t values_length = 256; + float values[values_length]; int row_no = 0; - int value_no = 0; - while(fgets(buffer, sizeof buffer, fp)) + while(fgets(buffer, buffer_length, fp)) { - char seps[] = ",;"; - char *token = strtok(buffer, seps); - while (token != NULL) - { - float value; - sscanf(token, "%f", &value); - values[value_no++] = EML_Q16_FROMFLOAT(value); - token = strtok (NULL, seps); - } + int value_no = 0; + const EmlError e = eml_test_parse_csv_line(buffer, values, values_length, &value_no); + EML_CHECK_ERROR(e); row_callback(values, value_no, row_no); - value_no = 0; row_no++; } + return EmlOk; } + #endif // EML_TEST_H