Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 12 additions & 49 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,60 +1,23 @@
```
```gitignore
# Python
__pycache__/
*.pyc
*.pyo
*.pyd
.Python
env/
venv/
.venv/
.ENV
.venv.bak
pip-log.txt
pip-delete-this-directory.txt
.tox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.log
.git/modules
.DS_Store
Thumbs.db

# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
*.egg-info/
.installed.cfg
*.egg

# Pytest
.pytest_cache/

# IDE
.vscode/
.idea/

# Environment
# Dependencies
.venv/
venv/
.env
.env.local
*.env.*
.env.*

# OS
.DS_Store
Thumbs.db
# Logs and temp files
*.log
*.tmp

# Editors
.vscode/
.idea/
```
27 changes: 27 additions & 0 deletions docs/source/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,24 @@ This section provides detailed documentation of all public modules and classes.
:show-inheritance:
```

### PartialRoPE

```{eval-rst}
.. autoclass:: transformer.pos.PartialRoPE
:members:
:undoc-members:
:show-inheritance:
```

### ALiBi (Attention with Linear Biases)

```{eval-rst}
.. autoclass:: transformer.pos.ALiBi
:members:
:undoc-members:
:show-inheritance:
```

## Feed-Forward Modules

### SwiGLU
Expand Down Expand Up @@ -91,3 +109,12 @@ This section provides detailed documentation of all public modules and classes.
:show-inheritance:
:special-members: __init__
```

## Utilities

```{eval-rst}
.. automodule:: transformer.utils
:members:
:undoc-members:
:show-inheritance:
```
12 changes: 6 additions & 6 deletions transformer.egg-info/PKG-INFO
Original file line number Diff line number Diff line change
Expand Up @@ -75,12 +75,12 @@ from transformer import Transformer, TransformerConfig
# Configure the model
config = TransformerConfig(
n_layers = 12,
n_heads: int = 32,
d_model: int = 1536,
attn_qk_norm: bool = False,
tied_weights: bool = False,
seq_len: int = 1024,
max_seq_len: int = 4096,
n_heads = 32,
d_model = 1536,
attn_qk_norm = False,
tied_weights = False,
seq_len = 1024,
max_seq_len = 4096,
)

# Initialize model
Expand Down
21 changes: 18 additions & 3 deletions transformer/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,24 @@
from .attns import GQA, MHA, CrossAttention
from .config import TransformerConfig
from .ffn import MLP, SwiGLU
from .pos import RoPE
from .pos import RoPE, PartialRoPE, ALiBi
from .transformer import Transformer, TransformerBlock
from .utils import check_type, resolve_layer_config

__all__ = ["TransformerConfig", "GQA", "MHA", "CrossAttention", "RoPE", "SwiGLU", "MLP", "TransformerBlock", "Transformer"]
__all__ = [
"TransformerConfig",
"GQA",
"MHA",
"CrossAttention",
"RoPE",
"PartialRoPE",
"ALiBi",
"SwiGLU",
"MLP",
"TransformerBlock",
"Transformer",
"check_type",
"resolve_layer_config"
]

__version__ = "0.4.0"
__version__ = "0.5.0"
Binary file added transformer/__pycache__/__init__.cpython-312.pyc
Binary file not shown.
Binary file added transformer/__pycache__/attns.cpython-312.pyc
Binary file not shown.
Binary file added transformer/__pycache__/config.cpython-312.pyc
Binary file not shown.
Binary file added transformer/__pycache__/ffn.cpython-312.pyc
Binary file not shown.
Binary file added transformer/__pycache__/pos.cpython-312.pyc
Binary file not shown.
Binary file not shown.
Binary file added transformer/__pycache__/utils.cpython-312.pyc
Binary file not shown.
23 changes: 16 additions & 7 deletions transformer/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ class TransformerConfig(PretrainedConfig):
- If ``str``, one of ``rms_norm`` or ``layer_norm``.
- If ``Type[nn.Module]`` then will be instantiated inside the model.
Should have the same API as a torch Normalization Layer.
- If ``List[Union[Type[nn.Module], str]]`` and len(ffn_class) == n_layers
- If ``List[Union[Type[nn.Module], str]]`` and len(norm_class) == n_layers
then will be instantiated inside the model for the corresponding layers.
:type norm_class: Union[List[Union[Type[nn.Module], str]], Type[nn.Module], str]

Expand All @@ -55,9 +55,9 @@ class TransformerConfig(PretrainedConfig):
- If ``Type[nn.Module]`` then will be instantiated inside the model.
Should have the same API as ``transformer.attn.MHA``.
Default ``MHA``
- If ``List[Union[Type[nn.Module], str]]`` and len(ffn_class) == n_layers
- If ``List[Union[Type[nn.Module], str]]`` and len(attn_class) == n_layers
then will be instantiated inside the model for the corresponding layers.
Default ``SwiGLU`` for every layer.
Default ``MHA`` for every layer.
:type attn_class: Union[List[Union[Type[nn.Module], str]], Type[nn.Module], str]

:param block_class: Transformer Block class for every layer. Default: ``None``
Expand Down Expand Up @@ -87,11 +87,9 @@ class TransformerConfig(PretrainedConfig):
:type seq_len: int

:param pos_encoding: Positional Encoding for attention.
- If ``List[Union[Type[nn.Module], str]]`` and len(ffn_class) == n_layers
then will be instantiated inside the model for the corresponding layers.
Default ``SwiGLU`` for every layer.
- If ``str`` one of ``RoPE``, ``AliBI``, ``PartialRoPE``. Default: ``RoPE``
Note: Is recommended to change the default to ``PartialRoPE`` which is used in SOTA models like Qwen3-Next-80B-A3B
- If ``List[str]`` and len(pos_encoding) == n_layers, applies different positional encodings per layer.
:type pos_encoding: Union[List[str], str]

:param rope_base: Base for the Exponential Frequency Calculation in RoPE. Default: ``10000.0``
Expand All @@ -100,6 +98,12 @@ class TransformerConfig(PretrainedConfig):
:param max_seq_len: Maximum sequence length for positional embeddings.
:type max_seq_len: int

:param use_cache: Whether to use KV cache during generation. Default: ``True``
:type use_cache: bool, optional

:param is_decoder: Whether this is a decoder model. Default: ``True``
:type is_decoder: bool, optional

:param kwargs: Additional keyword arguments passed to `PretrainedConfig`
:type kwargs: dict, optional

Expand Down Expand Up @@ -127,9 +131,11 @@ def __init__(
attn_dropout: Optional[float] = 0.0,
tied_weights: bool = False,
seq_len: int = 1024,
pos_encoding: str = "RoPE",
pos_encoding: Union[List[str], str] = "RoPE",
rope_base: float = 10000.0,
max_seq_len: int = 4096,
use_cache: bool = True,
is_decoder: bool = True,
**kwargs: Dict,
):
super().__init__(**kwargs)
Expand Down Expand Up @@ -162,3 +168,6 @@ def __init__(
self.pos_encoding = pos_encoding
self.rope_base = rope_base
self.max_seq_len = max_seq_len

self.use_cache = use_cache
self.is_decoder = is_decoder
Loading