diff --git a/tests/torch4ms/__init__.py b/tests/torch4ms/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/torch4ms/test_simple.py b/tests/torch4ms/test_simple.py new file mode 100644 index 000000000..44f65a0a0 --- /dev/null +++ b/tests/torch4ms/test_simple.py @@ -0,0 +1,16 @@ +import torch4ms +import torch +from torch4ms import MSDispatchMode, MSFunctionMode +import mindspore + +dispatch_mode = MSDispatchMode() +function_mode = MSFunctionMode() + +dispatch_mode.__enter__() +function_mode.__enter__() + +def test_add(): + x = torch.tensor(1) + y = torch.tensor(2) + z = x + y + print(z) diff --git a/torch4ms/__init__.py b/torch4ms/__init__.py index e69de29bb..0e34f418c 100644 --- a/torch4ms/__init__.py +++ b/torch4ms/__init__.py @@ -0,0 +1 @@ +from .tensor import * \ No newline at end of file diff --git a/torch4ms/_op_prim/__init__.py b/torch4ms/_op_prim/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/torch4ms/_op_prim/ascend/__init__.py b/torch4ms/_op_prim/ascend/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/torch4ms/_op_prim/ascend/legacy.py b/torch4ms/_op_prim/ascend/legacy.py new file mode 100644 index 000000000..4c96f0144 --- /dev/null +++ b/torch4ms/_op_prim/ascend/legacy.py @@ -0,0 +1,3511 @@ +from mindspore.ops.operations import * +from mindspore.ops.operations._grad_ops import * +from mindspore.ops._primitive_cache import _get_cache_prim +from mindspore.ops.auto_generate.gen_ops_prim import MaxPoolWithIndices, MaxPoolWithMask + + +a_cos_grad_op = ACosGrad().set_device('Ascend') +def a_cos_grad(*args): + return a_cos_grad_op(*args) + + +abs_grad_op = AbsGrad().set_device('Ascend') +def abs_grad(*args): + return abs_grad_op(*args) + + +acosh_grad_op = AcoshGrad().set_device('Ascend') +def acosh_grad(*args): + return acosh_grad_op(*args) + + +adaptive_avg_pool2_d_grad_op = AdaptiveAvgPool2DGrad().set_device('Ascend') +def adaptive_avg_pool2_d_grad(*args): + return adaptive_avg_pool2_d_grad_op(*args) + + +adaptive_avg_pool3_d_grad_op = AdaptiveAvgPool3DGrad().set_device('Ascend') +def adaptive_avg_pool3_d_grad(*args): + return adaptive_avg_pool3_d_grad_op(*args) + + +adaptive_max_pool2_d_grad_op = AdaptiveMaxPool2DGrad().set_device('Ascend') +def adaptive_max_pool2_d_grad(*args): + return adaptive_max_pool2_d_grad_op(*args) + + +adaptive_max_pool3_d_grad_op = AdaptiveMaxPool3DGrad().set_device('Ascend') +def adaptive_max_pool3_d_grad(*args): + return adaptive_max_pool3_d_grad_op(*args) + + +def affine_grid_grad(*args): + op = _get_cache_prim(AffineGridGrad)(*args[-1:]).set_device('Ascend') + return op(*args[:-1]) + + +asin_grad_op = AsinGrad().set_device('Ascend') +def asin_grad(*args): + return asin_grad_op(*args) + + +asinh_grad_op = AsinhGrad().set_device('Ascend') +def asinh_grad(*args): + return asinh_grad_op(*args) + + +atan_grad_op = AtanGrad().set_device('Ascend') +def atan_grad(*args): + return atan_grad_op(*args) + + +def avg_pool3_d_grad(*args): + op = _get_cache_prim(AvgPool3DGrad)(*args[-8:]).set_device('Ascend') + return op(*args[:-8]) + + +def avg_pool_grad(*args): + op = _get_cache_prim(AvgPoolGrad)(*args[-4:]).set_device('Ascend') + return op(*args[:-4]) + + +def avg_pool_grad_ge(*args): + op = _get_cache_prim(AvgPoolGradGe)(*args[-4:]).set_device('Ascend') + return op(*args[:-4]) + + +def avg_pool_grad_v1(*args): + op = _get_cache_prim(AvgPoolGradV1)(*args[-4:]).set_device('Ascend') + return op(*args[:-4]) + + +def avg_pool_grad_vm(*args): + op = _get_cache_prim(AvgPoolGradVm)(*args[-3:]).set_device('Ascend') + return op(*args[:-3]) + + +def bn_training_reduce_grad(*args): + op = _get_cache_prim(BNTrainingReduceGrad)(*args[-2:]).set_device('Ascend') + return op(*args[:-2]) + + +def bn_training_update_grad(*args): + op = _get_cache_prim(BNTrainingUpdateGrad)(*args[-2:]).set_device('Ascend') + return op(*args[:-2]) + + +def basic_lstm_cell_c_state_grad(*args): + op = _get_cache_prim(BasicLSTMCellCStateGrad)(*args[-2:]).set_device('Ascend') + return op(*args[:-2]) + + +def basic_lstm_cell_input_grad(*args): + op = _get_cache_prim(BasicLSTMCellInputGrad)(*args[-1:]).set_device('Ascend') + return op(*args[:-1]) + + +basic_lstm_cell_weight_grad_op = BasicLSTMCellWeightGrad().set_device('Ascend') +def basic_lstm_cell_weight_grad(*args): + return basic_lstm_cell_weight_grad_op(*args) + + +def batch_norm_grad(*args): + op = _get_cache_prim(BatchNormGrad)(*args[-3:]).set_device('Ascend') + return op(*args[:-3]) + + +def batch_norm_grad_grad(*args): + op = _get_cache_prim(BatchNormGradGrad)(*args[-3:]).set_device('Ascend') + return op(*args[:-3]) + + +def bias_add_grad(*args): + op = _get_cache_prim(BiasAddGrad)(*args[-1:]).set_device('Ascend') + return op(*args[:-1]) + + +def binary_cross_entropy_grad(*args): + op = _get_cache_prim(BinaryCrossEntropyGrad)(*args[-1:]).set_device('Ascend') + return op(*args[:-1]) + + +cholesky_grad_op = CholeskyGrad().set_device('Ascend') +def cholesky_grad(*args): + return cholesky_grad_op(*args) + + +def concat_offset(*args): + op = _get_cache_prim(ConcatOffset)(*args[-2:]).set_device('Ascend') + return op(*args[:-2]) + + +def conv2_d_backprop_filter(*args): + op = _get_cache_prim(Conv2DBackpropFilter)(*args[-10:]).set_device('Ascend') + return op(*args[:-10]) + + +def conv3_d_backprop_filter(*args): + op = _get_cache_prim(Conv3DBackpropFilter)(*args[-9:]).set_device('Ascend') + return op(*args[:-9]) + + +def deformable_offsets_grad(*args): + op = _get_cache_prim(DeformableOffsetsGrad)(*args[-7:]).set_device('Ascend') + return op(*args[:-7]) + + +def depthwise_conv2d_native_backprop_filter(*args): + op = _get_cache_prim(DepthwiseConv2dNativeBackpropFilter)(*args[-9:]).set_device('Ascend') + return op(*args[:-9]) + + +def depthwise_conv2d_native_backprop_input(*args): + op = _get_cache_prim(DepthwiseConv2dNativeBackpropInput)(*args[-9:]).set_device('Ascend') + return op(*args[:-9]) + + +def dilation2_d_backprop_filter(*args): + op = _get_cache_prim(Dilation2DBackpropFilter)(*args[-4:]).set_device('Ascend') + return op(*args[:-4]) + + +def dilation2_d_backprop_input(*args): + op = _get_cache_prim(Dilation2DBackpropInput)(*args[-4:]).set_device('Ascend') + return op(*args[:-4]) + + +def dropout_grad(*args): + op = _get_cache_prim(DropoutGrad)(*args[-1:]).set_device('Ascend') + return op(*args[:-1]) + + +def dynamic_gruv2_grad(*args): + op = _get_cache_prim(DynamicGRUV2Grad)(*args[-8:]).set_device('Ascend') + return op(*args[:-8]) + + +def dynamic_rnn_grad(*args): + op = _get_cache_prim(DynamicRNNGrad)(*args[-9:]).set_device('Ascend') + return op(*args[:-9]) + + +def einsum_grad(*args): + op = _get_cache_prim(EinsumGrad)(*args[-1:]).set_device('Ascend') + return op(*args[:-1]) + + +elu_grad_op = EluGrad().set_device('Ascend') +def elu_grad(*args): + return elu_grad_op(*args) + + +embedding_lookup_comm_grad_op = EmbeddingLookupCommGrad().set_device('Ascend') +def embedding_lookup_comm_grad(*args): + return embedding_lookup_comm_grad_op(*args) + + +fast_ge_lu_grad_op = FastGeLUGrad().set_device('Ascend') +def fast_ge_lu_grad(*args): + return fast_ge_lu_grad_op(*args) + + +def flash_attention_score_grad(*args): + op = _get_cache_prim(FlashAttentionScoreGrad)(*args[-8:]).set_device('Ascend') + return op(*args[:-8]) + + +flatten_grad_op = FlattenGrad().set_device('Ascend') +def flatten_grad(*args): + return flatten_grad_op(*args) + + +def fractional_avg_pool_grad(*args): + op = _get_cache_prim(FractionalAvgPoolGrad)(*args[-1:]).set_device('Ascend') + return op(*args[:-1]) + + +def fractional_max_pool3_d_grad_with_fixed_ksize(*args): + op = _get_cache_prim(FractionalMaxPool3DGradWithFixedKsize)(*args[-1:]).set_device('Ascend') + return op(*args[:-1]) + + +def fractional_max_pool_grad(*args): + op = _get_cache_prim(FractionalMaxPoolGrad)(*args[-1:]).set_device('Ascend') + return op(*args[:-1]) + + +def fractional_max_pool_grad_with_fixed_ksize(*args): + op = _get_cache_prim(FractionalMaxPoolGradWithFixedKsize)(*args[-1:]).set_device('Ascend') + return op(*args[:-1]) + + +def gruv2_grad(*args): + op = _get_cache_prim(GRUV2Grad)(*args[-6:]).set_device('Ascend') + return op(*args[:-6]) + + +gather_d_grad_v2_op = GatherDGradV2().set_device('Ascend') +def gather_d_grad_v2(*args): + return gather_d_grad_v2_op(*args) + + +ge_lu_grad_op = GeLUGrad().set_device('Ascend') +def ge_lu_grad(*args): + return ge_lu_grad_op(*args) + + +def global_comm(*args): + op = _get_cache_prim(GlobalComm)(*args[-2:]).set_device('Ascend') + return op(*args[:-2]) + + +def glu_grad(*args): + op = _get_cache_prim(GluGrad)(*args[-1:]).set_device('Ascend') + return op(*args[:-1]) + + +def grid_sampler2_d_grad(*args): + op = _get_cache_prim(GridSampler2DGrad)(*args[-4:]).set_device('Ascend') + return op(*args[:-4]) + + +def grid_sampler3_d_grad(*args): + op = _get_cache_prim(GridSampler3DGrad)(*args[-4:]).set_device('Ascend') + return op(*args[:-4]) + + +def gru_grad_data(*args): + op = _get_cache_prim(GruGradData)(*args[-6:]).set_device('Ascend') + return op(*args[:-6]) + + +def gru_grad_weight(*args): + op = _get_cache_prim(GruGradWeight)(*args[-6:]).set_device('Ascend') + return op(*args[:-6]) + + +def h_shrink_grad(*args): + op = _get_cache_prim(HShrinkGrad)(*args[-1:]).set_device('Ascend') + return op(*args[:-1]) + + +h_sigmoid_grad_op = HSigmoidGrad().set_device('Ascend') +def h_sigmoid_grad(*args): + return h_sigmoid_grad_op(*args) + + +h_swish_grad_op = HSwishGrad().set_device('Ascend') +def h_swish_grad(*args): + return h_swish_grad_op(*args) + + +igamma_grad_a_op = IgammaGradA().set_device('Ascend') +def igamma_grad_a(*args): + return igamma_grad_a_op(*args) + + +def instance_norm_grad(*args): + op = _get_cache_prim(InstanceNormGrad)(*args[-2:]).set_device('Ascend') + return op(*args[:-2]) + + +def instance_norm_v2_grad(*args): + op = _get_cache_prim(InstanceNormV2Grad)(*args[-2:]).set_device('Ascend') + return op(*args[:-2]) + + +inv_grad_op = InvGrad().set_device('Ascend') +def inv_grad(*args): + return inv_grad_op(*args) + + +def kl_div_loss_grad(*args): + op = _get_cache_prim(KLDivLossGrad)(*args[-1:]).set_device('Ascend') + return op(*args[:-1]) + + +def l2_normalize_grad(*args): + op = _get_cache_prim(L2NormalizeGrad)(*args[-2:]).set_device('Ascend') + return op(*args[:-2]) + + +def lrn_grad(*args): + op = _get_cache_prim(LRNGrad)(*args[-4:]).set_device('Ascend') + return op(*args[:-4]) + + +def lstm_grad(*args): + op = _get_cache_prim(LSTMGrad)(*args[-7:]).set_device('Ascend') + return op(*args[:-7]) + + +def lstm_grad_data(*args): + op = _get_cache_prim(LSTMGradData)(*args[-6:]).set_device('Ascend') + return op(*args[:-6]) + + +def lstm_grad_weight(*args): + op = _get_cache_prim(LSTMGradWeight)(*args[-6:]).set_device('Ascend') + return op(*args[:-6]) + + +def layer_norm_grad(*args): + op = _get_cache_prim(LayerNormGrad)(*args[-2:]).set_device('Ascend') + return op(*args[:-2]) + + +def layer_norm_grad_grad(*args): + op = _get_cache_prim(LayerNormGradGrad)(*args[-2:]).set_device('Ascend') + return op(*args[:-2]) + + +def log_softmax_grad(*args): + op = _get_cache_prim(LogSoftmaxGrad)(*args[-1:]).set_device('Ascend') + return op(*args[:-1]) + + +def logit_grad(*args): + op = _get_cache_prim(LogitGrad)(*args[-1:]).set_device('Ascend') + return op(*args[:-1]) + + +def lu_unpack_grad(*args): + op = _get_cache_prim(LuUnpackGrad)(*args[-2:]).set_device('Ascend') + return op(*args[:-2]) + + +map_tensor_get_grad_op = MapTensorGetGrad().set_device('Ascend') +def map_tensor_get_grad(*args): + return map_tensor_get_grad_op(*args) + + +masked_select_grad_op = MaskedSelectGrad().set_device('Ascend') +def masked_select_grad(*args): + return masked_select_grad_op(*args) + + +def max_pool3_d_grad(*args): + op = _get_cache_prim(MaxPool3DGrad)(*args[-5:]).set_device('Ascend') + return op(*args[:-5]) + + +def max_pool3_d_grad_grad(*args): + op = _get_cache_prim(MaxPool3DGradGrad)(*args[-4:]).set_device('Ascend') + return op(*args[:-4]) + + +def max_pool3_d_grad_with_argmax(*args): + op = _get_cache_prim(MaxPool3DGradWithArgmax)(*args[-6:]).set_device('Ascend') + return op(*args[:-6]) + + +def max_pool_grad(*args): + op = _get_cache_prim(MaxPoolGrad)(*args[-4:]).set_device('Ascend') + return op(*args[:-4]) + + +def max_pool_grad_grad(*args): + op = _get_cache_prim(MaxPoolGradGrad)(*args[-3:]).set_device('Ascend') + return op(*args[:-3]) + + +def max_pool_grad_grad_with_argmax(*args): + op = _get_cache_prim(MaxPoolGradGradWithArgmax)(*args[-3:]).set_device('Ascend') + return op(*args[:-3]) + + +def max_pool_grad_v1(*args): + op = _get_cache_prim(MaxPoolGradV1)(*args[-4:]).set_device('Ascend') + return op(*args[:-4]) + + +def max_pool_grad_with_argmax(*args): + op = _get_cache_prim(MaxPoolGradWithArgmax)(*args[-4:]).set_device('Ascend') + return op(*args[:-4]) + + +def max_pool_grad_with_argmax_v2(*args): + op = _get_cache_prim(MaxPoolGradWithArgmaxV2)(*args[-6:]).set_device('Ascend') + return op(*args[:-6]) + + +def max_unpool2_d_grad(*args): + op = _get_cache_prim(MaxUnpool2DGrad)(*args[-5:]).set_device('Ascend') + return op(*args[:-5]) + + +def max_unpool3_d_grad(*args): + op = _get_cache_prim(MaxUnpool3DGrad)(*args[-5:]).set_device('Ascend') + return op(*args[:-5]) + + +def maximum_grad(*args): + op = _get_cache_prim(MaximumGrad)(*args[-2:]).set_device('Ascend') + return op(*args[:-2]) + + +def maximum_grad_grad(*args): + op = _get_cache_prim(MaximumGradGrad)(*args[-2:]).set_device('Ascend') + return op(*args[:-2]) + + +def median_grad(*args): + op = _get_cache_prim(MedianGrad)(*args[-3:]).set_device('Ascend') + return op(*args[:-3]) + + +def minimum_grad(*args): + op = _get_cache_prim(MinimumGrad)(*args[-2:]).set_device('Ascend') + return op(*args[:-2]) + + +minimum_grad_grad_op = MinimumGradGrad().set_device('Ascend') +def minimum_grad_grad(*args): + return minimum_grad_grad_op(*args) + + +def mirror_pad_grad(*args): + op = _get_cache_prim(MirrorPadGrad)(*args[-1:]).set_device('Ascend') + return op(*args[:-1]) + + +def multi_margin_loss_grad(*args): + op = _get_cache_prim(MultiMarginLossGrad)(*args[-3:]).set_device('Ascend') + return op(*args[:-3]) + + +def multilabel_margin_loss_grad(*args): + op = _get_cache_prim(MultilabelMarginLossGrad)(*args[-1:]).set_device('Ascend') + return op(*args[:-1]) + + +def mvlgamma_grad(*args): + op = _get_cache_prim(MvlgammaGrad)(*args[-1:]).set_device('Ascend') + return op(*args[:-1]) + + +def nll_loss_grad(*args): + op = _get_cache_prim(NLLLossGrad)(*args[-2:]).set_device('Ascend') + return op(*args[:-2]) + + +def neighbor_exchange_v2_grad(*args): + op = _get_cache_prim(NeighborExchangeV2Grad)(*args[-6:]).set_device('Ascend') + return op(*args[:-6]) + + +p_re_lu_grad_op = PReLUGrad().set_device('Ascend') +def p_re_lu_grad(*args): + return p_re_lu_grad_op(*args) + + +def psroi_pooling_grad(*args): + op = _get_cache_prim(PSROIPoolingGrad)(*args[-4:]).set_device('Ascend') + return op(*args[:-4]) + + +def pad_v3_grad(*args): + op = _get_cache_prim(PadV3Grad)(*args[-2:]).set_device('Ascend') + return op(*args[:-2]) + + +def parallel_resize_bilinear_grad(*args): + op = _get_cache_prim(ParallelResizeBilinearGrad)(*args[-4:]).set_device('Ascend') + return op(*args[:-4]) + + +def pdist_grad(*args): + op = _get_cache_prim(PdistGrad)(*args[-1:]).set_device('Ascend') + return op(*args[:-1]) + + +def primitive(*args): + op = _get_cache_prim(Primitive)(*args[-1:]).set_device('Ascend') + return op(*args[:-1]) + + +def primitive_with_infer(*args): + op = _get_cache_prim(PrimitiveWithInfer)(*args[-1:]).set_device('Ascend') + return op(*args[:-1]) + + +def ps_roi_pooling_grad(*args): + op = _get_cache_prim(PsROIPoolingGrad)(*args[-9:]).set_device('Ascend') + return op(*args[:-9]) + + +def roi_align_grad(*args): + op = _get_cache_prim(ROIAlignGrad)(*args[-4:]).set_device('Ascend') + return op(*args[:-4]) + + +random_gamma_grad_op = RandomGammaGrad().set_device('Ascend') +def random_gamma_grad(*args): + return random_gamma_grad_op(*args) + + +re_lu6_grad_op = ReLU6Grad().set_device('Ascend') +def re_lu6_grad(*args): + return re_lu6_grad_op(*args) + + +reciprocal_grad_op = ReciprocalGrad().set_device('Ascend') +def reciprocal_grad(*args): + return reciprocal_grad_op(*args) + + +ref_to_embed_op = RefToEmbed().set_device('Ascend') +def ref_to_embed(*args): + return ref_to_embed_op(*args) + + +relu_grad_op = ReluGrad().set_device('Ascend') +def relu_grad(*args): + return relu_grad_op(*args) + + +def resize_bicubic_grad(*args): + op = _get_cache_prim(ResizeBicubicGrad)(*args[-2:]).set_device('Ascend') + return op(*args[:-2]) + + +def resize_bilinear_grad(*args): + op = _get_cache_prim(ResizeBilinearGrad)(*args[-2:]).set_device('Ascend') + return op(*args[:-2]) + + +def resize_linear1_d_grad(*args): + op = _get_cache_prim(ResizeLinear1DGrad)(*args[-1:]).set_device('Ascend') + return op(*args[:-1]) + + +def resize_nearest_neighbor_grad(*args): + op = _get_cache_prim(ResizeNearestNeighborGrad)(*args[-2:]).set_device('Ascend') + return op(*args[:-2]) + + +def resize_nearest_neighbor_v2_grad(*args): + op = _get_cache_prim(ResizeNearestNeighborV2Grad)(*args[-2:]).set_device('Ascend') + return op(*args[:-2]) + + +def resize_v2_grad(*args): + op = _get_cache_prim(ResizeV2Grad)(*args[-2:]).set_device('Ascend') + return op(*args[:-2]) + + +rms_norm_grad_op = RmsNormGrad().set_device('Ascend') +def rms_norm_grad(*args): + return rms_norm_grad_op(*args) + + +rsqrt_grad_op = RsqrtGrad().set_device('Ascend') +def rsqrt_grad(*args): + return rsqrt_grad_op(*args) + + +def scale_and_translate_grad(*args): + op = _get_cache_prim(ScaleAndTranslateGrad)(*args[-2:]).set_device('Ascend') + return op(*args[:-2]) + + +selu_grad_op = SeluGrad().set_device('Ascend') +def selu_grad(*args): + return selu_grad_op(*args) + + +si_lu_grad_op = SiLUGrad().set_device('Ascend') +def si_lu_grad(*args): + return si_lu_grad_op(*args) + + +sigmoid_cross_entropy_with_logits_grad_op = SigmoidCrossEntropyWithLogitsGrad().set_device('Ascend') +def sigmoid_cross_entropy_with_logits_grad(*args): + return sigmoid_cross_entropy_with_logits_grad_op(*args) + + +sigmoid_grad_op = SigmoidGrad().set_device('Ascend') +def sigmoid_grad(*args): + return sigmoid_grad_op(*args) + + +slice_grad_op = SliceGrad().set_device('Ascend') +def slice_grad(*args): + return slice_grad_op(*args) + + +def smooth_l1_loss_grad(*args): + op = _get_cache_prim(SmoothL1LossGrad)(*args[-2:]).set_device('Ascend') + return op(*args[:-2]) + + +def soft_margin_loss_grad(*args): + op = _get_cache_prim(SoftMarginLossGrad)(*args[-1:]).set_device('Ascend') + return op(*args[:-1]) + + +def soft_shrink_grad(*args): + op = _get_cache_prim(SoftShrinkGrad)(*args[-1:]).set_device('Ascend') + return op(*args[:-1]) + + +softmax_grad_op = SoftmaxGrad().set_device('Ascend') +def softmax_grad(*args): + return softmax_grad_op(*args) + + +softplus_grad_op = SoftplusGrad().set_device('Ascend') +def softplus_grad(*args): + return softplus_grad_op(*args) + + +sparse_fill_empty_rows_grad_op = SparseFillEmptyRowsGrad().set_device('Ascend') +def sparse_fill_empty_rows_grad(*args): + return sparse_fill_empty_rows_grad_op(*args) + + +sparse_segment_mean_grad_op = SparseSegmentMeanGrad().set_device('Ascend') +def sparse_segment_mean_grad(*args): + return sparse_segment_mean_grad_op(*args) + + +sparse_segment_sqrt_n_grad_op = SparseSegmentSqrtNGrad().set_device('Ascend') +def sparse_segment_sqrt_n_grad(*args): + return sparse_segment_sqrt_n_grad_op(*args) + + +sparse_segment_sum_grad_op = SparseSegmentSumGrad().set_device('Ascend') +def sparse_segment_sum_grad(*args): + return sparse_segment_sum_grad_op(*args) + + +sparse_slice_grad_op = SparseSliceGrad().set_device('Ascend') +def sparse_slice_grad(*args): + return sparse_slice_grad_op(*args) + + +sqrt_grad_op = SqrtGrad().set_device('Ascend') +def sqrt_grad(*args): + return sqrt_grad_op(*args) + + +def strided_slice_grad(*args): + op = _get_cache_prim(StridedSliceGrad)(*args[-5:]).set_device('Ascend') + return op(*args[:-5]) + + +def sync_batch_norm_grad(*args): + op = _get_cache_prim(SyncBatchNormGrad)(*args[-3:]).set_device('Ascend') + return op(*args[:-3]) + + +tanh_grad_op = TanhGrad().set_device('Ascend') +def tanh_grad(*args): + return tanh_grad_op(*args) + + +trace_grad_op = TraceGrad().set_device('Ascend') +def trace_grad(*args): + return trace_grad_op(*args) + + +unique_grad_op = UniqueGrad().set_device('Ascend') +def unique_grad(*args): + return unique_grad_op(*args) + + +upsample_nearest3_d_grad_op = UpsampleNearest3DGrad().set_device('Ascend') +def upsample_nearest3_d_grad(*args): + return upsample_nearest3_d_grad_op(*args) + + +def upsample_trilinear3_d_grad(*args): + op = _get_cache_prim(UpsampleTrilinear3DGrad)(*args[-1:]).set_device('Ascend') + return op(*args[:-1]) + + +wkv_grad_op = WKVGrad().set_device('Ascend') +def wkv_grad(*args): + return wkv_grad_op(*args) + + +a_cos_op = ACos().set_device('Ascend') +def a_cos(*args): + return a_cos_op(*args) + + +abs_op = Abs().set_device('Ascend') +def abs(*args): + return abs_op(*args) + + +accumulate_nv2_op = AccumulateNV2().set_device('Ascend') +def accumulate_nv2(*args): + return accumulate_nv2_op(*args) + + +acosh_op = Acosh().set_device('Ascend') +def acosh(*args): + return acosh_op(*args) + + +def adam(*args): + op = _get_cache_prim(Adam)(*args[-2:]).set_device('Ascend') + return op(*args[:-2]) + + +def adam_no_update_param(*args): + op = _get_cache_prim(AdamNoUpdateParam)(*args[-2:]).set_device('Ascend') + return op(*args[:-2]) + + +def adam_weight_decay(*args): + op = _get_cache_prim(AdamWeightDecay)(*args[-1:]).set_device('Ascend') + return op(*args[:-1]) + + +def adaptive_avg_pool2_d(*args): + op = _get_cache_prim(AdaptiveAvgPool2D)(*args[-1:]).set_device('Ascend') + return op(*args[:-1]) + + +def adaptive_avg_pool3_d(*args): + op = _get_cache_prim(AdaptiveAvgPool3D)(*args[-1:]).set_device('Ascend') + return op(*args[:-1]) + + +def adaptive_max_pool2_d(*args): + op = _get_cache_prim(AdaptiveMaxPool2D)(*args[-1:]).set_device('Ascend') + return op(*args[:-1]) + + +adaptive_max_pool3_d_op = AdaptiveMaxPool3D().set_device('Ascend') +def adaptive_max_pool3_d(*args): + return adaptive_max_pool3_d_op(*args) + + +add_op = Add().set_device('Ascend') +def add(*args): + return add_op(*args) + + +add_n_op = AddN().set_device('Ascend') +def add_n(*args): + return add_n_op(*args) + + +addcdiv_op = Addcdiv().set_device('Ascend') +def addcdiv(*args): + return addcdiv_op(*args) + + +addcmul_op = Addcmul().set_device('Ascend') +def addcmul(*args): + return addcmul_op(*args) + + +adjust_hue_op = AdjustHue().set_device('Ascend') +def adjust_hue(*args): + return adjust_hue_op(*args) + + +adjust_saturation_op = AdjustSaturation().set_device('Ascend') +def adjust_saturation(*args): + return adjust_saturation_op(*args) + + +def affine_grid(*args): + op = _get_cache_prim(AffineGrid)(*args[-1:]).set_device('Ascend') + return op(*args[:-1]) + + +def all_gather(*args): + op = _get_cache_prim(AllGather)(*args[-1:]).set_device('Ascend') + return op(*args[:-1]) + + +def all_reduce(*args): + op = _get_cache_prim(AllReduce)(*args[-2:]).set_device('Ascend') + return op(*args[:-2]) + + +def allto_all(*args): + op = _get_cache_prim(AlltoAll)(*args[-4:]).set_device('Ascend') + return op(*args[:-4]) + + +def allto_all_v(*args): + op = _get_cache_prim(AlltoAllV)(*args[-2:]).set_device('Ascend') + return op(*args[:-2]) + + +angle_op = Angle().set_device('Ascend') +def angle(*args): + return angle_op(*args) + + +apply_ada_max_op = ApplyAdaMax().set_device('Ascend') +def apply_ada_max(*args): + return apply_ada_max_op(*args) + + +apply_adadelta_op = ApplyAdadelta().set_device('Ascend') +def apply_adadelta(*args): + return apply_adadelta_op(*args) + + +def apply_adagrad(*args): + op = _get_cache_prim(ApplyAdagrad)(*args[-1:]).set_device('Ascend') + return op(*args[:-1]) + + +def apply_adagrad_da(*args): + op = _get_cache_prim(ApplyAdagradDA)(*args[-1:]).set_device('Ascend') + return op(*args[:-1]) + + +def apply_adagrad_v2(*args): + op = _get_cache_prim(ApplyAdagradV2)(*args[-2:]).set_device('Ascend') + return op(*args[:-2]) + + +def apply_adam_with_amsgrad(*args): + op = _get_cache_prim(ApplyAdamWithAmsgrad)(*args[-4:]).set_device('Ascend') + return op(*args[:-4]) + + +def apply_adam_with_amsgrad_v2(*args): + op = _get_cache_prim(ApplyAdamWithAmsgradV2)(*args[-1:]).set_device('Ascend') + return op(*args[:-1]) + + +apply_add_sign_op = ApplyAddSign().set_device('Ascend') +def apply_add_sign(*args): + return apply_add_sign_op(*args) + + +def apply_centered_rms_prop(*args): + op = _get_cache_prim(ApplyCenteredRMSProp)(*args[-1:]).set_device('Ascend') + return op(*args[:-1]) + + +def apply_ftrl(*args): + op = _get_cache_prim(ApplyFtrl)(*args[-1:]).set_device('Ascend') + return op(*args[:-1]) + + +apply_gradient_descent_op = ApplyGradientDescent().set_device('Ascend') +def apply_gradient_descent(*args): + return apply_gradient_descent_op(*args) + + +def apply_keras_momentum(*args): + op = _get_cache_prim(ApplyKerasMomentum)(*args[-2:]).set_device('Ascend') + return op(*args[:-2]) + + +def apply_momentum(*args): + op = _get_cache_prim(ApplyMomentum)(*args[-3:]).set_device('Ascend') + return op(*args[:-3]) + + +apply_power_sign_op = ApplyPowerSign().set_device('Ascend') +def apply_power_sign(*args): + return apply_power_sign_op(*args) + + +def apply_proximal_adagrad(*args): + op = _get_cache_prim(ApplyProximalAdagrad)(*args[-1:]).set_device('Ascend') + return op(*args[:-1]) + + +apply_proximal_gradient_descent_op = ApplyProximalGradientDescent().set_device('Ascend') +def apply_proximal_gradient_descent(*args): + return apply_proximal_gradient_descent_op(*args) + + +def apply_rms_prop(*args): + op = _get_cache_prim(ApplyRMSProp)(*args[-1:]).set_device('Ascend') + return op(*args[:-1]) + + +def apply_rotary_pos_emb(*args): + op = _get_cache_prim(ApplyRotaryPosEmb)(*args[-1:]).set_device('Ascend') + return op(*args[:-1]) + + +def approximate_equal(*args): + op = _get_cache_prim(ApproximateEqual)(*args[-1:]).set_device('Ascend') + return op(*args[:-1]) + + +def arg_max_with_value(*args): + op = _get_cache_prim(ArgMaxWithValue)(*args[-2:]).set_device('Ascend') + return op(*args[:-2]) + + +def arg_min_with_value(*args): + op = _get_cache_prim(ArgMinWithValue)(*args[-2:]).set_device('Ascend') + return op(*args[:-2]) + + +def argmax(*args): + op = _get_cache_prim(Argmax)(*args[-2:]).set_device('Ascend') + return op(*args[:-2]) + + +def argmin(*args): + op = _get_cache_prim(Argmin)(*args[-2:]).set_device('Ascend') + return op(*args[:-2]) + + +asin_op = Asin().set_device('Ascend') +def asin(*args): + return asin_op(*args) + + +asinh_op = Asinh().set_device('Ascend') +def asinh(*args): + return asinh_op(*args) + + +assign_op = Assign().set_device('Ascend') +def assign(*args): + return assign_op(*args) + + +assign_add_op = AssignAdd().set_device('Ascend') +def assign_add(*args): + return assign_add_op(*args) + + +assign_sub_op = AssignSub().set_device('Ascend') +def assign_sub(*args): + return assign_sub_op(*args) + + +atan_op = Atan().set_device('Ascend') +def atan(*args): + return atan_op(*args) + + +atan2_op = Atan2().set_device('Ascend') +def atan2(*args): + return atan2_op(*args) + + +atanh_op = Atanh().set_device('Ascend') +def atanh(*args): + return atanh_op(*args) + + +def avg_pool(*args): + op = _get_cache_prim(AvgPool)(*args[-4:]).set_device('Ascend') + return op(*args[:-4]) + + +def avg_pool3_d(*args): + op = _get_cache_prim(AvgPool3D)(*args[-8:]).set_device('Ascend') + return op(*args[:-8]) + + +def bce_with_logits_loss(*args): + op = _get_cache_prim(BCEWithLogitsLoss)(*args[-1:]).set_device('Ascend') + return op(*args[:-1]) + + +def barrier(*args): + op = _get_cache_prim(Barrier)(*args[-1:]).set_device('Ascend') + return op(*args[:-1]) + + +def bartlett_window(*args): + op = _get_cache_prim(BartlettWindow)(*args[-2:]).set_device('Ascend') + return op(*args[:-2]) + + +def basic_lstm_cell(*args): + op = _get_cache_prim(BasicLSTMCell)(*args[-4:]).set_device('Ascend') + return op(*args[:-4]) + + +def batch_i_send_i_recv(*args): + op = _get_cache_prim(BatchISendIRecv)(*args[-5:]).set_device('Ascend') + return op(*args[:-5]) + + +def batch_mat_mul(*args): + op = _get_cache_prim(BatchMatMul)(*args[-2:]).set_device('Ascend') + return op(*args[:-2]) + + +def batch_norm(*args): + op = _get_cache_prim(BatchNorm)(*args[-4:]).set_device('Ascend') + return op(*args[:-4]) + + +def batch_to_space(*args): + op = _get_cache_prim(BatchToSpace)(*args[-2:]).set_device('Ascend') + return op(*args[:-2]) + + +def batch_to_space_nd(*args): + op = _get_cache_prim(BatchToSpaceND)(*args[-2:]).set_device('Ascend') + return op(*args[:-2]) + + +batch_to_space_ndv2_op = BatchToSpaceNDV2().set_device('Ascend') +def batch_to_space_ndv2(*args): + return batch_to_space_ndv2_op(*args) + + +def bernoulli(*args): + op = _get_cache_prim(Bernoulli)(*args[-2:]).set_device('Ascend') + return op(*args[:-2]) + + +bessel_i0_op = BesselI0().set_device('Ascend') +def bessel_i0(*args): + return bessel_i0_op(*args) + + +bessel_i0e_op = BesselI0e().set_device('Ascend') +def bessel_i0e(*args): + return bessel_i0e_op(*args) + + +bessel_i1_op = BesselI1().set_device('Ascend') +def bessel_i1(*args): + return bessel_i1_op(*args) + + +bessel_i1e_op = BesselI1e().set_device('Ascend') +def bessel_i1e(*args): + return bessel_i1e_op(*args) + + +bessel_j0_op = BesselJ0().set_device('Ascend') +def bessel_j0(*args): + return bessel_j0_op(*args) + + +bessel_j1_op = BesselJ1().set_device('Ascend') +def bessel_j1(*args): + return bessel_j1_op(*args) + + +bessel_k0_op = BesselK0().set_device('Ascend') +def bessel_k0(*args): + return bessel_k0_op(*args) + + +bessel_k0e_op = BesselK0e().set_device('Ascend') +def bessel_k0e(*args): + return bessel_k0e_op(*args) + + +bessel_k1_op = BesselK1().set_device('Ascend') +def bessel_k1(*args): + return bessel_k1_op(*args) + + +bessel_k1e_op = BesselK1e().set_device('Ascend') +def bessel_k1e(*args): + return bessel_k1e_op(*args) + + +bessel_y0_op = BesselY0().set_device('Ascend') +def bessel_y0(*args): + return bessel_y0_op(*args) + + +bessel_y1_op = BesselY1().set_device('Ascend') +def bessel_y1(*args): + return bessel_y1_op(*args) + + +betainc_op = Betainc().set_device('Ascend') +def betainc(*args): + return betainc_op(*args) + + +def bias_add(*args): + op = _get_cache_prim(BiasAdd)(*args[-1:]).set_device('Ascend') + return op(*args[:-1]) + + +def binary_cross_entropy(*args): + op = _get_cache_prim(BinaryCrossEntropy)(*args[-1:]).set_device('Ascend') + return op(*args[:-1]) + + +bincount_op = Bincount().set_device('Ascend') +def bincount(*args): + return bincount_op(*args) + + +bitwise_and_op = BitwiseAnd().set_device('Ascend') +def bitwise_and(*args): + return bitwise_and_op(*args) + + +bitwise_or_op = BitwiseOr().set_device('Ascend') +def bitwise_or(*args): + return bitwise_or_op(*args) + + +bitwise_xor_op = BitwiseXor().set_device('Ascend') +def bitwise_xor(*args): + return bitwise_xor_op(*args) + + +def blackman_window(*args): + op = _get_cache_prim(BlackmanWindow)(*args[-2:]).set_device('Ascend') + return op(*args[:-2]) + + +def bounding_box_decode(*args): + op = _get_cache_prim(BoundingBoxDecode)(*args[-4:]).set_device('Ascend') + return op(*args[:-4]) + + +def bounding_box_encode(*args): + op = _get_cache_prim(BoundingBoxEncode)(*args[-2:]).set_device('Ascend') + return op(*args[:-2]) + + +def broadcast(*args): + op = _get_cache_prim(Broadcast)(*args[-2:]).set_device('Ascend') + return op(*args[:-2]) + + +def broadcast_to(*args): + op = _get_cache_prim(BroadcastTo)(*args[-1:]).set_device('Ascend') + return op(*args[:-1]) + + +def bucketize(*args): + op = _get_cache_prim(Bucketize)(*args[-1:]).set_device('Ascend') + return op(*args[:-1]) + + +def buffer_append(*args): + op = _get_cache_prim(BufferAppend)(*args[-3:]).set_device('Ascend') + return op(*args[:-3]) + + +def buffer_get_item(*args): + op = _get_cache_prim(BufferGetItem)(*args[-3:]).set_device('Ascend') + return op(*args[:-3]) + + +def buffer_sample(*args): + op = _get_cache_prim(BufferSample)(*args[-6:]).set_device('Ascend') + return op(*args[:-6]) + + +def ctc_greedy_decoder(*args): + op = _get_cache_prim(CTCGreedyDecoder)(*args[-1:]).set_device('Ascend') + return op(*args[:-1]) + + +def ctc_loss(*args): + op = _get_cache_prim(CTCLoss)(*args[-3:]).set_device('Ascend') + return op(*args[:-3]) + + +def ctc_loss_v2(*args): + op = _get_cache_prim(CTCLossV2)(*args[-3:]).set_device('Ascend') + return op(*args[:-3]) + + +cast_op = Cast().set_device('Ascend') +def cast(*args): + return cast_op(*args) + + +def cauchy(*args): + op = _get_cache_prim(Cauchy)(*args[-3:]).set_device('Ascend') + return op(*args[:-3]) + + +def cdist(*args): + op = _get_cache_prim(Cdist)(*args[-1:]).set_device('Ascend') + return op(*args[:-1]) + + +def ce_lu(*args): + op = _get_cache_prim(CeLU)(*args[-1:]).set_device('Ascend') + return op(*args[:-1]) + + +ceil_op = Ceil().set_device('Ascend') +def ceil(*args): + return ceil_op(*args) + + +def channel_shuffle(*args): + op = _get_cache_prim(ChannelShuffle)(*args[-1:]).set_device('Ascend') + return op(*args[:-1]) + + +check_numerics_op = CheckNumerics().set_device('Ascend') +def check_numerics(*args): + return check_numerics_op(*args) + + +check_valid_op = CheckValid().set_device('Ascend') +def check_valid(*args): + return check_valid_op(*args) + + +def cholesky(*args): + op = _get_cache_prim(Cholesky)(*args[-1:]).set_device('Ascend') + return op(*args[:-1]) + + +def cholesky_inverse(*args): + op = _get_cache_prim(CholeskyInverse)(*args[-1:]).set_device('Ascend') + return op(*args[:-1]) + + +def cholesky_solve(*args): + op = _get_cache_prim(CholeskySolve)(*args[-1:]).set_device('Ascend') + return op(*args[:-1]) + + +coalesce_op = Coalesce().set_device('Ascend') +def coalesce(*args): + return coalesce_op(*args) + + +def col2_im(*args): + op = _get_cache_prim(Col2Im)(*args[-4:]).set_device('Ascend') + return op(*args[:-4]) + + +def collective_gather(*args): + op = _get_cache_prim(CollectiveGather)(*args[-2:]).set_device('Ascend') + return op(*args[:-2]) + + +def collective_scatter(*args): + op = _get_cache_prim(CollectiveScatter)(*args[-2:]).set_device('Ascend') + return op(*args[:-2]) + + +def combined_non_max_suppression(*args): + op = _get_cache_prim(CombinedNonMaxSuppression)(*args[-2:]).set_device('Ascend') + return op(*args[:-2]) + + +compare_and_bitpack_op = CompareAndBitpack().set_device('Ascend') +def compare_and_bitpack(*args): + return compare_and_bitpack_op(*args) + + +complex_op = Complex().set_device('Ascend') +def complex(*args): + return complex_op(*args) + + +complex_abs_op = ComplexAbs().set_device('Ascend') +def complex_abs(*args): + return complex_abs_op(*args) + + +def compute_accidental_hits(*args): + op = _get_cache_prim(ComputeAccidentalHits)(*args[-1:]).set_device('Ascend') + return op(*args[:-1]) + + +def concat(*args): + op = _get_cache_prim(Concat)(*args[-1:]).set_device('Ascend') + return op(*args[:-1]) + + +def confusion_matrix(*args): + op = _get_cache_prim(ConfusionMatrix)(*args[-2:]).set_device('Ascend') + return op(*args[:-2]) + + +conj_op = Conj().set_device('Ascend') +def conj(*args): + return conj_op(*args) + + +conjugate_transpose_op = ConjugateTranspose().set_device('Ascend') +def conjugate_transpose(*args): + return conjugate_transpose_op(*args) + + +def conv2_d(*args): + op = _get_cache_prim(Conv2D)(*args[-9:]).set_device('Ascend') + return op(*args[:-9]) + + +def conv2_d_backprop_input(*args): + op = _get_cache_prim(Conv2DBackpropInput)(*args[-10:]).set_device('Ascend') + return op(*args[:-10]) + + +def conv2_d_transpose(*args): + op = _get_cache_prim(Conv2DTranspose)(*args[-10:]).set_device('Ascend') + return op(*args[:-10]) + + +def conv3_d(*args): + op = _get_cache_prim(Conv3D)(*args[-9:]).set_device('Ascend') + return op(*args[:-9]) + + +def conv3_d_transpose(*args): + op = _get_cache_prim(Conv3DTranspose)(*args[-11:]).set_device('Ascend') + return op(*args[:-11]) + + +copy_with_slice_op = CopyWithSlice().set_device('Ascend') +def copy_with_slice(*args): + return copy_with_slice_op(*args) + + +cos_op = Cos().set_device('Ascend') +def cos(*args): + return cos_op(*args) + + +cosh_op = Cosh().set_device('Ascend') +def cosh(*args): + return cosh_op(*args) + + +def count_non_zero(*args): + op = _get_cache_prim(CountNonZero)(*args[-1:]).set_device('Ascend') + return op(*args[:-1]) + + +def crop_and_resize(*args): + op = _get_cache_prim(CropAndResize)(*args[-2:]).set_device('Ascend') + return op(*args[:-2]) + + +def cross(*args): + op = _get_cache_prim(Cross)(*args[-1:]).set_device('Ascend') + return op(*args[:-1]) + + +def cum_prod(*args): + op = _get_cache_prim(CumProd)(*args[-2:]).set_device('Ascend') + return op(*args[:-2]) + + +def cum_sum(*args): + op = _get_cache_prim(CumSum)(*args[-2:]).set_device('Ascend') + return op(*args[:-2]) + + +def cummax(*args): + op = _get_cache_prim(Cummax)(*args[-1:]).set_device('Ascend') + return op(*args[:-1]) + + +def cummin(*args): + op = _get_cache_prim(Cummin)(*args[-1:]).set_device('Ascend') + return op(*args[:-1]) + + +def cumulative_logsumexp(*args): + op = _get_cache_prim(CumulativeLogsumexp)(*args[-2:]).set_device('Ascend') + return op(*args[:-2]) + + +d_type_op = DType().set_device('Ascend') +def d_type(*args): + return d_type_op(*args) + + +def data_format_dim_map(*args): + op = _get_cache_prim(DataFormatDimMap)(*args[-2:]).set_device('Ascend') + return op(*args[:-2]) + + +def data_format_vec_permute(*args): + op = _get_cache_prim(DataFormatVecPermute)(*args[-2:]).set_device('Ascend') + return op(*args[:-2]) + + +def deformable_offsets(*args): + op = _get_cache_prim(DeformableOffsets)(*args[-7:]).set_device('Ascend') + return op(*args[:-7]) + + +dense_op = Dense().set_device('Ascend') +def dense(*args): + return dense_op(*args) + + +depend_op = Depend().set_device('Ascend') +def depend(*args): + return depend_op(*args) + + +def depth_to_space(*args): + op = _get_cache_prim(DepthToSpace)(*args[-1:]).set_device('Ascend') + return op(*args[:-1]) + + +def depthwise_conv2d_native(*args): + op = _get_cache_prim(DepthwiseConv2dNative)(*args[-8:]).set_device('Ascend') + return op(*args[:-8]) + + +diag_op = Diag().set_device('Ascend') +def diag(*args): + return diag_op(*args) + + +diag_part_op = DiagPart().set_device('Ascend') +def diag_part(*args): + return diag_part_op(*args) + + +digamma_op = Digamma().set_device('Ascend') +def digamma(*args): + return digamma_op(*args) + + +def dilation2_d(*args): + op = _get_cache_prim(Dilation2D)(*args[-4:]).set_device('Ascend') + return op(*args[:-4]) + + +div_op = Div().set_device('Ascend') +def div(*args): + return div_op(*args) + + +div_no_nan_op = DivNoNan().set_device('Ascend') +def div_no_nan(*args): + return div_no_nan_op(*args) + + +def dropout(*args): + op = _get_cache_prim(Dropout)(*args[-3:]).set_device('Ascend') + return op(*args[:-3]) + + +def dropout2_d(*args): + op = _get_cache_prim(Dropout2D)(*args[-1:]).set_device('Ascend') + return op(*args[:-1]) + + +def dropout3_d(*args): + op = _get_cache_prim(Dropout3D)(*args[-1:]).set_device('Ascend') + return op(*args[:-1]) + +def dropout_gen_mask(*args): + op = _get_cache_prim(DropoutGenMask)(*args[-2:]).set_device('Ascend') + return op(*args[:-2]) + + +def dynamic_gruv2(*args): + op = _get_cache_prim(DynamicGRUV2)(*args[-10:]).set_device('Ascend') + return op(*args[:-10]) + + +def dynamic_rnn(*args): + op = _get_cache_prim(DynamicRNN)(*args[-11:]).set_device('Ascend') + return op(*args[:-11]) + + +def dynamic_shape(*args): + op = _get_cache_prim(DynamicShape)(*args[-1:]).set_device('Ascend') + return op(*args[:-1]) + + +def edit_distance(*args): + op = _get_cache_prim(EditDistance)(*args[-1:]).set_device('Ascend') + return op(*args[:-1]) + + +def eig(*args): + op = _get_cache_prim(Eig)(*args[-1:]).set_device('Ascend') + return op(*args[:-1]) + + +def einsum(*args): + op = _get_cache_prim(Einsum)(*args[-1:]).set_device('Ascend') + return op(*args[:-1]) + + +def elu(*args): + op = _get_cache_prim(Elu)(*args[-1:]).set_device('Ascend') + return op(*args[:-1]) + + +embedding_lookup_op = EmbeddingLookup().set_device('Ascend') +def embedding_lookup(*args): + return embedding_lookup_op(*args) + + +eps_op = Eps().set_device('Ascend') +def eps(*args): + return eps_op(*args) + + +equal_op = Equal().set_device('Ascend') +def equal(*args): + return equal_op(*args) + + +equal_count_op = EqualCount().set_device('Ascend') +def equal_count(*args): + return equal_count_op(*args) + + +erf_op = Erf().set_device('Ascend') +def erf(*args): + return erf_op(*args) + + +erfc_op = Erfc().set_device('Ascend') +def erfc(*args): + return erfc_op(*args) + + +erfinv_op = Erfinv().set_device('Ascend') +def erfinv(*args): + return erfinv_op(*args) + + +erfinv_op = Erfinv().set_device('Ascend') +def erfinv(*args): + return erfinv_op(*args) + + +def euclidean_norm(*args): + op = _get_cache_prim(EuclideanNorm)(*args[-1:]).set_device('Ascend') + return op(*args[:-1]) + + +exp_op = Exp().set_device('Ascend') +def exp(*args): + return exp_op(*args) + + +expand_dims_op = ExpandDims().set_device('Ascend') +def expand_dims(*args): + return expand_dims_op(*args) + + +expm1_op = Expm1().set_device('Ascend') +def expm1(*args): + return expm1_op(*args) + + +def extract_glimpse(*args): + op = _get_cache_prim(ExtractGlimpse)(*args[-4:]).set_device('Ascend') + return op(*args[:-4]) + + +def extract_image_patches(*args): + op = _get_cache_prim(ExtractImagePatches)(*args[-4:]).set_device('Ascend') + return op(*args[:-4]) + + +def extract_volume_patches(*args): + op = _get_cache_prim(ExtractVolumePatches)(*args[-3:]).set_device('Ascend') + return op(*args[:-3]) + + +eye_op = Eye().set_device('Ascend') +def eye(*args): + return eye_op(*args) + + +def fft_with_size(*args): + op = _get_cache_prim(FFTWithSize)(*args[-6:]).set_device('Ascend') + return op(*args[:-6]) + + +fast_ge_lu_op = FastGeLU().set_device('Ascend') +def fast_ge_lu(*args): + return fast_ge_lu_op(*args) + + +fill_op = Fill().set_device('Ascend') +def fill(*args): + return fill_op(*args) + + +def fill_diagonal(*args): + op = _get_cache_prim(FillDiagonal)(*args[-2:]).set_device('Ascend') + return op(*args[:-2]) + + +fill_v2_op = FillV2().set_device('Ascend') +def fill_v2(*args): + return fill_v2_op(*args) + + +fills_op = Fills().set_device('Ascend') +def fills(*args): + return fills_op(*args) + + +flatten_op = Flatten().set_device('Ascend') +def flatten(*args): + return flatten_op(*args) + + +float_status_op = FloatStatus().set_device('Ascend') +def float_status(*args): + return float_status_op(*args) + + +floor_op = Floor().set_device('Ascend') +def floor(*args): + return floor_op(*args) + + +floor_div_op = FloorDiv().set_device('Ascend') +def floor_div(*args): + return floor_div_op(*args) + + +floor_mod_op = FloorMod().set_device('Ascend') +def floor_mod(*args): + return floor_mod_op(*args) + + +fmax_op = Fmax().set_device('Ascend') +def fmax(*args): + return fmax_op(*args) + + +fmin_op = Fmin().set_device('Ascend') +def fmin(*args): + return fmin_op(*args) + + +fori_loop_op = ForiLoop().set_device('Ascend') +def fori_loop(*args): + return fori_loop_op(*args) + + +def fractional_avg_pool(*args): + op = _get_cache_prim(FractionalAvgPool)(*args[-6:]).set_device('Ascend') + return op(*args[:-6]) + + +def fractional_max_pool(*args): + op = _get_cache_prim(FractionalMaxPool)(*args[-6:]).set_device('Ascend') + return op(*args[:-6]) + + +def fractional_max_pool3_d_with_fixed_ksize(*args): + op = _get_cache_prim(FractionalMaxPool3DWithFixedKsize)(*args[-3:]).set_device('Ascend') + return op(*args[:-3]) + + +def fractional_max_pool_with_fixed_ksize(*args): + op = _get_cache_prim(FractionalMaxPoolWithFixedKsize)(*args[-3:]).set_device('Ascend') + return op(*args[:-3]) + + +def fused_ada_factor(*args): + op = _get_cache_prim(FusedAdaFactor)(*args[-3:]).set_device('Ascend') + return op(*args[:-3]) + + +def fused_ada_factor_with_global_norm(*args): + op = _get_cache_prim(FusedAdaFactorWithGlobalNorm)(*args[-3:]).set_device('Ascend') + return op(*args[:-3]) + + +def fused_cast_adam_weight_decay(*args): + op = _get_cache_prim(FusedCastAdamWeightDecay)(*args[-1:]).set_device('Ascend') + return op(*args[:-1]) + + +def fused_sparse_adam(*args): + op = _get_cache_prim(FusedSparseAdam)(*args[-2:]).set_device('Ascend') + return op(*args[:-2]) + + +def fused_sparse_ftrl(*args): + op = _get_cache_prim(FusedSparseFtrl)(*args[-5:]).set_device('Ascend') + return op(*args[:-5]) + + +def fused_sparse_lazy_adam(*args): + op = _get_cache_prim(FusedSparseLazyAdam)(*args[-2:]).set_device('Ascend') + return op(*args[:-2]) + + +def fused_sparse_proximal_adagrad(*args): + op = _get_cache_prim(FusedSparseProximalAdagrad)(*args[-1:]).set_device('Ascend') + return op(*args[:-1]) + + +fused_weight_scale_apply_momentum_op = FusedWeightScaleApplyMomentum().set_device('Ascend') +def fused_weight_scale_apply_momentum(*args): + return fused_weight_scale_apply_momentum_op(*args) + + +def glu(*args): + op = _get_cache_prim(GLU)(*args[-1:]).set_device('Ascend') + return op(*args[:-1]) + + +def gamma(*args): + op = _get_cache_prim(Gamma)(*args[-2:]).set_device('Ascend') + return op(*args[:-2]) + + +def gather(*args): + op = _get_cache_prim(Gather)(*args[-1:]).set_device('Ascend') + return op(*args[:-1]) + + +gather_d_op = GatherD().set_device('Ascend') +def gather_d(*args): + return gather_d_op(*args) + + +gather_nd_op = GatherNd().set_device('Ascend') +def gather_nd(*args): + return gather_nd_op(*args) + + +gcd_op = Gcd().set_device('Ascend') +def gcd(*args): + return gcd_op(*args) + + +ge_lu_op = GeLU().set_device('Ascend') +def ge_lu(*args): + return ge_lu_op(*args) + + +ge_switch_op = GeSwitch().set_device('Ascend') +def ge_switch(*args): + return ge_switch_op(*args) + + +geqrf_op = Geqrf().set_device('Ascend') +def geqrf(*args): + return geqrf_op(*args) + + +ger_op = Ger().set_device('Ascend') +def ger(*args): + return ger_op(*args) + + +def get_next(*args): + op = _get_cache_prim(GetNext)(*args[-4:]).set_device('Ascend') + return op(*args[:-4]) + + +greater_op = Greater().set_device('Ascend') +def greater(*args): + return greater_op(*args) + + +greater_equal_op = GreaterEqual().set_device('Ascend') +def greater_equal(*args): + return greater_equal_op(*args) + + +def grid_sampler2_d(*args): + op = _get_cache_prim(GridSampler2D)(*args[-3:]).set_device('Ascend') + return op(*args[:-3]) + + +def grid_sampler3_d(*args): + op = _get_cache_prim(GridSampler3D)(*args[-3:]).set_device('Ascend') + return op(*args[:-3]) + + +group_topk_op = GroupTopk().set_device('Ascend') +def group_topk(*args): + return group_topk_op(*args) + + +hsv_to_rgb_op = HSVToRGB().set_device('Ascend') +def hsv_to_rgb(*args): + return hsv_to_rgb_op(*args) + + +def h_shrink(*args): + op = _get_cache_prim(HShrink)(*args[-1:]).set_device('Ascend') + return op(*args[:-1]) + + +h_sigmoid_op = HSigmoid().set_device('Ascend') +def h_sigmoid(*args): + return h_sigmoid_op(*args) + + +h_swish_op = HSwish().set_device('Ascend') +def h_swish(*args): + return h_swish_op(*args) + + +def hamming_window(*args): + op = _get_cache_prim(HammingWindow)(*args[-4:]).set_device('Ascend') + return op(*args[:-4]) + + +heaviside_op = Heaviside().set_device('Ascend') +def heaviside(*args): + return heaviside_op(*args) + + +def histogram(*args): + op = _get_cache_prim(Histogram)(*args[-3:]).set_device('Ascend') + return op(*args[:-3]) + + +def histogram_fixed_width(*args): + op = _get_cache_prim(HistogramFixedWidth)(*args[-2:]).set_device('Ascend') + return op(*args[:-2]) + + +histogram_summary_op = HistogramSummary().set_device('Ascend') +def histogram_summary(*args): + return histogram_summary_op(*args) + + +def hook_backward(*args): + op = _get_cache_prim(HookBackward)(*args[-2:]).set_device('Ascend') + return op(*args[:-2]) + + +hypot_op = Hypot().set_device('Ascend') +def hypot(*args): + return hypot_op(*args) + + +def iou(*args): + op = _get_cache_prim(IOU)(*args[-1:]).set_device('Ascend') + return op(*args[:-1]) + + +identity_op = Identity().set_device('Ascend') +def identity(*args): + return identity_op(*args) + + +identity_n_op = IdentityN().set_device('Ascend') +def identity_n(*args): + return identity_n_op(*args) + + +igamma_op = Igamma().set_device('Ascend') +def igamma(*args): + return igamma_op(*args) + + +igammac_op = Igammac().set_device('Ascend') +def igammac(*args): + return igammac_op(*args) + + +def im2_col(*args): + op = _get_cache_prim(Im2Col)(*args[-4:]).set_device('Ascend') + return op(*args[:-4]) + + +imag_op = Imag().set_device('Ascend') +def imag(*args): + return imag_op(*args) + + +image_summary_op = ImageSummary().set_device('Ascend') +def image_summary(*args): + return image_summary_op(*args) + + +def in_top_k(*args): + op = _get_cache_prim(InTopK)(*args[-1:]).set_device('Ascend') + return op(*args[:-1]) + + +def index_add(*args): + op = _get_cache_prim(IndexAdd)(*args[-3:]).set_device('Ascend') + return op(*args[:-3]) + + +index_fill_op = IndexFill().set_device('Ascend') +def index_fill(*args): + return index_fill_op(*args) + + +def index_put(*args): + op = _get_cache_prim(IndexPut)(*args[-1:]).set_device('Ascend') + return op(*args[:-1]) + + +def inplace_add(*args): + op = _get_cache_prim(InplaceAdd)(*args[-1:]).set_device('Ascend') + return op(*args[:-1]) + + +def inplace_index_add(*args): + op = _get_cache_prim(InplaceIndexAdd)(*args[-1:]).set_device('Ascend') + return op(*args[:-1]) + + +def inplace_sub(*args): + op = _get_cache_prim(InplaceSub)(*args[-1:]).set_device('Ascend') + return op(*args[:-1]) + + +def inplace_update(*args): + op = _get_cache_prim(InplaceUpdate)(*args[-1:]).set_device('Ascend') + return op(*args[:-1]) + + +inplace_update_v2_op = InplaceUpdateV2().set_device('Ascend') +def inplace_update_v2(*args): + return inplace_update_v2_op(*args) + + +def insert_gradient_of(*args): + op = _get_cache_prim(InsertGradientOf)(*args[-1:]).set_device('Ascend') + return op(*args[:-1]) + + +inv_op = Inv().set_device('Ascend') +def inv(*args): + return inv_op(*args) + + +invert_op = Invert().set_device('Ascend') +def invert(*args): + return invert_op(*args) + + +invert_permutation_op = InvertPermutation().set_device('Ascend') +def invert_permutation(*args): + return invert_permutation_op(*args) + + +def is_close(*args): + op = _get_cache_prim(IsClose)(*args[-3:]).set_device('Ascend') + return op(*args[:-3]) + + +is_finite_op = IsFinite().set_device('Ascend') +def is_finite(*args): + return is_finite_op(*args) + + +is_inf_op = IsInf().set_device('Ascend') +def is_inf(*args): + return is_inf_op(*args) + + +is_nan_op = IsNan().set_device('Ascend') +def is_nan(*args): + return is_nan_op(*args) + + +def kl_div_loss(*args): + op = _get_cache_prim(KLDivLoss)(*args[-1:]).set_device('Ascend') + return op(*args[:-1]) + + +l2_loss_op = L2Loss().set_device('Ascend') +def l2_loss(*args): + return l2_loss_op(*args) + + +def l2_normalize(*args): + op = _get_cache_prim(L2Normalize)(*args[-2:]).set_device('Ascend') + return op(*args[:-2]) + + +def lars_update(*args): + op = _get_cache_prim(LARSUpdate)(*args[-3:]).set_device('Ascend') + return op(*args[:-3]) + + +def lrn(*args): + op = _get_cache_prim(LRN)(*args[-5:]).set_device('Ascend') + return op(*args[:-5]) + + +def lstm(*args): + op = _get_cache_prim(LSTM)(*args[-7:]).set_device('Ascend') + return op(*args[:-7]) + + +def layer_norm(*args): + op = _get_cache_prim(LayerNorm)(*args[-3:]).set_device('Ascend') + return op(*args[:-3]) + + +lcm_op = Lcm().set_device('Ascend') +def lcm(*args): + return lcm_op(*args) + + +left_shift_op = LeftShift().set_device('Ascend') +def left_shift(*args): + return left_shift_op(*args) + + +lerp_op = Lerp().set_device('Ascend') +def lerp(*args): + return lerp_op(*args) + + +lerp_scalar_op = LerpScalar().set_device('Ascend') +def lerp_scalar(*args): + return lerp_scalar_op(*args) + + +less_op = Less().set_device('Ascend') +def less(*args): + return less_op(*args) + + +less_equal_op = LessEqual().set_device('Ascend') +def less_equal(*args): + return less_equal_op(*args) + + +lgamma_op = Lgamma().set_device('Ascend') +def lgamma(*args): + return lgamma_op(*args) + + +lin_space_op = LinSpace().set_device('Ascend') +def lin_space(*args): + return lin_space_op(*args) + + +def list_diff(*args): + op = _get_cache_prim(ListDiff)(*args[-1:]).set_device('Ascend') + return op(*args[:-1]) + + +log_op = Log().set_device('Ascend') +def log(*args): + return log_op(*args) + + +log1p_op = Log1p().set_device('Ascend') +def log1p(*args): + return log1p_op(*args) + + +log_matrix_determinant_op = LogMatrixDeterminant().set_device('Ascend') +def log_matrix_determinant(*args): + return log_matrix_determinant_op(*args) + + +def log_normal_reverse(*args): + op = _get_cache_prim(LogNormalReverse)(*args[-2:]).set_device('Ascend') + return op(*args[:-2]) + + +def log_softmax(*args): + op = _get_cache_prim(LogSoftmax)(*args[-1:]).set_device('Ascend') + return op(*args[:-1]) + + +log_softmax_ext_op = LogSoftmaxExt().set_device('Ascend') +def log_softmax_ext(*args): + return log_softmax_ext_op(*args) + + +def log_space(*args): + op = _get_cache_prim(LogSpace)(*args[-3:]).set_device('Ascend') + return op(*args[:-3]) + + +def log_uniform_candidate_sampler(*args): + op = _get_cache_prim(LogUniformCandidateSampler)(*args[-5:]).set_device('Ascend') + return op(*args[:-5]) + + +logical_and_op = LogicalAnd().set_device('Ascend') +def logical_and(*args): + return logical_and_op(*args) + + +logical_not_op = LogicalNot().set_device('Ascend') +def logical_not(*args): + return logical_not_op(*args) + + +logical_or_op = LogicalOr().set_device('Ascend') +def logical_or(*args): + return logical_or_op(*args) + + +logical_xor_op = LogicalXor().set_device('Ascend') +def logical_xor(*args): + return logical_xor_op(*args) + + +def logit(*args): + op = _get_cache_prim(Logit)(*args[-1:]).set_device('Ascend') + return op(*args[:-1]) + + +def lower_bound(*args): + op = _get_cache_prim(LowerBound)(*args[-1:]).set_device('Ascend') + return op(*args[:-1]) + + +def lp_norm(*args): + op = _get_cache_prim(LpNorm)(*args[-4:]).set_device('Ascend') + return op(*args[:-4]) + + +def lstsq(*args): + op = _get_cache_prim(Lstsq)(*args[-2:]).set_device('Ascend') + return op(*args[:-2]) + + +lu_solve_op = LuSolve().set_device('Ascend') +def lu_solve(*args): + return lu_solve_op(*args) + + +def lu_unpack(*args): + op = _get_cache_prim(LuUnpack)(*args[-2:]).set_device('Ascend') + return op(*args[:-2]) + + +map_cache_idx_op = MapCacheIdx().set_device('Ascend') +def map_cache_idx(*args): + return map_cache_idx_op(*args) + + +map_uniform_op = MapUniform().set_device('Ascend') +def map_uniform(*args): + return map_uniform_op(*args) + + +masked_fill_op = MaskedFill().set_device('Ascend') +def masked_fill(*args): + return masked_fill_op(*args) + + +masked_scatter_op = MaskedScatter().set_device('Ascend') +def masked_scatter(*args): + return masked_scatter_op(*args) + + +masked_select_op = MaskedSelect().set_device('Ascend') +def masked_select(*args): + return masked_select_op(*args) + + +def mat_mul(*args): + op = _get_cache_prim(MatMul)(*args[-2:]).set_device('Ascend') + return op(*args[:-2]) + + +matrix_band_part_op = MatrixBandPart().set_device('Ascend') +def matrix_band_part(*args): + return matrix_band_part_op(*args) + + +matrix_determinant_op = MatrixDeterminant().set_device('Ascend') +def matrix_determinant(*args): + return matrix_determinant_op(*args) + + +def matrix_diag_part_v3(*args): + op = _get_cache_prim(MatrixDiagPartV3)(*args[-1:]).set_device('Ascend') + return op(*args[:-1]) + + +def matrix_diag_v3(*args): + op = _get_cache_prim(MatrixDiagV3)(*args[-1:]).set_device('Ascend') + return op(*args[:-1]) + + +matrix_exp_op = MatrixExp().set_device('Ascend') +def matrix_exp(*args): + return matrix_exp_op(*args) + + +def matrix_inverse(*args): + op = _get_cache_prim(MatrixInverse)(*args[-1:]).set_device('Ascend') + return op(*args[:-1]) + + +matrix_logarithm_op = MatrixLogarithm().set_device('Ascend') +def matrix_logarithm(*args): + return matrix_logarithm_op(*args) + + +def matrix_power(*args): + op = _get_cache_prim(MatrixPower)(*args[-1:]).set_device('Ascend') + return op(*args[:-1]) + + +def matrix_set_diag_v3(*args): + op = _get_cache_prim(MatrixSetDiagV3)(*args[-1:]).set_device('Ascend') + return op(*args[:-1]) + + +def matrix_solve(*args): + op = _get_cache_prim(MatrixSolve)(*args[-1:]).set_device('Ascend') + return op(*args[:-1]) + + +def matrix_solve_ls(*args): + op = _get_cache_prim(MatrixSolveLs)(*args[-1:]).set_device('Ascend') + return op(*args[:-1]) + + +def matrix_triangular_solve(*args): + op = _get_cache_prim(MatrixTriangularSolve)(*args[-2:]).set_device('Ascend') + return op(*args[:-2]) + + +def max_pool(*args): + op = _get_cache_prim(MaxPool)(*args[-4:]).set_device('Ascend') + return op(*args[:-4]) + + +def max_pool3_d(*args): + op = _get_cache_prim(MaxPool3D)(*args[-6:]).set_device('Ascend') + return op(*args[:-6]) + + +def max_pool3_d_with_argmax(*args): + op = _get_cache_prim(MaxPool3DWithArgmax)(*args[-7:]).set_device('Ascend') + return op(*args[:-7]) + + +def max_pool_with_argmax(*args): + op = _get_cache_prim(MaxPoolWithArgmax)(*args[-4:]).set_device('Ascend') + return op(*args[:-4]) + + +def max_pool_with_argmax_v2(*args): + op = _get_cache_prim(MaxPoolWithArgmaxV2)(*args[-6:]).set_device('Ascend') + return op(*args[:-6]) + + +def max_unpool2_d(*args): + op = _get_cache_prim(MaxUnpool2D)(*args[-5:]).set_device('Ascend') + return op(*args[:-5]) + + +def max_unpool3_d(*args): + op = _get_cache_prim(MaxUnpool3D)(*args[-5:]).set_device('Ascend') + return op(*args[:-5]) + + +maximum_op = Maximum().set_device('Ascend') +def maximum(*args): + return maximum_op(*args) + + +merge_op = Merge().set_device('Ascend') +def merge(*args): + return merge_op(*args) + + +def meshgrid(*args): + op = _get_cache_prim(Meshgrid)(*args[-1:]).set_device('Ascend') + return op(*args[:-1]) + + +minimum_op = Minimum().set_device('Ascend') +def minimum(*args): + return minimum_op(*args) + + +def mirror_pad(*args): + op = _get_cache_prim(MirrorPad)(*args[-1:]).set_device('Ascend') + return op(*args[:-1]) + + +mish_op = Mish().set_device('Ascend') +def mish(*args): + return mish_op(*args) + + +mod_op = Mod().set_device('Ascend') +def mod(*args): + return mod_op(*args) + + +def morph(*args): + op = _get_cache_prim(Morph)(*args[-3:]).set_device('Ascend') + return op(*args[:-3]) + + +move_to_op = MoveTo().set_device('Ascend') +def move_to(*args): + return move_to_op(*args) + + +mul_op = Mul().set_device('Ascend') +def mul(*args): + return mul_op(*args) + + +mul_no_nan_op = MulNoNan().set_device('Ascend') +def mul_no_nan(*args): + return mul_no_nan_op(*args) + + +def multi_margin_loss(*args): + op = _get_cache_prim(MultiMarginLoss)(*args[-3:]).set_device('Ascend') + return op(*args[:-3]) + + +def multilabel_margin_loss(*args): + op = _get_cache_prim(MultilabelMarginLoss)(*args[-1:]).set_device('Ascend') + return op(*args[:-1]) + + +def multinomial(*args): + op = _get_cache_prim(Multinomial)(*args[-3:]).set_device('Ascend') + return op(*args[:-3]) + + +def multinomial_with_replacement(*args): + op = _get_cache_prim(MultinomialWithReplacement)(*args[-2:]).set_device('Ascend') + return op(*args[:-2]) + + +def mvlgamma(*args): + op = _get_cache_prim(Mvlgamma)(*args[-1:]).set_device('Ascend') + return op(*args[:-1]) + + +def nll_loss(*args): + op = _get_cache_prim(NLLLoss)(*args[-2:]).set_device('Ascend') + return op(*args[:-2]) + + +def nms_with_mask(*args): + op = _get_cache_prim(NMSWithMask)(*args[-1:]).set_device('Ascend') + return op(*args[:-1]) + + +def nan_to_num(*args): + op = _get_cache_prim(NanToNum)(*args[-3:]).set_device('Ascend') + return op(*args[:-3]) + + +neg_op = Neg().set_device('Ascend') +def neg(*args): + return neg_op(*args) + + +def neighbor_exchange(*args): + op = _get_cache_prim(NeighborExchange)(*args[-6:]).set_device('Ascend') + return op(*args[:-6]) + + +def neighbor_exchange_v2(*args): + op = _get_cache_prim(NeighborExchangeV2)(*args[-6:]).set_device('Ascend') + return op(*args[:-6]) + + +next_after_op = NextAfter().set_device('Ascend') +def next_after(*args): + return next_after_op(*args) + + +def no_repeat_n_gram(*args): + op = _get_cache_prim(NoRepeatNGram)(*args[-1:]).set_device('Ascend') + return op(*args[:-1]) + + +def non_deterministic_ints(*args): + op = _get_cache_prim(NonDeterministicInts)(*args[-1:]).set_device('Ascend') + return op(*args[:-1]) + + +non_max_suppression_v3_op = NonMaxSuppressionV3().set_device('Ascend') +def non_max_suppression_v3(*args): + return non_max_suppression_v3_op(*args) + + +non_max_suppression_with_overlaps_op = NonMaxSuppressionWithOverlaps().set_device('Ascend') +def non_max_suppression_with_overlaps(*args): + return non_max_suppression_with_overlaps_op(*args) + + +non_zero_op = NonZero().set_device('Ascend') +def non_zero(*args): + return non_zero_op(*args) + + +not_equal_op = NotEqual().set_device('Ascend') +def not_equal(*args): + return not_equal_op(*args) + + +def nth_element(*args): + op = _get_cache_prim(NthElement)(*args[-1:]).set_device('Ascend') + return op(*args[:-1]) + + +def nuclear_norm(*args): + op = _get_cache_prim(NuclearNorm)(*args[-2:]).set_device('Ascend') + return op(*args[:-2]) + + +def one_hot(*args): + op = _get_cache_prim(OneHot)(*args[-1:]).set_device('Ascend') + return op(*args[:-1]) + + +ones_op = Ones().set_device('Ascend') +def ones(*args): + return ones_op(*args) + + +ones_like_op = OnesLike().set_device('Ascend') +def ones_like(*args): + return ones_like_op(*args) + + +orgqr_op = Orgqr().set_device('Ascend') +def orgqr(*args): + return orgqr_op(*args) + + +def ormqr(*args): + op = _get_cache_prim(Ormqr)(*args[-2:]).set_device('Ascend') + return op(*args[:-2]) + + +p_re_lu_op = PReLU().set_device('Ascend') +def p_re_lu(*args): + return p_re_lu_op(*args) + + +def pack(*args): + op = _get_cache_prim(Pack)(*args[-1:]).set_device('Ascend') + return op(*args[:-1]) + + +def pad(*args): + op = _get_cache_prim(Pad)(*args[-1:]).set_device('Ascend') + return op(*args[:-1]) + + +def pad_v3(*args): + op = _get_cache_prim(PadV3)(*args[-2:]).set_device('Ascend') + return op(*args[:-2]) + + +def padding(*args): + op = _get_cache_prim(Padding)(*args[-1:]).set_device('Ascend') + return op(*args[:-1]) + + +def paged_attention(*args): + op = _get_cache_prim(PagedAttention)(*args[-6:]).set_device('Ascend') + return op(*args[:-6]) + + +def paged_attention_mask(*args): + op = _get_cache_prim(PagedAttentionMask)(*args[-4:]).set_device('Ascend') + return op(*args[:-4]) + + +parallel_concat_op = ParallelConcat().set_device('Ascend') +def parallel_concat(*args): + return parallel_concat_op(*args) + + +def parameterized_truncated_normal(*args): + op = _get_cache_prim(ParameterizedTruncatedNormal)(*args[-2:]).set_device('Ascend') + return op(*args[:-2]) + + +partial_op = Partial().set_device('Ascend') +def partial(*args): + return partial_op(*args) + + +def pdist(*args): + op = _get_cache_prim(Pdist)(*args[-1:]).set_device('Ascend') + return op(*args[:-1]) + + +def poisson(*args): + op = _get_cache_prim(Poisson)(*args[-2:]).set_device('Ascend') + return op(*args[:-2]) + + +polar_op = Polar().set_device('Ascend') +def polar(*args): + return polar_op(*args) + + +polygamma_op = Polygamma().set_device('Ascend') +def polygamma(*args): + return polygamma_op(*args) + + +population_count_op = PopulationCount().set_device('Ascend') +def population_count(*args): + return population_count_op(*args) + + +pow_op = Pow().set_device('Ascend') +def pow(*args): + return pow_op(*args) + + +pull_op = Pull().set_device('Ascend') +def pull(*args): + return pull_op(*args) + + +def push(*args): + op = _get_cache_prim(Push)(*args[-2:]).set_device('Ascend') + return op(*args[:-2]) + + +py_execute_op = PyExecute().set_device('Ascend') +def py_execute(*args): + return py_execute_op(*args) + + +def py_func(*args): + op = _get_cache_prim(PyFunc)(*args[-6:]).set_device('Ascend') + return op(*args[:-6]) + + +def qr(*args): + op = _get_cache_prim(Qr)(*args[-1:]).set_device('Ascend') + return op(*args[:-1]) + + +def quantile(*args): + op = _get_cache_prim(Quantile)(*args[-3:]).set_device('Ascend') + return op(*args[:-3]) + + +rgb_to_hsv_op = RGBToHSV().set_device('Ascend') +def rgb_to_hsv(*args): + return rgb_to_hsv_op(*args) + + +def rnnt_loss(*args): + op = _get_cache_prim(RNNTLoss)(*args[-1:]).set_device('Ascend') + return op(*args[:-1]) + + +def roi_align(*args): + op = _get_cache_prim(ROIAlign)(*args[-5:]).set_device('Ascend') + return op(*args[:-5]) + + +def ragged_range(*args): + op = _get_cache_prim(RaggedRange)(*args[-1:]).set_device('Ascend') + return op(*args[:-1]) + + +def random_categorical(*args): + op = _get_cache_prim(RandomCategorical)(*args[-1:]).set_device('Ascend') + return op(*args[:-1]) + + +def random_choice_with_mask(*args): + op = _get_cache_prim(RandomChoiceWithMask)(*args[-3:]).set_device('Ascend') + return op(*args[:-3]) + + +def random_gamma(*args): + op = _get_cache_prim(RandomGamma)(*args[-2:]).set_device('Ascend') + return op(*args[:-2]) + + +def random_gamma(*args): + op = _get_cache_prim(RandomGamma)(*args[-2:]).set_device('Ascend') + return op(*args[:-2]) + + +def random_poisson(*args): + op = _get_cache_prim(RandomPoisson)(*args[-3:]).set_device('Ascend') + return op(*args[:-3]) + + +def random_shuffle(*args): + op = _get_cache_prim(RandomShuffle)(*args[-2:]).set_device('Ascend') + return op(*args[:-2]) + + +def randperm(*args): + op = _get_cache_prim(Randperm)(*args[-3:]).set_device('Ascend') + return op(*args[:-3]) + + +def randperm_v2(*args): + op = _get_cache_prim(RandpermV2)(*args[-3:]).set_device('Ascend') + return op(*args[:-3]) + + +def range(*args): + op = _get_cache_prim(Range)(*args[-1:]).set_device('Ascend') + return op(*args[:-1]) + + +rank_op = Rank().set_device('Ascend') +def rank(*args): + return rank_op(*args) + + +re_lu_op = ReLU().set_device('Ascend') +def re_lu(*args): + return re_lu_op(*args) + + +re_lu6_op = ReLU6().set_device('Ascend') +def re_lu6(*args): + return re_lu6_op(*args) + + +real_op = Real().set_device('Ascend') +def real(*args): + return real_op(*args) + + +real_div_op = RealDiv().set_device('Ascend') +def real_div(*args): + return real_div_op(*args) + + +def receive(*args): + op = _get_cache_prim(Receive)(*args[-6:]).set_device('Ascend') + return op(*args[:-6]) + + +reciprocal_op = Reciprocal().set_device('Ascend') +def reciprocal(*args): + return reciprocal_op(*args) + + +def reduce(*args): + op = _get_cache_prim(Reduce)(*args[-3:]).set_device('Ascend') + return op(*args[:-3]) + + +def reduce_all(*args): + op = _get_cache_prim(ReduceAll)(*args[-1:]).set_device('Ascend') + return op(*args[:-1]) + + +def reduce_any(*args): + op = _get_cache_prim(ReduceAny)(*args[-1:]).set_device('Ascend') + return op(*args[:-1]) + + +def reduce_max(*args): + op = _get_cache_prim(ReduceMax)(*args[-1:]).set_device('Ascend') + return op(*args[:-1]) + + +def reduce_mean(*args): + op = _get_cache_prim(ReduceMean)(*args[-1:]).set_device('Ascend') + return op(*args[:-1]) + + +def reduce_min(*args): + op = _get_cache_prim(ReduceMin)(*args[-1:]).set_device('Ascend') + return op(*args[:-1]) + + +def reduce_prod(*args): + op = _get_cache_prim(ReduceProd)(*args[-1:]).set_device('Ascend') + return op(*args[:-1]) + + +def reduce_scatter(*args): + op = _get_cache_prim(ReduceScatter)(*args[-2:]).set_device('Ascend') + return op(*args[:-2]) + + +def reduce_std(*args): + op = _get_cache_prim(ReduceStd)(*args[-3:]).set_device('Ascend') + return op(*args[:-3]) + + +def reduce_sum(*args): + op = _get_cache_prim(ReduceSum)(*args[-2:]).set_device('Ascend') + return op(*args[:-2]) + + +def renorm(*args): + op = _get_cache_prim(Renorm)(*args[-3:]).set_device('Ascend') + return op(*args[:-3]) + + +reshape_op = Reshape().set_device('Ascend') +def reshape(*args): + return reshape_op(*args) + + +reshape_and_cache_op = ReshapeAndCache().set_device('Ascend') +def reshape_and_cache(*args): + return reshape_and_cache_op(*args) + + +def reshard(*args): + op = _get_cache_prim(Reshard)(*args[-3:]).set_device('Ascend') + return op(*args[:-3]) + + +def resize_area(*args): + op = _get_cache_prim(ResizeArea)(*args[-1:]).set_device('Ascend') + return op(*args[:-1]) + + +def resize_bicubic(*args): + op = _get_cache_prim(ResizeBicubic)(*args[-2:]).set_device('Ascend') + return op(*args[:-2]) + + +def resize_bilinear_v2(*args): + op = _get_cache_prim(ResizeBilinearV2)(*args[-2:]).set_device('Ascend') + return op(*args[:-2]) + + +def resize_linear1_d(*args): + op = _get_cache_prim(ResizeLinear1D)(*args[-1:]).set_device('Ascend') + return op(*args[:-1]) + + +def resize_nearest_neighbor(*args): + op = _get_cache_prim(ResizeNearestNeighbor)(*args[-3:]).set_device('Ascend') + return op(*args[:-3]) + + +def resize_nearest_neighbor_v2(*args): + op = _get_cache_prim(ResizeNearestNeighborV2)(*args[-2:]).set_device('Ascend') + return op(*args[:-2]) + + +reusing_op = Reusing().set_device('Ascend') +def reusing(*args): + return reusing_op(*args) + + +def reverse_sequence(*args): + op = _get_cache_prim(ReverseSequence)(*args[-2:]).set_device('Ascend') + return op(*args[:-2]) + + +def reverse_v2(*args): + op = _get_cache_prim(ReverseV2)(*args[-1:]).set_device('Ascend') + return op(*args[:-1]) + + +right_shift_op = RightShift().set_device('Ascend') +def right_shift(*args): + return right_shift_op(*args) + + +rint_op = Rint().set_device('Ascend') +def rint(*args): + return rint_op(*args) + + +def rms_norm(*args): + op = _get_cache_prim(RmsNorm)(*args[-1:]).set_device('Ascend') + return op(*args[:-1]) + + +def roll(*args): + op = _get_cache_prim(Roll)(*args[-2:]).set_device('Ascend') + return op(*args[:-2]) + + +round_op = Round().set_device('Ascend') +def round(*args): + return round_op(*args) + + +rsqrt_op = Rsqrt().set_device('Ascend') +def rsqrt(*args): + return rsqrt_op(*args) + + +def sgd(*args): + op = _get_cache_prim(SGD)(*args[-3:]).set_device('Ascend') + return op(*args[:-3]) + + +def stft(*args): + op = _get_cache_prim(STFT)(*args[-6:]).set_device('Ascend') + return op(*args[:-6]) + + +def sample_distorted_bounding_box_v2(*args): + op = _get_cache_prim(SampleDistortedBoundingBoxV2)(*args[-6:]).set_device('Ascend') + return op(*args[:-6]) + + +scalar_summary_op = ScalarSummary().set_device('Ascend') +def scalar_summary(*args): + return scalar_summary_op(*args) + + +scalar_to_tensor_op = ScalarToTensor().set_device('Ascend') +def scalar_to_tensor(*args): + return scalar_to_tensor_op(*args) + + +def scale_and_translate(*args): + op = _get_cache_prim(ScaleAndTranslate)(*args[-2:]).set_device('Ascend') + return op(*args[:-2]) + + +scan_op = Scan().set_device('Ascend') +def scan(*args): + return scan_op(*args) + + +def scatter_add(*args): + op = _get_cache_prim(ScatterAdd)(*args[-1:]).set_device('Ascend') + return op(*args[:-1]) + + +def scatter_add_with_axis(*args): + op = _get_cache_prim(ScatterAddWithAxis)(*args[-1:]).set_device('Ascend') + return op(*args[:-1]) + + +def scatter_div(*args): + op = _get_cache_prim(ScatterDiv)(*args[-1:]).set_device('Ascend') + return op(*args[:-1]) + + +def scatter_max(*args): + op = _get_cache_prim(ScatterMax)(*args[-1:]).set_device('Ascend') + return op(*args[:-1]) + + +def scatter_min(*args): + op = _get_cache_prim(ScatterMin)(*args[-1:]).set_device('Ascend') + return op(*args[:-1]) + + +def scatter_mul(*args): + op = _get_cache_prim(ScatterMul)(*args[-1:]).set_device('Ascend') + return op(*args[:-1]) + + +scatter_nd_op = ScatterNd().set_device('Ascend') +def scatter_nd(*args): + return scatter_nd_op(*args) + + +def scatter_nd_add(*args): + op = _get_cache_prim(ScatterNdAdd)(*args[-1:]).set_device('Ascend') + return op(*args[:-1]) + + +def scatter_nd_div(*args): + op = _get_cache_prim(ScatterNdDiv)(*args[-1:]).set_device('Ascend') + return op(*args[:-1]) + + +def scatter_nd_max(*args): + op = _get_cache_prim(ScatterNdMax)(*args[-1:]).set_device('Ascend') + return op(*args[:-1]) + + +def scatter_nd_min(*args): + op = _get_cache_prim(ScatterNdMin)(*args[-1:]).set_device('Ascend') + return op(*args[:-1]) + + +def scatter_nd_mul(*args): + op = _get_cache_prim(ScatterNdMul)(*args[-1:]).set_device('Ascend') + return op(*args[:-1]) + + +def scatter_nd_sub(*args): + op = _get_cache_prim(ScatterNdSub)(*args[-1:]).set_device('Ascend') + return op(*args[:-1]) + + +def scatter_nd_update(*args): + op = _get_cache_prim(ScatterNdUpdate)(*args[-1:]).set_device('Ascend') + return op(*args[:-1]) + + +def scatter_sub(*args): + op = _get_cache_prim(ScatterSub)(*args[-1:]).set_device('Ascend') + return op(*args[:-1]) + + +def scatter_update(*args): + op = _get_cache_prim(ScatterUpdate)(*args[-1:]).set_device('Ascend') + return op(*args[:-1]) + + +se_lu_op = SeLU().set_device('Ascend') +def se_lu(*args): + return se_lu_op(*args) + + +def search_sorted(*args): + op = _get_cache_prim(SearchSorted)(*args[-2:]).set_device('Ascend') + return op(*args[:-2]) + + +segment_max_op = SegmentMax().set_device('Ascend') +def segment_max(*args): + return segment_max_op(*args) + + +segment_mean_op = SegmentMean().set_device('Ascend') +def segment_mean(*args): + return segment_mean_op(*args) + + +segment_min_op = SegmentMin().set_device('Ascend') +def segment_min(*args): + return segment_min_op(*args) + + +segment_prod_op = SegmentProd().set_device('Ascend') +def segment_prod(*args): + return segment_prod_op(*args) + + +segment_sum_op = SegmentSum().set_device('Ascend') +def segment_sum(*args): + return segment_sum_op(*args) + + +select_op = Select().set_device('Ascend') +def select(*args): + return select_op(*args) + + +select_view_op = SelectView().set_device('Ascend') +def select_view(*args): + return select_view_op(*args) + + +def send(*args): + op = _get_cache_prim(Send)(*args[-4:]).set_device('Ascend') + return op(*args[:-4]) + + +shape_op = Shape().set_device('Ascend') +def shape(*args): + return shape_op(*args) + + +sigmoid_op = Sigmoid().set_device('Ascend') +def sigmoid(*args): + return sigmoid_op(*args) + + +sigmoid_cross_entropy_with_logits_op = SigmoidCrossEntropyWithLogits().set_device('Ascend') +def sigmoid_cross_entropy_with_logits(*args): + return sigmoid_cross_entropy_with_logits_op(*args) + + +sign_op = Sign().set_device('Ascend') +def sign(*args): + return sign_op(*args) + + +sin_op = Sin().set_device('Ascend') +def sin(*args): + return sin_op(*args) + + +sinc_op = Sinc().set_device('Ascend') +def sinc(*args): + return sinc_op(*args) + + +sinh_op = Sinh().set_device('Ascend') +def sinh(*args): + return sinh_op(*args) + + +size_op = Size().set_device('Ascend') +def size(*args): + return size_op(*args) + + +slice_op = Slice().set_device('Ascend') +def slice(*args): + return slice_op(*args) + + +def smooth_l1_loss(*args): + op = _get_cache_prim(SmoothL1Loss)(*args[-2:]).set_device('Ascend') + return op(*args[:-2]) + + +def soft_margin_loss(*args): + op = _get_cache_prim(SoftMarginLoss)(*args[-1:]).set_device('Ascend') + return op(*args[:-1]) + + +def soft_shrink(*args): + op = _get_cache_prim(SoftShrink)(*args[-1:]).set_device('Ascend') + return op(*args[:-1]) + + +def softmax(*args): + op = _get_cache_prim(Softmax)(*args[-1:]).set_device('Ascend') + return op(*args[:-1]) + + +softmax_cross_entropy_with_logits_op = SoftmaxCrossEntropyWithLogits().set_device('Ascend') +def softmax_cross_entropy_with_logits(*args): + return softmax_cross_entropy_with_logits_op(*args) + + +softplus_op = Softplus().set_device('Ascend') +def softplus(*args): + return softplus_op(*args) + + +softsign_op = Softsign().set_device('Ascend') +def softsign(*args): + return softsign_op(*args) + + +def sort(*args): + op = _get_cache_prim(Sort)(*args[-2:]).set_device('Ascend') + return op(*args[:-2]) + + +def space_to_batch(*args): + op = _get_cache_prim(SpaceToBatch)(*args[-2:]).set_device('Ascend') + return op(*args[:-2]) + + +def space_to_batch_nd(*args): + op = _get_cache_prim(SpaceToBatchND)(*args[-2:]).set_device('Ascend') + return op(*args[:-2]) + + +def space_to_depth(*args): + op = _get_cache_prim(SpaceToDepth)(*args[-1:]).set_device('Ascend') + return op(*args[:-1]) + + +def sparse_apply_adadelta(*args): + op = _get_cache_prim(SparseApplyAdadelta)(*args[-2:]).set_device('Ascend') + return op(*args[:-2]) + + +def sparse_apply_adagrad(*args): + op = _get_cache_prim(SparseApplyAdagrad)(*args[-3:]).set_device('Ascend') + return op(*args[:-3]) + + +def sparse_apply_adagrad_v2(*args): + op = _get_cache_prim(SparseApplyAdagradV2)(*args[-4:]).set_device('Ascend') + return op(*args[:-4]) + + +def sparse_apply_ftrl(*args): + op = _get_cache_prim(SparseApplyFtrl)(*args[-5:]).set_device('Ascend') + return op(*args[:-5]) + + +def sparse_apply_ftrl_v2(*args): + op = _get_cache_prim(SparseApplyFtrlV2)(*args[-6:]).set_device('Ascend') + return op(*args[:-6]) + + +def sparse_apply_proximal_adagrad(*args): + op = _get_cache_prim(SparseApplyProximalAdagrad)(*args[-1:]).set_device('Ascend') + return op(*args[:-1]) + + +def sparse_apply_rms_prop(*args): + op = _get_cache_prim(SparseApplyRMSProp)(*args[-4:]).set_device('Ascend') + return op(*args[:-4]) + + +sparse_gather_v2_op = SparseGatherV2().set_device('Ascend') +def sparse_gather_v2(*args): + return sparse_gather_v2_op(*args) + + +sparse_slice_op = SparseSlice().set_device('Ascend') +def sparse_slice(*args): + return sparse_slice_op(*args) + + +def sparse_softmax_cross_entropy_with_logits(*args): + op = _get_cache_prim(SparseSoftmaxCrossEntropyWithLogits)(*args[-1:]).set_device('Ascend') + return op(*args[:-1]) + + +sparse_tensor_dense_add_op = SparseTensorDenseAdd().set_device('Ascend') +def sparse_tensor_dense_add(*args): + return sparse_tensor_dense_add_op(*args) + + +def sparse_tensor_dense_matmul(*args): + op = _get_cache_prim(SparseTensorDenseMatmul)(*args[-2:]).set_device('Ascend') + return op(*args[:-2]) + + +sparse_to_dense_op = SparseToDense().set_device('Ascend') +def sparse_to_dense(*args): + return sparse_to_dense_op(*args) + + +def split(*args): + op = _get_cache_prim(Split)(*args[-2:]).set_device('Ascend') + return op(*args[:-2]) + + +def split_v(*args): + op = _get_cache_prim(SplitV)(*args[-3:]).set_device('Ascend') + return op(*args[:-3]) + + +sqrt_op = Sqrt().set_device('Ascend') +def sqrt(*args): + return sqrt_op(*args) + + +square_op = Square().set_device('Ascend') +def square(*args): + return square_op(*args) + + +square_sum_all_op = SquareSumAll().set_device('Ascend') +def square_sum_all(*args): + return square_sum_all_op(*args) + + +squared_difference_op = SquaredDifference().set_device('Ascend') +def squared_difference(*args): + return squared_difference_op(*args) + + +def squeeze(*args): + op = _get_cache_prim(Squeeze)(*args[-1:]).set_device('Ascend') + return op(*args[:-1]) + + +def stack(*args): + op = _get_cache_prim(Stack)(*args[-1:]).set_device('Ascend') + return op(*args[:-1]) + + +def standard_laplace(*args): + op = _get_cache_prim(StandardLaplace)(*args[-2:]).set_device('Ascend') + return op(*args[:-2]) + + +def standard_normal(*args): + op = _get_cache_prim(StandardNormal)(*args[-2:]).set_device('Ascend') + return op(*args[:-2]) + + +stop_gradient_op = StopGradient().set_device('Ascend') +def stop_gradient(*args): + return stop_gradient_op(*args) + + +def strided_slice(*args): + op = _get_cache_prim(StridedSlice)(*args[-5:]).set_device('Ascend') + return op(*args[:-5]) + + +sub_op = Sub().set_device('Ascend') +def sub(*args): + return sub_op(*args) + + +sub_and_filter_op = SubAndFilter().set_device('Ascend') +def sub_and_filter(*args): + return sub_and_filter_op(*args) + + +def svd(*args): + op = _get_cache_prim(Svd)(*args[-2:]).set_device('Ascend') + return op(*args[:-2]) + + +tan_op = Tan().set_device('Ascend') +def tan(*args): + return tan_op(*args) + + +tanh_op = Tanh().set_device('Ascend') +def tanh(*args): + return tanh_op(*args) + + +def tensor_dump(*args): + op = _get_cache_prim(TensorDump)(*args[-1:]).set_device('Ascend') + return op(*args[:-1]) + + +tensor_scatter_add_op = TensorScatterAdd().set_device('Ascend') +def tensor_scatter_add(*args): + return tensor_scatter_add_op(*args) + + +tensor_scatter_div_op = TensorScatterDiv().set_device('Ascend') +def tensor_scatter_div(*args): + return tensor_scatter_div_op(*args) + + +def tensor_scatter_elements(*args): + op = _get_cache_prim(TensorScatterElements)(*args[-2:]).set_device('Ascend') + return op(*args[:-2]) + + +tensor_scatter_max_op = TensorScatterMax().set_device('Ascend') +def tensor_scatter_max(*args): + return tensor_scatter_max_op(*args) + + +tensor_scatter_min_op = TensorScatterMin().set_device('Ascend') +def tensor_scatter_min(*args): + return tensor_scatter_min_op(*args) + + +tensor_scatter_mul_op = TensorScatterMul().set_device('Ascend') +def tensor_scatter_mul(*args): + return tensor_scatter_mul_op(*args) + + +tensor_scatter_sub_op = TensorScatterSub().set_device('Ascend') +def tensor_scatter_sub(*args): + return tensor_scatter_sub_op(*args) + + +tensor_scatter_update_op = TensorScatterUpdate().set_device('Ascend') +def tensor_scatter_update(*args): + return tensor_scatter_update_op(*args) + + +tensor_shape_op = TensorShape().set_device('Ascend') +def tensor_shape(*args): + return tensor_shape_op(*args) + + +tensor_summary_op = TensorSummary().set_device('Ascend') +def tensor_summary(*args): + return tensor_summary_op(*args) + + +tile_op = Tile().set_device('Ascend') +def tile(*args): + return tile_op(*args) + + +def top_k(*args): + op = _get_cache_prim(TopK)(*args[-1:]).set_device('Ascend') + return op(*args[:-1]) + + +trace_op = Trace().set_device('Ascend') +def trace(*args): + return trace_op(*args) + + +transpose_op = Transpose().set_device('Ascend') +def transpose(*args): + return transpose_op(*args) + + +transpose_ext_view_op = TransposeExtView().set_device('Ascend') +def transpose_ext_view(*args): + return transpose_ext_view_op(*args) + + +transpose_view_op = TransposeView().set_device('Ascend') +def transpose_view(*args): + return transpose_view_op(*args) + + +tridiagonal_mat_mul_op = TridiagonalMatMul().set_device('Ascend') +def tridiagonal_mat_mul(*args): + return tridiagonal_mat_mul_op(*args) + + +def tril(*args): + op = _get_cache_prim(Tril)(*args[-1:]).set_device('Ascend') + return op(*args[:-1]) + + +def tril_indices(*args): + op = _get_cache_prim(TrilIndices)(*args[-4:]).set_device('Ascend') + return op(*args[:-4]) + + +def triplet_margin_loss(*args): + op = _get_cache_prim(TripletMarginLoss)(*args[-4:]).set_device('Ascend') + return op(*args[:-4]) + + +def triu(*args): + op = _get_cache_prim(Triu)(*args[-1:]).set_device('Ascend') + return op(*args[:-1]) + + +def triu_indices(*args): + op = _get_cache_prim(TriuIndices)(*args[-4:]).set_device('Ascend') + return op(*args[:-4]) + + +trunc_op = Trunc().set_device('Ascend') +def trunc(*args): + return trunc_op(*args) + + +truncate_div_op = TruncateDiv().set_device('Ascend') +def truncate_div(*args): + return truncate_div_op(*args) + + +truncate_mod_op = TruncateMod().set_device('Ascend') +def truncate_mod(*args): + return truncate_mod_op(*args) + + +def truncated_normal(*args): + op = _get_cache_prim(TruncatedNormal)(*args[-3:]).set_device('Ascend') + return op(*args[:-3]) + + +tuple_to_array_op = TupleToArray().set_device('Ascend') +def tuple_to_array(*args): + return tuple_to_array_op(*args) + + +def uniform_candidate_sampler(*args): + op = _get_cache_prim(UniformCandidateSampler)(*args[-6:]).set_device('Ascend') + return op(*args[:-6]) + + +def uniform_int(*args): + op = _get_cache_prim(UniformInt)(*args[-2:]).set_device('Ascend') + return op(*args[:-2]) + + +def uniform_real(*args): + op = _get_cache_prim(UniformReal)(*args[-2:]).set_device('Ascend') + return op(*args[:-2]) + + +unique_op = Unique().set_device('Ascend') +def unique(*args): + return unique_op(*args) + + +def unique_consecutive(*args): + op = _get_cache_prim(UniqueConsecutive)(*args[-3:]).set_device('Ascend') + return op(*args[:-3]) + +def unpack(*args): + op = _get_cache_prim(Unpack)(*args[-1:]).set_device('Ascend') + return op(*args[:-1]) + + +unravel_index_op = UnravelIndex().set_device('Ascend') +def unravel_index(*args): + return unravel_index_op(*args) + + +unsorted_segment_max_op = UnsortedSegmentMax().set_device('Ascend') +def unsorted_segment_max(*args): + return unsorted_segment_max_op(*args) + + +unsorted_segment_min_op = UnsortedSegmentMin().set_device('Ascend') +def unsorted_segment_min(*args): + return unsorted_segment_min_op(*args) + + +unsorted_segment_prod_op = UnsortedSegmentProd().set_device('Ascend') +def unsorted_segment_prod(*args): + return unsorted_segment_prod_op(*args) + + +unsorted_segment_sum_op = UnsortedSegmentSum().set_device('Ascend') +def unsorted_segment_sum(*args): + return unsorted_segment_sum_op(*args) + + +def unstack(*args): + op = _get_cache_prim(Unstack)(*args[-2:]).set_device('Ascend') + return op(*args[:-2]) + + +update_state_op = UpdateState().set_device('Ascend') +def update_state(*args): + return update_state_op(*args) + + +def upper_bound(*args): + op = _get_cache_prim(UpperBound)(*args[-1:]).set_device('Ascend') + return op(*args[:-1]) + + +upsample_nearest3_d_op = UpsampleNearest3D().set_device('Ascend') +def upsample_nearest3_d(*args): + return upsample_nearest3_d_op(*args) + + +def upsample_trilinear3_d(*args): + op = _get_cache_prim(UpsampleTrilinear3D)(*args[-1:]).set_device('Ascend') + return op(*args[:-1]) + + +while_loop_op = WhileLoop().set_device('Ascend') +def while_loop(*args): + return while_loop_op(*args) + + +xdivy_op = Xdivy().set_device('Ascend') +def xdivy(*args): + return xdivy_op(*args) + + +xlogy_op = Xlogy().set_device('Ascend') +def xlogy(*args): + return xlogy_op(*args) + + +zeros_op = Zeros().set_device('Ascend') +def zeros(*args): + return zeros_op(*args) + + +zeros_like_op = ZerosLike().set_device('Ascend') +def zeros_like(*args): + return zeros_like_op(*args) + + +zeta_op = Zeta().set_device('Ascend') +def zeta(*args): + return zeta_op(*args) + +def max_pool_with_indices(input, kernel_size, strides, padding, dilation, ceil_mode): + max_pool_func_ = _get_cache_prim(MaxPoolWithMask)(kernel_size, strides, padding, dilation, ceil_mode) + return max_pool_func_(input) diff --git a/torch4ms/_op_prim/ascend/pyboost.py b/torch4ms/_op_prim/ascend/pyboost.py new file mode 100644 index 000000000..69e5d48ca --- /dev/null +++ b/torch4ms/_op_prim/ascend/pyboost.py @@ -0,0 +1,877 @@ +from mindspore.ops.auto_generate.gen_ops_prim import * +from mindspore.ops.auto_generate.pyboost_inner_prim import * + +abs_op = Abs().set_device('Ascend') + +acos_ext_op = AcosExt().set_device('Ascend') + +acosh_ext_op = AcoshExt().set_device('Ascend') + +adamw_op = AdamW().set_device('Ascend') + +adaptive_avg_pool1d_op = AdaptiveAvgPool1D().set_device('Ascend') + +adaptive_avg_pool2d_ext_op = AdaptiveAvgPool2DExt().set_device('Ascend') + +adaptive_avg_pool2d_grad_ext_op = AdaptiveAvgPool2DGradExt().set_device('Ascend') + +adaptive_avg_pool3d_ext_op = AdaptiveAvgPool3DExt().set_device('Ascend') + +adaptive_avg_pool3d_grad_ext_op = AdaptiveAvgPool3DGradExt().set_device('Ascend') + +adaptive_max_pool1d_op = AdaptiveMaxPool1D().set_device('Ascend') + +add_op = Add().set_device('Ascend') + +add_ext_op = AddExt().set_device('Ascend') + +add_layer_norm_grad_op = AddLayerNormGrad().set_device('Ascend') + +add_layernorm_v2_op = AddLayerNormV2().set_device('Ascend') + +add_rms_norm_op = AddRmsNorm().set_device('Ascend') + +add_scalar_op = AddScalar().set_device('Ascend') + +addbmm_op = Addbmm().set_device('Ascend') + +addcdiv_ext_op = AddcdivExt().set_device('Ascend') + +addcmul_ext_op = AddcmulExt().set_device('Ascend') + +addmm_op = Addmm().set_device('Ascend') + +addmv_op = Addmv().set_device('Ascend') + +all_gather_matmul_op = AllGatherMatmul().set_device('Ascend') + +arange_op = Arange().set_device('Ascend') + +argmax_ext_op = ArgMaxExt().set_device('Ascend') + +argmin_ext_op = ArgMinExt().set_device('Ascend') + +argsort_op = ArgSort().set_device('Ascend') + +as_strided_op = AsStrided().set_device('Ascend') + +asin_ext_op = AsinExt().set_device('Ascend') + +asinh_ext_op = AsinhExt().set_device('Ascend') + +atan2_ext_op = Atan2Ext().set_device('Ascend') + +atan_ext_op = AtanExt().set_device('Ascend') + +atanh_op = Atanh().set_device('Ascend') + +avg_pool1d_op = AvgPool1D().set_device('Ascend') + +avg_pool2d_op = AvgPool2D().set_device('Ascend') + +avg_pool2d_grad_op = AvgPool2DGrad().set_device('Ascend') + +avg_pool3d_ext_op = AvgPool3DExt().set_device('Ascend') + +avg_pool3d_grad_ext_op = AvgPool3DGradExt().set_device('Ascend') + +baddbmm_op = Baddbmm().set_device('Ascend') + +batch_norm_elemt_op = BatchNormElemt().set_device('Ascend') + +batch_norm_elemt_grad_op = BatchNormElemtGrad().set_device('Ascend') + +batch_norm_ext_op = BatchNormExt().set_device('Ascend') + +batch_norm_gather_stats_with_counts_op = BatchNormGatherStatsWithCounts().set_device('Ascend') + +batch_norm_reduce_grad_op = BatchNormReduceGrad().set_device('Ascend') + +batch_norm_stats_op = BatchNormStats().set_device('Ascend') + +bernoulli_ext_op = BernoulliExt().set_device('Ascend') + +binary_cross_entropy_with_logits_backward_op = BinaryCrossEntropyWithLogitsBackward().set_device('Ascend') + +bincount_ext_op = BincountExt().set_device('Ascend') + +bitwise_and_scalar_op = BitwiseAndScalar().set_device('Ascend') + +bitwise_and_tensor_op = BitwiseAndTensor().set_device('Ascend') + +bitwise_not_op = BitwiseNot().set_device('Ascend') + +bitwise_or_scalar_op = BitwiseOrScalar().set_device('Ascend') + +bitwise_or_tensor_op = BitwiseOrTensor().set_device('Ascend') + +bitwise_xor_scalar_op = BitwiseXorScalar().set_device('Ascend') + +bitwise_xor_tensor_op = BitwiseXorTensor().set_device('Ascend') + +bmm_ext_op = BatchMatMulExt().set_device('Ascend') + +broadcast_to_view_op = BroadcastToView().set_device('Ascend') + +ceil_op = Ceil().set_device('Ascend') + +chunk_op = Chunk().set_device('Ascend') + +chunk_view_op = ChunkView().set_device('Ascend') + +clamp_scalar_op = ClampScalar().set_device('Ascend') + +clamp_tensor_op = ClampTensor().set_device('Ascend') + +clone_op = Clone().set_device('Ascend') + +col2im_ext_op = Col2ImExt().set_device('Ascend') + +col2im_grad_op = Col2ImGrad().set_device('Ascend') + +constant_pad_nd_op = ConstantPadND().set_device('Ascend') + +contiguous_op = Contiguous().set_device('Ascend') + +conv1d_ext_op = Conv1DExt().set_device('Ascend') + +conv1d_padding_op = Conv1DPadding().set_device('Ascend') + +conv2d_ext_op = Conv2DExt().set_device('Ascend') + +conv2d_padding_op = Conv2DPadding().set_device('Ascend') + +conv3d_ext_op = Conv3DExt().set_device('Ascend') + +conv3d_padding_op = Conv3DPadding().set_device('Ascend') + +conv_transpose2d_op = ConvTranspose2D().set_device('Ascend') + +convolution_op = Convolution().set_device('Ascend') + +convolution_grad_op = ConvolutionGrad().set_device('Ascend') + +convolution_str_op = ConvolutionStr().set_device('Ascend') + +convolution_str_grad_op = ConvolutionStrGrad().set_device('Ascend') + +copy_op = Copy().set_device('Ascend') + +cos_op = Cos().set_device('Ascend') + +cosh_op = Cosh().set_device('Ascend') + +count_nonzero_op = CountNonZero().set_device('Ascend') + +cummin_ext_op = CumminExt().set_device('Ascend') + +cumsum_ext_op = CumsumExt().set_device('Ascend') + +dense_op = Dense().set_device('Ascend') + +diag_ext_op = DiagExt().set_device('Ascend') + +dist_comm_all_gather_op = DistCommAllGather().set_device('Ascend') + +dist_comm_all_gather_into_tensor_op = DistCommAllGatherIntoTensor().set_device('Ascend') + +dist_comm_all_reduce_op = DistCommAllReduce().set_device('Ascend') + +dist_comm_all_to_all_v_op = DistCommAllToAllV().set_device('Ascend') + +dist_comm_all_to_all_v_single_op = DistCommAllToAllVSingle().set_device('Ascend') + +dist_comm_barrier_op = DistCommBarrier().set_device('Ascend') + +dist_comm_batch_isend_irecv_op = DistCommBatchIsendIrecv().set_device('Ascend') + +dist_comm_broadcast_op = DistCommBroadcast().set_device('Ascend') + +dist_comm_gather_op = DistCommGather().set_device('Ascend') + +dist_comm_gather_into_tensor_op = DistCommGatherIntoTensor().set_device('Ascend') + +dist_comm_irecv_op = DistCommIrecv().set_device('Ascend') + +dist_comm_isend_op = DistCommIsend().set_device('Ascend') + +dist_comm_reduce_op = DistCommReduce().set_device('Ascend') + +dist_comm_reduce_scatter_op = DistCommReduceScatter().set_device('Ascend') + +dist_comm_reduce_scatter_tensor_op = DistCommReduceScatterTensor().set_device('Ascend') + +dist_comm_scatter_op = DistCommScatter().set_device('Ascend') + +dist_comm_scatter_tensor_op = DistCommScatterTensor().set_device('Ascend') + +div_op = Div().set_device('Ascend') + +divmod_op = DivMod().set_device('Ascend') + +divmods_op = DivMods().set_device('Ascend') + +divs_op = Divs().set_device('Ascend') + +dot_op = Dot().set_device('Ascend') + +dropout_do_mask_ext_op = DropoutDoMaskExt().set_device('Ascend') + +dropout_ext_op = DropoutExt().set_device('Ascend') + +dropout_gen_mask_ext_op = DropoutGenMaskExt().set_device('Ascend') + +dropout_grad_ext_op = DropoutGradExt().set_device('Ascend') + +dynamic_quant_ext_op = DynamicQuantExt().set_device('Ascend') + +elu_grad_ext_op = EluGradExt().set_device('Ascend') + +embedding_op = Embedding().set_device('Ascend') + +embedding_dense_backward_op = EmbeddingDenseBackward().set_device('Ascend') + +equal_op = Equal().set_device('Ascend') + +equal_ext_op = EqualExt().set_device('Ascend') + +erf_op = Erf().set_device('Ascend') + +erfc_op = Erfc().set_device('Ascend') + +erfinv_op = Erfinv().set_device('Ascend') + +exp_op = Exp().set_device('Ascend') + +exp2_op = Exp2().set_device('Ascend') + +expand_as_op = ExpandAs().set_device('Ascend') + +expand_dims_op = ExpandDims().set_device('Ascend') + +expand_dims_view_op = ExpandDimsView().set_device('Ascend') + +expm1_op = Expm1().set_device('Ascend') + +eye_op = Eye().set_device('Ascend') + +fill_scalar_op = FillScalar().set_device('Ascend') + +fill_tensor_op = FillTensor().set_device('Ascend') + +flatten_ext_op = FlattenExt().set_device('Ascend') + +floor_op = Floor().set_device('Ascend') + +floor_div_op = FloorDiv().set_device('Ascend') + +floor_div_scalar_op = FloorDivScalar().set_device('Ascend') + +fmod_scalar_op = FmodScalar().set_device('Ascend') + +fmod_tensor_op = FmodTensor().set_device('Ascend') + +frac_op = Frac().set_device('Ascend') + +full_like_op = FullLike().set_device('Ascend') + +gather_d_op = GatherD().set_device('Ascend') + +gather_d_grad_v2_op = GatherDGradV2().set_device('Ascend') + +gcd_op = Gcd().set_device('Ascend') + +gelu_op = GeLU().set_device('Ascend') + +gelu_ext_op = GeluExt().set_device('Ascend') + +gelu_grad_op = GeLUGrad().set_device('Ascend') + +gelu_grad_ext_op = GeluGradExt().set_device('Ascend') + +generator_op = Generator().set_device('Ascend') + +gmm_op = Gmm().set_device('Ascend') + +gmm_backward_op = GmmBackward().set_device('Ascend') + +gmm_backward_fusion_op = GmmBackwardFusion().set_device('Ascend') + +gmm_v2_op = GmmV2().set_device('Ascend') + +gmm_v2_backward_op = GmmV2Backward().set_device('Ascend') + +gmm_v2_backward_fusion_op = GmmV2BackwardFusion().set_device('Ascend') + +greater_op = Greater().set_device('Ascend') + +greater_equal_op = GreaterEqual().set_device('Ascend') + +greater_equal_scalar_op = GreaterEqualScalar().set_device('Ascend') + +group_norm_op = GroupNorm().set_device('Ascend') + +group_norm_grad_op = GroupNormGrad().set_device('Ascend') + +grouped_matmul_v2_op = GroupedMatmulV2().set_device('Ascend') + +grouped_matmul_v4_op = GroupedMatmulV4().set_device('Ascend') + +hardtanh_op = Hardtanh().set_device('Ascend') + +hardtanh_grad_op = HardtanhGrad().set_device('Ascend') + +histc_ext_op = HistcExt().set_device('Ascend') + +hsigmoid_op = HSigmoid().set_device('Ascend') + +hsigmoid_grad_op = HSigmoidGrad().set_device('Ascend') + +hswish_op = HSwish().set_device('Ascend') + +hswish_grad_op = HSwishGrad().set_device('Ascend') + +im2col_ext_op = Im2ColExt().set_device('Ascend') + +index_op = Index().set_device('Ascend') + +index_add_ext_op = IndexAddExt().set_device('Ascend') + +index_fill_scalar_op = IndexFillScalar().set_device('Ascend') + +index_fill_tensor_op = IndexFillTensor().set_device('Ascend') + +index_select_op = IndexSelect().set_device('Ascend') + +inner_comm_all_gather_op = InnerCommAllGather().set_device('Ascend') + +inner_comm_all_reduce_op = InnerCommAllReduce().set_device('Ascend') + +inner_comm_all_to_all_v_op = InnerCommAllToAllV().set_device('Ascend') + +inner_comm_irecv_op = InnerCommIrecv().set_device('Ascend') + +inner_comm_isend_op = InnerCommIsend().set_device('Ascend') + +inner_comm_reduce_scatter_op = InnerCommReduceScatter().set_device('Ascend') + +inner_index_op = InnerIndex().set_device('Ascend') + +inner_inplace_index_put_op = InnerInplaceIndexPut().set_device('Ascend') + +inner_non_zero_op = InnerNonZero().set_device('Ascend') + +inplace_add_ext_op = InplaceAddExt().set_device('Ascend') + +inplace_addmm_op = InplaceAddmm().set_device('Ascend') + +inplace_adds_ext_op = InplaceAddsExt().set_device('Ascend') + +inplace_clamp_scalar_op = InplaceClampScalar().set_device('Ascend') + +inplace_clamp_tensor_op = InplaceClampTensor().set_device('Ascend') + +inplace_copy_op = InplaceCopy().set_device('Ascend') + +inplace_div_op = InplaceDiv().set_device('Ascend') + +inplace_divmod_op = InplaceDivMod().set_device('Ascend') + +inplace_divmods_op = InplaceDivMods().set_device('Ascend') + +inplace_divs_op = InplaceDivs().set_device('Ascend') + +inplace_elu_op = InplaceElu().set_device('Ascend') + +inplace_erfinv_op = InplaceErfinv().set_device('Ascend') + +inplace_exp_op = InplaceExp().set_device('Ascend') + +inplace_exponential_op = InplaceExponential().set_device('Ascend') + +inplace_fill_diagonal_op = InplaceFillDiagonal().set_device('Ascend') + +inplace_fill_scalar_op = InplaceFillScalar().set_device('Ascend') + +inplace_fill_tensor_op = InplaceFillTensor().set_device('Ascend') + +inplace_floor_op = InplaceFloor().set_device('Ascend') + +inplace_floor_divide_op = InplaceFloorDivide().set_device('Ascend') + +inplace_floor_divides_op = InplaceFloorDivides().set_device('Ascend') + +inplace_grouped_matmul_add_op = InplaceGroupedMatmulAdd().set_device('Ascend') + +inplace_hardtanh_op = InplaceHardtanh().set_device('Ascend') + +inplace_index_add_op = InplaceIndexAddExt().set_device('Ascend') + +inplace_index_put_op = InplaceIndexPut().set_device('Ascend') + +inplace_log_op = InplaceLog().set_device('Ascend') + +inplace_masked_fill_scalar_op = InplaceMaskedFillScalar().set_device('Ascend') + +inplace_masked_fill_tensor_op = InplaceMaskedFillTensor().set_device('Ascend') + +inplace_mul_op = InplaceMul().set_device('Ascend') + +inplace_muls_op = InplaceMuls().set_device('Ascend') + +inplace_normal_op = InplaceNormal().set_device('Ascend') + +inplace_put_op = InplacePut().set_device('Ascend') + +inplace_random_op = InplaceRandom().set_device('Ascend') + +inplace_relu_op = InplaceReLU().set_device('Ascend') + +inplace_scatter_add_op = InplaceScatterAdd().set_device('Ascend') + +inplace_scatter_src_op = InplaceScatterSrc().set_device('Ascend') + +inplace_scatter_src_reduce_op = InplaceScatterSrcReduce().set_device('Ascend') + +inplace_scatter_value_op = InplaceScatterValue().set_device('Ascend') + +inplace_scatter_value_reduce_op = InplaceScatterValueReduce().set_device('Ascend') + +inplace_stop_gradient_op = InplaceStopGradient().set_device('Ascend') + +inplace_sub_ext_op = InplaceSubExt().set_device('Ascend') + +inplace_sub_scalar_op = InplaceSubScalar().set_device('Ascend') + +inplace_tanh_op = InplaceTanh().set_device('Ascend') + +inplace_threshold_op = InplaceThreshold().set_device('Ascend') + +inplace_uniform_op = InplaceUniform().set_device('Ascend') + +inplace_zero_op = InplaceZero().set_device('Ascend') + +isfinite_op = IsFinite().set_device('Ascend') + +isinf_op = IsInf().set_device('Ascend') + +isneginf_op = IsNegInf().set_device('Ascend') + +kl_div_op = KLDiv().set_device('Ascend') + +kl_div_grad_op = KLDivGrad().set_device('Ascend') + +kthvalue_op = Kthvalue().set_device('Ascend') + +kv_cache_scatter_update_op = KVCacheScatterUpdate().set_device('Ascend') + +l1_loss_backward_ext_op = L1LossBackwardExt().set_device('Ascend') + +l1_loss_ext_op = L1LossExt().set_device('Ascend') + +layer_norm_ext_op = LayerNormExt().set_device('Ascend') + +layer_norm_grad_ext_op = LayerNormGradExt().set_device('Ascend') + +leaky_relu_ext_op = LeakyReLUExt().set_device('Ascend') + +leaky_relu_grad_ext_op = LeakyReLUGradExt().set_device('Ascend') + +lerp_op = Lerp().set_device('Ascend') + +lerp_scalar_op = LerpScalar().set_device('Ascend') + +less_op = Less().set_device('Ascend') + +less_equal_op = LessEqual().set_device('Ascend') + +lin_space_ext_op = LinSpaceExt().set_device('Ascend') + +linalg_qr_op = LinalgQr().set_device('Ascend') + +linalg_vector_norm_op = LinalgVectorNorm().set_device('Ascend') + +log_op = Log().set_device('Ascend') + +log10_op = Log10().set_device('Ascend') + +log1p_op = Log1p().set_device('Ascend') + +log2_op = Log2().set_device('Ascend') + +log_softmax_ext_op = LogSoftmaxExt().set_device('Ascend') + +logaddexp_op = LogAddExp().set_device('Ascend') + +logaddexp2_op = LogAddExp2().set_device('Ascend') + +logical_and_op = LogicalAnd().set_device('Ascend') + +logical_not_op = LogicalNot().set_device('Ascend') + +logical_or_op = LogicalOr().set_device('Ascend') + +logical_xor_op = LogicalXor().set_device('Ascend') + +logsigmoid_op = LogSigmoid().set_device('Ascend') + +logsigmoid_grad_op = LogSigmoidGrad().set_device('Ascend') + +logsumexp_op = LogSumExp().set_device('Ascend') + +masked_fill_op = MaskedFill().set_device('Ascend') + +masked_select_op = MaskedSelect().set_device('Ascend') + +masked_select_grad_op = MaskedSelectGrad().set_device('Ascend') + +matmul_allreduce_add_rmsnorm_op = MatmulAllReduceAddRmsNorm().set_device('Ascend') + +matmul_ext_op = MatMulExt().set_device('Ascend') + +matmul_reduce_scatter_op = MatmulReduceScatter().set_device('Ascend') + +matrix_inverse_ext_op = MatrixInverseExt().set_device('Ascend') + +max_op = Max().set_device('Ascend') + +max_dim_op = MaxDim().set_device('Ascend') + +max_unpool2d_ext_op = MaxUnpool2DExt().set_device('Ascend') + +maximum_op = Maximum().set_device('Ascend') + +mean_ext_op = MeanExt().set_device('Ascend') + +median_dim_op = MedianDim().set_device('Ascend') + +median_ext_op = MedianExt().set_device('Ascend') + +min_op = Min().set_device('Ascend') + +min_dim_op = MinDim().set_device('Ascend') + +minimum_op = Minimum().set_device('Ascend') + +mish_ext_op = MishExt().set_device('Ascend') + +mish_grad_ext_op = MishGradExt().set_device('Ascend') + +mm_ext_op = Mm().set_device('Ascend') + +moe_compute_expert_tokens_op = MoeComputeExpertTokens().set_device('Ascend') + +moe_finalize_routing_op = MoeFinalizeRouting().set_device('Ascend') + +moe_gating_top_k_softmax_op = MoeGatingTopKSoftmax().set_device('Ascend') + +moe_init_routing_op = MoeInitRouting().set_device('Ascend') + +moe_init_routing_v2_op = MoeInitRoutingV2().set_device('Ascend') + +moe_token_permute_op = MoeTokenPermute().set_device('Ascend') + +moe_token_permute_grad_op = MoeTokenPermuteGrad().set_device('Ascend') + +moe_token_unpermute_op = MoeTokenUnpermute().set_device('Ascend') + +moe_token_unpermute_grad_op = MoeTokenUnpermuteGrad().set_device('Ascend') + +mse_loss_ext_op = MSELossExt().set_device('Ascend') + +mse_loss_grad_ext_op = MSELossGradExt().set_device('Ascend') + +mul_op = Mul().set_device('Ascend') + +muls_op = Muls().set_device('Ascend') + +multi_scale_deformable_attn_op = MultiScaleDeformableAttn().set_device('Ascend') + +multi_scale_deformable_attn_grad_op = MultiScaleDeformableAttnGrad().set_device('Ascend') + +multinomial_ext_op = MultinomialExt().set_device('Ascend') + +mv_op = Mv().set_device('Ascend') + +nansum_op = Nansum().set_device('Ascend') + +narrow_op = Narrow().set_device('Ascend') + +narrow_view_op = NarrowView().set_device('Ascend') + +neg_op = Neg().set_device('Ascend') + +new_ones_op = NewOnes().set_device('Ascend') + +new_zeros_op = NewZeros().set_device('Ascend') + +nllloss_2d_op = NLLLoss2d().set_device('Ascend') + +nllloss_2d_grad_op = NLLLoss2dGrad().set_device('Ascend') + +non_zero_op = NonZero().set_device('Ascend') + +non_zero_ext_op = NonZeroExt().set_device('Ascend') + +norm_op = Norm().set_device('Ascend') + +normal_float_float_op = NormalFloatFloat().set_device('Ascend') + +normal_float_tensor_op = NormalFloatTensor().set_device('Ascend') + +normal_tensor_float_op = NormalTensorFloat().set_device('Ascend') + +normal_tensor_tensor_op = NormalTensorTensor().set_device('Ascend') + +not_equal_op = NotEqual().set_device('Ascend') + +ones_like_ext_op = OnesLikeExt().set_device('Ascend') + +outer_op = Outer().set_device('Ascend') + +pixel_shuffle_op = PixelShuffle().set_device('Ascend') + +polar_op = Polar().set_device('Ascend') + +pow_op = Pow().set_device('Ascend') + +pow_scalar_tensor_op = PowScalarTensor().set_device('Ascend') + +pow_tensor_scalar_op = PowTensorScalar().set_device('Ascend') + +prelu_op = PReLU().set_device('Ascend') + +prelu_grad_op = PReLUGrad().set_device('Ascend') + +prod_ext_op = ProdExt().set_device('Ascend') + +quant_v2_op = QuantV2().set_device('Ascend') + +rand_ext_op = RandExt().set_device('Ascend') + +rand_like_ext_op = RandLikeExt().set_device('Ascend') + +randint_op = RandInt().set_device('Ascend') + +randint_like_op = RandIntLike().set_device('Ascend') + +randn_op = Randn().set_device('Ascend') + +randn_like_op = RandnLike().set_device('Ascend') + +randperm_ext_op = RandpermExt().set_device('Ascend') + +reciprocal_op = Reciprocal().set_device('Ascend') + +reflection_pad_1d_op = ReflectionPad1D().set_device('Ascend') + +reflection_pad_1d_grad_op = ReflectionPad1DGrad().set_device('Ascend') + +reflection_pad_2d_op = ReflectionPad2D().set_device('Ascend') + +reflection_pad_2d_grad_op = ReflectionPad2DGrad().set_device('Ascend') + +reflection_pad_3d_op = ReflectionPad3D().set_device('Ascend') + +reflection_pad_3d_grad_op = ReflectionPad3DGrad().set_device('Ascend') + +relu_op = ReLU().set_device('Ascend') + +relu_grad_op = ReluGrad().set_device('Ascend') + +remainder_scalar_tensor_op = RemainderScalarTensor().set_device('Ascend') + +remainder_tensor_scalar_op = RemainderTensorScalar().set_device('Ascend') + +remainder_tensor_tensor_op = RemainderTensorTensor().set_device('Ascend') + +repeat_op = Repeat().set_device('Ascend') + +repeat_interleave_grad_op = RepeatInterleaveGrad().set_device('Ascend') + +repeat_interleave_int_op = RepeatInterleaveInt().set_device('Ascend') + +repeat_interleave_tensor_op = RepeatInterleaveTensor().set_device('Ascend') + +replication_pad_1d_op = ReplicationPad1D().set_device('Ascend') + +replication_pad_1d_grad_op = ReplicationPad1DGrad().set_device('Ascend') + +replication_pad_2d_op = ReplicationPad2D().set_device('Ascend') + +replication_pad_2d_grad_op = ReplicationPad2DGrad().set_device('Ascend') + +replication_pad_3d_op = ReplicationPad3D().set_device('Ascend') + +replication_pad_3d_grad_op = ReplicationPad3DGrad().set_device('Ascend') + +reshape_op = Reshape().set_device('Ascend') + +rms_norm_grad_op = RmsNormGrad().set_device('Ascend') + +rotary_position_embedding_op = RotaryPositionEmbedding().set_device('Ascend') + +rotary_position_embedding_grad_op = RotaryPositionEmbeddingGrad().set_device('Ascend') + +round_op = Round().set_device('Ascend') + +rsqrt_op = Rsqrt().set_device('Ascend') + +scatter_op = Scatter().set_device('Ascend') + +scatter_add_ext_op = ScatterAddExt().set_device('Ascend') + +scatter_value_op = ScatterValue().set_device('Ascend') + +select_op = Select().set_device('Ascend') + +select_ext_view_op = SelectExtView().set_device('Ascend') + +select_v2_op = SelectV2().set_device('Ascend') + +selu_ext_op = SeLUExt().set_device('Ascend') + +selu_grad_op = SeluGrad().set_device('Ascend') + +sigmoid_op = Sigmoid().set_device('Ascend') + +sigmoid_grad_op = SigmoidGrad().set_device('Ascend') + +sign_op = Sign().set_device('Ascend') + +silent_check_v2_op = SilentCheckV2().set_device('Ascend') + +silent_check_v3_op = SilentCheckV3().set_device('Ascend') + +silu_op = SiLU().set_device('Ascend') + +silu_grad_op = SiLUGrad().set_device('Ascend') + +sin_op = Sin().set_device('Ascend') + +sinc_op = Sinc().set_device('Ascend') + +sinh_op = Sinh().set_device('Ascend') + +slice_op = Slice().set_device('Ascend') + +slice_ext_op = SliceExt().set_device('Ascend') + +slice_ext_view_op = SliceExtView().set_device('Ascend') + +softmax_backward_op = SoftmaxBackward().set_device('Ascend') + +softplus_ext_op = SoftplusExt().set_device('Ascend') + +softplus_grad_ext_op = SoftplusGradExt().set_device('Ascend') + +sort_ext_op = SortExt().set_device('Ascend') + +speed_fusion_attention_op = SpeedFusionAttention().set_device('Ascend') + +speed_fusion_attention_grad_op = SpeedFusionAttentionGrad().set_device('Ascend') + +split_tensor_op = SplitTensor().set_device('Ascend') + +split_tensor_view_op = SplitTensorView().set_device('Ascend') + +split_with_size_op = SplitWithSize().set_device('Ascend') + +split_with_size_view_op = SplitWithSizeView().set_device('Ascend') + +sqrt_op = Sqrt().set_device('Ascend') + +square_op = Square().set_device('Ascend') + +std_op = Std().set_device('Ascend') + +std_mean_op = StdMean().set_device('Ascend') + +sub_op = Sub().set_device('Ascend') + +sub_ext_op = SubExt().set_device('Ascend') + +sub_scalar_op = SubScalar().set_device('Ascend') + +sum_ext_op = SumExt().set_device('Ascend') + +swiglu_op = Swiglu().set_device('Ascend') + +swiglu_grad_op = SwigluGrad().set_device('Ascend') + +t_ext_op = TExt().set_device('Ascend') + +take_op = Take().set_device('Ascend') + +tan_op = Tan().set_device('Ascend') + +tanh_op = Tanh().set_device('Ascend') + +tanh_grad_op = TanhGrad().set_device('Ascend') + +threshold_op = Threshold().set_device('Ascend') + +threshold_grad_op = ThresholdGrad().set_device('Ascend') + +topk_ext_op = TopkExt().set_device('Ascend') + +trace_ext_op = TraceExt().set_device('Ascend') + +transpose_op = Transpose().set_device('Ascend') + +transpose_ext_view_op = TransposeExtView().set_device('Ascend') + +transpose_view_op = TransposeView().set_device('Ascend') + +triangular_solve_op = TriangularSolve().set_device('Ascend') + +tril_ext_op = TrilExt().set_device('Ascend') + +trunc_op = Trunc().set_device('Ascend') + +uniform_ext_op = UniformExt().set_device('Ascend') + +unique2_op = Unique2().set_device('Ascend') + +unique_dim_op = UniqueDim().set_device('Ascend') + +unstack_ext_view_op = UnstackExtView().set_device('Ascend') + +upsample_bicubic2d_op = UpsampleBicubic2D().set_device('Ascend') + +upsample_bicubic2d_grad_op = UpsampleBicubic2DGrad().set_device('Ascend') + +upsample_bilinear2d_op = UpsampleBilinear2D().set_device('Ascend') + +upsample_bilinear2d_grad_op = UpsampleBilinear2DGrad().set_device('Ascend') + +upsample_linear1d_op = UpsampleLinear1D().set_device('Ascend') + +upsample_linear1d_grad_op = UpsampleLinear1DGrad().set_device('Ascend') + +upsample_nearest1d_op = UpsampleNearest1D().set_device('Ascend') + +upsample_nearest1d_grad_op = UpsampleNearest1DGrad().set_device('Ascend') + +upsample_nearest2d_op = UpsampleNearest2D().set_device('Ascend') + +upsample_nearest2d_grad_op = UpsampleNearest2DGrad().set_device('Ascend') + +upsample_nearest3d_op = UpsampleNearest3D().set_device('Ascend') + +upsample_nearest3d_grad_op = UpsampleNearest3DGrad().set_device('Ascend') + +var_op = Var().set_device('Ascend') + +var_mean_op = VarMean().set_device('Ascend') + +view_as_op = ViewAs().set_device('Ascend') + +xlogy_op = Xlogy().set_device('Ascend') + +xlogy_scalar_other_op = XLogYScalarOther().set_device('Ascend') + +xlogy_scalar_self_op = XLogYScalarSelf().set_device('Ascend') + +zeros_like_ext_op = ZerosLikeExt().set_device('Ascend') + diff --git a/torch4ms/_op_prim/cpu/__init__.py b/torch4ms/_op_prim/cpu/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/torch4ms/_op_prim/cpu/legacy.py b/torch4ms/_op_prim/cpu/legacy.py new file mode 100644 index 000000000..6ac9862d3 --- /dev/null +++ b/torch4ms/_op_prim/cpu/legacy.py @@ -0,0 +1,3511 @@ +from mindspore.ops.operations import * +from mindspore.ops.operations._grad_ops import * +from mindspore.ops.operations._inner_ops import * +from mindspore.ops._primitive_cache import _get_cache_prim + + +a_cos_grad_op = ACosGrad().set_device('CPU') +def a_cos_grad(*args): + return a_cos_grad_op(*args) + + +abs_grad_op = AbsGrad().set_device('CPU') +def abs_grad(*args): + return abs_grad_op(*args) + + +acosh_grad_op = AcoshGrad().set_device('CPU') +def acosh_grad(*args): + return acosh_grad_op(*args) + + +adaptive_avg_pool2_d_grad_op = AdaptiveAvgPool2DGrad().set_device('CPU') +def adaptive_avg_pool2_d_grad(*args): + return adaptive_avg_pool2_d_grad_op(*args) + + +adaptive_avg_pool3_d_grad_op = AdaptiveAvgPool3DGrad().set_device('CPU') +def adaptive_avg_pool3_d_grad(*args): + return adaptive_avg_pool3_d_grad_op(*args) + + +adaptive_max_pool2_d_grad_op = AdaptiveMaxPool2DGrad().set_device('CPU') +def adaptive_max_pool2_d_grad(*args): + return adaptive_max_pool2_d_grad_op(*args) + + +adaptive_max_pool3_d_grad_op = AdaptiveMaxPool3DGrad().set_device('CPU') +def adaptive_max_pool3_d_grad(*args): + return adaptive_max_pool3_d_grad_op(*args) + + +def affine_grid_grad(*args): + op = _get_cache_prim(AffineGridGrad)(*args[-1:]).set_device('CPU') + return op(*args[:-1]) + + +asin_grad_op = AsinGrad().set_device('CPU') +def asin_grad(*args): + return asin_grad_op(*args) + + +asinh_grad_op = AsinhGrad().set_device('CPU') +def asinh_grad(*args): + return asinh_grad_op(*args) + + +atan_grad_op = AtanGrad().set_device('CPU') +def atan_grad(*args): + return atan_grad_op(*args) + + +def avg_pool3_d_grad(*args): + op = _get_cache_prim(AvgPool3DGrad)(*args[-8:]).set_device('CPU') + return op(*args[:-8]) + + +def avg_pool_grad(*args): + op = _get_cache_prim(AvgPoolGrad)(*args[-4:]).set_device('CPU') + return op(*args[:-4]) + + +def avg_pool_grad_ge(*args): + op = _get_cache_prim(AvgPoolGradGe)(*args[-4:]).set_device('CPU') + return op(*args[:-4]) + + +def avg_pool_grad_v1(*args): + op = _get_cache_prim(AvgPoolGradV1)(*args[-4:]).set_device('CPU') + return op(*args[:-4]) + + +def avg_pool_grad_vm(*args): + op = _get_cache_prim(AvgPoolGradVm)(*args[-3:]).set_device('CPU') + return op(*args[:-3]) + + +def bn_training_reduce_grad(*args): + op = _get_cache_prim(BNTrainingReduceGrad)(*args[-2:]).set_device('CPU') + return op(*args[:-2]) + + +def bn_training_update_grad(*args): + op = _get_cache_prim(BNTrainingUpdateGrad)(*args[-2:]).set_device('CPU') + return op(*args[:-2]) + + +def basic_lstm_cell_c_state_grad(*args): + op = _get_cache_prim(BasicLSTMCellCStateGrad)(*args[-2:]).set_device('CPU') + return op(*args[:-2]) + + +def basic_lstm_cell_input_grad(*args): + op = _get_cache_prim(BasicLSTMCellInputGrad)(*args[-1:]).set_device('CPU') + return op(*args[:-1]) + + +basic_lstm_cell_weight_grad_op = BasicLSTMCellWeightGrad().set_device('CPU') +def basic_lstm_cell_weight_grad(*args): + return basic_lstm_cell_weight_grad_op(*args) + + +def batch_norm_grad(*args): + op = _get_cache_prim(BatchNormGrad)(*args[-3:]).set_device('CPU') + return op(*args[:-3]) + + +def batch_norm_grad_grad(*args): + op = _get_cache_prim(BatchNormGradGrad)(*args[-3:]).set_device('CPU') + return op(*args[:-3]) + + +def bias_add_grad(*args): + op = _get_cache_prim(BiasAddGrad)(*args[-1:]).set_device('CPU') + return op(*args[:-1]) + + +def binary_cross_entropy_grad(*args): + op = _get_cache_prim(BinaryCrossEntropyGrad)(*args[-1:]).set_device('CPU') + return op(*args[:-1]) + + +cholesky_grad_op = CholeskyGrad().set_device('CPU') +def cholesky_grad(*args): + return cholesky_grad_op(*args) + + +def concat_offset(*args): + op = _get_cache_prim(ConcatOffset)(*args[-2:]).set_device('CPU') + return op(*args[:-2]) + + +def conv2_d_backprop_filter(*args): + op = _get_cache_prim(Conv2DBackpropFilter)(*args[-10:]).set_device('CPU') + return op(*args[:-10]) + + +def conv3_d_backprop_filter(*args): + op = _get_cache_prim(Conv3DBackpropFilter)(*args[-9:]).set_device('CPU') + return op(*args[:-9]) + + +def deformable_offsets_grad(*args): + op = _get_cache_prim(DeformableOffsetsGrad)(*args[-7:]).set_device('CPU') + return op(*args[:-7]) + + +def depthwise_conv2d_native_backprop_filter(*args): + op = _get_cache_prim(DepthwiseConv2dNativeBackpropFilter)(*args[-9:]).set_device('CPU') + return op(*args[:-9]) + + +def depthwise_conv2d_native_backprop_input(*args): + op = _get_cache_prim(DepthwiseConv2dNativeBackpropInput)(*args[-9:]).set_device('CPU') + return op(*args[:-9]) + + +def dilation2_d_backprop_filter(*args): + op = _get_cache_prim(Dilation2DBackpropFilter)(*args[-4:]).set_device('CPU') + return op(*args[:-4]) + + +def dilation2_d_backprop_input(*args): + op = _get_cache_prim(Dilation2DBackpropInput)(*args[-4:]).set_device('CPU') + return op(*args[:-4]) + + +def dropout_grad(*args): + op = _get_cache_prim(DropoutGrad)(*args[-1:]).set_device('CPU') + return op(*args[:-1]) + + +def dynamic_gruv2_grad(*args): + op = _get_cache_prim(DynamicGRUV2Grad)(*args[-8:]).set_device('CPU') + return op(*args[:-8]) + + +def dynamic_rnn_grad(*args): + op = _get_cache_prim(DynamicRNNGrad)(*args[-9:]).set_device('CPU') + return op(*args[:-9]) + + +def einsum_grad(*args): + op = _get_cache_prim(EinsumGrad)(*args[-1:]).set_device('CPU') + return op(*args[:-1]) + + +elu_grad_op = EluGrad().set_device('CPU') +def elu_grad(*args): + return elu_grad_op(*args) + + +embedding_lookup_comm_grad_op = EmbeddingLookupCommGrad().set_device('CPU') +def embedding_lookup_comm_grad(*args): + return embedding_lookup_comm_grad_op(*args) + + +fast_ge_lu_grad_op = FastGeLUGrad().set_device('CPU') +def fast_ge_lu_grad(*args): + return fast_ge_lu_grad_op(*args) + + +def flash_attention_score_grad(*args): + op = _get_cache_prim(FlashAttentionScoreGrad)(*args[-8:]).set_device('CPU') + return op(*args[:-8]) + + +flatten_grad_op = FlattenGrad().set_device('CPU') +def flatten_grad(*args): + return flatten_grad_op(*args) + + +def fractional_avg_pool_grad(*args): + op = _get_cache_prim(FractionalAvgPoolGrad)(*args[-1:]).set_device('CPU') + return op(*args[:-1]) + + +def fractional_max_pool3_d_grad_with_fixed_ksize(*args): + op = _get_cache_prim(FractionalMaxPool3DGradWithFixedKsize)(*args[-1:]).set_device('CPU') + return op(*args[:-1]) + + +def fractional_max_pool_grad(*args): + op = _get_cache_prim(FractionalMaxPoolGrad)(*args[-1:]).set_device('CPU') + return op(*args[:-1]) + + +def fractional_max_pool_grad_with_fixed_ksize(*args): + op = _get_cache_prim(FractionalMaxPoolGradWithFixedKsize)(*args[-1:]).set_device('CPU') + return op(*args[:-1]) + + +def gruv2_grad(*args): + op = _get_cache_prim(GRUV2Grad)(*args[-6:]).set_device('CPU') + return op(*args[:-6]) + + +gather_d_grad_v2_op = GatherDGradV2().set_device('CPU') +def gather_d_grad_v2(*args): + return gather_d_grad_v2_op(*args) + + +ge_lu_grad_op = GeLUGrad().set_device('CPU') +def ge_lu_grad(*args): + return ge_lu_grad_op(*args) + + +def global_comm(*args): + op = _get_cache_prim(GlobalComm)(*args[-2:]).set_device('CPU') + return op(*args[:-2]) + + +def glu_grad(*args): + op = _get_cache_prim(GluGrad)(*args[-1:]).set_device('CPU') + return op(*args[:-1]) + + +def grid_sampler2_d_grad(*args): + op = _get_cache_prim(GridSampler2DGrad)(*args[-4:]).set_device('CPU') + return op(*args[:-4]) + + +def grid_sampler3_d_grad(*args): + op = _get_cache_prim(GridSampler3DGrad)(*args[-4:]).set_device('CPU') + return op(*args[:-4]) + + +def gru_grad_data(*args): + op = _get_cache_prim(GruGradData)(*args[-6:]).set_device('CPU') + return op(*args[:-6]) + + +def gru_grad_weight(*args): + op = _get_cache_prim(GruGradWeight)(*args[-6:]).set_device('CPU') + return op(*args[:-6]) + + +def h_shrink_grad(*args): + op = _get_cache_prim(HShrinkGrad)(*args[-1:]).set_device('CPU') + return op(*args[:-1]) + + +h_sigmoid_grad_op = HSigmoidGrad().set_device('CPU') +def h_sigmoid_grad(*args): + return h_sigmoid_grad_op(*args) + + +h_swish_grad_op = HSwishGrad().set_device('CPU') +def h_swish_grad(*args): + return h_swish_grad_op(*args) + + +igamma_grad_a_op = IgammaGradA().set_device('CPU') +def igamma_grad_a(*args): + return igamma_grad_a_op(*args) + + +def instance_norm_grad(*args): + op = _get_cache_prim(InstanceNormGrad)(*args[-2:]).set_device('CPU') + return op(*args[:-2]) + + +def instance_norm_v2_grad(*args): + op = _get_cache_prim(InstanceNormV2Grad)(*args[-2:]).set_device('CPU') + return op(*args[:-2]) + + +inv_grad_op = InvGrad().set_device('CPU') +def inv_grad(*args): + return inv_grad_op(*args) + + +def kl_div_loss_grad(*args): + op = _get_cache_prim(KLDivLossGrad)(*args[-1:]).set_device('CPU') + return op(*args[:-1]) + + +def l2_normalize_grad(*args): + op = _get_cache_prim(L2NormalizeGrad)(*args[-2:]).set_device('CPU') + return op(*args[:-2]) + + +def lrn_grad(*args): + op = _get_cache_prim(LRNGrad)(*args[-4:]).set_device('CPU') + return op(*args[:-4]) + + +def lstm_grad(*args): + op = _get_cache_prim(LSTMGrad)(*args[-7:]).set_device('CPU') + return op(*args[:-7]) + + +def lstm_grad_data(*args): + op = _get_cache_prim(LSTMGradData)(*args[-6:]).set_device('CPU') + return op(*args[:-6]) + + +def lstm_grad_weight(*args): + op = _get_cache_prim(LSTMGradWeight)(*args[-6:]).set_device('CPU') + return op(*args[:-6]) + + +def layer_norm_grad(*args): + op = _get_cache_prim(LayerNormGrad)(*args[-2:]).set_device('CPU') + return op(*args[:-2]) + + +def layer_norm_grad_grad(*args): + op = _get_cache_prim(LayerNormGradGrad)(*args[-2:]).set_device('CPU') + return op(*args[:-2]) + + +def log_softmax_grad(*args): + op = _get_cache_prim(LogSoftmaxGrad)(*args[-1:]).set_device('CPU') + return op(*args[:-1]) + + +def logit_grad(*args): + op = _get_cache_prim(LogitGrad)(*args[-1:]).set_device('CPU') + return op(*args[:-1]) + + +def lu_unpack_grad(*args): + op = _get_cache_prim(LuUnpackGrad)(*args[-2:]).set_device('CPU') + return op(*args[:-2]) + + +map_tensor_get_grad_op = MapTensorGetGrad().set_device('CPU') +def map_tensor_get_grad(*args): + return map_tensor_get_grad_op(*args) + + +masked_select_grad_op = MaskedSelectGrad().set_device('CPU') +def masked_select_grad(*args): + return masked_select_grad_op(*args) + + +def max_pool3_d_grad(*args): + op = _get_cache_prim(MaxPool3DGrad)(*args[-5:]).set_device('CPU') + return op(*args[:-5]) + + +def max_pool3_d_grad_grad(*args): + op = _get_cache_prim(MaxPool3DGradGrad)(*args[-4:]).set_device('CPU') + return op(*args[:-4]) + + +def max_pool3_d_grad_with_argmax(*args): + op = _get_cache_prim(MaxPool3DGradWithArgmax)(*args[-6:]).set_device('CPU') + return op(*args[:-6]) + + +def max_pool_grad(*args): + op = _get_cache_prim(MaxPoolGrad)(*args[-4:]).set_device('CPU') + return op(*args[:-4]) + + +def max_pool_grad_grad(*args): + op = _get_cache_prim(MaxPoolGradGrad)(*args[-3:]).set_device('CPU') + return op(*args[:-3]) + + +def max_pool_grad_grad_with_argmax(*args): + op = _get_cache_prim(MaxPoolGradGradWithArgmax)(*args[-3:]).set_device('CPU') + return op(*args[:-3]) + + +def max_pool_grad_v1(*args): + op = _get_cache_prim(MaxPoolGradV1)(*args[-4:]).set_device('CPU') + return op(*args[:-4]) + + +def max_pool_grad_with_argmax(*args): + op = _get_cache_prim(MaxPoolGradWithArgmax)(*args[-4:]).set_device('CPU') + return op(*args[:-4]) + + +def max_pool_grad_with_argmax_v2(*args): + op = _get_cache_prim(MaxPoolGradWithArgmaxV2)(*args[-6:]).set_device('CPU') + return op(*args[:-6]) + + +def max_unpool2_d_grad(*args): + op = _get_cache_prim(MaxUnpool2DGrad)(*args[-5:]).set_device('CPU') + return op(*args[:-5]) + + +def max_unpool3_d_grad(*args): + op = _get_cache_prim(MaxUnpool3DGrad)(*args[-5:]).set_device('CPU') + return op(*args[:-5]) + + +def maximum_grad(*args): + op = _get_cache_prim(MaximumGrad)(*args[-2:]).set_device('CPU') + return op(*args[:-2]) + + +def maximum_grad_grad(*args): + op = _get_cache_prim(MaximumGradGrad)(*args[-2:]).set_device('CPU') + return op(*args[:-2]) + + +def median_grad(*args): + op = _get_cache_prim(MedianGrad)(*args[-3:]).set_device('CPU') + return op(*args[:-3]) + + +def minimum_grad(*args): + op = _get_cache_prim(MinimumGrad)(*args[-2:]).set_device('CPU') + return op(*args[:-2]) + + +minimum_grad_grad_op = MinimumGradGrad().set_device('CPU') +def minimum_grad_grad(*args): + return minimum_grad_grad_op(*args) + + +def mirror_pad_grad(*args): + op = _get_cache_prim(MirrorPadGrad)(*args[-1:]).set_device('CPU') + return op(*args[:-1]) + + +def multi_margin_loss_grad(*args): + op = _get_cache_prim(MultiMarginLossGrad)(*args[-3:]).set_device('CPU') + return op(*args[:-3]) + + +def multilabel_margin_loss_grad(*args): + op = _get_cache_prim(MultilabelMarginLossGrad)(*args[-1:]).set_device('CPU') + return op(*args[:-1]) + + +def mvlgamma_grad(*args): + op = _get_cache_prim(MvlgammaGrad)(*args[-1:]).set_device('CPU') + return op(*args[:-1]) + + +def nll_loss_grad(*args): + op = _get_cache_prim(NLLLossGrad)(*args[-2:]).set_device('CPU') + return op(*args[:-2]) + + +def neighbor_exchange_v2_grad(*args): + op = _get_cache_prim(NeighborExchangeV2Grad)(*args[-6:]).set_device('CPU') + return op(*args[:-6]) + + +p_re_lu_grad_op = PReLUGrad().set_device('CPU') +def p_re_lu_grad(*args): + return p_re_lu_grad_op(*args) + + +def psroi_pooling_grad(*args): + op = _get_cache_prim(PSROIPoolingGrad)(*args[-4:]).set_device('CPU') + return op(*args[:-4]) + + +def pad_v3_grad(*args): + op = _get_cache_prim(PadV3Grad)(*args[-2:]).set_device('CPU') + return op(*args[:-2]) + + +def parallel_resize_bilinear_grad(*args): + op = _get_cache_prim(ParallelResizeBilinearGrad)(*args[-4:]).set_device('CPU') + return op(*args[:-4]) + + +def pdist_grad(*args): + op = _get_cache_prim(PdistGrad)(*args[-1:]).set_device('CPU') + return op(*args[:-1]) + + +def primitive(*args): + op = _get_cache_prim(Primitive)(*args[-1:]).set_device('CPU') + return op(*args[:-1]) + + +def primitive_with_infer(*args): + op = _get_cache_prim(PrimitiveWithInfer)(*args[-1:]).set_device('CPU') + return op(*args[:-1]) + + +def ps_roi_pooling_grad(*args): + op = _get_cache_prim(PsROIPoolingGrad)(*args[-9:]).set_device('CPU') + return op(*args[:-9]) + + +def roi_align_grad(*args): + op = _get_cache_prim(ROIAlignGrad)(*args[-4:]).set_device('CPU') + return op(*args[:-4]) + + +random_gamma_grad_op = RandomGammaGrad().set_device('CPU') +def random_gamma_grad(*args): + return random_gamma_grad_op(*args) + + +re_lu6_grad_op = ReLU6Grad().set_device('CPU') +def re_lu6_grad(*args): + return re_lu6_grad_op(*args) + + +reciprocal_grad_op = ReciprocalGrad().set_device('CPU') +def reciprocal_grad(*args): + return reciprocal_grad_op(*args) + + +ref_to_embed_op = RefToEmbed().set_device('CPU') +def ref_to_embed(*args): + return ref_to_embed_op(*args) + + +relu_grad_op = ReluGrad().set_device('CPU') +def relu_grad(*args): + return relu_grad_op(*args) + + +def resize_bicubic_grad(*args): + op = _get_cache_prim(ResizeBicubicGrad)(*args[-2:]).set_device('CPU') + return op(*args[:-2]) + + +def resize_bilinear_grad(*args): + op = _get_cache_prim(ResizeBilinearGrad)(*args[-2:]).set_device('CPU') + return op(*args[:-2]) + + +def resize_linear1_d_grad(*args): + op = _get_cache_prim(ResizeLinear1DGrad)(*args[-1:]).set_device('CPU') + return op(*args[:-1]) + + +def resize_nearest_neighbor_grad(*args): + op = _get_cache_prim(ResizeNearestNeighborGrad)(*args[-2:]).set_device('CPU') + return op(*args[:-2]) + + +def resize_nearest_neighbor_v2_grad(*args): + op = _get_cache_prim(ResizeNearestNeighborV2Grad)(*args[-2:]).set_device('CPU') + return op(*args[:-2]) + + +def resize_v2_grad(*args): + op = _get_cache_prim(ResizeV2Grad)(*args[-2:]).set_device('CPU') + return op(*args[:-2]) + + +rms_norm_grad_op = RmsNormGrad().set_device('CPU') +def rms_norm_grad(*args): + return rms_norm_grad_op(*args) + + +rsqrt_grad_op = RsqrtGrad().set_device('CPU') +def rsqrt_grad(*args): + return rsqrt_grad_op(*args) + + +def scale_and_translate_grad(*args): + op = _get_cache_prim(ScaleAndTranslateGrad)(*args[-2:]).set_device('CPU') + return op(*args[:-2]) + + +selu_grad_op = SeluGrad().set_device('CPU') +def selu_grad(*args): + return selu_grad_op(*args) + + +si_lu_grad_op = SiLUGrad().set_device('CPU') +def si_lu_grad(*args): + return si_lu_grad_op(*args) + + +sigmoid_cross_entropy_with_logits_grad_op = SigmoidCrossEntropyWithLogitsGrad().set_device('CPU') +def sigmoid_cross_entropy_with_logits_grad(*args): + return sigmoid_cross_entropy_with_logits_grad_op(*args) + + +sigmoid_grad_op = SigmoidGrad().set_device('CPU') +def sigmoid_grad(*args): + return sigmoid_grad_op(*args) + + +slice_grad_op = SliceGrad().set_device('CPU') +def slice_grad(*args): + return slice_grad_op(*args) + + +def smooth_l1_loss_grad(*args): + op = _get_cache_prim(SmoothL1LossGrad)(*args[-2:]).set_device('CPU') + return op(*args[:-2]) + + +def soft_margin_loss_grad(*args): + op = _get_cache_prim(SoftMarginLossGrad)(*args[-1:]).set_device('CPU') + return op(*args[:-1]) + + +def soft_shrink_grad(*args): + op = _get_cache_prim(SoftShrinkGrad)(*args[-1:]).set_device('CPU') + return op(*args[:-1]) + + +softmax_grad_op = SoftmaxGrad().set_device('CPU') +def softmax_grad(*args): + return softmax_grad_op(*args) + + +softplus_grad_op = SoftplusGrad().set_device('CPU') +def softplus_grad(*args): + return softplus_grad_op(*args) + + +sparse_fill_empty_rows_grad_op = SparseFillEmptyRowsGrad().set_device('CPU') +def sparse_fill_empty_rows_grad(*args): + return sparse_fill_empty_rows_grad_op(*args) + + +sparse_segment_mean_grad_op = SparseSegmentMeanGrad().set_device('CPU') +def sparse_segment_mean_grad(*args): + return sparse_segment_mean_grad_op(*args) + + +sparse_segment_sqrt_n_grad_op = SparseSegmentSqrtNGrad().set_device('CPU') +def sparse_segment_sqrt_n_grad(*args): + return sparse_segment_sqrt_n_grad_op(*args) + + +sparse_segment_sum_grad_op = SparseSegmentSumGrad().set_device('CPU') +def sparse_segment_sum_grad(*args): + return sparse_segment_sum_grad_op(*args) + + +sparse_slice_grad_op = SparseSliceGrad().set_device('CPU') +def sparse_slice_grad(*args): + return sparse_slice_grad_op(*args) + + +sqrt_grad_op = SqrtGrad().set_device('CPU') +def sqrt_grad(*args): + return sqrt_grad_op(*args) + + +def strided_slice_grad(*args): + op = _get_cache_prim(StridedSliceGrad)(*args[-5:]).set_device('CPU') + return op(*args[:-5]) + + +def sync_batch_norm_grad(*args): + op = _get_cache_prim(SyncBatchNormGrad)(*args[-3:]).set_device('CPU') + return op(*args[:-3]) + + +tanh_grad_op = TanhGrad().set_device('CPU') +def tanh_grad(*args): + return tanh_grad_op(*args) + + +trace_grad_op = TraceGrad().set_device('CPU') +def trace_grad(*args): + return trace_grad_op(*args) + + +unique_grad_op = UniqueGrad().set_device('CPU') +def unique_grad(*args): + return unique_grad_op(*args) + + +upsample_nearest3_d_grad_op = UpsampleNearest3DGrad().set_device('CPU') +def upsample_nearest3_d_grad(*args): + return upsample_nearest3_d_grad_op(*args) + + +def upsample_trilinear3_d_grad(*args): + op = _get_cache_prim(UpsampleTrilinear3DGrad)(*args[-1:]).set_device('CPU') + return op(*args[:-1]) + + +wkv_grad_op = WKVGrad().set_device('CPU') +def wkv_grad(*args): + return wkv_grad_op(*args) + + +a_cos_op = ACos().set_device('CPU') +def a_cos(*args): + return a_cos_op(*args) + + +abs_op = Abs().set_device('CPU') +def abs(*args): + return abs_op(*args) + + +accumulate_nv2_op = AccumulateNV2().set_device('CPU') +def accumulate_nv2(*args): + return accumulate_nv2_op(*args) + + +acosh_op = Acosh().set_device('CPU') +def acosh(*args): + return acosh_op(*args) + + +def adam(*args): + op = _get_cache_prim(Adam)(*args[-2:]).set_device('CPU') + return op(*args[:-2]) + + +def adam_no_update_param(*args): + op = _get_cache_prim(AdamNoUpdateParam)(*args[-2:]).set_device('CPU') + return op(*args[:-2]) + + +def adam_weight_decay(*args): + op = _get_cache_prim(AdamWeightDecay)(*args[-1:]).set_device('CPU') + return op(*args[:-1]) + + +def adaptive_avg_pool2_d(*args): + op = _get_cache_prim(AdaptiveAvgPool2D)(*args[-1:]).set_device('CPU') + return op(*args[:-1]) + + +def adaptive_avg_pool3_d(*args): + op = _get_cache_prim(AdaptiveAvgPool3D)(*args[-1:]).set_device('CPU') + return op(*args[:-1]) + + +def adaptive_max_pool2_d(*args): + op = _get_cache_prim(AdaptiveMaxPool2D)(*args[-1:]).set_device('CPU') + return op(*args[:-1]) + + +adaptive_max_pool3_d_op = AdaptiveMaxPool3D().set_device('CPU') +def adaptive_max_pool3_d(*args): + return adaptive_max_pool3_d_op(*args) + + +add_op = Add().set_device('CPU') +def add(*args): + return add_op(*args) + + +add_n_op = AddN().set_device('CPU') +def add_n(*args): + return add_n_op(*args) + + +addcdiv_op = Addcdiv().set_device('CPU') +def addcdiv(*args): + return addcdiv_op(*args) + + +addcmul_op = Addcmul().set_device('CPU') +def addcmul(*args): + return addcmul_op(*args) + + +adjust_hue_op = AdjustHue().set_device('CPU') +def adjust_hue(*args): + return adjust_hue_op(*args) + + +adjust_saturation_op = AdjustSaturation().set_device('CPU') +def adjust_saturation(*args): + return adjust_saturation_op(*args) + + +def affine_grid(*args): + op = _get_cache_prim(AffineGrid)(*args[-1:]).set_device('CPU') + return op(*args[:-1]) + + +def all_gather(*args): + op = _get_cache_prim(AllGather)(*args[-1:]).set_device('CPU') + return op(*args[:-1]) + + +def all_reduce(*args): + op = _get_cache_prim(AllReduce)(*args[-2:]).set_device('CPU') + return op(*args[:-2]) + + +def allto_all(*args): + op = _get_cache_prim(AlltoAll)(*args[-4:]).set_device('CPU') + return op(*args[:-4]) + + +def allto_all_v(*args): + op = _get_cache_prim(AlltoAllV)(*args[-2:]).set_device('CPU') + return op(*args[:-2]) + + +angle_op = Angle().set_device('CPU') +def angle(*args): + return angle_op(*args) + + +apply_ada_max_op = ApplyAdaMax().set_device('CPU') +def apply_ada_max(*args): + return apply_ada_max_op(*args) + + +apply_adadelta_op = ApplyAdadelta().set_device('CPU') +def apply_adadelta(*args): + return apply_adadelta_op(*args) + + +def apply_adagrad(*args): + op = _get_cache_prim(ApplyAdagrad)(*args[-1:]).set_device('CPU') + return op(*args[:-1]) + + +def apply_adagrad_da(*args): + op = _get_cache_prim(ApplyAdagradDA)(*args[-1:]).set_device('CPU') + return op(*args[:-1]) + + +def apply_adagrad_v2(*args): + op = _get_cache_prim(ApplyAdagradV2)(*args[-2:]).set_device('CPU') + return op(*args[:-2]) + + +def apply_adam_with_amsgrad(*args): + op = _get_cache_prim(ApplyAdamWithAmsgrad)(*args[-4:]).set_device('CPU') + return op(*args[:-4]) + + +def apply_adam_with_amsgrad_v2(*args): + op = _get_cache_prim(ApplyAdamWithAmsgradV2)(*args[-1:]).set_device('CPU') + return op(*args[:-1]) + + +apply_add_sign_op = ApplyAddSign().set_device('CPU') +def apply_add_sign(*args): + return apply_add_sign_op(*args) + + +def apply_centered_rms_prop(*args): + op = _get_cache_prim(ApplyCenteredRMSProp)(*args[-1:]).set_device('CPU') + return op(*args[:-1]) + + +def apply_ftrl(*args): + op = _get_cache_prim(ApplyFtrl)(*args[-1:]).set_device('CPU') + return op(*args[:-1]) + + +apply_gradient_descent_op = ApplyGradientDescent().set_device('CPU') +def apply_gradient_descent(*args): + return apply_gradient_descent_op(*args) + + +def apply_keras_momentum(*args): + op = _get_cache_prim(ApplyKerasMomentum)(*args[-2:]).set_device('CPU') + return op(*args[:-2]) + + +def apply_momentum(*args): + op = _get_cache_prim(ApplyMomentum)(*args[-3:]).set_device('CPU') + return op(*args[:-3]) + + +apply_power_sign_op = ApplyPowerSign().set_device('CPU') +def apply_power_sign(*args): + return apply_power_sign_op(*args) + + +def apply_proximal_adagrad(*args): + op = _get_cache_prim(ApplyProximalAdagrad)(*args[-1:]).set_device('CPU') + return op(*args[:-1]) + + +apply_proximal_gradient_descent_op = ApplyProximalGradientDescent().set_device('CPU') +def apply_proximal_gradient_descent(*args): + return apply_proximal_gradient_descent_op(*args) + + +def apply_rms_prop(*args): + op = _get_cache_prim(ApplyRMSProp)(*args[-1:]).set_device('CPU') + return op(*args[:-1]) + + +def apply_rotary_pos_emb(*args): + op = _get_cache_prim(ApplyRotaryPosEmb)(*args[-1:]).set_device('CPU') + return op(*args[:-1]) + + +def approximate_equal(*args): + op = _get_cache_prim(ApproximateEqual)(*args[-1:]).set_device('CPU') + return op(*args[:-1]) + + +def arg_max_with_value(*args): + op = _get_cache_prim(ArgMaxWithValue)(*args[-2:]).set_device('CPU') + return op(*args[:-2]) + + +def arg_min_with_value(*args): + op = _get_cache_prim(ArgMinWithValue)(*args[-2:]).set_device('CPU') + return op(*args[:-2]) + + +def argmax(*args): + op = _get_cache_prim(Argmax)(*args[-2:]).set_device('CPU') + return op(*args[:-2]) + + +def argmin(*args): + op = _get_cache_prim(Argmin)(*args[-2:]).set_device('CPU') + return op(*args[:-2]) + + +asin_op = Asin().set_device('CPU') +def asin(*args): + return asin_op(*args) + + +asinh_op = Asinh().set_device('CPU') +def asinh(*args): + return asinh_op(*args) + + +assign_op = Assign().set_device('CPU') +def assign(*args): + return assign_op(*args) + + +assign_add_op = AssignAdd().set_device('CPU') +def assign_add(*args): + return assign_add_op(*args) + + +assign_sub_op = AssignSub().set_device('CPU') +def assign_sub(*args): + return assign_sub_op(*args) + + +atan_op = Atan().set_device('CPU') +def atan(*args): + return atan_op(*args) + + +atan2_op = Atan2().set_device('CPU') +def atan2(*args): + return atan2_op(*args) + + +atanh_op = Atanh().set_device('CPU') +def atanh(*args): + return atanh_op(*args) + + +def avg_pool(*args): + op = _get_cache_prim(AvgPool)(*args[-4:]).set_device('CPU') + return op(*args[:-4]) + + +def avg_pool3_d(*args): + op = _get_cache_prim(AvgPool3D)(*args[-8:]).set_device('CPU') + return op(*args[:-8]) + + +def bce_with_logits_loss(*args): + op = _get_cache_prim(BCEWithLogitsLoss)(*args[-1:]).set_device('CPU') + return op(*args[:-1]) + + +def barrier(*args): + op = _get_cache_prim(Barrier)(*args[-1:]).set_device('CPU') + return op(*args[:-1]) + + +def bartlett_window(*args): + op = _get_cache_prim(BartlettWindow)(*args[-2:]).set_device('CPU') + return op(*args[:-2]) + + +def basic_lstm_cell(*args): + op = _get_cache_prim(BasicLSTMCell)(*args[-4:]).set_device('CPU') + return op(*args[:-4]) + + +def batch_i_send_i_recv(*args): + op = _get_cache_prim(BatchISendIRecv)(*args[-5:]).set_device('CPU') + return op(*args[:-5]) + + +def batch_mat_mul(*args): + op = _get_cache_prim(BatchMatMul)(*args[-2:]).set_device('CPU') + return op(*args[:-2]) + + +def batch_norm(*args): + op = _get_cache_prim(BatchNorm)(*args[-4:]).set_device('CPU') + return op(*args[:-4]) + + +def batch_to_space(*args): + op = _get_cache_prim(BatchToSpace)(*args[-2:]).set_device('CPU') + return op(*args[:-2]) + + +def batch_to_space_nd(*args): + op = _get_cache_prim(BatchToSpaceND)(*args[-2:]).set_device('CPU') + return op(*args[:-2]) + + +batch_to_space_ndv2_op = BatchToSpaceNDV2().set_device('CPU') +def batch_to_space_ndv2(*args): + return batch_to_space_ndv2_op(*args) + + +def bernoulli(*args): + op = _get_cache_prim(Bernoulli)(*args[-2:]).set_device('CPU') + return op(*args[:-2]) + + +bessel_i0_op = BesselI0().set_device('CPU') +def bessel_i0(*args): + return bessel_i0_op(*args) + + +bessel_i0e_op = BesselI0e().set_device('CPU') +def bessel_i0e(*args): + return bessel_i0e_op(*args) + + +bessel_i1_op = BesselI1().set_device('CPU') +def bessel_i1(*args): + return bessel_i1_op(*args) + + +bessel_i1e_op = BesselI1e().set_device('CPU') +def bessel_i1e(*args): + return bessel_i1e_op(*args) + + +bessel_j0_op = BesselJ0().set_device('CPU') +def bessel_j0(*args): + return bessel_j0_op(*args) + + +bessel_j1_op = BesselJ1().set_device('CPU') +def bessel_j1(*args): + return bessel_j1_op(*args) + + +bessel_k0_op = BesselK0().set_device('CPU') +def bessel_k0(*args): + return bessel_k0_op(*args) + + +bessel_k0e_op = BesselK0e().set_device('CPU') +def bessel_k0e(*args): + return bessel_k0e_op(*args) + + +bessel_k1_op = BesselK1().set_device('CPU') +def bessel_k1(*args): + return bessel_k1_op(*args) + + +bessel_k1e_op = BesselK1e().set_device('CPU') +def bessel_k1e(*args): + return bessel_k1e_op(*args) + + +bessel_y0_op = BesselY0().set_device('CPU') +def bessel_y0(*args): + return bessel_y0_op(*args) + + +bessel_y1_op = BesselY1().set_device('CPU') +def bessel_y1(*args): + return bessel_y1_op(*args) + + +betainc_op = Betainc().set_device('CPU') +def betainc(*args): + return betainc_op(*args) + + +def bias_add(*args): + op = _get_cache_prim(BiasAdd)(*args[-1:]).set_device('CPU') + return op(*args[:-1]) + + +def binary_cross_entropy(*args): + op = _get_cache_prim(BinaryCrossEntropy)(*args[-1:]).set_device('CPU') + return op(*args[:-1]) + + +bincount_op = Bincount().set_device('CPU') +def bincount(*args): + return bincount_op(*args) + + +bitwise_and_op = BitwiseAnd().set_device('CPU') +def bitwise_and(*args): + return bitwise_and_op(*args) + + +bitwise_or_op = BitwiseOr().set_device('CPU') +def bitwise_or(*args): + return bitwise_or_op(*args) + + +bitwise_xor_op = BitwiseXor().set_device('CPU') +def bitwise_xor(*args): + return bitwise_xor_op(*args) + + +def blackman_window(*args): + op = _get_cache_prim(BlackmanWindow)(*args[-2:]).set_device('CPU') + return op(*args[:-2]) + + +def bounding_box_decode(*args): + op = _get_cache_prim(BoundingBoxDecode)(*args[-4:]).set_device('CPU') + return op(*args[:-4]) + + +def bounding_box_encode(*args): + op = _get_cache_prim(BoundingBoxEncode)(*args[-2:]).set_device('CPU') + return op(*args[:-2]) + + +def broadcast(*args): + op = _get_cache_prim(Broadcast)(*args[-2:]).set_device('CPU') + return op(*args[:-2]) + + +def broadcast_to(*args): + op = _get_cache_prim(BroadcastTo)(*args[-1:]).set_device('CPU') + return op(*args[:-1]) + + +def bucketize(*args): + op = _get_cache_prim(Bucketize)(*args[-1:]).set_device('CPU') + return op(*args[:-1]) + + +def buffer_append(*args): + op = _get_cache_prim(BufferAppend)(*args[-3:]).set_device('CPU') + return op(*args[:-3]) + + +def buffer_get_item(*args): + op = _get_cache_prim(BufferGetItem)(*args[-3:]).set_device('CPU') + return op(*args[:-3]) + + +def buffer_sample(*args): + op = _get_cache_prim(BufferSample)(*args[-6:]).set_device('CPU') + return op(*args[:-6]) + + +def ctc_greedy_decoder(*args): + op = _get_cache_prim(CTCGreedyDecoder)(*args[-1:]).set_device('CPU') + return op(*args[:-1]) + + +def ctc_loss(*args): + op = _get_cache_prim(CTCLoss)(*args[-3:]).set_device('CPU') + return op(*args[:-3]) + + +def ctc_loss_v2(*args): + op = _get_cache_prim(CTCLossV2)(*args[-3:]).set_device('CPU') + return op(*args[:-3]) + + +cast_op = Cast().set_device('CPU') +def cast(*args): + return cast_op(*args) + + +def cauchy(*args): + op = _get_cache_prim(Cauchy)(*args[-3:]).set_device('CPU') + return op(*args[:-3]) + + +def cdist(*args): + op = _get_cache_prim(Cdist)(*args[-1:]).set_device('CPU') + return op(*args[:-1]) + + +def ce_lu(*args): + op = _get_cache_prim(CeLU)(*args[-1:]).set_device('CPU') + return op(*args[:-1]) + + +ceil_op = Ceil().set_device('CPU') +def ceil(*args): + return ceil_op(*args) + + +def channel_shuffle(*args): + op = _get_cache_prim(ChannelShuffle)(*args[-1:]).set_device('CPU') + return op(*args[:-1]) + + +check_numerics_op = CheckNumerics().set_device('CPU') +def check_numerics(*args): + return check_numerics_op(*args) + + +check_valid_op = CheckValid().set_device('CPU') +def check_valid(*args): + return check_valid_op(*args) + + +def cholesky(*args): + op = _get_cache_prim(Cholesky)(*args[-1:]).set_device('CPU') + return op(*args[:-1]) + + +def cholesky_inverse(*args): + op = _get_cache_prim(CholeskyInverse)(*args[-1:]).set_device('CPU') + return op(*args[:-1]) + + +def cholesky_solve(*args): + op = _get_cache_prim(CholeskySolve)(*args[-1:]).set_device('CPU') + return op(*args[:-1]) + + +coalesce_op = Coalesce().set_device('CPU') +def coalesce(*args): + return coalesce_op(*args) + + +def col2_im(*args): + op = _get_cache_prim(Col2Im)(*args[-4:]).set_device('CPU') + return op(*args[:-4]) + + +def collective_gather(*args): + op = _get_cache_prim(CollectiveGather)(*args[-2:]).set_device('CPU') + return op(*args[:-2]) + + +def collective_scatter(*args): + op = _get_cache_prim(CollectiveScatter)(*args[-2:]).set_device('CPU') + return op(*args[:-2]) + + +def combined_non_max_suppression(*args): + op = _get_cache_prim(CombinedNonMaxSuppression)(*args[-2:]).set_device('CPU') + return op(*args[:-2]) + + +compare_and_bitpack_op = CompareAndBitpack().set_device('CPU') +def compare_and_bitpack(*args): + return compare_and_bitpack_op(*args) + + +complex_op = Complex().set_device('CPU') +def complex(*args): + return complex_op(*args) + + +complex_abs_op = ComplexAbs().set_device('CPU') +def complex_abs(*args): + return complex_abs_op(*args) + + +def compute_accidental_hits(*args): + op = _get_cache_prim(ComputeAccidentalHits)(*args[-1:]).set_device('CPU') + return op(*args[:-1]) + + +def concat(*args): + op = _get_cache_prim(Concat)(*args[-1:]).set_device('CPU') + return op(*args[:-1]) + + +def confusion_matrix(*args): + op = _get_cache_prim(ConfusionMatrix)(*args[-2:]).set_device('CPU') + return op(*args[:-2]) + + +conj_op = Conj().set_device('CPU') +def conj(*args): + return conj_op(*args) + + +conjugate_transpose_op = ConjugateTranspose().set_device('CPU') +def conjugate_transpose(*args): + return conjugate_transpose_op(*args) + + +def conv2_d(*args): + op = _get_cache_prim(Conv2D)(*args[-9:]).set_device('CPU') + return op(*args[:-9]) + + +def conv2_d_backprop_input(*args): + op = _get_cache_prim(Conv2DBackpropInput)(*args[-10:]).set_device('CPU') + return op(*args[:-10]) + + +def conv2_d_transpose(*args): + op = _get_cache_prim(Conv2DTranspose)(*args[-10:]).set_device('CPU') + return op(*args[:-10]) + + +def conv3_d(*args): + op = _get_cache_prim(Conv3D)(*args[-9:]).set_device('CPU') + return op(*args[:-9]) + + +def conv3_d_transpose(*args): + op = _get_cache_prim(Conv3DTranspose)(*args[-11:]).set_device('CPU') + return op(*args[:-11]) + + +copy_with_slice_op = CopyWithSlice().set_device('CPU') +def copy_with_slice(*args): + return copy_with_slice_op(*args) + + +cos_op = Cos().set_device('CPU') +def cos(*args): + return cos_op(*args) + + +cosh_op = Cosh().set_device('CPU') +def cosh(*args): + return cosh_op(*args) + + +def count_non_zero(*args): + op = _get_cache_prim(CountNonZero)(*args[-1:]).set_device('CPU') + return op(*args[:-1]) + + +def crop_and_resize(*args): + op = _get_cache_prim(CropAndResize)(*args[-2:]).set_device('CPU') + return op(*args[:-2]) + + +def cross(*args): + op = _get_cache_prim(Cross)(*args[-1:]).set_device('CPU') + return op(*args[:-1]) + + +def cum_prod(*args): + op = _get_cache_prim(CumProd)(*args[-2:]).set_device('CPU') + return op(*args[:-2]) + + +def cum_sum(*args): + op = _get_cache_prim(CumSum)(*args[-2:]).set_device('CPU') + return op(*args[:-2]) + + +def cummax(*args): + op = _get_cache_prim(Cummax)(*args[-1:]).set_device('CPU') + return op(*args[:-1]) + + +def cummin(*args): + op = _get_cache_prim(Cummin)(*args[-1:]).set_device('CPU') + return op(*args[:-1]) + + +def cumulative_logsumexp(*args): + op = _get_cache_prim(CumulativeLogsumexp)(*args[-2:]).set_device('CPU') + return op(*args[:-2]) + + +d_type_op = DType().set_device('CPU') +def d_type(*args): + return d_type_op(*args) + + +def data_format_dim_map(*args): + op = _get_cache_prim(DataFormatDimMap)(*args[-2:]).set_device('CPU') + return op(*args[:-2]) + + +def data_format_vec_permute(*args): + op = _get_cache_prim(DataFormatVecPermute)(*args[-2:]).set_device('CPU') + return op(*args[:-2]) + + +def deformable_offsets(*args): + op = _get_cache_prim(DeformableOffsets)(*args[-7:]).set_device('CPU') + return op(*args[:-7]) + + +dense_op = Dense().set_device('CPU') +def dense(*args): + return dense_op(*args) + + +depend_op = Depend().set_device('CPU') +def depend(*args): + return depend_op(*args) + + +def depth_to_space(*args): + op = _get_cache_prim(DepthToSpace)(*args[-1:]).set_device('CPU') + return op(*args[:-1]) + + +def depthwise_conv2d_native(*args): + op = _get_cache_prim(DepthwiseConv2dNative)(*args[-8:]).set_device('CPU') + return op(*args[:-8]) + + +diag_op = Diag().set_device('CPU') +def diag(*args): + return diag_op(*args) + + +diag_part_op = DiagPart().set_device('CPU') +def diag_part(*args): + return diag_part_op(*args) + + +digamma_op = Digamma().set_device('CPU') +def digamma(*args): + return digamma_op(*args) + + +def dilation2_d(*args): + op = _get_cache_prim(Dilation2D)(*args[-4:]).set_device('CPU') + return op(*args[:-4]) + + +div_op = Div().set_device('CPU') +def div(*args): + return div_op(*args) + + +div_no_nan_op = DivNoNan().set_device('CPU') +def div_no_nan(*args): + return div_no_nan_op(*args) + + +def dropout(*args): + op = _get_cache_prim(Dropout)(*args[-3:]).set_device('CPU') + return op(*args[:-3]) + + +def dropout2_d(*args): + op = _get_cache_prim(Dropout2D)(*args[-1:]).set_device('CPU') + return op(*args[:-1]) + + +def dropout3_d(*args): + op = _get_cache_prim(Dropout3D)(*args[-1:]).set_device('CPU') + return op(*args[:-1]) + + +def dropout_gen_mask(*args): + op = _get_cache_prim(DropoutGenMask)(*args[-2:]).set_device('CPU') + return op(*args[:-2]) + + +def dynamic_gruv2(*args): + op = _get_cache_prim(DynamicGRUV2)(*args[-10:]).set_device('CPU') + return op(*args[:-10]) + + +def dynamic_rnn(*args): + op = _get_cache_prim(DynamicRNN)(*args[-11:]).set_device('CPU') + return op(*args[:-11]) + + +def dynamic_shape(*args): + op = _get_cache_prim(DynamicShape)(*args[-1:]).set_device('CPU') + return op(*args[:-1]) + + +def edit_distance(*args): + op = _get_cache_prim(EditDistance)(*args[-1:]).set_device('CPU') + return op(*args[:-1]) + + +def eig(*args): + op = _get_cache_prim(Eig)(*args[-1:]).set_device('CPU') + return op(*args[:-1]) + + +def einsum(*args): + op = _get_cache_prim(Einsum)(*args[-1:]).set_device('CPU') + return op(*args[:-1]) + + +def elu(*args): + op = _get_cache_prim(Elu)(*args[-1:]).set_device('CPU') + return op(*args[:-1]) + + +embedding_lookup_op = EmbeddingLookup().set_device('CPU') +def embedding_lookup(*args): + return embedding_lookup_op(*args) + + +eps_op = Eps().set_device('CPU') +def eps(*args): + return eps_op(*args) + + +equal_op = Equal().set_device('CPU') +def equal(*args): + return equal_op(*args) + + +equal_count_op = EqualCount().set_device('CPU') +def equal_count(*args): + return equal_count_op(*args) + + +erf_op = Erf().set_device('CPU') +def erf(*args): + return erf_op(*args) + + +erfc_op = Erfc().set_device('CPU') +def erfc(*args): + return erfc_op(*args) + + +erfinv_op = Erfinv().set_device('CPU') +def erfinv(*args): + return erfinv_op(*args) + + +erfinv_op = Erfinv().set_device('CPU') +def erfinv(*args): + return erfinv_op(*args) + + +def euclidean_norm(*args): + op = _get_cache_prim(EuclideanNorm)(*args[-1:]).set_device('CPU') + return op(*args[:-1]) + + +exp_op = Exp().set_device('CPU') +def exp(*args): + return exp_op(*args) + + + +expand_dims_op = ExpandDims().set_device('CPU') +def expand_dims(*args): + return expand_dims_op(*args) + + +expm1_op = Expm1().set_device('CPU') +def expm1(*args): + return expm1_op(*args) + + +def extract_glimpse(*args): + op = _get_cache_prim(ExtractGlimpse)(*args[-4:]).set_device('CPU') + return op(*args[:-4]) + + +def extract_image_patches(*args): + op = _get_cache_prim(ExtractImagePatches)(*args[-4:]).set_device('CPU') + return op(*args[:-4]) + + +def extract_volume_patches(*args): + op = _get_cache_prim(ExtractVolumePatches)(*args[-3:]).set_device('CPU') + return op(*args[:-3]) + + +eye_op = Eye().set_device('CPU') +def eye(*args): + return eye_op(*args) + + +def fft_with_size(*args): + op = _get_cache_prim(FFTWithSize)(*args[-6:]).set_device('CPU') + return op(*args[:-6]) + + +fast_ge_lu_op = FastGeLU().set_device('CPU') +def fast_ge_lu(*args): + return fast_ge_lu_op(*args) + + +fill_op = Fill().set_device('CPU') +def fill(*args): + return fill_op(*args) + + +def fill_diagonal(*args): + op = _get_cache_prim(FillDiagonal)(*args[-2:]).set_device('CPU') + return op(*args[:-2]) + + +fill_v2_op = FillV2().set_device('CPU') +def fill_v2(*args): + return fill_v2_op(*args) + + +fills_op = Fills().set_device('CPU') +def fills(*args): + return fills_op(*args) + + +flatten_op = Flatten().set_device('CPU') +def flatten(*args): + return flatten_op(*args) + + +float_status_op = FloatStatus().set_device('CPU') +def float_status(*args): + return float_status_op(*args) + + +floor_op = Floor().set_device('CPU') +def floor(*args): + return floor_op(*args) + + +floor_div_op = FloorDiv().set_device('CPU') +def floor_div(*args): + return floor_div_op(*args) + + +floor_mod_op = FloorMod().set_device('CPU') +def floor_mod(*args): + return floor_mod_op(*args) + + +fmax_op = Fmax().set_device('CPU') +def fmax(*args): + return fmax_op(*args) + + +fmin_op = Fmin().set_device('CPU') +def fmin(*args): + return fmin_op(*args) + + +fori_loop_op = ForiLoop().set_device('CPU') +def fori_loop(*args): + return fori_loop_op(*args) + + +def fractional_avg_pool(*args): + op = _get_cache_prim(FractionalAvgPool)(*args[-6:]).set_device('CPU') + return op(*args[:-6]) + + +def fractional_max_pool(*args): + op = _get_cache_prim(FractionalMaxPool)(*args[-6:]).set_device('CPU') + return op(*args[:-6]) + + +def fractional_max_pool3_d_with_fixed_ksize(*args): + op = _get_cache_prim(FractionalMaxPool3DWithFixedKsize)(*args[-3:]).set_device('CPU') + return op(*args[:-3]) + + +def fractional_max_pool_with_fixed_ksize(*args): + op = _get_cache_prim(FractionalMaxPoolWithFixedKsize)(*args[-3:]).set_device('CPU') + return op(*args[:-3]) + + +def fused_ada_factor(*args): + op = _get_cache_prim(FusedAdaFactor)(*args[-3:]).set_device('CPU') + return op(*args[:-3]) + + +def fused_ada_factor_with_global_norm(*args): + op = _get_cache_prim(FusedAdaFactorWithGlobalNorm)(*args[-3:]).set_device('CPU') + return op(*args[:-3]) + + +def fused_cast_adam_weight_decay(*args): + op = _get_cache_prim(FusedCastAdamWeightDecay)(*args[-1:]).set_device('CPU') + return op(*args[:-1]) + + +def fused_sparse_adam(*args): + op = _get_cache_prim(FusedSparseAdam)(*args[-2:]).set_device('CPU') + return op(*args[:-2]) + + +def fused_sparse_ftrl(*args): + op = _get_cache_prim(FusedSparseFtrl)(*args[-5:]).set_device('CPU') + return op(*args[:-5]) + + +def fused_sparse_lazy_adam(*args): + op = _get_cache_prim(FusedSparseLazyAdam)(*args[-2:]).set_device('CPU') + return op(*args[:-2]) + + +def fused_sparse_proximal_adagrad(*args): + op = _get_cache_prim(FusedSparseProximalAdagrad)(*args[-1:]).set_device('CPU') + return op(*args[:-1]) + + +fused_weight_scale_apply_momentum_op = FusedWeightScaleApplyMomentum().set_device('CPU') +def fused_weight_scale_apply_momentum(*args): + return fused_weight_scale_apply_momentum_op(*args) + + +def glu(*args): + op = _get_cache_prim(GLU)(*args[-1:]).set_device('CPU') + return op(*args[:-1]) + + +def gamma(*args): + op = _get_cache_prim(Gamma)(*args[-2:]).set_device('CPU') + return op(*args[:-2]) + + +def gather(*args): + op = _get_cache_prim(Gather)(*args[-1:]).set_device('CPU') + return op(*args[:-1]) + + +gather_d_op = GatherD().set_device('CPU') +def gather_d(*args): + return gather_d_op(*args) + + +gather_nd_op = GatherNd().set_device('CPU') +def gather_nd(*args): + return gather_nd_op(*args) + + +gcd_op = Gcd().set_device('CPU') +def gcd(*args): + return gcd_op(*args) + + +ge_lu_op = GeLU().set_device('CPU') +def ge_lu(*args): + return ge_lu_op(*args) + + +ge_switch_op = GeSwitch().set_device('CPU') +def ge_switch(*args): + return ge_switch_op(*args) + + +geqrf_op = Geqrf().set_device('CPU') +def geqrf(*args): + return geqrf_op(*args) + + +ger_op = Ger().set_device('CPU') +def ger(*args): + return ger_op(*args) + + +def get_next(*args): + op = _get_cache_prim(GetNext)(*args[-4:]).set_device('CPU') + return op(*args[:-4]) + + +greater_op = Greater().set_device('CPU') +def greater(*args): + return greater_op(*args) + + +greater_equal_op = GreaterEqual().set_device('CPU') +def greater_equal(*args): + return greater_equal_op(*args) + + +def grid_sampler2_d(*args): + op = _get_cache_prim(GridSampler2D)(*args[-3:]).set_device('CPU') + return op(*args[:-3]) + + +def grid_sampler3_d(*args): + op = _get_cache_prim(GridSampler3D)(*args[-3:]).set_device('CPU') + return op(*args[:-3]) + + +group_topk_op = GroupTopk().set_device('CPU') +def group_topk(*args): + return group_topk_op(*args) + + +hsv_to_rgb_op = HSVToRGB().set_device('CPU') +def hsv_to_rgb(*args): + return hsv_to_rgb_op(*args) + + +def h_shrink(*args): + op = _get_cache_prim(HShrink)(*args[-1:]).set_device('CPU') + return op(*args[:-1]) + + +h_sigmoid_op = HSigmoid().set_device('CPU') +def h_sigmoid(*args): + return h_sigmoid_op(*args) + + +h_swish_op = HSwish().set_device('CPU') +def h_swish(*args): + return h_swish_op(*args) + + +def hamming_window(*args): + op = _get_cache_prim(HammingWindow)(*args[-4:]).set_device('CPU') + return op(*args[:-4]) + + +heaviside_op = Heaviside().set_device('CPU') +def heaviside(*args): + return heaviside_op(*args) + + +def histogram(*args): + op = _get_cache_prim(Histogram)(*args[-3:]).set_device('CPU') + return op(*args[:-3]) + + +def histogram_fixed_width(*args): + op = _get_cache_prim(HistogramFixedWidth)(*args[-2:]).set_device('CPU') + return op(*args[:-2]) + + +histogram_summary_op = HistogramSummary().set_device('CPU') +def histogram_summary(*args): + return histogram_summary_op(*args) + + +def hook_backward(*args): + op = _get_cache_prim(HookBackward)(*args[-2:]).set_device('CPU') + return op(*args[:-2]) + + +hypot_op = Hypot().set_device('CPU') +def hypot(*args): + return hypot_op(*args) + + +def iou(*args): + op = _get_cache_prim(IOU)(*args[-1:]).set_device('CPU') + return op(*args[:-1]) + + +identity_op = Identity().set_device('CPU') +def identity(*args): + return identity_op(*args) + + +identity_n_op = IdentityN().set_device('CPU') +def identity_n(*args): + return identity_n_op(*args) + + +igamma_op = Igamma().set_device('CPU') +def igamma(*args): + return igamma_op(*args) + + +igammac_op = Igammac().set_device('CPU') +def igammac(*args): + return igammac_op(*args) + + +def im2_col(*args): + op = _get_cache_prim(Im2Col)(*args[-4:]).set_device('CPU') + return op(*args[:-4]) + + +imag_op = Imag().set_device('CPU') +def imag(*args): + return imag_op(*args) + + +image_summary_op = ImageSummary().set_device('CPU') +def image_summary(*args): + return image_summary_op(*args) + + +def in_top_k(*args): + op = _get_cache_prim(InTopK)(*args[-1:]).set_device('CPU') + return op(*args[:-1]) + + +def index_add(*args): + op = _get_cache_prim(IndexAdd)(*args[-3:]).set_device('CPU') + return op(*args[:-3]) + + +index_fill_op = IndexFill().set_device('CPU') +def index_fill(*args): + return index_fill_op(*args) + + +def index_put(*args): + op = _get_cache_prim(IndexPut)(*args[-1:]).set_device('CPU') + return op(*args[:-1]) + + +def inplace_add(*args): + op = _get_cache_prim(InplaceAdd)(*args[-1:]).set_device('CPU') + return op(*args[:-1]) + + +def inplace_index_add(*args): + op = _get_cache_prim(InplaceIndexAdd)(*args[-1:]).set_device('CPU') + return op(*args[:-1]) + + +def inplace_sub(*args): + op = _get_cache_prim(InplaceSub)(*args[-1:]).set_device('CPU') + return op(*args[:-1]) + + +def inplace_update(*args): + op = _get_cache_prim(InplaceUpdate)(*args[-1:]).set_device('CPU') + return op(*args[:-1]) + + +inplace_update_v2_op = InplaceUpdateV2().set_device('CPU') +def inplace_update_v2(*args): + return inplace_update_v2_op(*args) + + +def insert_gradient_of(*args): + op = _get_cache_prim(InsertGradientOf)(*args[-1:]).set_device('CPU') + return op(*args[:-1]) + + +inv_op = Inv().set_device('CPU') +def inv(*args): + return inv_op(*args) + + +invert_op = Invert().set_device('CPU') +def invert(*args): + return invert_op(*args) + + +invert_permutation_op = InvertPermutation().set_device('CPU') +def invert_permutation(*args): + return invert_permutation_op(*args) + + +def is_close(*args): + op = _get_cache_prim(IsClose)(*args[-3:]).set_device('CPU') + return op(*args[:-3]) + + +is_finite_op = IsFinite().set_device('CPU') +def is_finite(*args): + return is_finite_op(*args) + + +is_inf_op = IsInf().set_device('CPU') +def is_inf(*args): + return is_inf_op(*args) + + +is_nan_op = IsNan().set_device('CPU') +def is_nan(*args): + return is_nan_op(*args) + + +def kl_div_loss(*args): + op = _get_cache_prim(KLDivLoss)(*args[-1:]).set_device('CPU') + return op(*args[:-1]) + + +l2_loss_op = L2Loss().set_device('CPU') +def l2_loss(*args): + return l2_loss_op(*args) + + +def l2_normalize(*args): + op = _get_cache_prim(L2Normalize)(*args[-2:]).set_device('CPU') + return op(*args[:-2]) + + +def lars_update(*args): + op = _get_cache_prim(LARSUpdate)(*args[-3:]).set_device('CPU') + return op(*args[:-3]) + + +def lrn(*args): + op = _get_cache_prim(LRN)(*args[-5:]).set_device('CPU') + return op(*args[:-5]) + + +def lstm(*args): + op = _get_cache_prim(LSTM)(*args[-7:]).set_device('CPU') + return op(*args[:-7]) + + +def layer_norm(*args): + op = _get_cache_prim(LayerNorm)(*args[-3:]).set_device('CPU') + return op(*args[:-3]) + + +lcm_op = Lcm().set_device('CPU') +def lcm(*args): + return lcm_op(*args) + + +left_shift_op = LeftShift().set_device('CPU') +def left_shift(*args): + return left_shift_op(*args) + + +lerp_op = Lerp().set_device('CPU') +def lerp(*args): + return lerp_op(*args) + + +lerp_scalar_op = LerpScalar().set_device('CPU') +def lerp_scalar(*args): + return lerp_scalar_op(*args) + + +less_op = Less().set_device('CPU') +def less(*args): + return less_op(*args) + + +less_equal_op = LessEqual().set_device('CPU') +def less_equal(*args): + return less_equal_op(*args) + + +lgamma_op = Lgamma().set_device('CPU') +def lgamma(*args): + return lgamma_op(*args) + + +lin_space_op = LinSpace().set_device('CPU') +def lin_space(*args): + return lin_space_op(*args) + + +def list_diff(*args): + op = _get_cache_prim(ListDiff)(*args[-1:]).set_device('CPU') + return op(*args[:-1]) + + +log_op = Log().set_device('CPU') +def log(*args): + return log_op(*args) + + +log1p_op = Log1p().set_device('CPU') +def log1p(*args): + return log1p_op(*args) + + +log_matrix_determinant_op = LogMatrixDeterminant().set_device('CPU') +def log_matrix_determinant(*args): + return log_matrix_determinant_op(*args) + + +def log_normal_reverse(*args): + op = _get_cache_prim(LogNormalReverse)(*args[-2:]).set_device('CPU') + return op(*args[:-2]) + + +def log_softmax(*args): + op = _get_cache_prim(LogSoftmax)(*args[-1:]).set_device('CPU') + return op(*args[:-1]) + + +log_softmax_ext_op = LogSoftmaxExt().set_device('CPU') +def log_softmax_ext(*args): + return log_softmax_ext_op(*args) + + +def log_space(*args): + op = _get_cache_prim(LogSpace)(*args[-3:]).set_device('CPU') + return op(*args[:-3]) + + +def log_uniform_candidate_sampler(*args): + op = _get_cache_prim(LogUniformCandidateSampler)(*args[-5:]).set_device('CPU') + return op(*args[:-5]) + + +logical_and_op = LogicalAnd().set_device('CPU') +def logical_and(*args): + return logical_and_op(*args) + + +logical_not_op = LogicalNot().set_device('CPU') +def logical_not(*args): + return logical_not_op(*args) + + +logical_or_op = LogicalOr().set_device('CPU') +def logical_or(*args): + return logical_or_op(*args) + + +logical_xor_op = LogicalXor().set_device('CPU') +def logical_xor(*args): + return logical_xor_op(*args) + + +def logit(*args): + op = _get_cache_prim(Logit)(*args[-1:]).set_device('CPU') + return op(*args[:-1]) + + +def lower_bound(*args): + op = _get_cache_prim(LowerBound)(*args[-1:]).set_device('CPU') + return op(*args[:-1]) + + +def lp_norm(*args): + op = _get_cache_prim(LpNorm)(*args[-4:]).set_device('CPU') + return op(*args[:-4]) + + +def lstsq(*args): + op = _get_cache_prim(Lstsq)(*args[-2:]).set_device('CPU') + return op(*args[:-2]) + + +lu_solve_op = LuSolve().set_device('CPU') +def lu_solve(*args): + return lu_solve_op(*args) + + +def lu_unpack(*args): + op = _get_cache_prim(LuUnpack)(*args[-2:]).set_device('CPU') + return op(*args[:-2]) + + +map_cache_idx_op = MapCacheIdx().set_device('CPU') +def map_cache_idx(*args): + return map_cache_idx_op(*args) + + +map_uniform_op = MapUniform().set_device('CPU') +def map_uniform(*args): + return map_uniform_op(*args) + + +masked_fill_op = MaskedFill().set_device('CPU') +def masked_fill(*args): + return masked_fill_op(*args) + + +masked_scatter_op = MaskedScatter().set_device('CPU') +def masked_scatter(*args): + return masked_scatter_op(*args) + + +masked_select_op = MaskedSelect().set_device('CPU') +def masked_select(*args): + return masked_select_op(*args) + + +def mat_mul(*args): + op = _get_cache_prim(MatMul)(*args[-2:]).set_device('CPU') + return op(*args[:-2]) + + +matrix_band_part_op = MatrixBandPart().set_device('CPU') +def matrix_band_part(*args): + return matrix_band_part_op(*args) + + +matrix_determinant_op = MatrixDeterminant().set_device('CPU') +def matrix_determinant(*args): + return matrix_determinant_op(*args) + + +def matrix_diag_part_v3(*args): + op = _get_cache_prim(MatrixDiagPartV3)(*args[-1:]).set_device('CPU') + return op(*args[:-1]) + + +def matrix_diag_v3(*args): + op = _get_cache_prim(MatrixDiagV3)(*args[-1:]).set_device('CPU') + return op(*args[:-1]) + + +matrix_exp_op = MatrixExp().set_device('CPU') +def matrix_exp(*args): + return matrix_exp_op(*args) + + +def matrix_inverse(*args): + op = _get_cache_prim(MatrixInverse)(*args[-1:]).set_device('CPU') + return op(*args[:-1]) + + +matrix_logarithm_op = MatrixLogarithm().set_device('CPU') +def matrix_logarithm(*args): + return matrix_logarithm_op(*args) + + +def matrix_power(*args): + op = _get_cache_prim(MatrixPower)(*args[-1:]).set_device('CPU') + return op(*args[:-1]) + + +def matrix_set_diag_v3(*args): + op = _get_cache_prim(MatrixSetDiagV3)(*args[-1:]).set_device('CPU') + return op(*args[:-1]) + + +def matrix_solve(*args): + op = _get_cache_prim(MatrixSolve)(*args[-1:]).set_device('CPU') + return op(*args[:-1]) + + +def matrix_solve_ls(*args): + op = _get_cache_prim(MatrixSolveLs)(*args[-1:]).set_device('CPU') + return op(*args[:-1]) + + +def matrix_triangular_solve(*args): + op = _get_cache_prim(MatrixTriangularSolve)(*args[-2:]).set_device('CPU') + return op(*args[:-2]) + + +def max_pool(*args): + op = _get_cache_prim(MaxPool)(*args[-4:]).set_device('CPU') + return op(*args[:-4]) + + +def max_pool3_d(*args): + op = _get_cache_prim(MaxPool3D)(*args[-6:]).set_device('CPU') + return op(*args[:-6]) + + +def max_pool3_d_with_argmax(*args): + op = _get_cache_prim(MaxPool3DWithArgmax)(*args[-7:]).set_device('CPU') + return op(*args[:-7]) + + +def max_pool_with_argmax(*args): + op = _get_cache_prim(MaxPoolWithArgmax)(*args[-4:]).set_device('CPU') + return op(*args[:-4]) + + +def max_pool_with_argmax_v2(*args): + op = _get_cache_prim(MaxPoolWithArgmaxV2)(*args[-6:]).set_device('CPU') + return op(*args[:-6]) + + +def max_unpool2_d(*args): + op = _get_cache_prim(MaxUnpool2D)(*args[-5:]).set_device('CPU') + return op(*args[:-5]) + + +def max_unpool3_d(*args): + op = _get_cache_prim(MaxUnpool3D)(*args[-5:]).set_device('CPU') + return op(*args[:-5]) + + +maximum_op = Maximum().set_device('CPU') +def maximum(*args): + return maximum_op(*args) + + +merge_op = Merge().set_device('CPU') +def merge(*args): + return merge_op(*args) + + +def meshgrid(*args): + op = _get_cache_prim(Meshgrid)(*args[-1:]).set_device('CPU') + return op(*args[:-1]) + + +minimum_op = Minimum().set_device('CPU') +def minimum(*args): + return minimum_op(*args) + + +def mirror_pad(*args): + op = _get_cache_prim(MirrorPad)(*args[-1:]).set_device('CPU') + return op(*args[:-1]) + + +mish_op = Mish().set_device('CPU') +def mish(*args): + return mish_op(*args) + + +mod_op = Mod().set_device('CPU') +def mod(*args): + return mod_op(*args) + + +def morph(*args): + op = _get_cache_prim(Morph)(*args[-3:]).set_device('CPU') + return op(*args[:-3]) + + +move_to_op = MoveTo().set_device('CPU') +def move_to(*args): + return move_to_op(*args) + + +mul_op = Mul().set_device('CPU') +def mul(*args): + return mul_op(*args) + + +mul_no_nan_op = MulNoNan().set_device('CPU') +def mul_no_nan(*args): + return mul_no_nan_op(*args) + + +def multi_margin_loss(*args): + op = _get_cache_prim(MultiMarginLoss)(*args[-3:]).set_device('CPU') + return op(*args[:-3]) + + +def multilabel_margin_loss(*args): + op = _get_cache_prim(MultilabelMarginLoss)(*args[-1:]).set_device('CPU') + return op(*args[:-1]) + + +def multinomial(*args): + op = _get_cache_prim(Multinomial)(*args[-3:]).set_device('CPU') + return op(*args[:-3]) + + +def multinomial_with_replacement(*args): + op = _get_cache_prim(MultinomialWithReplacement)(*args[-2:]).set_device('CPU') + return op(*args[:-2]) + + +def mvlgamma(*args): + op = _get_cache_prim(Mvlgamma)(*args[-1:]).set_device('CPU') + return op(*args[:-1]) + + +def nll_loss(*args): + op = _get_cache_prim(NLLLoss)(*args[-2:]).set_device('CPU') + return op(*args[:-2]) + + +def nms_with_mask(*args): + op = _get_cache_prim(NMSWithMask)(*args[-1:]).set_device('CPU') + return op(*args[:-1]) + + +def nan_to_num(*args): + op = _get_cache_prim(NanToNum)(*args[-3:]).set_device('CPU') + return op(*args[:-3]) + + +neg_op = Neg().set_device('CPU') +def neg(*args): + return neg_op(*args) + + +def neighbor_exchange(*args): + op = _get_cache_prim(NeighborExchange)(*args[-6:]).set_device('CPU') + return op(*args[:-6]) + + +def neighbor_exchange_v2(*args): + op = _get_cache_prim(NeighborExchangeV2)(*args[-6:]).set_device('CPU') + return op(*args[:-6]) + + +next_after_op = NextAfter().set_device('CPU') +def next_after(*args): + return next_after_op(*args) + + +def no_repeat_n_gram(*args): + op = _get_cache_prim(NoRepeatNGram)(*args[-1:]).set_device('CPU') + return op(*args[:-1]) + + +def non_deterministic_ints(*args): + op = _get_cache_prim(NonDeterministicInts)(*args[-1:]).set_device('CPU') + return op(*args[:-1]) + + +non_max_suppression_v3_op = NonMaxSuppressionV3().set_device('CPU') +def non_max_suppression_v3(*args): + return non_max_suppression_v3_op(*args) + + +non_max_suppression_with_overlaps_op = NonMaxSuppressionWithOverlaps().set_device('CPU') +def non_max_suppression_with_overlaps(*args): + return non_max_suppression_with_overlaps_op(*args) + + +non_zero_op = NonZero().set_device('CPU') +def non_zero(*args): + return non_zero_op(*args) + + +not_equal_op = NotEqual().set_device('CPU') +def not_equal(*args): + return not_equal_op(*args) + + +def nth_element(*args): + op = _get_cache_prim(NthElement)(*args[-1:]).set_device('CPU') + return op(*args[:-1]) + + +def nuclear_norm(*args): + op = _get_cache_prim(NuclearNorm)(*args[-2:]).set_device('CPU') + return op(*args[:-2]) + + +def one_hot(*args): + op = _get_cache_prim(OneHot)(*args[-1:]).set_device('CPU') + return op(*args[:-1]) + + +ones_op = Ones().set_device('CPU') +def ones(*args): + return ones_op(*args) + + +ones_like_op = OnesLike().set_device('CPU') +def ones_like(*args): + return ones_like_op(*args) + + +orgqr_op = Orgqr().set_device('CPU') +def orgqr(*args): + return orgqr_op(*args) + + +def ormqr(*args): + op = _get_cache_prim(Ormqr)(*args[-2:]).set_device('CPU') + return op(*args[:-2]) + + +p_re_lu_op = PReLU().set_device('CPU') +def p_re_lu(*args): + return p_re_lu_op(*args) + + +def pack(*args): + op = _get_cache_prim(Pack)(*args[-1:]).set_device('CPU') + return op(*args[:-1]) + + +def pad(*args): + op = _get_cache_prim(Pad)(*args[-1:]).set_device('CPU') + return op(*args[:-1]) + + +def pad_v3(*args): + op = _get_cache_prim(PadV3)(*args[-2:]).set_device('CPU') + return op(*args[:-2]) + + +def padding(*args): + op = _get_cache_prim(Padding)(*args[-1:]).set_device('CPU') + return op(*args[:-1]) + + +def paged_attention(*args): + op = _get_cache_prim(PagedAttention)(*args[-6:]).set_device('CPU') + return op(*args[:-6]) + + +def paged_attention_mask(*args): + op = _get_cache_prim(PagedAttentionMask)(*args[-4:]).set_device('CPU') + return op(*args[:-4]) + + +parallel_concat_op = ParallelConcat().set_device('CPU') +def parallel_concat(*args): + return parallel_concat_op(*args) + + +def parameterized_truncated_normal(*args): + op = _get_cache_prim(ParameterizedTruncatedNormal)(*args[-2:]).set_device('CPU') + return op(*args[:-2]) + + +partial_op = Partial().set_device('CPU') +def partial(*args): + return partial_op(*args) + + +def pdist(*args): + op = _get_cache_prim(Pdist)(*args[-1:]).set_device('CPU') + return op(*args[:-1]) + + +def poisson(*args): + op = _get_cache_prim(Poisson)(*args[-2:]).set_device('CPU') + return op(*args[:-2]) + + +polar_op = Polar().set_device('CPU') +def polar(*args): + return polar_op(*args) + + +polygamma_op = Polygamma().set_device('CPU') +def polygamma(*args): + return polygamma_op(*args) + + +population_count_op = PopulationCount().set_device('CPU') +def population_count(*args): + return population_count_op(*args) + + +pow_op = Pow().set_device('CPU') +def pow(*args): + return pow_op(*args) + + +pull_op = Pull().set_device('CPU') +def pull(*args): + return pull_op(*args) + + +def push(*args): + op = _get_cache_prim(Push)(*args[-2:]).set_device('CPU') + return op(*args[:-2]) + + +py_execute_op = PyExecute().set_device('CPU') +def py_execute(*args): + return py_execute_op(*args) + + +def py_func(*args): + op = _get_cache_prim(PyFunc)(*args[-6:]).set_device('CPU') + return op(*args[:-6]) + + +def qr(*args): + op = _get_cache_prim(Qr)(*args[-1:]).set_device('CPU') + return op(*args[:-1]) + + +def quantile(*args): + op = _get_cache_prim(Quantile)(*args[-3:]).set_device('CPU') + return op(*args[:-3]) + + +rgb_to_hsv_op = RGBToHSV().set_device('CPU') +def rgb_to_hsv(*args): + return rgb_to_hsv_op(*args) + + +def rnnt_loss(*args): + op = _get_cache_prim(RNNTLoss)(*args[-1:]).set_device('CPU') + return op(*args[:-1]) + + +def roi_align(*args): + op = _get_cache_prim(ROIAlign)(*args[-5:]).set_device('CPU') + return op(*args[:-5]) + + +def ragged_range(*args): + op = _get_cache_prim(RaggedRange)(*args[-1:]).set_device('CPU') + return op(*args[:-1]) + + +def random_categorical(*args): + op = _get_cache_prim(RandomCategorical)(*args[-1:]).set_device('CPU') + return op(*args[:-1]) + + +def random_choice_with_mask(*args): + op = _get_cache_prim(RandomChoiceWithMask)(*args[-3:]).set_device('CPU') + return op(*args[:-3]) + + +def random_gamma(*args): + op = _get_cache_prim(RandomGamma)(*args[-2:]).set_device('CPU') + return op(*args[:-2]) + + +def random_gamma(*args): + op = _get_cache_prim(RandomGamma)(*args[-2:]).set_device('CPU') + return op(*args[:-2]) + + +def random_poisson(*args): + op = _get_cache_prim(RandomPoisson)(*args[-3:]).set_device('CPU') + return op(*args[:-3]) + + +def random_shuffle(*args): + op = _get_cache_prim(RandomShuffle)(*args[-2:]).set_device('CPU') + return op(*args[:-2]) + + +def randperm(*args): + op = _get_cache_prim(Randperm)(*args[-3:]).set_device('CPU') + return op(*args[:-3]) + + +def randperm_v2(*args): + op = _get_cache_prim(RandpermV2)(*args[-3:]).set_device('CPU') + return op(*args[:-3]) + + +def range(*args): + op = _get_cache_prim(Range)(*args[-1:]).set_device('CPU') + return op(*args[:-1]) + + +rank_op = Rank().set_device('CPU') +def rank(*args): + return rank_op(*args) + + +re_lu_op = ReLU().set_device('CPU') +def re_lu(*args): + return re_lu_op(*args) + + +re_lu6_op = ReLU6().set_device('CPU') +def re_lu6(*args): + return re_lu6_op(*args) + + +real_op = Real().set_device('CPU') +def real(*args): + return real_op(*args) + + +real_div_op = RealDiv().set_device('CPU') +def real_div(*args): + return real_div_op(*args) + + +def receive(*args): + op = _get_cache_prim(Receive)(*args[-6:]).set_device('CPU') + return op(*args[:-6]) + + +reciprocal_op = Reciprocal().set_device('CPU') +def reciprocal(*args): + return reciprocal_op(*args) + + +def reduce(*args): + op = _get_cache_prim(Reduce)(*args[-3:]).set_device('CPU') + return op(*args[:-3]) + + +def reduce_all(*args): + op = _get_cache_prim(ReduceAll)(*args[-1:]).set_device('CPU') + return op(*args[:-1]) + + +def reduce_any(*args): + op = _get_cache_prim(ReduceAny)(*args[-1:]).set_device('CPU') + return op(*args[:-1]) + + +def reduce_max(*args): + op = _get_cache_prim(ReduceMax)(*args[-1:]).set_device('CPU') + return op(*args[:-1]) + + +def reduce_mean(*args): + op = _get_cache_prim(ReduceMean)(*args[-1:]).set_device('CPU') + return op(*args[:-1]) + + +def reduce_min(*args): + op = _get_cache_prim(ReduceMin)(*args[-1:]).set_device('CPU') + return op(*args[:-1]) + + +def reduce_prod(*args): + op = _get_cache_prim(ReduceProd)(*args[-1:]).set_device('CPU') + return op(*args[:-1]) + + +def reduce_scatter(*args): + op = _get_cache_prim(ReduceScatter)(*args[-2:]).set_device('CPU') + return op(*args[:-2]) + + +def reduce_std(*args): + op = _get_cache_prim(ReduceStd)(*args[-3:]).set_device('CPU') + return op(*args[:-3]) + + +def reduce_sum(*args): + op = _get_cache_prim(ReduceSum)(*args[-2:]).set_device('CPU') + return op(*args[:-2]) + + +def renorm(*args): + op = _get_cache_prim(Renorm)(*args[-3:]).set_device('CPU') + return op(*args[:-3]) + + +reshape_op = Reshape().set_device('CPU') +def reshape(*args): + return reshape_op(*args) + + +reshape_and_cache_op = ReshapeAndCache().set_device('CPU') +def reshape_and_cache(*args): + return reshape_and_cache_op(*args) + + +def reshard(*args): + op = _get_cache_prim(Reshard)(*args[-3:]).set_device('CPU') + return op(*args[:-3]) + + +def resize_area(*args): + op = _get_cache_prim(ResizeArea)(*args[-1:]).set_device('CPU') + return op(*args[:-1]) + + +def resize_bicubic(*args): + op = _get_cache_prim(ResizeBicubic)(*args[-2:]).set_device('CPU') + return op(*args[:-2]) + + +def resize_bilinear_v2(*args): + op = _get_cache_prim(ResizeBilinearV2)(*args[-2:]).set_device('CPU') + return op(*args[:-2]) + + +def resize_linear1_d(*args): + op = _get_cache_prim(ResizeLinear1D)(*args[-1:]).set_device('CPU') + return op(*args[:-1]) + + +def resize_nearest_neighbor(*args): + op = _get_cache_prim(ResizeNearestNeighbor)(*args[-3:]).set_device('CPU') + return op(*args[:-3]) + + +def resize_nearest_neighbor_v2(*args): + op = _get_cache_prim(ResizeNearestNeighborV2)(*args[-2:]).set_device('CPU') + return op(*args[:-2]) + + +reusing_op = Reusing().set_device('CPU') +def reusing(*args): + return reusing_op(*args) + + +def reverse_sequence(*args): + op = _get_cache_prim(ReverseSequence)(*args[-2:]).set_device('CPU') + return op(*args[:-2]) + + +def reverse_v2(*args): + op = _get_cache_prim(ReverseV2)(*args[-1:]).set_device('CPU') + return op(*args[:-1]) + + +right_shift_op = RightShift().set_device('CPU') +def right_shift(*args): + return right_shift_op(*args) + + +rint_op = Rint().set_device('CPU') +def rint(*args): + return rint_op(*args) + + +def rms_norm(*args): + op = _get_cache_prim(RmsNorm)(*args[-1:]).set_device('CPU') + return op(*args[:-1]) + + +def roll(*args): + op = _get_cache_prim(Roll)(*args[-2:]).set_device('CPU') + return op(*args[:-2]) + + +round_op = Round().set_device('CPU') +def round(*args): + return round_op(*args) + + +rsqrt_op = Rsqrt().set_device('CPU') +def rsqrt(*args): + return rsqrt_op(*args) + + +def sgd(*args): + op = _get_cache_prim(SGD)(*args[-3:]).set_device('CPU') + return op(*args[:-3]) + + +def stft(*args): + op = _get_cache_prim(STFT)(*args[-6:]).set_device('CPU') + return op(*args[:-6]) + + +def sample_distorted_bounding_box_v2(*args): + op = _get_cache_prim(SampleDistortedBoundingBoxV2)(*args[-6:]).set_device('CPU') + return op(*args[:-6]) + + +scalar_summary_op = ScalarSummary().set_device('CPU') +def scalar_summary(*args): + return scalar_summary_op(*args) + + +scalar_to_tensor_op = ScalarToTensor().set_device('CPU') +def scalar_to_tensor(*args): + return scalar_to_tensor_op(*args) + + +def scale_and_translate(*args): + op = _get_cache_prim(ScaleAndTranslate)(*args[-2:]).set_device('CPU') + return op(*args[:-2]) + + +scan_op = Scan().set_device('CPU') +def scan(*args): + return scan_op(*args) + + +def scatter_add(*args): + op = _get_cache_prim(ScatterAdd)(*args[-1:]).set_device('CPU') + return op(*args[:-1]) + + +def scatter_add_with_axis(*args): + op = _get_cache_prim(ScatterAddWithAxis)(*args[-1:]).set_device('CPU') + return op(*args[:-1]) + + +def scatter_div(*args): + op = _get_cache_prim(ScatterDiv)(*args[-1:]).set_device('CPU') + return op(*args[:-1]) + + +def scatter_max(*args): + op = _get_cache_prim(ScatterMax)(*args[-1:]).set_device('CPU') + return op(*args[:-1]) + + +def scatter_min(*args): + op = _get_cache_prim(ScatterMin)(*args[-1:]).set_device('CPU') + return op(*args[:-1]) + + +def scatter_mul(*args): + op = _get_cache_prim(ScatterMul)(*args[-1:]).set_device('CPU') + return op(*args[:-1]) + + +scatter_nd_op = ScatterNd().set_device('CPU') +def scatter_nd(*args): + return scatter_nd_op(*args) + + +def scatter_nd_add(*args): + op = _get_cache_prim(ScatterNdAdd)(*args[-1:]).set_device('CPU') + return op(*args[:-1]) + + +def scatter_nd_div(*args): + op = _get_cache_prim(ScatterNdDiv)(*args[-1:]).set_device('CPU') + return op(*args[:-1]) + + +def scatter_nd_max(*args): + op = _get_cache_prim(ScatterNdMax)(*args[-1:]).set_device('CPU') + return op(*args[:-1]) + + +def scatter_nd_min(*args): + op = _get_cache_prim(ScatterNdMin)(*args[-1:]).set_device('CPU') + return op(*args[:-1]) + + +def scatter_nd_mul(*args): + op = _get_cache_prim(ScatterNdMul)(*args[-1:]).set_device('CPU') + return op(*args[:-1]) + + +def scatter_nd_sub(*args): + op = _get_cache_prim(ScatterNdSub)(*args[-1:]).set_device('CPU') + return op(*args[:-1]) + + +def scatter_nd_update(*args): + op = _get_cache_prim(ScatterNdUpdate)(*args[-1:]).set_device('CPU') + return op(*args[:-1]) + + +def scatter_sub(*args): + op = _get_cache_prim(ScatterSub)(*args[-1:]).set_device('CPU') + return op(*args[:-1]) + + +def scatter_update(*args): + op = _get_cache_prim(ScatterUpdate)(*args[-1:]).set_device('CPU') + return op(*args[:-1]) + + +se_lu_op = SeLU().set_device('CPU') +def se_lu(*args): + return se_lu_op(*args) + + +def search_sorted(*args): + op = _get_cache_prim(SearchSorted)(*args[-2:]).set_device('CPU') + return op(*args[:-2]) + + +segment_max_op = SegmentMax().set_device('CPU') +def segment_max(*args): + return segment_max_op(*args) + + +segment_mean_op = SegmentMean().set_device('CPU') +def segment_mean(*args): + return segment_mean_op(*args) + + +segment_min_op = SegmentMin().set_device('CPU') +def segment_min(*args): + return segment_min_op(*args) + + +segment_prod_op = SegmentProd().set_device('CPU') +def segment_prod(*args): + return segment_prod_op(*args) + + +segment_sum_op = SegmentSum().set_device('CPU') +def segment_sum(*args): + return segment_sum_op(*args) + + +select_op = Select().set_device('CPU') +def select(*args): + return select_op(*args) + + +select_view_op = SelectView().set_device('CPU') +def select_view(*args): + return select_view_op(*args) + + +def send(*args): + op = _get_cache_prim(Send)(*args[-4:]).set_device('CPU') + return op(*args[:-4]) + + +shape_op = Shape().set_device('CPU') +def shape(*args): + return shape_op(*args) + + +sigmoid_op = Sigmoid().set_device('CPU') +def sigmoid(*args): + return sigmoid_op(*args) + + +sigmoid_cross_entropy_with_logits_op = SigmoidCrossEntropyWithLogits().set_device('CPU') +def sigmoid_cross_entropy_with_logits(*args): + return sigmoid_cross_entropy_with_logits_op(*args) + + +sign_op = Sign().set_device('CPU') +def sign(*args): + return sign_op(*args) + + +sin_op = Sin().set_device('CPU') +def sin(*args): + return sin_op(*args) + + +sinc_op = Sinc().set_device('CPU') +def sinc(*args): + return sinc_op(*args) + + +sinh_op = Sinh().set_device('CPU') +def sinh(*args): + return sinh_op(*args) + + +size_op = Size().set_device('CPU') +def size(*args): + return size_op(*args) + + +slice_op = Slice().set_device('CPU') +def slice(*args): + return slice_op(*args) + + +def smooth_l1_loss(*args): + op = _get_cache_prim(SmoothL1Loss)(*args[-2:]).set_device('CPU') + return op(*args[:-2]) + + +def soft_margin_loss(*args): + op = _get_cache_prim(SoftMarginLoss)(*args[-1:]).set_device('CPU') + return op(*args[:-1]) + + +def soft_shrink(*args): + op = _get_cache_prim(SoftShrink)(*args[-1:]).set_device('CPU') + return op(*args[:-1]) + + +def softmax(*args): + op = _get_cache_prim(Softmax)(*args[-1:]).set_device('CPU') + return op(*args[:-1]) + + +softmax_cross_entropy_with_logits_op = SoftmaxCrossEntropyWithLogits().set_device('CPU') +def softmax_cross_entropy_with_logits(*args): + return softmax_cross_entropy_with_logits_op(*args) + + +softplus_op = Softplus().set_device('CPU') +def softplus(*args): + return softplus_op(*args) + + +softsign_op = Softsign().set_device('CPU') +def softsign(*args): + return softsign_op(*args) + + +def sort(*args): + op = _get_cache_prim(Sort)(*args[-2:]).set_device('CPU') + return op(*args[:-2]) + + +def space_to_batch(*args): + op = _get_cache_prim(SpaceToBatch)(*args[-2:]).set_device('CPU') + return op(*args[:-2]) + + +def space_to_batch_nd(*args): + op = _get_cache_prim(SpaceToBatchND)(*args[-2:]).set_device('CPU') + return op(*args[:-2]) + + +def space_to_depth(*args): + op = _get_cache_prim(SpaceToDepth)(*args[-1:]).set_device('CPU') + return op(*args[:-1]) + + +def sparse_apply_adadelta(*args): + op = _get_cache_prim(SparseApplyAdadelta)(*args[-2:]).set_device('CPU') + return op(*args[:-2]) + + +def sparse_apply_adagrad(*args): + op = _get_cache_prim(SparseApplyAdagrad)(*args[-3:]).set_device('CPU') + return op(*args[:-3]) + + +def sparse_apply_adagrad_v2(*args): + op = _get_cache_prim(SparseApplyAdagradV2)(*args[-4:]).set_device('CPU') + return op(*args[:-4]) + + +def sparse_apply_ftrl(*args): + op = _get_cache_prim(SparseApplyFtrl)(*args[-5:]).set_device('CPU') + return op(*args[:-5]) + + +def sparse_apply_ftrl_v2(*args): + op = _get_cache_prim(SparseApplyFtrlV2)(*args[-6:]).set_device('CPU') + return op(*args[:-6]) + + +def sparse_apply_proximal_adagrad(*args): + op = _get_cache_prim(SparseApplyProximalAdagrad)(*args[-1:]).set_device('CPU') + return op(*args[:-1]) + + +def sparse_apply_rms_prop(*args): + op = _get_cache_prim(SparseApplyRMSProp)(*args[-4:]).set_device('CPU') + return op(*args[:-4]) + + +sparse_gather_v2_op = SparseGatherV2().set_device('CPU') +def sparse_gather_v2(*args): + return sparse_gather_v2_op(*args) + + +sparse_slice_op = SparseSlice().set_device('CPU') +def sparse_slice(*args): + return sparse_slice_op(*args) + + +def sparse_softmax_cross_entropy_with_logits(*args): + op = _get_cache_prim(SparseSoftmaxCrossEntropyWithLogits)(*args[-1:]).set_device('CPU') + return op(*args[:-1]) + + +sparse_tensor_dense_add_op = SparseTensorDenseAdd().set_device('CPU') +def sparse_tensor_dense_add(*args): + return sparse_tensor_dense_add_op(*args) + + +def sparse_tensor_dense_matmul(*args): + op = _get_cache_prim(SparseTensorDenseMatmul)(*args[-2:]).set_device('CPU') + return op(*args[:-2]) + + +sparse_to_dense_op = SparseToDense().set_device('CPU') +def sparse_to_dense(*args): + return sparse_to_dense_op(*args) + + +def split(*args): + op = _get_cache_prim(Split)(*args[-2:]).set_device('CPU') + return op(*args[:-2]) + + +def split_v(*args): + op = _get_cache_prim(SplitV)(*args[-3:]).set_device('CPU') + return op(*args[:-3]) + + +sqrt_op = Sqrt().set_device('CPU') +def sqrt(*args): + return sqrt_op(*args) + + +square_op = Square().set_device('CPU') +def square(*args): + return square_op(*args) + + +square_sum_all_op = SquareSumAll().set_device('CPU') +def square_sum_all(*args): + return square_sum_all_op(*args) + + +squared_difference_op = SquaredDifference().set_device('CPU') +def squared_difference(*args): + return squared_difference_op(*args) + + +def squeeze(*args): + op = _get_cache_prim(Squeeze)(*args[-1:]).set_device('CPU') + return op(*args[:-1]) + + +def stack(*args): + op = _get_cache_prim(Stack)(*args[-1:]).set_device('CPU') + return op(*args[:-1]) + + +def standard_laplace(*args): + op = _get_cache_prim(StandardLaplace)(*args[-2:]).set_device('CPU') + return op(*args[:-2]) + + +def standard_normal(*args): + op = _get_cache_prim(StandardNormal)(*args[-2:]).set_device('CPU') + return op(*args[:-2]) + + +stop_gradient_op = StopGradient().set_device('CPU') +def stop_gradient(*args): + return stop_gradient_op(*args) + + +def strided_slice(*args): + op = _get_cache_prim(StridedSlice)(*args[-5:]).set_device('CPU') + return op(*args[:-5]) + + +sub_op = Sub().set_device('CPU') +def sub(*args): + return sub_op(*args) + + +sub_and_filter_op = SubAndFilter().set_device('CPU') +def sub_and_filter(*args): + return sub_and_filter_op(*args) + + +def svd(*args): + op = _get_cache_prim(Svd)(*args[-2:]).set_device('CPU') + return op(*args[:-2]) + + +tan_op = Tan().set_device('CPU') +def tan(*args): + return tan_op(*args) + + +tanh_op = Tanh().set_device('CPU') +def tanh(*args): + return tanh_op(*args) + + +def tensor_dump(*args): + op = _get_cache_prim(TensorDump)(*args[-1:]).set_device('CPU') + return op(*args[:-1]) + + +tensor_scatter_add_op = TensorScatterAdd().set_device('CPU') +def tensor_scatter_add(*args): + return tensor_scatter_add_op(*args) + + +tensor_scatter_div_op = TensorScatterDiv().set_device('CPU') +def tensor_scatter_div(*args): + return tensor_scatter_div_op(*args) + + +def tensor_scatter_elements(*args): + op = _get_cache_prim(TensorScatterElements)(*args[-2:]).set_device('CPU') + return op(*args[:-2]) + + +tensor_scatter_max_op = TensorScatterMax().set_device('CPU') +def tensor_scatter_max(*args): + return tensor_scatter_max_op(*args) + + +tensor_scatter_min_op = TensorScatterMin().set_device('CPU') +def tensor_scatter_min(*args): + return tensor_scatter_min_op(*args) + + +tensor_scatter_mul_op = TensorScatterMul().set_device('CPU') +def tensor_scatter_mul(*args): + return tensor_scatter_mul_op(*args) + + +tensor_scatter_sub_op = TensorScatterSub().set_device('CPU') +def tensor_scatter_sub(*args): + return tensor_scatter_sub_op(*args) + + +tensor_scatter_update_op = TensorScatterUpdate().set_device('CPU') +def tensor_scatter_update(*args): + return tensor_scatter_update_op(*args) + + +tensor_shape_op = TensorShape().set_device('CPU') +def tensor_shape(*args): + return tensor_shape_op(*args) + + +tensor_summary_op = TensorSummary().set_device('CPU') +def tensor_summary(*args): + return tensor_summary_op(*args) + + +tile_op = Tile().set_device('CPU') +def tile(*args): + return tile_op(*args) + + +def top_k(*args): + op = _get_cache_prim(TopK)(*args[-1:]).set_device('CPU') + return op(*args[:-1]) + + +trace_op = Trace().set_device('CPU') +def trace(*args): + return trace_op(*args) + + +transpose_op = Transpose().set_device('CPU') +def transpose(*args): + return transpose_op(*args) + + +transpose_ext_view_op = TransposeExtView().set_device('CPU') +def transpose_ext_view(*args): + return transpose_ext_view_op(*args) + + +transpose_view_op = TransposeView().set_device('CPU') +def transpose_view(*args): + return transpose_view_op(*args) + + +tridiagonal_mat_mul_op = TridiagonalMatMul().set_device('CPU') +def tridiagonal_mat_mul(*args): + return tridiagonal_mat_mul_op(*args) + + +def tril(*args): + op = _get_cache_prim(Tril)(*args[-1:]).set_device('CPU') + return op(*args[:-1]) + + +def tril_indices(*args): + op = _get_cache_prim(TrilIndices)(*args[-4:]).set_device('CPU') + return op(*args[:-4]) + + +def triplet_margin_loss(*args): + op = _get_cache_prim(TripletMarginLoss)(*args[-4:]).set_device('CPU') + return op(*args[:-4]) + + +def triu(*args): + op = _get_cache_prim(Triu)(*args[-1:]).set_device('CPU') + return op(*args[:-1]) + + +def triu_indices(*args): + op = _get_cache_prim(TriuIndices)(*args[-4:]).set_device('CPU') + return op(*args[:-4]) + + +trunc_op = Trunc().set_device('CPU') +def trunc(*args): + return trunc_op(*args) + + +truncate_div_op = TruncateDiv().set_device('CPU') +def truncate_div(*args): + return truncate_div_op(*args) + + +truncate_mod_op = TruncateMod().set_device('CPU') +def truncate_mod(*args): + return truncate_mod_op(*args) + + +def truncated_normal(*args): + op = _get_cache_prim(TruncatedNormal)(*args[-3:]).set_device('CPU') + return op(*args[:-3]) + + +tuple_to_array_op = TupleToArray().set_device('CPU') +def tuple_to_array(*args): + return tuple_to_array_op(*args) + + +def uniform_candidate_sampler(*args): + op = _get_cache_prim(UniformCandidateSampler)(*args[-6:]).set_device('CPU') + return op(*args[:-6]) + + +def uniform_int(*args): + op = _get_cache_prim(UniformInt)(*args[-2:]).set_device('CPU') + return op(*args[:-2]) + + +def uniform_real(*args): + op = _get_cache_prim(UniformReal)(*args[-2:]).set_device('CPU') + return op(*args[:-2]) + + +unique_op = Unique().set_device('CPU') +def unique(*args): + return unique_op(*args) + + +def unique_consecutive(*args): + op = _get_cache_prim(UniqueConsecutive)(*args[-3:]).set_device('CPU') + return op(*args[:-3]) + + + +def unpack(*args): + op = _get_cache_prim(Unpack)(*args[-1:]).set_device('CPU') + return op(*args[:-1]) + + +unravel_index_op = UnravelIndex().set_device('CPU') +def unravel_index(*args): + return unravel_index_op(*args) + + +unsorted_segment_max_op = UnsortedSegmentMax().set_device('CPU') +def unsorted_segment_max(*args): + return unsorted_segment_max_op(*args) + + +unsorted_segment_min_op = UnsortedSegmentMin().set_device('CPU') +def unsorted_segment_min(*args): + return unsorted_segment_min_op(*args) + + +unsorted_segment_prod_op = UnsortedSegmentProd().set_device('CPU') +def unsorted_segment_prod(*args): + return unsorted_segment_prod_op(*args) + + +unsorted_segment_sum_op = UnsortedSegmentSum().set_device('CPU') +def unsorted_segment_sum(*args): + return unsorted_segment_sum_op(*args) + + +def unstack(*args): + op = _get_cache_prim(Unstack)(*args[-2:]).set_device('CPU') + return op(*args[:-2]) + + +update_state_op = UpdateState().set_device('CPU') +def update_state(*args): + return update_state_op(*args) + + +def upper_bound(*args): + op = _get_cache_prim(UpperBound)(*args[-1:]).set_device('CPU') + return op(*args[:-1]) + + +upsample_nearest3_d_op = UpsampleNearest3D().set_device('CPU') +def upsample_nearest3_d(*args): + return upsample_nearest3_d_op(*args) + + +def upsample_trilinear3_d(*args): + op = _get_cache_prim(UpsampleTrilinear3D)(*args[-1:]).set_device('CPU') + return op(*args[:-1]) + + +while_loop_op = WhileLoop().set_device('CPU') +def while_loop(*args): + return while_loop_op(*args) + + +xdivy_op = Xdivy().set_device('CPU') +def xdivy(*args): + return xdivy_op(*args) + + +xlogy_op = Xlogy().set_device('CPU') +def xlogy(*args): + return xlogy_op(*args) + + +zeros_op = Zeros().set_device('CPU') +def zeros(*args): + return zeros_op(*args) + + +zeros_like_op = ZerosLike().set_device('CPU') +def zeros_like(*args): + return zeros_like_op(*args) + + +zeta_op = Zeta().set_device('CPU') +def zeta(*args): + return zeta_op(*args) diff --git a/torch4ms/_op_prim/gpu/__init__.py b/torch4ms/_op_prim/gpu/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/torch4ms/_op_prim/gpu/legacy.py b/torch4ms/_op_prim/gpu/legacy.py new file mode 100644 index 000000000..7f71def6e --- /dev/null +++ b/torch4ms/_op_prim/gpu/legacy.py @@ -0,0 +1,3511 @@ +from mindspore.ops.operations import * +from mindspore.ops.operations._grad_ops import * +from mindspore.ops.operations._inner_ops import * +from mindspore.ops._primitive_cache import _get_cache_prim + + +a_cos_grad_op = ACosGrad().set_device('GPU') +def a_cos_grad(*args): + return a_cos_grad_op(*args) + + +abs_grad_op = AbsGrad().set_device('GPU') +def abs_grad(*args): + return abs_grad_op(*args) + + +acosh_grad_op = AcoshGrad().set_device('GPU') +def acosh_grad(*args): + return acosh_grad_op(*args) + + +adaptive_avg_pool2_d_grad_op = AdaptiveAvgPool2DGrad().set_device('GPU') +def adaptive_avg_pool2_d_grad(*args): + return adaptive_avg_pool2_d_grad_op(*args) + + +adaptive_avg_pool3_d_grad_op = AdaptiveAvgPool3DGrad().set_device('GPU') +def adaptive_avg_pool3_d_grad(*args): + return adaptive_avg_pool3_d_grad_op(*args) + + +adaptive_max_pool2_d_grad_op = AdaptiveMaxPool2DGrad().set_device('GPU') +def adaptive_max_pool2_d_grad(*args): + return adaptive_max_pool2_d_grad_op(*args) + + +adaptive_max_pool3_d_grad_op = AdaptiveMaxPool3DGrad().set_device('GPU') +def adaptive_max_pool3_d_grad(*args): + return adaptive_max_pool3_d_grad_op(*args) + + +def affine_grid_grad(*args): + op = _get_cache_prim(AffineGridGrad)(*args[-1:]).set_device('GPU') + return op(*args[:-1]) + + +asin_grad_op = AsinGrad().set_device('GPU') +def asin_grad(*args): + return asin_grad_op(*args) + + +asinh_grad_op = AsinhGrad().set_device('GPU') +def asinh_grad(*args): + return asinh_grad_op(*args) + + +atan_grad_op = AtanGrad().set_device('GPU') +def atan_grad(*args): + return atan_grad_op(*args) + + +def avg_pool3_d_grad(*args): + op = _get_cache_prim(AvgPool3DGrad)(*args[-8:]).set_device('GPU') + return op(*args[:-8]) + + +def avg_pool_grad(*args): + op = _get_cache_prim(AvgPoolGrad)(*args[-4:]).set_device('GPU') + return op(*args[:-4]) + + +def avg_pool_grad_ge(*args): + op = _get_cache_prim(AvgPoolGradGe)(*args[-4:]).set_device('GPU') + return op(*args[:-4]) + + +def avg_pool_grad_v1(*args): + op = _get_cache_prim(AvgPoolGradV1)(*args[-4:]).set_device('GPU') + return op(*args[:-4]) + + +def avg_pool_grad_vm(*args): + op = _get_cache_prim(AvgPoolGradVm)(*args[-3:]).set_device('GPU') + return op(*args[:-3]) + + +def bn_training_reduce_grad(*args): + op = _get_cache_prim(BNTrainingReduceGrad)(*args[-2:]).set_device('GPU') + return op(*args[:-2]) + + +def bn_training_update_grad(*args): + op = _get_cache_prim(BNTrainingUpdateGrad)(*args[-2:]).set_device('GPU') + return op(*args[:-2]) + + +def basic_lstm_cell_c_state_grad(*args): + op = _get_cache_prim(BasicLSTMCellCStateGrad)(*args[-2:]).set_device('GPU') + return op(*args[:-2]) + + +def basic_lstm_cell_input_grad(*args): + op = _get_cache_prim(BasicLSTMCellInputGrad)(*args[-1:]).set_device('GPU') + return op(*args[:-1]) + + +basic_lstm_cell_weight_grad_op = BasicLSTMCellWeightGrad().set_device('GPU') +def basic_lstm_cell_weight_grad(*args): + return basic_lstm_cell_weight_grad_op(*args) + + +def batch_norm_grad(*args): + op = _get_cache_prim(BatchNormGrad)(*args[-3:]).set_device('GPU') + return op(*args[:-3]) + + +def batch_norm_grad_grad(*args): + op = _get_cache_prim(BatchNormGradGrad)(*args[-3:]).set_device('GPU') + return op(*args[:-3]) + + +def bias_add_grad(*args): + op = _get_cache_prim(BiasAddGrad)(*args[-1:]).set_device('GPU') + return op(*args[:-1]) + + +def binary_cross_entropy_grad(*args): + op = _get_cache_prim(BinaryCrossEntropyGrad)(*args[-1:]).set_device('GPU') + return op(*args[:-1]) + + +cholesky_grad_op = CholeskyGrad().set_device('GPU') +def cholesky_grad(*args): + return cholesky_grad_op(*args) + + +def concat_offset(*args): + op = _get_cache_prim(ConcatOffset)(*args[-2:]).set_device('GPU') + return op(*args[:-2]) + + +def conv2_d_backprop_filter(*args): + op = _get_cache_prim(Conv2DBackpropFilter)(*args[-10:]).set_device('GPU') + return op(*args[:-10]) + + +def conv3_d_backprop_filter(*args): + op = _get_cache_prim(Conv3DBackpropFilter)(*args[-9:]).set_device('GPU') + return op(*args[:-9]) + + +def deformable_offsets_grad(*args): + op = _get_cache_prim(DeformableOffsetsGrad)(*args[-7:]).set_device('GPU') + return op(*args[:-7]) + + +def depthwise_conv2d_native_backprop_filter(*args): + op = _get_cache_prim(DepthwiseConv2dNativeBackpropFilter)(*args[-9:]).set_device('GPU') + return op(*args[:-9]) + + +def depthwise_conv2d_native_backprop_input(*args): + op = _get_cache_prim(DepthwiseConv2dNativeBackpropInput)(*args[-9:]).set_device('GPU') + return op(*args[:-9]) + + +def dilation2_d_backprop_filter(*args): + op = _get_cache_prim(Dilation2DBackpropFilter)(*args[-4:]).set_device('GPU') + return op(*args[:-4]) + + +def dilation2_d_backprop_input(*args): + op = _get_cache_prim(Dilation2DBackpropInput)(*args[-4:]).set_device('GPU') + return op(*args[:-4]) + + +def dropout_grad(*args): + op = _get_cache_prim(DropoutGrad)(*args[-1:]).set_device('GPU') + return op(*args[:-1]) + + +def dynamic_gruv2_grad(*args): + op = _get_cache_prim(DynamicGRUV2Grad)(*args[-8:]).set_device('GPU') + return op(*args[:-8]) + + +def dynamic_rnn_grad(*args): + op = _get_cache_prim(DynamicRNNGrad)(*args[-9:]).set_device('GPU') + return op(*args[:-9]) + + +def einsum_grad(*args): + op = _get_cache_prim(EinsumGrad)(*args[-1:]).set_device('GPU') + return op(*args[:-1]) + + +elu_grad_op = EluGrad().set_device('GPU') +def elu_grad(*args): + return elu_grad_op(*args) + + +embedding_lookup_comm_grad_op = EmbeddingLookupCommGrad().set_device('GPU') +def embedding_lookup_comm_grad(*args): + return embedding_lookup_comm_grad_op(*args) + + +fast_ge_lu_grad_op = FastGeLUGrad().set_device('GPU') +def fast_ge_lu_grad(*args): + return fast_ge_lu_grad_op(*args) + + +def flash_attention_score_grad(*args): + op = _get_cache_prim(FlashAttentionScoreGrad)(*args[-8:]).set_device('GPU') + return op(*args[:-8]) + + +flatten_grad_op = FlattenGrad().set_device('GPU') +def flatten_grad(*args): + return flatten_grad_op(*args) + + +def fractional_avg_pool_grad(*args): + op = _get_cache_prim(FractionalAvgPoolGrad)(*args[-1:]).set_device('GPU') + return op(*args[:-1]) + + +def fractional_max_pool3_d_grad_with_fixed_ksize(*args): + op = _get_cache_prim(FractionalMaxPool3DGradWithFixedKsize)(*args[-1:]).set_device('GPU') + return op(*args[:-1]) + + +def fractional_max_pool_grad(*args): + op = _get_cache_prim(FractionalMaxPoolGrad)(*args[-1:]).set_device('GPU') + return op(*args[:-1]) + + +def fractional_max_pool_grad_with_fixed_ksize(*args): + op = _get_cache_prim(FractionalMaxPoolGradWithFixedKsize)(*args[-1:]).set_device('GPU') + return op(*args[:-1]) + + +def gruv2_grad(*args): + op = _get_cache_prim(GRUV2Grad)(*args[-6:]).set_device('GPU') + return op(*args[:-6]) + + +gather_d_grad_v2_op = GatherDGradV2().set_device('GPU') +def gather_d_grad_v2(*args): + return gather_d_grad_v2_op(*args) + + +ge_lu_grad_op = GeLUGrad().set_device('GPU') +def ge_lu_grad(*args): + return ge_lu_grad_op(*args) + + +def global_comm(*args): + op = _get_cache_prim(GlobalComm)(*args[-2:]).set_device('GPU') + return op(*args[:-2]) + + +def glu_grad(*args): + op = _get_cache_prim(GluGrad)(*args[-1:]).set_device('GPU') + return op(*args[:-1]) + + +def grid_sampler2_d_grad(*args): + op = _get_cache_prim(GridSampler2DGrad)(*args[-4:]).set_device('GPU') + return op(*args[:-4]) + + +def grid_sampler3_d_grad(*args): + op = _get_cache_prim(GridSampler3DGrad)(*args[-4:]).set_device('GPU') + return op(*args[:-4]) + + +def gru_grad_data(*args): + op = _get_cache_prim(GruGradData)(*args[-6:]).set_device('GPU') + return op(*args[:-6]) + + +def gru_grad_weight(*args): + op = _get_cache_prim(GruGradWeight)(*args[-6:]).set_device('GPU') + return op(*args[:-6]) + + +def h_shrink_grad(*args): + op = _get_cache_prim(HShrinkGrad)(*args[-1:]).set_device('GPU') + return op(*args[:-1]) + + +h_sigmoid_grad_op = HSigmoidGrad().set_device('GPU') +def h_sigmoid_grad(*args): + return h_sigmoid_grad_op(*args) + + +h_swish_grad_op = HSwishGrad().set_device('GPU') +def h_swish_grad(*args): + return h_swish_grad_op(*args) + + +igamma_grad_a_op = IgammaGradA().set_device('GPU') +def igamma_grad_a(*args): + return igamma_grad_a_op(*args) + + +def instance_norm_grad(*args): + op = _get_cache_prim(InstanceNormGrad)(*args[-2:]).set_device('GPU') + return op(*args[:-2]) + + +def instance_norm_v2_grad(*args): + op = _get_cache_prim(InstanceNormV2Grad)(*args[-2:]).set_device('GPU') + return op(*args[:-2]) + + +inv_grad_op = InvGrad().set_device('GPU') +def inv_grad(*args): + return inv_grad_op(*args) + + +def kl_div_loss_grad(*args): + op = _get_cache_prim(KLDivLossGrad)(*args[-1:]).set_device('GPU') + return op(*args[:-1]) + + +def l2_normalize_grad(*args): + op = _get_cache_prim(L2NormalizeGrad)(*args[-2:]).set_device('GPU') + return op(*args[:-2]) + + +def lrn_grad(*args): + op = _get_cache_prim(LRNGrad)(*args[-4:]).set_device('GPU') + return op(*args[:-4]) + + +def lstm_grad(*args): + op = _get_cache_prim(LSTMGrad)(*args[-7:]).set_device('GPU') + return op(*args[:-7]) + + +def lstm_grad_data(*args): + op = _get_cache_prim(LSTMGradData)(*args[-6:]).set_device('GPU') + return op(*args[:-6]) + + +def lstm_grad_weight(*args): + op = _get_cache_prim(LSTMGradWeight)(*args[-6:]).set_device('GPU') + return op(*args[:-6]) + + +def layer_norm_grad(*args): + op = _get_cache_prim(LayerNormGrad)(*args[-2:]).set_device('GPU') + return op(*args[:-2]) + + +def layer_norm_grad_grad(*args): + op = _get_cache_prim(LayerNormGradGrad)(*args[-2:]).set_device('GPU') + return op(*args[:-2]) + + +def log_softmax_grad(*args): + op = _get_cache_prim(LogSoftmaxGrad)(*args[-1:]).set_device('GPU') + return op(*args[:-1]) + + +def logit_grad(*args): + op = _get_cache_prim(LogitGrad)(*args[-1:]).set_device('GPU') + return op(*args[:-1]) + + +def lu_unpack_grad(*args): + op = _get_cache_prim(LuUnpackGrad)(*args[-2:]).set_device('GPU') + return op(*args[:-2]) + + +map_tensor_get_grad_op = MapTensorGetGrad().set_device('GPU') +def map_tensor_get_grad(*args): + return map_tensor_get_grad_op(*args) + + +masked_select_grad_op = MaskedSelectGrad().set_device('GPU') +def masked_select_grad(*args): + return masked_select_grad_op(*args) + + +def max_pool3_d_grad(*args): + op = _get_cache_prim(MaxPool3DGrad)(*args[-5:]).set_device('GPU') + return op(*args[:-5]) + + +def max_pool3_d_grad_grad(*args): + op = _get_cache_prim(MaxPool3DGradGrad)(*args[-4:]).set_device('GPU') + return op(*args[:-4]) + + +def max_pool3_d_grad_with_argmax(*args): + op = _get_cache_prim(MaxPool3DGradWithArgmax)(*args[-6:]).set_device('GPU') + return op(*args[:-6]) + + +def max_pool_grad(*args): + op = _get_cache_prim(MaxPoolGrad)(*args[-4:]).set_device('GPU') + return op(*args[:-4]) + + +def max_pool_grad_grad(*args): + op = _get_cache_prim(MaxPoolGradGrad)(*args[-3:]).set_device('GPU') + return op(*args[:-3]) + + +def max_pool_grad_grad_with_argmax(*args): + op = _get_cache_prim(MaxPoolGradGradWithArgmax)(*args[-3:]).set_device('GPU') + return op(*args[:-3]) + + +def max_pool_grad_v1(*args): + op = _get_cache_prim(MaxPoolGradV1)(*args[-4:]).set_device('GPU') + return op(*args[:-4]) + + +def max_pool_grad_with_argmax(*args): + op = _get_cache_prim(MaxPoolGradWithArgmax)(*args[-4:]).set_device('GPU') + return op(*args[:-4]) + + +def max_pool_grad_with_argmax_v2(*args): + op = _get_cache_prim(MaxPoolGradWithArgmaxV2)(*args[-6:]).set_device('GPU') + return op(*args[:-6]) + + +def max_unpool2_d_grad(*args): + op = _get_cache_prim(MaxUnpool2DGrad)(*args[-5:]).set_device('GPU') + return op(*args[:-5]) + + +def max_unpool3_d_grad(*args): + op = _get_cache_prim(MaxUnpool3DGrad)(*args[-5:]).set_device('GPU') + return op(*args[:-5]) + + +def maximum_grad(*args): + op = _get_cache_prim(MaximumGrad)(*args[-2:]).set_device('GPU') + return op(*args[:-2]) + + +def maximum_grad_grad(*args): + op = _get_cache_prim(MaximumGradGrad)(*args[-2:]).set_device('GPU') + return op(*args[:-2]) + + +def median_grad(*args): + op = _get_cache_prim(MedianGrad)(*args[-3:]).set_device('GPU') + return op(*args[:-3]) + + +def minimum_grad(*args): + op = _get_cache_prim(MinimumGrad)(*args[-2:]).set_device('GPU') + return op(*args[:-2]) + + +minimum_grad_grad_op = MinimumGradGrad().set_device('GPU') +def minimum_grad_grad(*args): + return minimum_grad_grad_op(*args) + + +def mirror_pad_grad(*args): + op = _get_cache_prim(MirrorPadGrad)(*args[-1:]).set_device('GPU') + return op(*args[:-1]) + + +def multi_margin_loss_grad(*args): + op = _get_cache_prim(MultiMarginLossGrad)(*args[-3:]).set_device('GPU') + return op(*args[:-3]) + + +def multilabel_margin_loss_grad(*args): + op = _get_cache_prim(MultilabelMarginLossGrad)(*args[-1:]).set_device('GPU') + return op(*args[:-1]) + + +def mvlgamma_grad(*args): + op = _get_cache_prim(MvlgammaGrad)(*args[-1:]).set_device('GPU') + return op(*args[:-1]) + + +def nll_loss_grad(*args): + op = _get_cache_prim(NLLLossGrad)(*args[-2:]).set_device('GPU') + return op(*args[:-2]) + + +def neighbor_exchange_v2_grad(*args): + op = _get_cache_prim(NeighborExchangeV2Grad)(*args[-6:]).set_device('GPU') + return op(*args[:-6]) + + +p_re_lu_grad_op = PReLUGrad().set_device('GPU') +def p_re_lu_grad(*args): + return p_re_lu_grad_op(*args) + + +def psroi_pooling_grad(*args): + op = _get_cache_prim(PSROIPoolingGrad)(*args[-4:]).set_device('GPU') + return op(*args[:-4]) + + +def pad_v3_grad(*args): + op = _get_cache_prim(PadV3Grad)(*args[-2:]).set_device('GPU') + return op(*args[:-2]) + + +def parallel_resize_bilinear_grad(*args): + op = _get_cache_prim(ParallelResizeBilinearGrad)(*args[-4:]).set_device('GPU') + return op(*args[:-4]) + + +def pdist_grad(*args): + op = _get_cache_prim(PdistGrad)(*args[-1:]).set_device('GPU') + return op(*args[:-1]) + + +def primitive(*args): + op = _get_cache_prim(Primitive)(*args[-1:]).set_device('GPU') + return op(*args[:-1]) + + +def primitive_with_infer(*args): + op = _get_cache_prim(PrimitiveWithInfer)(*args[-1:]).set_device('GPU') + return op(*args[:-1]) + + +def ps_roi_pooling_grad(*args): + op = _get_cache_prim(PsROIPoolingGrad)(*args[-9:]).set_device('GPU') + return op(*args[:-9]) + + +def roi_align_grad(*args): + op = _get_cache_prim(ROIAlignGrad)(*args[-4:]).set_device('GPU') + return op(*args[:-4]) + + +random_gamma_grad_op = RandomGammaGrad().set_device('GPU') +def random_gamma_grad(*args): + return random_gamma_grad_op(*args) + + +re_lu6_grad_op = ReLU6Grad().set_device('GPU') +def re_lu6_grad(*args): + return re_lu6_grad_op(*args) + + +reciprocal_grad_op = ReciprocalGrad().set_device('GPU') +def reciprocal_grad(*args): + return reciprocal_grad_op(*args) + + +ref_to_embed_op = RefToEmbed().set_device('GPU') +def ref_to_embed(*args): + return ref_to_embed_op(*args) + + +relu_grad_op = ReluGrad().set_device('GPU') +def relu_grad(*args): + return relu_grad_op(*args) + + +def resize_bicubic_grad(*args): + op = _get_cache_prim(ResizeBicubicGrad)(*args[-2:]).set_device('GPU') + return op(*args[:-2]) + + +def resize_bilinear_grad(*args): + op = _get_cache_prim(ResizeBilinearGrad)(*args[-2:]).set_device('GPU') + return op(*args[:-2]) + + +def resize_linear1_d_grad(*args): + op = _get_cache_prim(ResizeLinear1DGrad)(*args[-1:]).set_device('GPU') + return op(*args[:-1]) + + +def resize_nearest_neighbor_grad(*args): + op = _get_cache_prim(ResizeNearestNeighborGrad)(*args[-2:]).set_device('GPU') + return op(*args[:-2]) + + +def resize_nearest_neighbor_v2_grad(*args): + op = _get_cache_prim(ResizeNearestNeighborV2Grad)(*args[-2:]).set_device('GPU') + return op(*args[:-2]) + + +def resize_v2_grad(*args): + op = _get_cache_prim(ResizeV2Grad)(*args[-2:]).set_device('GPU') + return op(*args[:-2]) + + +rms_norm_grad_op = RmsNormGrad().set_device('GPU') +def rms_norm_grad(*args): + return rms_norm_grad_op(*args) + + +rsqrt_grad_op = RsqrtGrad().set_device('GPU') +def rsqrt_grad(*args): + return rsqrt_grad_op(*args) + + +def scale_and_translate_grad(*args): + op = _get_cache_prim(ScaleAndTranslateGrad)(*args[-2:]).set_device('GPU') + return op(*args[:-2]) + + +selu_grad_op = SeluGrad().set_device('GPU') +def selu_grad(*args): + return selu_grad_op(*args) + + +si_lu_grad_op = SiLUGrad().set_device('GPU') +def si_lu_grad(*args): + return si_lu_grad_op(*args) + + +sigmoid_cross_entropy_with_logits_grad_op = SigmoidCrossEntropyWithLogitsGrad().set_device('GPU') +def sigmoid_cross_entropy_with_logits_grad(*args): + return sigmoid_cross_entropy_with_logits_grad_op(*args) + + +sigmoid_grad_op = SigmoidGrad().set_device('GPU') +def sigmoid_grad(*args): + return sigmoid_grad_op(*args) + + +slice_grad_op = SliceGrad().set_device('GPU') +def slice_grad(*args): + return slice_grad_op(*args) + + +def smooth_l1_loss_grad(*args): + op = _get_cache_prim(SmoothL1LossGrad)(*args[-2:]).set_device('GPU') + return op(*args[:-2]) + + +def soft_margin_loss_grad(*args): + op = _get_cache_prim(SoftMarginLossGrad)(*args[-1:]).set_device('GPU') + return op(*args[:-1]) + + +def soft_shrink_grad(*args): + op = _get_cache_prim(SoftShrinkGrad)(*args[-1:]).set_device('GPU') + return op(*args[:-1]) + + +softmax_grad_op = SoftmaxGrad().set_device('GPU') +def softmax_grad(*args): + return softmax_grad_op(*args) + + +softplus_grad_op = SoftplusGrad().set_device('GPU') +def softplus_grad(*args): + return softplus_grad_op(*args) + + +sparse_fill_empty_rows_grad_op = SparseFillEmptyRowsGrad().set_device('GPU') +def sparse_fill_empty_rows_grad(*args): + return sparse_fill_empty_rows_grad_op(*args) + + +sparse_segment_mean_grad_op = SparseSegmentMeanGrad().set_device('GPU') +def sparse_segment_mean_grad(*args): + return sparse_segment_mean_grad_op(*args) + + +sparse_segment_sqrt_n_grad_op = SparseSegmentSqrtNGrad().set_device('GPU') +def sparse_segment_sqrt_n_grad(*args): + return sparse_segment_sqrt_n_grad_op(*args) + + +sparse_segment_sum_grad_op = SparseSegmentSumGrad().set_device('GPU') +def sparse_segment_sum_grad(*args): + return sparse_segment_sum_grad_op(*args) + + +sparse_slice_grad_op = SparseSliceGrad().set_device('GPU') +def sparse_slice_grad(*args): + return sparse_slice_grad_op(*args) + + +sqrt_grad_op = SqrtGrad().set_device('GPU') +def sqrt_grad(*args): + return sqrt_grad_op(*args) + + +def strided_slice_grad(*args): + op = _get_cache_prim(StridedSliceGrad)(*args[-5:]).set_device('GPU') + return op(*args[:-5]) + + +def sync_batch_norm_grad(*args): + op = _get_cache_prim(SyncBatchNormGrad)(*args[-3:]).set_device('GPU') + return op(*args[:-3]) + + +tanh_grad_op = TanhGrad().set_device('GPU') +def tanh_grad(*args): + return tanh_grad_op(*args) + + +trace_grad_op = TraceGrad().set_device('GPU') +def trace_grad(*args): + return trace_grad_op(*args) + + +unique_grad_op = UniqueGrad().set_device('GPU') +def unique_grad(*args): + return unique_grad_op(*args) + + +upsample_nearest3_d_grad_op = UpsampleNearest3DGrad().set_device('GPU') +def upsample_nearest3_d_grad(*args): + return upsample_nearest3_d_grad_op(*args) + + +def upsample_trilinear3_d_grad(*args): + op = _get_cache_prim(UpsampleTrilinear3DGrad)(*args[-1:]).set_device('GPU') + return op(*args[:-1]) + + +wkv_grad_op = WKVGrad().set_device('GPU') +def wkv_grad(*args): + return wkv_grad_op(*args) + + +a_cos_op = ACos().set_device('GPU') +def a_cos(*args): + return a_cos_op(*args) + + +abs_op = Abs().set_device('GPU') +def abs(*args): + return abs_op(*args) + + +accumulate_nv2_op = AccumulateNV2().set_device('GPU') +def accumulate_nv2(*args): + return accumulate_nv2_op(*args) + + +acosh_op = Acosh().set_device('GPU') +def acosh(*args): + return acosh_op(*args) + + +def adam(*args): + op = _get_cache_prim(Adam)(*args[-2:]).set_device('GPU') + return op(*args[:-2]) + + +def adam_no_update_param(*args): + op = _get_cache_prim(AdamNoUpdateParam)(*args[-2:]).set_device('GPU') + return op(*args[:-2]) + + +def adam_weight_decay(*args): + op = _get_cache_prim(AdamWeightDecay)(*args[-1:]).set_device('GPU') + return op(*args[:-1]) + + +def adaptive_avg_pool2_d(*args): + op = _get_cache_prim(AdaptiveAvgPool2D)(*args[-1:]).set_device('GPU') + return op(*args[:-1]) + + +def adaptive_avg_pool3_d(*args): + op = _get_cache_prim(AdaptiveAvgPool3D)(*args[-1:]).set_device('GPU') + return op(*args[:-1]) + + +def adaptive_max_pool2_d(*args): + op = _get_cache_prim(AdaptiveMaxPool2D)(*args[-1:]).set_device('GPU') + return op(*args[:-1]) + + +adaptive_max_pool3_d_op = AdaptiveMaxPool3D().set_device('GPU') +def adaptive_max_pool3_d(*args): + return adaptive_max_pool3_d_op(*args) + + +add_op = Add().set_device('GPU') +def add(*args): + return add_op(*args) + + +add_n_op = AddN().set_device('GPU') +def add_n(*args): + return add_n_op(*args) + + +addcdiv_op = Addcdiv().set_device('GPU') +def addcdiv(*args): + return addcdiv_op(*args) + + +addcmul_op = Addcmul().set_device('GPU') +def addcmul(*args): + return addcmul_op(*args) + + +adjust_hue_op = AdjustHue().set_device('GPU') +def adjust_hue(*args): + return adjust_hue_op(*args) + + +adjust_saturation_op = AdjustSaturation().set_device('GPU') +def adjust_saturation(*args): + return adjust_saturation_op(*args) + + +def affine_grid(*args): + op = _get_cache_prim(AffineGrid)(*args[-1:]).set_device('GPU') + return op(*args[:-1]) + + +def all_gather(*args): + op = _get_cache_prim(AllGather)(*args[-1:]).set_device('GPU') + return op(*args[:-1]) + + +def all_reduce(*args): + op = _get_cache_prim(AllReduce)(*args[-2:]).set_device('GPU') + return op(*args[:-2]) + + +def allto_all(*args): + op = _get_cache_prim(AlltoAll)(*args[-4:]).set_device('GPU') + return op(*args[:-4]) + + +def allto_all_v(*args): + op = _get_cache_prim(AlltoAllV)(*args[-2:]).set_device('GPU') + return op(*args[:-2]) + + +angle_op = Angle().set_device('GPU') +def angle(*args): + return angle_op(*args) + + +apply_ada_max_op = ApplyAdaMax().set_device('GPU') +def apply_ada_max(*args): + return apply_ada_max_op(*args) + + +apply_adadelta_op = ApplyAdadelta().set_device('GPU') +def apply_adadelta(*args): + return apply_adadelta_op(*args) + + +def apply_adagrad(*args): + op = _get_cache_prim(ApplyAdagrad)(*args[-1:]).set_device('GPU') + return op(*args[:-1]) + + +def apply_adagrad_da(*args): + op = _get_cache_prim(ApplyAdagradDA)(*args[-1:]).set_device('GPU') + return op(*args[:-1]) + + +def apply_adagrad_v2(*args): + op = _get_cache_prim(ApplyAdagradV2)(*args[-2:]).set_device('GPU') + return op(*args[:-2]) + + +def apply_adam_with_amsgrad(*args): + op = _get_cache_prim(ApplyAdamWithAmsgrad)(*args[-4:]).set_device('GPU') + return op(*args[:-4]) + + +def apply_adam_with_amsgrad_v2(*args): + op = _get_cache_prim(ApplyAdamWithAmsgradV2)(*args[-1:]).set_device('GPU') + return op(*args[:-1]) + + +apply_add_sign_op = ApplyAddSign().set_device('GPU') +def apply_add_sign(*args): + return apply_add_sign_op(*args) + + +def apply_centered_rms_prop(*args): + op = _get_cache_prim(ApplyCenteredRMSProp)(*args[-1:]).set_device('GPU') + return op(*args[:-1]) + + +def apply_ftrl(*args): + op = _get_cache_prim(ApplyFtrl)(*args[-1:]).set_device('GPU') + return op(*args[:-1]) + + +apply_gradient_descent_op = ApplyGradientDescent().set_device('GPU') +def apply_gradient_descent(*args): + return apply_gradient_descent_op(*args) + + +def apply_keras_momentum(*args): + op = _get_cache_prim(ApplyKerasMomentum)(*args[-2:]).set_device('GPU') + return op(*args[:-2]) + + +def apply_momentum(*args): + op = _get_cache_prim(ApplyMomentum)(*args[-3:]).set_device('GPU') + return op(*args[:-3]) + + +apply_power_sign_op = ApplyPowerSign().set_device('GPU') +def apply_power_sign(*args): + return apply_power_sign_op(*args) + + +def apply_proximal_adagrad(*args): + op = _get_cache_prim(ApplyProximalAdagrad)(*args[-1:]).set_device('GPU') + return op(*args[:-1]) + + +apply_proximal_gradient_descent_op = ApplyProximalGradientDescent().set_device('GPU') +def apply_proximal_gradient_descent(*args): + return apply_proximal_gradient_descent_op(*args) + + +def apply_rms_prop(*args): + op = _get_cache_prim(ApplyRMSProp)(*args[-1:]).set_device('GPU') + return op(*args[:-1]) + + +def apply_rotary_pos_emb(*args): + op = _get_cache_prim(ApplyRotaryPosEmb)(*args[-1:]).set_device('GPU') + return op(*args[:-1]) + + +def approximate_equal(*args): + op = _get_cache_prim(ApproximateEqual)(*args[-1:]).set_device('GPU') + return op(*args[:-1]) + + +def arg_max_with_value(*args): + op = _get_cache_prim(ArgMaxWithValue)(*args[-2:]).set_device('GPU') + return op(*args[:-2]) + + +def arg_min_with_value(*args): + op = _get_cache_prim(ArgMinWithValue)(*args[-2:]).set_device('GPU') + return op(*args[:-2]) + + +def argmax(*args): + op = _get_cache_prim(Argmax)(*args[-2:]).set_device('GPU') + return op(*args[:-2]) + + +def argmin(*args): + op = _get_cache_prim(Argmin)(*args[-2:]).set_device('GPU') + return op(*args[:-2]) + + +asin_op = Asin().set_device('GPU') +def asin(*args): + return asin_op(*args) + + +asinh_op = Asinh().set_device('GPU') +def asinh(*args): + return asinh_op(*args) + + +assign_op = Assign().set_device('GPU') +def assign(*args): + return assign_op(*args) + + +assign_add_op = AssignAdd().set_device('GPU') +def assign_add(*args): + return assign_add_op(*args) + + +assign_sub_op = AssignSub().set_device('GPU') +def assign_sub(*args): + return assign_sub_op(*args) + + +atan_op = Atan().set_device('GPU') +def atan(*args): + return atan_op(*args) + + +atan2_op = Atan2().set_device('GPU') +def atan2(*args): + return atan2_op(*args) + + +atanh_op = Atanh().set_device('GPU') +def atanh(*args): + return atanh_op(*args) + + +def avg_pool(*args): + op = _get_cache_prim(AvgPool)(*args[-4:]).set_device('GPU') + return op(*args[:-4]) + + +def avg_pool3_d(*args): + op = _get_cache_prim(AvgPool3D)(*args[-8:]).set_device('GPU') + return op(*args[:-8]) + + +def bce_with_logits_loss(*args): + op = _get_cache_prim(BCEWithLogitsLoss)(*args[-1:]).set_device('GPU') + return op(*args[:-1]) + + +def barrier(*args): + op = _get_cache_prim(Barrier)(*args[-1:]).set_device('GPU') + return op(*args[:-1]) + + +def bartlett_window(*args): + op = _get_cache_prim(BartlettWindow)(*args[-2:]).set_device('GPU') + return op(*args[:-2]) + + +def basic_lstm_cell(*args): + op = _get_cache_prim(BasicLSTMCell)(*args[-4:]).set_device('GPU') + return op(*args[:-4]) + + +def batch_i_send_i_recv(*args): + op = _get_cache_prim(BatchISendIRecv)(*args[-5:]).set_device('GPU') + return op(*args[:-5]) + + +def batch_mat_mul(*args): + op = _get_cache_prim(BatchMatMul)(*args[-2:]).set_device('GPU') + return op(*args[:-2]) + + +def batch_norm(*args): + op = _get_cache_prim(BatchNorm)(*args[-4:]).set_device('GPU') + return op(*args[:-4]) + + +def batch_to_space(*args): + op = _get_cache_prim(BatchToSpace)(*args[-2:]).set_device('GPU') + return op(*args[:-2]) + + +def batch_to_space_nd(*args): + op = _get_cache_prim(BatchToSpaceND)(*args[-2:]).set_device('GPU') + return op(*args[:-2]) + + +batch_to_space_ndv2_op = BatchToSpaceNDV2().set_device('GPU') +def batch_to_space_ndv2(*args): + return batch_to_space_ndv2_op(*args) + + +def bernoulli(*args): + op = _get_cache_prim(Bernoulli)(*args[-2:]).set_device('GPU') + return op(*args[:-2]) + + +bessel_i0_op = BesselI0().set_device('GPU') +def bessel_i0(*args): + return bessel_i0_op(*args) + + +bessel_i0e_op = BesselI0e().set_device('GPU') +def bessel_i0e(*args): + return bessel_i0e_op(*args) + + +bessel_i1_op = BesselI1().set_device('GPU') +def bessel_i1(*args): + return bessel_i1_op(*args) + + +bessel_i1e_op = BesselI1e().set_device('GPU') +def bessel_i1e(*args): + return bessel_i1e_op(*args) + + +bessel_j0_op = BesselJ0().set_device('GPU') +def bessel_j0(*args): + return bessel_j0_op(*args) + + +bessel_j1_op = BesselJ1().set_device('GPU') +def bessel_j1(*args): + return bessel_j1_op(*args) + + +bessel_k0_op = BesselK0().set_device('GPU') +def bessel_k0(*args): + return bessel_k0_op(*args) + + +bessel_k0e_op = BesselK0e().set_device('GPU') +def bessel_k0e(*args): + return bessel_k0e_op(*args) + + +bessel_k1_op = BesselK1().set_device('GPU') +def bessel_k1(*args): + return bessel_k1_op(*args) + + +bessel_k1e_op = BesselK1e().set_device('GPU') +def bessel_k1e(*args): + return bessel_k1e_op(*args) + + +bessel_y0_op = BesselY0().set_device('GPU') +def bessel_y0(*args): + return bessel_y0_op(*args) + + +bessel_y1_op = BesselY1().set_device('GPU') +def bessel_y1(*args): + return bessel_y1_op(*args) + + +betainc_op = Betainc().set_device('GPU') +def betainc(*args): + return betainc_op(*args) + + +def bias_add(*args): + op = _get_cache_prim(BiasAdd)(*args[-1:]).set_device('GPU') + return op(*args[:-1]) + + +def binary_cross_entropy(*args): + op = _get_cache_prim(BinaryCrossEntropy)(*args[-1:]).set_device('GPU') + return op(*args[:-1]) + + +bincount_op = Bincount().set_device('GPU') +def bincount(*args): + return bincount_op(*args) + + +bitwise_and_op = BitwiseAnd().set_device('GPU') +def bitwise_and(*args): + return bitwise_and_op(*args) + + +bitwise_or_op = BitwiseOr().set_device('GPU') +def bitwise_or(*args): + return bitwise_or_op(*args) + + +bitwise_xor_op = BitwiseXor().set_device('GPU') +def bitwise_xor(*args): + return bitwise_xor_op(*args) + + +def blackman_window(*args): + op = _get_cache_prim(BlackmanWindow)(*args[-2:]).set_device('GPU') + return op(*args[:-2]) + + +def bounding_box_decode(*args): + op = _get_cache_prim(BoundingBoxDecode)(*args[-4:]).set_device('GPU') + return op(*args[:-4]) + + +def bounding_box_encode(*args): + op = _get_cache_prim(BoundingBoxEncode)(*args[-2:]).set_device('GPU') + return op(*args[:-2]) + + +def broadcast(*args): + op = _get_cache_prim(Broadcast)(*args[-2:]).set_device('GPU') + return op(*args[:-2]) + + +def broadcast_to(*args): + op = _get_cache_prim(BroadcastTo)(*args[-1:]).set_device('GPU') + return op(*args[:-1]) + + +def bucketize(*args): + op = _get_cache_prim(Bucketize)(*args[-1:]).set_device('GPU') + return op(*args[:-1]) + + +def buffer_append(*args): + op = _get_cache_prim(BufferAppend)(*args[-3:]).set_device('GPU') + return op(*args[:-3]) + + +def buffer_get_item(*args): + op = _get_cache_prim(BufferGetItem)(*args[-3:]).set_device('GPU') + return op(*args[:-3]) + + +def buffer_sample(*args): + op = _get_cache_prim(BufferSample)(*args[-6:]).set_device('GPU') + return op(*args[:-6]) + + +def ctc_greedy_decoder(*args): + op = _get_cache_prim(CTCGreedyDecoder)(*args[-1:]).set_device('GPU') + return op(*args[:-1]) + + +def ctc_loss(*args): + op = _get_cache_prim(CTCLoss)(*args[-3:]).set_device('GPU') + return op(*args[:-3]) + + +def ctc_loss_v2(*args): + op = _get_cache_prim(CTCLossV2)(*args[-3:]).set_device('GPU') + return op(*args[:-3]) + + +cast_op = Cast().set_device('GPU') +def cast(*args): + return cast_op(*args) + + +def cauchy(*args): + op = _get_cache_prim(Cauchy)(*args[-3:]).set_device('GPU') + return op(*args[:-3]) + + +def cdist(*args): + op = _get_cache_prim(Cdist)(*args[-1:]).set_device('GPU') + return op(*args[:-1]) + + +def ce_lu(*args): + op = _get_cache_prim(CeLU)(*args[-1:]).set_device('GPU') + return op(*args[:-1]) + + +ceil_op = Ceil().set_device('GPU') +def ceil(*args): + return ceil_op(*args) + + +def channel_shuffle(*args): + op = _get_cache_prim(ChannelShuffle)(*args[-1:]).set_device('GPU') + return op(*args[:-1]) + + +check_numerics_op = CheckNumerics().set_device('GPU') +def check_numerics(*args): + return check_numerics_op(*args) + + +check_valid_op = CheckValid().set_device('GPU') +def check_valid(*args): + return check_valid_op(*args) + + +def cholesky(*args): + op = _get_cache_prim(Cholesky)(*args[-1:]).set_device('GPU') + return op(*args[:-1]) + + +def cholesky_inverse(*args): + op = _get_cache_prim(CholeskyInverse)(*args[-1:]).set_device('GPU') + return op(*args[:-1]) + + +def cholesky_solve(*args): + op = _get_cache_prim(CholeskySolve)(*args[-1:]).set_device('GPU') + return op(*args[:-1]) + + +coalesce_op = Coalesce().set_device('GPU') +def coalesce(*args): + return coalesce_op(*args) + + +def col2_im(*args): + op = _get_cache_prim(Col2Im)(*args[-4:]).set_device('GPU') + return op(*args[:-4]) + + +def collective_gather(*args): + op = _get_cache_prim(CollectiveGather)(*args[-2:]).set_device('GPU') + return op(*args[:-2]) + + +def collective_scatter(*args): + op = _get_cache_prim(CollectiveScatter)(*args[-2:]).set_device('GPU') + return op(*args[:-2]) + + +def combined_non_max_suppression(*args): + op = _get_cache_prim(CombinedNonMaxSuppression)(*args[-2:]).set_device('GPU') + return op(*args[:-2]) + + +compare_and_bitpack_op = CompareAndBitpack().set_device('GPU') +def compare_and_bitpack(*args): + return compare_and_bitpack_op(*args) + + +complex_op = Complex().set_device('GPU') +def complex(*args): + return complex_op(*args) + + +complex_abs_op = ComplexAbs().set_device('GPU') +def complex_abs(*args): + return complex_abs_op(*args) + + +def compute_accidental_hits(*args): + op = _get_cache_prim(ComputeAccidentalHits)(*args[-1:]).set_device('GPU') + return op(*args[:-1]) + + +def concat(*args): + op = _get_cache_prim(Concat)(*args[-1:]).set_device('GPU') + return op(*args[:-1]) + + +def confusion_matrix(*args): + op = _get_cache_prim(ConfusionMatrix)(*args[-2:]).set_device('GPU') + return op(*args[:-2]) + + +conj_op = Conj().set_device('GPU') +def conj(*args): + return conj_op(*args) + + +conjugate_transpose_op = ConjugateTranspose().set_device('GPU') +def conjugate_transpose(*args): + return conjugate_transpose_op(*args) + + +def conv2_d(*args): + op = _get_cache_prim(Conv2D)(*args[-9:]).set_device('GPU') + return op(*args[:-9]) + + +def conv2_d_backprop_input(*args): + op = _get_cache_prim(Conv2DBackpropInput)(*args[-10:]).set_device('GPU') + return op(*args[:-10]) + + +def conv2_d_transpose(*args): + op = _get_cache_prim(Conv2DTranspose)(*args[-10:]).set_device('GPU') + return op(*args[:-10]) + + +def conv3_d(*args): + op = _get_cache_prim(Conv3D)(*args[-9:]).set_device('GPU') + return op(*args[:-9]) + + +def conv3_d_transpose(*args): + op = _get_cache_prim(Conv3DTranspose)(*args[-11:]).set_device('GPU') + return op(*args[:-11]) + + +copy_with_slice_op = CopyWithSlice().set_device('GPU') +def copy_with_slice(*args): + return copy_with_slice_op(*args) + + +cos_op = Cos().set_device('GPU') +def cos(*args): + return cos_op(*args) + + +cosh_op = Cosh().set_device('GPU') +def cosh(*args): + return cosh_op(*args) + + +def count_non_zero(*args): + op = _get_cache_prim(CountNonZero)(*args[-1:]).set_device('GPU') + return op(*args[:-1]) + + +def crop_and_resize(*args): + op = _get_cache_prim(CropAndResize)(*args[-2:]).set_device('GPU') + return op(*args[:-2]) + + +def cross(*args): + op = _get_cache_prim(Cross)(*args[-1:]).set_device('GPU') + return op(*args[:-1]) + + +def cum_prod(*args): + op = _get_cache_prim(CumProd)(*args[-2:]).set_device('GPU') + return op(*args[:-2]) + + +def cum_sum(*args): + op = _get_cache_prim(CumSum)(*args[-2:]).set_device('GPU') + return op(*args[:-2]) + + +def cummax(*args): + op = _get_cache_prim(Cummax)(*args[-1:]).set_device('GPU') + return op(*args[:-1]) + + +def cummin(*args): + op = _get_cache_prim(Cummin)(*args[-1:]).set_device('GPU') + return op(*args[:-1]) + + +def cumulative_logsumexp(*args): + op = _get_cache_prim(CumulativeLogsumexp)(*args[-2:]).set_device('GPU') + return op(*args[:-2]) + + +d_type_op = DType().set_device('GPU') +def d_type(*args): + return d_type_op(*args) + + +def data_format_dim_map(*args): + op = _get_cache_prim(DataFormatDimMap)(*args[-2:]).set_device('GPU') + return op(*args[:-2]) + + +def data_format_vec_permute(*args): + op = _get_cache_prim(DataFormatVecPermute)(*args[-2:]).set_device('GPU') + return op(*args[:-2]) + + +def deformable_offsets(*args): + op = _get_cache_prim(DeformableOffsets)(*args[-7:]).set_device('GPU') + return op(*args[:-7]) + + +dense_op = Dense().set_device('GPU') +def dense(*args): + return dense_op(*args) + + +depend_op = Depend().set_device('GPU') +def depend(*args): + return depend_op(*args) + + +def depth_to_space(*args): + op = _get_cache_prim(DepthToSpace)(*args[-1:]).set_device('GPU') + return op(*args[:-1]) + + +def depthwise_conv2d_native(*args): + op = _get_cache_prim(DepthwiseConv2dNative)(*args[-8:]).set_device('GPU') + return op(*args[:-8]) + + +diag_op = Diag().set_device('GPU') +def diag(*args): + return diag_op(*args) + + +diag_part_op = DiagPart().set_device('GPU') +def diag_part(*args): + return diag_part_op(*args) + + +digamma_op = Digamma().set_device('GPU') +def digamma(*args): + return digamma_op(*args) + + +def dilation2_d(*args): + op = _get_cache_prim(Dilation2D)(*args[-4:]).set_device('GPU') + return op(*args[:-4]) + + +div_op = Div().set_device('GPU') +def div(*args): + return div_op(*args) + + +div_no_nan_op = DivNoNan().set_device('GPU') +def div_no_nan(*args): + return div_no_nan_op(*args) + + +def dropout(*args): + op = _get_cache_prim(Dropout)(*args[-3:]).set_device('GPU') + return op(*args[:-3]) + + +def dropout2_d(*args): + op = _get_cache_prim(Dropout2D)(*args[-1:]).set_device('GPU') + return op(*args[:-1]) + + +def dropout3_d(*args): + op = _get_cache_prim(Dropout3D)(*args[-1:]).set_device('GPU') + return op(*args[:-1]) + + +def dropout_gen_mask(*args): + op = _get_cache_prim(DropoutGenMask)(*args[-2:]).set_device('GPU') + return op(*args[:-2]) + + +def dynamic_gruv2(*args): + op = _get_cache_prim(DynamicGRUV2)(*args[-10:]).set_device('GPU') + return op(*args[:-10]) + + +def dynamic_rnn(*args): + op = _get_cache_prim(DynamicRNN)(*args[-11:]).set_device('GPU') + return op(*args[:-11]) + + +def dynamic_shape(*args): + op = _get_cache_prim(DynamicShape)(*args[-1:]).set_device('GPU') + return op(*args[:-1]) + + +def edit_distance(*args): + op = _get_cache_prim(EditDistance)(*args[-1:]).set_device('GPU') + return op(*args[:-1]) + + +def eig(*args): + op = _get_cache_prim(Eig)(*args[-1:]).set_device('GPU') + return op(*args[:-1]) + + +def einsum(*args): + op = _get_cache_prim(Einsum)(*args[-1:]).set_device('GPU') + return op(*args[:-1]) + + +def elu(*args): + op = _get_cache_prim(Elu)(*args[-1:]).set_device('GPU') + return op(*args[:-1]) + + +embedding_lookup_op = EmbeddingLookup().set_device('GPU') +def embedding_lookup(*args): + return embedding_lookup_op(*args) + + +eps_op = Eps().set_device('GPU') +def eps(*args): + return eps_op(*args) + + +equal_op = Equal().set_device('GPU') +def equal(*args): + return equal_op(*args) + + +equal_count_op = EqualCount().set_device('GPU') +def equal_count(*args): + return equal_count_op(*args) + + +erf_op = Erf().set_device('GPU') +def erf(*args): + return erf_op(*args) + + +erfc_op = Erfc().set_device('GPU') +def erfc(*args): + return erfc_op(*args) + + +erfinv_op = Erfinv().set_device('GPU') +def erfinv(*args): + return erfinv_op(*args) + + +erfinv_op = Erfinv().set_device('GPU') +def erfinv(*args): + return erfinv_op(*args) + + +def euclidean_norm(*args): + op = _get_cache_prim(EuclideanNorm)(*args[-1:]).set_device('GPU') + return op(*args[:-1]) + + +exp_op = Exp().set_device('GPU') +def exp(*args): + return exp_op(*args) + + + +expand_dims_op = ExpandDims().set_device('GPU') +def expand_dims(*args): + return expand_dims_op(*args) + + +expm1_op = Expm1().set_device('GPU') +def expm1(*args): + return expm1_op(*args) + + +def extract_glimpse(*args): + op = _get_cache_prim(ExtractGlimpse)(*args[-4:]).set_device('GPU') + return op(*args[:-4]) + + +def extract_image_patches(*args): + op = _get_cache_prim(ExtractImagePatches)(*args[-4:]).set_device('GPU') + return op(*args[:-4]) + + +def extract_volume_patches(*args): + op = _get_cache_prim(ExtractVolumePatches)(*args[-3:]).set_device('GPU') + return op(*args[:-3]) + + +eye_op = Eye().set_device('GPU') +def eye(*args): + return eye_op(*args) + + +def fft_with_size(*args): + op = _get_cache_prim(FFTWithSize)(*args[-6:]).set_device('GPU') + return op(*args[:-6]) + + +fast_ge_lu_op = FastGeLU().set_device('GPU') +def fast_ge_lu(*args): + return fast_ge_lu_op(*args) + + +fill_op = Fill().set_device('GPU') +def fill(*args): + return fill_op(*args) + + +def fill_diagonal(*args): + op = _get_cache_prim(FillDiagonal)(*args[-2:]).set_device('GPU') + return op(*args[:-2]) + + +fill_v2_op = FillV2().set_device('GPU') +def fill_v2(*args): + return fill_v2_op(*args) + + +fills_op = Fills().set_device('GPU') +def fills(*args): + return fills_op(*args) + + +flatten_op = Flatten().set_device('GPU') +def flatten(*args): + return flatten_op(*args) + + +float_status_op = FloatStatus().set_device('GPU') +def float_status(*args): + return float_status_op(*args) + + +floor_op = Floor().set_device('GPU') +def floor(*args): + return floor_op(*args) + + +floor_div_op = FloorDiv().set_device('GPU') +def floor_div(*args): + return floor_div_op(*args) + + +floor_mod_op = FloorMod().set_device('GPU') +def floor_mod(*args): + return floor_mod_op(*args) + + +fmax_op = Fmax().set_device('GPU') +def fmax(*args): + return fmax_op(*args) + + +fmin_op = Fmin().set_device('GPU') +def fmin(*args): + return fmin_op(*args) + + +fori_loop_op = ForiLoop().set_device('GPU') +def fori_loop(*args): + return fori_loop_op(*args) + + +def fractional_avg_pool(*args): + op = _get_cache_prim(FractionalAvgPool)(*args[-6:]).set_device('GPU') + return op(*args[:-6]) + + +def fractional_max_pool(*args): + op = _get_cache_prim(FractionalMaxPool)(*args[-6:]).set_device('GPU') + return op(*args[:-6]) + + +def fractional_max_pool3_d_with_fixed_ksize(*args): + op = _get_cache_prim(FractionalMaxPool3DWithFixedKsize)(*args[-3:]).set_device('GPU') + return op(*args[:-3]) + + +def fractional_max_pool_with_fixed_ksize(*args): + op = _get_cache_prim(FractionalMaxPoolWithFixedKsize)(*args[-3:]).set_device('GPU') + return op(*args[:-3]) + + +def fused_ada_factor(*args): + op = _get_cache_prim(FusedAdaFactor)(*args[-3:]).set_device('GPU') + return op(*args[:-3]) + + +def fused_ada_factor_with_global_norm(*args): + op = _get_cache_prim(FusedAdaFactorWithGlobalNorm)(*args[-3:]).set_device('GPU') + return op(*args[:-3]) + + +def fused_cast_adam_weight_decay(*args): + op = _get_cache_prim(FusedCastAdamWeightDecay)(*args[-1:]).set_device('GPU') + return op(*args[:-1]) + + +def fused_sparse_adam(*args): + op = _get_cache_prim(FusedSparseAdam)(*args[-2:]).set_device('GPU') + return op(*args[:-2]) + + +def fused_sparse_ftrl(*args): + op = _get_cache_prim(FusedSparseFtrl)(*args[-5:]).set_device('GPU') + return op(*args[:-5]) + + +def fused_sparse_lazy_adam(*args): + op = _get_cache_prim(FusedSparseLazyAdam)(*args[-2:]).set_device('GPU') + return op(*args[:-2]) + + +def fused_sparse_proximal_adagrad(*args): + op = _get_cache_prim(FusedSparseProximalAdagrad)(*args[-1:]).set_device('GPU') + return op(*args[:-1]) + + +fused_weight_scale_apply_momentum_op = FusedWeightScaleApplyMomentum().set_device('GPU') +def fused_weight_scale_apply_momentum(*args): + return fused_weight_scale_apply_momentum_op(*args) + + +def glu(*args): + op = _get_cache_prim(GLU)(*args[-1:]).set_device('GPU') + return op(*args[:-1]) + + +def gamma(*args): + op = _get_cache_prim(Gamma)(*args[-2:]).set_device('GPU') + return op(*args[:-2]) + + +def gather(*args): + op = _get_cache_prim(Gather)(*args[-1:]).set_device('GPU') + return op(*args[:-1]) + + +gather_d_op = GatherD().set_device('GPU') +def gather_d(*args): + return gather_d_op(*args) + + +gather_nd_op = GatherNd().set_device('GPU') +def gather_nd(*args): + return gather_nd_op(*args) + + +gcd_op = Gcd().set_device('GPU') +def gcd(*args): + return gcd_op(*args) + + +ge_lu_op = GeLU().set_device('GPU') +def ge_lu(*args): + return ge_lu_op(*args) + + +ge_switch_op = GeSwitch().set_device('GPU') +def ge_switch(*args): + return ge_switch_op(*args) + + +geqrf_op = Geqrf().set_device('GPU') +def geqrf(*args): + return geqrf_op(*args) + + +ger_op = Ger().set_device('GPU') +def ger(*args): + return ger_op(*args) + + +def get_next(*args): + op = _get_cache_prim(GetNext)(*args[-4:]).set_device('GPU') + return op(*args[:-4]) + + +greater_op = Greater().set_device('GPU') +def greater(*args): + return greater_op(*args) + + +greater_equal_op = GreaterEqual().set_device('GPU') +def greater_equal(*args): + return greater_equal_op(*args) + + +def grid_sampler2_d(*args): + op = _get_cache_prim(GridSampler2D)(*args[-3:]).set_device('GPU') + return op(*args[:-3]) + + +def grid_sampler3_d(*args): + op = _get_cache_prim(GridSampler3D)(*args[-3:]).set_device('GPU') + return op(*args[:-3]) + + +group_topk_op = GroupTopk().set_device('GPU') +def group_topk(*args): + return group_topk_op(*args) + + +hsv_to_rgb_op = HSVToRGB().set_device('GPU') +def hsv_to_rgb(*args): + return hsv_to_rgb_op(*args) + + +def h_shrink(*args): + op = _get_cache_prim(HShrink)(*args[-1:]).set_device('GPU') + return op(*args[:-1]) + + +h_sigmoid_op = HSigmoid().set_device('GPU') +def h_sigmoid(*args): + return h_sigmoid_op(*args) + + +h_swish_op = HSwish().set_device('GPU') +def h_swish(*args): + return h_swish_op(*args) + + +def hamming_window(*args): + op = _get_cache_prim(HammingWindow)(*args[-4:]).set_device('GPU') + return op(*args[:-4]) + + +heaviside_op = Heaviside().set_device('GPU') +def heaviside(*args): + return heaviside_op(*args) + + +def histogram(*args): + op = _get_cache_prim(Histogram)(*args[-3:]).set_device('GPU') + return op(*args[:-3]) + + +def histogram_fixed_width(*args): + op = _get_cache_prim(HistogramFixedWidth)(*args[-2:]).set_device('GPU') + return op(*args[:-2]) + + +histogram_summary_op = HistogramSummary().set_device('GPU') +def histogram_summary(*args): + return histogram_summary_op(*args) + + +def hook_backward(*args): + op = _get_cache_prim(HookBackward)(*args[-2:]).set_device('GPU') + return op(*args[:-2]) + + +hypot_op = Hypot().set_device('GPU') +def hypot(*args): + return hypot_op(*args) + + +def iou(*args): + op = _get_cache_prim(IOU)(*args[-1:]).set_device('GPU') + return op(*args[:-1]) + + +identity_op = Identity().set_device('GPU') +def identity(*args): + return identity_op(*args) + + +identity_n_op = IdentityN().set_device('GPU') +def identity_n(*args): + return identity_n_op(*args) + + +igamma_op = Igamma().set_device('GPU') +def igamma(*args): + return igamma_op(*args) + + +igammac_op = Igammac().set_device('GPU') +def igammac(*args): + return igammac_op(*args) + + +def im2_col(*args): + op = _get_cache_prim(Im2Col)(*args[-4:]).set_device('GPU') + return op(*args[:-4]) + + +imag_op = Imag().set_device('GPU') +def imag(*args): + return imag_op(*args) + + +image_summary_op = ImageSummary().set_device('GPU') +def image_summary(*args): + return image_summary_op(*args) + + +def in_top_k(*args): + op = _get_cache_prim(InTopK)(*args[-1:]).set_device('GPU') + return op(*args[:-1]) + + +def index_add(*args): + op = _get_cache_prim(IndexAdd)(*args[-3:]).set_device('GPU') + return op(*args[:-3]) + + +index_fill_op = IndexFill().set_device('GPU') +def index_fill(*args): + return index_fill_op(*args) + + +def index_put(*args): + op = _get_cache_prim(IndexPut)(*args[-1:]).set_device('GPU') + return op(*args[:-1]) + + +def inplace_add(*args): + op = _get_cache_prim(InplaceAdd)(*args[-1:]).set_device('GPU') + return op(*args[:-1]) + + +def inplace_index_add(*args): + op = _get_cache_prim(InplaceIndexAdd)(*args[-1:]).set_device('GPU') + return op(*args[:-1]) + + +def inplace_sub(*args): + op = _get_cache_prim(InplaceSub)(*args[-1:]).set_device('GPU') + return op(*args[:-1]) + + +def inplace_update(*args): + op = _get_cache_prim(InplaceUpdate)(*args[-1:]).set_device('GPU') + return op(*args[:-1]) + + +inplace_update_v2_op = InplaceUpdateV2().set_device('GPU') +def inplace_update_v2(*args): + return inplace_update_v2_op(*args) + + +def insert_gradient_of(*args): + op = _get_cache_prim(InsertGradientOf)(*args[-1:]).set_device('GPU') + return op(*args[:-1]) + + +inv_op = Inv().set_device('GPU') +def inv(*args): + return inv_op(*args) + + +invert_op = Invert().set_device('GPU') +def invert(*args): + return invert_op(*args) + + +invert_permutation_op = InvertPermutation().set_device('GPU') +def invert_permutation(*args): + return invert_permutation_op(*args) + + +def is_close(*args): + op = _get_cache_prim(IsClose)(*args[-3:]).set_device('GPU') + return op(*args[:-3]) + + +is_finite_op = IsFinite().set_device('GPU') +def is_finite(*args): + return is_finite_op(*args) + + +is_inf_op = IsInf().set_device('GPU') +def is_inf(*args): + return is_inf_op(*args) + + +is_nan_op = IsNan().set_device('GPU') +def is_nan(*args): + return is_nan_op(*args) + + +def kl_div_loss(*args): + op = _get_cache_prim(KLDivLoss)(*args[-1:]).set_device('GPU') + return op(*args[:-1]) + + +l2_loss_op = L2Loss().set_device('GPU') +def l2_loss(*args): + return l2_loss_op(*args) + + +def l2_normalize(*args): + op = _get_cache_prim(L2Normalize)(*args[-2:]).set_device('GPU') + return op(*args[:-2]) + + +def lars_update(*args): + op = _get_cache_prim(LARSUpdate)(*args[-3:]).set_device('GPU') + return op(*args[:-3]) + + +def lrn(*args): + op = _get_cache_prim(LRN)(*args[-5:]).set_device('GPU') + return op(*args[:-5]) + + +def lstm(*args): + op = _get_cache_prim(LSTM)(*args[-7:]).set_device('GPU') + return op(*args[:-7]) + + +def layer_norm(*args): + op = _get_cache_prim(LayerNorm)(*args[-3:]).set_device('GPU') + return op(*args[:-3]) + + +lcm_op = Lcm().set_device('GPU') +def lcm(*args): + return lcm_op(*args) + + +left_shift_op = LeftShift().set_device('GPU') +def left_shift(*args): + return left_shift_op(*args) + + +lerp_op = Lerp().set_device('GPU') +def lerp(*args): + return lerp_op(*args) + + +lerp_scalar_op = LerpScalar().set_device('GPU') +def lerp_scalar(*args): + return lerp_scalar_op(*args) + + +less_op = Less().set_device('GPU') +def less(*args): + return less_op(*args) + + +less_equal_op = LessEqual().set_device('GPU') +def less_equal(*args): + return less_equal_op(*args) + + +lgamma_op = Lgamma().set_device('GPU') +def lgamma(*args): + return lgamma_op(*args) + + +lin_space_op = LinSpace().set_device('GPU') +def lin_space(*args): + return lin_space_op(*args) + + +def list_diff(*args): + op = _get_cache_prim(ListDiff)(*args[-1:]).set_device('GPU') + return op(*args[:-1]) + + +log_op = Log().set_device('GPU') +def log(*args): + return log_op(*args) + + +log1p_op = Log1p().set_device('GPU') +def log1p(*args): + return log1p_op(*args) + + +log_matrix_determinant_op = LogMatrixDeterminant().set_device('GPU') +def log_matrix_determinant(*args): + return log_matrix_determinant_op(*args) + + +def log_normal_reverse(*args): + op = _get_cache_prim(LogNormalReverse)(*args[-2:]).set_device('GPU') + return op(*args[:-2]) + + +def log_softmax(*args): + op = _get_cache_prim(LogSoftmax)(*args[-1:]).set_device('GPU') + return op(*args[:-1]) + + +log_softmax_ext_op = LogSoftmaxExt().set_device('GPU') +def log_softmax_ext(*args): + return log_softmax_ext_op(*args) + + +def log_space(*args): + op = _get_cache_prim(LogSpace)(*args[-3:]).set_device('GPU') + return op(*args[:-3]) + + +def log_uniform_candidate_sampler(*args): + op = _get_cache_prim(LogUniformCandidateSampler)(*args[-5:]).set_device('GPU') + return op(*args[:-5]) + + +logical_and_op = LogicalAnd().set_device('GPU') +def logical_and(*args): + return logical_and_op(*args) + + +logical_not_op = LogicalNot().set_device('GPU') +def logical_not(*args): + return logical_not_op(*args) + + +logical_or_op = LogicalOr().set_device('GPU') +def logical_or(*args): + return logical_or_op(*args) + + +logical_xor_op = LogicalXor().set_device('GPU') +def logical_xor(*args): + return logical_xor_op(*args) + + +def logit(*args): + op = _get_cache_prim(Logit)(*args[-1:]).set_device('GPU') + return op(*args[:-1]) + + +def lower_bound(*args): + op = _get_cache_prim(LowerBound)(*args[-1:]).set_device('GPU') + return op(*args[:-1]) + + +def lp_norm(*args): + op = _get_cache_prim(LpNorm)(*args[-4:]).set_device('GPU') + return op(*args[:-4]) + + +def lstsq(*args): + op = _get_cache_prim(Lstsq)(*args[-2:]).set_device('GPU') + return op(*args[:-2]) + + +lu_solve_op = LuSolve().set_device('GPU') +def lu_solve(*args): + return lu_solve_op(*args) + + +def lu_unpack(*args): + op = _get_cache_prim(LuUnpack)(*args[-2:]).set_device('GPU') + return op(*args[:-2]) + + +map_cache_idx_op = MapCacheIdx().set_device('GPU') +def map_cache_idx(*args): + return map_cache_idx_op(*args) + + +map_uniform_op = MapUniform().set_device('GPU') +def map_uniform(*args): + return map_uniform_op(*args) + + +masked_fill_op = MaskedFill().set_device('GPU') +def masked_fill(*args): + return masked_fill_op(*args) + + +masked_scatter_op = MaskedScatter().set_device('GPU') +def masked_scatter(*args): + return masked_scatter_op(*args) + + +masked_select_op = MaskedSelect().set_device('GPU') +def masked_select(*args): + return masked_select_op(*args) + + +def mat_mul(*args): + op = _get_cache_prim(MatMul)(*args[-2:]).set_device('GPU') + return op(*args[:-2]) + + +matrix_band_part_op = MatrixBandPart().set_device('GPU') +def matrix_band_part(*args): + return matrix_band_part_op(*args) + + +matrix_determinant_op = MatrixDeterminant().set_device('GPU') +def matrix_determinant(*args): + return matrix_determinant_op(*args) + + +def matrix_diag_part_v3(*args): + op = _get_cache_prim(MatrixDiagPartV3)(*args[-1:]).set_device('GPU') + return op(*args[:-1]) + + +def matrix_diag_v3(*args): + op = _get_cache_prim(MatrixDiagV3)(*args[-1:]).set_device('GPU') + return op(*args[:-1]) + + +matrix_exp_op = MatrixExp().set_device('GPU') +def matrix_exp(*args): + return matrix_exp_op(*args) + + +def matrix_inverse(*args): + op = _get_cache_prim(MatrixInverse)(*args[-1:]).set_device('GPU') + return op(*args[:-1]) + + +matrix_logarithm_op = MatrixLogarithm().set_device('GPU') +def matrix_logarithm(*args): + return matrix_logarithm_op(*args) + + +def matrix_power(*args): + op = _get_cache_prim(MatrixPower)(*args[-1:]).set_device('GPU') + return op(*args[:-1]) + + +def matrix_set_diag_v3(*args): + op = _get_cache_prim(MatrixSetDiagV3)(*args[-1:]).set_device('GPU') + return op(*args[:-1]) + + +def matrix_solve(*args): + op = _get_cache_prim(MatrixSolve)(*args[-1:]).set_device('GPU') + return op(*args[:-1]) + + +def matrix_solve_ls(*args): + op = _get_cache_prim(MatrixSolveLs)(*args[-1:]).set_device('GPU') + return op(*args[:-1]) + + +def matrix_triangular_solve(*args): + op = _get_cache_prim(MatrixTriangularSolve)(*args[-2:]).set_device('GPU') + return op(*args[:-2]) + + +def max_pool(*args): + op = _get_cache_prim(MaxPool)(*args[-4:]).set_device('GPU') + return op(*args[:-4]) + + +def max_pool3_d(*args): + op = _get_cache_prim(MaxPool3D)(*args[-6:]).set_device('GPU') + return op(*args[:-6]) + + +def max_pool3_d_with_argmax(*args): + op = _get_cache_prim(MaxPool3DWithArgmax)(*args[-7:]).set_device('GPU') + return op(*args[:-7]) + + +def max_pool_with_argmax(*args): + op = _get_cache_prim(MaxPoolWithArgmax)(*args[-4:]).set_device('GPU') + return op(*args[:-4]) + + +def max_pool_with_argmax_v2(*args): + op = _get_cache_prim(MaxPoolWithArgmaxV2)(*args[-6:]).set_device('GPU') + return op(*args[:-6]) + + +def max_unpool2_d(*args): + op = _get_cache_prim(MaxUnpool2D)(*args[-5:]).set_device('GPU') + return op(*args[:-5]) + + +def max_unpool3_d(*args): + op = _get_cache_prim(MaxUnpool3D)(*args[-5:]).set_device('GPU') + return op(*args[:-5]) + + +maximum_op = Maximum().set_device('GPU') +def maximum(*args): + return maximum_op(*args) + + +merge_op = Merge().set_device('GPU') +def merge(*args): + return merge_op(*args) + + +def meshgrid(*args): + op = _get_cache_prim(Meshgrid)(*args[-1:]).set_device('GPU') + return op(*args[:-1]) + + +minimum_op = Minimum().set_device('GPU') +def minimum(*args): + return minimum_op(*args) + + +def mirror_pad(*args): + op = _get_cache_prim(MirrorPad)(*args[-1:]).set_device('GPU') + return op(*args[:-1]) + + +mish_op = Mish().set_device('GPU') +def mish(*args): + return mish_op(*args) + + +mod_op = Mod().set_device('GPU') +def mod(*args): + return mod_op(*args) + + +def morph(*args): + op = _get_cache_prim(Morph)(*args[-3:]).set_device('GPU') + return op(*args[:-3]) + + +move_to_op = MoveTo().set_device('GPU') +def move_to(*args): + return move_to_op(*args) + + +mul_op = Mul().set_device('GPU') +def mul(*args): + return mul_op(*args) + + +mul_no_nan_op = MulNoNan().set_device('GPU') +def mul_no_nan(*args): + return mul_no_nan_op(*args) + + +def multi_margin_loss(*args): + op = _get_cache_prim(MultiMarginLoss)(*args[-3:]).set_device('GPU') + return op(*args[:-3]) + + +def multilabel_margin_loss(*args): + op = _get_cache_prim(MultilabelMarginLoss)(*args[-1:]).set_device('GPU') + return op(*args[:-1]) + + +def multinomial(*args): + op = _get_cache_prim(Multinomial)(*args[-3:]).set_device('GPU') + return op(*args[:-3]) + + +def multinomial_with_replacement(*args): + op = _get_cache_prim(MultinomialWithReplacement)(*args[-2:]).set_device('GPU') + return op(*args[:-2]) + + +def mvlgamma(*args): + op = _get_cache_prim(Mvlgamma)(*args[-1:]).set_device('GPU') + return op(*args[:-1]) + + +def nll_loss(*args): + op = _get_cache_prim(NLLLoss)(*args[-2:]).set_device('GPU') + return op(*args[:-2]) + + +def nms_with_mask(*args): + op = _get_cache_prim(NMSWithMask)(*args[-1:]).set_device('GPU') + return op(*args[:-1]) + + +def nan_to_num(*args): + op = _get_cache_prim(NanToNum)(*args[-3:]).set_device('GPU') + return op(*args[:-3]) + + +neg_op = Neg().set_device('GPU') +def neg(*args): + return neg_op(*args) + + +def neighbor_exchange(*args): + op = _get_cache_prim(NeighborExchange)(*args[-6:]).set_device('GPU') + return op(*args[:-6]) + + +def neighbor_exchange_v2(*args): + op = _get_cache_prim(NeighborExchangeV2)(*args[-6:]).set_device('GPU') + return op(*args[:-6]) + + +next_after_op = NextAfter().set_device('GPU') +def next_after(*args): + return next_after_op(*args) + + +def no_repeat_n_gram(*args): + op = _get_cache_prim(NoRepeatNGram)(*args[-1:]).set_device('GPU') + return op(*args[:-1]) + + +def non_deterministic_ints(*args): + op = _get_cache_prim(NonDeterministicInts)(*args[-1:]).set_device('GPU') + return op(*args[:-1]) + + +non_max_suppression_v3_op = NonMaxSuppressionV3().set_device('GPU') +def non_max_suppression_v3(*args): + return non_max_suppression_v3_op(*args) + + +non_max_suppression_with_overlaps_op = NonMaxSuppressionWithOverlaps().set_device('GPU') +def non_max_suppression_with_overlaps(*args): + return non_max_suppression_with_overlaps_op(*args) + + +non_zero_op = NonZero().set_device('GPU') +def non_zero(*args): + return non_zero_op(*args) + + +not_equal_op = NotEqual().set_device('GPU') +def not_equal(*args): + return not_equal_op(*args) + + +def nth_element(*args): + op = _get_cache_prim(NthElement)(*args[-1:]).set_device('GPU') + return op(*args[:-1]) + + +def nuclear_norm(*args): + op = _get_cache_prim(NuclearNorm)(*args[-2:]).set_device('GPU') + return op(*args[:-2]) + + +def one_hot(*args): + op = _get_cache_prim(OneHot)(*args[-1:]).set_device('GPU') + return op(*args[:-1]) + + +ones_op = Ones().set_device('GPU') +def ones(*args): + return ones_op(*args) + + +ones_like_op = OnesLike().set_device('GPU') +def ones_like(*args): + return ones_like_op(*args) + + +orgqr_op = Orgqr().set_device('GPU') +def orgqr(*args): + return orgqr_op(*args) + + +def ormqr(*args): + op = _get_cache_prim(Ormqr)(*args[-2:]).set_device('GPU') + return op(*args[:-2]) + + +p_re_lu_op = PReLU().set_device('GPU') +def p_re_lu(*args): + return p_re_lu_op(*args) + + +def pack(*args): + op = _get_cache_prim(Pack)(*args[-1:]).set_device('GPU') + return op(*args[:-1]) + + +def pad(*args): + op = _get_cache_prim(Pad)(*args[-1:]).set_device('GPU') + return op(*args[:-1]) + + +def pad_v3(*args): + op = _get_cache_prim(PadV3)(*args[-2:]).set_device('GPU') + return op(*args[:-2]) + + +def padding(*args): + op = _get_cache_prim(Padding)(*args[-1:]).set_device('GPU') + return op(*args[:-1]) + + +def paged_attention(*args): + op = _get_cache_prim(PagedAttention)(*args[-6:]).set_device('GPU') + return op(*args[:-6]) + + +def paged_attention_mask(*args): + op = _get_cache_prim(PagedAttentionMask)(*args[-4:]).set_device('GPU') + return op(*args[:-4]) + + +parallel_concat_op = ParallelConcat().set_device('GPU') +def parallel_concat(*args): + return parallel_concat_op(*args) + + +def parameterized_truncated_normal(*args): + op = _get_cache_prim(ParameterizedTruncatedNormal)(*args[-2:]).set_device('GPU') + return op(*args[:-2]) + + +partial_op = Partial().set_device('GPU') +def partial(*args): + return partial_op(*args) + + +def pdist(*args): + op = _get_cache_prim(Pdist)(*args[-1:]).set_device('GPU') + return op(*args[:-1]) + + +def poisson(*args): + op = _get_cache_prim(Poisson)(*args[-2:]).set_device('GPU') + return op(*args[:-2]) + + +polar_op = Polar().set_device('GPU') +def polar(*args): + return polar_op(*args) + + +polygamma_op = Polygamma().set_device('GPU') +def polygamma(*args): + return polygamma_op(*args) + + +population_count_op = PopulationCount().set_device('GPU') +def population_count(*args): + return population_count_op(*args) + + +pow_op = Pow().set_device('GPU') +def pow(*args): + return pow_op(*args) + + +pull_op = Pull().set_device('GPU') +def pull(*args): + return pull_op(*args) + + +def push(*args): + op = _get_cache_prim(Push)(*args[-2:]).set_device('GPU') + return op(*args[:-2]) + + +py_execute_op = PyExecute().set_device('GPU') +def py_execute(*args): + return py_execute_op(*args) + + +def py_func(*args): + op = _get_cache_prim(PyFunc)(*args[-6:]).set_device('GPU') + return op(*args[:-6]) + + +def qr(*args): + op = _get_cache_prim(Qr)(*args[-1:]).set_device('GPU') + return op(*args[:-1]) + + +def quantile(*args): + op = _get_cache_prim(Quantile)(*args[-3:]).set_device('GPU') + return op(*args[:-3]) + + +rgb_to_hsv_op = RGBToHSV().set_device('GPU') +def rgb_to_hsv(*args): + return rgb_to_hsv_op(*args) + + +def rnnt_loss(*args): + op = _get_cache_prim(RNNTLoss)(*args[-1:]).set_device('GPU') + return op(*args[:-1]) + + +def roi_align(*args): + op = _get_cache_prim(ROIAlign)(*args[-5:]).set_device('GPU') + return op(*args[:-5]) + + +def ragged_range(*args): + op = _get_cache_prim(RaggedRange)(*args[-1:]).set_device('GPU') + return op(*args[:-1]) + + +def random_categorical(*args): + op = _get_cache_prim(RandomCategorical)(*args[-1:]).set_device('GPU') + return op(*args[:-1]) + + +def random_choice_with_mask(*args): + op = _get_cache_prim(RandomChoiceWithMask)(*args[-3:]).set_device('GPU') + return op(*args[:-3]) + + +def random_gamma(*args): + op = _get_cache_prim(RandomGamma)(*args[-2:]).set_device('GPU') + return op(*args[:-2]) + + +def random_gamma(*args): + op = _get_cache_prim(RandomGamma)(*args[-2:]).set_device('GPU') + return op(*args[:-2]) + + +def random_poisson(*args): + op = _get_cache_prim(RandomPoisson)(*args[-3:]).set_device('GPU') + return op(*args[:-3]) + + +def random_shuffle(*args): + op = _get_cache_prim(RandomShuffle)(*args[-2:]).set_device('GPU') + return op(*args[:-2]) + + +def randperm(*args): + op = _get_cache_prim(Randperm)(*args[-3:]).set_device('GPU') + return op(*args[:-3]) + + +def randperm_v2(*args): + op = _get_cache_prim(RandpermV2)(*args[-3:]).set_device('GPU') + return op(*args[:-3]) + + +def range(*args): + op = _get_cache_prim(Range)(*args[-1:]).set_device('GPU') + return op(*args[:-1]) + + +rank_op = Rank().set_device('GPU') +def rank(*args): + return rank_op(*args) + + +re_lu_op = ReLU().set_device('GPU') +def re_lu(*args): + return re_lu_op(*args) + + +re_lu6_op = ReLU6().set_device('GPU') +def re_lu6(*args): + return re_lu6_op(*args) + + +real_op = Real().set_device('GPU') +def real(*args): + return real_op(*args) + + +real_div_op = RealDiv().set_device('GPU') +def real_div(*args): + return real_div_op(*args) + + +def receive(*args): + op = _get_cache_prim(Receive)(*args[-6:]).set_device('GPU') + return op(*args[:-6]) + + +reciprocal_op = Reciprocal().set_device('GPU') +def reciprocal(*args): + return reciprocal_op(*args) + + +def reduce(*args): + op = _get_cache_prim(Reduce)(*args[-3:]).set_device('GPU') + return op(*args[:-3]) + + +def reduce_all(*args): + op = _get_cache_prim(ReduceAll)(*args[-1:]).set_device('GPU') + return op(*args[:-1]) + + +def reduce_any(*args): + op = _get_cache_prim(ReduceAny)(*args[-1:]).set_device('GPU') + return op(*args[:-1]) + + +def reduce_max(*args): + op = _get_cache_prim(ReduceMax)(*args[-1:]).set_device('GPU') + return op(*args[:-1]) + + +def reduce_mean(*args): + op = _get_cache_prim(ReduceMean)(*args[-1:]).set_device('GPU') + return op(*args[:-1]) + + +def reduce_min(*args): + op = _get_cache_prim(ReduceMin)(*args[-1:]).set_device('GPU') + return op(*args[:-1]) + + +def reduce_prod(*args): + op = _get_cache_prim(ReduceProd)(*args[-1:]).set_device('GPU') + return op(*args[:-1]) + + +def reduce_scatter(*args): + op = _get_cache_prim(ReduceScatter)(*args[-2:]).set_device('GPU') + return op(*args[:-2]) + + +def reduce_std(*args): + op = _get_cache_prim(ReduceStd)(*args[-3:]).set_device('GPU') + return op(*args[:-3]) + + +def reduce_sum(*args): + op = _get_cache_prim(ReduceSum)(*args[-2:]).set_device('GPU') + return op(*args[:-2]) + + +def renorm(*args): + op = _get_cache_prim(Renorm)(*args[-3:]).set_device('GPU') + return op(*args[:-3]) + + +reshape_op = Reshape().set_device('GPU') +def reshape(*args): + return reshape_op(*args) + + +reshape_and_cache_op = ReshapeAndCache().set_device('GPU') +def reshape_and_cache(*args): + return reshape_and_cache_op(*args) + + +def reshard(*args): + op = _get_cache_prim(Reshard)(*args[-3:]).set_device('GPU') + return op(*args[:-3]) + + +def resize_area(*args): + op = _get_cache_prim(ResizeArea)(*args[-1:]).set_device('GPU') + return op(*args[:-1]) + + +def resize_bicubic(*args): + op = _get_cache_prim(ResizeBicubic)(*args[-2:]).set_device('GPU') + return op(*args[:-2]) + + +def resize_bilinear_v2(*args): + op = _get_cache_prim(ResizeBilinearV2)(*args[-2:]).set_device('GPU') + return op(*args[:-2]) + + +def resize_linear1_d(*args): + op = _get_cache_prim(ResizeLinear1D)(*args[-1:]).set_device('GPU') + return op(*args[:-1]) + + +def resize_nearest_neighbor(*args): + op = _get_cache_prim(ResizeNearestNeighbor)(*args[-3:]).set_device('GPU') + return op(*args[:-3]) + + +def resize_nearest_neighbor_v2(*args): + op = _get_cache_prim(ResizeNearestNeighborV2)(*args[-2:]).set_device('GPU') + return op(*args[:-2]) + + +reusing_op = Reusing().set_device('GPU') +def reusing(*args): + return reusing_op(*args) + + +def reverse_sequence(*args): + op = _get_cache_prim(ReverseSequence)(*args[-2:]).set_device('GPU') + return op(*args[:-2]) + + +def reverse_v2(*args): + op = _get_cache_prim(ReverseV2)(*args[-1:]).set_device('GPU') + return op(*args[:-1]) + + +right_shift_op = RightShift().set_device('GPU') +def right_shift(*args): + return right_shift_op(*args) + + +rint_op = Rint().set_device('GPU') +def rint(*args): + return rint_op(*args) + + +def rms_norm(*args): + op = _get_cache_prim(RmsNorm)(*args[-1:]).set_device('GPU') + return op(*args[:-1]) + + +def roll(*args): + op = _get_cache_prim(Roll)(*args[-2:]).set_device('GPU') + return op(*args[:-2]) + + +round_op = Round().set_device('GPU') +def round(*args): + return round_op(*args) + + +rsqrt_op = Rsqrt().set_device('GPU') +def rsqrt(*args): + return rsqrt_op(*args) + + +def sgd(*args): + op = _get_cache_prim(SGD)(*args[-3:]).set_device('GPU') + return op(*args[:-3]) + + +def stft(*args): + op = _get_cache_prim(STFT)(*args[-6:]).set_device('GPU') + return op(*args[:-6]) + + +def sample_distorted_bounding_box_v2(*args): + op = _get_cache_prim(SampleDistortedBoundingBoxV2)(*args[-6:]).set_device('GPU') + return op(*args[:-6]) + + +scalar_summary_op = ScalarSummary().set_device('GPU') +def scalar_summary(*args): + return scalar_summary_op(*args) + + +scalar_to_tensor_op = ScalarToTensor().set_device('GPU') +def scalar_to_tensor(*args): + return scalar_to_tensor_op(*args) + + +def scale_and_translate(*args): + op = _get_cache_prim(ScaleAndTranslate)(*args[-2:]).set_device('GPU') + return op(*args[:-2]) + + +scan_op = Scan().set_device('GPU') +def scan(*args): + return scan_op(*args) + + +def scatter_add(*args): + op = _get_cache_prim(ScatterAdd)(*args[-1:]).set_device('GPU') + return op(*args[:-1]) + + +def scatter_add_with_axis(*args): + op = _get_cache_prim(ScatterAddWithAxis)(*args[-1:]).set_device('GPU') + return op(*args[:-1]) + + +def scatter_div(*args): + op = _get_cache_prim(ScatterDiv)(*args[-1:]).set_device('GPU') + return op(*args[:-1]) + + +def scatter_max(*args): + op = _get_cache_prim(ScatterMax)(*args[-1:]).set_device('GPU') + return op(*args[:-1]) + + +def scatter_min(*args): + op = _get_cache_prim(ScatterMin)(*args[-1:]).set_device('GPU') + return op(*args[:-1]) + + +def scatter_mul(*args): + op = _get_cache_prim(ScatterMul)(*args[-1:]).set_device('GPU') + return op(*args[:-1]) + + +scatter_nd_op = ScatterNd().set_device('GPU') +def scatter_nd(*args): + return scatter_nd_op(*args) + + +def scatter_nd_add(*args): + op = _get_cache_prim(ScatterNdAdd)(*args[-1:]).set_device('GPU') + return op(*args[:-1]) + + +def scatter_nd_div(*args): + op = _get_cache_prim(ScatterNdDiv)(*args[-1:]).set_device('GPU') + return op(*args[:-1]) + + +def scatter_nd_max(*args): + op = _get_cache_prim(ScatterNdMax)(*args[-1:]).set_device('GPU') + return op(*args[:-1]) + + +def scatter_nd_min(*args): + op = _get_cache_prim(ScatterNdMin)(*args[-1:]).set_device('GPU') + return op(*args[:-1]) + + +def scatter_nd_mul(*args): + op = _get_cache_prim(ScatterNdMul)(*args[-1:]).set_device('GPU') + return op(*args[:-1]) + + +def scatter_nd_sub(*args): + op = _get_cache_prim(ScatterNdSub)(*args[-1:]).set_device('GPU') + return op(*args[:-1]) + + +def scatter_nd_update(*args): + op = _get_cache_prim(ScatterNdUpdate)(*args[-1:]).set_device('GPU') + return op(*args[:-1]) + + +def scatter_sub(*args): + op = _get_cache_prim(ScatterSub)(*args[-1:]).set_device('GPU') + return op(*args[:-1]) + + +def scatter_update(*args): + op = _get_cache_prim(ScatterUpdate)(*args[-1:]).set_device('GPU') + return op(*args[:-1]) + + +se_lu_op = SeLU().set_device('GPU') +def se_lu(*args): + return se_lu_op(*args) + + +def search_sorted(*args): + op = _get_cache_prim(SearchSorted)(*args[-2:]).set_device('GPU') + return op(*args[:-2]) + + +segment_max_op = SegmentMax().set_device('GPU') +def segment_max(*args): + return segment_max_op(*args) + + +segment_mean_op = SegmentMean().set_device('GPU') +def segment_mean(*args): + return segment_mean_op(*args) + + +segment_min_op = SegmentMin().set_device('GPU') +def segment_min(*args): + return segment_min_op(*args) + + +segment_prod_op = SegmentProd().set_device('GPU') +def segment_prod(*args): + return segment_prod_op(*args) + + +segment_sum_op = SegmentSum().set_device('GPU') +def segment_sum(*args): + return segment_sum_op(*args) + + +select_op = Select().set_device('GPU') +def select(*args): + return select_op(*args) + + +select_view_op = SelectView().set_device('GPU') +def select_view(*args): + return select_view_op(*args) + + +def send(*args): + op = _get_cache_prim(Send)(*args[-4:]).set_device('GPU') + return op(*args[:-4]) + + +shape_op = Shape().set_device('GPU') +def shape(*args): + return shape_op(*args) + + +sigmoid_op = Sigmoid().set_device('GPU') +def sigmoid(*args): + return sigmoid_op(*args) + + +sigmoid_cross_entropy_with_logits_op = SigmoidCrossEntropyWithLogits().set_device('GPU') +def sigmoid_cross_entropy_with_logits(*args): + return sigmoid_cross_entropy_with_logits_op(*args) + + +sign_op = Sign().set_device('GPU') +def sign(*args): + return sign_op(*args) + + +sin_op = Sin().set_device('GPU') +def sin(*args): + return sin_op(*args) + + +sinc_op = Sinc().set_device('GPU') +def sinc(*args): + return sinc_op(*args) + + +sinh_op = Sinh().set_device('GPU') +def sinh(*args): + return sinh_op(*args) + + +size_op = Size().set_device('GPU') +def size(*args): + return size_op(*args) + + +slice_op = Slice().set_device('GPU') +def slice(*args): + return slice_op(*args) + + +def smooth_l1_loss(*args): + op = _get_cache_prim(SmoothL1Loss)(*args[-2:]).set_device('GPU') + return op(*args[:-2]) + + +def soft_margin_loss(*args): + op = _get_cache_prim(SoftMarginLoss)(*args[-1:]).set_device('GPU') + return op(*args[:-1]) + + +def soft_shrink(*args): + op = _get_cache_prim(SoftShrink)(*args[-1:]).set_device('GPU') + return op(*args[:-1]) + + +def softmax(*args): + op = _get_cache_prim(Softmax)(*args[-1:]).set_device('GPU') + return op(*args[:-1]) + + +softmax_cross_entropy_with_logits_op = SoftmaxCrossEntropyWithLogits().set_device('GPU') +def softmax_cross_entropy_with_logits(*args): + return softmax_cross_entropy_with_logits_op(*args) + + +softplus_op = Softplus().set_device('GPU') +def softplus(*args): + return softplus_op(*args) + + +softsign_op = Softsign().set_device('GPU') +def softsign(*args): + return softsign_op(*args) + + +def sort(*args): + op = _get_cache_prim(Sort)(*args[-2:]).set_device('GPU') + return op(*args[:-2]) + + +def space_to_batch(*args): + op = _get_cache_prim(SpaceToBatch)(*args[-2:]).set_device('GPU') + return op(*args[:-2]) + + +def space_to_batch_nd(*args): + op = _get_cache_prim(SpaceToBatchND)(*args[-2:]).set_device('GPU') + return op(*args[:-2]) + + +def space_to_depth(*args): + op = _get_cache_prim(SpaceToDepth)(*args[-1:]).set_device('GPU') + return op(*args[:-1]) + + +def sparse_apply_adadelta(*args): + op = _get_cache_prim(SparseApplyAdadelta)(*args[-2:]).set_device('GPU') + return op(*args[:-2]) + + +def sparse_apply_adagrad(*args): + op = _get_cache_prim(SparseApplyAdagrad)(*args[-3:]).set_device('GPU') + return op(*args[:-3]) + + +def sparse_apply_adagrad_v2(*args): + op = _get_cache_prim(SparseApplyAdagradV2)(*args[-4:]).set_device('GPU') + return op(*args[:-4]) + + +def sparse_apply_ftrl(*args): + op = _get_cache_prim(SparseApplyFtrl)(*args[-5:]).set_device('GPU') + return op(*args[:-5]) + + +def sparse_apply_ftrl_v2(*args): + op = _get_cache_prim(SparseApplyFtrlV2)(*args[-6:]).set_device('GPU') + return op(*args[:-6]) + + +def sparse_apply_proximal_adagrad(*args): + op = _get_cache_prim(SparseApplyProximalAdagrad)(*args[-1:]).set_device('GPU') + return op(*args[:-1]) + + +def sparse_apply_rms_prop(*args): + op = _get_cache_prim(SparseApplyRMSProp)(*args[-4:]).set_device('GPU') + return op(*args[:-4]) + + +sparse_gather_v2_op = SparseGatherV2().set_device('GPU') +def sparse_gather_v2(*args): + return sparse_gather_v2_op(*args) + + +sparse_slice_op = SparseSlice().set_device('GPU') +def sparse_slice(*args): + return sparse_slice_op(*args) + + +def sparse_softmax_cross_entropy_with_logits(*args): + op = _get_cache_prim(SparseSoftmaxCrossEntropyWithLogits)(*args[-1:]).set_device('GPU') + return op(*args[:-1]) + + +sparse_tensor_dense_add_op = SparseTensorDenseAdd().set_device('GPU') +def sparse_tensor_dense_add(*args): + return sparse_tensor_dense_add_op(*args) + + +def sparse_tensor_dense_matmul(*args): + op = _get_cache_prim(SparseTensorDenseMatmul)(*args[-2:]).set_device('GPU') + return op(*args[:-2]) + + +sparse_to_dense_op = SparseToDense().set_device('GPU') +def sparse_to_dense(*args): + return sparse_to_dense_op(*args) + + +def split(*args): + op = _get_cache_prim(Split)(*args[-2:]).set_device('GPU') + return op(*args[:-2]) + + +def split_v(*args): + op = _get_cache_prim(SplitV)(*args[-3:]).set_device('GPU') + return op(*args[:-3]) + + +sqrt_op = Sqrt().set_device('GPU') +def sqrt(*args): + return sqrt_op(*args) + + +square_op = Square().set_device('GPU') +def square(*args): + return square_op(*args) + + +square_sum_all_op = SquareSumAll().set_device('GPU') +def square_sum_all(*args): + return square_sum_all_op(*args) + + +squared_difference_op = SquaredDifference().set_device('GPU') +def squared_difference(*args): + return squared_difference_op(*args) + + +def squeeze(*args): + op = _get_cache_prim(Squeeze)(*args[-1:]).set_device('GPU') + return op(*args[:-1]) + + +def stack(*args): + op = _get_cache_prim(Stack)(*args[-1:]).set_device('GPU') + return op(*args[:-1]) + + +def standard_laplace(*args): + op = _get_cache_prim(StandardLaplace)(*args[-2:]).set_device('GPU') + return op(*args[:-2]) + + +def standard_normal(*args): + op = _get_cache_prim(StandardNormal)(*args[-2:]).set_device('GPU') + return op(*args[:-2]) + + +stop_gradient_op = StopGradient().set_device('GPU') +def stop_gradient(*args): + return stop_gradient_op(*args) + + +def strided_slice(*args): + op = _get_cache_prim(StridedSlice)(*args[-5:]).set_device('GPU') + return op(*args[:-5]) + + +sub_op = Sub().set_device('GPU') +def sub(*args): + return sub_op(*args) + + +sub_and_filter_op = SubAndFilter().set_device('GPU') +def sub_and_filter(*args): + return sub_and_filter_op(*args) + + +def svd(*args): + op = _get_cache_prim(Svd)(*args[-2:]).set_device('GPU') + return op(*args[:-2]) + + +tan_op = Tan().set_device('GPU') +def tan(*args): + return tan_op(*args) + + +tanh_op = Tanh().set_device('GPU') +def tanh(*args): + return tanh_op(*args) + + +def tensor_dump(*args): + op = _get_cache_prim(TensorDump)(*args[-1:]).set_device('GPU') + return op(*args[:-1]) + + +tensor_scatter_add_op = TensorScatterAdd().set_device('GPU') +def tensor_scatter_add(*args): + return tensor_scatter_add_op(*args) + + +tensor_scatter_div_op = TensorScatterDiv().set_device('GPU') +def tensor_scatter_div(*args): + return tensor_scatter_div_op(*args) + + +def tensor_scatter_elements(*args): + op = _get_cache_prim(TensorScatterElements)(*args[-2:]).set_device('GPU') + return op(*args[:-2]) + + +tensor_scatter_max_op = TensorScatterMax().set_device('GPU') +def tensor_scatter_max(*args): + return tensor_scatter_max_op(*args) + + +tensor_scatter_min_op = TensorScatterMin().set_device('GPU') +def tensor_scatter_min(*args): + return tensor_scatter_min_op(*args) + + +tensor_scatter_mul_op = TensorScatterMul().set_device('GPU') +def tensor_scatter_mul(*args): + return tensor_scatter_mul_op(*args) + + +tensor_scatter_sub_op = TensorScatterSub().set_device('GPU') +def tensor_scatter_sub(*args): + return tensor_scatter_sub_op(*args) + + +tensor_scatter_update_op = TensorScatterUpdate().set_device('GPU') +def tensor_scatter_update(*args): + return tensor_scatter_update_op(*args) + + +tensor_shape_op = TensorShape().set_device('GPU') +def tensor_shape(*args): + return tensor_shape_op(*args) + + +tensor_summary_op = TensorSummary().set_device('GPU') +def tensor_summary(*args): + return tensor_summary_op(*args) + + +tile_op = Tile().set_device('GPU') +def tile(*args): + return tile_op(*args) + + +def top_k(*args): + op = _get_cache_prim(TopK)(*args[-1:]).set_device('GPU') + return op(*args[:-1]) + + +trace_op = Trace().set_device('GPU') +def trace(*args): + return trace_op(*args) + + +transpose_op = Transpose().set_device('GPU') +def transpose(*args): + return transpose_op(*args) + + +transpose_ext_view_op = TransposeExtView().set_device('GPU') +def transpose_ext_view(*args): + return transpose_ext_view_op(*args) + + +transpose_view_op = TransposeView().set_device('GPU') +def transpose_view(*args): + return transpose_view_op(*args) + + +tridiagonal_mat_mul_op = TridiagonalMatMul().set_device('GPU') +def tridiagonal_mat_mul(*args): + return tridiagonal_mat_mul_op(*args) + + +def tril(*args): + op = _get_cache_prim(Tril)(*args[-1:]).set_device('GPU') + return op(*args[:-1]) + + +def tril_indices(*args): + op = _get_cache_prim(TrilIndices)(*args[-4:]).set_device('GPU') + return op(*args[:-4]) + + +def triplet_margin_loss(*args): + op = _get_cache_prim(TripletMarginLoss)(*args[-4:]).set_device('GPU') + return op(*args[:-4]) + + +def triu(*args): + op = _get_cache_prim(Triu)(*args[-1:]).set_device('GPU') + return op(*args[:-1]) + + +def triu_indices(*args): + op = _get_cache_prim(TriuIndices)(*args[-4:]).set_device('GPU') + return op(*args[:-4]) + + +trunc_op = Trunc().set_device('GPU') +def trunc(*args): + return trunc_op(*args) + + +truncate_div_op = TruncateDiv().set_device('GPU') +def truncate_div(*args): + return truncate_div_op(*args) + + +truncate_mod_op = TruncateMod().set_device('GPU') +def truncate_mod(*args): + return truncate_mod_op(*args) + + +def truncated_normal(*args): + op = _get_cache_prim(TruncatedNormal)(*args[-3:]).set_device('GPU') + return op(*args[:-3]) + + +tuple_to_array_op = TupleToArray().set_device('GPU') +def tuple_to_array(*args): + return tuple_to_array_op(*args) + + +def uniform_candidate_sampler(*args): + op = _get_cache_prim(UniformCandidateSampler)(*args[-6:]).set_device('GPU') + return op(*args[:-6]) + + +def uniform_int(*args): + op = _get_cache_prim(UniformInt)(*args[-2:]).set_device('GPU') + return op(*args[:-2]) + + +def uniform_real(*args): + op = _get_cache_prim(UniformReal)(*args[-2:]).set_device('GPU') + return op(*args[:-2]) + + +unique_op = Unique().set_device('GPU') +def unique(*args): + return unique_op(*args) + + +def unique_consecutive(*args): + op = _get_cache_prim(UniqueConsecutive)(*args[-3:]).set_device('GPU') + return op(*args[:-3]) + + + +def unpack(*args): + op = _get_cache_prim(Unpack)(*args[-1:]).set_device('GPU') + return op(*args[:-1]) + + +unravel_index_op = UnravelIndex().set_device('GPU') +def unravel_index(*args): + return unravel_index_op(*args) + + +unsorted_segment_max_op = UnsortedSegmentMax().set_device('GPU') +def unsorted_segment_max(*args): + return unsorted_segment_max_op(*args) + + +unsorted_segment_min_op = UnsortedSegmentMin().set_device('GPU') +def unsorted_segment_min(*args): + return unsorted_segment_min_op(*args) + + +unsorted_segment_prod_op = UnsortedSegmentProd().set_device('GPU') +def unsorted_segment_prod(*args): + return unsorted_segment_prod_op(*args) + + +unsorted_segment_sum_op = UnsortedSegmentSum().set_device('GPU') +def unsorted_segment_sum(*args): + return unsorted_segment_sum_op(*args) + + +def unstack(*args): + op = _get_cache_prim(Unstack)(*args[-2:]).set_device('GPU') + return op(*args[:-2]) + + +update_state_op = UpdateState().set_device('GPU') +def update_state(*args): + return update_state_op(*args) + + +def upper_bound(*args): + op = _get_cache_prim(UpperBound)(*args[-1:]).set_device('GPU') + return op(*args[:-1]) + + +upsample_nearest3_d_op = UpsampleNearest3D().set_device('GPU') +def upsample_nearest3_d(*args): + return upsample_nearest3_d_op(*args) + + +def upsample_trilinear3_d(*args): + op = _get_cache_prim(UpsampleTrilinear3D)(*args[-1:]).set_device('GPU') + return op(*args[:-1]) + + +while_loop_op = WhileLoop().set_device('GPU') +def while_loop(*args): + return while_loop_op(*args) + + +xdivy_op = Xdivy().set_device('GPU') +def xdivy(*args): + return xdivy_op(*args) + + +xlogy_op = Xlogy().set_device('GPU') +def xlogy(*args): + return xlogy_op(*args) + + +zeros_op = Zeros().set_device('GPU') +def zeros(*args): + return zeros_op(*args) + + +zeros_like_op = ZerosLike().set_device('GPU') +def zeros_like(*args): + return zeros_like_op(*args) + + +zeta_op = Zeta().set_device('GPU') +def zeta(*args): + return zeta_op(*args) diff --git a/torch4ms/aten/op_map.py b/torch4ms/aten/op_map.py new file mode 100644 index 000000000..202a50b05 --- /dev/null +++ b/torch4ms/aten/op_map.py @@ -0,0 +1,58 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Torch ops implemented using jax.""" + +import functools + +import math +from mindspore import ops, mint +import functools +import torch + +from .op_register import register_torch_dispatch_op + +# Keys are OpOverload, value is a callable that takes +# Tensor +all_ops = {} + + +def op(*aten, **kwargs): + def inner(func): + for a in aten: + register_torch_dispatch_op(a, func, **kwargs) + return func + + return inner + + +@op( + torch.ops.aten.view_copy, + torch.ops.aten.view, + torch.ops.aten._unsafe_view, + torch.ops.aten.reshape, +) +def _aten_unsafe_view(x, shape): + return mint.reshape(x, shape) + + +@op(torch.ops.aten.add.Tensor) +@op(torch.ops.aten.add.Scalar) +def _aten_add(x, y, *, alpha=1): + """if isinstance(x, jnp.ndarray) and isinstance(y, jnp.ndarray): + + assert x.dtype == y.dtype, (x.dtype, y.dtype) + """ + res = mint.add(x, y, alpha=alpha) + return res diff --git a/torch4ms/aten/op_register.py b/torch4ms/aten/op_register.py new file mode 100644 index 000000000..456cc5746 --- /dev/null +++ b/torch4ms/aten/op_register.py @@ -0,0 +1,26 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + + +all_aten_ops = {} + + +def register_torch_dispatch_op(aten_op, + impl_callable): + if aten_op in all_aten_ops: + logging.warning(f'Duplicate op registration for {aten_op}') + all_aten_ops[aten_op] = impl_callable + return impl_callable diff --git a/torch4ms/aten/utils.py b/torch4ms/aten/utils.py new file mode 100644 index 000000000..aaef376f2 --- /dev/null +++ b/torch4ms/aten/utils.py @@ -0,0 +1,34 @@ +import mindspore +import torch + +PT2MS_DTYPE_MAP = { + torch.bool: mindspore.bool_, + torch.int8: mindspore.int8, + torch.int16: mindspore.int16, + torch.int32: mindspore.int32, + torch.int64: mindspore.int64, + torch.long: mindspore.int64, + torch.uint8: mindspore.uint8, + torch.uint16: mindspore.uint16, + torch.uint32: mindspore.uint32, + torch.uint64: mindspore.uint64, + torch.float8_e4m3fn: mindspore.float8_e4m3fn, + torch.float8_e5m2: mindspore.float8_e5m2, + torch.bfloat16: mindspore.bfloat16, + torch.half: mindspore.float16, + torch.float16: mindspore.float16, + torch.float32: mindspore.float32, + torch.float64: mindspore.float64, + torch.double: mindspore.double, + torch.complex64: mindspore.complex64, + torch.complex128: mindspore.complex128, + None: None, +} + +MS2PT_DTYPE_MAP = {v: k for k, v in PT2MS_DTYPE_MAP.items()} + +def ms2pt_dtype(ms_dtype): + return MS2PT_DTYPE_MAP[ms_dtype] + +def pt2ms_dtype(pt_dtype): + return PT2MS_DTYPE_MAP[pt_dtype] diff --git a/torch4ms/distributed/__init__.py b/torch4ms/distributed/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/torch4ms/tensor.py b/torch4ms/tensor.py new file mode 100644 index 000000000..0f313e866 --- /dev/null +++ b/torch4ms/tensor.py @@ -0,0 +1,189 @@ +import torch +import torch.utils._python_dispatch as torch_dispatch + +import mindspore +from torch4ms.aten.utils import ms2pt_dtype + + +class Tensor(torch.Tensor): + + @staticmethod + def __new__(cls, data, requires_grad=False): + dtype = ms2pt_dtype(data.dtype) + shape = data.shape + + if dtype is None: + dtype = torch.float32 + + if not (dtype.is_floating_point or dtype.is_complex): + requires_grad = False + + return torch.Tensor._make_wrapper_subclass( + cls, + shape, + dtype=dtype, + device="meta", + requires_grad=requires_grad, + ) + + def __init__(self, data: mindspore.Tensor, requires_grad=False): + super().__init__() + self._data = data + + def __str__(self): + return "Tensor({})".format(self._data) + + __repr__ = __str__ + + @property + def shape(self): + return torch.Size(self._data.shape) + + @property + def ndim(self): + return self._data.ndim + + __torch_function__ = torch._C._disabled_torch_function_impl + + @classmethod + def __torch_dispatch__(cls, func, types, args=(), kwargs=None): + # TODO(hanq): figure out why is dispatch mode not sufficient + if func == torch.ops.prim.device.default: + return torch.device("privateuseone", 0) + raise AssertionError( + "torchax Tensors can only do math within the torchax environment." + "Please wrap your code with `with torchax.default_env()` or " + "call torchax.enable_globally() before." + ) + + def numpy(self): + return self._data.asnumpy() + + def mindspore(self): + return self._data + + @property + def dtype(self): + return ms2pt_dtype(self._data.dtype) + + def dim(self): + return self.ndim + + @property + def device(self): + return torch.device("ms:0") + + @property + def ms_device(self): + return self._data.device + + def tolist(self): + return self._data.tolist() + + +class MSFunctionMode(torch.overrides.TorchFunctionMode): + """Context manager that dispatches torch function calls to MindSpore.""" + + def __torch_function__(self, func, types, args=(), kwargs=None) -> torch.Tensor: + try: + return dispatch(func, types, args, kwargs) + except Exception: + pass + return func(*args, **(kwargs or {})) + + +class MSDispatchMode(torch_dispatch.TorchDispatchMode): + def __torch_dispatch__(self, func, types, args=(), kwargs=None): + print(func, types, args, kwargs) + if isinstance(func, torch._ops.OpOverloadPacket): + with self: + return func(*args, **kwargs) + # Only functions under these namespaces will be intercepted + if func.namespace not in ( + "aten", + "_c10d_functional", + "torchvision", + "xla", + ): + return func(*args, **kwargs) + return dispatch(func, types, args, kwargs) + + +def dispatch(func, types, args, kwargs): + kwargs = kwargs or {} + # if func in TENSOR_CONSTRUCTORS: + # return self._handle_tensor_constructor(func, args, kwargs) + if func in ( + torch.Tensor.to, + torch.ops.aten.lift_fresh.default, + torch.ops.aten._to_copy, + torch.ops.aten._to_copy.default, + ): + print(func, types, args, kwargs) + return self._torch_Tensor_to(args, kwargs) + + # If the func doesn't act on Tensor, and is not a tensor constructor, + # We should skip and let torch handle it. + + tensor_args = [ + t for t in torch_pytree.tree_flatten(args)[0] if isinstance(t, torch.Tensor) + ] + + def is_not_torchax_tensor(x): + return not isinstance(x, Tensor) and not isinstance(x, View) + + if tensor_args and all(is_not_torchax_tensor(t) for t in tensor_args): + res = func(*args, **kwargs) + return res + + with jax.named_scope(_name_of_func(func)): + op = self._get_op_or_decomp(func) + + old_args, old_kwargs = args, kwargs + with self._dispatch_mode: + args, kwargs = torch_pytree.tree_map_only( + torch.distributed._functional_collectives.AsyncCollectiveTensor, + torch.distributed._functional_collectives.wait_tensor, + (args, kwargs), + ) + + try: + if not op.is_view_op: + args, kwargs = self.v2t_iso((args, kwargs)) + + with self: + if self.param.autocast_dtype is not None: + autocast_policy = amp.autocast_policy.get(func) + if autocast_policy is not None: + args, kwargs = amp.execute_policy( + autocast_policy, args, kwargs, self.param.autocast_dtype + ) + + if op.is_jax_function: + args, kwargs = self.t2j_iso((args, kwargs)) + except AssertionError: + if self.config.debug_mixed_tensor: + breakpoint() + else: + raise + + if op.needs_env: + kwargs["env"] = self + + if op.is_jax_function: + res = op.func(*args, **kwargs) + else: + # enable dispatch mode because this op could be a composite autograd op + # meaning, it will decompose in C++ + with self._dispatch_mode: + res = op.func(*args, **kwargs) + + if op.is_jax_function: + res = self.j2t_iso(res) + + if self.config.force_materialize_views and isinstance(res, View): + res = res.torch() + + if self.config.debug_accuracy_for_each_op: + debug_accuracy(func, old_args, old_kwargs, res) + return res