In [3]:
import pytest
import torch

from typing import List

from flair.data import Sentence

def flair_sentence(tokens):

    sentences = [Sentence(i) for i in tokens]

    return sentences

In [4]:
flair_sentence([['The', 'quick', 'brown', 'fox', 'jumps', 'over', 'the', 'lazy', 'dog']])

[Sentence[9]: "The quick brown fox jumps over the lazy dog"]

In [6]:
from glirel.modules.token_rep import TokenRepLayer

token_rep = TokenRepLayer(
        model_name="microsoft/deberta-v3-small",
        fine_tune=True,
        subtoken_pooling="first",
        hidden_size=768,
        add_tokens=[],
    )



In [9]:
token_rep = TokenRepLayer(
        model_name="microsoft/deberta-v3-small",
        fine_tune=True,
        subtoken_pooling="first",
        hidden_size=768,
        add_tokens=["[REL]", "[SEP]"],
    )



In [23]:
tokens_batch = [
    ["Hello", "world"],                 
    ["This", "is", "a", "test"],        
]
lengths = torch.tensor([len(seq) for seq in tokens_batch])  

output = token_rep(tokens_batch, lengths)

In [24]:
output

{'embeddings': tensor([[[-0.8532, -0.1729, -0.1481,  ...,  0.2372,  0.0825, -0.3372],
          [ 1.0145, -0.7746, -0.5060,  ..., -0.0447, -0.0466, -0.3296],
          [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000]],
 
         [[-1.8800, -0.1707, -0.3937,  ...,  0.0134,  0.2015,  0.1256],
          [-1.9091, -0.0522,  0.2174,  ..., -0.0082,  0.3590,  0.0646],
          [-0.1139,  0.1377,  0.2996,  ...,  0.1373,  0.6843,  0.0982],
          [ 0.6009, -0.5112, -0.0265,  ...,  0.5871,  0.5133, -0.2278]]]),
 'mask': tensor([[1, 1, 0, 0],
         [1, 1, 1, 1]])}

In [62]:
import torch
import torch.nn as nn
from transformers import AutoTokenizer, AutoModel

########################################################################
# Replicates Flair’s logic
########################################################################

def fill_masked_elements(all_token_embeddings, hidden_states, mask, word_ids, lengths):
    """
    For 'first' or 'last' subtoken pooling: copy out exactly the subtoken embeddings
    that match the mask + are valid word_ids, and place them in the correct positions.
    """
    batch_size = all_token_embeddings.size(0)
    for i in range(batch_size):
        keep = hidden_states[i][mask[i] & (word_ids[i] >= 0)]
        replaced = insert_missing_embeddings(keep, word_ids[i], lengths[i])
        all_token_embeddings[i, : lengths[i], :] = replaced
    return all_token_embeddings


def insert_missing_embeddings(token_embeddings, word_ids_i, length_i):
    """
    If some token indices [0..length_i-1] never appeared in 'token_embeddings',
    insert zero-vectors at those positions.
    """
    if token_embeddings.size(0) == 0:
        # No subtokens found at all, so fill with zeros
        return torch.zeros(
            int(length_i),
            token_embeddings.size(-1),
            device=token_embeddings.device,
            dtype=token_embeddings.dtype
        )
    elif token_embeddings.size(0) < length_i:
        # Potentially insert zero-vectors for any missing token positions
        for idx in range(int(length_i)):
            if not (word_ids_i == idx).any():
                zero_vec = torch.zeros_like(token_embeddings[:1])
                token_embeddings = torch.cat(
                    (token_embeddings[:idx], zero_vec, token_embeddings[idx:]),
                    dim=0
                )
    return token_embeddings


def fill_mean_token_embeddings(all_token_embeddings, hidden_states, word_ids, token_lengths):
    """
    For 'mean' subtoken pooling: sum all subtoken embeddings for each token ID
    and divide by the subtoken count.
    """
    bsz, max_tokens, emb_dim = all_token_embeddings.shape
    # mask to ignore special tokens (CLS, SEP, or None)
    mask = (word_ids >= 0)

    # sum embeddings for each (batch, token_id)
    all_token_embeddings.scatter_add_(
        dim=1,
        index=word_ids.clamp(min=0).unsqueeze(-1).expand(-1, -1, emb_dim),
        src=hidden_states * mask.unsqueeze(-1).float(),
    )

    # count how many subtokens contributed per token
    subtoken_counts = torch.zeros_like(all_token_embeddings[:, :, 0])
    subtoken_counts.scatter_add_(
        1,
        word_ids.clamp(min=0),
        mask.float()
    )

    # average
    all_token_embeddings = torch.where(
        subtoken_counts.unsqueeze(-1) > 0,
        all_token_embeddings / subtoken_counts.unsqueeze(-1),
        torch.zeros_like(all_token_embeddings),
    )

    # zero out positions beyond the actual token length
    max_len = max_tokens
    idx_range = torch.arange(max_len, device=token_lengths.device).unsqueeze(0)
    valid_mask = (idx_range < token_lengths.unsqueeze(1))
    all_token_embeddings = all_token_embeddings * valid_mask.unsqueeze(-1)

    return all_token_embeddings

