Skip to content

Commit

Permalink
feat(torch): ocr model training and inference
Browse files Browse the repository at this point in the history
  • Loading branch information
Bycob committed May 16, 2022
1 parent 679fb1d commit 3fc2e27
Show file tree
Hide file tree
Showing 21 changed files with 901 additions and 171 deletions.
1 change: 1 addition & 0 deletions src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ if (USE_TORCH)
backends/torch/torchinputconns.cc
backends/torch/native/templates/nbeats.cc
backends/torch/native/templates/vit.cc
backends/torch/native/templates/crnn.cc
backends/torch/native/templates/visformer.cc
backends/torch/native/templates/ttransformer.cc
backends/torch/native/templates/ttransformer/tembedder.cc
Expand Down
112 changes: 112 additions & 0 deletions src/backends/torch/native/templates/crnn.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
/**
* DeepDetect
* Copyright (c) 2022 Jolibrain
* Author: Louis Jean <louis.jean@jolibrain.com>
*
* This file is part of deepdetect.
*
* deepdetect is free software: you can redistribute it and/or modify
* it under the terms of the GNU Lesser General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* deepdetect is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public License
* along with deepdetect. If not, see <http://www.gnu.org/licenses/>.
*/

#include "crnn.hpp"

#include "../../torchlib.h"

namespace dd
{
void CRNN::get_params(const APIData &ad_params,
const std::vector<long int> &input_dims,
int output_size)
{
if (ad_params.has("timesteps"))
_timesteps = ad_params.get("timesteps").get<int>();
if (ad_params.has("hidden_size"))
_hidden_size = ad_params.get("hidden_size").get<int>();
if (ad_params.has("num_layers"))
_num_layers = ad_params.get("num_layers").get<int>();

if (output_size > 0)
_output_size = output_size;

at::Tensor dummy = torch::zeros(
std::vector<int64_t>(input_dims.begin(), input_dims.end()));
int batch_size = dummy.size(0);
dummy = dummy.reshape({ batch_size, -1, _timesteps });
_input_size = dummy.size(1);
}

void CRNN::init()
{
if (_lstm)
{
unregister_module("lstm");
_lstm = nullptr;
}

uint32_t hidden_size, proj_size;
if (_hidden_size > 0 || _hidden_size == _output_size)
{
hidden_size = _hidden_size;
proj_size = _output_size;
}
else
{
hidden_size = _output_size;
proj_size = 0;
}

_lstm = register_module(
"lstm",
torch::nn::LSTM(torch::nn::LSTMOptions(_input_size, hidden_size)
.num_layers(_num_layers)
.proj_size(proj_size)));
}

void CRNN::set_output_size(int output_size)
{
if (_output_size != output_size)
{
_output_size = output_size;
init();
}
}

torch::Tensor CRNN::forward(torch::Tensor feats)
{
// Input: feature map from resnet
// Output: LSTM results
int batch_size = feats.size(0);
feats = feats.reshape({ batch_size, -1, _timesteps });
// timesteps first
feats = feats.permute({ 2, 0, 1 });

// std::cout << "feats before: " << feats.sizes() << std::endl;
std::tuple<at::Tensor, std::tuple<at::Tensor, at::Tensor>> outputs
= _lstm->forward(feats);
// std::cout << "feats after: " << std::get<0>(outputs).sizes() <<
// std::endl;

return std::get<0>(outputs);
}

torch::Tensor CRNN::loss(std::string loss, torch::Tensor input,
torch::Tensor output, torch::Tensor target)
{
(void)loss;
(void)input;
(void)output;
(void)target;
throw MLLibInternalException("CRNN::loss not implemented");
}
}
125 changes: 125 additions & 0 deletions src/backends/torch/native/templates/crnn.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
/**
* DeepDetect
* Copyright (c) 2022 Jolibrain
* Author: Louis Jean <louis.jean@jolibrain.com>
*
* This file is part of deepdetect.
*
* deepdetect is free software: you can redistribute it and/or modify
* it under the terms of the GNU Lesser General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* deepdetect is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public License
* along with deepdetect. If not, see <http://www.gnu.org/licenses/>.
*/

#ifndef CRNN_H
#define CRNN_H

#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wunused-parameter"
#include "torch/torch.h"
#pragma GCC diagnostic pop
#include "../../torchinputconns.h"
#include "../native_net.h"

