Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for instructur/instructor-xl models #125

Open
BBC-Esq opened this issue Mar 2, 2024 · 9 comments
Open

Support for instructur/instructor-xl models #125

BBC-Esq opened this issue Mar 2, 2024 · 9 comments

Comments

@BBC-Esq
Copy link

BBC-Esq commented Mar 2, 2024

Can you please support the instructor models here?

https://github.com/xlang-ai/instructor-embedding

These are arguably the best models for their sizes.

@michaelfeil
Copy link
Owner

Can you provide a minimal code using just huggingface transformers and sentence-transformers?
What does it take to run forward passes with them?

The instructor models and the package are slightly outdated - I would consider accepting a PR if this is easy to maintain & does not come with extra dependencies.

@BBC-Esq
Copy link
Author

BBC-Esq commented Mar 4, 2024

Can you provide a minimal code using just huggingface transformers and sentence-transformers? What does it take to run forward passes with them?

The instructor models and the package are slightly outdated - I would consider accepting a PR if this is easy to maintain & does not come with extra dependencies.

I don't use huggingface transformers nor sentence-transformers directly, but rather the "instructor" integration with Langchain. I can give specific examples of that if you wanted? At any rate, the source code clearly uses transformers and sentence-transformers libraries.

Wheel is here, open with 7-zip, and look at instructor.py:

https://files.pythonhosted.org/packages/6c/fc/64375441f43cc9ddc81f76a1a8f516e6d63f5b6ecb67fffdcddc0445f0d3/InstructorEmbedding-1.0.1-py2.py3-none-any.whl

"Instructor-xl" is better than any of the sentence-transformers and/or huggingface models IMHO, although I haven't tried the recent ones mixing Mistral, etc. with embedding models...I'm referring to the traditional "gtr" etc.

Instructor does also have instructions on their github on how to quantize it...perhaps this is a viable alternative if we can't get it implemented in straight-transformers/sentence-transformers/flash-attention2 and all that stuff...Here's how they say to do it, BUT I'VE NEVER AFTER SPENDING HOURS figured out how to get it working!

To [Quantize](https://pytorch.org/docs/stable/quantization.html) the Instructor embedding model, run the following code:

# imports 
import torch
from InstructorEmbedding import INSTRUCTOR

# load the model 
model = INSTRUCTOR('hkunlp/instructor-large', device='cpu')  # you can use GPU

# quantize the model 
qmodel = torch.quantization.quantize_dynamic(
model, {torch.nn.Linear}, dtype=torch.qint8)

# Inference 
sentence = "3D ActionSLAM: wearable person tracking in multi-floor environments"
instruction = "Represent the Science title:"

embeddings = qmodel.encode([[instruction,sentence]])  
# you can also normalize the embeddings:  normalize_embeddings=True 

print(f"Quantized Embeddings:\n {embeddings}")

Taken from: https://github.com/xlang-ai/instructor-embedding?tab=readme-ov-file#quantization

If would actually pay someone if they could convert instructor (all three sizes) to ctranslate2...that's how much respect I have for it.

With that being said, I love your repository and appreciate your working on hf-hub-translate before that...Here's where I'm contemplating using all ctranslate2 embedding models in my program:

BBC-Esq/VectorDB-Plugin-for-LM-Studio#143

It's fun to include multiple models, all of which can be seen here:

https://github.com/BBC-Esq/ChromaDB-Plugin-for-LM-Studio/blob/main/src/constants.py

If I do move to Infinity in the near future...it'll be lopsided to have all other embedding models incredibly efficient with ctranslate2, basically, except instructor...

@michaelfeil michaelfeil changed the title PLEASE support instructur/instructor-xl models Support for instructur/instructor-xl models Mar 7, 2024
@BBC-Esq
Copy link
Author

BBC-Esq commented Mar 12, 2024

I just tried to do this PR but don't know what I'm doing at all...please help. I noticed that you've done converters before.

OpenNMT/CTranslate2#1639

@BBC-Esq
Copy link
Author

BBC-Esq commented Mar 18, 2024

BTW, I think that sentence transformers just started supporting instructor models recently!!!

https://www.sbert.net/docs/pretrained_models.html#instructor-models

@BBC-Esq
Copy link
Author

BBC-Esq commented Mar 18, 2024

But "reportedly" they're only supported up to sentence-transformers 2.2.2, but I haven't had a chance to test this yet...

UKPLab/sentence-transformers#2474 (comment)

@dharless-eli-lilly
Copy link

Facing this issue now. Is there any way to work around this if using the instructor from Langchain?

@BBC-Esq
Copy link
Author

BBC-Esq commented May 9, 2024

I finally resolved it myself but haven't had a chance to upload the script to any fork of any repo.

@BBC-Esq
Copy link
Author

BBC-Esq commented May 9, 2024

Here's the modified script. Let me know if it works for you as I'd appreciate the shout out. ;-)

