Skip to content

Commit

Permalink
Uses absl::string_view as much as possible
Browse files Browse the repository at this point in the history
  • Loading branch information
taku910 committed Jun 14, 2022
1 parent 68034f9 commit 631420b
Show file tree
Hide file tree
Showing 13 changed files with 333 additions and 245 deletions.
4 changes: 2 additions & 2 deletions python/src/sentencepiece/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,8 @@ def SampleEncodeAndScoreAsPieces(self, input, num_samples, theta, wor, include_b
def SampleEncodeAndScoreAsIds(self, input, num_samples, theta, wor, include_best):
return _sentencepiece.SentencePieceProcessor_SampleEncodeAndScoreAsIds(self, input, num_samples, theta, wor, include_best)

def CalculateEntropy(self, text, theta):
return _sentencepiece.SentencePieceProcessor_CalculateEntropy(self, text, theta)
def CalculateEntropy(self, *args):
return _sentencepiece.SentencePieceProcessor_CalculateEntropy(self, *args)

def GetPieceSize(self):
return _sentencepiece.SentencePieceProcessor_GetPieceSize(self)
Expand Down
92 changes: 25 additions & 67 deletions python/src/sentencepiece/sentencepiece.i
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ class PyInputString {
str_ = nullptr;
}
}
absl::string_view str() const { return absl::string_view(data(), size()); }
const char* data() const { return str_; }
Py_ssize_t size() const { return size_; }
bool IsAvalable() const { return str_ != nullptr; }
Expand Down Expand Up @@ -179,7 +180,7 @@ inline void CheckIds(const std::vector<int> &ids, int num_pieces) {
}
}

inline void CheckIds(const std::vector<std::string> &ids, int num_pieces) {}
inline void CheckIds(const std::vector<absl::string_view> &ids, int num_pieces) {}

