Skip to content

Commit

Permalink
Merge pull request #41 from beniz/txt_1dconv
Browse files Browse the repository at this point in the history
Character-level 1D convolutional deep neural models for text
  • Loading branch information
beniz committed Jan 11, 2016
2 parents e044baa + 93f6306 commit 6116a0c
Show file tree
Hide file tree
Showing 8 changed files with 490 additions and 102 deletions.
18 changes: 10 additions & 8 deletions src/caffeinputconns.cc
Expand Up @@ -672,8 +672,8 @@ namespace dd
LOG(INFO) << "Txt db test file " << testdbfullname << " with " << _db_testbatchsize << " records\n";
}
// XXX: remove in-memory data, which pre-processing is useless and should be avoided
_txt.clear();
_test_txt.clear();
destroy_txt_entries(_txt);
destroy_txt_entries(_test_txt);

return 0;
}
Expand All @@ -685,11 +685,11 @@ namespace dd

// write to dbs (i.e. train and possibly test)
write_txt_to_db(dbfullname,_txt);
_txt.clear();
destroy_txt_entries(_txt);
if (!_test_txt.empty())
{
write_txt_to_db(testdbfullname,_test_txt);
_test_txt.clear();
destroy_txt_entries(_test_txt);
}

// write corresp file
Expand All @@ -706,7 +706,7 @@ namespace dd
}

