Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #309 from mlverse/sparsemax
Sparsemax
- Loading branch information
Showing
16 changed files
with
332 additions
and
13 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,4 @@ | ||
Scalar <- R6::R6Class( | ||
Scalar <- R7Class( | ||
classname = "torch_scalar", | ||
|
||
public = list( | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,133 @@ | ||
#define LANTERN_BUILD | ||
#include "lantern/lantern.h" | ||
#include <torch/torch.h> | ||
#include <string> | ||
#include <iostream> | ||
#include "../utils.hpp" | ||
#include <torch/torch.h> | ||
#include <stdexcept> // std::out_of_range | ||
|
||
using namespace torch::autograd; | ||
|
||
// Inherit from Function | ||
class SparseMaxFunction : public Function<SparseMaxFunction> { | ||
public: | ||
|
||
static torch::Tensor forward(AutogradContext *ctx, torch::Tensor input, int dim) { | ||
|
||
auto input_dim = input.dim(); | ||
if (input_dim <= dim || dim < -input_dim) | ||
{ | ||
throw std::out_of_range("Dimension out of range"); | ||
} | ||
|
||
bool needs_reshaping = input_dim > 2; | ||
auto original_size = input.sizes().vec(); | ||
|
||
if (needs_reshaping) | ||
{ | ||
// transpose batch and nth dim | ||
input = input.transpose(0, dim); | ||
|
||
// Flatten all dimensions except nth dim | ||
input = input.reshape({input.size(0), -1}); | ||
|
||
// Transpose flattened dimensions to 0th dim, nth dim to last dim | ||
input = input.transpose(0, -1); | ||
} | ||
|
||
// Translate by max for numerical stability | ||
input = input - std::get<0>(input.max(-1, true)).expand_as(input); | ||
|
||
auto zs = std::get<0>(input.sort(-1, true)); | ||
auto range = torch::arange(1, input.size(-1) + 1); | ||
range = range.expand_as(input).to(input); | ||
|
||
// Determine sparsity of projection | ||
auto bound = 1 + range * zs; | ||
auto is_gt = bound.gt(zs.cumsum(-1)).to(input.dtype()); | ||
auto k = std::get<0>((is_gt * range).max(-1, true)); | ||
|
||
// Compute threshold | ||
auto zs_sparse = is_gt * zs; | ||
|
||
// Compute taus | ||
auto taus = (zs_sparse.sum(-1, true) - 1) / k; | ||
taus = taus.expand_as(input); | ||
|
||
auto output = torch::max(torch::zeros_like(input), input - taus); | ||
|
||
// Save context | ||
ctx->save_for_backward({output}); | ||
ctx->saved_data["needs_reshaping"] = needs_reshaping; | ||
ctx->saved_data["dim"] = dim; | ||
|
||
if (needs_reshaping) | ||
{ | ||
// Tranpose flattened dim to last dim, nth dim to 0th dim | ||
output = output.transpose(0, 1); | ||
|
||
// Reshape to original size | ||
output = output.reshape(original_size); | ||
|
||
// Swap batch dim and nth dim | ||
output = output.transpose(0, dim); | ||
} | ||
|
||
return output; | ||
} | ||
|
||
static tensor_list backward(AutogradContext *ctx, tensor_list grad_outputs) { | ||
auto saved = ctx->get_saved_variables(); | ||
auto output = saved[0]; | ||
auto grad_output = grad_outputs[0]; | ||
|
||
bool needs_reshaping = ctx->saved_data["needs_reshaping"].toBool(); | ||
int dim = ctx->saved_data["dim"].toInt(); | ||
auto original_size = grad_output.sizes().vec(); | ||
|
||
if (needs_reshaping) | ||
{ | ||
// transpose batch and nth dim | ||
grad_output = grad_output.transpose(0, dim); | ||
|
||
// Flatten all dimensions except nth dim | ||
grad_output = grad_output.reshape({grad_output.size(0), -1}); | ||
|
||
// Transpose flattened dimensions to 0th dim, nth dim to last dim | ||
grad_output = grad_output.transpose(0, -1); | ||
} | ||
|
||
// Compute gradient | ||
auto nonzeros = torch::ne(output, 0); | ||
auto num_nonzeros = nonzeros.sum(-1, true); | ||
auto sum = (grad_output * nonzeros).sum(-1, true) / num_nonzeros; | ||
auto grad_input = nonzeros * (grad_output - sum.expand_as(grad_output)); | ||
|
||
if (needs_reshaping) | ||
{ | ||
// Tranpose flattened dim to last dim, nth dim to 0th dim | ||
grad_input = grad_input.transpose(0, 1); | ||
|
||
// Reshape to original size | ||
grad_input = grad_input.reshape(original_size); | ||
|
||
// Swap batch dim and nth dim | ||
grad_input = grad_input.transpose(0, dim); | ||
} | ||
|
||
auto o = torch::autograd::variable_list(2); | ||
o[0] = grad_input; | ||
|
||
return o; | ||
} | ||
}; | ||
|
||
void * _lantern_contrib_torch_sparsemax (void * input, int dim) | ||
{ | ||
LANTERN_FUNCTION_START | ||
torch::Tensor t = reinterpret_cast<LanternObject<torch::Tensor> *>(input)->get(); | ||
torch::Tensor res = SparseMaxFunction::apply(t, dim); | ||
return (void*) new LanternObject<torch::Tensor>(res); | ||
LANTERN_FUNCTION_END | ||
} |
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
#include "torch_types.h" | ||
#include "utils.h" | ||
|
||
// [[Rcpp::export]] | ||
Rcpp::XPtr<XPtrTorchTensor> cpp_contrib_torch_sparsemax (Rcpp::XPtr<XPtrTorchTensor> input, int dim) | ||
{ | ||
XPtrTorchTensor out = lantern_contrib_torch_sparsemax(input->get(), dim); | ||
return make_xptr<XPtrTorchTensor>(out); | ||
} |
Oops, something went wrong.