class ThreadPool {
public:
Expand Down Expand Up @@ -266,6 +267,7 @@ inline void InitNumThreads(const std::vector<T> &ins, int *num_threads) {
%ignore sentencepiece::util::Status;
%ignore sentencepiece::util::StatusCode;
%ignore absl::string_view;
%ignore std::string_view;
%ignore sentencepiece::SentencePieceText;
%ignore sentencepiece::NormalizerSpec;
%ignore sentencepiece::TrainerSpec;
Expand Down Expand Up @@ -386,7 +388,7 @@ inline void InitNumThreads(const std::vector<T> &ins, int *num_threads) {
return $self->DecodeIds(ids);
}

std::string _DecodePieces(const std::vector<std::string> &pieces) const {
std::string _DecodePieces(const std::vector<absl::string_view> &pieces) const {
return $self->DecodePieces(pieces);
}

Expand All @@ -397,7 +399,7 @@ inline void InitNumThreads(const std::vector<T> &ins, int *num_threads) {
}

sentencepiece::util::bytes _DecodePiecesAsSerializedProto(
const std::vector<std::string> &pieces) const {
const std::vector<absl::string_view> &pieces) const {
CheckIds(pieces, $self->GetPieceSize());
return $self->DecodePiecesAsSerializedProto(pieces);
}
Expand All @@ -416,12 +418,12 @@ inline void InitNumThreads(const std::vector<T> &ins, int *num_threads) {
}

std::vector<std::string> _DecodePiecesBatch(
const std::vector<std::vector<std::string>> &ins, int num_threads) const {
const std::vector<std::vector<absl::string_view>> &ins, int num_threads) const {
DEFINE_DECODE_BATCH_FUNC_IMPL(DecodePieces, std::string, std::string);
}

BytesArray _DecodePiecesAsSerializedProtoBatch(
const std::vector<std::vector<std::string>> &ins, int num_threads) const {
const std::vector<std::vector<absl::string_view>> &ins, int num_threads) const {
DEFINE_DECODE_BATCH_FUNC_IMPL(DecodePiecesAsSerializedProto, std::string,
sentencepiece::util::bytes);
}
Expand Down Expand Up @@ -1029,14 +1031,14 @@ inline void InitNumThreads(const std::vector<T> &ins, int *num_threads) {
%typemap(out) std::vector<int> {
$result = PyList_New($1.size());
for (size_t i = 0; i < $1.size(); ++i) {
PyList_SetItem($result, i, PyInt_FromLong(static_cast<long>($1[i])));
PyList_SET_ITEM($result, i, PyInt_FromLong(static_cast<long>($1[i])));
}
}

%typemap(out) std::vector<float> {
$result = PyList_New($1.size());
for (size_t i = 0; i < $1.size(); ++i) {
PyList_SetItem($result, i, PyFloat_FromDouble(static_cast<double>($1[i])));
PyList_SET_ITEM($result, i, PyFloat_FromDouble(static_cast<double>($1[i])));
}
}

Expand All @@ -1045,24 +1047,24 @@ inline void InitNumThreads(const std::vector<T> &ins, int *num_threads) {
for (size_t i = 0; i < $1.size(); ++i) {
PyObject *obj = PyList_New($1[i].size());
for (size_t j = 0; j < $1[i].size(); ++j) {
PyList_SetItem(obj, j, PyInt_FromLong(static_cast<long>($1[i][j])));
PyList_SET_ITEM(obj, j, PyInt_FromLong(static_cast<long>($1[i][j])));
}
PyList_SetItem($result, i, obj);
PyList_SET_ITEM($result, i, obj);
}
}

%typemap(out) std::vector<std::string> {
PyObject *input_type = resultobj;
$result = PyList_New($1.size());
for (size_t i = 0; i < $1.size(); ++i) {
PyList_SetItem($result, i, MakePyOutputString($1[i], input_type));
PyList_SET_ITEM($result, i, MakePyOutputString($1[i], input_type));
}
}

%typemap(out) BytesArray {
$result = PyList_New($1.size());
for (size_t i = 0; i < $1.size(); ++i) {
PyList_SetItem($result, i, MakePyOutputBytes($1[i]));
PyList_SET_ITEM($result, i, MakePyOutputBytes($1[i]));
}
}

Expand All @@ -1072,9 +1074,9 @@ inline void InitNumThreads(const std::vector<T> &ins, int *num_threads) {
for (size_t i = 0; i < $1.size(); ++i) {
PyObject *obj = PyList_New($1[i].size());
for (size_t j = 0; j < $1[i].size(); ++j) {
PyList_SetItem(obj, j, MakePyOutputString($1[i][j], input_type));
PyList_SET_ITEM(obj, j, MakePyOutputString($1[i][j], input_type));
}
PyList_SetItem($result, i, obj);
PyList_SET_ITEM($result, i, obj);
}
}

Expand Down Expand Up @@ -1118,51 +1120,7 @@ inline void InitNumThreads(const std::vector<T> &ins, int *num_threads) {
SWIG_fail;
}
resultobj = ustring.input_type();
$1 = absl::string_view(ustring.data(), ustring.size());
}

%typemap(in) const std::vector<std::string>& {
std::vector<std::string> *out = nullptr;
if (PyList_Check($input)) {
const size_t size = PyList_Size($input);
out = new std::vector<std::string>(size);
for (size_t i = 0; i < size; ++i) {
const PyInputString ustring(PyList_GetItem($input, i));
if (ustring.IsAvalable()) {
(*out)[i].assign(ustring.data(), ustring.size());
} else {
PyErr_SetString(PyExc_TypeError, "list must contain strings");
SWIG_fail;
}
resultobj = ustring.input_type();
}
} else {
PyErr_SetString(PyExc_TypeError, "not a list");
SWIG_fail;
}
$1 = out;
}

%typemap(in) const std::vector<absl::string_view>& {
std::vector<absl::string_view> *out = nullptr;
if (PyList_Check($input)) {
const size_t size = PyList_Size($input);
out = new std::vector<std::string>(size);
for (size_t i = 0; i < size; ++i) {
const PyInputString ustring(PyList_GetItem($input, i));
if (ustring.IsAvalable()) {
(*out)[i] = absl::string_view(ustring.data(), ustring.size());
} else {
PyErr_SetString(PyExc_TypeError, "list must contain strings");
SWIG_fail;
}
resultobj = ustring.input_type();
}
} else {
PyErr_SetString(PyExc_TypeError, "not a list");
SWIG_fail;
}
$1 = out;
$1 = ustring.str();
}

%typemap(in) const std::vector<absl::string_view>& {
Expand All @@ -1173,7 +1131,7 @@ inline void InitNumThreads(const std::vector<T> &ins, int *num_threads) {
for (size_t i = 0; i < size; ++i) {
const PyInputString ustring(PyList_GetItem($input, i));
if (ustring.IsAvalable()) {
(*out)[i] = absl::string_view(ustring.data(), ustring.size());
(*out)[i] = ustring.str();
} else {
PyErr_SetString(PyExc_TypeError, "list must contain strings");
SWIG_fail;
Expand Down Expand Up @@ -1208,11 +1166,11 @@ inline void InitNumThreads(const std::vector<T> &ins, int *num_threads) {
$1 = out;
}

%typemap(in) const std::vector<std::vector<std::string>>& {
std::vector<std::vector<std::string>> *out = nullptr;
%typemap(in) const std::vector<std::vector<absl::string_view>>& {
std::vector<std::vector<absl::string_view>> *out = nullptr;
if (PyList_Check($input)) {
const size_t size = PyList_Size($input);
out = new std::vector<std::vector<std::string>>(size);
out = new std::vector<std::vector<absl::string_view>>(size);
for (size_t i = 0; i < size; ++i) {
PyObject *o = PyList_GetItem($input, i);
if (PyList_Check(o)) {
Expand All @@ -1221,7 +1179,7 @@ inline void InitNumThreads(const std::vector<T> &ins, int *num_threads) {
for (size_t j = 0; j < size2; ++j) {
const PyInputString ustring(PyList_GetItem(o, j));
if (ustring.IsAvalable()) {
(*out)[i][j].assign(ustring.data(), ustring.size());
(*out)[i][j] = ustring.str();
} else {
PyErr_SetString(PyExc_TypeError,"list must contain integers");
SWIG_fail;
Expand Down Expand Up @@ -1302,9 +1260,9 @@ inline void InitNumThreads(const std::vector<T> &ins, int *num_threads) {
for (size_t i = 0; i < $1.size(); ++i) {
PyObject *obj = PyList_New($1[i].first.size());
for (size_t j = 0; j < $1[i].first.size(); ++j) {
PyList_SetItem(obj, j, MakePyOutputString($1[i].first[j], input_type));
PyList_SET_ITEM(obj, j, MakePyOutputString($1[i].first[j], input_type));
}
PyList_SetItem($result, i, PyTuple_Pack(2, obj, PyFloat_FromDouble(static_cast<double>($1[i].second))));
PyList_SET_ITEM($result, i, PyTuple_Pack(2, obj, PyFloat_FromDouble(static_cast<double>($1[i].second))));
}
}

Expand All @@ -1313,9 +1271,9 @@ inline void InitNumThreads(const std::vector<T> &ins, int *num_threads) {
for (size_t i = 0; i < $1.size(); ++i) {
PyObject *obj = PyList_New($1[i].first.size());
for (size_t j = 0; j < $1[i].first.size(); ++j) {
PyList_SetItem(obj, j, PyInt_FromLong(static_cast<long>($1[i].first[j])));
PyList_SET_ITEM(obj, j, PyInt_FromLong(static_cast<long>($1[i].first[j])));
}
PyList_SetItem($result, i, PyTuple_Pack(2, obj, PyFloat_FromDouble(static_cast<double>($1[i].second))));
PyList_SET_ITEM($result, i, PyTuple_Pack(2, obj, PyFloat_FromDouble(static_cast<double>($1[i].second))));
}
}

Expand Down
Loading

0 comments on commit 631420b

Please sign in to comment.