class CustomTransformerWordEmbeddings(nn.Module):
    """
    A drop-in replacement for Flair's `TransformerWordEmbeddings`:
    """
    def __init__(self, model_name: str, fine_tune: bool, subtoken_pooling: str, allow_long_sentences: bool = True):
        """
        :param model_name: Hugging Face model ID or path
        :param fine_tune: Whether to keep the model parameters trainable
        :param subtoken_pooling: 'first', 'last', 'mean', or 'first_last'
        :param allow_long_sentences: If True, we won't chunk or break up long inputs
        """
        super().__init__()
        self.name = f"CustomTransformerWordEmbeddings({model_name})"
        self.model_name = model_name
        self.fine_tune = fine_tune
        self.subtoken_pooling = subtoken_pooling
        self.allow_long_sentences = allow_long_sentences

        # Load model and tokenizer
        self.tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
        self.model = AutoModel.from_pretrained(model_name)

        # Freeze or unfreeze
        if not fine_tune:
            for p in self.model.parameters():
                p.requires_grad = False

        # If we do 'first_last', dimension doubles
        hidden_size = self.model.config.hidden_size
        if subtoken_pooling == "first_last":
            self._embedding_length = hidden_size * 2
        else:
            self._embedding_length = hidden_size

        # We'll store embeddings under this `self.name`

    @property
    def embedding_length(self) -> int:
        return self._embedding_length

    def embed(self, sentences):
        """
        Expects a list of "sentence-like" objects.
        Each "sentence" must have a `.tokens` list.
        Each "token" must have:
          - a .text attribute
          - a .set_embedding(name, vector) method
        """
        if not sentences:
            return

        # Prepare input: list of list-of-strings
        batch_of_lists = []
        for s in sentences:
            # s.tokens is a list of tokens, each with .text
            batch_of_lists.append([t.text for t in s.tokens])

        # Tokenize using HF
        encoding = self.tokenizer(
            batch_of_lists,
            is_split_into_words=True,
            return_tensors='pt',
            padding=True,
            truncation=not self.allow_long_sentences
        )
        input_ids = encoding['input_ids']
        attention_mask = encoding['attention_mask']
        device = next(self.model.parameters()).device  # move to same device as model
        input_ids = input_ids.to(device)
        attention_mask = attention_mask.to(device)

        # Forward pass
        outputs = self.model(input_ids, attention_mask=attention_mask, return_dict=True)
        # shape = [batch_size, seq_len, hidden_dim]
        last_hidden = outputs.last_hidden_state

        batch_size, seq_len, hidden_dim = last_hidden.shape

        # Reconstruct which subtoken belongs to which token
        # word_ids(i) => a list of length seq_len with an integer or None
        word_ids_batch = []
        max_token_count = 0
        for i in range(batch_size):
            w_ids = encoding.word_ids(batch_index=i)
            if w_ids is None:
                # fallback (slow tokenizer) => all None
                w_ids = [None]*seq_len
            # figure out how many tokens are in that sample
            valid_ids = [x for x in w_ids if x is not None]
            if valid_ids:
                max_id = max(valid_ids)
                token_count = max_id + 1
            else:
                token_count = 0
            if token_count > max_token_count:
                max_token_count = token_count
            word_ids_batch.append(w_ids)

        # Build a [batch_size, seq_len] tensor of token IDs, or -100 if None
        word_ids_tensor = torch.full((batch_size, seq_len), -100, dtype=torch.long, device=device)
        for i in range(batch_size):
            for j, w_id in enumerate(word_ids_batch[i]):
                if w_id is not None:
                    word_ids_tensor[i, j] = w_id

        # Token lengths per sentence
        token_lengths = []
        for i in range(batch_size):
            valid = [x for x in word_ids_batch[i] if x is not None]
            token_lengths.append((max(valid)+1) if valid else 0)
        token_lengths_tensor = torch.tensor(token_lengths, device=device, dtype=torch.long)

        # Prepare final [batch_size, max_token_count, embedding_dim]
        embed_dim = self.embedding_length
        all_token_embeddings = torch.zeros(
            (batch_size, max_token_count, embed_dim),
            device=device, dtype=last_hidden.dtype
        )

        # Subtoken pooling
        if self.subtoken_pooling == "first":
            # 'first' subtoken => mask out the beginning of each word
            gain_mask = (word_ids_tensor[:, 1:] != word_ids_tensor[:, :-1])
            # first position is always True
            true_tensor = torch.ones((batch_size, 1), dtype=torch.bool, device=device)
            first_mask = torch.cat([true_tensor, gain_mask], dim=1)
            fill_masked_elements(all_token_embeddings, last_hidden, first_mask, word_ids_tensor, token_lengths_tensor)

        elif self.subtoken_pooling == "last":
            # 'last' subtoken => mask out the boundary at the next subtoken
            gain_mask = (word_ids_tensor[:, 1:] != word_ids_tensor[:, :-1])
            true_end = torch.ones((batch_size, 1), dtype=torch.bool, device=device)
            last_mask = torch.cat([gain_mask, true_end], dim=1)
            fill_masked_elements(all_token_embeddings, last_hidden, last_mask, word_ids_tensor, token_lengths_tensor)

        elif self.subtoken_pooling == "first_last":
            # doubling hidden size => first half for 'first', second half for 'last'
            real_hsize = self.model.config.hidden_size
            gain_mask = (word_ids_tensor[:, 1:] != word_ids_tensor[:, :-1])
            ones = torch.ones((batch_size, 1), dtype=torch.bool, device=device)
            first_mask = torch.cat([ones, gain_mask], dim=1)
            last_mask = torch.cat([gain_mask, ones], dim=1)
            # fill first half
            fill_masked_elements(
                all_token_embeddings[:, :, :real_hsize],
                last_hidden, first_mask, word_ids_tensor, token_lengths_tensor
            )
            # fill second half
            fill_masked_elements(
                all_token_embeddings[:, :, real_hsize:],
                last_hidden, last_mask, word_ids_tensor, token_lengths_tensor
            )

        elif self.subtoken_pooling == "mean":
            fill_mean_token_embeddings(all_token_embeddings, last_hidden, word_ids_tensor, token_lengths_tensor)
        else:
            raise ValueError(f"Unknown subtoken_pooling={self.subtoken_pooling}")

        # Now store each token's embedding
        # For each sample i
        for i, sentence in enumerate(sentences):
            length_i = token_lengths[i]
            # slice out the relevant portion
            embs_i = all_token_embeddings[i, :length_i]  # shape [length_i, embed_dim]
            # set embedding on each token
            for token_idx, token in enumerate(sentence.tokens):
                token.set_embedding(self.name, embs_i[token_idx])

    def __str__(self):
        return self.name


In [63]:
from typing import List
import torch
from torch import nn
from torch.nn.utils.rnn import pad_sequence

class ModifiedTokenRepLayer(nn.Module):
    def __init__(self, model_name: str, fine_tune: bool, subtoken_pooling: str,
                 hidden_size: int, add_tokens: List[str]):
        super().__init__()

        self.bert_layer = CustomTransformerWordEmbeddings(
            model_name,
            fine_tune=fine_tune,
            subtoken_pooling=subtoken_pooling,
            allow_long_sentences=True
        )

        # Add tokens to vocabulary
        self.bert_layer.tokenizer.add_tokens(add_tokens)

        # Resize token embeddings
        self.bert_layer.model.resize_token_embeddings(len(self.bert_layer.tokenizer))

        bert_hidden_size = self.bert_layer.embedding_length

        if hidden_size != bert_hidden_size:
            self.projection = nn.Linear(bert_hidden_size, hidden_size)

    def forward(self, tokens: List[List[str]], lengths: torch.Tensor):
        token_embeddings = self.compute_word_embedding(tokens)

        if hasattr(self, "projection"):
            token_embeddings = self.projection(token_embeddings)

        B = len(lengths)
        max_length = lengths.max()
        mask = (torch.arange(max_length).view(1, -1).repeat(B, 1)
                < lengths.cpu().unsqueeze(1)).to(token_embeddings.device).long()

        return {"embeddings": token_embeddings, "mask": mask}

    def compute_word_embedding(self, tokens):
        # sentences = [Sentence(i) for i in tokens]
        # self.bert_layer.embed(sentences)
        # we just replicate that approach, but we need minimal "Sentence" and "Token" classes:

        sentences = [MinimalSentence(toks) for toks in tokens]   # see definition below
        #sentences = [Sentence(i) for i in tokens]
        self.bert_layer.embed(sentences)
        # gather embeddings
        token_embeddings = pad_sequence(
            [torch.stack([tok.get_embedding(self.bert_layer.name) for tok in s.tokens])
             for s in sentences],
            batch_first=True
        )
        return token_embeddings

