Skip to content

Commit

Permalink
[tegaki-wagomu] Add k-nearest-neighbour classification to Wagomu.
Browse files Browse the repository at this point in the history
  • Loading branch information
Christoph Burgmer committed Nov 3, 2009
1 parent c0e119c commit 9522d3d
Show file tree
Hide file tree
Showing 2 changed files with 205 additions and 30 deletions.
221 changes: 191 additions & 30 deletions tegaki-engines/tegaki-wagomu/wagomu.cpp
Expand Up @@ -566,41 +566,19 @@ Results *Recognizer::recognize(Character *ch, unsigned int n_results) {

/* remove duplicate characters */
if (has_duplicates and n_chars > 1) {
if (k_neighbours == 1) {
n_chars = remove_duplicates(distm, n_chars);
} else {
n_chars = incremental_knn(distm, n_chars, n_results);
}
} else {
/* sort the results with glibc's quicksort */
qsort ((void *) distm,
(size_t) n_chars,
sizeof (CharDist),
(int (*) (const void *, const void*)) char_unicode_cmp);
/* first pass: mark duplicate characters and keep w/ min distance */
for (i=1; i < n_chars; i++) {
if (distm[i-1].unicode == distm[i].unicode) {
/* always shift the minium one to the right */
if (distm[i-1].dist < distm[i].dist)
distm[i].dist = distm[i-1].dist;
}
}
/* second pass: inplace removal */
unsigned int current_pos = 0;
for (i=1; i < n_chars; i++) {
if (distm[i].unicode == distm[i-1].unicode)
/* double characters, the latter one has smaller distance */
continue;
else {
/* store last character */
distm[current_pos] = distm[i-1];
current_pos++;
}
}
distm[current_pos] = distm[n_chars-1];
current_pos++;
n_chars = current_pos;
(int (*) (const void *, const void*)) char_dist_cmp);
}

/* sort the results with glibc's quicksort */
qsort ((void *) distm,
(size_t) n_chars,
sizeof (CharDist),
(int (*) (const void *, const void*)) char_dist_cmp);

size = MIN(n_chars, n_results);

Results *results = new Results(size);
Expand All @@ -611,6 +589,189 @@ Results *Recognizer::recognize(Character *ch, unsigned int n_results) {
return results;
}

unsigned int Recognizer::incremental_knn(CharDist *distm,
unsigned int n_chars,
unsigned int n_results) {
unsigned int i;

/* sort the results with glibc's quicksort */
qsort ((void *) distm,
(size_t) n_chars,
sizeof (CharDist),
(int (*) (const void *, const void*)) char_dist_cmp);

unsigned int result_count = 0;
GSList* nearest_neighbours = NULL;
GSList* cur_neighbour;

/* build nearest neighbour list, insert candidates into result set
incrementally */
for(i=0; i < n_chars; i++) {
cur_neighbour = nearest_neighbours;
GSList* bigger_neighbour = NULL;
GSList* previous_neighbour = NULL;
while (cur_neighbour != NULL) {
Neighbour *cur_data = (Neighbour *) cur_neighbour->data;
/* set the bigger neighbour, where we'll move our match up to */
if (previous_neighbour != NULL) {
if (((Neighbour *) previous_neighbour->data)->count
> cur_data->count) {
/* we found a smaller item */
bigger_neighbour = previous_neighbour;
}
}
if (cur_data->unicode == distm[i].unicode) {
cur_data->count++;
cur_data->akku_dist += distm[i].dist;
/* insert element between next bigger and same size
neighbours. TODO: can't we use any existing function with
O(n)? */
if (bigger_neighbour == NULL) {
/* head */
if (previous_neighbour != NULL) {
/* we're not the first element */
previous_neighbour->next = cur_neighbour->next;
cur_neighbour->next = nearest_neighbours;
nearest_neighbours = cur_neighbour;
}
} else {
if (bigger_neighbour->next != cur_neighbour) {
previous_neighbour->next = cur_neighbour->next;
cur_neighbour->next = bigger_neighbour->next;
bigger_neighbour->next = cur_neighbour;
}
}
break;

} else {
/* not found, try next node */
previous_neighbour = cur_neighbour;
cur_neighbour = g_slist_next(cur_neighbour);
}
} // end while

if (cur_neighbour == NULL) {
/* element not yet contained in neighbour list */
GSList *new_elem = NULL;
/* use glib slice allocator */
Neighbour *new_data = (Neighbour *)
g_slice_alloc(sizeof(Neighbour));
new_data->unicode = distm[i].unicode;
new_data->akku_dist = distm[i].dist;
new_data->count = 1;
new_data->seen = false;

new_elem = g_slist_append(new_elem, (void *) new_data);

if (previous_neighbour != NULL) {
previous_neighbour->next = new_elem;
} else {
/* first element */
nearest_neighbours = new_elem;
}

cur_neighbour = new_elem;
}

/* check if we have a newest nearest neighbour */
if (i == k_neighbours) {
/* first time we have reached k, all beginning elements with same
count go into the result set */
GSList* c = nearest_neighbours;
while (c != NULL
and ((Neighbour *) c->data)->count
== ((Neighbour *) nearest_neighbours->data)->count) {
((Neighbour *) c->data)->seen = true;
distm[result_count].unicode = ((Neighbour *) c->data)->unicode;
/* dist = mean of instances => sorted result list */
distm[result_count].dist = ((Neighbour *) c->data)->akku_dist
/ ((Neighbour *) c->data)->count;
result_count++;
c = g_slist_next(c);
}
} else {
Neighbour *cur_data = (Neighbour *) cur_neighbour->data;
if (i >= k_neighbours and not cur_data->seen
and cur_data->count
== ((Neighbour *) nearest_neighbours->data)->count) {
/* next nearest neighbour */
cur_data->seen = true;
distm[result_count].unicode = cur_data->unicode;
/* dist = mean of instances => sorted result list */
distm[result_count].dist = cur_data->akku_dist
/ cur_data->count;
result_count++;
}
}

}

/* append those that didn't make it until n_result is reached */
// cur_neighbour = nearest_neighbours;
// while (result_count <= n_results and cur_neighbour != NULL) {
// Neighbour *cur_data = ((Neighbour *) cur_neighbour->data);
// if (not cur_data->seen) {
// distm[result_count].unicode = cur_data->unicode;
// /* dist = mean of instances => sorted result list */
// distm[result_count].dist = cur_data->akku_dist
// / cur_data->count;
// result_count++;
// }
// cur_neighbour = g_slist_next(cur_neighbour);
// }

/* clean up */
cur_neighbour = nearest_neighbours;
while (cur_neighbour != NULL) {
g_slice_free1(sizeof(Neighbour), cur_neighbour->data);
cur_neighbour = g_slist_next(cur_neighbour);
}
g_slist_free(nearest_neighbours);

return result_count;
}

