Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sinusoidal Position Embedding weights somehow get altered #31387

Closed
1 of 4 tasks
ZhiyuanChen opened this issue Jun 12, 2024 · 1 comment
Closed
1 of 4 tasks

Sinusoidal Position Embedding weights somehow get altered #31387

ZhiyuanChen opened this issue Jun 12, 2024 · 1 comment

Comments

@ZhiyuanChen
Copy link
Contributor

ZhiyuanChen commented Jun 12, 2024

System Info

  • transformers version: 4.41.2
  • Platform: macOS-15.0-arm64-arm-64bit
  • Python version: 3.10.14
  • Huggingface_hub version: 0.23.3
  • Safetensors version: 0.4.3
  • Accelerate version: 0.30.1
  • Accelerate config: not found
  • PyTorch version (GPU?): 2.2.2 (False)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed

Who can help?

No response

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

  1. use the modelling file at here
  2. use the model weights at here
  3. build model with model.from_pretrained

Expected behavior

Weights of Sinusoidal Position Embedding should be static.

But after model.from_pretrained, the weights are changed. Even though weights of position embeddings does not appears in state dict.

Notes

I have the following Sinusoidal Position Embedding implementation

# MultiMolecule
# Copyright (C) 2024-Present  MultiMolecule
# Copyright (C) 2020 The Facebook AI Research Team Authors
# Copyright (C) 2020 The HuggingFace Inc. team.

# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# any later version.

# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU Affero General Public License for more details.

# You should have received a copy of the GNU Affero General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.

from __future__ import annotations

import math

import torch
import torch.onnx.operators
from torch import Tensor, nn


class SinusoidalEmbedding(nn.Embedding):
    """
    This module produces sinusoidal positional embeddings of any length.

    We don't want to save the weight of this embedding since it's not trained (deterministic) and it can be huge.

    Padding symbols are ignored.

    These embeddings get automatically extended in forward if more positions is needed.
    """

    # _is_hf_initialized = True

    def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int | None = None):
        weight = self.get_embedding(num_embeddings, embedding_dim, padding_idx)
        super().__init__(num_embeddings, embedding_dim, padding_idx, _weight=weight.detach(), _freeze=True)

    def update_weight(self, num_embeddings: int, embedding_dim: int, padding_idx: int | None = None):
        weight = self.get_embedding(num_embeddings, embedding_dim, padding_idx).to(
            dtype=self.weight.dtype, device=self.weight.device  # type: ignore[has-type]
        )
        self.weight = nn.Parameter(weight.detach(), requires_grad=False)

    @staticmethod
    def get_embedding(num_embeddings: int, embedding_dim: int, padding_idx: int | None = None) -> Tensor:
        """
        Build sinusoidal embeddings.

        This matches the implementation in tensor2tensor, but differs slightly from the description in Section 3.5 of
        "Attention Is All You Need".
        """
        half_dim = embedding_dim // 2
        emb = torch.exp(torch.arange(half_dim, dtype=torch.float) * -(math.log(10000) / (half_dim - 1)))
        emb = torch.arange(num_embeddings, dtype=torch.float).unsqueeze(1) * emb.unsqueeze(0)
        emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view(num_embeddings, -1)
        if embedding_dim % 2 == 1:
            # zero pad
            emb = torch.cat([emb, torch.zeros(num_embeddings, 1)], dim=1)
        if padding_idx is not None:
            emb[padding_idx, :] = 0
        return emb

    @staticmethod
    def make_positions(tensor, padding_idx: int):
        """
        Replace non-padding symbols with their position numbers.

        Position numbers begin at padding_idx+1. Padding symbols are ignored.
        """
        # The series of casts and type-conversions here are carefully
        # balanced to both work with ONNX export and XLA. In particular XLA
        # prefers ints, cumsum defaults to output longs, and ONNX doesn't know
        # how to handle the dtype kwarg in cumsum.
        mask = tensor.ne(padding_idx).int()
        return (torch.cumsum(mask, dim=1).type_as(mask) * mask).long() + padding_idx

    def forward(self, input: Tensor):
        _, seq_len = input.shape[:2]
        max_pos = seq_len
        if self.padding_idx is not None:
            max_pos += self.padding_idx + 1
        if max_pos > self.weight.size(0):
            # expand embeddings if needed
            self.update_weight(max_pos, self.embedding_dim, self.padding_idx)
        positions = self.make_positions(input, self.padding_idx)
        return super().forward(positions)

This should be equivalent to the one used in msft

So this issue should also apply to methods in transformers library.

@ZhiyuanChen
Copy link
Contributor Author

The current workaround is to overwrite _load_from_state_dict, as from_pretrained do not call load_state_dict.
I'm still inspecting where it is changed.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant