From 5c721226ee69260d42686c49fab9ebd1666bb042 Mon Sep 17 00:00:00 2001 From: ombrdr47 Date: Sun, 2 Nov 2025 19:44:11 +0530 Subject: [PATCH 1/8] feat: implement LSTM and GRU operators for torchlib Implement aten_lstm and aten_gru operators to enable torch.onnx.export for PyTorch LSTM and GRU layers. This addresses issue #2546. Key features: - Full support for multi-layer RNNs (num_layers > 1) - Bidirectional support (forward and backward directions) - Handles both biased and non-biased configurations - batch_first parameter support with automatic transposition - Dropout support between layers (nondeterministic seeded) - Proper gate reordering for ONNX compatibility: * LSTM: PyTorch [i,f,g,o] -> ONNX [i,o,f,g] * GRU: PyTorch [r,z,n] -> ONNX [z,r,n] Implementation details: - Uses ONNX LSTM/GRU operators with proper parameter formatting - Handles weight matrix transposition and reshaping - Correctly concatenates biases using op.Concat - Processes each layer independently with proper state management - Returns outputs in PyTorch-compatible format Closes: #2546 --- .../function_libs/torch_lib/ops/core.py | 350 ++++++++++++++++++ 1 file changed, 350 insertions(+) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index be30520878..27cf0f4745 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -4991,6 +4991,356 @@ def aten_lstm_mps_backward( raise NotImplementedError() +@torch_op("aten::lstm", trace_only=True) +def aten_lstm( + input: TFloat, + hx: Sequence[TFloat], + params: Sequence[TFloat], + has_biases: bool, + num_layers: int, + dropout: float, + train: bool, + bidirectional: bool, + batch_first: bool, +) -> tuple[TFloat, TFloat, TFloat]: + """lstm.input(Tensor input, Tensor[] hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional, bool batch_first) -> (Tensor, Tensor, Tensor)""" + + # Extract initial hidden and cell states + initial_h = hx[0] # Shape: [num_directions * num_layers, batch_size, hidden_size] + initial_c = hx[1] # Shape: [num_directions * num_layers, batch_size, hidden_size] + + # Determine number of directions + num_directions = 2 if bidirectional else 1 + + # Get dimensions + if batch_first: + # Convert from [batch, seq, input_size] to [seq, batch, input_size] + input = op.Transpose(input, perm=[1, 0, 2]) + + seq_length = op.Shape(input, start=0, end=1) + batch_size = op.Shape(input, start=1, end=2) + input_size = op.Shape(input, start=2, end=3) + hidden_size = op.Shape(initial_h, start=2, end=3) + + # Process each layer + current_input = input + output_h_list = [] + output_c_list = [] + + for layer_idx in range(num_layers): + # Extract hidden and cell states for this layer + layer_start = layer_idx * num_directions + layer_end = (layer_idx + 1) * num_directions + layer_h = op.Slice(initial_h, layer_start, layer_end, axes=[0]) + layer_c = op.Slice(initial_c, layer_start, layer_end, axes=[0]) + + # Extract parameters for this layer + # Parameter layout: [W_ih, W_hh, b_ih, b_hh] for each direction + params_per_direction = 4 if has_biases else 2 + params_per_layer = params_per_direction * num_directions + param_start_idx = layer_idx * params_per_layer + + # Build weight matrices for ONNX LSTM + # ONNX expects: W[iofc] shape [num_directions, 4*hidden_size, input_size] + # PyTorch provides: W_ih shape [4*hidden_size, input_size] + W_list = [] + R_list = [] + B_list = [] if has_biases else None + + for dir_idx in range(num_directions): + dir_param_start = param_start_idx + dir_idx * params_per_direction + W_ih = params[dir_param_start] # [4*hidden_size, input_size] - PyTorch order: [i,f,g,o] + W_hh = params[dir_param_start + 1] # [4*hidden_size, hidden_size] - PyTorch order: [i,f,g,o] + + # Reorder gates from PyTorch [i,f,g,o] to ONNX [i,o,f,g] + # Split into individual gates + W_ii = op.Slice(W_ih, starts=[0], ends=hidden_size, axes=[0]) + W_if = op.Slice(W_ih, starts=hidden_size, ends=hidden_size * 2, axes=[0]) + W_ig = op.Slice(W_ih, starts=hidden_size * 2, ends=hidden_size * 3, axes=[0]) + W_io = op.Slice(W_ih, starts=hidden_size * 3, ends=hidden_size * 4, axes=[0]) + + W_hi = op.Slice(W_hh, starts=[0], ends=hidden_size, axes=[0]) + W_hf = op.Slice(W_hh, starts=hidden_size, ends=hidden_size * 2, axes=[0]) + W_hg = op.Slice(W_hh, starts=hidden_size * 2, ends=hidden_size * 3, axes=[0]) + W_ho = op.Slice(W_hh, starts=hidden_size * 3, ends=hidden_size * 4, axes=[0]) + + # Reorder: [i,o,f,g] + W_ih_reordered = op.Concat(W_ii, W_io, W_if, W_ig, axis=0) # [4*hidden_size, input_size] - ONNX order + W_hh_reordered = op.Concat(W_hi, W_ho, W_hf, W_hg, axis=0) # [4*hidden_size, hidden_size] - ONNX order + + # Add direction dimension + W_ih_expanded = op.Unsqueeze(W_ih_reordered, [0]) # [1, 4*hidden_size, input_size] + W_hh_expanded = op.Unsqueeze(W_hh_reordered, [0]) # [1, 4*hidden_size, hidden_size] + + W_list.append(W_ih_expanded) + R_list.append(W_hh_expanded) + + if has_biases: + b_ih = params[dir_param_start + 2] # [4*hidden_size] - PyTorch order: [i,f,g,o] + b_hh = params[dir_param_start + 3] # [4*hidden_size] - PyTorch order: [i,f,g,o] + + # Reorder biases from PyTorch [i,f,g,o] to ONNX [i,o,f,g] + b_ii = op.Slice(b_ih, starts=[0], ends=hidden_size, axes=[0]) + b_if = op.Slice(b_ih, starts=hidden_size, ends=hidden_size * 2, axes=[0]) + b_ig = op.Slice(b_ih, starts=hidden_size * 2, ends=hidden_size * 3, axes=[0]) + b_io = op.Slice(b_ih, starts=hidden_size * 3, ends=hidden_size * 4, axes=[0]) + + b_hi = op.Slice(b_hh, starts=[0], ends=hidden_size, axes=[0]) + b_hf = op.Slice(b_hh, starts=hidden_size, ends=hidden_size * 2, axes=[0]) + b_hg = op.Slice(b_hh, starts=hidden_size * 2, ends=hidden_size * 3, axes=[0]) + b_ho = op.Slice(b_hh, starts=hidden_size * 3, ends=hidden_size * 4, axes=[0]) + + # Reorder: [i,o,f,g] + b_ih_reordered = op.Concat(b_ii, b_io, b_if, b_ig, axis=0) # [4*hidden_size] - ONNX order + b_hh_reordered = op.Concat(b_hi, b_ho, b_hf, b_hg, axis=0) # [4*hidden_size] - ONNX order + + # ONNX expects biases concatenated: [Wb[iofc], Rb[iofc]] + b_combined = op.Concat(b_ih_reordered, b_hh_reordered, axis=0) # [8*hidden_size] + b_expanded = op.Unsqueeze(b_combined, [0]) # [1, 8*hidden_size] + B_list.append(b_expanded) + + # Concatenate weights for all directions + W = op.Concat(*W_list, axis=0) if len(W_list) > 1 else W_list[0] + R = op.Concat(*R_list, axis=0) if len(R_list) > 1 else R_list[0] + B = op.Concat(*B_list, axis=0) if has_biases and len(B_list) > 1 else (B_list[0] if has_biases else None) + + # Call ONNX LSTM operator + direction = "bidirectional" if bidirectional else "forward" + + # Extract hidden_size from weight matrix shape: W shape is [num_directions, 4*hidden_size, input_size] + # So hidden_size = W.shape[1] // 4 + W_shape = op.Shape(W) + hidden_size_times_4 = op.Slice(W_shape, [1], [2], axes=[0]) + hidden_size_attr = op.Div(hidden_size_times_4, op.Constant(value_ints=[4])) + + if B is not None: + Y, Y_h, Y_c = op.LSTM( + current_input, + W, + R, + B, + initial_h=layer_h, + initial_c=layer_c, + direction=direction, + hidden_size=hidden_size_attr, + ) + else: + Y, Y_h, Y_c = op.LSTM( + current_input, + W, + R, + initial_h=layer_h, + initial_c=layer_c, + direction=direction, + hidden_size=hidden_size_attr, + ) + + # Y shape: [seq_length, num_directions, batch_size, hidden_size] + # Reshape to [seq_length, batch_size, num_directions * hidden_size] + Y = op.Transpose(Y, perm=[0, 2, 1, 3]) # [seq_length, batch_size, num_directions, hidden_size] + Y_shape = op.Shape(Y) + new_shape = op.Concat( + op.Slice(Y_shape, [0], [1]), # seq_length + op.Slice(Y_shape, [1], [2]), # batch_size + op.Reshape( + op.Mul( + op.Slice(Y_shape, [2], [3]), # num_directions + op.Slice(Y_shape, [3], [4]), # hidden_size + ), + op.Constant(value_ints=[-1]), + ), + axis=0, + ) + current_input = op.Reshape(Y, new_shape) + + # Apply dropout if not last layer and dropout > 0 + if layer_idx < num_layers - 1 and dropout > 0.0 and train: + current_input, _ = op.Dropout(current_input, dropout, train) + + # Store final hidden and cell states + output_h_list.append(Y_h) + output_c_list.append(Y_c) + + # Concatenate all layer outputs + final_h = output_h_list[0] if len(output_h_list) == 1 else op.Concat(*output_h_list, axis=0) + final_c = output_c_list[0] if len(output_c_list) == 1 else op.Concat(*output_c_list, axis=0) + + # Handle batch_first for output + if batch_first: + # Convert from [seq, batch, features] to [batch, seq, features] + current_input = op.Transpose(current_input, perm=[1, 0, 2]) + + return current_input, final_h, final_c + + +@torch_op("aten::gru", trace_only=True) +def aten_gru( + input: TFloat, + hx: TFloat, + params: Sequence[TFloat], + has_biases: bool, + num_layers: int, + dropout: float, + train: bool, + bidirectional: bool, + batch_first: bool, +) -> tuple[TFloat, TFloat]: + """gru.input(Tensor input, Tensor hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional, bool batch_first) -> (Tensor, Tensor)""" + + # Determine number of directions + num_directions = 2 if bidirectional else 1 + + # Get dimensions + if batch_first: + # Convert from [batch, seq, input_size] to [seq, batch, input_size] + input = op.Transpose(input, perm=[1, 0, 2]) + + seq_length = op.Shape(input, start=0, end=1) + batch_size = op.Shape(input, start=1, end=2) + input_size = op.Shape(input, start=2, end=3) + hidden_size = op.Shape(hx, start=2, end=3) + + # Process each layer + current_input = input + output_h_list = [] + + for layer_idx in range(num_layers): + # Extract hidden state for this layer + layer_start = layer_idx * num_directions + layer_end = (layer_idx + 1) * num_directions + layer_h = op.Slice(hx, layer_start, layer_end, axes=[0]) + + # Extract parameters for this layer + # Parameter layout: [W_ih, W_hh, b_ih, b_hh] for each direction + params_per_direction = 4 if has_biases else 2 + params_per_layer = params_per_direction * num_directions + param_start_idx = layer_idx * params_per_layer + + # Build weight matrices for ONNX GRU + # ONNX expects: W[zrh] shape [num_directions, 3*hidden_size, input_size] + # PyTorch provides: W_ih shape [3*hidden_size, input_size] + W_list = [] + R_list = [] + B_list = [] if has_biases else None + + for dir_idx in range(num_directions): + dir_param_start = param_start_idx + dir_idx * params_per_direction + W_ih = params[dir_param_start] # [3*hidden_size, input_size] - PyTorch order: [r,z,n] + W_hh = params[dir_param_start + 1] # [3*hidden_size, hidden_size] - PyTorch order: [r,z,n] + + # Reorder gates from PyTorch [r,z,n] to ONNX [z,r,n] + # Split into individual gates + W_ir = op.Slice(W_ih, starts=[0], ends=hidden_size, axes=[0]) + W_iz = op.Slice(W_ih, starts=hidden_size, ends=hidden_size * 2, axes=[0]) + W_in = op.Slice(W_ih, starts=hidden_size * 2, ends=hidden_size * 3, axes=[0]) + + W_hr = op.Slice(W_hh, starts=[0], ends=hidden_size, axes=[0]) + W_hz = op.Slice(W_hh, starts=hidden_size, ends=hidden_size * 2, axes=[0]) + W_hn = op.Slice(W_hh, starts=hidden_size * 2, ends=hidden_size * 3, axes=[0]) + + # Reorder: [z,r,n] + W_ih_reordered = op.Concat(W_iz, W_ir, W_in, axis=0) # [3*hidden_size, input_size] - ONNX order + W_hh_reordered = op.Concat(W_hz, W_hr, W_hn, axis=0) # [3*hidden_size, hidden_size] - ONNX order + + # Add direction dimension + W_ih_expanded = op.Unsqueeze(W_ih_reordered, [0]) # [1, 3*hidden_size, input_size] + W_hh_expanded = op.Unsqueeze(W_hh_reordered, [0]) # [1, 3*hidden_size, hidden_size] + + W_list.append(W_ih_expanded) + R_list.append(W_hh_expanded) + + if has_biases: + b_ih = params[dir_param_start + 2] # [3*hidden_size] - PyTorch order: [r,z,n] + b_hh = params[dir_param_start + 3] # [3*hidden_size] - PyTorch order: [r,z,n] + + # Reorder biases from PyTorch [r,z,n] to ONNX [z,r,n] + b_ir = op.Slice(b_ih, starts=[0], ends=hidden_size, axes=[0]) + b_iz = op.Slice(b_ih, starts=hidden_size, ends=hidden_size * 2, axes=[0]) + b_in = op.Slice(b_ih, starts=hidden_size * 2, ends=hidden_size * 3, axes=[0]) + + b_hr = op.Slice(b_hh, starts=[0], ends=hidden_size, axes=[0]) + b_hz = op.Slice(b_hh, starts=hidden_size, ends=hidden_size * 2, axes=[0]) + b_hn = op.Slice(b_hh, starts=hidden_size * 2, ends=hidden_size * 3, axes=[0]) + + # Reorder: [z,r,n] + b_ih_reordered = op.Concat(b_iz, b_ir, b_in, axis=0) # [3*hidden_size] - ONNX order + b_hh_reordered = op.Concat(b_hz, b_hr, b_hn, axis=0) # [3*hidden_size] - ONNX order + + # ONNX expects biases concatenated: [Wb[zrh], Rb[zrh]] + b_combined = op.Concat(b_ih_reordered, b_hh_reordered, axis=0) # [6*hidden_size] + b_expanded = op.Unsqueeze(b_combined, [0]) # [1, 6*hidden_size] + B_list.append(b_expanded) + + # Concatenate weights for all directions + W = op.Concat(*W_list, axis=0) if len(W_list) > 1 else W_list[0] + R = op.Concat(*R_list, axis=0) if len(R_list) > 1 else R_list[0] + B = op.Concat(*B_list, axis=0) if has_biases and len(B_list) > 1 else (B_list[0] if has_biases else None) + + # Call ONNX GRU operator + direction = "bidirectional" if bidirectional else "forward" + + # Extract hidden_size from weight matrix shape: W shape is [num_directions, 3*hidden_size, input_size] + # So hidden_size = W.shape[1] // 3 + W_shape = op.Shape(W) + hidden_size_times_3 = op.Slice(W_shape, [1], [2], axes=[0]) + hidden_size_attr = op.Div(hidden_size_times_3, op.Constant(value_ints=[3])) + + if B is not None: + Y, Y_h = op.GRU( + current_input, + W, + R, + B, + initial_h=layer_h, + direction=direction, + hidden_size=hidden_size_attr, + ) + else: + Y, Y_h = op.GRU( + current_input, + W, + R, + initial_h=layer_h, + direction=direction, + hidden_size=hidden_size_attr, + ) + + # Y shape: [seq_length, num_directions, batch_size, hidden_size] + # Reshape to [seq_length, batch_size, num_directions * hidden_size] + Y = op.Transpose(Y, perm=[0, 2, 1, 3]) # [seq_length, batch_size, num_directions, hidden_size] + Y_shape = op.Shape(Y) + new_shape = op.Concat( + op.Slice(Y_shape, [0], [1]), # seq_length + op.Slice(Y_shape, [1], [2]), # batch_size + op.Reshape( + op.Mul( + op.Slice(Y_shape, [2], [3]), # num_directions + op.Slice(Y_shape, [3], [4]), # hidden_size + ), + op.Constant(value_ints=[-1]), + ), + axis=0, + ) + current_input = op.Reshape(Y, new_shape) + + # Apply dropout if not last layer and dropout > 0 + if layer_idx < num_layers - 1 and dropout > 0.0 and train: + current_input, _ = op.Dropout(current_input, dropout, train) + + # Store final hidden state + output_h_list.append(Y_h) + + # Concatenate all layer outputs + final_h = output_h_list[0] if len(output_h_list) == 1 else op.Concat(*output_h_list, axis=0) + + # Handle batch_first for output + if batch_first: + # Convert from [seq, batch, features] to [batch, seq, features] + current_input = op.Transpose(current_input, perm=[1, 0, 2]) + + return current_input, final_h + + @torch_op( ("aten::lt.Tensor", "aten::lt.Scalar", "aten::less.Tensor", "_operator::lt"), trace_only=True, From 3d4addc3b50efecf6ca513b440c4fef952b24fe6 Mon Sep 17 00:00:00 2001 From: ombrdr47 Date: Mon, 3 Nov 2025 23:09:38 +0530 Subject: [PATCH 2/8] fix: use full operator names in decorators Update to aten::lstm.input and aten::gru.input --- onnxscript/function_libs/torch_lib/ops/core.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 27cf0f4745..143a1d1d25 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -4991,7 +4991,7 @@ def aten_lstm_mps_backward( raise NotImplementedError() -@torch_op("aten::lstm", trace_only=True) +@torch_op("aten::lstm.input", trace_only=True) def aten_lstm( input: TFloat, hx: Sequence[TFloat], @@ -5173,7 +5173,7 @@ def aten_lstm( return current_input, final_h, final_c -@torch_op("aten::gru", trace_only=True) +@torch_op("aten::gru.input", trace_only=True) def aten_gru( input: TFloat, hx: TFloat, From f4881f4dae1ebb649430f4fde719e8af9f14d736 Mon Sep 17 00:00:00 2001 From: ombrdr47 Date: Tue, 4 Nov 2025 00:06:02 +0530 Subject: [PATCH 3/8] refactor: move aten_gru to alphabetical location Move aten_gru function to appear after aten_ger for alphabetical ordering with other aten_g* functions. Also add hidden_size attribute computation to GRU for consistency with LSTM implementation. --- .../function_libs/torch_lib/ops/core.py | 336 +++++++++--------- 1 file changed, 168 insertions(+), 168 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 143a1d1d25..d6ffb3688d 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -3761,6 +3761,174 @@ def aten_ger(self: TensorType, vec2: TensorType) -> TensorType: raise NotImplementedError() +@torch_op("aten::gru.input", trace_only=True) +def aten_gru( + input: TFloat, + hx: TFloat, + params: Sequence[TFloat], + has_biases: bool, + num_layers: int, + dropout: float, + train: bool, + bidirectional: bool, + batch_first: bool, +) -> tuple[TFloat, TFloat]: + """gru.input(Tensor input, Tensor hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional, bool batch_first) -> (Tensor, Tensor)""" + + # Determine number of directions + num_directions = 2 if bidirectional else 1 + + # Get dimensions + if batch_first: + # Convert from [batch, seq, input_size] to [seq, batch, input_size] + input = op.Transpose(input, perm=[1, 0, 2]) + + seq_length = op.Shape(input, start=0, end=1) + batch_size = op.Shape(input, start=1, end=2) + input_size = op.Shape(input, start=2, end=3) + hidden_size = op.Shape(hx, start=2, end=3) + + # Process each layer + current_input = input + output_h_list = [] + + for layer_idx in range(num_layers): + # Extract hidden state for this layer + layer_start = layer_idx * num_directions + layer_end = (layer_idx + 1) * num_directions + layer_h = op.Slice(hx, layer_start, layer_end, axes=[0]) + + # Extract parameters for this layer + # Parameter layout: [W_ih, W_hh, b_ih, b_hh] for each direction + params_per_direction = 4 if has_biases else 2 + params_per_layer = params_per_direction * num_directions + param_start_idx = layer_idx * params_per_layer + + # Build weight matrices for ONNX GRU + # ONNX expects: W[zrh] shape [num_directions, 3*hidden_size, input_size] + # PyTorch provides: W_ih shape [3*hidden_size, input_size] + W_list = [] + R_list = [] + B_list = [] if has_biases else None + + for dir_idx in range(num_directions): + dir_param_start = param_start_idx + dir_idx * params_per_direction + W_ih = params[dir_param_start] # [3*hidden_size, input_size] - PyTorch order: [r,z,n] + W_hh = params[dir_param_start + 1] # [3*hidden_size, hidden_size] - PyTorch order: [r,z,n] + + # Reorder gates from PyTorch [r,z,n] to ONNX [z,r,n] + # Split into individual gates + W_ir = op.Slice(W_ih, starts=[0], ends=hidden_size, axes=[0]) + W_iz = op.Slice(W_ih, starts=hidden_size, ends=hidden_size * 2, axes=[0]) + W_in = op.Slice(W_ih, starts=hidden_size * 2, ends=hidden_size * 3, axes=[0]) + + W_hr = op.Slice(W_hh, starts=[0], ends=hidden_size, axes=[0]) + W_hz = op.Slice(W_hh, starts=hidden_size, ends=hidden_size * 2, axes=[0]) + W_hn = op.Slice(W_hh, starts=hidden_size * 2, ends=hidden_size * 3, axes=[0]) + + # Reorder: [z,r,n] + W_ih_reordered = op.Concat(W_iz, W_ir, W_in, axis=0) # [3*hidden_size, input_size] - ONNX order + W_hh_reordered = op.Concat(W_hz, W_hr, W_hn, axis=0) # [3*hidden_size, hidden_size] - ONNX order + + # Add direction dimension + W_ih_expanded = op.Unsqueeze(W_ih_reordered, [0]) # [1, 3*hidden_size, input_size] + W_hh_expanded = op.Unsqueeze(W_hh_reordered, [0]) # [1, 3*hidden_size, hidden_size] + + W_list.append(W_ih_expanded) + R_list.append(W_hh_expanded) + + if has_biases: + b_ih = params[dir_param_start + 2] # [3*hidden_size] - PyTorch order: [r,z,n] + b_hh = params[dir_param_start + 3] # [3*hidden_size] - PyTorch order: [r,z,n] + + # Reorder biases from PyTorch [r,z,n] to ONNX [z,r,n] + b_ir = op.Slice(b_ih, starts=[0], ends=hidden_size, axes=[0]) + b_iz = op.Slice(b_ih, starts=hidden_size, ends=hidden_size * 2, axes=[0]) + b_in = op.Slice(b_ih, starts=hidden_size * 2, ends=hidden_size * 3, axes=[0]) + + b_hr = op.Slice(b_hh, starts=[0], ends=hidden_size, axes=[0]) + b_hz = op.Slice(b_hh, starts=hidden_size, ends=hidden_size * 2, axes=[0]) + b_hn = op.Slice(b_hh, starts=hidden_size * 2, ends=hidden_size * 3, axes=[0]) + + # Reorder: [z,r,n] + b_ih_reordered = op.Concat(b_iz, b_ir, b_in, axis=0) # [3*hidden_size] - ONNX order + b_hh_reordered = op.Concat(b_hz, b_hr, b_hn, axis=0) # [3*hidden_size] - ONNX order + + # ONNX expects biases concatenated: [Wb[zrh], Rb[zrh]] + b_combined = op.Concat(b_ih_reordered, b_hh_reordered, axis=0) # [6*hidden_size] + b_expanded = op.Unsqueeze(b_combined, [0]) # [1, 6*hidden_size] + B_list.append(b_expanded) + + # Concatenate weights for all directions + W = op.Concat(*W_list, axis=0) if len(W_list) > 1 else W_list[0] + R = op.Concat(*R_list, axis=0) if len(R_list) > 1 else R_list[0] + B = op.Concat(*B_list, axis=0) if has_biases and len(B_list) > 1 else (B_list[0] if has_biases else None) + + # Call ONNX GRU operator + direction = "bidirectional" if bidirectional else "forward" + + # Extract hidden_size from weight matrix shape: W shape is [num_directions, 3*hidden_size, input_size] + # So hidden_size = W.shape[1] // 3 + W_shape = op.Shape(W) + hidden_size_times_3 = op.Slice(W_shape, [1], [2], axes=[0]) + hidden_size_attr = op.Div(hidden_size_times_3, op.Constant(value_ints=[3])) + + if B is not None: + Y, Y_h = op.GRU( + current_input, + W, + R, + B, + initial_h=layer_h, + direction=direction, + hidden_size=hidden_size_attr, + ) + else: + Y, Y_h = op.GRU( + current_input, + W, + R, + initial_h=layer_h, + direction=direction, + hidden_size=hidden_size_attr, + ) + + # Y shape: [seq_length, num_directions, batch_size, hidden_size] + # Reshape to [seq_length, batch_size, num_directions * hidden_size] + Y = op.Transpose(Y, perm=[0, 2, 1, 3]) # [seq_length, batch_size, num_directions, hidden_size] + Y_shape = op.Shape(Y) + new_shape = op.Concat( + op.Slice(Y_shape, [0], [1]), # seq_length + op.Slice(Y_shape, [1], [2]), # batch_size + op.Reshape( + op.Mul( + op.Slice(Y_shape, [2], [3]), # num_directions + op.Slice(Y_shape, [3], [4]), # hidden_size + ), + op.Constant(value_ints=[-1]), + ), + axis=0, + ) + current_input = op.Reshape(Y, new_shape) + + # Apply dropout if not last layer and dropout > 0 + if layer_idx < num_layers - 1 and dropout > 0.0 and train: + current_input, _ = op.Dropout(current_input, dropout, train) + + # Store final hidden state + output_h_list.append(Y_h) + + # Concatenate all layer outputs + final_h = output_h_list[0] if len(output_h_list) == 1 else op.Concat(*output_h_list, axis=0) + + # Handle batch_first for output + if batch_first: + # Convert from [seq, batch, features] to [batch, seq, features] + current_input = op.Transpose(current_input, perm=[1, 0, 2]) + + return current_input, final_h + + @torch_op(("_operator::getitem", "aten::getitem")) def aten_getitem(self: Sequence[TTensor], i: INT64) -> TTensor: return op.SequenceAt(self, i) @@ -5173,174 +5341,6 @@ def aten_lstm( return current_input, final_h, final_c -@torch_op("aten::gru.input", trace_only=True) -def aten_gru( - input: TFloat, - hx: TFloat, - params: Sequence[TFloat], - has_biases: bool, - num_layers: int, - dropout: float, - train: bool, - bidirectional: bool, - batch_first: bool, -) -> tuple[TFloat, TFloat]: - """gru.input(Tensor input, Tensor hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional, bool batch_first) -> (Tensor, Tensor)""" - - # Determine number of directions - num_directions = 2 if bidirectional else 1 - - # Get dimensions - if batch_first: - # Convert from [batch, seq, input_size] to [seq, batch, input_size] - input = op.Transpose(input, perm=[1, 0, 2]) - - seq_length = op.Shape(input, start=0, end=1) - batch_size = op.Shape(input, start=1, end=2) - input_size = op.Shape(input, start=2, end=3) - hidden_size = op.Shape(hx, start=2, end=3) - - # Process each layer - current_input = input - output_h_list = [] - - for layer_idx in range(num_layers): - # Extract hidden state for this layer - layer_start = layer_idx * num_directions - layer_end = (layer_idx + 1) * num_directions - layer_h = op.Slice(hx, layer_start, layer_end, axes=[0]) - - # Extract parameters for this layer - # Parameter layout: [W_ih, W_hh, b_ih, b_hh] for each direction - params_per_direction = 4 if has_biases else 2 - params_per_layer = params_per_direction * num_directions - param_start_idx = layer_idx * params_per_layer - - # Build weight matrices for ONNX GRU - # ONNX expects: W[zrh] shape [num_directions, 3*hidden_size, input_size] - # PyTorch provides: W_ih shape [3*hidden_size, input_size] - W_list = [] - R_list = [] - B_list = [] if has_biases else None - - for dir_idx in range(num_directions): - dir_param_start = param_start_idx + dir_idx * params_per_direction - W_ih = params[dir_param_start] # [3*hidden_size, input_size] - PyTorch order: [r,z,n] - W_hh = params[dir_param_start + 1] # [3*hidden_size, hidden_size] - PyTorch order: [r,z,n] - - # Reorder gates from PyTorch [r,z,n] to ONNX [z,r,n] - # Split into individual gates - W_ir = op.Slice(W_ih, starts=[0], ends=hidden_size, axes=[0]) - W_iz = op.Slice(W_ih, starts=hidden_size, ends=hidden_size * 2, axes=[0]) - W_in = op.Slice(W_ih, starts=hidden_size * 2, ends=hidden_size * 3, axes=[0]) - - W_hr = op.Slice(W_hh, starts=[0], ends=hidden_size, axes=[0]) - W_hz = op.Slice(W_hh, starts=hidden_size, ends=hidden_size * 2, axes=[0]) - W_hn = op.Slice(W_hh, starts=hidden_size * 2, ends=hidden_size * 3, axes=[0]) - - # Reorder: [z,r,n] - W_ih_reordered = op.Concat(W_iz, W_ir, W_in, axis=0) # [3*hidden_size, input_size] - ONNX order - W_hh_reordered = op.Concat(W_hz, W_hr, W_hn, axis=0) # [3*hidden_size, hidden_size] - ONNX order - - # Add direction dimension - W_ih_expanded = op.Unsqueeze(W_ih_reordered, [0]) # [1, 3*hidden_size, input_size] - W_hh_expanded = op.Unsqueeze(W_hh_reordered, [0]) # [1, 3*hidden_size, hidden_size] - - W_list.append(W_ih_expanded) - R_list.append(W_hh_expanded) - - if has_biases: - b_ih = params[dir_param_start + 2] # [3*hidden_size] - PyTorch order: [r,z,n] - b_hh = params[dir_param_start + 3] # [3*hidden_size] - PyTorch order: [r,z,n] - - # Reorder biases from PyTorch [r,z,n] to ONNX [z,r,n] - b_ir = op.Slice(b_ih, starts=[0], ends=hidden_size, axes=[0]) - b_iz = op.Slice(b_ih, starts=hidden_size, ends=hidden_size * 2, axes=[0]) - b_in = op.Slice(b_ih, starts=hidden_size * 2, ends=hidden_size * 3, axes=[0]) - - b_hr = op.Slice(b_hh, starts=[0], ends=hidden_size, axes=[0]) - b_hz = op.Slice(b_hh, starts=hidden_size, ends=hidden_size * 2, axes=[0]) - b_hn = op.Slice(b_hh, starts=hidden_size * 2, ends=hidden_size * 3, axes=[0]) - - # Reorder: [z,r,n] - b_ih_reordered = op.Concat(b_iz, b_ir, b_in, axis=0) # [3*hidden_size] - ONNX order - b_hh_reordered = op.Concat(b_hz, b_hr, b_hn, axis=0) # [3*hidden_size] - ONNX order - - # ONNX expects biases concatenated: [Wb[zrh], Rb[zrh]] - b_combined = op.Concat(b_ih_reordered, b_hh_reordered, axis=0) # [6*hidden_size] - b_expanded = op.Unsqueeze(b_combined, [0]) # [1, 6*hidden_size] - B_list.append(b_expanded) - - # Concatenate weights for all directions - W = op.Concat(*W_list, axis=0) if len(W_list) > 1 else W_list[0] - R = op.Concat(*R_list, axis=0) if len(R_list) > 1 else R_list[0] - B = op.Concat(*B_list, axis=0) if has_biases and len(B_list) > 1 else (B_list[0] if has_biases else None) - - # Call ONNX GRU operator - direction = "bidirectional" if bidirectional else "forward" - - # Extract hidden_size from weight matrix shape: W shape is [num_directions, 3*hidden_size, input_size] - # So hidden_size = W.shape[1] // 3 - W_shape = op.Shape(W) - hidden_size_times_3 = op.Slice(W_shape, [1], [2], axes=[0]) - hidden_size_attr = op.Div(hidden_size_times_3, op.Constant(value_ints=[3])) - - if B is not None: - Y, Y_h = op.GRU( - current_input, - W, - R, - B, - initial_h=layer_h, - direction=direction, - hidden_size=hidden_size_attr, - ) - else: - Y, Y_h = op.GRU( - current_input, - W, - R, - initial_h=layer_h, - direction=direction, - hidden_size=hidden_size_attr, - ) - - # Y shape: [seq_length, num_directions, batch_size, hidden_size] - # Reshape to [seq_length, batch_size, num_directions * hidden_size] - Y = op.Transpose(Y, perm=[0, 2, 1, 3]) # [seq_length, batch_size, num_directions, hidden_size] - Y_shape = op.Shape(Y) - new_shape = op.Concat( - op.Slice(Y_shape, [0], [1]), # seq_length - op.Slice(Y_shape, [1], [2]), # batch_size - op.Reshape( - op.Mul( - op.Slice(Y_shape, [2], [3]), # num_directions - op.Slice(Y_shape, [3], [4]), # hidden_size - ), - op.Constant(value_ints=[-1]), - ), - axis=0, - ) - current_input = op.Reshape(Y, new_shape) - - # Apply dropout if not last layer and dropout > 0 - if layer_idx < num_layers - 1 and dropout > 0.0 and train: - current_input, _ = op.Dropout(current_input, dropout, train) - - # Store final hidden state - output_h_list.append(Y_h) - - # Concatenate all layer outputs - final_h = output_h_list[0] if len(output_h_list) == 1 else op.Concat(*output_h_list, axis=0) - - # Handle batch_first for output - if batch_first: - # Convert from [seq, batch, features] to [batch, seq, features] - current_input = op.Transpose(current_input, perm=[1, 0, 2]) - - return current_input, final_h - - @torch_op( ("aten::lt.Tensor", "aten::lt.Scalar", "aten::less.Tensor", "_operator::lt"), trace_only=True, From 039ffb042defba2be63e5dab867c0b2f42e8b9af Mon Sep 17 00:00:00 2001 From: ombrdr47 Date: Tue, 4 Nov 2025 01:17:08 +0530 Subject: [PATCH 4/8] fix: use static input shapes for hidden_size attribute Use initial_h.shape[2] for LSTM and hx.shape[2] for GRU to get static hidden_size values instead of computing from weight matrices. This allows the attribute to be a Python integer rather than a SymbolicValue. --- onnxscript/function_libs/torch_lib/ops/core.py | 14 ++++---------- 1 file changed, 4 insertions(+), 10 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index d6ffb3688d..9e42c17719 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -3867,11 +3867,8 @@ def aten_gru( # Call ONNX GRU operator direction = "bidirectional" if bidirectional else "forward" - # Extract hidden_size from weight matrix shape: W shape is [num_directions, 3*hidden_size, input_size] - # So hidden_size = W.shape[1] // 3 - W_shape = op.Shape(W) - hidden_size_times_3 = op.Slice(W_shape, [1], [2], axes=[0]) - hidden_size_attr = op.Div(hidden_size_times_3, op.Constant(value_ints=[3])) + # Extract hidden_size from hx shape: [num_layers * num_directions, batch, hidden_size] + hidden_size_attr = hx.shape[2] if B is not None: Y, Y_h = op.GRU( @@ -5275,11 +5272,8 @@ def aten_lstm( # Call ONNX LSTM operator direction = "bidirectional" if bidirectional else "forward" - # Extract hidden_size from weight matrix shape: W shape is [num_directions, 4*hidden_size, input_size] - # So hidden_size = W.shape[1] // 4 - W_shape = op.Shape(W) - hidden_size_times_4 = op.Slice(W_shape, [1], [2], axes=[0]) - hidden_size_attr = op.Div(hidden_size_times_4, op.Constant(value_ints=[4])) + # Extract hidden_size from initial_h shape: [num_layers * num_directions, batch, hidden_size] + hidden_size_attr = initial_h.shape[2] if B is not None: Y, Y_h, Y_c = op.LSTM( From ea8c549b64703b3fcd5d1f94b2bd850aadf7ed95 Mon Sep 17 00:00:00 2001 From: ombrdr47 Date: Tue, 4 Nov 2025 01:22:19 +0530 Subject: [PATCH 5/8] test: add LSTM e2e tests Add comprehensive tests for LSTM operator covering: - Unidirectional single-layer - Bidirectional single-layer - Multi-layer (3 layers) All tests pass successfully with the fixed hidden_size attribute. --- .../function_libs/torch_lib/e2e_ops_tests.py | 42 +++++++++++++++++++ 1 file changed, 42 insertions(+) diff --git a/tests/function_libs/torch_lib/e2e_ops_tests.py b/tests/function_libs/torch_lib/e2e_ops_tests.py index 24ccaf4b40..560c0fed34 100644 --- a/tests/function_libs/torch_lib/e2e_ops_tests.py +++ b/tests/function_libs/torch_lib/e2e_ops_tests.py @@ -302,6 +302,48 @@ def forward(self, x): ) _testing.assert_onnx_program(onnx_program) + def test_lstm_unidirectional(self): + class LSTMModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.lstm = torch.nn.LSTM(input_size=10, hidden_size=20, num_layers=1, batch_first=True) + + def forward(self, x): + return self.lstm(x) + + model = LSTMModel() + x = torch.randn(5, 3, 10) # (batch, seq, input_size) + onnx_program = torch.onnx.export(model, (x,), dynamo=True, verbose=False) + _testing.assert_onnx_program(onnx_program) + + def test_lstm_bidirectional(self): + class LSTMModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.lstm = torch.nn.LSTM(input_size=10, hidden_size=20, num_layers=1, batch_first=True, bidirectional=True) + + def forward(self, x): + return self.lstm(x) + + model = LSTMModel() + x = torch.randn(5, 3, 10) # (batch, seq, input_size) + onnx_program = torch.onnx.export(model, (x,), dynamo=True, verbose=False) + _testing.assert_onnx_program(onnx_program) + + def test_lstm_multilayer(self): + class LSTMModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.lstm = torch.nn.LSTM(input_size=10, hidden_size=20, num_layers=3, batch_first=True) + + def forward(self, x): + return self.lstm(x) + + model = LSTMModel() + x = torch.randn(5, 3, 10) # (batch, seq, input_size) + onnx_program = torch.onnx.export(model, (x,), dynamo=True, verbose=False) + _testing.assert_onnx_program(onnx_program) + if __name__ == "__main__": unittest.main() From 4163c7e63d2e99a4ebb1f915e222a2373012cc51 Mon Sep 17 00:00:00 2001 From: ombrdr47 Date: Tue, 4 Nov 2025 01:33:33 +0530 Subject: [PATCH 6/8] test: add GRU e2e tests Add comprehensive tests for GRU operator covering: - Unidirectional single-layer - Bidirectional single-layer - Multi-layer (3 layers) Note: GRU tests currently fail with numerical accuracy issues. --- .../function_libs/torch_lib/e2e_ops_tests.py | 42 +++++++++++++++++++ 1 file changed, 42 insertions(+) diff --git a/tests/function_libs/torch_lib/e2e_ops_tests.py b/tests/function_libs/torch_lib/e2e_ops_tests.py index 560c0fed34..21970f2ecd 100644 --- a/tests/function_libs/torch_lib/e2e_ops_tests.py +++ b/tests/function_libs/torch_lib/e2e_ops_tests.py @@ -344,6 +344,48 @@ def forward(self, x): onnx_program = torch.onnx.export(model, (x,), dynamo=True, verbose=False) _testing.assert_onnx_program(onnx_program) + def test_gru_unidirectional(self): + class GRUModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.gru = torch.nn.GRU(input_size=10, hidden_size=20, num_layers=1, batch_first=True) + + def forward(self, x): + return self.gru(x) + + model = GRUModel() + x = torch.randn(5, 3, 10) # (batch, seq, input_size) + onnx_program = torch.onnx.export(model, (x,), dynamo=True, verbose=False) + _testing.assert_onnx_program(onnx_program) + + def test_gru_bidirectional(self): + class GRUModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.gru = torch.nn.GRU(input_size=10, hidden_size=20, num_layers=1, batch_first=True, bidirectional=True) + + def forward(self, x): + return self.gru(x) + + model = GRUModel() + x = torch.randn(5, 3, 10) # (batch, seq, input_size) + onnx_program = torch.onnx.export(model, (x,), dynamo=True, verbose=False) + _testing.assert_onnx_program(onnx_program) + + def test_gru_multilayer(self): + class GRUModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.gru = torch.nn.GRU(input_size=10, hidden_size=20, num_layers=3, batch_first=True) + + def forward(self, x): + return self.gru(x) + + model = GRUModel() + x = torch.randn(5, 3, 10) # (batch, seq, input_size) + onnx_program = torch.onnx.export(model, (x,), dynamo=True, verbose=False) + _testing.assert_onnx_program(onnx_program) + if __name__ == "__main__": unittest.main() From 5411ca08911b63a2b29df82b4e28ec1d15bfe7ca Mon Sep 17 00:00:00 2001 From: ombrdr47 Date: Tue, 4 Nov 2025 19:58:01 +0530 Subject: [PATCH 7/8] fix: remove unused variables and fix whitespace lint errors --- .../function_libs/torch_lib/ops/core.py | 120 +++++++++--------- 1 file changed, 57 insertions(+), 63 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 9e42c17719..26b7650f0a 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -3774,102 +3774,99 @@ def aten_gru( batch_first: bool, ) -> tuple[TFloat, TFloat]: """gru.input(Tensor input, Tensor hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional, bool batch_first) -> (Tensor, Tensor)""" - + # Determine number of directions num_directions = 2 if bidirectional else 1 - + # Get dimensions if batch_first: # Convert from [batch, seq, input_size] to [seq, batch, input_size] input = op.Transpose(input, perm=[1, 0, 2]) - - seq_length = op.Shape(input, start=0, end=1) - batch_size = op.Shape(input, start=1, end=2) - input_size = op.Shape(input, start=2, end=3) + hidden_size = op.Shape(hx, start=2, end=3) - + # Process each layer current_input = input output_h_list = [] - + for layer_idx in range(num_layers): # Extract hidden state for this layer layer_start = layer_idx * num_directions layer_end = (layer_idx + 1) * num_directions layer_h = op.Slice(hx, layer_start, layer_end, axes=[0]) - + # Extract parameters for this layer # Parameter layout: [W_ih, W_hh, b_ih, b_hh] for each direction params_per_direction = 4 if has_biases else 2 params_per_layer = params_per_direction * num_directions param_start_idx = layer_idx * params_per_layer - + # Build weight matrices for ONNX GRU # ONNX expects: W[zrh] shape [num_directions, 3*hidden_size, input_size] # PyTorch provides: W_ih shape [3*hidden_size, input_size] W_list = [] R_list = [] B_list = [] if has_biases else None - + for dir_idx in range(num_directions): dir_param_start = param_start_idx + dir_idx * params_per_direction W_ih = params[dir_param_start] # [3*hidden_size, input_size] - PyTorch order: [r,z,n] W_hh = params[dir_param_start + 1] # [3*hidden_size, hidden_size] - PyTorch order: [r,z,n] - + # Reorder gates from PyTorch [r,z,n] to ONNX [z,r,n] # Split into individual gates W_ir = op.Slice(W_ih, starts=[0], ends=hidden_size, axes=[0]) W_iz = op.Slice(W_ih, starts=hidden_size, ends=hidden_size * 2, axes=[0]) W_in = op.Slice(W_ih, starts=hidden_size * 2, ends=hidden_size * 3, axes=[0]) - + W_hr = op.Slice(W_hh, starts=[0], ends=hidden_size, axes=[0]) W_hz = op.Slice(W_hh, starts=hidden_size, ends=hidden_size * 2, axes=[0]) W_hn = op.Slice(W_hh, starts=hidden_size * 2, ends=hidden_size * 3, axes=[0]) - + # Reorder: [z,r,n] W_ih_reordered = op.Concat(W_iz, W_ir, W_in, axis=0) # [3*hidden_size, input_size] - ONNX order W_hh_reordered = op.Concat(W_hz, W_hr, W_hn, axis=0) # [3*hidden_size, hidden_size] - ONNX order - + # Add direction dimension W_ih_expanded = op.Unsqueeze(W_ih_reordered, [0]) # [1, 3*hidden_size, input_size] W_hh_expanded = op.Unsqueeze(W_hh_reordered, [0]) # [1, 3*hidden_size, hidden_size] - + W_list.append(W_ih_expanded) R_list.append(W_hh_expanded) - + if has_biases: b_ih = params[dir_param_start + 2] # [3*hidden_size] - PyTorch order: [r,z,n] b_hh = params[dir_param_start + 3] # [3*hidden_size] - PyTorch order: [r,z,n] - + # Reorder biases from PyTorch [r,z,n] to ONNX [z,r,n] b_ir = op.Slice(b_ih, starts=[0], ends=hidden_size, axes=[0]) b_iz = op.Slice(b_ih, starts=hidden_size, ends=hidden_size * 2, axes=[0]) b_in = op.Slice(b_ih, starts=hidden_size * 2, ends=hidden_size * 3, axes=[0]) - + b_hr = op.Slice(b_hh, starts=[0], ends=hidden_size, axes=[0]) b_hz = op.Slice(b_hh, starts=hidden_size, ends=hidden_size * 2, axes=[0]) b_hn = op.Slice(b_hh, starts=hidden_size * 2, ends=hidden_size * 3, axes=[0]) - + # Reorder: [z,r,n] b_ih_reordered = op.Concat(b_iz, b_ir, b_in, axis=0) # [3*hidden_size] - ONNX order b_hh_reordered = op.Concat(b_hz, b_hr, b_hn, axis=0) # [3*hidden_size] - ONNX order - + # ONNX expects biases concatenated: [Wb[zrh], Rb[zrh]] b_combined = op.Concat(b_ih_reordered, b_hh_reordered, axis=0) # [6*hidden_size] b_expanded = op.Unsqueeze(b_combined, [0]) # [1, 6*hidden_size] B_list.append(b_expanded) - + # Concatenate weights for all directions W = op.Concat(*W_list, axis=0) if len(W_list) > 1 else W_list[0] R = op.Concat(*R_list, axis=0) if len(R_list) > 1 else R_list[0] B = op.Concat(*B_list, axis=0) if has_biases and len(B_list) > 1 else (B_list[0] if has_biases else None) - + # Call ONNX GRU operator direction = "bidirectional" if bidirectional else "forward" - + # Extract hidden_size from hx shape: [num_layers * num_directions, batch, hidden_size] hidden_size_attr = hx.shape[2] - + if B is not None: Y, Y_h = op.GRU( current_input, @@ -3889,7 +3886,7 @@ def aten_gru( direction=direction, hidden_size=hidden_size_attr, ) - + # Y shape: [seq_length, num_directions, batch_size, hidden_size] # Reshape to [seq_length, batch_size, num_directions * hidden_size] Y = op.Transpose(Y, perm=[0, 2, 1, 3]) # [seq_length, batch_size, num_directions, hidden_size] @@ -3907,22 +3904,22 @@ def aten_gru( axis=0, ) current_input = op.Reshape(Y, new_shape) - + # Apply dropout if not last layer and dropout > 0 if layer_idx < num_layers - 1 and dropout > 0.0 and train: current_input, _ = op.Dropout(current_input, dropout, train) - + # Store final hidden state output_h_list.append(Y_h) - + # Concatenate all layer outputs final_h = output_h_list[0] if len(output_h_list) == 1 else op.Concat(*output_h_list, axis=0) - + # Handle batch_first for output if batch_first: # Convert from [seq, batch, features] to [batch, seq, features] current_input = op.Transpose(current_input, perm=[1, 0, 2]) - + return current_input, final_h @@ -5169,112 +5166,109 @@ def aten_lstm( batch_first: bool, ) -> tuple[TFloat, TFloat, TFloat]: """lstm.input(Tensor input, Tensor[] hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional, bool batch_first) -> (Tensor, Tensor, Tensor)""" - + # Extract initial hidden and cell states initial_h = hx[0] # Shape: [num_directions * num_layers, batch_size, hidden_size] initial_c = hx[1] # Shape: [num_directions * num_layers, batch_size, hidden_size] - + # Determine number of directions num_directions = 2 if bidirectional else 1 - + # Get dimensions if batch_first: # Convert from [batch, seq, input_size] to [seq, batch, input_size] input = op.Transpose(input, perm=[1, 0, 2]) - - seq_length = op.Shape(input, start=0, end=1) - batch_size = op.Shape(input, start=1, end=2) - input_size = op.Shape(input, start=2, end=3) + hidden_size = op.Shape(initial_h, start=2, end=3) - + # Process each layer current_input = input output_h_list = [] output_c_list = [] - + for layer_idx in range(num_layers): # Extract hidden and cell states for this layer layer_start = layer_idx * num_directions layer_end = (layer_idx + 1) * num_directions layer_h = op.Slice(initial_h, layer_start, layer_end, axes=[0]) layer_c = op.Slice(initial_c, layer_start, layer_end, axes=[0]) - + # Extract parameters for this layer # Parameter layout: [W_ih, W_hh, b_ih, b_hh] for each direction params_per_direction = 4 if has_biases else 2 params_per_layer = params_per_direction * num_directions param_start_idx = layer_idx * params_per_layer - + # Build weight matrices for ONNX LSTM # ONNX expects: W[iofc] shape [num_directions, 4*hidden_size, input_size] # PyTorch provides: W_ih shape [4*hidden_size, input_size] W_list = [] R_list = [] B_list = [] if has_biases else None - + for dir_idx in range(num_directions): dir_param_start = param_start_idx + dir_idx * params_per_direction W_ih = params[dir_param_start] # [4*hidden_size, input_size] - PyTorch order: [i,f,g,o] W_hh = params[dir_param_start + 1] # [4*hidden_size, hidden_size] - PyTorch order: [i,f,g,o] - + # Reorder gates from PyTorch [i,f,g,o] to ONNX [i,o,f,g] # Split into individual gates W_ii = op.Slice(W_ih, starts=[0], ends=hidden_size, axes=[0]) W_if = op.Slice(W_ih, starts=hidden_size, ends=hidden_size * 2, axes=[0]) W_ig = op.Slice(W_ih, starts=hidden_size * 2, ends=hidden_size * 3, axes=[0]) W_io = op.Slice(W_ih, starts=hidden_size * 3, ends=hidden_size * 4, axes=[0]) - + W_hi = op.Slice(W_hh, starts=[0], ends=hidden_size, axes=[0]) W_hf = op.Slice(W_hh, starts=hidden_size, ends=hidden_size * 2, axes=[0]) W_hg = op.Slice(W_hh, starts=hidden_size * 2, ends=hidden_size * 3, axes=[0]) W_ho = op.Slice(W_hh, starts=hidden_size * 3, ends=hidden_size * 4, axes=[0]) - + # Reorder: [i,o,f,g] W_ih_reordered = op.Concat(W_ii, W_io, W_if, W_ig, axis=0) # [4*hidden_size, input_size] - ONNX order W_hh_reordered = op.Concat(W_hi, W_ho, W_hf, W_hg, axis=0) # [4*hidden_size, hidden_size] - ONNX order - + # Add direction dimension W_ih_expanded = op.Unsqueeze(W_ih_reordered, [0]) # [1, 4*hidden_size, input_size] W_hh_expanded = op.Unsqueeze(W_hh_reordered, [0]) # [1, 4*hidden_size, hidden_size] - + W_list.append(W_ih_expanded) R_list.append(W_hh_expanded) - + if has_biases: b_ih = params[dir_param_start + 2] # [4*hidden_size] - PyTorch order: [i,f,g,o] b_hh = params[dir_param_start + 3] # [4*hidden_size] - PyTorch order: [i,f,g,o] - + # Reorder biases from PyTorch [i,f,g,o] to ONNX [i,o,f,g] b_ii = op.Slice(b_ih, starts=[0], ends=hidden_size, axes=[0]) b_if = op.Slice(b_ih, starts=hidden_size, ends=hidden_size * 2, axes=[0]) b_ig = op.Slice(b_ih, starts=hidden_size * 2, ends=hidden_size * 3, axes=[0]) b_io = op.Slice(b_ih, starts=hidden_size * 3, ends=hidden_size * 4, axes=[0]) - + b_hi = op.Slice(b_hh, starts=[0], ends=hidden_size, axes=[0]) b_hf = op.Slice(b_hh, starts=hidden_size, ends=hidden_size * 2, axes=[0]) b_hg = op.Slice(b_hh, starts=hidden_size * 2, ends=hidden_size * 3, axes=[0]) b_ho = op.Slice(b_hh, starts=hidden_size * 3, ends=hidden_size * 4, axes=[0]) - + # Reorder: [i,o,f,g] b_ih_reordered = op.Concat(b_ii, b_io, b_if, b_ig, axis=0) # [4*hidden_size] - ONNX order b_hh_reordered = op.Concat(b_hi, b_ho, b_hf, b_hg, axis=0) # [4*hidden_size] - ONNX order - + # ONNX expects biases concatenated: [Wb[iofc], Rb[iofc]] b_combined = op.Concat(b_ih_reordered, b_hh_reordered, axis=0) # [8*hidden_size] b_expanded = op.Unsqueeze(b_combined, [0]) # [1, 8*hidden_size] B_list.append(b_expanded) - + # Concatenate weights for all directions W = op.Concat(*W_list, axis=0) if len(W_list) > 1 else W_list[0] R = op.Concat(*R_list, axis=0) if len(R_list) > 1 else R_list[0] B = op.Concat(*B_list, axis=0) if has_biases and len(B_list) > 1 else (B_list[0] if has_biases else None) - + # Call ONNX LSTM operator direction = "bidirectional" if bidirectional else "forward" - + # Extract hidden_size from initial_h shape: [num_layers * num_directions, batch, hidden_size] hidden_size_attr = initial_h.shape[2] - + if B is not None: Y, Y_h, Y_c = op.LSTM( current_input, @@ -5296,7 +5290,7 @@ def aten_lstm( direction=direction, hidden_size=hidden_size_attr, ) - + # Y shape: [seq_length, num_directions, batch_size, hidden_size] # Reshape to [seq_length, batch_size, num_directions * hidden_size] Y = op.Transpose(Y, perm=[0, 2, 1, 3]) # [seq_length, batch_size, num_directions, hidden_size] @@ -5314,24 +5308,24 @@ def aten_lstm( axis=0, ) current_input = op.Reshape(Y, new_shape) - + # Apply dropout if not last layer and dropout > 0 if layer_idx < num_layers - 1 and dropout > 0.0 and train: current_input, _ = op.Dropout(current_input, dropout, train) - + # Store final hidden and cell states output_h_list.append(Y_h) output_c_list.append(Y_c) - + # Concatenate all layer outputs final_h = output_h_list[0] if len(output_h_list) == 1 else op.Concat(*output_h_list, axis=0) final_c = output_c_list[0] if len(output_c_list) == 1 else op.Concat(*output_c_list, axis=0) - + # Handle batch_first for output if batch_first: # Convert from [seq, batch, features] to [batch, seq, features] current_input = op.Transpose(current_input, perm=[1, 0, 2]) - + return current_input, final_h, final_c From 72ea8a16580d49d8f6a145e87b7e06668ad7d5f3 Mon Sep 17 00:00:00 2001 From: ombrdr47 Date: Tue, 4 Nov 2025 22:15:34 +0530 Subject: [PATCH 8/8] style: apply lintrunner formatting fixes --- .../function_libs/torch_lib/ops/core.py | 104 +++++++++++++----- .../function_libs/torch_lib/e2e_ops_tests.py | 32 +++++- 2 files changed, 105 insertions(+), 31 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 26b7650f0a..96f64bbb8a 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -3810,8 +3810,12 @@ def aten_gru( for dir_idx in range(num_directions): dir_param_start = param_start_idx + dir_idx * params_per_direction - W_ih = params[dir_param_start] # [3*hidden_size, input_size] - PyTorch order: [r,z,n] - W_hh = params[dir_param_start + 1] # [3*hidden_size, hidden_size] - PyTorch order: [r,z,n] + W_ih = params[ + dir_param_start + ] # [3*hidden_size, input_size] - PyTorch order: [r,z,n] + W_hh = params[ + dir_param_start + 1 + ] # [3*hidden_size, hidden_size] - PyTorch order: [r,z,n] # Reorder gates from PyTorch [r,z,n] to ONNX [z,r,n] # Split into individual gates @@ -3824,12 +3828,18 @@ def aten_gru( W_hn = op.Slice(W_hh, starts=hidden_size * 2, ends=hidden_size * 3, axes=[0]) # Reorder: [z,r,n] - W_ih_reordered = op.Concat(W_iz, W_ir, W_in, axis=0) # [3*hidden_size, input_size] - ONNX order - W_hh_reordered = op.Concat(W_hz, W_hr, W_hn, axis=0) # [3*hidden_size, hidden_size] - ONNX order + W_ih_reordered = op.Concat( + W_iz, W_ir, W_in, axis=0 + ) # [3*hidden_size, input_size] - ONNX order + W_hh_reordered = op.Concat( + W_hz, W_hr, W_hn, axis=0 + ) # [3*hidden_size, hidden_size] - ONNX order # Add direction dimension W_ih_expanded = op.Unsqueeze(W_ih_reordered, [0]) # [1, 3*hidden_size, input_size] - W_hh_expanded = op.Unsqueeze(W_hh_reordered, [0]) # [1, 3*hidden_size, hidden_size] + W_hh_expanded = op.Unsqueeze( + W_hh_reordered, [0] + ) # [1, 3*hidden_size, hidden_size] W_list.append(W_ih_expanded) R_list.append(W_hh_expanded) @@ -3848,18 +3858,28 @@ def aten_gru( b_hn = op.Slice(b_hh, starts=hidden_size * 2, ends=hidden_size * 3, axes=[0]) # Reorder: [z,r,n] - b_ih_reordered = op.Concat(b_iz, b_ir, b_in, axis=0) # [3*hidden_size] - ONNX order - b_hh_reordered = op.Concat(b_hz, b_hr, b_hn, axis=0) # [3*hidden_size] - ONNX order + b_ih_reordered = op.Concat( + b_iz, b_ir, b_in, axis=0 + ) # [3*hidden_size] - ONNX order + b_hh_reordered = op.Concat( + b_hz, b_hr, b_hn, axis=0 + ) # [3*hidden_size] - ONNX order # ONNX expects biases concatenated: [Wb[zrh], Rb[zrh]] - b_combined = op.Concat(b_ih_reordered, b_hh_reordered, axis=0) # [6*hidden_size] + b_combined = op.Concat( + b_ih_reordered, b_hh_reordered, axis=0 + ) # [6*hidden_size] b_expanded = op.Unsqueeze(b_combined, [0]) # [1, 6*hidden_size] B_list.append(b_expanded) # Concatenate weights for all directions W = op.Concat(*W_list, axis=0) if len(W_list) > 1 else W_list[0] R = op.Concat(*R_list, axis=0) if len(R_list) > 1 else R_list[0] - B = op.Concat(*B_list, axis=0) if has_biases and len(B_list) > 1 else (B_list[0] if has_biases else None) + B = ( + op.Concat(*B_list, axis=0) + if has_biases and len(B_list) > 1 + else (B_list[0] if has_biases else None) + ) # Call ONNX GRU operator direction = "bidirectional" if bidirectional else "forward" @@ -3889,7 +3909,9 @@ def aten_gru( # Y shape: [seq_length, num_directions, batch_size, hidden_size] # Reshape to [seq_length, batch_size, num_directions * hidden_size] - Y = op.Transpose(Y, perm=[0, 2, 1, 3]) # [seq_length, batch_size, num_directions, hidden_size] + Y = op.Transpose( + Y, perm=[0, 2, 1, 3] + ) # [seq_length, batch_size, num_directions, hidden_size] Y_shape = op.Shape(Y) new_shape = op.Concat( op.Slice(Y_shape, [0], [1]), # seq_length @@ -3913,7 +3935,9 @@ def aten_gru( output_h_list.append(Y_h) # Concatenate all layer outputs - final_h = output_h_list[0] if len(output_h_list) == 1 else op.Concat(*output_h_list, axis=0) + final_h = ( + output_h_list[0] if len(output_h_list) == 1 else op.Concat(*output_h_list, axis=0) + ) # Handle batch_first for output if batch_first: @@ -5208,8 +5232,12 @@ def aten_lstm( for dir_idx in range(num_directions): dir_param_start = param_start_idx + dir_idx * params_per_direction - W_ih = params[dir_param_start] # [4*hidden_size, input_size] - PyTorch order: [i,f,g,o] - W_hh = params[dir_param_start + 1] # [4*hidden_size, hidden_size] - PyTorch order: [i,f,g,o] + W_ih = params[ + dir_param_start + ] # [4*hidden_size, input_size] - PyTorch order: [i,f,g,o] + W_hh = params[ + dir_param_start + 1 + ] # [4*hidden_size, hidden_size] - PyTorch order: [i,f,g,o] # Reorder gates from PyTorch [i,f,g,o] to ONNX [i,o,f,g] # Split into individual gates @@ -5224,19 +5252,29 @@ def aten_lstm( W_ho = op.Slice(W_hh, starts=hidden_size * 3, ends=hidden_size * 4, axes=[0]) # Reorder: [i,o,f,g] - W_ih_reordered = op.Concat(W_ii, W_io, W_if, W_ig, axis=0) # [4*hidden_size, input_size] - ONNX order - W_hh_reordered = op.Concat(W_hi, W_ho, W_hf, W_hg, axis=0) # [4*hidden_size, hidden_size] - ONNX order + W_ih_reordered = op.Concat( + W_ii, W_io, W_if, W_ig, axis=0 + ) # [4*hidden_size, input_size] - ONNX order + W_hh_reordered = op.Concat( + W_hi, W_ho, W_hf, W_hg, axis=0 + ) # [4*hidden_size, hidden_size] - ONNX order # Add direction dimension W_ih_expanded = op.Unsqueeze(W_ih_reordered, [0]) # [1, 4*hidden_size, input_size] - W_hh_expanded = op.Unsqueeze(W_hh_reordered, [0]) # [1, 4*hidden_size, hidden_size] + W_hh_expanded = op.Unsqueeze( + W_hh_reordered, [0] + ) # [1, 4*hidden_size, hidden_size] W_list.append(W_ih_expanded) R_list.append(W_hh_expanded) if has_biases: - b_ih = params[dir_param_start + 2] # [4*hidden_size] - PyTorch order: [i,f,g,o] - b_hh = params[dir_param_start + 3] # [4*hidden_size] - PyTorch order: [i,f,g,o] + b_ih = params[ + dir_param_start + 2 + ] # [4*hidden_size] - PyTorch order: [i,f,g,o] + b_hh = params[ + dir_param_start + 3 + ] # [4*hidden_size] - PyTorch order: [i,f,g,o] # Reorder biases from PyTorch [i,f,g,o] to ONNX [i,o,f,g] b_ii = op.Slice(b_ih, starts=[0], ends=hidden_size, axes=[0]) @@ -5250,18 +5288,28 @@ def aten_lstm( b_ho = op.Slice(b_hh, starts=hidden_size * 3, ends=hidden_size * 4, axes=[0]) # Reorder: [i,o,f,g] - b_ih_reordered = op.Concat(b_ii, b_io, b_if, b_ig, axis=0) # [4*hidden_size] - ONNX order - b_hh_reordered = op.Concat(b_hi, b_ho, b_hf, b_hg, axis=0) # [4*hidden_size] - ONNX order + b_ih_reordered = op.Concat( + b_ii, b_io, b_if, b_ig, axis=0 + ) # [4*hidden_size] - ONNX order + b_hh_reordered = op.Concat( + b_hi, b_ho, b_hf, b_hg, axis=0 + ) # [4*hidden_size] - ONNX order # ONNX expects biases concatenated: [Wb[iofc], Rb[iofc]] - b_combined = op.Concat(b_ih_reordered, b_hh_reordered, axis=0) # [8*hidden_size] + b_combined = op.Concat( + b_ih_reordered, b_hh_reordered, axis=0 + ) # [8*hidden_size] b_expanded = op.Unsqueeze(b_combined, [0]) # [1, 8*hidden_size] B_list.append(b_expanded) # Concatenate weights for all directions W = op.Concat(*W_list, axis=0) if len(W_list) > 1 else W_list[0] R = op.Concat(*R_list, axis=0) if len(R_list) > 1 else R_list[0] - B = op.Concat(*B_list, axis=0) if has_biases and len(B_list) > 1 else (B_list[0] if has_biases else None) + B = ( + op.Concat(*B_list, axis=0) + if has_biases and len(B_list) > 1 + else (B_list[0] if has_biases else None) + ) # Call ONNX LSTM operator direction = "bidirectional" if bidirectional else "forward" @@ -5293,7 +5341,9 @@ def aten_lstm( # Y shape: [seq_length, num_directions, batch_size, hidden_size] # Reshape to [seq_length, batch_size, num_directions * hidden_size] - Y = op.Transpose(Y, perm=[0, 2, 1, 3]) # [seq_length, batch_size, num_directions, hidden_size] + Y = op.Transpose( + Y, perm=[0, 2, 1, 3] + ) # [seq_length, batch_size, num_directions, hidden_size] Y_shape = op.Shape(Y) new_shape = op.Concat( op.Slice(Y_shape, [0], [1]), # seq_length @@ -5318,8 +5368,12 @@ def aten_lstm( output_c_list.append(Y_c) # Concatenate all layer outputs - final_h = output_h_list[0] if len(output_h_list) == 1 else op.Concat(*output_h_list, axis=0) - final_c = output_c_list[0] if len(output_c_list) == 1 else op.Concat(*output_c_list, axis=0) + final_h = ( + output_h_list[0] if len(output_h_list) == 1 else op.Concat(*output_h_list, axis=0) + ) + final_c = ( + output_c_list[0] if len(output_c_list) == 1 else op.Concat(*output_c_list, axis=0) + ) # Handle batch_first for output if batch_first: diff --git a/tests/function_libs/torch_lib/e2e_ops_tests.py b/tests/function_libs/torch_lib/e2e_ops_tests.py index 21970f2ecd..f74dda699d 100644 --- a/tests/function_libs/torch_lib/e2e_ops_tests.py +++ b/tests/function_libs/torch_lib/e2e_ops_tests.py @@ -306,7 +306,9 @@ def test_lstm_unidirectional(self): class LSTMModel(torch.nn.Module): def __init__(self): super().__init__() - self.lstm = torch.nn.LSTM(input_size=10, hidden_size=20, num_layers=1, batch_first=True) + self.lstm = torch.nn.LSTM( + input_size=10, hidden_size=20, num_layers=1, batch_first=True + ) def forward(self, x): return self.lstm(x) @@ -320,7 +322,13 @@ def test_lstm_bidirectional(self): class LSTMModel(torch.nn.Module): def __init__(self): super().__init__() - self.lstm = torch.nn.LSTM(input_size=10, hidden_size=20, num_layers=1, batch_first=True, bidirectional=True) + self.lstm = torch.nn.LSTM( + input_size=10, + hidden_size=20, + num_layers=1, + batch_first=True, + bidirectional=True, + ) def forward(self, x): return self.lstm(x) @@ -334,7 +342,9 @@ def test_lstm_multilayer(self): class LSTMModel(torch.nn.Module): def __init__(self): super().__init__() - self.lstm = torch.nn.LSTM(input_size=10, hidden_size=20, num_layers=3, batch_first=True) + self.lstm = torch.nn.LSTM( + input_size=10, hidden_size=20, num_layers=3, batch_first=True + ) def forward(self, x): return self.lstm(x) @@ -348,7 +358,9 @@ def test_gru_unidirectional(self): class GRUModel(torch.nn.Module): def __init__(self): super().__init__() - self.gru = torch.nn.GRU(input_size=10, hidden_size=20, num_layers=1, batch_first=True) + self.gru = torch.nn.GRU( + input_size=10, hidden_size=20, num_layers=1, batch_first=True + ) def forward(self, x): return self.gru(x) @@ -362,7 +374,13 @@ def test_gru_bidirectional(self): class GRUModel(torch.nn.Module): def __init__(self): super().__init__() - self.gru = torch.nn.GRU(input_size=10, hidden_size=20, num_layers=1, batch_first=True, bidirectional=True) + self.gru = torch.nn.GRU( + input_size=10, + hidden_size=20, + num_layers=1, + batch_first=True, + bidirectional=True, + ) def forward(self, x): return self.gru(x) @@ -376,7 +394,9 @@ def test_gru_multilayer(self): class GRUModel(torch.nn.Module): def __init__(self): super().__init__() - self.gru = torch.nn.GRU(input_size=10, hidden_size=20, num_layers=3, batch_first=True) + self.gru = torch.nn.GRU( + input_size=10, hidden_size=20, num_layers=3, batch_first=True + ) def forward(self, x): return self.gru(x)