In [1]:
# default_exp embeddings

# Embeddings
> AdaptNLP Embeddings Module

In [2]:
#hide
from nbdev.showdoc import *
from fastcore.test import test_eq
from fastcore.xtras import is_listy

In [3]:
#export
import logging
from typing import List, Dict, Union
from collections import defaultdict

from fastcore.dispatch import typedispatch
from flair.data import Sentence
from flair.embeddings import (
    Embeddings,
    WordEmbeddings,
    StackedEmbeddings,
    FlairEmbeddings,
    DocumentPoolEmbeddings,
    DocumentRNNEmbeddings,
    TransformerWordEmbeddings,
)

from adaptnlp.model_hub import FlairModelHub, HFModelHub, FlairModelResult, HFModelResult

  return torch._C._cuda_getDeviceCount() > 0


In [4]:
#export
_flair_hub = FlairModelHub()
_hf_hub = HFModelHub()

In [5]:
#export
logger = logging.getLogger(__name__)

In [6]:
#exporti
@typedispatch
def _make_sentences(text:str, as_list=False) -> Union[List[Sentence], Sentence]:
    return [Sentence(text)] if as_list else Sentence(text)

In [7]:
#hide
test_sentences = 'a,b,c'.split(',')
out = _make_sentences(test_sentences)
tst_out = [Sentence('a'), Sentence('b'), Sentence('c')]
for o,t in zip(out, tst_out):
    test_eq(o[0].text, t[0].text)

In [8]:
#exporti
@typedispatch
def _make_sentences(text:list, as_list=False) -> Union[List[Sentence], Sentence]:
    if all(isinstance(t,str) for t in text):
        return [Sentence(t) for t in text]
    elif all(isinstance(t, Sentence) for t in text):
        return text

In [9]:
#hide
test_sentence = 'My name is Zach'
out = _make_sentences(test_sentence, as_list=True)
tst_out = [Sentence(test_sentence)]
for o,t in zip(out, tst_out):
    test_eq(o[0].text, t[0].text)
test_eq(is_listy(out), True)

In [10]:
#exporti
@typedispatch
def _make_sentences(text:Sentence, as_list=False) -> Union[List[Sentence], Sentence]:
    return [text] if as_list else text

In [11]:
#hide
test_sentence = Sentence('Me')
out = _make_sentences(test_sentence)
test_eq(test_sentence[0].text, out[0].text)

In [12]:
#exporti
def _get_embedding_model(model_name_or_path:Union[str, HFModelResult, FlairModelResult]) -> Union[FlairEmbeddings, WordEmbeddings, TransformerWordEmbeddings, Sentence]:
    "Load the proper `Embeddings` model from `model_name_or_path`"
    if isinstance(model_name_or_path, FlairModelResult): 
        nm = model_name_or_path.name
        try:
            return WordEmbeddings(nm.strip('flairNLP/'))
        except:
            return FlairEmbeddings(nm.strip('flairNLP/'))
    
    elif isinstance(model_name_or_path, HFModelResult): return TransformerWordEmbeddings(model_name_or_path.name)
    else:
        res = _flair_hub.search_model_by_name(model_name_or_path, user_uploaded=True)
        if len(res) < 1:
            # No models found
            res = _hf_hub.search_model_by_name(model_name_or_path, user_uploaded=True)
            if len(res) < 1:
                raise ValueError(f'Embeddings not found for the model key: {model_name_or_path}, check documentation or custom model path to verify specified model')
            else:
                return TransformerWordEmbeddings(res[0].name) # Returning the first should always be the non-fast option
        else:
            nm = res[0].name
            try:
                return WordEmbeddings(nm.strip('flairNLP/'))
            except:
                return FlairEmbeddings(nm.strip('flairNLP/'))

In [13]:
#export
class EasyWordEmbeddings:
    "Word embeddings from the latest language models"

    def __init__(self):
        self.models: Dict[Embeddings] = defaultdict(bool)

    def embed_text(
        self,
        text: Union[List[Sentence], Sentence, List[str], str], # Text input, it can be a string or any of Flair's `Sentence` input formats
        model_name_or_path: Union[str, HFModelResult, FlairModelResult] = "bert-base-cased", # The hosted model name key, model path, or an instance of either `HFModelResult` or `FlairModelResult`
    ) -> List[Sentence]:
        """Produces embeddings for text
        
        **Return**:
        * A list of Flair's `Sentence`s
        """
        # Convert into sentences
        sentences = _make_sentences(text)

        # Load correct Embeddings module
        self.models[model_name_or_path] = _get_embedding_model(model_name_or_path)
        embedding = self.models[model_name_or_path]
        return embedding.embed(sentences)

    def embed_all(
        self,
        text: Union[List[Sentence], Sentence, List[str], str], # Text input, it can be a string or any of Flair's `Sentence` input formats
        *model_names_or_paths: str, # A variable input of model names or paths to embed
    ) -> List[Sentence]:
        """Embeds text with all embedding models loaded

        **Return**:
        * A list of Flair's `Sentence`s
        """
        # Convert into sentences
        sentences = _make_sentences(text)

        if model_names_or_paths:
            for embedding_name in model_names_or_paths:
                sentences = self.embed_text(
                    sentences, model_name_or_path=embedding_name
                )
        else:
            for embedding_name in self.models.keys():
                sentences = self.embed_text(
                    sentences, model_name_or_path=embedding_name
                )
        return sentences

