Permalink
Switch branches/tags
Nothing to show
Find file Copy path
Fetching contributors…
Cannot retrieve contributors at this time
264 lines (242 sloc) 7.78 KB
// Copyright 2013 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <stdio.h>
#include <string.h>
#include <math.h>
#include <stdlib.h>
const long long max_size = 2000; // max length of strings
const long long N = 10; // number of closest words that will be shown
const long long max_w = 50; // max length of vocabulary entries
int ArgPos(char *str, int argc, char **argv) {
int a;
for (a = 1; a < argc; a++) if (!strcmp(str, argv[a])) {
if (a == argc - 1) {
printf("Argument missing for %s\n", str);
exit(1);
}
return a;
}
return -1;
}
#define MAX_STRING 1000
const int vocab_hash_size = 30000000; // Maximum 30 * 0.7 = 21M words in the vocabulary
int *vocab_hash;
struct vocab_word {
long long cn;
int *point;
char *word, *code, codelen;
};
struct vocab_word *vocab;
long long vocab_max_size, vocab_size;
// Returns hash value of a word
int GetWordHash(const char *word) {
unsigned long long a, hash = 0;
for (a = 0; a < strlen(word); a++) hash = hash * 257 + word[a];
hash = hash % vocab_hash_size;
return hash;
}
// Returns position of a word in the vocabulary; if the word is not found, returns -1
int SearchVocab(char *word) {
unsigned int hash = GetWordHash(word);
while (1) {
if (vocab_hash[hash] == -1) return -1;
if (!strcmp(word, vocab[vocab_hash[hash]].word)) {
return vocab_hash[hash];
}
hash = (hash + 1) % vocab_hash_size;
}
return -1;
}
// Adds a word to the vocabulary
int AddWordToVocab(const char *word) {
unsigned int hash, length = strlen(word) + 1;
if (length > MAX_STRING) length = MAX_STRING;
vocab[vocab_size].word = (char *)calloc(length, sizeof(char));
strcpy(vocab[vocab_size].word, word);
vocab[vocab_size].cn = 0;
vocab_size++;
// Reallocate memory if needed
if (vocab_size + 2 >= vocab_max_size) {
vocab_max_size += 1000;
vocab = (struct vocab_word *)realloc(vocab, vocab_max_size * sizeof(struct vocab_word));
}
hash = GetWordHash(word);
while (vocab_hash[hash] != -1) hash = (hash + 1) % vocab_hash_size;
vocab_hash[hash] = vocab_size - 1;
return vocab_size - 1;
}
int main(int argc, char **argv) {
FILE *f;
char st1[max_size];
char bestw[N][max_size];
char emb_file[max_size], word_file[max_size], st[100][max_size];
float dist, len, bestd[N], vec[max_size];
long long words, size, a, b, c, d, cn, bi[100];
// char ch;
float *M;
if (argc == 1) {
printf("Usage: ./distance\n"); // in the BINARY FORMAT\n");
printf("\t-emb <file>\n");
printf("\t\tEmbedding file\n");
printf("\t-word <file>\n");
printf("\t\tRestrict to answer within these words in this file\n");
return 0;
}
int is_word_file = 0;
int i;
if ((i = ArgPos((char *)"-word", argc, argv)) > 0) {
strcpy(word_file, argv[i + 1]);
printf("# word_file=%s\n", word_file);
is_word_file = 1;
// build vocab
vocab_size = 0;
vocab_max_size = 1000;
vocab = (struct vocab_word *)calloc(vocab_max_size, sizeof(struct vocab_word));
vocab_hash = (int *)calloc(vocab_hash_size, sizeof(int));
for (a = 0; a < vocab_hash_size; a++) vocab_hash[a] = -1;
f = fopen(word_file, "r");
if (f == NULL) {
printf("Word file not found\n");
return -1;
}
char *word = (char *)malloc(MAX_STRING * sizeof(char));
int count = 0;
while(1) {
fscanf(f, "%s", word);
if (feof(f)) break; // if I moved to other places, it processes the last word twice; here is fine!
AddWordToVocab(word);
count++;
}
printf(" constraint neighbors to %d words\n", count);
fclose(f);
}
if ((i = ArgPos((char *)"-emb", argc, argv)) > 0) {
strcpy(emb_file, argv[i + 1]);
printf("# emb_file=%s\n", emb_file);
}
// f = fopen(emb_file, "rb");
f = fopen(emb_file, "r");
if (f == NULL) {
printf("Input file not found\n");
return -1;
}
/** Reading embedding file **/
// fscanf(f, "%lld", &words);
// fscanf(f, "%lld", &size);
fscanf(f, "%lld %lld", &words, &size);
printf("Words %lld, size %lld\n", words, size);
char* full_vocab[words];
// char *full_vocab;
M = (float *)malloc((long long)words * (long long)size * sizeof(float));
if (M == NULL) {
printf("Cannot allocate memory: %lld MB %lld %lld\n", (long long)words * size * sizeof(float) / 1048576, words, size);
return -1;
}
for (b = 0; b < words; b++) {
// fscanf(f, "%s%c", &full_vocab[b * max_w], &ch);
// fscanf(f, "%s", &full_vocab[b * max_w]);
full_vocab[b] = (char *)malloc(MAX_STRING * sizeof(char));
fscanf(f, "%s", full_vocab[b]);
for (a = 0; a < size; a++) {
fscanf(f, "%f", &M[a + b * size]);
// fread(&M[a + b * size], sizeof(float), 1, f);
}
len = 0;
for (a = 0; a < size; a++) len += M[a + b * size] * M[a + b * size];
len = sqrt(len);
for (a = 0; a < size; a++) M[a + b * size] /= len;
}
fclose(f);
/** Query nearest neighbors **/
while (1) {
for (a = 0; a < N; a++) bestd[a] = 0;
for (a = 0; a < N; a++) bestw[a][0] = 0;
printf("Enter word or sentence (EXIT to break): ");
a = 0;
while (1) {
st1[a] = fgetc(stdin);
if ((st1[a] == '\n') || (a >= max_size - 1)) {
st1[a] = 0;
break;
}
a++;
}
if (!strcmp(st1, "EXIT")) break;
cn = 0;
b = 0;
c = 0;
while (1) {
st[cn][b] = st1[c];
b++;
c++;
st[cn][b] = 0;
if (st1[c] == 0) break;
if (st1[c] == ' ') {
cn++;
b = 0;
c++;
}
}
cn++;
for (a = 0; a < cn; a++) {
// for (b = 0; b < words; b++) if (!strcmp(&full_vocab[b * max_w], st[a])) break;
for (b = 0; b < words; b++) if (!strcmp(full_vocab[b], st[a])) break;
if (b == words) b = -1;
bi[a] = b;
printf("\nWord: %s Position in vocabulary: %lld\n", st[a], bi[a]);
if (b == -1) {
printf("Out of dictionary word!\n");
break;
}
}
if (b == -1) continue;
printf("\n Word Cosine distance\n------------------------------------------------------------------------\n");
for (a = 0; a < size; a++) vec[a] = 0;
for (b = 0; b < cn; b++) {
if (bi[b] == -1) continue;
for (a = 0; a < size; a++) vec[a] += M[a + bi[b] * size];
}
len = 0;
for (a = 0; a < size; a++) len += vec[a] * vec[a];
len = sqrt(len);
for (a = 0; a < size; a++) vec[a] /= len;
for (a = 0; a < N; a++) bestd[a] = 0;
for (a = 0; a < N; a++) bestw[a][0] = 0;
// go through the set of words
for (c = 0; c < words; c++) {
if (is_word_file){
if (SearchVocab(full_vocab[c]) == -1) continue; // not in the list
}
a = 0;
for (b = 0; b < cn; b++) if (bi[b] == c) a = 1;
if (a == 1) continue;
dist = 0;
for (a = 0; a < size; a++) dist += vec[a] * M[a + c * size];
for (a = 0; a < N; a++) {
if (dist > bestd[a]) {
for (d = N - 1; d > a; d--) {
bestd[d] = bestd[d - 1];
strcpy(bestw[d], bestw[d - 1]);
}
bestd[a] = dist;
// strcpy(bestw[a], &full_vocab[c * max_w]);
strcpy(bestw[a], full_vocab[c]);
break;
}
}
}
for (a = 0; a < N; a++) printf("%50s\t\t%f\n", bestw[a], bestd[a]);
}
return 0;
}