name the script "instructor" and replace the "instructor.py" script in the package that's normally pip installed.

REVISED INSTRUCTOR SCRIPT HERE
# This script is based on the modifications from https://github.com/UKPLab/sentence-transformers
import importlib
import json
import os
from collections import OrderedDict
from typing import Union

import numpy as np
import torch
from sentence_transformers import SentenceTransformer
from sentence_transformers.models import Transformer
from torch import Tensor, nn
from tqdm.autonotebook import trange
from transformers import AutoConfig, AutoTokenizer
from sentence_transformers.util import disabled_tqdm
from huggingface_hub import snapshot_download


def batch_to_device(batch, target_device: str):
    for key in batch:
        if isinstance(batch[key], Tensor):
            batch[key] = batch[key].to(target_device)
    return batch


class INSTRUCTORPooling(nn.Module):
    """Performs pooling (max or mean) on the token embeddings.

    Using pooling, it generates from a variable sized sentence a fixed sized sentence embedding.
    This layer also allows to use the CLS token if it is returned by the underlying word embedding model.
    You can concatenate multiple poolings together.

    :param word_embedding_dimension: Dimensions for the word embeddings
    :param pooling_mode: Can be a string: mean/max/cls. If set, overwrites the other pooling_mode_* settings
    :param pooling_mode_cls_token: Use the first token (CLS token) as text representations
    :param pooling_mode_max_tokens: Use max in each dimension over all tokens.
    :param pooling_mode_mean_tokens: Perform mean-pooling
    :param pooling_mode_mean_sqrt_len_tokens: Perform mean-pooling, but divide by sqrt(input_length).
    :param pooling_mode_weightedmean_tokens: Perform (position) weighted mean pooling,
    see https://arxiv.org/abs/2202.08904
    :param pooling_mode_lasttoken: Perform last token pooling,
    see https://arxiv.org/abs/2202.08904 & https://arxiv.org/abs/2201.10005
    """

    def __init__(
        self,
        word_embedding_dimension: int,
        pooling_mode: Union[str, None] = None,
        pooling_mode_cls_token: bool = False,
        pooling_mode_max_tokens: bool = False,
        pooling_mode_mean_tokens: bool = True,
        pooling_mode_mean_sqrt_len_tokens: bool = False,
        pooling_mode_weightedmean_tokens: bool = False,
        pooling_mode_lasttoken: bool = False,
    ):
        super().__init__()

        self.config_keys = [
            "word_embedding_dimension",
            "pooling_mode_cls_token",
            "pooling_mode_mean_tokens",
            "pooling_mode_max_tokens",
            "pooling_mode_mean_sqrt_len_tokens",
            "pooling_mode_weightedmean_tokens",
            "pooling_mode_lasttoken",
        ]

        if pooling_mode is not None:  # Set pooling mode by string
            pooling_mode = pooling_mode.lower()
            assert pooling_mode in ["mean", "max", "cls", "weightedmean", "lasttoken"]
            pooling_mode_cls_token = pooling_mode == "cls"
            pooling_mode_max_tokens = pooling_mode == "max"
            pooling_mode_mean_tokens = pooling_mode == "mean"
            pooling_mode_weightedmean_tokens = pooling_mode == "weightedmean"
            pooling_mode_lasttoken = pooling_mode == "lasttoken"

        self.word_embedding_dimension = word_embedding_dimension
        self.pooling_mode_cls_token = pooling_mode_cls_token
        self.pooling_mode_mean_tokens = pooling_mode_mean_tokens
        self.pooling_mode_max_tokens = pooling_mode_max_tokens
        self.pooling_mode_mean_sqrt_len_tokens = pooling_mode_mean_sqrt_len_tokens
        self.pooling_mode_weightedmean_tokens = pooling_mode_weightedmean_tokens
        self.pooling_mode_lasttoken = pooling_mode_lasttoken

        pooling_mode_multiplier = sum(
            [
                pooling_mode_cls_token,
                pooling_mode_max_tokens,
                pooling_mode_mean_tokens,
                pooling_mode_mean_sqrt_len_tokens,
                pooling_mode_weightedmean_tokens,
                pooling_mode_lasttoken,
            ]
        )
        self.pooling_output_dimension = (
            pooling_mode_multiplier * word_embedding_dimension
        )

    def __repr__(self):
        return f"Pooling({self.get_config_dict()})"

    def get_pooling_mode_str(self) -> str:
        """
        Returns the pooling mode as string
        """
        modes = []
        if self.pooling_mode_cls_token:
            modes.append("cls")
        if self.pooling_mode_mean_tokens:
            modes.append("mean")
        if self.pooling_mode_max_tokens:
            modes.append("max")
        if self.pooling_mode_mean_sqrt_len_tokens:
            modes.append("mean_sqrt_len_tokens")
        if self.pooling_mode_weightedmean_tokens:
            modes.append("weightedmean")
        if self.pooling_mode_lasttoken:
            modes.append("lasttoken")

        return "+".join(modes)

    def forward(self, features):
        # print(features.keys())
        token_embeddings = features["token_embeddings"]
        attention_mask = features["attention_mask"]

        ## Pooling strategy
        output_vectors = []
        if self.pooling_mode_cls_token:
            cls_token = features.get(
                "cls_token_embeddings", token_embeddings[:, 0]
            )  # Take first token by default
            output_vectors.append(cls_token)
        if self.pooling_mode_max_tokens:
            input_mask_expanded = (
                attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
            )
            token_embeddings[
                input_mask_expanded == 0
            ] = -1e9  # Set padding tokens to large negative value
            max_over_time = torch.max(token_embeddings, 1)[0]
            output_vectors.append(max_over_time)
        if self.pooling_mode_mean_tokens or self.pooling_mode_mean_sqrt_len_tokens:
            input_mask_expanded = (
                attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
            )
            sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1)

            # If tokens are weighted (by WordWeights layer), feature 'token_weights_sum' will be present
            if "token_weights_sum" in features:
                sum_mask = (
                    features["token_weights_sum"]
                    .unsqueeze(-1)
                    .expand(sum_embeddings.size())
                )
            else:
                sum_mask = input_mask_expanded.sum(1)

            sum_mask = torch.clamp(sum_mask, min=1e-9)

            if self.pooling_mode_mean_tokens:
                output_vectors.append(sum_embeddings / sum_mask)
            if self.pooling_mode_mean_sqrt_len_tokens:
                output_vectors.append(sum_embeddings / torch.sqrt(sum_mask))
        if self.pooling_mode_weightedmean_tokens:
            input_mask_expanded = (
                attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
            )
            # token_embeddings shape: bs, seq, hidden_dim
            weights = (
                torch.arange(start=1, end=token_embeddings.shape[1] + 1)
                .unsqueeze(0)
                .unsqueeze(-1)
                .expand(token_embeddings.size())
                .float()
                .to(token_embeddings.device)
            )
            assert weights.shape == token_embeddings.shape == input_mask_expanded.shape
            input_mask_expanded = input_mask_expanded * weights

            sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1)

            # If tokens are weighted (by WordWeights layer), feature 'token_weights_sum' will be present
            if "token_weights_sum" in features:
                sum_mask = (
                    features["token_weights_sum"]
                    .unsqueeze(-1)
                    .expand(sum_embeddings.size())
                )
            else:
                sum_mask = input_mask_expanded.sum(1)

            sum_mask = torch.clamp(sum_mask, min=1e-9)
            output_vectors.append(sum_embeddings / sum_mask)
        if self.pooling_mode_lasttoken:
            batch_size, _, hidden_dim = token_embeddings.shape
            # attention_mask shape: (bs, seq_len)
            # Get shape [bs] indices of the last token (i.e. the last token for each batch item)
            # argmin gives us the index of the first 0 in the attention mask;
            # We get the last 1 index by subtracting 1
            gather_indices = (
                torch.argmin(attention_mask, 1, keepdim=False) - 1
            )  # Shape [bs]

            # There are empty sequences, where the index would become -1 which will crash
            gather_indices = torch.clamp(gather_indices, min=0)

            # Turn indices from shape [bs] --> [bs, 1, hidden_dim]
            gather_indices = gather_indices.unsqueeze(-1).repeat(1, hidden_dim)
            gather_indices = gather_indices.unsqueeze(1)
            assert gather_indices.shape == (batch_size, 1, hidden_dim)

            # Gather along the 1st dim (seq_len) (bs, seq_len, hidden_dim -> bs, hidden_dim)
            # Actually no need for the attention mask as we gather the last token where attn_mask = 1
            # but as we set some indices (which shouldn't be attended to) to 0 with clamp, we
            # use the attention mask to ignore them again
            input_mask_expanded = (
                attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
            )
            embedding = torch.gather(
                token_embeddings * input_mask_expanded, 1, gather_indices
            ).squeeze(dim=1)
            output_vectors.append(embedding)

        output_vector = torch.cat(output_vectors, 1)
        features.update({"sentence_embedding": output_vector})
        return features

    def get_sentence_embedding_dimension(self):
        return self.pooling_output_dimension

    def get_config_dict(self):
        return {key: self.__dict__[key] for key in self.config_keys}

    def save(self, output_path):
        with open(
            os.path.join(output_path, "config.json"), "w", encoding="UTF-8"
        ) as config_file:
            json.dump(self.get_config_dict(), config_file, indent=2)

    @staticmethod
    def load(input_path):
        with open(
            os.path.join(input_path, "config.json"), encoding="UTF-8"
        ) as config_file:
            config = json.load(config_file)

        return INSTRUCTORPooling(**config)


