Skip to content

Commit

Permalink
feat(ml): torch db stores encoded images instead of tensors
Browse files Browse the repository at this point in the history
fix: corrected passing test labels in image dataset with torch

feat: encoded images with regression targets
  • Loading branch information
beniz authored and mergify[bot] committed Nov 26, 2020
1 parent 7f16490 commit e7f3c19
Show file tree
Hide file tree
Showing 4 changed files with 194 additions and 27 deletions.
136 changes: 124 additions & 12 deletions src/backends/torch/torchdataset.cc
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,99 @@ namespace dd
}
}

void TorchDataset::write_image_to_db(const cv::Mat &bgr,
const torch::Tensor &target)
{
// serialize image
std::stringstream dstream;
std::vector<uint8_t> buffer;
std::vector<int> param = { cv::IMWRITE_JPEG_QUALITY, 100 };
cv::imencode(".jpg", bgr, buffer, param);
for (uint8_t c : buffer)
dstream << c;

// serialize target
std::ostringstream tstream;
torch::save(target, tstream);

// check on db
if (_dbData == nullptr)
{
_dbData = std::shared_ptr<db::DB>(db::GetDB(_backend));
_dbData->Open(_dbFullName, db::NEW);
_txn = std::shared_ptr<db::Transaction>(_dbData->NewTransaction());
}

// data & target keys
std::stringstream data_key;
std::stringstream target_key;
data_key << std::to_string(_current_index) << "_data";
target_key << std::to_string(_current_index) << "_target";

// store into db
_txn->Put(data_key.str(), dstream.str());
_txn->Put(target_key.str(), tstream.str());

// should not commit transactions every time;
if (++_current_index % _batches_per_transaction == 0)
{
_txn->Commit();
_txn.reset(_dbData->NewTransaction());
_logger->info("Put {} images in db", _current_index);
}
}

void TorchDataset::read_image_from_db(const std::string &datas,
const std::string &targets,
cv::Mat &bgr, torch::Tensor &targett,
const bool &bw)
{
std::vector<uint8_t> img_data(datas.begin(), datas.end());
bgr = cv::Mat(img_data, true);
bgr = cv::imdecode(bgr,
bw ? CV_LOAD_IMAGE_GRAYSCALE : CV_LOAD_IMAGE_COLOR);
std::stringstream targetstream(targets);
torch::load(targett, targetstream);
}

// add image batch
void TorchDataset::add_image_batch(const cv::Mat &bgr, const int &width,
const int &height, const int &target)
{
if (!_db)
{
// to tensor
at::Tensor imgt = image_to_tensor(bgr, height, width);
at::Tensor targett = target_to_tensor(target);
add_batch({ imgt }, { targett });
}
else
{
// write to db
torch::Tensor targett = target_to_tensor(target);
write_image_to_db(bgr, targett);
}
}

void TorchDataset::add_image_batch(const cv::Mat &bgr, const int &width,
const int &height,
const std::vector<double> &target)
{
if (!_db)
{
// to tensor
at::Tensor imgt = image_to_tensor(bgr, height, width);
at::Tensor targett = target_to_tensor(target);
add_batch({ imgt }, { targett });
}
else
{
// write to db
torch::Tensor targett = target_to_tensor(target);
write_image_to_db(bgr, targett);
}
}

