Skip to content

Commit

Permalink
ONNX Pad support (partial) (huggingface#2196)
Browse files Browse the repository at this point in the history
  • Loading branch information
Vladislav Kuzemchik committed May 20, 2024
1 parent 7ebc354 commit 14b0930
Show file tree
Hide file tree
Showing 3 changed files with 273 additions and 0 deletions.
38 changes: 38 additions & 0 deletions candle-core/src/tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2212,6 +2212,44 @@ impl Tensor {
Self::cat(&args, dim)
}

/// Pad the input tensor using constant values along dimension `dim`. This adds `left` elements before the
/// input tensor values and `right` elements after.
pub fn pad_with_const<D: Dim>(
&self,
dim: D,
left: usize,
right: usize,
constant: f64,
) -> Result<Self> {
if left == 0 && right == 0 {
Ok(self.clone())
} else if left == 0 {
let dim = dim.to_index(self.shape(), "pad_with_const")?;
let mut dims = self.dims().to_vec();
dims[dim] = right;
let right =
Tensor::ones(dims.as_slice(), self.dtype, self.device())?.affine(constant, 0.0)?;
Tensor::cat(&[self, &right], dim)
} else if right == 0 {
let dim = dim.to_index(self.shape(), "pad_with_const")?;
let mut dims = self.dims().to_vec();
dims[dim] = left;
let left =
Tensor::ones(dims.as_slice(), self.dtype, self.device())?.affine(constant, 0.0)?;
Tensor::cat(&[&left, self], dim)
} else {
let dim = dim.to_index(self.shape(), "pad_with_const")?;
let mut dims = self.dims().to_vec();
dims[dim] = left;
let left =
Tensor::ones(dims.as_slice(), self.dtype, self.device())?.affine(constant, 0.0)?;
dims[dim] = right;
let right =
Tensor::ones(dims.as_slice(), self.dtype, self.device())?.affine(constant, 0.0)?;
Tensor::cat(&[&left, self, &right], dim)
}
}

/// Pad the input tensor using 0s along dimension `dim`. This adds `left` elements before the
/// input tensor values and `right` elements after.
pub fn pad_with_zeros<D: Dim>(&self, dim: D, left: usize, right: usize) -> Result<Self> {
Expand Down
52 changes: 52 additions & 0 deletions candle-onnx/src/eval.rs
Original file line number Diff line number Diff line change
Expand Up @@ -444,6 +444,58 @@ pub fn simple_eval(
};
values.insert(node.output[0].clone(), ys);
}
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#Pad
"Pad" => {
let input = get(&node.input[0])?;
let pads_opt = get_attr_opt::<[i64]>(node, "pads")?;
let mode = get_attr_opt::<str>(node, "mode")?;
let constant_value = get_attr_opt::<f32>(node, "constant_value")?;
let axes_opt = get_attr_opt::<[i64]>(node, "axes")?;
if let Some(pads) = pads_opt {
let dims = input.dims();
let num_axes = dims.len();
let all_axes:Vec<usize> = (0..num_axes).collect();


let axes = axes_opt.map(|axes|{
axes.iter().map(|axis|
if *axis < 0 {
num_axes - *axis as usize
} else {
*axis as usize
}
).collect::<Vec<_>>()
}).unwrap_or(all_axes);

// TODO: Support negative padding.
let pads = pads.iter().map(|&v| v as usize).collect::<Vec<_>>();


let output = match mode {
None | Some("constant") => {
axes.iter().enumerate().fold(Ok(input.clone()), |tensor,(idx,axis)| {
if let Some(val) = constant_value {
tensor?.pad_with_const(*axis, pads[idx], pads[idx+ num_axes], *val as f64)
} else {
tensor?.pad_with_zeros(*axis,pads[idx],pads[idx+ num_axes])
}
})
},

Some("edge") => {
axes.iter().enumerate().fold(Ok(input.clone()), |tensor,(idx,axis)| {
tensor?.pad_with_same(*axis,pads[idx],pads[idx+ num_axes])
})
},
// Some("reflect") => (),
// Some("wrap") => (),
Some(s) => bail!("unsupported pad type {s}"),
};
values.insert(node.output[0].clone(), output?);
} else {
values.insert(node.output[0].clone(), input.clone());
}
}
"AveragePool" => {
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#AveragePool
let dilations = get_attr_opt::<[i64]>(node, "dilations")?;
Expand Down
183 changes: 183 additions & 0 deletions candle-onnx/tests/ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -727,6 +727,189 @@ fn test_dropout_operation() -> Result<()> {
Ok(())
}

// "Pad"
#[test]
fn test_pad_const_operation() -> Result<()> {
let mut att_pads = AttributeProto {
name: "pads".to_string(),
ref_attr_name: "pads".to_string(),
i: 0,
doc_string: "axis".to_string(),
r#type: 7,
f: 0.0,
s: vec![],
t: None,
g: None,
sparse_tensor: None,
tp: None,
floats: vec![],
ints: vec![0, 2, 0, 0],
strings: vec![],
tensors: vec![],
graphs: vec![],
sparse_tensors: vec![],
type_protos: vec![],
};
let manual_graph = create_model_proto_with_graph(Some(GraphProto {
node: vec![NodeProto {
op_type: "Pad".to_string(),
domain: "".to_string(),
attribute: vec![att_pads.clone()],
input: vec![INPUT_X.to_string()],
output: vec![OUTPUT_Z.to_string()],
name: "".to_string(),
doc_string: "".to_string(),
}],
name: "".to_string(),
initializer: vec![],
input: vec![
ValueInfoProto {
name: INPUT_X.to_string(),
doc_string: "".to_string(),
r#type: None,
},
ValueInfoProto {
name: INPUT_Y.to_string(),
doc_string: "".to_string(),
r#type: None,
},
],
output: vec![ValueInfoProto {
name: OUTPUT_Z.to_string(),
doc_string: "".to_string(),
r#type: None,
}],
value_info: vec![],
doc_string: "".to_string(),
sparse_initializer: vec![],
quantization_annotation: vec![],
}));
let x = Tensor::from_vec(
//
vec![1.0f32, 2.0f32, 3.0f32, 4.0f32],
&[2, 2],
&Device::Cpu,
)?;

let mut inputs: HashMap<String, Tensor> = HashMap::new();
inputs.insert(INPUT_X.to_string(), x);

let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
assert_eq!(eval.len(), 1);

let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");

let results = z.to_vec2::<f32>()?;

assert_eq!(
results,
vec![vec![0.0, 0.0, 1.0, 2.0], vec![0.0, 0.0, 3.0, 4.0]]
);

Ok(())
}

// "Pad"
#[test]
fn test_pad_edge_operation() -> Result<()> {
let att_pads = AttributeProto {
name: "pads".to_string(),
ref_attr_name: "pads".to_string(),
i: 0,
doc_string: "pods".to_string(),
r#type: 7,
f: 0.0,
s: vec![],
t: None,
g: None,
sparse_tensor: None,
tp: None,
floats: vec![],
ints: vec![0, 2, 0, 0],
strings: vec![],
tensors: vec![],
graphs: vec![],
sparse_tensors: vec![],
type_protos: vec![],
};

let att_mode = AttributeProto {
name: "mode".to_string(),
ref_attr_name: "mode".to_string(),
i: 0,
doc_string: "mode".to_string(),
r#type: 3,
f: 0.0,
s: Vec::from("edge".to_string()),
t: None,
g: None,
sparse_tensor: None,
tp: None,
floats: vec![],
ints: vec![0, 2, 0, 0],
strings: vec![],
tensors: vec![],
graphs: vec![],
sparse_tensors: vec![],
type_protos: vec![],
};

let manual_graph = create_model_proto_with_graph(Some(GraphProto {
node: vec![NodeProto {
op_type: "Pad".to_string(),
domain: "".to_string(),
attribute: vec![att_pads.clone(), att_mode.clone()],
input: vec![INPUT_X.to_string()],
output: vec![OUTPUT_Z.to_string()],
name: "".to_string(),
doc_string: "".to_string(),
}],
name: "".to_string(),
initializer: vec![],
input: vec![
ValueInfoProto {
name: INPUT_X.to_string(),
doc_string: "".to_string(),
r#type: None,
},
ValueInfoProto {
name: INPUT_Y.to_string(),
doc_string: "".to_string(),
r#type: None,
},
],
output: vec![ValueInfoProto {
name: OUTPUT_Z.to_string(),
doc_string: "".to_string(),
r#type: None,
}],
value_info: vec![],
doc_string: "".to_string(),
sparse_initializer: vec![],
quantization_annotation: vec![],
}));
let x = Tensor::from_vec(
//
vec![1.0f32, 2.0f32, 3.0f32, 4.0f32],
&[2, 2],
&Device::Cpu,
)?;

let mut inputs: HashMap<String, Tensor> = HashMap::new();
inputs.insert(INPUT_X.to_string(), x);

let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
assert_eq!(eval.len(), 1);

let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");

let results = z.to_vec2::<f32>()?;

assert_eq!(results, vec![vec![1.0, 1.0 ,1.0, 2.0], vec![3.0, 3.0, 3.0, 4.0]]);

Ok(())
}

// "Flatten"
#[test]
fn test_flatten_operation() -> Result<()> {
Expand Down

0 comments on commit 14b0930

Please sign in to comment.