Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions torchrec/ir/serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,17 +89,17 @@ def ebc_meta_forward(
features: KeyedJaggedTensor,
) -> KeyedTensor:
batch_size = features.stride()
dim = sum(ebc._lengths_per_embedding)
dims = ebc._lengths_per_embedding
arg_list = [
features.values(),
features.weights_or_none(),
features.lengths_or_none(),
features.offsets_or_none(),
] # if want to include the weights: `+ [bag.weight for bag in self.embedding_bags.values()]`
output = torch.ops.torchrec.ir_custom_op(arg_list, batch_size, dim)
outputs = torch.ops.torchrec.ir_custom_op(arg_list, batch_size, dims)
return KeyedTensor(
keys=ebc._embedding_names,
values=output,
values=torch.cat(outputs, dim=1),
length_per_key=ebc._lengths_per_embedding,
)

Expand All @@ -110,17 +110,17 @@ def fpebc_meta_forward(
) -> KeyedTensor:
batch_size = features.stride()
ebc = fpebc._embedding_bag_collection
dim = sum(ebc._lengths_per_embedding)
dims = ebc._lengths_per_embedding
arg_list = [
features.values(),
features.weights_or_none(),
features.lengths_or_none(),
features.offsets_or_none(),
] # if want to include the weights: `+ [bag.weight for bag in self.embedding_bags.values()]`
output = torch.ops.torchrec.ir_custom_op(arg_list, batch_size, dim)
outputs = torch.ops.torchrec.ir_custom_op(arg_list, batch_size, dims)
return KeyedTensor(
keys=ebc._embedding_names,
values=output,
values=torch.cat(outputs, dim=1),
length_per_key=ebc._lengths_per_embedding,
)

Expand Down
38 changes: 20 additions & 18 deletions torchrec/ir/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,30 +28,32 @@
logger: logging.Logger = logging.getLogger(__name__)


@torch.library.custom_op("torchrec::ir_custom_op", mutates_args={})
def ir_custom_op_impl(
tensors: List[Optional[torch.Tensor]], batch_size: int, dim: int
) -> torch.Tensor:
device = None
def get_device(tensors: List[Optional[torch.Tensor]]) -> Optional[torch.device]:
"""
Returns the device of the first non-None tensor in the list.
"""
for t in tensors:
if t is not None:
device = t.device
break
logger.info(f"torch.ops.torchrec.ir_custom_op -> ({batch_size}, {dim}) {device}")
return torch.empty(batch_size, dim, device=device)
return t.device
return None


@torch.library.custom_op("torchrec::ir_custom_op", mutates_args={})
def ir_custom_op_impl(
tensors: List[Optional[torch.Tensor]], batch_size: int, dims: List[int]
) -> List[torch.Tensor]:
device = get_device(tensors)
logger.info(f"torch.ops.torchrec.ir_custom_op -> ({batch_size}, {dims}) {device}")
return [torch.empty(batch_size, dim, device=device) for dim in dims]


@torch.library.register_fake("torchrec::ir_custom_op")
def ir_custom_op_fake(
tensors: List[Optional[torch.Tensor]], batch_size: int, dim: int
) -> torch.Tensor:
device = None
for t in tensors:
if t is not None:
device = t.device
break
logger.info(f"ir_custom_op_fake -> ({batch_size}, {dim}) {device}")
return torch.empty(batch_size, dim, device=device)
tensors: List[Optional[torch.Tensor]], batch_size: int, dims: List[int]
) -> List[torch.Tensor]:
device = get_device(tensors)
logger.info(f"ir_custom_op_fake -> ({batch_size}, {dims}) {device}")
return [torch.empty(batch_size, dim, device=device) for dim in dims]


def encapsulate_ir_modules(
Expand Down