Skip to content

Commit

Permalink
* Map at::ITensorListRef as used by at::cat() in presets for PyT…
Browse files Browse the repository at this point in the history
…orch (issue #1293)
  • Loading branch information
saudet committed Dec 20, 2022
1 parent c0d15db commit 10ddc85
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 2 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@

* Map `at::ITensorListRef` as used by `at::cat()` in presets for PyTorch ([issue #1293](https://github.com/bytedeco/javacpp-presets/issues/1293))
* Map `torch::data::datasets::ChunkDataReader` and related data loading classes from PyTorch ([issue #1215](https://github.com/bytedeco/javacpp-presets/issues/1215))
* Add missing predefined `AVChannelLayout` in presets for FFmpeg ([issue #1286](https://github.com/bytedeco/javacpp-presets/issues/1286))
* Map `c10::impl::GenericDict` as returned by `c10::IValue::toGenericDict()` in presets for PyTorch
Expand Down
8 changes: 8 additions & 0 deletions pytorch/src/gen/java/org/bytedeco/pytorch/global/torch.java
Original file line number Diff line number Diff line change
Expand Up @@ -18139,6 +18139,7 @@ public class torch extends org.bytedeco.pytorch.presets.torch {
/** Return the Device of a TensorList, if the list is non-empty and
* the first Tensor is defined. (This function implicitly assumes
* that all tensors in the list have the same device.) */
@Namespace("at") public static native @ByVal DeviceOptional device_of(@ByVal TensorArrayRef t);

// namespace at

Expand Down Expand Up @@ -21293,6 +21294,8 @@ scalar_t sf(scalar_t x, scalar_t y)
// #include <ATen/core/Tensor.h>
// #include <functional>

@Namespace("at") public static native @Cast("bool") boolean has_names(@ByVal TensorArrayRef tensors);

// Converts dim to an positional index. Errors if `dim` cannot be used to
// refer to any dimension of tensor.
@Namespace("at") public static native @Cast("int64_t") long dimname_to_position(@Const @ByRef Tensor tensor, @ByVal Dimname dim);
Expand Down Expand Up @@ -37093,10 +37096,15 @@ scalar_t sf(scalar_t x, scalar_t y)


// aten::cat(Tensor[] tensors, int dim=0) -> Tensor
@Namespace("at") public static native @ByVal Tensor cat(@Const @ByRef TensorArrayRef tensors, @Cast("int64_t") long dim/*=0*/);
@Namespace("at") public static native @ByVal Tensor cat(@Const @ByRef TensorArrayRef tensors);

// aten::cat.out(Tensor[] tensors, int dim=0, *, Tensor(a!) out) -> Tensor(a!)
@Namespace("at") public static native @ByRef Tensor cat_out(@ByRef Tensor out, @Const @ByRef TensorArrayRef tensors, @Cast("int64_t") long dim/*=0*/);
@Namespace("at") public static native @ByRef Tensor cat_out(@ByRef Tensor out, @Const @ByRef TensorArrayRef tensors);

// aten::cat.out(Tensor[] tensors, int dim=0, *, Tensor(a!) out) -> Tensor(a!)
@Namespace("at") public static native @ByRef Tensor cat_outf(@Const @ByRef TensorArrayRef tensors, @Cast("int64_t") long dim, @ByRef Tensor out);

// aten::cat.names(Tensor[] tensors, Dimname dim) -> Tensor
@Namespace("at") public static native @ByVal Tensor cat(@ByVal TensorArrayRef tensors, @ByVal Dimname dim);
Expand Down
4 changes: 2 additions & 2 deletions pytorch/src/main/java/org/bytedeco/pytorch/presets/torch.java
Original file line number Diff line number Diff line change
Expand Up @@ -2338,7 +2338,7 @@ public void map(InfoMap infoMap) {
.put(new Info("c10::ArrayRef<at::Dimname>::iterator", "c10::ArrayRef<at::Dimname>::const_iterator").cast().pointerTypes("Dimname"))
.put(new Info("c10::ArrayRef<at::Scalar>", "at::ArrayRef<at::Scalar>").pointerTypes("ScalarArrayRef"))
.put(new Info("c10::ArrayRef<at::Scalar>::iterator", "c10::ArrayRef<at::Scalar>::const_iterator").cast().pointerTypes("Scalar"))
.put(new Info("c10::ArrayRef<at::Tensor>", "at::ArrayRef<at::Tensor>", "at::TensorList").pointerTypes("TensorArrayRef"))
.put(new Info("c10::ArrayRef<at::Tensor>", "at::ArrayRef<at::Tensor>", "at::TensorList", "at::ITensorListRef").pointerTypes("TensorArrayRef"))
.put(new Info("c10::ArrayRef<at::Tensor>(std::vector<at::Tensor,A>&)").javaText(
"public TensorArrayRef(@ByRef TensorVector Vec) { super((Pointer)null); allocate(Vec); }\n"
+ "private native void allocate(@ByRef TensorVector Vec);"))
Expand Down Expand Up @@ -2366,7 +2366,7 @@ public void map(InfoMap infoMap) {
"c10::ArrayRef<at::Tensor>::equals", "c10::ArrayRef<at::indexing::TensorIndex>::equals",
"c10::ArrayRef<c10::optional<at::Tensor> >::equals", "c10::ArrayRef<torch::jit::NamedValue>::equals",
"c10::ArrayRef<torch::autograd::SavedVariable>::equals", "c10::ArrayRef<torch::autograd::SavedVariable>::vec",
"at::ITensorListRef", "std::array<c10::FunctionalityOffsetAndMask,c10::num_functionality_keys>").skip())
"std::array<c10::FunctionalityOffsetAndMask,c10::num_functionality_keys>").skip())
.put(new Info("c10::OptionalArray<int64_t>").pointerTypes("OptionalLongArray"))
.put(new Info("c10::OptionalArray<double>").pointerTypes("OptionalDoubleArray"))
.put(new Info("c10::OptionalArrayRef<int64_t>").pointerTypes("OptionalIntArrayRef"))
Expand Down

0 comments on commit 10ddc85

Please sign in to comment.