def import_from_string(dotted_path):
    """
    Import a dotted module path and return the attribute/class designated by the
    last name in the path. Raise ImportError if the import failed.
    """
    try:
        module_path, class_name = dotted_path.rsplit(".", 1)
    except ValueError:
        msg = f"{dotted_path} doesn't look like a module path"
        raise ImportError(msg)

    try:
        module = importlib.import_module(dotted_path)
    except:
        module = importlib.import_module(module_path)

    try:
        return getattr(module, class_name)
    except AttributeError:
        msg = f"Module {module_path} does not define a {class_name} attribute/class"
        raise ImportError(msg)


class INSTRUCTORTransformer(Transformer):
    def __init__(
        self,
        model_name_or_path: str,
        max_seq_length=None,
        model_args=None,
        cache_dir=None,
        tokenizer_args=None,
        do_lower_case: bool = False,
        tokenizer_name_or_path: Union[str, None] = None,
        load_model: bool = True,
    ):
        super().__init__(model_name_or_path)
        if model_args is None:
            model_args = {}
        if tokenizer_args is None:
            tokenizer_args = {}
        self.config_keys = ["max_seq_length", "do_lower_case"]
        self.do_lower_case = do_lower_case
        self.model_name_or_path = model_name_or_path
        if model_name_or_path == "bi-contriever":
            model_name_or_path = "facebook/contriever"
        if model_name_or_path.startswith("bigtr"):
            model_name_or_path = model_name_or_path.split("#")[1]
        if "bigtr" in model_name_or_path and os.path.isdir(model_name_or_path):
            config = AutoConfig.from_pretrained(
                os.path.join(model_name_or_path, "with_prompt"),
                **model_args,
                cache_dir=cache_dir,
            )
        else:
            config = AutoConfig.from_pretrained(
                model_name_or_path, **model_args, cache_dir=cache_dir
            )

        if load_model:
            self._load_model(self.model_name_or_path, config, cache_dir, **model_args)
        self.tokenizer = AutoTokenizer.from_pretrained(
            tokenizer_name_or_path
            if tokenizer_name_or_path is not None
            else model_name_or_path,
            cache_dir=cache_dir,
            **tokenizer_args,
        )

        if max_seq_length is None:
            if (
                hasattr(self.auto_model, "config")
                and hasattr(self.auto_model.config, "max_position_embeddings")
                and hasattr(self.tokenizer, "model_max_length")
            ):
                max_seq_length = min(
                    self.auto_model.config.max_position_embeddings,
                    self.tokenizer.model_max_length,
                )

        self.max_seq_length = max_seq_length
        if tokenizer_name_or_path is not None:
            self.auto_model.config.tokenizer_class = self.tokenizer.__class__.__name__

    def forward(self, features):
        input_features = {
            "input_ids": features["input_ids"],
            "attention_mask": features["attention_mask"],
        }
        if "token_type_ids" in features:
            input_features["token_type_ids"] = features["token_type_ids"]

        instruction_mask = features["instruction_mask"]
        output_states = self.auto_model(**input_features, return_dict=False)
        output_tokens = output_states[0]
        attention_mask = features["attention_mask"]
        instruction_mask = features["instruction_mask"]
        attention_mask = attention_mask * instruction_mask
        features.update(
            {"token_embeddings": output_tokens, "attention_mask": attention_mask}
        )

        if self.auto_model.config.output_hidden_states:
            all_layer_idx = 2
            if (
                len(output_states) < 3
            ):  # Some models only output last_hidden_states and all_hidden_states
                all_layer_idx = 1
            hidden_states = output_states[all_layer_idx]
            features.update({"all_layer_embeddings": hidden_states})

        return features

    @staticmethod
    def load(input_path: str):
        # Old classes used other config names than 'sentence_bert_config.json'
        for config_name in [
            "sentence_bert_config.json",
            "sentence_roberta_config.json",
            "sentence_distilbert_config.json",
            "sentence_camembert_config.json",
            "sentence_albert_config.json",
            "sentence_xlm-roberta_config.json",
            "sentence_xlnet_config.json",
        ]:
            sbert_config_path = os.path.join(input_path, config_name)
            if os.path.exists(sbert_config_path):
                break

        with open(sbert_config_path, encoding="UTF-8") as config_file:
            config = json.load(config_file)
        return INSTRUCTORTransformer(model_name_or_path=input_path, **config)

    def tokenize(self, texts):
        """
        Tokenizes a text and maps tokens to token-ids
        """
        output = {}
        if isinstance(texts[0], str):
            to_tokenize = [texts]
            to_tokenize = [[str(s).strip() for s in col] for col in to_tokenize]

            # Lowercase
            if self.do_lower_case:
                to_tokenize = [[s.lower() for s in col] for col in to_tokenize]

            input_features = self.tokenizer(
                *to_tokenize,
                padding="max_length",
                truncation="longest_first",
                return_tensors="pt",
                max_length=self.max_seq_length,
            )

        elif isinstance(texts[0], list):
            assert isinstance(texts[0][1], str)
            assert (
                len(texts[0]) == 2
            ), "The input should have both instruction and input text"

            instructions = []
            instruction_prepended_input_texts = []
            for pair in texts:
                instruction = pair[0].strip()
                text = pair[1].strip()
                if self.do_lower_case:
                    instruction = instruction.lower()
                    text = text.lower()
                instructions.append(instruction)
                instruction_prepended_input_texts.append("".join([instruction, text]))

            input_features = self.tokenize(instruction_prepended_input_texts)
            instruction_features = self.tokenize(instructions)
            input_features = INSTRUCTOR.prepare_input_features(
                input_features, instruction_features
            )
        else:
            raise ValueError("not support other modes")

        output.update(input_features)
        return output


