# Experimental: Extract BERT embeddings for each noun in a poem

To see if we can find any correlation of subject continuity and word embeddings of all of a poem's nouns, I put together some code together to lay the ground work for that.

First, I implemented a function that takes a poem and outputs the BERT embedding for each token.
Second, there is another function that extracts all nouns (and their word index) from a poem.

Note that BERT operates on subword basis: While most of the common English words are a token as themselves, less common words (e.g. "embeddings") are split up into multiple tokens. In the most extreme case, each character of a word is a token in itself.

Hence, BERT might output multiple token embeddings for a given word. To get all the token embeddings of a given word, the function `get_token_embeddings` outputs besides the actual embeddings also a mapping, so that we can reconstruct which embedding belongs to which token, and which tokens belong to which word in the input.

Similarly, the function `get_nouns_in_poem` does not only output the nouns themselves, but also their indexes in the original poem, so that we can simply fetch all token embeddings that belong to some noun.

Here is a useful tutorial about how to get the word embeddings from BERT (I copied a lot of stuff from there):
https://mccormickml.com/2019/05/14/BERT-word-embeddings-tutorial/

In [None]:
!pip install transformers
import torch
from transformers import BertTokenizerFast, BertModel
import nltk
nltk.download('stopwords')
nltk.download('wordnet')
nltk.download('punkt')
nltk.download('averaged_perceptron_tagger')
nltk.download('treebank')

In [None]:
tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')
model = BertModel.from_pretrained('bert-base-uncased', output_hidden_states = True)

In [81]:
def get_token_embeddings(tokenizer, model, poem):
    """Retrieve the BERT embedding of each token in a poem

    :param tokenizer: BERT Fast tokenizer
    :param model: BERT pre-trained model
    :param poem: poem as a single string
    :return: (mapping of word index to token, embeddings for each token)
    """
    # Tokenize the poem
    marked_poem = "[CLS] " + poem + " [SEP]"
    bert_tokenized_poem = tokenizer.tokenize(marked_poem)

    # Get the word to token ID mapping
    encoded_poem = tokenizer(poem)
    word_ids = encoded_poem.word_ids()
    word_to_token = list(zip(word_ids, bert_tokenized_poem))
    word_token_mapping = [(i, pair) for i, pair in enumerate(word_to_token[1:-1])]

    # Get all hidden states
    indexed_tokens = tokenizer.convert_tokens_to_ids(bert_tokenized_poem)
    segment_ids = [1] * len(bert_tokenized_poem)

    tokens_tensor = torch.tensor([indexed_tokens])
    segments_tensors = torch.tensor([segment_ids])

    model.eval()
    with torch.no_grad():
        outputs = model(tokens_tensor, segments_tensors)
    hidden_states = torch.stack(outputs[2], dim=0).squeeze(dim=1).permute(1,0,2)

    # Concat last four layers as final embedding
    embeddings = []
    for token in hidden_states[1:-1]:
        last_four_layers = torch.cat([token[-i] for i in range(1, 5)], dim=0)
        embeddings.append(last_four_layers)
    
    return word_token_mapping, embeddings

## Example 1: Each word consists only of one token

In [94]:
poem_1 = """
although chocolate cookies are sweet
used for jelly, a warm drink or a treat
with some nut, and some spice
and some cream. if it’s nice
it’s the flavor that makes you a treat
"""

In [95]:
word_token_mapping, embeddings = get_token_embeddings(tokenizer, model, poem_1)

print(f"Number of tokens: {len(embeddings)}")
print(f"Shape of each embedding: {embeddings[0].shape}\n")

print("(Token index, (word index, word))")
word_token_mapping

Number of tokens: 41
Shape of each embedding: torch.Size([3072])

(Token index, (word index, word))