namespace dd
{
class CRNN : public NativeModuleImpl<CRNN>
{
public:
torch::nn::LSTM _lstm = nullptr;
int _timesteps = 32;
int _num_layers = 3;
int _hidden_size = 64;

// dataset / backbone dependent variables
int _input_size = 64;
int _output_size = 2;

public:
CRNN(int timesteps = 32, int num_layers = 3, int hidden_size = 0,
int input_size = 64, int output_size = 2)
: _timesteps(timesteps), _num_layers(num_layers),
_hidden_size(hidden_size), _input_size(input_size),
_output_size(output_size)
{
init();
}

CRNN(const APIData &ad_params, const std::vector<long int> &input_dims,
int output_size = -1)
{
get_params(ad_params, input_dims, output_size);
init();
}

CRNN(const CRNN &other)
: torch::nn::Module(other), _timesteps(other._timesteps),
_num_layers(other._num_layers), _hidden_size(other._hidden_size),
_input_size(other._input_size), _output_size(other._output_size)
{
init();
}

CRNN &operator=(const CRNN &other)
{
_timesteps = other._timesteps;
_num_layers = other._num_layers;
_hidden_size = other._hidden_size;

_input_size = other._input_size;
_output_size = other._output_size;
init();
return *this;
}

void init();

void get_params(const APIData &ad_params,
const std::vector<long int> &input_dims, int output_size);

void set_output_size(int output_size);

torch::Tensor forward(torch::Tensor x) override;

void reset() override
{
init();
}

torch::Tensor extract(torch::Tensor x, std::string extract_layer) override
{
(void)x;
(void)extract_layer;
return torch::Tensor();
}

bool extractable(std::string extract_layer) const override
{
(void)extract_layer;
return false;
}

std::vector<std::string> extractable_layers() const override
{
return std::vector<std::string>();
}

torch::Tensor cleanup_output(torch::Tensor output) override
{
return output;
}

torch::Tensor loss(std::string loss, torch::Tensor input,
torch::Tensor output, torch::Tensor target) override;
};
}

#endif // CRNN_H
47 changes: 47 additions & 0 deletions src/backends/torch/torchdataset.cc
Original file line number Diff line number Diff line change
Expand Up @@ -840,6 +840,53 @@ namespace dd
return 0;
}

int TorchDataset::add_image_text_file(
const std::string &fname, const std::string &target, int height,
int width, std::unordered_map<uint32_t, int> &alphabet,
int max_ocr_length)
{
at::Tensor target_tensor = torch::zeros(
max_ocr_length, at::TensorOptions().dtype(torch::kInt64));
at::Tensor target_length
= torch::full(1, at::Scalar(int(target.size())),
at::TensorOptions().dtype(torch::kInt64));
int i = 0;

for (auto &c : target)
{
if (i >= max_ocr_length)
{
// can happen in test set
this->_logger->warn("Sequence \"{}\" is exceeding maximum ocr "
"length {}. Truncating...",
target, max_ocr_length);
break;
}
auto it = alphabet.find(c);

if (it != alphabet.end())
{
target_tensor[i] = it->second;
}
else if (!_test)
{
this->_logger->info("added {} to alphabet", c);
int id = alphabet.size();
alphabet[c] = id;
target_tensor[i] = id;
}
else
{
this->_logger->warn(
"Character {} in test set but not in train set", c);
}

i++;
}
add_image_file(fname, { target_tensor, target_length }, height, width);
return 0;
}

at::Tensor TorchDataset::image_to_tensor(const cv::Mat &bgr,
const int &height, const int &width,
const bool &target)
Expand Down
12 changes: 12 additions & 0 deletions src/backends/torch/torchdataset.h
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,18 @@ namespace dd
const std::string &bboxfname, const int &height,
const int &width);

/**
* \brief adds image from image filename with a text target (ocr)
* \param width of preprocessed image
* \param height of preprocessed image
* \param alphabet mapping between characters and logit index
* \param max_ocr_length maximum possible size of the sequence
*/
int add_image_text_file(const std::string &fname,
const std::string &target, int height, int width,
std::unordered_map<uint32_t, int> &alphabet,
int max_ocr_length);

/**
* \brief turns an image into a torch::Tensor
* \param bgr input image
Expand Down
Loading

0 comments on commit 3fc2e27

Please sign in to comment.