Skip to content

Commit

Permalink
Add docstrings and types for MambaCache (#30023)
Browse files Browse the repository at this point in the history
* Add docstrings and types for MambaCache

* Update src/transformers/models/mamba/modeling_mamba.py

* Update src/transformers/models/mamba/modeling_mamba.py

* Update src/transformers/models/mamba/modeling_mamba.py

* make fixup

* import copy in generation_whisper

* ruff

* Revert "make fixup"

This reverts commit c4fedd6.
  • Loading branch information
koayon authored and Ita Zaporozhets committed May 14, 2024
1 parent e2ef9fb commit f2ca923
Showing 1 changed file with 19 additions and 3 deletions.
22 changes: 19 additions & 3 deletions src/transformers/models/mamba/modeling_mamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,23 @@


class MambaCache:
def __init__(self, config, batch_size, dtype=torch.float16, device=None):
"""
Arguments:
config: MambaConfig
batch_size: int
dtype: torch.dtype
device: torch.device
Attributes:
seqlen_offset: int
dtype: torch.dtype
conv_states: Dict[int, torch.Tensor] # layer_idx -> [batch_size, intermediate_size, conv_kernel_size]
ssm_states: Dict[int, torch.Tensor] # layer_idx -> [batch_size, intermediate_size, ssm_state_size]
"""

def __init__(
self, config: MambaConfig, batch_size: int, dtype: torch.dtype = torch.float16, device: Optional[str] = None
):
self.seqlen_offset = 0
self.dtype = dtype
intermediate_size = config.intermediate_size
Expand All @@ -86,13 +102,13 @@ class MambaMixer(nn.Module):
and is why Mamba is called **selective** state spaces)
"""

def __init__(self, config, layer_idx):
def __init__(self, config: MambaConfig, layer_idx: int):
super().__init__()
self.hidden_size = config.hidden_size
self.ssm_state_size = config.state_size
self.conv_kernel_size = config.conv_kernel
self.intermediate_size = config.intermediate_size
self.time_step_rank = config.time_step_rank
self.time_step_rank = int(config.time_step_rank)
self.layer_idx = layer_idx
self.use_conv_bias = config.use_conv_bias
self.conv1d = nn.Conv1d(
Expand Down

0 comments on commit f2ca923

Please sign in to comment.