# How to Encode a Sentence for StableDiffusion with BERT

This notebook is meant to teach you about how transformers are used to encode a sentence for the LatentDiffusion model, the predecessor to StableDiffusion. The most important aspect of transformers, and all machine learning, is how we encode and decode information. We need to encode the caption in a way that can be provided to the image generation step.   

This notebook focuses on how that works. We will walk through the inference pipeline of Huggingface's BERT explaining each step along the way. BERT was the start of the transformer revolution with [Attention is All You Need](https://arxiv.org/pdf/1706.03762.pdf). It is the great-grandfather of ChatGPT. We're specifically using the implementation by Huggingface. The code is [here](https://github.com/huggingface/transformers/blob/main/src/transformers/models/bert/modeling_bert.py) and the documentation is [here](https://huggingface.co/docs/transformers/v4.25.1/en/model_doc/bert). StableDiffusion uses CLIP's text embedding system to do this, but the concept is the same and BERT is widely used.

## Primary Components

There are 4 components we need to pay attention to:
- Tokenizer
- Embedder
- Transformer
- Pooler

The tokenizer and embedder work together to make the input to the transformer and transformer and pooler work together to create the useful output we need in the next step. 

First we need to install and import the libraries we need.

In [None]:
# Installing the huggingface transformers library, and pytorch
!pip install transformers
!pip install torch

In [2]:
# Importing the libraries, and creating the tokenizer and model.
from transformers import BertTokenizer, BertModel

model_name = "bert-base-uncased"
tokenizer = BertTokenizer.from_pretrained(model_name)
model = BertModel.from_pretrained(model_name)

  from .autonotebook import tqdm as notebook_tqdm
Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.seq_relationship.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


## Tokenizer

The first step is to take apart an English sentence encoded as a string. This is the job of the tokenizer. This gives us a list of 'tokens' from our sentence, represented by integers. It first splits up sentences into words and punctuation, then it splits the words into sub-words. The exact tokenizer and how it works doesn't matter that much, so don't worry about the details, but BERT uses [wordpiece](https://huggingface.co/course/chapter6/6?fw=pt). This is an attempt to learn the components of words so that our dictionary can stay small and allow words like "snowboard" and "surfboard" to be inherently related. 

Here's the BERT tokenizer in action, you can play with the sentence if you like to see how it works. 

In [12]:
encoding = tokenizer("Hello, my dog is cute. He likes snowboarding!")
print("Token IDs:", encoding["input_ids"])
print("Token Strings:", [str(tokenizer.decode(token)) for token in encoding["input_ids"]])
print("Complete Decode:", tokenizer.decode(encoding["input_ids"]))


Token IDs: [101, 7592, 1010, 2026, 3899, 2003, 10140, 1012, 2002, 7777, 4586, 21172, 999, 102]
Token Strings: ['[ C L S ]', 'h e l l o', ',', 'm y', 'd o g', 'i s', 'c u t e', '.', 'h e', 'l i k e s', 's n o w', '# # b o a r d i n g', '!', '[ S E P ]']
Complete Decode: [CLS] hello, my dog is cute. he likes snowboarding! [SEP]


### Semantic and Control Tokens

Note that there are 2 extra tokens that do no correspond to words in our original sentence! These are the control tokens, which are the secret sauce that makes this all work. `[CLS]` means we're doing a "classification" task, and `[SEP]` which means end in this case. There are a few more tokens, `[UNK]` means unknown, `[MASK]` means that this token is masked off. 

The `[CLS]` token is the most important for us. The BERT model was trained to do several tasks. One task was to relate two sentences based on the final state of this token, after the embedding, transformer, and pooling steps (which we go into next). BERT is trained to say that "I have never seen a hummingbird" contradicts the premise "I have never seen a hummingbird not flying" based off of what the transformer does to this token. So, it in a sense contains the "meaning" of the sentence. Another task it was trained to complete is to predict words/tokens that were removed which uses the `[MASK]` token. 

### Other Tokenizers

OpenAI has a really good tool explaining their favorite tokenizer [here](https://beta.openai.com/tokenizer) that uses [byte-pair-encoding](https://en.wikipedia.org/wiki/Byte_pair_encoding). This is more flexible than the wordpiece, but uses more tokens. If you look at the OpenAI tokenizer it splits `snowboarding` into 3 tokens, `_s`, `now`, and `boarding`, but BERT's splits it up into only 2 tokens `snow`, `##boarding` (the `##` means it's a continuation of a word). It took researchers a while to figure out that transformers could work well with a less linguistically informed tokenizer. For natural language as long as you do it sensibly, the exact tokenization doesn't matter that much for reasons that will become clear later. 

However, for security research, this could matter a lot. Network packets are more structured than language and a good tokenization will give the transformer a massive leg up when it's trying to learn. You could encode a lot of information about an environment in a token sequence, if you do it right. 

## Embedder

The next step is to take these tokens, which are a just list of integers, and give them room to breathe. The information an integer can encode is small. There's only ~30,000 of them and they're just numbers. We need something that can contain all the information the word "dog" has. You as a human have a picture in your mind of a mammal that has 4 legs and a head. You may remember pets you grew up with, but a ML system doesn't have any of that. A number cannot contain that much information, so we give it 768 floating point numbers instead with an embedding layer. 

#### A Warning!!!
This is all the system has, it is a single point in space that is nearly meaningless. There's a strong desire to anthropomorphize this, which I've done a bit in this writing, but this information encoding is fundamentally different from how humans process information. This embedding has more to do with the number of bits of information we need this token and it's role in the transformer to contain than giving it a place to store semantic meaning. Think of it more like a location in a space, like a point on a map, than anything else.

The model contains this embedding layer:

In [11]:
model.embeddings

BertEmbeddings(
  (word_embeddings): Embedding(30522, 768, padding_idx=0)
  (position_embeddings): Embedding(512, 768)
  (token_type_embeddings): Embedding(2, 768)
  (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
  (dropout): Dropout(p=0.1, inplace=False)
)

We can see that there are 30522 elements in the `word_embeddings` component of the embedder. The embedder just takes token `[CLS]` token, which is `101`, to the 101st element of this list. 

However, there are some other embeddings. The `position_embeddings` encode the position of the token. The `[CLS]` token is in position 0, and so we add the position embedding for 0 to that token's embedding. The transformer does not care what order it receives it's input sequence. Adding this position information to each token is how we fix that. The token type embedding isn't used in LatentDiffusion, so we can ignore that. Note that this was only trained on sequences that have at most 512 tokens! 

Here's the embedding layer applied to our sentence. It's a giant tensor!

In [14]:
encoding = tokenizer("Hello, my dog is cute. He likes snowboarding!", return_tensors="pt")
model.embeddings(encoding["input_ids"])

tensor([[[ 1.6855e-01, -2.8577e-01, -3.2613e-01,  ..., -2.7571e-02,
           3.8253e-02,  1.6400e-01],
         [ 3.7386e-01, -1.5575e-02, -2.4561e-01,  ..., -3.1657e-02,
           5.5144e-01, -5.2406e-01],
         [ 4.6706e-04,  1.6225e-01, -6.4443e-02,  ...,  4.9443e-01,
           6.9413e-01,  3.6286e-01],
         ...,
         [-1.1198e-01, -1.0581e+00, -2.0782e-01,  ..., -6.7460e-01,
          -9.2665e-02,  4.3315e-02],
         [ 4.7513e-01, -1.6158e-01, -3.3946e-01,  ...,  5.6113e-01,
           5.0795e-01,  5.9823e-01],
         [-6.0564e-01,  9.6814e-02,  1.8802e-01,  ..., -2.7735e-01,
           1.8495e-01,  5.7977e-02]]], grad_fn=<NativeLayerNormBackward0>)

The important part of this is the shape of the tensor. We can print it out and inspect it. We have 1 sentence, there are 14 tokens for our sentence, and each token is embedded in 768 dimensions. So we get the shape:

In [6]:
model.embeddings(encoding["input_ids"]).shape

torch.Size([1, 14, 768])

Tensor is just a fancy way of organizing the floating point numbers. This is a rank 3 1x14x768 tensor, which is represented in the computer as an array of 10752 floating point numbers with another 3 integers 1, 14, and 768 to determine the "shape". You can think of this as a list with 1 element which is a list with 14 vectors where each vector has 768 elements.  

## Transformer

The encoder is the transformer. It's the next step in the process. We're focused on the embeddings here, not the transformer so we're not going to focus on how this component works, but we will talk about what it does. First let's look at the output:

In [16]:
encoder_outputs = model.encoder(model.embeddings(encoding["input_ids"]))
print("Output:",encoder_outputs.last_hidden_state)
print("Shape:",encoder_outputs.last_hidden_state.shape)

Output: tensor([[[ 0.0233,  0.1726, -0.0896,  ..., -0.2837,  0.3139,  0.6161],
         [ 0.5281,  0.1442,  0.0739,  ...,  0.0331,  1.0601, -0.1347],
         [-0.4106,  0.3789,  0.5610,  ..., -0.4233,  0.7288, -0.0884],
         ...,
         [ 0.9196, -0.1436,  0.0633,  ..., -0.2213, -0.4891, -0.3392],
         [-0.2715, -0.1521,  0.1280,  ...,  0.8118,  0.1872, -0.0361],
         [ 0.5228,  0.3717, -0.0870,  ...,  0.1787, -0.3445, -0.2823]]],
       grad_fn=<NativeLayerNormBackward0>)
Shape: torch.Size([1, 14, 768])


When we inspect the shape we see that it's the same shape as the last one! This is the origin of the name "transformer", it takes one sequence as input and outputs a sequence with the same shape. These are called "context embeddings" which we'll talk about at the end of this notebook. There's more than this `last_hidden_state` in the model's output and the generator system that creates text uses all of it. We, however, do not care about most of this and discard everything in the next step. 

## Pooler

The last step in the forward pass of BERT for what we're doing is a pooling step. This gives us the final output which can be fed into other models.

Roughly, BERT was initially trained to do text generation, predict the next token in the sequence. This is very easy to set up, all you need is a metric boatload of text. GPT3 was trained on trillions of tokens scraped from all over the web. To give you context, there are only 2.3 billion tokens in all of Wikipedia. BERT was trained on billions of tokens, an is a small model by today's standards.

Once this initial training was done, then BERT was trained to do text classification. This is harder to setup as it needs a curated dataset with sentences that are paired. The most common one is [GLUE](https://gluebenchmark.com/). This secondary training step is called finetuning, and is usually used to make a large foundational model that was trained on one task do something else. For images for example, one can finetune a model trained on ImageNet to classify images from medical scan instead. ChatGPT has extra functionality that probably required a very interesting training pipeline with a few more finetuning steps.

Here's the pooling step in BERT:

In [9]:
bert_outputs = model.pooler(encoder_outputs.last_hidden_state)
print(bert_outputs)
print(bert_outputs.shape)

tensor([[-0.8019, -0.3828, -0.8803,  0.4781,  0.6283, -0.1234,  0.7109,  0.2231,
         -0.6607, -0.9999, -0.2453,  0.7718,  0.9811,  0.5380,  0.9193, -0.6514,
         -0.0960, -0.5760,  0.1937,  0.0498,  0.6715,  1.0000,  0.0885,  0.2818,
          0.4310,  0.9341, -0.7703,  0.9159,  0.9443,  0.6812, -0.5385,  0.1731,
         -0.9905, -0.2347, -0.9173, -0.9890,  0.3910, -0.6985,  0.0112,  0.2179,
         -0.8985,  0.2892,  0.9999, -0.5019,  0.2948, -0.3317, -1.0000,  0.2828,
         -0.8458,  0.7528,  0.8369,  0.5738, -0.0594,  0.3531,  0.4162,  0.0565,
         -0.2063, -0.0427, -0.2212, -0.5139, -0.6546,  0.4094, -0.7918, -0.8546,
          0.7591,  0.7553, -0.0725, -0.2874,  0.0137, -0.1024,  0.7832,  0.2025,
         -0.1078, -0.8899,  0.5245,  0.2641, -0.6041,  1.0000, -0.1185, -0.9764,
          0.8874,  0.6552,  0.5250, -0.1935,  0.5401, -1.0000,  0.4399, -0.1549,
         -0.9904,  0.1935,  0.4651, -0.1812,  0.7752,  0.5038, -0.7092, -0.3990,
         -0.1677, -0.8439, -

The pooler operates on the first token and passes it through a fully connected feed forward network. This corresponds to the `[CLS]` token, and this is what is then fed to the image generation part of StableDiffusion. The other tokens aren't useful to us and are discarded in this step. 

## The Other Tokens

The output is called the "context embeddings", they in a sense have added the context to the initial embeddings. . This [paper](https://proceedings.neurips.cc/paper/2019/file/159c1ffe5b61b41b3c4d8f4c2150f6c4-Paper.pdf) has an excellent example with the token for "die" which we will talk about. 

![German Die vs English die](bert_die_vs_die.png)

The final tokens in a sense have the "in-context" meaning when the model is trained to replace masked words. The image is from the paper where they show that the token "die" is embedded in 4 separate locations depending on if it's the German article "die" that means "the" in english, or the english verb "die" in it's singular or plural form, or the game piece you roll while playing D&D or craps.  