In [1]:
import os
os.environ['PJRT_DEVICE'] = 'TPU'

import torch
import torch_xla.core.xla_model as xm
import torch_xla.debug.metrics as met


def xla_repeat_interleave(input, repeats, dim):
    met.clear_all()
    shape = list(input.shape)
    shape.insert(dim, repeats)

    expanded = input.unsqueeze(dim + 1)
    tiled = expanded.expand(*shape)
    result = tiled.reshape(-1, *shape[2:])
    print(met.metrics_report())
    return result

device = xm.xla_device()
# Input tensor
image_embeddings = torch.randn(torch.Size([1, 256, 64, 64]), device=device)
tokens = torch.randn(torch.Size([16, 8, 256]), device=device)

# Using xla_repeat_interleave function
result_xla = xla_repeat_interleave(image_embeddings, tokens.shape[0], dim=0)

# Using torch.repeat_interleave
result_builtin = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0)

# Print the results
print("xla_repeat_interleave:")
print(result_xla)
print("torch.repeat_interleave:")
print(result_builtin)

# Check if the results are equal
print("Are the results equal?", torch.allclose(result_xla, result_builtin))

Counter: CreateXlaTensor
  Value: 3
Counter: xla::expand_symint
  Value: 1
Counter: xla::unsqueeze
  Value: 1
Counter: xla::view_symint
  Value: 1

xla_repeat_interleave:
tensor([[[[-0.6783, -0.3352,  1.1862,  ...,  0.5943,  0.6945,  0.2949],
          [ 0.8233,  0.0155, -1.0901,  ..., -0.9990,  0.8104,  0.3706],
          [-0.6425, -0.2450, -0.9264,  ...,  0.0396, -0.9698, -0.8212],
          ...,
          [ 0.1677, -0.2868, -0.5007,  ..., -0.1369, -1.0855, -0.3244],
          [ 2.0562, -1.1668,  0.6644,  ...,  0.9703, -0.6673,  1.2617],
          [-0.3927,  0.7532, -0.0586,  ...,  0.4447,  1.3878, -0.0295]],

         [[ 0.7528, -0.3819, -1.0636,  ..., -0.1932,  0.9681, -0.4283],
          [ 0.2421,  0.0424, -0.4103,  ...,  0.7230,  0.5019,  1.5766],
          [-0.0991,  2.1333, -0.6871,  ...,  1.1626, -0.3252, -0.1060],
          ...,
          [-0.9665,  0.1913, -0.6528,  ..., -1.7360,  0.4690, -1.3952],
          [ 0.6700, -1.9343,  0.2071,  ..., -0.2868,  0.8358, -0.0427],
     