class INSTRUCTOR(SentenceTransformer):
    @staticmethod
    def prepare_input_features(
        input_features, instruction_features, return_data_type: str = "pt"
    ):
        if return_data_type == "np":
            input_features["attention_mask"] = torch.from_numpy(
                input_features["attention_mask"]
            )
            instruction_features["attention_mask"] = torch.from_numpy(
                instruction_features["attention_mask"]
            )

        input_attention_mask_shape = input_features["attention_mask"].shape
        instruction_attention_mask = instruction_features["attention_mask"]

        # reducing the attention length by 1 in order to omit the attention corresponding to the end_token
        instruction_attention_mask = instruction_attention_mask[:, 1:]

        # creating instruction attention matrix equivalent to the size of the input attention matrix
        expanded_instruction_attention_mask = torch.zeros(
            input_attention_mask_shape, dtype=torch.int64
        )
        # assigning the the actual instruction attention matrix to the expanded_instruction_attention_mask
        # eg:
        # instruction_attention_mask: 3x3
        #  [[1,1,1],
        #   [1,1,0],
        #   [1,0,0]]
        # expanded_instruction_attention_mask: 3x4
        #  [[1,1,1,0],
        #   [1,1,0,0],
        #   [1,0,0,0]]
        expanded_instruction_attention_mask[
            : instruction_attention_mask.size(0), : instruction_attention_mask.size(1)
        ] = instruction_attention_mask

        # In the pooling layer we want to consider only the tokens corresponding to the input text
        # and not the instruction. This is achieved by inverting the
        # attention_mask corresponding to the instruction.
        expanded_instruction_attention_mask = 1 - expanded_instruction_attention_mask
        input_features["instruction_mask"] = expanded_instruction_attention_mask
        if return_data_type == "np":
            input_features["attention_mask"] = input_features["attention_mask"].numpy()
            instruction_features["attention_mask"] = instruction_features[
                "attention_mask"
            ].numpy()
        return input_features

    def smart_batching_collate(self, batch):
        num_texts = len(batch[0].texts)
        texts = [[] for _ in range(num_texts)]
        labels = []

        for example in batch:
            for idx, text in enumerate(example.texts):
                texts[idx].append(text)
            labels.append(example.label)

        labels = torch.tensor(labels)
        batched_input_features = []

        for idx in range(num_texts):
            assert isinstance(texts[idx][0], list)
            assert (
                len(texts[idx][0]) == 2
            ), "The input should have both instruction and input text"

            num = len(texts[idx])
            instructions = []
            instruction_prepended_input_texts = []
            for local_idx in range(num):
                assert len(texts[idx][local_idx]) == 2
                instructions.append(texts[idx][local_idx][0])
                instruction_prepended_input_texts.append("".join(texts[idx][local_idx]))
                assert isinstance(instructions[-1], str)
                assert isinstance(instruction_prepended_input_texts[-1], str)

            input_features = self.tokenize(instruction_prepended_input_texts)
            instruction_features = self.tokenize(instructions)
            input_features = INSTRUCTOR.prepare_input_features(
                input_features, instruction_features
            )
            batched_input_features.append(input_features)

        return batched_input_features, labels

    def _load_sbert_model(self, model_path, token=None, cache_folder=None, revision=None, trust_remote_code=False):
        """
        Loads a full sentence-transformers model
        """
        if os.path.isdir(model_path):
            # If model_path is a local directory, load the model directly
            model_path = str(model_path)
        else:
            # If model_path is a Hugging Face repository ID, download the model
            download_kwargs = {
                "repo_id": model_path,
                "revision": revision,
                "library_name": "sentence-transformers",
                "token": token,
                "cache_dir": cache_folder,
                "tqdm_class": disabled_tqdm,
            }
            model_path = snapshot_download(**download_kwargs)

        # Check if the config_sentence_transformers.json file exists (exists since v2 of the framework)
        config_sentence_transformers_json_path = os.path.join(
            model_path, "config_sentence_transformers.json"
        )
        if os.path.exists(config_sentence_transformers_json_path):
            with open(
                config_sentence_transformers_json_path, encoding="UTF-8"
            ) as config_file:
                self._model_config = json.load(config_file)

        # Check if a readme exists
        model_card_path = os.path.join(model_path, "README.md")
        if os.path.exists(model_card_path):
            try:
                with open(model_card_path, encoding="utf8") as config_file:
                    self._model_card_text = config_file.read()
            except:
                pass

        # Load the modules of sentence transformer
        modules_json_path = os.path.join(model_path, "modules.json")
        with open(modules_json_path, encoding="UTF-8") as config_file:
            modules_config = json.load(config_file)

        modules = OrderedDict()
        for module_config in modules_config:
            if module_config["idx"] == 0:
                module_class = INSTRUCTORTransformer
            elif module_config["idx"] == 1:
                module_class = INSTRUCTORPooling
            else:
                module_class = import_from_string(module_config["type"])
            module = module_class.load(os.path.join(model_path, module_config["path"]))
            modules[module_config["name"]] = module

        return modules

    def encode(
        self,
        sentences,
        batch_size: int = 32,
        show_progress_bar: Union[bool, None] = None,
        output_value: str = "sentence_embedding",
        convert_to_numpy: bool = True,
        convert_to_tensor: bool = False,
        device: Union[str, None] = None,
        normalize_embeddings: bool = False,
    ):
        """
        Computes sentence embeddings

        :param sentences: the sentences to embed
        :param batch_size: the batch size used for the computation
        :param show_progress_bar: Output a progress bar when encode sentences
        :param output_value:  Default sentence_embedding, to get sentence embeddings.
        Can be set to token_embeddings to get wordpiece token embeddings. Set to None, to get all output values
        :param convert_to_numpy: If true, the output is a list of numpy vectors.
        Else, it is a list of pytorch tensors.
        :param convert_to_tensor: If true, you get one large tensor as return.
        Overwrites any setting from convert_to_numpy
        :param device: Which torch.device to use for the computation
        :param normalize_embeddings: If set to true, returned vectors will have length 1.
        In that case, the faster dot-product (util.dot_score) instead of cosine similarity can be used.

        :return:
           By default, a list of tensors is returned. If convert_to_tensor,
           a stacked tensor is returned. If convert_to_numpy, a numpy matrix is returned.
        """
        self.eval()
        if show_progress_bar is None:
            show_progress_bar = False

        if convert_to_tensor:
            convert_to_numpy = False

        if output_value != "sentence_embedding":
            convert_to_tensor = False
            convert_to_numpy = False

        input_was_string = False
        if isinstance(sentences, str) or not hasattr(
            sentences, "__len__"
        ):  # Cast an individual sentence to a list with length 1
            sentences = [sentences]
            input_was_string = True

        if device is None:
            device = self._target_device

        self.to(device)

        all_embeddings = []
        if isinstance(sentences[0], list):
            lengths = []
            for sen in sentences:
                lengths.append(-self._text_length(sen[1]))
            length_sorted_idx = np.argsort(lengths)
        else:
            length_sorted_idx = np.argsort(
                [-self._text_length(sen) for sen in sentences]
            )
        sentences_sorted = [sentences[idx] for idx in length_sorted_idx]

        for start_index in trange(
            0, len(sentences), batch_size, desc="Batches", disable=not show_progress_bar
        ):
            sentences_batch = sentences_sorted[start_index : start_index + batch_size]
            features = self.tokenize(sentences_batch)
            features = batch_to_device(features, device)

            with torch.no_grad():
                out_features = self.forward(features)

                if output_value == "token_embeddings":
                    embeddings = []
                    for token_emb, attention in zip(
                        out_features[output_value], out_features["attention_mask"]
                    ):
                        last_mask_id = len(attention) - 1
                        while last_mask_id > 0 and attention[last_mask_id].item() == 0:
                            last_mask_id -= 1

                        embeddings.append(token_emb[0 : last_mask_id + 1])
                elif output_value is None:  # Return all outputs
                    embeddings = []
                    for sent_idx in range(len(out_features["sentence_embedding"])):
                        row = {
                            name: out_features[name][sent_idx] for name in out_features
                        }
                        embeddings.append(row)
                else:  # Sentence embeddings
                    embeddings = out_features[output_value]
                    embeddings = embeddings.detach()
                    if normalize_embeddings:
                        embeddings = torch.nn.functional.normalize(
                            embeddings, p=2, dim=1
                        )

                    # fixes for #522 and #487 to avoid oom problems on gpu with large datasets
                    if convert_to_numpy:
                        embeddings = embeddings.cpu()

                all_embeddings.extend(embeddings)

        all_embeddings = [all_embeddings[idx] for idx in np.argsort(length_sorted_idx)]

        if convert_to_tensor:
            all_embeddings = torch.stack(all_embeddings)
        elif convert_to_numpy:
            all_embeddings = np.asarray([emb.numpy() for emb in all_embeddings])

        if input_was_string:
            all_embeddings = all_embeddings[0]

        return all_embeddings

@dharless-eli-lilly
Copy link

Awesome, thank you, will check this out.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants