Skip to content

Commit

Permalink
Fix transformer implementation.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 542529738
Change-Id: Ic45ab86839e6af51da6bb047cded0a1d53538140
  • Loading branch information
anianruoss committed Jun 27, 2023
1 parent 3e46a29 commit 51bd6c4
Showing 1 changed file with 4 additions and 3 deletions.
7 changes: 4 additions & 3 deletions models/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

"""Transformer model."""

import dataclasses
from typing import Callable, Optional

import chex
Expand Down Expand Up @@ -51,14 +52,14 @@ class TransformerConfig:
# The size of the sliding attention window. See MultiHeadDotProductAttention.
attention_window: Optional[int] = None
# The positional encoding used with default sin/cos (Vaswani et al., 2017).
positional_encodings: pos_encs_lib.PositionalEncodings = (
pos_encs_lib.PositionalEncodings.SIN_COS
positional_encodings: pos_encs_lib.PositionalEncodings = dataclasses.field(
default_factory=lambda: pos_encs_lib.PositionalEncodings.SIN_COS
)
# The maximum size of the context (used by the posiitonal encodings).
max_time: int = 10_000
# The parameters for the positional encodings, default sin/cos.
positional_encodings_params: pos_encs_lib.PositionalEncodingsParams = (
pos_encs_lib.SinCosParams()
dataclasses.field(default_factory=pos_encs_lib.SinCosParams)
)
# How much larger the hidden layer of the feedforward network should be
# compared to the `embedding_dim`.
Expand Down

0 comments on commit 51bd6c4

Please sign in to comment.