class MinimalToken:
    def __init__(self, text: str):
        self.text = text
        self._embeddings = {}

    def set_embedding(self, name: str, vector: torch.Tensor):
        self._embeddings[name] = vector

    def get_embedding(self, name: str) -> torch.Tensor:
        return self._embeddings[name]

class MinimalSentence:
    def __init__(self, list_of_words: List[str]):
        self.tokens = [MinimalToken(w) for w in list_of_words]


In [25]:
modified_token_rep = ModifiedTokenRepLayer(
        model_name="microsoft/deberta-v3-small",
        fine_tune=True,
        subtoken_pooling="first",
        hidden_size=768,
        add_tokens=["[REL]", "[SEP]"],
    )
    
tokens_batch = [
    ["Hello", "world"],                 
    ["This", "is", "a", "test"],        
]
lengths = torch.tensor([len(seq) for seq in tokens_batch])  

output = modified_token_rep(tokens_batch, lengths)



In [26]:
output

{'embeddings': tensor([[[-0.8532, -0.1729, -0.1481,  ...,  0.2372,  0.0825, -0.3372],
          [ 1.0145, -0.7746, -0.5060,  ..., -0.0447, -0.0466, -0.3296],
          [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000]],
 
         [[-1.8800, -0.1707, -0.3937,  ...,  0.0134,  0.2015,  0.1256],
          [-1.9091, -0.0522,  0.2174,  ..., -0.0082,  0.3590,  0.0646],
          [-0.1139,  0.1377,  0.2996,  ...,  0.1373,  0.6843,  0.0982],
          [ 0.6009, -0.5112, -0.0265,  ...,  0.5871,  0.5133, -0.2278]]],
        grad_fn=<CopySlices>),
 'mask': tensor([[1, 1, 0, 0],
         [1, 1, 1, 1]])}

In [31]:
token_rep = TokenRepLayer(
        model_name="microsoft/deberta-v3-small",
        fine_tune=True,
        subtoken_pooling="first",
        hidden_size=768,
        add_tokens=["[REL]", "[SEP]"],
    )
    
tokens_batch = [
    ["Hello", "world"],                 
    ["This", "is", "a", "test"],        
]
lengths = torch.tensor([len(seq) for seq in tokens_batch])  

original_output = token_rep(tokens_batch, lengths)



In [32]:
original_output

{'embeddings': tensor([[[-0.8532, -0.1729, -0.1481,  ...,  0.2372,  0.0825, -0.3372],
          [ 1.0145, -0.7746, -0.5060,  ..., -0.0447, -0.0466, -0.3296],
          [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000]],
 
         [[-1.8800, -0.1707, -0.3937,  ...,  0.0134,  0.2015,  0.1256],
          [-1.9091, -0.0522,  0.2174,  ..., -0.0082,  0.3590,  0.0646],
          [-0.1139,  0.1377,  0.2996,  ...,  0.1373,  0.6843,  0.0982],
          [ 0.6009, -0.5112, -0.0265,  ...,  0.5871,  0.5133, -0.2278]]]),
 'mask': tensor([[1, 1, 0, 0],
         [1, 1, 1, 1]])}

In [35]:
original_output["embeddings"].shape

torch.Size([2, 4, 768])

In [36]:
output["embeddings"].shape

torch.Size([2, 4, 768])

In [42]:
output["embeddings"]

tensor([[[-0.8532, -0.1729, -0.1481,  ...,  0.2372,  0.0825, -0.3372],
         [ 1.0145, -0.7746, -0.5060,  ..., -0.0447, -0.0466, -0.3296],
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000]],

        [[-1.8800, -0.1707, -0.3937,  ...,  0.0134,  0.2015,  0.1256],
         [-1.9091, -0.0522,  0.2174,  ..., -0.0082,  0.3590,  0.0646],
         [-0.1139,  0.1377,  0.2996,  ...,  0.1373,  0.6843,  0.0982],
         [ 0.6009, -0.5112, -0.0265,  ...,  0.5871,  0.5133, -0.2278]]],
       grad_fn=<CopySlices>)