void TxtCaffeInputFileConn::write_txt_to_db(const std::string &dbfullname,
std::vector<TxtBowEntry> &txt,
std::vector<TxtEntry<double>*> &txt,
const std::string &backend)
{
// Create new DB
Expand All @@ -720,10 +720,12 @@ namespace dd
const int kMaxKeyLength = 256;
char key_cstr[kMaxKeyLength];
int n = 0;
auto hit = txt.cbegin();
while(hit!=txt.cend())
auto hit = txt.begin();
while(hit!=txt.end())
{
datum = to_datum((*hit));
if (_characters)
datum = to_datum<TxtCharEntry>(static_cast<TxtCharEntry*>((*hit)));
else datum = to_datum<TxtBowEntry>(static_cast<TxtBowEntry*>((*hit)));
if (_channels == 0)
_channels = datum.channels();
int length = snprintf(key_cstr,kMaxKeyLength,"%s",std::to_string(n).c_str());
Expand Down
93 changes: 70 additions & 23 deletions src/caffeinputconns.h
Expand Up @@ -39,7 +39,7 @@ namespace dd
public:
CaffeInputInterface() {}
CaffeInputInterface(const CaffeInputInterface &cii)
:_dv(cii._dv),_dv_test(cii._dv_test),_ids(cii._ids) {} //,_test_labels(cii._test_labels) {}
:_dv(cii._dv),_dv_test(cii._dv_test),_ids(cii._ids),_flat1dconv(cii._flat1dconv) {}
~CaffeInputInterface() {}

/**
Expand All @@ -63,6 +63,7 @@ namespace dd
std::vector<caffe::Datum> _dv; /**< main input datum vector, used for training or prediction */
std::vector<caffe::Datum> _dv_test; /**< test input datum vector, when applicable in training mode */
std::vector<std::string> _ids; /**< input ids (e.g. image ids) */
bool _flat1dconv = false; /**< whether a 1D convolution model. */
};

/**
Expand All @@ -72,9 +73,9 @@ namespace dd
{
public:
ImgCaffeInputFileConn()
:ImgInputFileConn() {}
:ImgInputFileConn() { _db = true; }
ImgCaffeInputFileConn(const ImgCaffeInputFileConn &i)
:ImgInputFileConn(i),CaffeInputInterface(i) {}
:ImgInputFileConn(i),CaffeInputInterface(i) { _db = true; }
~ImgCaffeInputFileConn() {}

// size of each element in Caffe jargon
Expand Down Expand Up @@ -511,6 +512,8 @@ namespace dd
void init(const APIData &ad)
{
TxtInputFileConn::init(ad);
if (_characters)
_flat1dconv = true;
}

int channels() const
Expand All @@ -522,11 +525,15 @@ namespace dd

int height() const
{
return 1;
if (_characters)
return _sequence;
else return 1;
}

int width() const
{
if (_characters)
return _alphabet.size();
return 1;
}

Expand All @@ -550,7 +557,7 @@ namespace dd
const std::string &backend="lmdb");

void write_txt_to_db(const std::string &dbname,
std::vector<TxtBowEntry> &txt,
std::vector<TxtEntry<double>*> &txt,
const std::string &backend="lmdb");

void transform(const APIData &ad)
Expand All @@ -559,7 +566,7 @@ namespace dd
APIData ad_input = ad_param.getobj("input");
if (ad_input.has("db") && ad_input.get("db").get<bool>())
_db = true;

// transform to one-hot vector datum
if (_train && _db)
{
Expand All @@ -583,18 +590,22 @@ namespace dd
TxtInputFileConn::transform(ad);

int n = 0;
auto hit = _txt.cbegin();
while(hit!=_txt.cend())
auto hit = _txt.begin();
while(hit!=_txt.end())
{
_dv.push_back(std::move(to_datum((*hit))));
if (_characters)
_dv.push_back(std::move(to_datum<TxtCharEntry>(static_cast<TxtCharEntry*>((*hit)))));
else _dv.push_back(std::move(to_datum<TxtBowEntry>(static_cast<TxtBowEntry*>((*hit)))));
_ids.push_back(std::to_string(n));
++hit;
++n;
}
hit = _test_txt.cbegin();
while(hit!=_test_txt.cend())
hit = _test_txt.begin();
while(hit!=_test_txt.end())
{
_dv_test.push_back(std::move(to_datum((*hit))));
if (_characters)
_dv_test.push_back(std::move(to_datum<TxtCharEntry>(static_cast<TxtCharEntry*>((*hit)))));
else _dv_test.push_back(std::move(to_datum<TxtBowEntry>(static_cast<TxtBowEntry*>((*hit)))));
++hit;
++n;
}
Expand Down Expand Up @@ -630,23 +641,59 @@ namespace dd
_test_db = std::unique_ptr<caffe::db::DB>();
}

caffe::Datum to_datum(const TxtBowEntry &tbe)
template<class TEntry> caffe::Datum to_datum(TEntry *tbe)
{
std::unordered_map<std::string,Word>::const_iterator wit;
caffe::Datum datum;
int datum_channels = _vocab.size(); // XXX: may be very large
int datum_channels;
if (_characters)
datum_channels = 1;
else datum_channels = _vocab.size(); // XXX: may be very large
datum.set_channels(datum_channels);
datum.set_height(1);
datum.set_height(1);
datum.set_width(1);
datum.set_label(tbe._target);
for (int i=0;i<datum_channels;i++) // XXX: expected to be slow
datum.add_float_data(0.0);
auto hit = tbe._v.cbegin();
while(hit!=tbe._v.cend())
datum.set_label(tbe->_target);
if (!_characters)
{
if ((wit = _vocab.find((*hit).first))!=_vocab.end())
datum.set_float_data(_vocab[(*hit).first]._pos,static_cast<float>((*hit).second));
++hit;
for (int i=0;i<datum_channels;i++) // XXX: expected to be slow
datum.add_float_data(0.0);
tbe->reset();
while(tbe->has_elt())
{
std::string key;
double val;
tbe->get_next_elt(key,val);
if ((wit = _vocab.find(key))!=_vocab.end())
datum.set_float_data(_vocab[key]._pos,static_cast<float>(val));
}
}
else // character-level features
{
tbe->reset();
std::vector<int> vals;
std::unordered_map<char,int>::const_iterator whit;
while(tbe->has_elt())
{
std::string key;
double val = -1.0;
tbe->get_next_elt(key,val);
if ((whit=_alphabet.find(key[0]))!=_alphabet.end())
vals.push_back((*whit).second);
else vals.push_back(-1);
}
std::reverse(vals.begin(),vals.end()); // reverse quantization helps
/*if (vals.size() > _sequence)
std::cerr << "more characters than sequence / " << vals.size() << " / sequence=" << _sequence << std::endl;*/
for (int c=0;c<_sequence;c++)
{
std::vector<float> v(_alphabet.size(),0.0);
if (c<(int)vals.size() && vals[c] != -1)
v[vals[c]] = 1.0;
for (float f: v)
datum.add_float_data(f);
}
datum.set_height(_sequence);
datum.set_width(_alphabet.size());
}
return datum;
}
Expand Down

0 comments on commit 6116a0c

Please sign in to comment.