In [48]:
import typing 

import torch
import transformers

In [49]:
class Encoder():

    # information on the model: https://arxiv.org/abs/2209.07562
    def __init__(self, model: str = "Twitter/twhin-bert-base"):
        self.model = dict(
            tokenizer=transformers.AutoTokenizer.from_pretrained(model),
            transformer=transformers.AutoModel.from_pretrained(model),
        )

    def __call__(self, batch: typing.List[str]):
        return self._pool(
            self.model["transformer"](
                **self.model["tokenizer"](
                    batch,
                    padding=True,
                    return_tensors="pt",
                )
            )
            .last_hidden_state
        )
    
    @staticmethod
    def _pool(batch: torch.Tensor, method: typing.Literal["mean", "cls"] = "mean"):
        return dict(
            mean=lambda x: x.mean(dim=1),
            cls=lambda x: x[:, 0, :]
        )[method](batch)

    @property
    def num_dim(self) -> int:
        return self.model["transformer"].config.to_dict()["hidden_size"]

In [50]:
encoder: Encoder = Encoder()
f"Encoded data dimensionality: {encoder.num_dim}"

Some weights of BertModel were not initialized from the model checkpoint at Twitter/twhin-bert-base and are newly initialized: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


'Encoded data dimensionality: 768'

In [51]:
history: typing.List[str] = [
    "Democrats scream every day about limiting Americans rights and pretend to care about going after tax cheats, so why are they totally silent when it comes to Hunter Biden?.",
    "If our economy is doing as well as JoeBiden seems to think under his policies, then why is inflation up 16.6%?!  Either he is lying or in la la land. Clearly, his policies have failed hardworking American families."
]
feed: typing.List[str] = [
    "JoeBiden's weak appeasement policies are hurting our country and letting down our allies. Why is he giving evil regimes and thugs a pass?",
    "Donald Trump tried to use Justice Department officials not as independent fact finders, but as partisan surrogates to legitimize his Big Lie.",
    "Uncle Nearest Premium Whiskey Honored As Wine Enthusiast's 2020 Spirit Brand Of The Year",
    "at subway: And just a little lettuce. the guy starts backing a truck full of lettuce toward my sandwich"
]

In [52]:
history_matrix: torch.Tensor = encoder(history)
history_matrix.size()

torch.Size([2, 768])

In [53]:
feed_matrix: torch.Tensor = encoder(feed)
feed_matrix.size()

torch.Size([4, 768])

In [54]:
similarity_score = torch.nn.CosineSimilarity(dim=0)

In [55]:
for n, post_vector in enumerate(feed_matrix):
    display(
        f"similarity(history, feed_{n}) = {similarity_score(history_matrix.mean(), post_vector).item():2.4f}"
    )
    

'similarity(history, feed_0) = 0.0828'

'similarity(history, feed_1) = 0.0712'

'similarity(history, feed_2) = 0.0572'

'similarity(history, feed_3) = 0.0583'