void TorchDataset::add_batch(const std::vector<at::Tensor> &data,
const std::vector<at::Tensor> &target)
{
Expand Down Expand Up @@ -309,13 +402,8 @@ namespace dd
_dbData->Get(target_key.str(), targets);
_dbCursor->Next();

std::stringstream datastream(datas);
std::stringstream targetstream(targets);

std::vector<torch::Tensor> d;
std::vector<torch::Tensor> t;
torch::load(d, datastream);
torch::load(t, targetstream);

if (first_iter)
{
Expand All @@ -324,12 +412,39 @@ namespace dd
first_iter = false;
}

if (!_image)
{
std::stringstream datastream(datas);
std::stringstream targetstream(targets);
torch::load(d, datastream);
torch::load(t, targetstream);
}
else
{
ImgTorchInputFileConn *inputc
= reinterpret_cast<ImgTorchInputFileConn *>(_inputc);

cv::Mat bgr;
torch::Tensor targett;
read_image_from_db(datas, targets, bgr, targett, inputc->_bw);

torch::Tensor imgt
= image_to_tensor(bgr, inputc->height(), inputc->width());

d.push_back(imgt);
t.push_back(targett);
}

for (unsigned int i = 0; i < d.size(); ++i)
{
while (i >= data.size())
data.emplace_back();
data[i].push_back(d.at(i));
}
for (unsigned int i = 0; i < t.size(); ++i)
{
while (i >= target.size())
target.emplace_back();
target[i].push_back(t.at(i));
}

Expand Down Expand Up @@ -391,10 +506,7 @@ namespace dd
}
if (dimg._imgs.size() != 0)
{
at::Tensor imgt = image_to_tensor(dimg._imgs[0], height, width);
at::Tensor targett = target_to_tensor(target);

add_batch({ imgt }, { targett });
add_image_batch(dimg._imgs[0], height, width, target);
return 0;
}
else
Expand Down Expand Up @@ -426,10 +538,10 @@ namespace dd
}
if (dimg._imgs.size() != 0)
{
at::Tensor imgt = image_to_tensor(dimg._imgs[0], height, width);
/*at::Tensor imgt = image_to_tensor(dimg._imgs[0], height, width);
at::Tensor targett = target_to_tensor(target);

add_batch({ imgt }, { targett });
add_batch({ imgt }, { targett });*/
add_image_batch(dimg._imgs[0], height, width, target);
return 0;
}
else
Expand Down
49 changes: 48 additions & 1 deletion src/backends/torch/torchdataset.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,8 @@ namespace dd
= nullptr; /**< back ptr to input connector. */
bool _classification = true; /**< whether a classification dataset. */

bool _image = false; /**< whether an image dataset. */

/**
* \brief empty constructor
*/
Expand All @@ -96,7 +98,7 @@ namespace dd
_logger(d._logger), _shuffle(d._shuffle), _dbData(d._dbData),
_indices(d._indices), _lfiles(d._lfiles), _batches(d._batches),
_dbFullName(d._dbFullName), _inputc(d._inputc),
_classification(d._classification)
_classification(d._classification), _image(d._image)
{
}

Expand All @@ -112,6 +114,19 @@ namespace dd
void add_batch(const std::vector<at::Tensor> &data,
const std::vector<at::Tensor> &target = {});

/**
* \brief add an encoded image to a batch, with an int target
*/
void add_image_batch(const cv::Mat &bgr, const int &width,
const int &height, const int &target);

/**
* \brief add an encoded image to a batch, with a vector of regression
* targets
*/
void add_image_batch(const cv::Mat &bgr, const int &width,
const int &height, const std::vector<double> &target);

/**
* \brief reset dataset reading status : ie start new epoch
*/
Expand Down Expand Up @@ -237,14 +252,34 @@ namespace dd
}

/*-- image tools --*/

/**
* \brief adds image to batch, with an int target
*/
int add_image_file(const std::string &fname, const int &target,
const int &height, const int &width);

/**
* \brief adds image to batch, with a set of regression targets
*/
int add_image_file(const std::string &fname,
const std::vector<double> &target, const int &height,
const int &width);

/**
* \brief turns an image into a torch::Tensor
*/
at::Tensor image_to_tensor(const cv::Mat &bgr, const int &height,
const int &width);

/**
* \brief turns an int into a torch::Tensor
*/
at::Tensor target_to_tensor(const int &target);

/**
* \brief turns a vector of double into a torch::Tensor
*/
at::Tensor target_to_tensor(const std::vector<double> &target);

