In [None]:
#default_exp utils

# Utils

<a href="https://colab.research.google.com/github/butchland/fastai_xla_extensions/blob/master/nbs/01_utils.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>


> Utilities used by other modules

In [None]:
#hide
#colab
!pip install -Uqq cloud-tpu-client==0.10 https://storage.googleapis.com/tpu-pytorch/wheels/torch_xla-1.7-cp36-cp36m-linux_x86_64.whl

[K     |████████████████████████████████| 133.6MB 77kB/s 
[K     |████████████████████████████████| 61kB 3.1MB/s 
[?25h

In [None]:
#exporti
try:
    import torch_xla
except ImportError:
    pass



In [None]:
#export
import sys

def xla_imported():
    "Check whether the `torch_xla` module has been successfully imported"
    return 'torch_xla' in sys.modules

`xla_imported` is a utility method that is used to check if the `torch_xla` module has been successfully imported.

In [None]:
#hide
# fake out xla modules on environments not configured for TPU 
if not xla_imported():
    from types import SimpleNamespace
    def fake_metrics_report(*args,**kwargs):
        return ""
    met = SimpleNamespace(
        metrics_report = fake_metrics_report
    )
    

In [None]:
#exporti
if xla_imported():
    import torch_xla.debug.metrics as met

In [None]:
#export
def print_aten_ops():
    "print out xla aten operations (from xla debug metrics report `torch_xla.debug.metrics`)"
    # import torch_xla.debug.metrics as met
    from io import StringIO
    import sys

    class Capturing(list):
        def __enter__(self):
            self._stdout = sys.stdout
            sys.stdout = self._stringio = StringIO()
            return self
        def __exit__(self, *args):
            self.extend(self._stringio.getvalue().splitlines())
            del self._stringio    # free up some memory
            sys.stdout = self._stdout

    out = met.metrics_report()
    if out.find("aten::"):
        print_now=False
        lines = out.split("\n")
        for l in lines:
            if print_now:
                print_now=False
                print(l)
            if l.find("aten::")>-1:
                print("needs lowering:", l)
                print_now=True


One of the problems we have hit testing different models and transforms is that sometimes it is slower on TPUs compared to running on CPUs, but this happens because we hit operations on Pytorch XLA that are only handled by the CPU and not by the accelerator. 

`print_aten_ops` calls directly some pytorch metrics which outputs to `stdout`, so the only way to get that info is capture it.

In [None]:
#colab
#test that torch_xla has been imported on colab
assert xla_imported()

In [None]:
#hide
#TODO: Add example usage for print_aten_ops