[(0, (0, 'although')),
 (1, (1, 'chocolate')),
 (2, (2, 'cookies')),
 (3, (3, 'are')),
 (4, (4, 'sweet')),
 (5, (5, 'used')),
 (6, (6, 'for')),
 (7, (7, 'jelly')),
 (8, (8, ',')),
 (9, (9, 'a')),
 (10, (10, 'warm')),
 (11, (11, 'drink')),
 (12, (12, 'or')),
 (13, (13, 'a')),
 (14, (14, 'treat')),
 (15, (15, 'with')),
 (16, (16, 'some')),
 (17, (17, 'nut')),
 (18, (18, ',')),
 (19, (19, 'and')),
 (20, (20, 'some')),
 (21, (21, 'spice')),
 (22, (22, 'and')),
 (23, (23, 'some')),
 (24, (24, 'cream')),
 (25, (25, '.')),
 (26, (26, 'if')),
 (27, (27, 'it')),
 (28, (28, '’')),
 (29, (29, 's')),
 (30, (30, 'nice')),
 (31, (31, 'it')),
 (32, (32, '’')),
 (33, (33, 's')),
 (34, (34, 'the')),
 (35, (35, 'flavor')),
 (36, (36, 'that')),
 (37, (37, 'makes')),
 (38, (38, 'you')),
 (39, (39, 'a')),
 (40, (40, 'treat'))]

# Example 2: Words that consist of multiple tokens

In [96]:
poem_2 = """
I haven’t switched on my TV for years
we have people like me, and my fears
i talk to the news
i am paying my dues
rarely and loudly despise all my peers
"""

In [97]:
word_token_mapping, embeddings = get_token_embeddings(tokenizer, model, poem_2)

print(f"Number of tokens: {len(embeddings)}")
print(f"Shape of each embedding: {embeddings[0].shape}\n")

print("(Token index, (word index, word))")
word_token_mapping

Number of tokens: 39
Shape of each embedding: torch.Size([3072])

(Token index, (word index, word))


[(0, (0, 'i')),
 (1, (1, 'haven')),
 (2, (2, '’')),
 (3, (3, 't')),
 (4, (4, 'switched')),
 (5, (5, 'on')),
 (6, (6, 'my')),
 (7, (7, 'tv')),
 (8, (8, 'for')),
 (9, (9, 'years')),
 (10, (10, 'we')),
 (11, (11, 'have')),
 (12, (12, 'people')),
 (13, (13, 'like')),
 (14, (14, 'me')),
 (15, (15, ',')),
 (16, (16, 'and')),
 (17, (17, 'my')),
 (18, (18, 'fears')),
 (19, (19, 'i')),
 (20, (20, 'talk')),
 (21, (21, 'to')),
 (22, (22, 'the')),
 (23, (23, 'news')),
 (24, (24, 'i')),
 (25, (25, 'am')),
 (26, (26, 'paying')),
 (27, (27, 'my')),
 (28, (28, 'due')),
 (29, (28, '##s')),
 (30, (29, 'rarely')),
 (31, (30, 'and')),
 (32, (31, 'loudly')),
 (33, (32, 'des')),
 (34, (32, '##pis')),
 (35, (32, '##e')),
 (36, (33, 'all')),
 (37, (34, 'my')),
 (38, (35, 'peers'))]

# Extracting nouns and their word indexes from poem

In [98]:
# from https://stackoverflow.com/questions/33587667/extracting-all-nouns-from-a-text-file-using-nltk
# thought about using Rami's code from the lexical diversity but it seemed overkill for what I needed

def get_nouns_in_poem(poem):
    """Extract all nouns and their word index from a poem

    :param poem: One poem as a single string
    :return: list of tuples: (word index, noun as string)
    """
    is_noun = lambda pos: pos[:2] == 'NN'
    nltk_tokenized_poem = nltk.word_tokenize(poem)
    nouns = [(i, word) for i, (word, pos) in enumerate(nltk.pos_tag(nltk_tokenized_poem)) if is_noun(pos)]
    return nouns

In [99]:
get_nouns_in_poem(poem_1)

[(1, 'chocolate'),
 (2, 'cookies'),
 (11, 'drink'),
 (14, 'treat'),
 (17, 'nut'),
 (21, 'spice'),
 (24, 'cream'),
 (35, 'flavor'),
 (40, 'treat')]