diff --git a/.gitignore b/.gitignore index 164864f..fa4ba62 100644 --- a/.gitignore +++ b/.gitignore @@ -1,60 +1,23 @@ -``` +```gitignore # Python __pycache__/ *.pyc *.pyo *.pyd -.Python -env/ -venv/ -.venv/ -.ENV -.venv.bak -pip-log.txt -pip-delete-this-directory.txt -.tox/ -.coverage -.coverage.* -.cache -nosetests.xml -coverage.xml -*.cover -*.log -.git/modules -.DS_Store -Thumbs.db - -# Distribution / packaging -.Python -build/ -develop-eggs/ -dist/ -downloads/ -eggs/ -.eggs/ -lib/ -lib64/ -parts/ -sdist/ -var/ -wheels/ *.egg-info/ -.installed.cfg -*.egg - -# Pytest -.pytest_cache/ - -# IDE -.vscode/ -.idea/ -# Environment +# Dependencies +.venv/ +venv/ .env .env.local -*.env.* +.env.* -# OS -.DS_Store -Thumbs.db +# Logs and temp files +*.log +*.tmp + +# Editors +.vscode/ +.idea/ ``` \ No newline at end of file diff --git a/docs/source/api.md b/docs/source/api.md index cb9df1e..6c31de9 100644 --- a/docs/source/api.md +++ b/docs/source/api.md @@ -51,6 +51,24 @@ This section provides detailed documentation of all public modules and classes. :show-inheritance: ``` +### PartialRoPE + +```{eval-rst} +.. autoclass:: transformer.pos.PartialRoPE + :members: + :undoc-members: + :show-inheritance: +``` + +### ALiBi (Attention with Linear Biases) + +```{eval-rst} +.. autoclass:: transformer.pos.ALiBi + :members: + :undoc-members: + :show-inheritance: +``` + ## Feed-Forward Modules ### SwiGLU @@ -91,3 +109,12 @@ This section provides detailed documentation of all public modules and classes. :show-inheritance: :special-members: __init__ ``` + +## Utilities + +```{eval-rst} +.. automodule:: transformer.utils + :members: + :undoc-members: + :show-inheritance: +``` diff --git a/transformer.egg-info/PKG-INFO b/transformer.egg-info/PKG-INFO index b4cb916..ba3e76e 100644 --- a/transformer.egg-info/PKG-INFO +++ b/transformer.egg-info/PKG-INFO @@ -75,12 +75,12 @@ from transformer import Transformer, TransformerConfig # Configure the model config = TransformerConfig( n_layers = 12, - n_heads: int = 32, - d_model: int = 1536, - attn_qk_norm: bool = False, - tied_weights: bool = False, - seq_len: int = 1024, - max_seq_len: int = 4096, + n_heads = 32, + d_model = 1536, + attn_qk_norm = False, + tied_weights = False, + seq_len = 1024, + max_seq_len = 4096, ) # Initialize model diff --git a/transformer/__init__.py b/transformer/__init__.py index 35b8ac0..6a03448 100644 --- a/transformer/__init__.py +++ b/transformer/__init__.py @@ -1,9 +1,24 @@ from .attns import GQA, MHA, CrossAttention from .config import TransformerConfig from .ffn import MLP, SwiGLU -from .pos import RoPE +from .pos import RoPE, PartialRoPE, ALiBi from .transformer import Transformer, TransformerBlock +from .utils import check_type, resolve_layer_config -__all__ = ["TransformerConfig", "GQA", "MHA", "CrossAttention", "RoPE", "SwiGLU", "MLP", "TransformerBlock", "Transformer"] +__all__ = [ + "TransformerConfig", + "GQA", + "MHA", + "CrossAttention", + "RoPE", + "PartialRoPE", + "ALiBi", + "SwiGLU", + "MLP", + "TransformerBlock", + "Transformer", + "check_type", + "resolve_layer_config" +] -__version__ = "0.4.0" +__version__ = "0.5.0" diff --git a/transformer/__pycache__/__init__.cpython-312.pyc b/transformer/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000..03c10e2 Binary files /dev/null and b/transformer/__pycache__/__init__.cpython-312.pyc differ diff --git a/transformer/__pycache__/attns.cpython-312.pyc b/transformer/__pycache__/attns.cpython-312.pyc new file mode 100644 index 0000000..2a7a055 Binary files /dev/null and b/transformer/__pycache__/attns.cpython-312.pyc differ diff --git a/transformer/__pycache__/config.cpython-312.pyc b/transformer/__pycache__/config.cpython-312.pyc new file mode 100644 index 0000000..78ab607 Binary files /dev/null and b/transformer/__pycache__/config.cpython-312.pyc differ diff --git a/transformer/__pycache__/ffn.cpython-312.pyc b/transformer/__pycache__/ffn.cpython-312.pyc new file mode 100644 index 0000000..9940eb4 Binary files /dev/null and b/transformer/__pycache__/ffn.cpython-312.pyc differ diff --git a/transformer/__pycache__/pos.cpython-312.pyc b/transformer/__pycache__/pos.cpython-312.pyc new file mode 100644 index 0000000..0a69223 Binary files /dev/null and b/transformer/__pycache__/pos.cpython-312.pyc differ diff --git a/transformer/__pycache__/transformer.cpython-312.pyc b/transformer/__pycache__/transformer.cpython-312.pyc new file mode 100644 index 0000000..85a2556 Binary files /dev/null and b/transformer/__pycache__/transformer.cpython-312.pyc differ diff --git a/transformer/__pycache__/utils.cpython-312.pyc b/transformer/__pycache__/utils.cpython-312.pyc new file mode 100644 index 0000000..2c593a9 Binary files /dev/null and b/transformer/__pycache__/utils.cpython-312.pyc differ diff --git a/transformer/config.py b/transformer/config.py index a8f261c..1c6bb4b 100644 --- a/transformer/config.py +++ b/transformer/config.py @@ -36,7 +36,7 @@ class TransformerConfig(PretrainedConfig): - If ``str``, one of ``rms_norm`` or ``layer_norm``. - If ``Type[nn.Module]`` then will be instantiated inside the model. Should have the same API as a torch Normalization Layer. - - If ``List[Union[Type[nn.Module], str]]`` and len(ffn_class) == n_layers + - If ``List[Union[Type[nn.Module], str]]`` and len(norm_class) == n_layers then will be instantiated inside the model for the corresponding layers. :type norm_class: Union[List[Union[Type[nn.Module], str]], Type[nn.Module], str] @@ -55,9 +55,9 @@ class TransformerConfig(PretrainedConfig): - If ``Type[nn.Module]`` then will be instantiated inside the model. Should have the same API as ``transformer.attn.MHA``. Default ``MHA`` - - If ``List[Union[Type[nn.Module], str]]`` and len(ffn_class) == n_layers + - If ``List[Union[Type[nn.Module], str]]`` and len(attn_class) == n_layers then will be instantiated inside the model for the corresponding layers. - Default ``SwiGLU`` for every layer. + Default ``MHA`` for every layer. :type attn_class: Union[List[Union[Type[nn.Module], str]], Type[nn.Module], str] :param block_class: Transformer Block class for every layer. Default: ``None`` @@ -87,11 +87,9 @@ class TransformerConfig(PretrainedConfig): :type seq_len: int :param pos_encoding: Positional Encoding for attention. - - If ``List[Union[Type[nn.Module], str]]`` and len(ffn_class) == n_layers - then will be instantiated inside the model for the corresponding layers. - Default ``SwiGLU`` for every layer. - If ``str`` one of ``RoPE``, ``AliBI``, ``PartialRoPE``. Default: ``RoPE`` Note: Is recommended to change the default to ``PartialRoPE`` which is used in SOTA models like Qwen3-Next-80B-A3B + - If ``List[str]`` and len(pos_encoding) == n_layers, applies different positional encodings per layer. :type pos_encoding: Union[List[str], str] :param rope_base: Base for the Exponential Frequency Calculation in RoPE. Default: ``10000.0`` @@ -100,6 +98,12 @@ class TransformerConfig(PretrainedConfig): :param max_seq_len: Maximum sequence length for positional embeddings. :type max_seq_len: int + :param use_cache: Whether to use KV cache during generation. Default: ``True`` + :type use_cache: bool, optional + + :param is_decoder: Whether this is a decoder model. Default: ``True`` + :type is_decoder: bool, optional + :param kwargs: Additional keyword arguments passed to `PretrainedConfig` :type kwargs: dict, optional @@ -127,9 +131,11 @@ def __init__( attn_dropout: Optional[float] = 0.0, tied_weights: bool = False, seq_len: int = 1024, - pos_encoding: str = "RoPE", + pos_encoding: Union[List[str], str] = "RoPE", rope_base: float = 10000.0, max_seq_len: int = 4096, + use_cache: bool = True, + is_decoder: bool = True, **kwargs: Dict, ): super().__init__(**kwargs) @@ -162,3 +168,6 @@ def __init__( self.pos_encoding = pos_encoding self.rope_base = rope_base self.max_seq_len = max_seq_len + + self.use_cache = use_cache + self.is_decoder = is_decoder diff --git a/transformer/transformer.py b/transformer/transformer.py index 4033540..d5d575d 100644 --- a/transformer/transformer.py +++ b/transformer/transformer.py @@ -1,5 +1,5 @@ import math -from typing import Any, Dict, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, Type, Union import torch import torch.nn as nn @@ -11,8 +11,8 @@ from .attns import GQA, MHA, CrossAttention from .config import TransformerConfig from .ffn import MLP, SwiGLU -from .pos import PartialRoPE, RoPE -from .utils import check_type +from .pos import PartialRoPE, RoPE, ALiBi +from .utils import check_type, resolve_layer_config class TransformerBlock(GradientCheckpointingLayer): @@ -49,8 +49,16 @@ def __init__( super().__init__() self.d_model, self.d_ff, self.n_heads, self.layer_idx = config.d_model, config.d_ff, config.n_heads, layer_idx self.norm_design = config.norm_design - - if config.attn_class == "MHA": + self.n_layers = config.n_layer + + # Resolve per-layer configurations + attn_class = resolve_layer_config(config.attn_class, layer_idx, self.n_layers) + ffn_class = resolve_layer_config(config.ffn_class, layer_idx, self.n_layers) + norm_class = resolve_layer_config(config.norm_class, layer_idx, self.n_layers) + pos_encoding = resolve_layer_config(config.pos_encoding, layer_idx, self.n_layers) if isinstance(config.pos_encoding, list) else config.pos_encoding + + # Create attention module + if attn_class == "MHA": self.attn = MHA( self.d_model, self.n_heads, @@ -58,11 +66,11 @@ def __init__( attn_bias=config.attn_bias, qk_norm=config.attn_qk_norm, layer_idx=layer_idx, - pos_encoding=config.pos_encoding, + pos_encoding=pos_encoding, max_seq_len=config.max_seq_len, **attn_kwargs, ) - elif config.attn_class == "GQA": + elif attn_class == "GQA": self.attn = GQA( self.d_model, self.n_heads, @@ -71,16 +79,23 @@ def __init__( attn_bias=config.attn_bias, qk_norm=config.attn_qk_norm, layer_idx=layer_idx, - pos_encoding=config.pos_encoding, + pos_encoding=pos_encoding, max_seq_len=config.max_seq_len, **attn_kwargs, ) - elif config.attn_class == "CrossAttention": - raise ValueError(f"Under Development: {config.attn_class}") - elif check_type(config.attn_class) == 0: - raise ValueError(f"Unknown attention type: {config.attn_class}") - elif check_type(config.attn_class) == 1: - self.attn = config.attn_class( + elif attn_class == "CrossAttention": + self.attn = CrossAttention( + self.d_model, + self.n_heads, + dropout=config.attn_dropout, + attn_bias=config.attn_bias, + qk_norm=config.attn_qk_norm, + layer_idx=layer_idx, + rope_base=config.rope_base, + max_seq_len=config.max_seq_len, + ) + elif check_type(attn_class) == 1: + self.attn = attn_class( self.d_model, self.n_heads, dropout=config.attn_dropout, @@ -88,30 +103,24 @@ def __init__( qk_norm=config.attn_qk_norm, layer_idx=layer_idx, max_seq_len=config.max_seq_len, - pos_encoding=config.pos_encoding, + pos_encoding=pos_encoding, **attn_kwargs, ) else: - raise RuntimeError( - f"TransformerConfig.attn_class Should be str or Type[nn.Module] but found: {config.attn_class}" - ) + raise ValueError(f"Unknown attention type: {attn_class}") - if config.ffn_class == "SwiGLU": + # Create feed-forward module + if ffn_class == "SwiGLU": self.ffn = SwiGLU(self.d_model, self.d_ff, bias=config.ffn_bias, **ffn_kwargs) - elif config.ffn_class == "MLP": + elif ffn_class == "MLP": self.ffn = MLP(self.d_model, self.d_ff, bias=config.ffn_bias, **ffn_kwargs) - elif config.ffn_class == "MoE": - raise ValueError(f"Under Development: {config.ffn_class}") - elif check_type(config.ffn_class) == 0: - raise ValueError(f"Unknown ffn class: {config.ffn_class}") - elif check_type(config.ffn_class) == 1: - self.ffn = config.ffn_class(self.d_model, self.d_ff, bias=config.ffn_bias, **ffn_kwargs) + elif check_type(ffn_class) == 1: + self.ffn = ffn_class(self.d_model, self.d_ff, bias=config.ffn_bias, **ffn_kwargs) else: - raise RuntimeError( - f"TransformerConfig.ffn_class Should be str or Type[nn.Module] but found: {config.ffn_class}" - ) + raise ValueError(f"Unknown ffn class: {ffn_class}") - if config.norm_class == "rms_norm": + # Create normalization modules + if norm_class == "rms_norm": if config.norm_design == "pre_norm" or config.norm_design == "post_norm": self.norm_attn, self.norm_ffn = ( nn.RMSNorm(self.d_model, **norm_kwargs), @@ -126,7 +135,7 @@ def __init__( ) else: raise ValueError(f"Invalid norm_design: {config.norm_design}") - elif config.norm_class == "layer_norm": + elif norm_class == "layer_norm": if config.norm_design == "pre_norm" or config.norm_design == "post_norm": self.norm_attn, self.norm_ffn = ( nn.LayerNorm(self.d_model, **norm_kwargs), @@ -141,27 +150,23 @@ def __init__( ) else: raise ValueError(f"Invalid norm_design: {config.norm_design}") - elif check_type(config.norm_class) == 0: - raise ValueError(f"Unknown normalization class: {config.norm_class}") - elif check_type(config.norm_class) == 1: + elif check_type(norm_class) == 1: if config.norm_design == "pre_norm" or config.norm_design == "post_norm": self.norm_attn, self.norm_ffn = ( - config.norm_class(self.d_model, **norm_kwargs), - config.norm_class(self.d_model, **norm_kwargs), + norm_class(self.d_model, **norm_kwargs), + norm_class(self.d_model, **norm_kwargs), ) elif config.norm_design == "both": self.pre_norm_attn, self.pre_norm_ffn, self.post_norm_attn, self.post_norm_ffn = ( - config.norm_class(self.d_model, **norm_kwargs), - config.norm_class(self.d_model, **norm_kwargs), - config.norm_class(self.d_model, **norm_kwargs), - config.norm_class(self.d_model, **norm_kwargs), + norm_class(self.d_model, **norm_kwargs), + norm_class(self.d_model, **norm_kwargs), + norm_class(self.d_model, **norm_kwargs), + norm_class(self.d_model, **norm_kwargs), ) else: raise ValueError(f"Invalid norm_design: {config.norm_design}") else: - raise RuntimeError( - f"TransformerConfig.norm_class Should be str or Type[nn.Module] but found: {config.norm_class}" - ) + raise ValueError(f"Unknown normalization class: {norm_class}") def forward( self, @@ -259,6 +264,15 @@ class Transformer(PreTrainedModel, GenerationMixin): :param norm_kwargs: Additional Keyword Arguments passed to the Normalization Layer. Default: ``{}`` :type norm_kwargs: Dict, optional + + :param patch_size: Patch size for Vision Transformer (ViT) compatibility. If specified, adds a patch embedding layer. + :type patch_size: Optional[int], optional + + :param img_size: Image size for ViT compatibility. Used with patch_size to compute number of patches. + :type img_size: Optional[Union[int, Tuple[int, int]]], optional + + :param num_channels: Number of input channels for ViT. Default: 3 (RGB). + :type num_channels: int, optional """ config_class = TransformerConfig @@ -268,7 +282,7 @@ class Transformer(PreTrainedModel, GenerationMixin): _supports_flash_attn = True _supports_sdpa = True - input_modalities = "text" # Will add "image" for v0.4.0 + input_modalities = ["text", "image"] def __init__( self, @@ -277,10 +291,31 @@ def __init__( pos_encoding_kwargs: Dict = {}, ffn_kwargs: Dict = {}, norm_kwargs: Dict = {}, + patch_size: Optional[int] = None, + img_size: Optional[Union[int, Tuple[int, int]]] = None, + num_channels: int = 3, ): super().__init__(config) self.config = config self.d_model = config.d_model + self.patch_size = patch_size + self.img_size = img_size + + # Vision Transformer (ViT) support + if patch_size is not None and img_size is not None: + if isinstance(img_size, int): + img_size = (img_size, img_size) + self.num_patches = (img_size[0] // patch_size) * (img_size[1] // patch_size) + self.patch_embed = nn.Conv2d( + num_channels, config.d_model, kernel_size=patch_size, stride=patch_size + ) + self.cls_token = nn.Parameter(torch.zeros(1, 1, config.d_model)) + self.pos_embed = nn.Parameter(torch.zeros(1, self.num_patches + 1, config.d_model)) + else: + self.patch_embed = None + self.cls_token = None + self.pos_embed = None + self.num_patches = None self.emb = nn.Embedding(config.vocab_size, config.d_model) block_class = config.block_class if config.block_class is not None else TransformerBlock @@ -312,6 +347,11 @@ def __init__( self.lm_head.weight = self.emb.weight else: self.lm_head.weight.data.normal_(mean=0.0, std=0.025) + + # Initialize ViT-specific parameters + if self.patch_embed is not None: + nn.init.normal_(self.cls_token, std=0.02) + nn.init.normal_(self.pos_embed, std=0.02) self.post_init() diff --git a/transformer/utils.py b/transformer/utils.py index 7848eff..2d3e57a 100644 --- a/transformer/utils.py +++ b/transformer/utils.py @@ -2,7 +2,7 @@ import os import random import sys -from typing import Dict, Optional, Tuple, Type, Union +from typing import Dict, List, Optional, Tuple, Type, Union import torch import torch.nn as nn @@ -10,9 +10,46 @@ def check_type(x: Union[Type[nn.Module], str]) -> int: + """ + Check the type of x and return an integer code. + + :param x: Object to check (should be a string or nn.Module subclass) + :type x: Union[Type[nn.Module], str] + + :return: 0 if string, 1 if nn.Module subclass + :rtype: int + + :raises TypeError: If x is neither a string nor an nn.Module subclass + """ if isinstance(x, str): return 0 elif isinstance(x, type) and issubclass(x, nn.Module): return 1 else: raise TypeError(f"Type not valid: {x}") + + +def resolve_layer_config(config_value: Union[str, Type[nn.Module], List], layer_idx: int, n_layers: int): + """ + Resolve configuration value for a specific layer index. + Supports both uniform (single value) and per-layer (list) configurations. + + :param config_value: Configuration value (string, type, or list) + :type config_value: Union[str, Type[nn.Module], List] + + :param layer_idx: Index of the current layer + :type layer_idx: int + + :param n_layers: Total number of layers + :type n_layers: int + + :return: Resolved configuration value for this layer + :raises ValueError: If list length doesn't match n_layers + """ + if isinstance(config_value, list): + if len(config_value) != n_layers: + raise ValueError( + f"List configuration must have length {n_layers}, got {len(config_value)}" + ) + return config_value[layer_idx] + return config_value