unsigned int Recognizer::remove_duplicates(CharDist *distm,
unsigned int n_chars) {
unsigned int i;

qsort ((void *) distm,
(size_t) n_chars,
sizeof (CharDist),
(int (*) (const void *, const void*)) char_unicode_cmp);
/* first pass: mark duplicate characters and keep w/ min distance */
for (i=1; i < n_chars; i++) {
if (distm[i-1].unicode == distm[i].unicode) {
/* always shift the minium one to the right */
if (distm[i-1].dist < distm[i].dist)
distm[i].dist = distm[i-1].dist;
}
}
/* second pass: inplace removal */
unsigned int current_pos = 0;
for (i=1; i < n_chars; i++) {
if (distm[i].unicode == distm[i-1].unicode)
/* double characters, the latter one has smaller distance */
continue;
else {
/* store last character */
distm[current_pos] = distm[i-1];
current_pos++;
}
}
distm[current_pos] = distm[n_chars-1];
current_pos++;
n_chars = current_pos;

/* sort the results with glibc's quicksort */
qsort ((void *) distm,
(size_t) n_chars,
sizeof (CharDist),
(int (*) (const void *, const void*)) char_dist_cmp);

return n_chars;
}

char* Recognizer::get_error_message() {
return error_msg;
}
Expand Down
14 changes: 14 additions & 0 deletions tegaki-engines/tegaki-wagomu/wagomu.h
Expand Up @@ -88,6 +88,13 @@ typedef struct {
char pad[4];
} CharacterGroup;

typedef struct {
unsigned int unicode;
unsigned int count;
float akku_dist;
bool seen;
} Neighbour;

#ifdef __SSE__
typedef union {
__m128 v;
Expand All @@ -111,6 +118,9 @@ class Recognizer {
void set_window_size(unsigned int size);
char *get_error_message();

/* for k-nearest-neighbour evaluation */
static const unsigned int k_neighbours = 5;

private:
GMappedFile *file;
char *data;
Expand Down Expand Up @@ -145,6 +155,10 @@ class Recognizer {

unsigned int get_max_n_vectors();

unsigned int remove_duplicates(CharDist *distm, unsigned int n_chars);
unsigned int incremental_knn(CharDist *distm, unsigned int n_chars,
unsigned int n_results);

inline float local_distance(float *v1, float *v2);

inline float dtw(float *s, unsigned int n, float *t, unsigned int m);
Expand Down

0 comments on commit 9522d3d

Please sign in to comment.