From 580fa4ecd82da104e74cb923bef6bcc98d1de282 Mon Sep 17 00:00:00 2001 From: Kulin Seth Date: Wed, 22 Feb 2023 12:58:56 -0800 Subject: [PATCH] Add the functionality to dump MPS ops. 1. DUMP_MPS_OPS to use LoggingTensor to dump out the ATen ops. 2. Skip running the EXPECTTEST list, as some tests are still seg-faulting --- test/test_mps.py | 39 ++++++++++++++-- test/test_mps_utils.py | 103 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 137 insertions(+), 5 deletions(-) create mode 100644 test/test_mps_utils.py diff --git a/test/test_mps.py b/test/test_mps.py index 8855f1b0d0cb6..e1c2c987e7579 100644 --- a/test/test_mps.py +++ b/test/test_mps.py @@ -33,7 +33,7 @@ import torch.backends.mps from torch.distributions import Uniform, Exponential from functools import partial, reduce - +from test_mps_utils import LoggingTensor, capture_logs, tracefunc from torch.testing._internal.common_methods_invocations import ( op_db, UnaryUfuncInfo, @@ -9466,6 +9466,7 @@ class TestConsistency(TestCaseMPS): 'nonzero': ['b8', 'u8', 'f16', 'f32', 'i16', 'i32', 'i64'], 'norm': ['f32', 'f16'], 'normal': ['f16', 'f32'], + 'normal_': ['f16', 'f32'], 'ones': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], 'ones_like': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], 'ormqr': ['f32'], @@ -10543,6 +10544,8 @@ class TestConsistency(TestCaseMPS): # Failures due to unsupported data types on MPS backend 'bfloat16': [torch.bool, torch.float16, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8], 'chalf': [torch.bool, torch.float16, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8], + # Byte tests are failing + 'byte': [torch.float16, torch.float32], 'nn.functional.conv1d': [torch.int64], 'nn.functional.conv2d': [torch.int64], 'nn.functional.conv_transpose1d': [torch.int64], @@ -10626,12 +10629,14 @@ class TestConsistency(TestCaseMPS): # Failures due to random output that they generate using # Philox engine causing mismatch with CPU results 'uniform': [torch.float16, torch.float32], + 'randn': [torch.float16, torch.float32], 'rand_like': [torch.float16, torch.float32], 'randint_like': [torch.float16, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8], 'randn_like': [torch.float16, torch.float32], 'bernoulli': [torch.float32], 'nn.functional.feature_alpha_dropoutwith_train': [torch.float32], 'normal': [torch.float16, torch.float32, torch.float16, torch.float32], + 'normal_': [torch.float16, torch.float32], 'normalnumber_mean': [torch.float16, torch.float32], 'nn.functional.alpha_dropout': [torch.float32], 'nn.functional.dropout': [torch.float32], @@ -10723,6 +10728,7 @@ def compare_with_CUDA(self, op, mps_out, atol, rtol): @ops(op_db, allowed_dtypes=MPS_DTYPES) def test_output_match(self, device, dtype, op): + # sys.setprofile(tracefunc) self.assertEqual(device, "cpu") if not torch.backends.mps.is_available(): self.skipTest("MPS is not available") @@ -10777,6 +10783,10 @@ def get_samples(): # TODO: This checks only the function variant. We should also check the method and inplace version # when they exist + + if os.environ.get("DUMP_MPS_OPS", None) == "1": + mps_sample.input = LoggingTensor(mps_sample.input) + cpu_args = [cpu_sample.input] + list(cpu_sample.args) cpu_kwargs = cpu_sample.kwargs mps_args = [mps_sample.input] + list(mps_sample.args) @@ -10786,8 +10796,20 @@ def get_samples(): if (op.name == "tensor_split" and isinstance(mps_args[1], torch.Tensor)): mps_args[1] = cpu_args[1] - cpu_out = op(*cpu_args, **cpu_kwargs) - mps_out = op(*mps_args, **mps_kwargs) + # Skip running the tests to generate full list + if os.environ.get("EXPECTTEST_ACCEPT", None) == "1": + continue + + if os.environ.get("DUMP_MPS_OPS", None) == "1": + with capture_logs() as logs: + cpu_out = op(*cpu_args, **cpu_kwargs) + mps_out = op(*mps_args, **mps_kwargs) + print("Forward logs:") + print("\n".join(logs)) + else: + cpu_out = op(*cpu_args, **cpu_kwargs) + mps_out = op(*mps_args, **mps_kwargs) + if op.name == "nn.functional.conv2d" or op.name == "linalg.multi_dot" and dtype == torch.float32: atol = 1e-4 @@ -10867,8 +10889,15 @@ def req_grad(t): # Compare computed gradients with cpu given random grad_output vector # Sometimes when the derivative is 0, we just don't bother creating the graph # allow_unused is needed in those cases. - cpu_grad_inputs = torch.autograd.grad(diff_cpu_out, diff_cpu_arg, grad_outputs=cpu_grad_outputs, allow_unused=True) - mps_grad_inputs = torch.autograd.grad(diff_mps_out, diff_mps_arg, grad_outputs=mps_grad_outputs, allow_unused=True) + if os.environ.get("DUMP_MPS_OPS", None) == "1": + with capture_logs() as logs: + cpu_grad_inputs = torch.autograd.grad(diff_cpu_out, diff_cpu_arg, grad_outputs=cpu_grad_outputs, allow_unused=True) + mps_grad_inputs = torch.autograd.grad(diff_mps_out, diff_mps_arg, grad_outputs=mps_grad_outputs, allow_unused=True) + print("Backward logs:") + print("\n".join(logs)) + else: + cpu_grad_inputs = torch.autograd.grad(diff_cpu_out, diff_cpu_arg, grad_outputs=cpu_grad_outputs, allow_unused=True) + mps_grad_inputs = torch.autograd.grad(diff_mps_out, diff_mps_arg, grad_outputs=mps_grad_outputs, allow_unused=True) self.assertEqual(cpu_grad_inputs, mps_grad_inputs, atol=atol, rtol=rtol) except Exception as e: diff --git a/test/test_mps_utils.py b/test/test_mps_utils.py new file mode 100644 index 0000000000000..e63ef37a5b323 --- /dev/null +++ b/test/test_mps_utils.py @@ -0,0 +1,103 @@ +import torch +from torch.testing._internal.common_utils import TestCase, run_tests +from torch.utils._pytree import tree_map + +from typing import Iterator, List +import logging +import contextlib +import itertools + +class LoggingTensor(torch.Tensor): + elem: torch.Tensor + + __slots__ = ['elem'] + + @staticmethod + def __new__(cls, elem, *args, **kwargs): + # The wrapping tensor (LoggingTensor) shouldn't hold any + # memory for the class in question, but it should still + # advertise the same device as before + r = torch.Tensor._make_wrapper_subclass( + cls, elem.size(), + # TODO: clone strides and storage aliasing + dtype=elem.dtype, layout=elem.layout, + device=elem.device, requires_grad=elem.requires_grad + ) + # ...the real tensor is held as an element on the tensor. + r.elem = elem + return r + + def __repr__(self): + return f"LoggingTensor({self.elem})" + + @classmethod + def __torch_dispatch__(cls, func, types, args=(), kwargs=None): + def unwrap(e): + return e.elem if isinstance(e, LoggingTensor) else e + + def wrap(e): + return LoggingTensor(e) if isinstance(e, torch.Tensor) else e + + rs = tree_map(wrap, func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs))) + logging.getLogger("LoggingTensor").info(f"{func.__module__}.{func.__name__}", args, kwargs, rs) + return rs + +# https://stackoverflow.com/questions/36408496/python-logging-handler-to-append-to-list +class LoggingTensorHandler(logging.Handler): + + def __init__(self, log_list) -> None: + logging.Handler.__init__(self) + self.log_list = log_list + self.next_shortid = 0 + + # WARNING: not deterministic over multiple threads, this matters for + # autograd + def _shortid(self, o: object) -> int: + if not hasattr(o, '_shortid'): + o._shortid = self.next_shortid + self.next_shortid += 1 + return o._shortid + + def _fmt(self, a: object) -> str: + if isinstance(a, LoggingTensor): + return f'${self._shortid(a)}' + elif isinstance(a, torch.nn.Parameter): + return f'Parameter(..., size={tuple(a.size())})' + elif isinstance(a, torch.Tensor): + return f'Tensor(..., size={tuple(a.size())})' + else: + return repr(a) + + def emit(self, record): + fmt_args = ", ".join(itertools.chain( + (self._fmt(a) for a in record.args[0]), + (f"{k}={self._fmt(v)}" for k, v in record.args[1].items()) + )) + fmt_rets = ", ".join(self._fmt(a) for a in record.args[2]) \ + if isinstance(record.args[2], (list, tuple)) else self._fmt(record.args[2]) + self.log_list.append(f'{fmt_rets} = {record.msg}({fmt_args})') + +@contextlib.contextmanager +def capture_logs(): + logger = logging.getLogger("LoggingTensor") + log_list = [] + handler = LoggingTensorHandler(log_list) + logger.addHandler(handler) + logger.setLevel(logging.INFO) + logger.propagate = False + try: + yield log_list + finally: + logger.removeHandler(handler) + +def tracefunc(frame, event, arg, indent=[0]): + if event == "call": + indent[0] += 2 + print("-" * indent[0] + "> call function", frame.f_code.co_name) + elif event == "return": + print("<" + "-" * indent[0], "exit function", frame.f_code.co_name) + indent[0] -= 2 + return tracefunc + +import sys +