In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.debug.metrics as met

In [2]:
import os
os.environ['XRT_DEVICE_MAP'] = "CPU:0;/job:localservice/replica:0/task:0/device:XLA_CPU:0"
os.environ['XRT_WORKERS'] = "localservice:0;grpc://localhost:40501"

In [3]:
pred_w = torch.rand(1000, 1).to(torch.float64)
pred_h = torch.rand(1000, 1)
pred_ctr_x = torch.rand(1000, 1)
pred_ctr_y = torch.rand(1000, 1)
pred_boxes_cpu = torch.rand(1000, 4)
pred_boxes_xla = torch.rand(1000, 4)

In [4]:
# CPU Golden Results 
pred_boxes_cpu[:, 0::4] = pred_ctr_x - 0.5 * pred_w
# y1
pred_boxes_cpu[:, 1::4] = pred_ctr_y - 0.5 * pred_h
# x2 (note: "- 1" is correct; don't be fooled by the asymmetry)
pred_boxes_cpu[:, 2::4] = pred_ctr_x + 0.5 * pred_w - 1
# y2 (note: "- 1" is correct; don't be fooled by the asymmetry)
pred_boxes_cpu[:, 3::4] = pred_ctr_y + 0.5 * pred_h - 1

In [4]:
dev = xm.xla_device()
print(dev)

xla:0


In [5]:
# Test case:
# https://github.com/asuhan/maskrcnn-benchmark/blob/9063850dc3069dce9d6a8ce9f65f8449b1cd3be7/maskrcnn_benchmark/modeling/box_coder.py#L85
# I'm not sure why in the code above, the multiply factor 0.5 gets to transfer to FP64 by default. 
# Here, I did torch.tensor(0.5, dtype = torch.float64, device = dev) to reproduce the error I saw previous: mixed precisions
pred_w = pred_w.to(dev)
pred_h = pred_h.to(dev)
pred_ctr_x = pred_ctr_x.to(dev)
pred_ctr_y = pred_ctr_y.to(dev)

rel_codes = pred_boxes_cpu.to(dev)
pred_boxes_xla = torch.zeros_like(rel_codes)
pred_boxes_xla[:, 0::4] = pred_ctr_x - torch.tensor(0.5, dtype = torch.float64, device = dev)  * pred_w
# y1
pred_boxes_xla[:, 1::4] = pred_ctr_y - torch.tensor(0.5, dtype = torch.float64, device = dev) * pred_h
# x2 (note: "- 1" is correct; don't be fooled by the asymmetry)
pred_boxes_xla[:, 2::4] = pred_ctr_x + torch.tensor(0.5, dtype = torch.float64, device = dev) * pred_w - 1
# y2 (note: "- 1" is correct; don't be fooled by the asymmetry)
pred_boxes_xla[:, 3::4] = pred_ctr_y + torch.tensor(0.5, dtype = torch.float64, device = dev) * pred_h - 1
xm.mark_step()

RuntimeError: Internal: From /job:localservice/replica:0/task:0:
2 root error(s) found.
  (0) Internal: Seen floating point types of different precisions in %pad.119 = f64[1000,4]{1,0} pad(f32[1000,1]{1,0} %convert.67, f64[] %constant.118), padding=0_0_0x2_1_1, but mixed precision is disallowed.
	 [[{{node XRTCompile}}]]
	 [[XRTCompile_G3]]
  (1) Internal: Seen floating point types of different precisions in %pad.119 = f64[1000,4]{1,0} pad(f32[1000,1]{1,0} %convert.67, f64[] %constant.118), padding=0_0_0x2_1_1, but mixed precision is disallowed.
	 [[{{node XRTCompile}}]]
0 successful operations.
0 derived errors ignored.
Recent warning and error logs:
  OP_REQUIRES failed at xrt_compile_ops.cc:220 : Internal: Seen floating point types of different precisions in %pad.119 = f64[1000,4]{1,0} pad(f32[1000,1]{1,0} %convert.67, f64[] %constant.118), padding=0_0_0x2_1_1, but mixed precision is disallowed.

In [None]:
# Will see above bug if return FP64 to XLA_GPU/CPU


In [9]:
# Won't have issue if return FP32 to XLA_GPU/CPU
print(pred_boxes_xla) 

tensor([[ 5.1625e-02,  1.2382e-01, -8.2171e-04, -4.5915e-02],
        [ 3.3453e-01,  1.9495e-01, -3.9123e-01, -5.0426e-01],
        [ 2.8908e-01,  8.7252e-01,  2.3290e-01,  3.2853e-02],
        ...,
        [-2.9353e-02,  4.6598e-01, -2.8248e-01,  2.5617e-01],
        [ 1.6288e-02,  6.0928e-01, -8.6419e-01, -3.0415e-01],
        [ 3.6718e-01,  6.8817e-01, -6.5811e-03, -2.0635e-01]], device='xla:0')


In [7]:
# XLA vs CPU results
torch.max(torch.abs(pred_boxes_xla - pred_boxes_cpu))

tensor(5.9605e-08, device='xla:0')