In [None]:
#default_exp utils

# Utils

One of the problems we have hit testing different models and transforms is that sometimes it is slower even than CPUs, but this happens because we hit operations on pytorch that are only handled by CPU and not by hte accelerator. `print_aten_ops` calls directly some pytorch metrics wich ouputs to stdout, so the only way to get that info is capture it.

In [None]:
#colab


In [None]:
#colab
!pip install -Uqq cloud-tpu-client==0.10 https://storage.googleapis.com/tpu-pytorch/wheels/torch_xla-1.7-cp36-cp36m-linux_x86_64.whl

In [None]:
DEBUG = False

In [None]:
#hide
import warnings
try:
    import torch_xla
except ImportError as e:
    if DEBUG:
        warnings.warn('TPU Environment not available')

In [None]:
#exporti
#hide_output
import sys

def xla_imported(): 
    return 'torch_xla' in sys.modules

In [None]:
#exporti
if not xla_imported():
    from types import SimpleNamespace
    def fake_metrics_report(*args,**kwargs):
        return ""
    met = SimpleNamespace(
        metrics_report = fake_metrics_report
    )
else:
    import torch_xla.debug.metrics as met

In [None]:
#export
def print_aten_ops():
    # import torch_xla.debug.metrics as met
    from io import StringIO 
    import sys

    class Capturing(list):
        def __enter__(self):
            self._stdout = sys.stdout
            sys.stdout = self._stringio = StringIO()
            return self
        def __exit__(self, *args):
            self.extend(self._stringio.getvalue().splitlines())
            del self._stringio    # free up some memory
            sys.stdout = self._stdout

    out = met.metrics_report()
    if out.find("aten::"):
        print_now=False
        lines = out.split("\n")
        for l in lines:
            if print_now:
                print_now=False
                print(l)
            if l.find("aten::")>-1:
                print("needs lowering:", l)
                print_now=True