Skip to content

Commit

Permalink
[NFC] polish colossalai/nn/_ops/embedding_bag.py code style (#1552)
Browse files Browse the repository at this point in the history
  • Loading branch information
MaruyamaAya committed Sep 8, 2022
1 parent 868c469 commit 08854ea
Showing 1 changed file with 14 additions and 15 deletions.
29 changes: 14 additions & 15 deletions colossalai/nn/_ops/embedding_bag.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,22 +90,21 @@ def colo_embedding_bag(input_tensor: GeneralTensor,

# Handle differen parallel actions.

if not weight.has_compute_spec(): # No Model Parallel Applied
if not weight.has_compute_spec(): # No Model Parallel Applied
assert weight.is_replicate(), 'Invalid weight spec for native embedding op'
return ColoTensor.from_torch_tensor(
tensor=F.embedding_bag(input_tensor,
weight,
offsets=offsets,
max_norm=max_norm,
norm_type=norm_type,
scale_grad_by_freq=scale_grad_by_freq,
mode=mode,
sparse=sparse,
per_sample_weights=per_sample_weights,
include_last_offset=include_last_offset,
padding_idx=padding_idx),
spec=ColoTensorSpec(weight.get_process_group()))
elif weight.has_compute_pattern(ComputePattern.TP1D): # Single Model Parallel Applied
return ColoTensor.from_torch_tensor(tensor=F.embedding_bag(input_tensor,
weight,
offsets=offsets,
max_norm=max_norm,
norm_type=norm_type,
scale_grad_by_freq=scale_grad_by_freq,
mode=mode,
sparse=sparse,
per_sample_weights=per_sample_weights,
include_last_offset=include_last_offset,
padding_idx=padding_idx),
spec=ColoTensorSpec(weight.get_process_group()))
elif weight.has_compute_pattern(ComputePattern.TP1D): # Single Model Parallel Applied
if weight.is_shard_1dcol():
tp_mode = 'col'
else:
Expand Down

0 comments on commit 08854ea

Please sign in to comment.