private:
Expand All @@ -253,6 +288,18 @@ namespace dd
*/
void write_tensors_to_db(const std::vector<at::Tensor> &data,
const std::vector<at::Tensor> &target);

/**
* \brief writes encoded image to db with a tensor target
*/
void write_image_to_db(const cv::Mat &bgr, const torch::Tensor &target);

/**
* \brief reads an encoded image from db along with its tensor target
*/
void read_image_from_db(const std::string &datas,
const std::string &targets, cv::Mat &bgr,
torch::Tensor &targett, const bool &bw);
};

}
Expand Down
31 changes: 18 additions & 13 deletions src/backends/torch/torchinputconns.cc
Original file line number Diff line number Diff line change
Expand Up @@ -109,14 +109,15 @@ namespace dd
std::vector<std::pair<std::string, int>> &lfiles,
std::unordered_map<int, std::string> &hcorresp,
std::unordered_map<std::string, int> &hcorresp_r,
const std::string &folderPath)
const std::string &folderPath, const bool &test)
{
_logger->info("Reading image folder {}", folderPath);

// TODO Put file parsing from caffe in common files to use it in other
// backends
int cl = 0;

std::unordered_map<std::string, int>::const_iterator hcit;
std::unordered_set<std::string> subdirs;
if (fileops::list_directory(folderPath, false, true, false, subdirs))
throw InputConnectorBadParamException(
Expand All @@ -130,8 +131,18 @@ namespace dd
throw InputConnectorBadParamException(
"failed reading image train data sub-directory " + (*uit));
std::string cls = dd_utils::split((*uit), '/').back();
hcorresp.insert(std::pair<int, std::string>(cl, cls));
hcorresp_r.insert(std::pair<std::string, int>(cls, cl));
if (!test)
{
hcorresp.insert(std::pair<int, std::string>(cl, cls));
hcorresp_r.insert(std::pair<std::string, int>(cls, cl));
}
else
{
if ((hcit = hcorresp_r.find(cls)) != hcorresp_r.end())
cl = (*hcit).second;
else
_logger->warn("unknown class {} in test set", cls);
}
auto fit = subdir_files.begin();
while (
fit
Expand All @@ -140,7 +151,8 @@ namespace dd
lfiles.push_back(std::pair<std::string, int>((*fit), cl));
++fit;
}
++cl;
if (!test)
++cl;
++uit;
}
}
Expand Down Expand Up @@ -249,15 +261,8 @@ namespace dd
read_image_folder(lfiles, hcorresp, hcorresp_r, _uris.at(0));
if (_uris.size() > 1)
{
std::unordered_map<int, std::string>
test_hcorresp; // correspondence class number / class
// name
std::unordered_map<std::string, int>
test_hcorresp_r; // reverse correspondence for test
// set.

read_image_folder(test_lfiles, test_hcorresp,
test_hcorresp_r, _uris.at(1));
read_image_folder(test_lfiles, hcorresp, hcorresp_r,
_uris.at(1), true);
}

if (_dataset._shuffle)
Expand Down
5 changes: 4 additions & 1 deletion src/backends/torch/torchinputconns.h
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,9 @@ namespace dd
: ImgInputFileConn(i), TorchInputInterface(i)
{
_dataset._inputc = this;
_dataset._image = true;
_test_dataset._inputc = this;
_test_dataset._image = true;
set_db_transaction_size(TORCH_IMG_TRANSACTION_SIZE);
}

Expand Down Expand Up @@ -241,7 +243,8 @@ namespace dd
void read_image_folder(std::vector<std::pair<std::string, int>> &lfiles,
std::unordered_map<int, std::string> &hcorresp,
std::unordered_map<std::string, int> &hcorresp_r,
const std::string &folderPath);
const std::string &folderPath,
const bool &test = false);

/**
* \brief read images from txt list
Expand Down

0 comments on commit e7f3c19

Please sign in to comment.