Skip to content

Commit

Permalink
Add the functionality to dump MPS ops.
Browse files Browse the repository at this point in the history
1. DUMP_MPS_OPS to use LoggingTensor to dump out the ATen ops.
2. Skip running the EXPECTTEST list, as some tests are still
   seg-faulting
  • Loading branch information
kulinseth committed Feb 22, 2023
1 parent ae768d1 commit 580fa4e
Show file tree
Hide file tree
Showing 2 changed files with 137 additions and 5 deletions.
39 changes: 34 additions & 5 deletions test/test_mps.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
import torch.backends.mps
from torch.distributions import Uniform, Exponential
from functools import partial, reduce

from test_mps_utils import LoggingTensor, capture_logs, tracefunc
from torch.testing._internal.common_methods_invocations import (
op_db,
UnaryUfuncInfo,
Expand Down Expand Up @@ -9466,6 +9466,7 @@ class TestConsistency(TestCaseMPS):
'nonzero': ['b8', 'u8', 'f16', 'f32', 'i16', 'i32', 'i64'],
'norm': ['f32', 'f16'],
'normal': ['f16', 'f32'],
'normal_': ['f16', 'f32'],
'ones': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
'ones_like': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
'ormqr': ['f32'],
Expand Down Expand Up @@ -10543,6 +10544,8 @@ class TestConsistency(TestCaseMPS):
# Failures due to unsupported data types on MPS backend
'bfloat16': [torch.bool, torch.float16, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8],
'chalf': [torch.bool, torch.float16, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8],
# Byte tests are failing
'byte': [torch.float16, torch.float32],
'nn.functional.conv1d': [torch.int64],
'nn.functional.conv2d': [torch.int64],
'nn.functional.conv_transpose1d': [torch.int64],
Expand Down Expand Up @@ -10626,12 +10629,14 @@ class TestConsistency(TestCaseMPS):
# Failures due to random output that they generate using
# Philox engine causing mismatch with CPU results
'uniform': [torch.float16, torch.float32],
'randn': [torch.float16, torch.float32],
'rand_like': [torch.float16, torch.float32],
'randint_like': [torch.float16, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8],
'randn_like': [torch.float16, torch.float32],
'bernoulli': [torch.float32],
'nn.functional.feature_alpha_dropoutwith_train': [torch.float32],
'normal': [torch.float16, torch.float32, torch.float16, torch.float32],
'normal_': [torch.float16, torch.float32],
'normalnumber_mean': [torch.float16, torch.float32],
'nn.functional.alpha_dropout': [torch.float32],
'nn.functional.dropout': [torch.float32],
Expand Down Expand Up @@ -10723,6 +10728,7 @@ def compare_with_CUDA(self, op, mps_out, atol, rtol):

@ops(op_db, allowed_dtypes=MPS_DTYPES)
def test_output_match(self, device, dtype, op):
# sys.setprofile(tracefunc)
self.assertEqual(device, "cpu")
if not torch.backends.mps.is_available():
self.skipTest("MPS is not available")
Expand Down Expand Up @@ -10777,6 +10783,10 @@ def get_samples():

# TODO: This checks only the function variant. We should also check the method and inplace version
# when they exist

if os.environ.get("DUMP_MPS_OPS", None) == "1":
mps_sample.input = LoggingTensor(mps_sample.input)

cpu_args = [cpu_sample.input] + list(cpu_sample.args)
cpu_kwargs = cpu_sample.kwargs
mps_args = [mps_sample.input] + list(mps_sample.args)
Expand All @@ -10786,8 +10796,20 @@ def get_samples():
if (op.name == "tensor_split" and isinstance(mps_args[1], torch.Tensor)):
mps_args[1] = cpu_args[1]

cpu_out = op(*cpu_args, **cpu_kwargs)
mps_out = op(*mps_args, **mps_kwargs)
# Skip running the tests to generate full list
if os.environ.get("EXPECTTEST_ACCEPT", None) == "1":
continue

if os.environ.get("DUMP_MPS_OPS", None) == "1":
with capture_logs() as logs:
cpu_out = op(*cpu_args, **cpu_kwargs)
mps_out = op(*mps_args, **mps_kwargs)
print("Forward logs:")
print("\n".join(logs))
else:
cpu_out = op(*cpu_args, **cpu_kwargs)
mps_out = op(*mps_args, **mps_kwargs)


if op.name == "nn.functional.conv2d" or op.name == "linalg.multi_dot" and dtype == torch.float32:
atol = 1e-4
Expand Down Expand Up @@ -10867,8 +10889,15 @@ def req_grad(t):
# Compare computed gradients with cpu given random grad_output vector
# Sometimes when the derivative is 0, we just don't bother creating the graph
# allow_unused is needed in those cases.
cpu_grad_inputs = torch.autograd.grad(diff_cpu_out, diff_cpu_arg, grad_outputs=cpu_grad_outputs, allow_unused=True)
mps_grad_inputs = torch.autograd.grad(diff_mps_out, diff_mps_arg, grad_outputs=mps_grad_outputs, allow_unused=True)
if os.environ.get("DUMP_MPS_OPS", None) == "1":
with capture_logs() as logs:
cpu_grad_inputs = torch.autograd.grad(diff_cpu_out, diff_cpu_arg, grad_outputs=cpu_grad_outputs, allow_unused=True)
mps_grad_inputs = torch.autograd.grad(diff_mps_out, diff_mps_arg, grad_outputs=mps_grad_outputs, allow_unused=True)
print("Backward logs:")
print("\n".join(logs))
else:
cpu_grad_inputs = torch.autograd.grad(diff_cpu_out, diff_cpu_arg, grad_outputs=cpu_grad_outputs, allow_unused=True)
mps_grad_inputs = torch.autograd.grad(diff_mps_out, diff_mps_arg, grad_outputs=mps_grad_outputs, allow_unused=True)

self.assertEqual(cpu_grad_inputs, mps_grad_inputs, atol=atol, rtol=rtol)
except Exception as e:
Expand Down
103 changes: 103 additions & 0 deletions test/test_mps_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
import torch
from torch.testing._internal.common_utils import TestCase, run_tests
from torch.utils._pytree import tree_map

from typing import Iterator, List
import logging
import contextlib
import itertools

class LoggingTensor(torch.Tensor):
elem: torch.Tensor

__slots__ = ['elem']

@staticmethod
def __new__(cls, elem, *args, **kwargs):
# The wrapping tensor (LoggingTensor) shouldn't hold any
# memory for the class in question, but it should still
# advertise the same device as before
r = torch.Tensor._make_wrapper_subclass(
cls, elem.size(),
# TODO: clone strides and storage aliasing
dtype=elem.dtype, layout=elem.layout,
device=elem.device, requires_grad=elem.requires_grad
)
# ...the real tensor is held as an element on the tensor.
r.elem = elem
return r

def __repr__(self):
return f"LoggingTensor({self.elem})"

@classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
def unwrap(e):
return e.elem if isinstance(e, LoggingTensor) else e

def wrap(e):
return LoggingTensor(e) if isinstance(e, torch.Tensor) else e

rs = tree_map(wrap, func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs)))
logging.getLogger("LoggingTensor").info(f"{func.__module__}.{func.__name__}", args, kwargs, rs)
return rs

