Skip to content

Commit

Permalink
Modifying interaction layer to include 2 MLPs in DLRM (pytorch#1)
Browse files Browse the repository at this point in the history
Summary:
X-link: pytorch/torchsnapshot#1

X-link: facebookresearch/recipes#26

Pull Request resolved: pytorch#382

X-link: facebookresearch/dlrm#242

This diff adds 2 MLPs to the interaction layer in DLRM for MLPerf update. New DLRM module called DLRMV2 can be realized by --dlrmv2 argument. Additional arguments for the interaction MLPs are --interaction_branch1_layer_sizes and --interaction_branch2_layer_sizes to pass in the MLP sizes. The output dimension of the interaction MLPs must be a multiple of the embedding dimension.

DLRMTrain now takes in a DLRM/DLRMV2 module at construction time.

Reviewed By: colin2328, samiwilf

Differential Revision: D35861688

fbshipit-source-id: e8d4e7cd45260f4d229553242b6ea48068f5dda9
  • Loading branch information
narayanan2004 authored and facebook-github-bot committed Jun 9, 2022
1 parent 3f2fb54 commit f6d1ab2
Show file tree
Hide file tree
Showing 3 changed files with 612 additions and 25 deletions.
5 changes: 3 additions & 2 deletions examples/ray/train_torchrec.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from torchrec.distributed.embeddingbag import EmbeddingBagCollectionSharder
from torchrec.distributed.model_parallel import DistributedModelParallel
from torchrec.distributed.types import ModuleSharder
from torchrec.models.dlrm import DLRMTrain
from torchrec.models.dlrm import DLRM, DLRMTrain
from torchrec.modules.embedding_configs import EmbeddingBagConfig
from torchrec.modules.embedding_modules import EmbeddingBagCollection
from torchrec.optim.keyed import KeyedOptimizerWrapper
Expand Down Expand Up @@ -82,7 +82,7 @@ def train(
)
for feature_idx, feature_name in enumerate(DEFAULT_CAT_NAMES)
]
train_model = DLRMTrain(
dlrm_model = DLRM(
embedding_bag_collection=EmbeddingBagCollection(
tables=eb_configs, device=torch.device("meta")
),
Expand All @@ -91,6 +91,7 @@ def train(
over_arch_layer_sizes=over_arch_layer_sizes,
dense_device=device,
)
train_model = DLRMTrain(dlrm_model)

# Enable optimizer fusion
fused_params = {
Expand Down
279 changes: 257 additions & 22 deletions torchrec/models/dlrm.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,107 @@ def forward(
return torch.cat((dense_features, interactions_flat), dim=1)


class InteractionV2Arch(nn.Module):
"""
Processes the output of both `SparseArch` (sparse_features) and `DenseArch`
(dense_features). Return Y*Z and the dense layer itself (all concatenated)
where Y is the output of interaction branch 1 and Z is the output of interaction
branch 2. Y and Z are of size Bx(F1xD) and Bx(DxF2) respectively for some F1 and F2.
.. note::
The dimensionality of the `dense_features` (D) is expected to match the
dimensionality of the `sparse_features` so that the dot products between them
can be computed.
The output dimension of the 2 interaction branches should be a multiple
of D.
Args:
num_sparse_features (int): F.
interaction_branch1 (nn.Module): MLP module for the first branch of
interaction layer
interaction_branch2 (nn.Module): MLP module for the second branch of
interaction layer
Example::
D = 3
B = 10
keys = ["f1", "f2"]
F = len(keys)
# Assume last layer of
I1 = DenseArch(
in_features= 3 * D + D,
layer_sizes=[4*D, 4*D], # F1 = 4
device=dense_device,
)
I2 = DenseArch(
in_features= 3 * D + D,
layer_sizes=[4*D, 4*D], # F2 = 4
device=dense_device,
)
inter_arch = InteractionV2Arch(
num_sparse_features=len(keys),
interaction_branch1 = I1,
interaction_branch2 = I2,
)
dense_features = torch.rand((B, D))
sparse_features = torch.rand((B, F, D))
# B X (D + F1 * F2)
concat_dense = inter_arch(dense_features, sparse_features)
"""

def __init__(
self,
num_sparse_features: int,
interaction_branch1: nn.Module,
interaction_branch2: nn.Module,
) -> None:
super().__init__()
self.F: int = num_sparse_features
self.interaction_branch1 = interaction_branch1
self.interaction_branch2 = interaction_branch2

def forward(
self, dense_features: torch.Tensor, sparse_features: torch.Tensor
) -> torch.Tensor:
"""
Args:
dense_features (torch.Tensor): an input tensor of size B X D.
sparse_features (torch.Tensor): an input tensor of size B X F X D.
Returns:
torch.Tensor: an output tensor of size B X (D + F1 * F2)) where
F1*D and F2*D are the output dimensions of the 2 interaction MLPs.
"""
if self.F <= 0:
return dense_features
(B, D) = dense_features.shape

combined_values = torch.cat(
(dense_features.unsqueeze(1), sparse_features), dim=1
)

interaction_branch1_out = self.interaction_branch1(
torch.reshape(combined_values, (B, -1))
)

interaction_branch2_out = self.interaction_branch2(
torch.reshape(combined_values, (B, -1))
)

interactions = torch.bmm(
interaction_branch1_out.reshape([B, -1, D]),
interaction_branch2_out.reshape([B, D, -1]),
)
interactions_flat = torch.reshape(interactions, (B, -1))

return torch.cat((dense_features, interactions_flat), dim=1)


class OverArch(nn.Module):
"""
Final Arch of DLRM - simple MLP over OverArch.
Expand Down Expand Up @@ -371,11 +472,15 @@ def __init__(
layer_sizes=dense_arch_layer_sizes,
device=dense_device,
)
self.inter_arch = InteractionArch(num_sparse_features=num_sparse_features)

self.inter_arch = InteractionArch(
num_sparse_features=num_sparse_features,
)

over_in_features: int = (
embedding_dim + choose(num_sparse_features, 2) + num_sparse_features
)

self.over_arch = OverArch(
in_features=over_in_features,
layer_sizes=over_arch_layer_sizes,
Expand Down Expand Up @@ -404,6 +509,152 @@ def forward(
return logits


class DLRMV2(DLRM):
"""
Recsys model v2 modified from the original model from "Deep Learning Recommendation
Model for Personalization and Recommendation Systems"
(https://arxiv.org/abs/1906.00091). Similar to DLRM module but has
additional MLPs in the interaction layer (along 2 branches).
The module assumes all sparse features have the same embedding dimension
(i.e. each EmbeddingBagConfig uses the same embedding_dim).
The following notation is used throughout the documentation for the models:
* F: number of sparse features
* D: embedding_dimension of sparse features
* B: batch size
* num_features: number of dense features
Args:
embedding_bag_collection (EmbeddingBagCollection): collection of embedding bags
used to define `SparseArch`.
dense_in_features (int): the dimensionality of the dense input features.
dense_arch_layer_sizes (List[int]): the layer sizes for the `DenseArch`.
over_arch_layer_sizes (List[int]): the layer sizes for the `OverArch`.
The output dimension of the `InteractionArch` should not be manually
specified here.
interaction_branch1_layer_sizes (List[int]): the layer sizes for first branch of
interaction layer. The output dimension must be a multiple of D.
interaction_branch2_layer_sizes (List[int]):the layer sizes for second branch of
interaction layer. The output dimension must be a multiple of D.
dense_device (Optional[torch.device]): default compute device.
Example::
B = 2
D = 8
eb1_config = EmbeddingBagConfig(
name="t1", embedding_dim=D, num_embeddings=100, feature_names=["f1", "f3"]
)
eb2_config = EmbeddingBagConfig(
name="t2",
embedding_dim=D,
num_embeddings=100,
feature_names=["f2"],
)
ebc = EmbeddingBagCollection(tables=[eb1_config, eb2_config])
model = DLRMV2(
embedding_bag_collection=ebc,
dense_in_features=100,
dense_arch_layer_sizes=[20, D],
interaction_branch1_layer_sizes=[3*D+D, 4*D],
interaction_branch2_layer_sizes=[3*D+D, 4*D],
over_arch_layer_sizes=[5, 1],
)
features = torch.rand((B, 100))
# 0 1
# 0 [1,2] [4,5]
# 1 [4,3] [2,9]
# ^
# feature
sparse_features = KeyedJaggedTensor.from_offsets_sync(
keys=["f1", "f3"],
values=torch.tensor([1, 2, 4, 5, 4, 3, 2, 9]),
offsets=torch.tensor([0, 2, 4, 6, 8]),
)
logits = model(
dense_features=features,
sparse_features=sparse_features,
)
"""

def __init__(
self,
embedding_bag_collection: EmbeddingBagCollection,
dense_in_features: int,
dense_arch_layer_sizes: List[int],
over_arch_layer_sizes: List[int],
interaction_branch1_layer_sizes: List[int],
interaction_branch2_layer_sizes: List[int],
dense_device: Optional[torch.device] = None,
) -> None:
# initialize DLRM
# sparse arch and dense arch are initialized via DLRM
super().__init__(
embedding_bag_collection,
dense_in_features,
dense_arch_layer_sizes,
over_arch_layer_sizes,
dense_device,
)

embedding_dim: int = embedding_bag_collection.embedding_bag_configs[
0
].embedding_dim
num_sparse_features: int = len(self.sparse_arch.sparse_feature_names)

# Fix interaction and over arch for DLRMv2
if interaction_branch1_layer_sizes[-1] % embedding_dim != 0:
raise ValueError(
"Final interaction branch1 layer size "
"({}) is not a multiple of embedding size ({})".format(
interaction_branch1_layer_sizes[-1], embedding_dim
)
)
projected_dim_1: int = interaction_branch1_layer_sizes[-1] // embedding_dim
interaction_branch1 = DenseArch(
in_features=num_sparse_features * embedding_dim
+ dense_arch_layer_sizes[-1],
layer_sizes=interaction_branch1_layer_sizes,
device=dense_device,
)

if interaction_branch2_layer_sizes[-1] % embedding_dim != 0:
raise ValueError(
"Final interaction branch2 layer size "
"({}) is not a multiple of embedding size ({})".format(
interaction_branch2_layer_sizes[-1], embedding_dim
)
)
projected_dim_2: int = interaction_branch2_layer_sizes[-1] // embedding_dim
interaction_branch2 = DenseArch(
in_features=num_sparse_features * embedding_dim
+ dense_arch_layer_sizes[-1],
layer_sizes=interaction_branch2_layer_sizes,
device=dense_device,
)

self.inter_arch = InteractionV2Arch(
num_sparse_features=num_sparse_features,
interaction_branch1=interaction_branch1,
interaction_branch2=interaction_branch2,
)

over_in_features: int = embedding_dim + projected_dim_1 * projected_dim_2

self.over_arch = OverArch(
in_features=over_in_features,
layer_sizes=over_arch_layer_sizes,
device=dense_device,
)


class DLRMTrain(nn.Module):
"""
nn.Module to wrap DLRM model to use with train_pipeline.
Expand All @@ -419,42 +670,26 @@ class DLRMTrain(nn.Module):
(i.e, each EmbeddingBagConfig uses the same embedding_dim)
Args:
embedding_bag_collection (EmbeddingBagCollection): collection of embedding bags
used to define SparseArch.
dense_in_features (int): the dimensionality of the dense input features.
dense_arch_layer_sizes (list[int]): the layer sizes for the DenseArch.
over_arch_layer_sizes (list[int]): the layer sizes for the OverArch. NOTE: The
output dimension of the InteractionArch should not be manually specified
here.
dense_device: (Optional[torch.device]).
dlrm_module: DLRM module (DLRM or DLRMV2) to be used in training
Example::
ebc = EmbeddingBagCollection(config=ebc_config)
model = DLRMTrain(
dlrm_module = DLRM(
embedding_bag_collection=ebc,
dense_in_features=100,
dense_arch_layer_sizes=[20],
over_arch_layer_sizes=[5, 1],
)
dlrm_model = DLRMTrain(dlrm_module)
"""

def __init__(
self,
embedding_bag_collection: EmbeddingBagCollection,
dense_in_features: int,
dense_arch_layer_sizes: List[int],
over_arch_layer_sizes: List[int],
dense_device: Optional[torch.device] = None,
dlrm_module: DLRM,
) -> None:
super().__init__()
self.model = DLRM(
embedding_bag_collection=embedding_bag_collection,
dense_in_features=dense_in_features,
dense_arch_layer_sizes=dense_arch_layer_sizes,
over_arch_layer_sizes=over_arch_layer_sizes,
dense_device=dense_device,
)
self.model = dlrm_module
self.loss_fn: nn.Module = nn.BCEWithLogitsLoss()

def forward(
Expand Down

0 comments on commit f6d1ab2

Please sign in to comment.