# sentenceLUKE
* https://huggingface.co/sonoisa/sentence-luke-japanese-base-lite

In [1]:
%%capture
!pip install transformers
!pip install sentencepiece

In [2]:
from transformers import MLukeTokenizer, LukeModel
import torch

In [4]:
class SentenceLukeJapanese:
    def __init__(self, model_name_or_path, device=None):
        self.tokenizer = MLukeTokenizer.from_pretrained(model_name_or_path)
        self.model = LukeModel.from_pretrained(model_name_or_path)
        self.model.eval()

        if device is None:
            device = "cuda" if torch.cuda.is_available() else "cpu"
        self.device = torch.device(device)
        self.model.to(device)

    def _mean_pooling(self, model_output, attention_mask):
        token_embeddings = model_output[0] #First element of model_output contains all token embeddings
        input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
        return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)

    @torch.no_grad()
    def encode(self, sentences, batch_size=8):
        all_embeddings = []
        iterator = range(0, len(sentences), batch_size)
        for batch_idx in iterator:
            batch = sentences[batch_idx:batch_idx + batch_size]

            encoded_input = self.tokenizer.batch_encode_plus(batch, padding="longest", 
                                           truncation=True, return_tensors="pt").to(self.device)
            model_output = self.model(**encoded_input)
            sentence_embeddings = self._mean_pooling(model_output, encoded_input["attention_mask"]).to('cpu')

            all_embeddings.extend(sentence_embeddings)

        return torch.stack(all_embeddings)


In [5]:
MODEL_NAME = "sonoisa/sentence-luke-japanese-base-lite"
model = SentenceLukeJapanese(MODEL_NAME)

sentences = ["量子コンピュータを使うデータサイエンティスト"]
sentence_embeddings = model.encode(sentences, batch_size=8)

print("Sentence embeddings:", sentence_embeddings)

Sentence embeddings: tensor([[-5.0329e-01, -6.4833e-01,  2.9950e-01,  1.7696e-01, -1.0518e-01,
          1.4911e-01,  5.9331e-01,  1.1000e-01,  1.5425e-02,  1.0512e-01,
         -8.0088e-01, -2.1464e-01,  1.0730e-01, -3.3763e-01, -6.3525e-01,
          2.0462e-01,  1.7721e-01,  1.4496e-01, -1.3519e-01, -7.5960e-02,
          2.0178e-01,  4.8294e-02, -6.5613e-01, -3.7342e-01,  5.7285e-03,
         -2.4720e-01, -1.0586e+00, -2.8139e-01,  2.3205e-01, -3.7691e-01,
         -7.1272e-02,  2.4130e-01,  8.6934e-01, -4.7366e-01,  1.1664e-01,
         -1.6697e-01, -1.9240e-01,  4.6092e-01,  4.5507e-02, -4.7678e-01,
         -8.3307e-02, -5.0876e-01,  7.7227e-01,  1.8693e-01, -2.6294e-01,
          1.5385e-01,  2.7899e-01, -1.4505e-01,  4.4026e-01, -7.1868e-01,
          2.3138e-01, -1.5273e-01,  3.8154e-01,  1.1373e-01, -3.5916e-01,
          2.2289e-01,  3.7619e-01,  1.5584e-01,  8.4267e-02, -1.3700e-01,
         -1.2963e-01, -3.1573e-01, -1.4394e-01,  1.5690e-01, -4.9301e-01,
         -6.5670e