Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 7 additions & 5 deletions keras_hub/src/models/mix_transformer/mix_transformer_backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def __init__(
num_layers,
blockwise_num_heads,
blockwise_sr_ratios,
end_value,
max_drop_path_rate,
patch_sizes,
strides,
image_shape=(None, None, 3),
Expand All @@ -45,7 +45,9 @@ def __init__(
ratio to perform for each layer on the sequence before key and
value projections. If set to > 1, a `Conv2D` layer is used to
reduce the length of the sequence.
end_value: The end value of the sequence.
max_drop_path_rate: The final value of the `linspace()` that
defines the drop path rates for the `DropPath` layers of
the `HierarchicalTransformerEncoder` layers.
image_shape: optional shape tuple, defaults to (None, None, 3).
hidden_dims: the embedding dims per hierarchical layer, used as
the levels of the feature pyramid.
Expand Down Expand Up @@ -73,7 +75,7 @@ def __init__(
model.fit(images, labels, epochs=3)
```
"""
dpr = [x for x in np.linspace(0.0, end_value, sum(depths))]
dpr = [x for x in np.linspace(0.0, max_drop_path_rate, sum(depths))]

# === Layers ===
cur = 0
Expand Down Expand Up @@ -136,7 +138,7 @@ def __init__(
self.num_layers = num_layers
self.blockwise_num_heads = blockwise_num_heads
self.blockwise_sr_ratios = blockwise_sr_ratios
self.end_value = end_value
self.max_drop_path_rate = max_drop_path_rate
self.patch_sizes = patch_sizes
self.strides = strides

Expand All @@ -150,7 +152,7 @@ def get_config(self):
"num_layers": self.num_layers,
"blockwise_num_heads": self.blockwise_num_heads,
"blockwise_sr_ratios": self.blockwise_sr_ratios,
"end_value": self.end_value,
"max_drop_path_rate": self.max_drop_path_rate,
"patch_sizes": self.patch_sizes,
"strides": self.strides,
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def setUp(self):
"num_layers": 2,
"blockwise_num_heads": [1, 2],
"blockwise_sr_ratios": [8, 4],
"end_value": 0.1,
"max_drop_path_rate": 0.1,
"patch_sizes": [7, 3],
"strides": [4, 2],
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def setUp(self):
num_layers=2,
blockwise_num_heads=[1, 2],
blockwise_sr_ratios=[8, 4],
end_value=0.1,
max_drop_path_rate=0.1,
patch_sizes=[7, 3],
strides=[4, 2],
)
Expand Down