In [14]:
#hide
import torch
embeddings = EasyWordEmbeddings()
res = embeddings.embed_text("text you want embeddings for", model_name_or_path="bert-base-cased")
test_eq(res[0][1].get_embedding().shape, torch.Size([768]))

Some weights of the model checkpoint at bert-base-cased-finetuned-mrpc were not used when initializing BertModel: ['classifier.bias', 'classifier.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [15]:
show_doc(EasyWordEmbeddings.embed_text)

<h4 id="EasyWordEmbeddings.embed_text" class="doc_header"><code>EasyWordEmbeddings.embed_text</code><a href="__main__.py#L8" class="source_link" style="float:right">[source]</a></h4>

> <code>EasyWordEmbeddings.embed_text</code>(**`text`**:`Union`\[`List`\[`Sentence`\], `Sentence`, `List`\[`str`\], `str`\], **`model_name_or_path`**:`Union`\[`str`, [`HFModelResult`](/adaptnlp/model_hub.html#HFModelResult), [`FlairModelResult`](/adaptnlp/model_hub.html#FlairModelResult)\]=*`'bert-base-cased'`*)

Produces embeddings for text

**Return**:
* A list of Flair's `Sentence`s

In [16]:
show_doc(EasyWordEmbeddings.embed_all)

<h4 id="EasyWordEmbeddings.embed_all" class="doc_header"><code>EasyWordEmbeddings.embed_all</code><a href="__main__.py#L26" class="source_link" style="float:right">[source]</a></h4>

> <code>EasyWordEmbeddings.embed_all</code>(**`text`**:`Union`\[`List`\[`Sentence`\], `Sentence`, `List`\[`str`\], `str`\], **\*`model_names_or_paths`**:`str`)

Embeds text with all embedding models loaded

**Return**:
* A list of Flair's `Sentence`s

In [17]:
#export
class EasyStackedEmbeddings:
    """Word Embeddings that have been concatenated and "stacked" as specified by flair

    Usage:

    ```python
    >>> embeddings = adaptnlp.EasyStackedEmbeddings("bert-base-cased", "gpt2", "xlnet-base-cased")
    ```

    **Parameters:**

    * `&ast;embeddings` - Non-keyword variable number of strings specifying the embeddings you want to stack
    """

    def __init__(self, *embeddings: str):
        print("May need a couple moments to instantiate...")
        self.embedding_stack = []

        # Load correct Embeddings module
        for model_name_or_path in embeddings:
            self.embedding_stack.append(_get_embedding_model(model_name_or_path))

        assert len(self.embedding_stack) != 0
        self.stacked_embeddings = StackedEmbeddings(embeddings=self.embedding_stack)

    def embed_text(
        self,
        text: Union[List[Sentence], Sentence, List[str], str],
    ) -> List[Sentence]:
        """Stacked embeddings

        **Parameters**:
        * `text` - Text input, it can be a string or any of Flair's `Sentence` input formats

        **Return**:
        * A list of Flair's `Sentence`s
        """
        # Convert into sentences
        sentences = _make_sentences(text, as_list=True)

        # Unlike flair embeddings modules, stacked embeddings do not return a list of sentences
        self.stacked_embeddings.embed(sentences)
        return sentences

In [18]:
#hide
embeddings = EasyStackedEmbeddings("bert-base-cased", "xlnet-base-cased")
sentences = embeddings.embed_text("This is Albert.  My last name is Einstein.  I like physics and atoms.")
test_eq(sentences[0][0].get_embedding().shape, torch.Size([1536]))

May need a couple moments to instantiate...


Some weights of the model checkpoint at bert-base-cased-finetuned-mrpc were not used when initializing BertModel: ['classifier.bias', 'classifier.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of the model checkpoint at xlnet-base-cased were not used when initializing XLNetModel: ['lm_loss.bias', 'lm_loss.weight']
- This IS expected if you are initializing XLNetModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if y

In [19]:
show_doc(EasyStackedEmbeddings.embed_text)

<h4 id="EasyStackedEmbeddings.embed_text" class="doc_header"><code>EasyStackedEmbeddings.embed_text</code><a href="__main__.py#L27" class="source_link" style="float:right">[source]</a></h4>

> <code>EasyStackedEmbeddings.embed_text</code>(**`text`**:`Union`\[`List`\[`Sentence`\], `Sentence`, `List`\[`str`\], `str`\])

Stacked embeddings

**Parameters**:
* `text` - Text input, it can be a string or any of Flair's `Sentence` input formats

**Return**:
* A list of Flair's `Sentence`s

In [20]:
#export
class EasyDocumentEmbeddings:
    "Document Embeddings generated by pool and rnn methods applied to the word embeddings of text"

    __allowed_methods = ["rnn", "pool"]
    __allowed_configs = ("pool_configs", "rnn_configs")

    def __init__(
        self,
        *embeddings: str, # Non-keyword variable number of strings referring to model names or paths
        methods: List[str] = ["rnn", "pool"], # A list of strings to specify which document embeddings to use i.e. ["rnn", "pool"] (avoids unncessary loading of models if only using one)
        configs: Dict = {
            "pool_configs": {"fine_tune_mode": "linear", "pooling": "mean"},
            "rnn_configs": {
                "hidden_size": 512,
                "rnn_layers": 1,
                "reproject_words": True,
                "reproject_words_dimension": 256,
                "bidirectional": False,
                "dropout": 0.5,
                "word_dropout": 0.0,
                "locked_dropout": 0.0,
                "rnn_type": "GRU",
                "fine_tune": True,
            },
        },
    ):
        print("May need a couple moments to instantiate...")
        self.embedding_stack = []

        # Check methods
        for m in methods:
            assert m in self.__class__.__allowed_methods

        # Set configs for pooling and rnn parameters
        for k, v in configs.items():
            assert k in self.__class__.__allowed_configs
            setattr(self, k, v)

        # Load correct Embeddings module
        for model_name_or_path in embeddings:
            self.embedding_stack.append(_get_embedding_model(model_name_or_path))

        assert len(self.embedding_stack) != 0
        if "pool" in methods:
            self.pool_embeddings = DocumentPoolEmbeddings(
                self.embedding_stack, **self.pool_configs
            )
            print("Pooled embedding loaded")
        if "rnn" in methods:
            self.rnn_embeddings = DocumentRNNEmbeddings(
                self.embedding_stack, **self.rnn_configs
            )
            print("RNN embeddings loaded")

    def embed_pool(
        self,
        text: Union[List[Sentence], Sentence, List[str], str], # Text input, it can be a string or any of Flair's `Sentence` input formats
    ) -> List[Sentence]:
        """Generate stacked embeddings with `DocumentPoolEmbeddings`

        **Return**:
        * A list of Flair's `Sentence`s
        """
        sentences = _make_sentences(text, as_list=True)
        self.pool_embeddings.embed(sentences)
        return sentences

    def embed_rnn(
        self,
        text: Union[List[Sentence], Sentence, List[str], str], # Text input, it can be a string or any of Flair's `Sentence` input formats
    ) -> List[Sentence]: 
        """Generate stacked embeddings with `DocumentRNNEmbeddings`

        **Return**:
        * A list of Flair's `Sentence`s
        """
        sentences = _make_sentences(text, as_list=True)
        self.rnn_embeddings.embed(sentences)
        return sentences

In [21]:
#hide
embeddings = EasyDocumentEmbeddings("bert-base-cased", "xlnet-base-cased")
text = embeddings.embed_pool("This is Albert.  My last name is Einstein.  I like physics and atoms.")
test_eq(text[0].get_embedding().shape, torch.Size([1536]))

May need a couple moments to instantiate...


Some weights of the model checkpoint at bert-base-cased-finetuned-mrpc were not used when initializing BertModel: ['classifier.bias', 'classifier.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of the model checkpoint at xlnet-base-cased were not used when initializing XLNetModel: ['lm_loss.bias', 'lm_loss.weight']
- This IS expected if you are initializing XLNetModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if y

Pooled embedding loaded
RNN embeddings loaded


In [22]:
#hide
text = embeddings.embed_rnn("This is Albert.  My last name is Einstein.  I like physics and atoms.")
test_eq(text[0].get_embedding().shape, torch.Size([512]))

In [23]:
show_doc(EasyDocumentEmbeddings.embed_pool)

<h4 id="EasyDocumentEmbeddings.embed_pool" class="doc_header"><code>EasyDocumentEmbeddings.embed_pool</code><a href="__main__.py#L56" class="source_link" style="float:right">[source]</a></h4>

> <code>EasyDocumentEmbeddings.embed_pool</code>(**`text`**:`Union`\[`List`\[`Sentence`\], `Sentence`, `List`\[`str`\], `str`\])

Generate stacked embeddings with `DocumentPoolEmbeddings`

**Return**:
* A list of Flair's `Sentence`s

In [24]:
show_doc(EasyDocumentEmbeddings.embed_rnn)

<h4 id="EasyDocumentEmbeddings.embed_rnn" class="doc_header"><code>EasyDocumentEmbeddings.embed_rnn</code><a href="__main__.py#L69" class="source_link" style="float:right">[source]</a></h4>

> <code>EasyDocumentEmbeddings.embed_rnn</code>(**`text`**:`Union`\[`List`\[`Sentence`\], `Sentence`, `List`\[`str`\], `str`\])

Generate stacked embeddings with `DocumentRNNEmbeddings`

**Return**:
* A list of Flair's `Sentence`s