Skip to content

Commit

Permalink
Consolidate CSV reading functions
Browse files Browse the repository at this point in the history
  • Loading branch information
jonnor committed Nov 23, 2018
1 parent 1c48a0d commit 46ece6b
Show file tree
Hide file tree
Showing 6 changed files with 120 additions and 141 deletions.
137 changes: 72 additions & 65 deletions emlearn.ino
Expand Up @@ -4,19 +4,83 @@

#include "digits.h"

const int32_t n_features = 64;

const int32_t bytes_per_number = 50; // 32bit integer, plus separator. But better to have too much
const int32_t buffer_length = n_features*bytes_per_number;
char receive_buffer[buffer_length] = {0,};
int32_t receive_idx = 0;
void send_reply(int32_t request, int32_t time_taken,
int32_t prediction, int32_t n_repetitions)
{
Serial.print(request);
Serial.print(";");
Serial.print(time_taken);
Serial.print(";");
Serial.print(prediction);
//Serial.print(";");
//Serial.print(n_repetitions);
//for (int i=0; i<n_features; i++) {
// Serial.print(";");
// Serial.print(values[i]);
//}
Serial.print("\n");
}

void parse_predict_reply(char *buffer, float *values, int32_t values_length)
{
// FIXME: buffer needs to be zero terminated?

// Parse the values to use for prediction
int32_t n_values = -1;
const EmlError e = eml_test_parse_csv_line(buffer, values, values_length, &n_values);
if (e != EmlOk) {
return;
}

const expected_values = n_features + 2;
if (n_values != n_features) {
return;
}

int32_t request = values[0];
int32_t n_repetitions = values[1];
values = values+2;
n_values -= 2;

// Do predictions
volatile int32_t sum = 0; // avoid profiler folding the loop
const long pre = micros();

int32_t prediction = -999;
for (int32_t i=0; i<n_repetitions; i++) {
const int32_t p = digits_predict(values, n_values);
if (prediction != -999 && p != prediction) {
// consistency check, should always be same
prediction = -2;
break;
}
//Serial.print("cl: "); Serial.println(s);
sum += p;
prediction = p;
}

const long post = micros();
const long time_taken = post - pre;

// Send back on parseable format
send_reply(request, time_taken, prediction, n_repetitions);
}


void setup() {
Serial.begin(115200);
}

void loop() {
int32_t values[n_features];
const int32_t n_features = 64;
const int32_t bytes_per_number = 20; // 32bit integer, plus separator. But better to have too much
const int32_t buffer_length = n_features*bytes_per_number;

char receive_buffer[buffer_length] = {0,};
const int32_t values_length = 4+n_features;
float values[values_length];
int32_t receive_idx = 0;

while (Serial.available() > 0) {

Expand All @@ -30,67 +94,10 @@ void loop() {
}

if (ch == '\n') {
int32_t request = -3;
int32_t n_repetitions = -1;
const int non_value_fields = 2;

// Parse the values to use for prediction
int field_no = 0;
char seps[] = ",;";
char *token = strtok(receive_buffer, seps);
while (token != NULL)
{
long value;
sscanf(token, "%ld", &value);
if (field_no == 0) {
request = value;
} else if (field_no == 1) {
n_repetitions = value;
} else {
values[field_no-non_value_fields] = value;
}
field_no++;
token = strtok(NULL, seps);
}

if (field_no-non_value_fields != n_features) {
Serial.print("Error, wrong number of features: "); Serial.println(field_no-non_value_fields);
}

parse_predict_reply();

receive_idx = 0;
memset(receive_buffer, buffer_length, 0);

// Do predictions
volatile int32_t sum = 0; // avoid profiler folding the loop
const long pre = micros();
int32_t prediction = -999;
for (int32_t i=0; i<n_repetitions; i++) {
const int32_t p = digits_predict(values, n_features);
if (prediction != -999 && p != prediction) {
// consistency check, should always be same
prediction = -2;
break;
}
//Serial.print("cl: "); Serial.println(s);
sum += p;
prediction = p;
}
const long post = micros();
const long time_taken = post - pre;

// Send back on parseable format
Serial.print(request);
Serial.print(";");
Serial.print(time_taken);
Serial.print(";");
Serial.print(prediction);
Serial.print(";");
Serial.print(n_repetitions);
for (int i=0; i<n_features; i++) {
Serial.print(";");
Serial.print(values[i]);
}
Serial.print("\n");
}
}
}
2 changes: 1 addition & 1 deletion emlearn/bayes.py
Expand Up @@ -83,7 +83,7 @@ def __init__(self, estimator, method):
name = 'mybayes'
func = 'eml_bayes_predict(&{}_model, values, length)'.format(name)
code = self.save(name=name)
self.classifier = common.CompiledClassifier(code, name=name, call=func, test_function='eml_bayes_test_read_csv')
self.classifier = common.CompiledClassifier(code, name=name, call=func)
elif method == 'inline':
raise NotImplementedError('NaiveBayes does not support inline C code generation')
else:
Expand Down
2 changes: 1 addition & 1 deletion emlearn/common.py
Expand Up @@ -29,7 +29,7 @@ def build_classifier(cmodel, name, temp_dir, include_dir, func=None, compiler=No
#include "{def_file_name}"
#include <eml_test.h>
static void classify(const int32_t *values, int length, int row) {{
static void classify(const float *values, int length, int row) {{
const int32_t class = {func};
printf("%d,%d\\n", row, class);
}}
Expand Down
7 changes: 5 additions & 2 deletions emlearn/eml_bayes.h
Expand Up @@ -86,10 +86,13 @@ eml_bayes_logpdf(eml_q16_t x, eml_q16_t mean, eml_q16_t std, eml_q16_t stdlog2)


int32_t
eml_bayes_predict(EmlBayesModel *model, const eml_q16_t values[], int32_t values_length) {
eml_bayes_predict(EmlBayesModel *model, const float values[], int32_t values_length) {
//printf("predict(%d), classes=%d features=%d\n",
// values_length, model->n_classes, model->n_features);

EML_PRECONDITION(model, EmlUninitialized);
EML_PRECONDITION(values, EmlUninitialized);

const int MAX_CLASSES = 10;
eml_q16_t class_probabilities[MAX_CLASSES];

Expand All @@ -99,7 +102,7 @@ eml_bayes_predict(EmlBayesModel *model, const eml_q16_t values[], int32_t values
for (int value_idx = 0; value_idx<values_length; value_idx++) {
const int32_t summary_idx = class_idx*model->n_features + value_idx;
EmlBayesSummary summary = model->summaries[summary_idx];
const eml_q16_t val = values[value_idx];
const eml_q16_t val = EML_Q16_FROMFLOAT(values[value_idx]);
const eml_q16_t plog = eml_bayes_logpdf(val, summary.mean, summary.std, summary.stdlog2);

class_p += plog;
Expand Down
34 changes: 0 additions & 34 deletions emlearn/eml_net.h
Expand Up @@ -315,37 +315,3 @@ eml_net_predict(EmlNet *model, const float *features, int32_t features_length) {
return _class;
}




#if 0
EmlError
eml_net_parse_csv_line(const char *buffer, float *values, int32_t values_length) {

int field_no = 0;
const char seps[] = ",;";
char *token = strtok(buffer, seps);

while (token != NULL)
{
long value;
sscanf(token, "%ld", &value);

if (field_no >= values_length) {
return EmlNetSizeMismatch;
}

values[field_no] = value;
field_no++;
token = strtok(NULL, seps);
}

return EmlNetOk;
}

#endif





79 changes: 41 additions & 38 deletions emlearn/eml_test.h
Expand Up @@ -6,60 +6,63 @@
#include <stdlib.h>
#include <string.h>

#include "eml_common.h"

typedef void (*EmlCsvCallback)(const int32_t *values, int length, int row);
typedef void (*EmlCsvCallback)(const float *values, int length, int row);

void
eml_test_read_csv(FILE *fp, EmlCsvCallback row_callback) {
char buffer[1024];
int32_t values[256];
int row_no = 0;
int value_no = 0;
// Return number of values parsed, or -EmlError
int32_t
eml_test_parse_csv_line(char *buffer, float *values, int32_t values_length,
int32_t *values_read_out)
{
EML_PRECONDITION(buffer, EmlUninitialized);
EML_PRECONDITION(values, EmlUninitialized);

while(fgets(buffer, sizeof buffer, fp))
int field_no = 0;
const char seps[] = ",;";
char *token = strtok(buffer, seps);

while (token != NULL)
{
char seps[] = ",;";
char *token = strtok(buffer, seps);
while (token != NULL)
{
long value;
sscanf(token, "%ld", &value);
values[value_no++] = value;
token = strtok (NULL, seps);
float value;
sscanf(token, "%f", &value);

if (field_no >= values_length) {
return EmlSizeMismatch;
}
row_callback(values, value_no, row_no);
value_no = 0;
row_no++;

values[field_no] = value;
field_no++;
token = strtok(NULL, seps);
}

if (values_read_out) {
*values_read_out = field_no;
}

return EmlOk;
}

// FIXME: Remove in favor of float/int32 support
#include "eml_fixedpoint.h"
typedef void (*EmBayesCsvCallback)(const eml_q16_t values[], int length, int row);
void
eml_bayes_test_read_csv(FILE *fp, EmBayesCsvCallback row_callback) {
char buffer[1024];
eml_q16_t values[256];
EmlError
eml_test_read_csv(FILE *fp, EmlCsvCallback row_callback) {
const int32_t buffer_length = 1024;
char buffer[buffer_length];
const int32_t values_length = 256;
float values[values_length];
int row_no = 0;
int value_no = 0;

while(fgets(buffer, sizeof buffer, fp))
while(fgets(buffer, buffer_length, fp))
{
char seps[] = ",;";
char *token = strtok(buffer, seps);
while (token != NULL)
{
float value;
sscanf(token, "%f", &value);
values[value_no++] = EML_Q16_FROMFLOAT(value);
token = strtok (NULL, seps);
}
int value_no = 0;
const EmlError e = eml_test_parse_csv_line(buffer, values, values_length, &value_no);
EML_CHECK_ERROR(e);
row_callback(values, value_no, row_no);
value_no = 0;
row_no++;
}
return EmlOk;
}


#endif // EML_TEST_H


0 comments on commit 46ece6b

Please sign in to comment.