In [44]:
torch.equal(original_output["embeddings"], output["embeddings"])

True

In [43]:
torch.allclose(original_output["embeddings"], output["embeddings"])

True

In [60]:
model_name = "microsoft/deberta-v3-small"
fine_tune = True
subtoken_pooling = "first"
hidden_size = 768
add_tokens = ["[REL]", "[SEP]"]

modified_token_rep = ModifiedTokenRepLayer(
        model_name=model_name,
        fine_tune=fine_tune,
        subtoken_pooling=subtoken_pooling,
        hidden_size=hidden_size,
        add_tokens=add_tokens,
    )   

token_rep = TokenRepLayer(
        model_name=model_name,
        fine_tune=fine_tune,
        subtoken_pooling=subtoken_pooling,
        hidden_size=hidden_size,
        add_tokens=add_tokens,
    )
    
tokens_batch = [
    ["Hello", "world"],                 
    ["This", "is", "a", "test"],        
]
lengths = torch.tensor([len(seq) for seq in tokens_batch])  

output = modified_token_rep(tokens_batch, lengths)
original_output = token_rep(tokens_batch, lengths)

print(torch.equal(original_output["embeddings"], output["embeddings"]))
print(torch.allclose(original_output["embeddings"], output["embeddings"]))

True
True


In [None]:
# First_last subtoken pooling is failing

In [54]:
original_output

{'embeddings': tensor([[[-7.4865e-01,  2.2039e-01,  4.4916e-01,  3.0461e-02,  4.6049e-01,
           -5.8071e-01,  3.2968e-01,  8.3875e-01],
          [ 3.8578e-02,  2.7650e-01,  1.8024e-01,  6.8406e-02,  3.4463e-01,
           -3.5767e-01,  1.0805e+00,  2.8995e-01],
          [ 1.1109e-02,  1.9507e-02, -5.8411e-03,  2.1758e-02, -3.4151e-02,
            7.7907e-03, -2.6146e-02,  3.5306e-02],
          [ 1.1109e-02,  1.9507e-02, -5.8411e-03,  2.1758e-02, -3.4151e-02,
            7.7907e-03, -2.6146e-02,  3.5306e-02]],
 
         [[-1.5341e-01,  2.5130e-01, -4.9714e-01,  3.2337e-01,  3.8418e-02,
           -3.8637e-01, -3.1271e-01,  1.1782e+00],
          [-6.3128e-01,  7.4822e-02, -6.0581e-01,  5.1159e-01,  1.4507e-01,
           -5.4399e-01, -8.3792e-01,  6.5109e-01],
          [ 2.0277e-02, -1.7902e-01, -3.0790e-04,  1.2179e-01, -3.9742e-01,
           -7.8828e-02, -1.6313e-01,  9.0789e-01],
          [-4.3891e-01, -7.5439e-01, -1.9608e-01, -1.2052e-01, -5.9269e-01,
            6.7124

In [55]:
output

{'embeddings': tensor([[[-0.9712,  0.2857, -0.1243,  0.7853, -0.3994, -0.3484, -0.2385,
            0.6714],
          [-0.0792,  0.2863,  0.3726,  1.2731, -0.3714, -0.1627, -0.2779,
            0.4744],
          [ 0.0170,  0.0078, -0.0107, -0.0303,  0.0333,  0.0278, -0.0318,
            0.0191],
          [ 0.0170,  0.0078, -0.0107, -0.0303,  0.0333,  0.0278, -0.0318,
            0.0191]],
 
         [[-0.8322, -0.1813, -0.2201,  0.7731, -0.4976,  0.2493,  0.2634,
           -0.4637],
          [-0.5629,  0.0897,  0.3572,  0.6093, -0.4855, -0.2258,  0.2375,
            0.2057],
          [-0.4416, -0.2267,  0.3529,  0.7161, -0.8860,  0.3869,  0.1543,
            0.1356],
          [ 0.2068,  0.2609,  0.2046,  0.6199, -1.5394, -0.3388,  0.2332,
            0.0019]]], grad_fn=<ViewBackward0>),
 'mask': tensor([[1, 1, 0, 0],
         [1, 1, 1, 1]])}