# https://stackoverflow.com/questions/36408496/python-logging-handler-to-append-to-list
class LoggingTensorHandler(logging.Handler):

def __init__(self, log_list) -> None:
logging.Handler.__init__(self)
self.log_list = log_list
self.next_shortid = 0

# WARNING: not deterministic over multiple threads, this matters for
# autograd
def _shortid(self, o: object) -> int:
if not hasattr(o, '_shortid'):
o._shortid = self.next_shortid
self.next_shortid += 1
return o._shortid

def _fmt(self, a: object) -> str:
if isinstance(a, LoggingTensor):
return f'${self._shortid(a)}'
elif isinstance(a, torch.nn.Parameter):
return f'Parameter(..., size={tuple(a.size())})'
elif isinstance(a, torch.Tensor):
return f'Tensor(..., size={tuple(a.size())})'
else:
return repr(a)

def emit(self, record):
fmt_args = ", ".join(itertools.chain(
(self._fmt(a) for a in record.args[0]),
(f"{k}={self._fmt(v)}" for k, v in record.args[1].items())
))
fmt_rets = ", ".join(self._fmt(a) for a in record.args[2]) \
if isinstance(record.args[2], (list, tuple)) else self._fmt(record.args[2])
self.log_list.append(f'{fmt_rets} = {record.msg}({fmt_args})')

@contextlib.contextmanager
def capture_logs():
logger = logging.getLogger("LoggingTensor")
log_list = []
handler = LoggingTensorHandler(log_list)
logger.addHandler(handler)
logger.setLevel(logging.INFO)
logger.propagate = False
try:
yield log_list
finally:
logger.removeHandler(handler)

def tracefunc(frame, event, arg, indent=[0]):
if event == "call":
indent[0] += 2
print("-" * indent[0] + "> call function", frame.f_code.co_name)
elif event == "return":
print("<" + "-" * indent[0], "exit function", frame.f_code.co_name)
indent[0] -= 2
return tracefunc

import sys

0 comments on commit 580fa4e

Please sign in to comment.