Skip to content

Commit

Permalink
Fix data type issues with log1p() op (#151)
Browse files Browse the repository at this point in the history
  • Loading branch information
razarmehr authored and kulinseth committed Nov 5, 2022
1 parent a6b04bc commit e2f5d91
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 44 deletions.
55 changes: 13 additions & 42 deletions aten/src/ATen/native/mps/operations/UnaryOps.mm
Original file line number Diff line number Diff line change
Expand Up @@ -195,50 +195,21 @@ void unary_op(const Tensor& self, const Tensor& output, std::string op_name, Una

TORCH_IMPL_FUNC(log1p_out_mps) (const Tensor& self, const Tensor& output)
{
TORCH_CHECK(self.scalar_type() != ScalarType::Long, "MPS does not support log1p op with int64 input")
using namespace mps;
if (!output.is_same_size(self)) {
output.resize_(self.sizes());
}
MPSGraphCache* cache_ = MPSGraphCache::getInstance();
@autoreleasepool {
string key = string("log1p_out_mps") + getTensorsStringKey({self});
auto cachedGraph = cache_->LookUpAs<MPSUnaryCachedGraph>(key);

if(!cachedGraph) {
MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ MPSCachedGraph* () {
MPSUnaryCachedGraph *newCachedGraph = nil;
@autoreleasepool {
MPSGraph* mpsGraph = make_mps_graph();
newCachedGraph = new MPSUnaryCachedGraph(mpsGraph);
newCachedGraph->inputTensor_ = mpsGraphRankedPlaceHolder(mpsGraph, self);
MPSGraphTensor* oneTensor = [mpsGraph constantWithScalar:1.0
shape:getMPSShape(self)
dataType:mps::getMPSDataType(self.scalar_type())];
MPSGraphTensor* addedTensor = [mpsGraph additionWithPrimaryTensor:newCachedGraph->inputTensor_
secondaryTensor:oneTensor
name:nil];
newCachedGraph->outputTensor_ = [mpsGraph logarithmWithTensor:addedTensor
name:nil];
}
return newCachedGraph;
});
cachedGraph = tmpCachedGraph->as<MPSUnaryCachedGraph>();
}

Placeholder selfPlaceholder = Placeholder(cachedGraph->inputTensor_, self);
Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor_, output);
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* feeds = @{
selfPlaceholder.getMPSGraphTensor() : selfPlaceholder.getMPSGraphTensorData()
};
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results = @{
outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()
};
runMPSGraph(getCurrentMPSStream(), cachedGraph->graph(), feeds, results);
}
TORCH_CHECK(self.scalar_type() != ScalarType::Long, "MPS does not support log1p op with int64 input");
mps::unary_op(self, output, "log1p_out_mps",
^ MPSGraphTensor* (MPSGraph* mpsGraph, MPSGraphTensor* inputTensor) {
MPSGraphTensor* oneTensor = [mpsGraph constantWithScalar:1.0
dataType:inputTensor.dataType];
MPSGraphTensor* addedTensor = [mpsGraph additionWithPrimaryTensor:inputTensor
secondaryTensor:oneTensor
name:nil];
return [mpsGraph logarithmWithTensor:addedTensor
name:nil];
});
}

TORCH_IMPL_FUNC(frac_out_mps) (const Tensor& self, const Tensor& output) {
TORCH_IMPL_FUNC(frac_out_mps) (const Tensor& self, const Tensor& output)
{
TORCH_CHECK(isFloatingType(self.scalar_type()), "frac_out_mps is only implemented for floating types");
mps::unary_op(self, output, "frac_out_mps",
^ MPSGraphTensor* (MPSGraph* mpsGraph, MPSGraphTensor* inputTensor) {
Expand Down
4 changes: 2 additions & 2 deletions test/test_mps.py
Original file line number Diff line number Diff line change
Expand Up @@ -7516,6 +7516,7 @@ class TestConsistency(TestCase):
'linspace': ['f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
'log': ['b8', 'f32', 'i16', 'i32', 'u8'],
'log10': ['b8', 'f32', 'i16', 'i32', 'u8'],
'log1p': ['b8', 'f32', 'i16', 'i32', 'u8'],
'log2': ['b8', 'f32', 'i16', 'i32', 'u8'],
'log_softmax': ['f32'],
'logaddexp': ['f32'],
Expand Down Expand Up @@ -7738,6 +7739,7 @@ class TestConsistency(TestCase):
'linspace': ['f16', 'f32'],
'log': ['f32'],
'log10': ['f32'],
'log1p': ['f32'],
'log2': ['f32'],
'log_softmax': ['f32'],
'logaddexp': ['f32'],
Expand Down Expand Up @@ -7858,7 +7860,6 @@ class TestConsistency(TestCase):
'diag_embed': [torch.uint8],
'diagonal_scatter': [torch.uint8],
'index_add': None,
'log1p': None,
'long': None,
'nn.functional.avg_pool1d': [torch.int64],
'nn.functional.avg_pool2d': [torch.int64],
Expand Down Expand Up @@ -7980,7 +7981,6 @@ class TestConsistency(TestCase):
'int': ['torch.bool', 'torch.float16', 'torch.float32', 'torch.int16', 'torch.int64', 'torch.uint8'],
'linalg.eigvals': ['torch.float32'],
'linalg.multi_dot': ['torch.float32'],
'log1p': ['torch.bool', 'torch.int16', 'torch.int32', 'torch.uint8'],
'logical_and': ['torch.bool', 'torch.float16', 'torch.float32', 'torch.int16', 'torch.int32', 'torch.int64', 'torch.uint8'],
'logical_or': ['torch.bool', 'torch.float16', 'torch.float32', 'torch.int16', 'torch.int32', 'torch.int64', 'torch.uint8'],
'logical_xor': ['torch.bool', 'torch.float16', 'torch.float32', 'torch.int16', 'torch.int32', 'torch.int64', 'torch.uint8'],
Expand Down

0 comments on commit e2f5d91

Please sign in to comment.