<a href="https://colab.research.google.com/github/blurred421/LFD473-code/blob/main/notebooks/Chapter15.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Chapter 15: Word Embeddings and Text Classification

In [None]:
!pip install transformers evaluate chromadb langchain datasets gensim

## 15.2 Learning Objectives

By the end of this chapter, you should be able to:
- tokenize and encode sentences into their corresponding embeddings
- train a simple model using embeddings as features
- use vector databases to store and search documents
- use a similarity metric to perform zero-shot text classification

## 15.4 AG News Dataset

![](https://raw.githubusercontent.com/dvgodoy/assets/main/PyTorchInPractice/images/ch0/data_step1.png)

In this chapter, we'll be primarily using the AG News Dataset. The original [AG](http://groups.di.unipi.it/~gulli/AG_corpus_of_news_articles.html) is a collection of more than 1,000,000 news articles gathered from more than 2,000 news sources.

The version we'll be using here, the [AG News Dataset](https://github.com/mhjabreel/CharCnn_Keras/tree/master/data/ag_news_csv) was constructed by choosing the four largest classes from the original corpus, namely, "world", "sports", "business", and "science and technology". Each class contains 30,000 training and 1,900 testing samples, amouting to a total of 120,000 training and 7,600 testing samples.

The AG News Dataset is a [built-in dataset](https://pytorch.org/text/stable/datasets.html#ag-news) from Torchtext. It downloads the corresponding files directly from the [AG News Dataset](https://github.com/mhjabreel/CharCnn_Keras/tree/master/data/ag_news_csv) repository.

In [None]:
!wget https://raw.githubusercontent.com/mhjabreel/CharCnn_Keras/master/data/ag_news_csv/train.csv
!wget https://raw.githubusercontent.com/mhjabreel/CharCnn_Keras/master/data/ag_news_csv/test.csv
!wget https://raw.githubusercontent.com/mhjabreel/CharCnn_Keras/master/data/ag_news_csv/classes.txt

--2024-09-09 17:13:19--  https://raw.githubusercontent.com/mhjabreel/CharCnn_Keras/master/data/ag_news_csv/train.csv
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.109.133, 185.199.110.133, 185.199.108.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.109.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 29470338 (28M) [text/plain]
Saving to: ‘train.csv’


2024-09-09 17:13:23 (11.0 MB/s) - ‘train.csv’ saved [29470338/29470338]

--2024-09-09 17:13:23--  https://raw.githubusercontent.com/mhjabreel/CharCnn_Keras/master/data/ag_news_csv/test.csv
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.109.133, 185.199.110.133, 185.199.108.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.109.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1857427 (1.8M) [text/plain]
Saving to: ‘test.csv’


2024-09-09 17:13:24 (9.19 MB/

### 15.4.1 Data Cleaning

![](https://raw.githubusercontent.com/dvgodoy/assets/main/PyTorchInPractice/images/ch0/data_step2.png)

In [None]:
import numpy as np

chr_codes = np.array([
     36,   151,    38,  8220,   147,   148,   146,   225,   133,    39,  8221,  8212,   232,   149,   145,   233,
  64257,  8217,   163,   160,    91,    93,  8211,  8482,   234,    37,  8364,   153,   195,   169
])
chr_subst = {f' #{c};':chr(c) for c in chr_codes}
chr_subst.update({' amp;': '&', ' quot;': "'", ' hellip;': '...', ' nbsp;': ' ', '&lt;': '', '&gt;': '',
                  '&lt;em&gt;': '', '&lt;/em&gt;': '', '&lt;strong&gt;': '', '&lt;/strong&gt;': ''})

In [None]:
def replace_chars(sent):
    to_replace = [c for c in list(chr_subst.keys()) if c in sent]
    for c in to_replace:
        sent = sent.replace(c, chr_subst[c])
    return sent

def preproc_description(desc):
    desc = desc.replace('\\', ' ').strip()
    return replace_chars(desc)

### 15.4.2 Hugging Face Datasets

![](https://raw.githubusercontent.com/dvgodoy/assets/main/PyTorchInPractice/images/ch0/data_step4.png)

In [None]:
from datasets import load_dataset, Split, DatasetDict

colnames = ['topic', 'title', 'news']

train_ds = load_dataset("csv", data_files='train.csv', sep=',', split=Split.ALL, column_names=colnames)
test_ds = load_dataset("csv", data_files='test.csv', sep=',', split=Split.ALL, column_names=colnames)

datasets = DatasetDict({'train': train_ds, 'test': test_ds})
datasets

Downloading and preparing dataset csv/default to /home/dvgodoy/.cache/huggingface/datasets/csv/default-e74aa9f4afc75bd6/0.0.0/6954658bab30a358235fa864b05cf819af0e179325c740e4bc853bcc7ec513e1...


Downloading data files:   0%|          | 0/1 [00:00<?, ?it/s]

Extracting data files:   0%|          | 0/1 [00:00<?, ?it/s]

Generating train split: 0 examples [00:00, ? examples/s]

Dataset csv downloaded and prepared to /home/dvgodoy/.cache/huggingface/datasets/csv/default-e74aa9f4afc75bd6/0.0.0/6954658bab30a358235fa864b05cf819af0e179325c740e4bc853bcc7ec513e1. Subsequent calls will reuse this data.
Downloading and preparing dataset csv/default to /home/dvgodoy/.cache/huggingface/datasets/csv/default-b95e4b26323eb18c/0.0.0/6954658bab30a358235fa864b05cf819af0e179325c740e4bc853bcc7ec513e1...


Downloading data files:   0%|          | 0/1 [00:00<?, ?it/s]

Extracting data files:   0%|          | 0/1 [00:00<?, ?it/s]

Generating train split: 0 examples [00:00, ? examples/s]

Dataset csv downloaded and prepared to /home/dvgodoy/.cache/huggingface/datasets/csv/default-b95e4b26323eb18c/0.0.0/6954658bab30a358235fa864b05cf819af0e179325c740e4bc853bcc7ec513e1. Subsequent calls will reuse this data.


DatasetDict({
    train: Dataset({
        features: ['topic', 'title', 'news'],
        num_rows: 120000
    })
    test: Dataset({
        features: ['topic', 'title', 'news'],
        num_rows: 7600
    })
})

In [None]:
datasets['train'][0]

{'topic': 3,
 'title': 'Wall St. Bears Claw Back Into the Black (Reuters)',
 'news': "Reuters - Short-sellers, Wall Street's dwindling\\band of ultra-cynics, are seeing green again."}

In [None]:
datasets = datasets.map(lambda row: {'topic': row['topic']-1,
                                     'news': preproc_description(row['news'])})
datasets = datasets.select_columns(['topic', 'news'])

Map:   0%|          | 0/120000 [00:00<?, ? examples/s]

Map:   0%|          | 0/7600 [00:00<?, ? examples/s]

In [None]:
batch = datasets['train'][:4]
labels, descriptions = batch['topic'], batch['news']
labels, descriptions

([2, 2, 2, 2],
 ["Reuters - Short-sellers, Wall Street's dwindling band of ultra-cynics, are seeing green again.",
  'Reuters - Private investment firm Carlyle Group, which has a reputation for making well-timed and occasionally controversial plays in the defense industry, has quietly placed its bets on another part of the market.',
  'Reuters - Soaring crude prices plus worries about the economy and the outlook for earnings are expected to hang over the stock market next week during the depth of the summer doldrums.',
  'Reuters - Authorities have halted oil export flows from the main pipeline in southern Iraq after intelligence showed a rebel militia could strike infrastructure, an oil official said on Saturday.'])

## 15.5 Tokenization

![](https://raw.githubusercontent.com/dvgodoy/assets/main/PyTorchInPractice/images/ch0/data_step3.png)

In [None]:
tokens = descriptions[0].split()
tokens

['Reuters',
 '-',
 'Short-sellers,',
 'Wall',
 "Street's",
 'dwindling',
 'band',
 'of',
 'ultra-cynics,',
 'are',
 'seeing',
 'green',
 'again.']

In [None]:
from gensim.utils import simple_preprocess
tokens = simple_preprocess(descriptions[0])
tokens

['reuters',
 'short',
 'sellers',
 'wall',
 'street',
 'dwindling',
 'band',
 'of',
 'ultra',
 'cynics',
 'are',
 'seeing',
 'green',
 'again']

In [None]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")

tok_obj = tokenizer.backend_tokenizer
tok_obj



<tokenizers.Tokenizer at 0x7f2c4d81ac30>

In [None]:
normalizer = tok_obj.normalizer
normalizer.lowercase, normalizer.clean_text, normalizer.strip_accents

(True, True, None)

In [None]:
normalized = normalizer.normalize_str(descriptions[0])
normalized

"reuters - short-sellers, wall street's dwindling band of ultra-cynics, are seeing green again."

In [None]:
pre_tokenizer = tok_obj.pre_tokenizer
tokens = pre_tokenizer.pre_tokenize_str(normalized)
tokens

[('reuters', (0, 7)),
 ('-', (8, 9)),
 ('short', (10, 15)),
 ('-', (15, 16)),
 ('sellers', (16, 23)),
 (',', (23, 24)),
 ('wall', (25, 29)),
 ('street', (30, 36)),
 ("'", (36, 37)),
 ('s', (37, 38)),
 ('dwindling', (39, 48)),
 ('band', (49, 53)),
 ('of', (54, 56)),
 ('ultra', (57, 62)),
 ('-', (62, 63)),
 ('cynics', (63, 69)),
 (',', (69, 70)),
 ('are', (71, 74)),
 ('seeing', (75, 81)),
 ('green', (82, 87)),
 ('again', (88, 93)),
 ('.', (93, 94))]

### 15.5.1 Vocabulary

In [None]:
vocab = tok_obj.get_vocab()
vocab

{'##camp': 26468,
 '[unused625]': 630,
 'raid': 8118,
 'zoological': 26168,
 'guarantee': 11302,
 'auckland': 8666,
 'fai': 26208,
 'overhead': 8964,
 '##hee': 21030,
 'johnston': 10773,
 '##linger': 23101,
 'acting': 3772,
 '##tablished': 28146,
 '[unused9]': 10,
 'considered': 2641,
 'pardon': 14933,
 'greyish': 26916,
 '##54': 27009,
 'caracas': 21675,
 'renumbered': 27855,
 'flowing': 8577,
 '[unused748]': 753,
 'forces': 2749,
 'credited': 5827,
 '##hell': 18223,
 'milk': 6501,
 'deals': 9144,
 '##kle': 19099,
 'christy': 21550,
 'guests': 6368,
 '[unused372]': 377,
 'curse': 8364,
 'alvaro': 24892,
 '##onus': 24891,
 '##千': 30310,
 'corpses': 18113,
 'mollusk': 13269,
 '−': 1597,
 '##graphy': 12565,
 '##lius': 15513,
 'convincing': 13359,
 'clutched': 13514,
 'iraq': 5712,
 '##ggy': 22772,
 'urgency': 19353,
 'executives': 12706,
 'hobart': 14005,
 'telecommunication': 25958,
 'def': 13366,
 'inmates': 13187,
 'twinned': 25901,
 'kathleen': 14559,
 'staircase': 10714,
 'approache

In [None]:
vocab['dwindling']

KeyError: 'dwindling'

In [None]:
tok_obj.get_vocab_size()

30522

### 15.5.2 Tokenizer's Model

In [None]:
tok_obj.model

<tokenizers.models.WordPiece at 0x7f2c4cf87890>

In [None]:
tokens_only = [token[0] for token in tokens]
token_ids = [tok_obj.model.token_to_id(token) for token in tokens_only]
print(tokens_only)
print(token_ids)

['reuters', '-', 'short', '-', 'sellers', ',', 'wall', 'street', "'", 's', 'dwindling', 'band', 'of', 'ultra', '-', 'cynics', ',', 'are', 'seeing', 'green', 'again', '.']
[26665, 1011, 2460, 1011, 19041, 1010, 2813, 2395, 1005, 1055, None, 2316, 1997, 11087, 1011, None, 1010, 2024, 3773, 2665, 2153, 1012]


In [None]:
missing_id = token_ids.index(None)
missing_token = tokens_only[missing_id]
missing_id, missing_token

(10, 'dwindling')

In [None]:
tokenized_word = tok_obj.model.tokenize(missing_token)
[piece.as_tuple() for piece in tokenized_word]

[(1040, 'd', (0, 1)), (11101, '##wind', (1, 5)), (2989, '##ling', (5, 9))]

In [None]:
encoded = tok_obj.encode(descriptions[0], add_special_tokens=False)
print(encoded.tokens)
print(encoded.ids)

['reuters', '-', 'short', '-', 'sellers', ',', 'wall', 'street', "'", 's', 'd', '##wind', '##ling', 'band', 'of', 'ultra', '-', 'cy', '##nic', '##s', ',', 'are', 'seeing', 'green', 'again', '.']
[26665, 1011, 2460, 1011, 19041, 1010, 2813, 2395, 1005, 1055, 1040, 11101, 2989, 2316, 1997, 11087, 1011, 22330, 8713, 2015, 1010, 2024, 3773, 2665, 2153, 1012]


### 15.5.3 Special Tokens

In [None]:
post_processor = tok_obj.post_processor
post_encoded = post_processor.process(encoded)
print(post_encoded.tokens)

['[CLS]', 'reuters', '-', 'short', '-', 'sellers', ',', 'wall', 'street', "'", 's', 'd', '##wind', '##ling', 'band', 'of', 'ultra', '-', 'cy', '##nic', '##s', ',', 'are', 'seeing', 'green', 'again', '.', '[SEP]']


In [None]:
print(tok_obj.encode(descriptions[0]).tokens)

['[CLS]', 'reuters', '-', 'short', '-', 'sellers', ',', 'wall', 'street', "'", 's', 'd', '##wind', '##ling', 'band', 'of', 'ultra', '-', 'cy', '##nic', '##s', ',', 'are', 'seeing', 'green', 'again', '.', '[SEP]']


#### 15.5.3.1 `[CLS]`: Classification Token

In [None]:
tokenizer.cls_token, tokenizer.cls_token_id

('[CLS]', 101)

#### 15.5.3.2 `[SEP]`: Separation Token

In [None]:
tokenizer.sep_token, tokenizer.sep_token_id

('[SEP]', 102)

In [None]:
print(tok_obj.encode(*descriptions[:2]).tokens)

['[CLS]', 'reuters', '-', 'short', '-', 'sellers', ',', 'wall', 'street', "'", 's', 'd', '##wind', '##ling', 'band', 'of', 'ultra', '-', 'cy', '##nic', '##s', ',', 'are', 'seeing', 'green', 'again', '.', '[SEP]', 'reuters', '-', 'private', 'investment', 'firm', 'carly', '##le', 'group', ',', 'which', 'has', 'a', 'reputation', 'for', 'making', 'well', '-', 'timed', 'and', 'occasionally', 'controversial', 'plays', 'in', 'the', 'defense', 'industry', ',', 'has', 'quietly', 'placed', 'its', 'bets', 'on', 'another', 'part', 'of', 'the', 'market', '.', '[SEP]']


#### 15.5.3.3 `[UNK]`: Unknown Token

In [None]:
tokenizer.unk_token, tokenizer.unk_token_id

('[UNK]', 100)

#### 15.5.3.4 `[PAD]`: Padding Token

In [None]:
tokenizer.pad_token, tokenizer.pad_token_id

('[PAD]', 0)

In [None]:
[len(seq) for seq in tokenizer(descriptions)['input_ids']]

[28, 41, 38, 34]

In [None]:
padded_token_ids = tokenizer(descriptions, padding=True, return_tensors='pt')['input_ids']
padded_token_ids

tensor([[  101, 26665,  1011,  2460,  1011, 19041,  1010,  2813,  2395,  1005,
          1055,  1040, 11101,  2989,  2316,  1997, 11087,  1011, 22330,  8713,
          2015,  1010,  2024,  3773,  2665,  2153,  1012,   102,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0],
        [  101, 26665,  1011,  2797,  5211,  3813, 18431,  2571,  2177,  1010,
          2029,  2038,  1037,  5891,  2005,  2437,  2092,  1011, 22313,  1998,
          5681,  6801,  3248,  1999,  1996,  3639,  3068,  1010,  2038,  5168,
          2872,  2049, 29475,  2006,  2178,  2112,  1997,  1996,  3006,  1012,
           102],
        [  101, 26665,  1011, 23990, 13587,  7597,  4606, 15508,  2055,  1996,
          4610,  1998,  1996, 17680,  2005, 16565,  2024,  3517,  2000,  6865,
          2058,  1996,  4518,  3006,  2279,  2733,  2076,  1996,  5995,  1997,
          1996,  2621,  2079,  6392,  6824,  2015,  1012,   102,     0,     0,
             0],
 

### 15.5.4 Truncation

In [None]:
tokenizer.max_len_single_sentence, tokenizer.model_max_length

(510, 512)

## 15.6 Embeddings

In [None]:
import torch.nn as nn

emb_dims = 50
embeddings = nn.Embedding(len(vocab), emb_dims)
embeddings

Embedding(30522, 50)

In [None]:
import torch

idx = torch.as_tensor([vocab['reuters']])
idx, embeddings(idx)

(tensor([26665]),
 tensor([[ 1.8288e-01, -1.3801e+00, -5.3830e-01, -1.4301e-01, -2.0827e-01,
          -1.7362e+00,  1.0018e+00,  4.6152e-01,  3.2680e-02, -5.8854e-01,
           3.1180e-01,  6.0066e-01, -1.2477e-01, -1.1660e+00, -1.2219e+00,
           1.0182e+00, -2.0216e-01, -3.1973e-01, -5.4026e-01, -1.8794e+00,
           2.2819e-01,  2.7748e-01, -1.4689e-01, -9.8170e-01, -2.1549e+00,
           4.9118e-01, -4.7388e-01, -2.3673e-01,  2.1740e-03,  1.1351e-01,
          -1.0422e+00,  1.4274e+00, -9.8884e-02, -7.2925e-01, -3.3722e-01,
           4.6264e-01, -3.8414e-01, -9.5412e-01, -5.6739e-02,  2.9316e+00,
           1.2275e-01,  5.6614e-01,  1.0147e-01, -8.2784e-01,  1.8933e-01,
          -8.5093e-01, -3.1484e-01, -1.7159e+00,  6.4275e-01,  2.7018e+00]],
        grad_fn=<EmbeddingBackward0>))

In [None]:
embeddings.weight[idx]

tensor([[ 1.8288e-01, -1.3801e+00, -5.3830e-01, -1.4301e-01, -2.0827e-01,
         -1.7362e+00,  1.0018e+00,  4.6152e-01,  3.2680e-02, -5.8854e-01,
          3.1180e-01,  6.0066e-01, -1.2477e-01, -1.1660e+00, -1.2219e+00,
          1.0182e+00, -2.0216e-01, -3.1973e-01, -5.4026e-01, -1.8794e+00,
          2.2819e-01,  2.7748e-01, -1.4689e-01, -9.8170e-01, -2.1549e+00,
          4.9118e-01, -4.7388e-01, -2.3673e-01,  2.1740e-03,  1.1351e-01,
         -1.0422e+00,  1.4274e+00, -9.8884e-02, -7.2925e-01, -3.3722e-01,
          4.6264e-01, -3.8414e-01, -9.5412e-01, -5.6739e-02,  2.9316e+00,
          1.2275e-01,  5.6614e-01,  1.0147e-01, -8.2784e-01,  1.8933e-01,
         -8.5093e-01, -3.1484e-01, -1.7159e+00,  6.4275e-01,  2.7018e+00]],
       grad_fn=<IndexBackward0>)

### 15.6.1 Word2Vec

### 15.6.2 Embedding Arithmetic

**KING - MAN + WOMAN = ?**

**KING - MAN + WOMAN = QUEEN**

![](https://raw.githubusercontent.com/dvgodoy/assets/main/PyTorchInPractice/images/ch13/embed_arithmetic.png)

**KING - MAN + WOMAN ~ KING**

![](https://raw.githubusercontent.com/dvgodoy/assets/main/PyTorchInPractice/images/ch13/synthetic_queen.png)

### 15.6.3 Global Vectors (GloVe)

In [None]:
from gensim import downloader

vec = downloader.load('glove-wiki-gigaword-50')

In [None]:
vec.vectors, vec.vectors.shape

(array([[ 0.418   ,  0.24968 , -0.41242 , ..., -0.18411 , -0.11514 ,
         -0.78581 ],
        [ 0.013441,  0.23682 , -0.16899 , ..., -0.56657 ,  0.044691,
          0.30392 ],
        [ 0.15164 ,  0.30177 , -0.16763 , ..., -0.35652 ,  0.016413,
          0.10216 ],
        ...,
        [-0.51181 ,  0.058706,  1.0913  , ..., -0.25003 , -1.125   ,
          1.5863  ],
        [-0.75898 , -0.47426 ,  0.4737  , ...,  0.78954 , -0.014116,
          0.6448  ],
        [ 0.072617, -0.51393 ,  0.4728  , ..., -0.18907 , -0.59021 ,
          0.55559 ]], dtype=float32),
 (400000, 50))

In [None]:
vec['reuters']

array([-0.13741  , -0.25495  ,  1.8853   ,  0.1476   ,  0.63859  ,
       -0.67678  , -1.1622   , -0.21528  ,  0.2598   , -0.52879  ,
        0.66678  , -0.76747  , -0.52731  ,  0.06657  ,  0.076613 ,
        0.32743  , -0.80251  , -0.4955   , -0.37393  ,  0.11261  ,
        1.1671   ,  1.1508   ,  0.61801  ,  0.079467 ,  0.1269   ,
       -0.072447 , -1.2037   , -0.24622  , -0.77076  ,  0.76699  ,
        1.2745   , -0.12898  ,  0.99892  , -0.26733  , -0.57542  ,
       -1.0151   , -0.14278  , -0.43824  ,  0.76577  , -0.0087715,
        1.2848   ,  0.0030819,  0.1186   , -0.38817  , -0.23516  ,
       -0.92094  , -0.51644  ,  1.5083   ,  0.36456  ,  0.59912  ],
      dtype=float32)

In [None]:
import torch.nn as nn

tensor_glove = torch.as_tensor(vec.vectors).float()
embedding = nn.Embedding.from_pretrained(tensor_glove)
embedding.state_dict()

OrderedDict([('weight',
              tensor([[ 0.4180,  0.2497, -0.4124,  ..., -0.1841, -0.1151, -0.7858],
                      [ 0.0134,  0.2368, -0.1690,  ..., -0.5666,  0.0447,  0.3039],
                      [ 0.1516,  0.3018, -0.1676,  ..., -0.3565,  0.0164,  0.1022],
                      ...,
                      [-0.5118,  0.0587,  1.0913,  ..., -0.2500, -1.1250,  1.5863],
                      [-0.7590, -0.4743,  0.4737,  ...,  0.7895, -0.0141,  0.6448],
                      [ 0.0726, -0.5139,  0.4728,  ..., -0.1891, -0.5902,  0.5556]]))])

In [None]:
idx = vec.key_to_index['reuters']
token = vec.index_to_key[idx]
idx, token

(10851, 'reuters')

In [None]:
vec.key_to_index['zzzzz']

KeyError: 'zzzzz'

In [None]:
def encode_str(key_to_index, tokens, unk_token=-1):
    token_ids = [key_to_index.get(token, unk_token) for token in tokens]
    return token_ids

In [None]:
some_ids = encode_str(vec.key_to_index, ['reuters', 'zzzzz'])
some_ids

[10851, -1]

In [None]:
def get_embeddings(embedding, token_ids):
    valid_ids = torch.as_tensor([token_id for token_id in token_ids if token_id >= 0])
    embedded_tokens = embedding(valid_ids)
    return embedded_tokens

In [None]:
get_embeddings(embedding, some_ids)

tensor([[-0.1374, -0.2549,  1.8853,  0.1476,  0.6386, -0.6768, -1.1622, -0.2153,
          0.2598, -0.5288,  0.6668, -0.7675, -0.5273,  0.0666,  0.0766,  0.3274,
         -0.8025, -0.4955, -0.3739,  0.1126,  1.1671,  1.1508,  0.6180,  0.0795,
          0.1269, -0.0724, -1.2037, -0.2462, -0.7708,  0.7670,  1.2745, -0.1290,
          0.9989, -0.2673, -0.5754, -1.0151, -0.1428, -0.4382,  0.7658, -0.0088,
          1.2848,  0.0031,  0.1186, -0.3882, -0.2352, -0.9209, -0.5164,  1.5083,
          0.3646,  0.5991]])

In [None]:
def func_builder(vec):
    tensor_glove = torch.as_tensor(vec.vectors).float()
    embedding = nn.Embedding.from_pretrained(tensor_glove)

    def get_vecs_by_tokens(tokens):
        token_ids = encode_str(vec.key_to_index, tokens)
        embedded_tokens = get_embeddings(embedding, token_ids)
        return embedded_tokens

    return get_vecs_by_tokens

get_vecs_by_tokens = func_builder(vec)

In [None]:
from gensim.utils import simple_preprocess
tokens = simple_preprocess(descriptions[0])
tokens

['reuters',
 'short',
 'sellers',
 'wall',
 'street',
 'dwindling',
 'band',
 'of',
 'ultra',
 'cynics',
 'are',
 'seeing',
 'green',
 'again']

In [None]:
embedded_tokens = get_vecs_by_tokens(tokens)
embedded_tokens.shape

torch.Size([14, 50])

## 15.7 Vector Databases

### 15.7.1 ChromaDB

[ChromaDB](https://docs.trychroma.com/getting-started) is an open-source embedding database that allows you store embeddings and metadata, embed documents and queries, and search embeddings.

In this example, we'll be storing a collection of GloVe embeddings for the AG News Dataset on a persisted database, and then will query the collection to search for similar items.

Creating a database in ChromaDB follows a short sequence of steps:
- getting a client, which we can configure to persist the data
- creating a collection that will store the embeddings and metadata - you can think of it as a folder or table
- adding documents (to be embedded by ChromaDB itself) or embeddings (as we're doing here) to the collection, along with any corresponding metadata you may wish to add
- querying the collection to get the most similar results back

Let's get a client and define `agnews_db` as the folder our collection must be saved to:

In [None]:
import chromadb

client = chromadb.PersistentClient(path="./agnews_db")

In [None]:
collection = client.create_collection("agnews_collection")

![](https://raw.githubusercontent.com/dvgodoy/assets/main/PyTorchInPractice/images/ch13/vector_db.png)

In [None]:
def tokenize_batch(sentences, tokenizer=None):
    if tokenizer is None:
        tokenizer = simple_preprocess

    return [tokenizer(s) for s in sentences]

def get_bag_of_embeddings(tokens):
    embeddings = torch.cat([get_vecs_by_tokens(s).mean(axis=0).unsqueeze(0) for s in tokens], dim=0)
    return embeddings

In [None]:
from torch.utils.data import DataLoader

batch_size = 32
unshuffled_dl = DataLoader(dataset=datasets['train'], batch_size=batch_size, shuffle=False)

for i, batch in enumerate(unshuffled_dl):
    labels, sentences = batch['topic'], batch['news']
    tokens = tokenize_batch(sentences)
    embeddings = get_bag_of_embeddings(tokens)
    ids = [f'{i:06}' for i in np.arange(i*batch_size, i*batch_size+len(sentences))]

    collection.add(embeddings=embeddings.tolist(),
                   documents=sentences,
                   metadatas=[{'label': v} for v in labels.tolist()],
                   ids=ids)

    if i == 300: # roughly 10k docs
        break

In [None]:
collection.count()

9632

### 15.7.2 Similarity Search

In [None]:
query_sentence = 'The company running the Japanese nuclear plant hit by a fatal accident is to close its reactors for safety checks.'
query_tokens = tokenize_batch([query_sentence])
query_embeddings = get_bag_of_embeddings(query_tokens)[0]

query_embeddings

tensor([ 4.0827e-01,  7.9920e-02,  3.1115e-01,  1.8721e-01, -4.7369e-02,
         3.3698e-01, -5.0617e-01, -3.6810e-02,  2.6068e-01, -1.2847e-01,
         2.2948e-01, -5.9424e-02, -3.3787e-01,  2.9188e-02,  2.6071e-01,
         2.0179e-01, -7.7526e-02,  3.3718e-01, -5.2526e-01, -2.7158e-01,
         3.7156e-01, -1.0214e-01, -1.1645e-01, -2.9637e-01,  7.9672e-02,
        -1.6904e+00, -4.1659e-02,  1.0523e-01,  2.8247e-01,  2.9835e-03,
         3.0708e+00, -1.5180e-01, -1.4941e-01, -2.3085e-01,  2.1777e-01,
        -4.0086e-02,  2.0281e-01,  8.9309e-02,  5.5554e-02,  2.6830e-02,
        -2.9215e-01, -1.5104e-01,  2.6449e-01, -1.0034e-01,  1.4842e-01,
         8.8036e-02, -2.3717e-01,  3.1029e-01,  1.2701e-02, -1.4513e-01])

![](https://raw.githubusercontent.com/dvgodoy/assets/main/PyTorchInPractice/images/ch13/query_db.png)

In [None]:
query_embeddings = query_embeddings.tolist()
collection.query(query_embeddings=query_embeddings, n_results=5)

{'ids': [['000030', '001046', '004715', '002464', '006905']],
 'distances': [[0.0,
   0.8038501739501953,
   0.9175586104393005,
   0.9644219875335693,
   0.9812381267547607]],
 'metadatas': [[{'label': 2},
   {'label': 0},
   {'label': 0},
   {'label': 0},
   {'label': 2}]],
 'embeddings': None,
 'documents': [['The company running the Japanese nuclear plant hit by a fatal accident is to close its reactors for safety checks.',
   'AP - The operator of a nuclear power plant where a long-neglected cooling pipe burst and killed four workers last week said Monday that four other pipes at its reactors also went unchecked for years.',
   'TOKYO The operators of a Japanese nuclear plant say there was no evidence of danger at the plant before a deadly explosion this month.',
   'Reuters - No more Japanese nuclear reactors need to be closed for inspections, electric power companies said on Wednesday after submitting reports ordered by the government following a reactor accident that killed fou

In [None]:
query_sentence = 'asian stock market'
query_tokens = tokenize_batch([query_sentence])
query_embeddings = get_bag_of_embeddings(query_tokens)[0]
query_embeddings = query_embeddings.tolist()

In [None]:
collection.query(query_embeddings=query_embeddings, n_results=5)

{'ids': [['007389', '006925', '004791', '007014', '006829']],
 'distances': [[5.2573628425598145,
   5.490711688995361,
   5.643772125244141,
   5.697542190551758,
   5.810704231262207]],
 'metadatas': [[{'label': 2},
   {'label': 2},
   {'label': 0},
   {'label': 0},
   {'label': 2}]],
 'embeddings': None,
 'documents': [['Asian stocks rose after oil prices fell from a record on Friday, easing concern higher energy costs will damp consumer spending and corporate profits.',
   'Asian stocks advanced after oil prices fell from a record Friday in New York, easing concern higher energy costs will damp consumer spending and corporate profits.',
   "AP - Tokyo's main stock index ended lower Friday amid profit-taking of technology issues and concerns about soaring oil prices. The U.S. dollar was down against the Japanese yen.",
   'Japanese stocks rose after oil prices fell from a record in New York on Friday, easing concern higher energy costs will damp consumer spending and corporate profi

***
**ASIDE: Cosine Similarity**

If two vectors are pointing in the same direction, their cosine similarity is a perfect one. If they are orthogonal (that is, if there is a right angle between them), their cosine similarity is zero. If they are pointing in opposite directions, their cosine similarity is minus one.

$$
\Large
\cos \theta = \frac{\sum_i{x_iy_i}}{\sqrt{\sum_j{x_j^2}}\sqrt{\sum_j{y_j^2}}}
$$
***


## 15.8 Zero-Shot Text Classification

![](https://raw.githubusercontent.com/dvgodoy/assets/main/PyTorchInPractice/images/ch0/model_step5.png)

In [None]:
cand_labels = ["world", "sports", "business", "science and technology"]

cand_emb = torch.vstack([get_vecs_by_tokens(tokens).mean(axis=0) for tokens in tokenize_batch(cand_labels)])
cand_emb.shape

torch.Size([4, 50])

In [None]:
cos = nn.CosineSimilarity(dim=2)

cos(cand_emb.unsqueeze(1), cand_emb.unsqueeze(0))

tensor([[1.0000, 0.6529, 0.6136, 0.6678],
        [0.6529, 1.0000, 0.6410, 0.6171],
        [0.6136, 0.6410, 1.0000, 0.8069],
        [0.6678, 0.6171, 0.8069, 1.0000]])

In [None]:
batch_size = 32
dataloader = DataLoader(dataset=datasets['test'], batch_size=batch_size, shuffle=False)

batch = next(iter(dataloader))
labels, sentences = batch['topic'], batch['news']
tokens = tokenize_batch(sentences)
embeddings = get_bag_of_embeddings(tokens)
similarities = cos(embeddings.unsqueeze(1), cand_emb.unsqueeze(0))
similarities

tensor([[0.6534, 0.5361, 0.7945, 0.7030],
        [0.7634, 0.6130, 0.7007, 0.7579],
        [0.6658, 0.5156, 0.7244, 0.8548],
        [0.7275, 0.5425, 0.6800, 0.7176],
        [0.7058, 0.5120, 0.7004, 0.7567],
        [0.7184, 0.5626, 0.7428, 0.7919],
        [0.7056, 0.5410, 0.7269, 0.8026],
        [0.6613, 0.5637, 0.7525, 0.8029],
        [0.5893, 0.4821, 0.6439, 0.6403],
        [0.7183, 0.5389, 0.7209, 0.7500],
        [0.6883, 0.6081, 0.8330, 0.8441],
        [0.6728, 0.6266, 0.8165, 0.8166],
        [0.7453, 0.6166, 0.8239, 0.8035],
        [0.7400, 0.4860, 0.6221, 0.7080],
        [0.6582, 0.4579, 0.6692, 0.7232],
        [0.6819, 0.4262, 0.6630, 0.7159],
        [0.6154, 0.4228, 0.5796, 0.6999],
        [0.7636, 0.5894, 0.7239, 0.8140],
        [0.6990, 0.4981, 0.6911, 0.8043],
        [0.6747, 0.4536, 0.7283, 0.6745],
        [0.7615, 0.5422, 0.7594, 0.7675],
        [0.6438, 0.5324, 0.7470, 0.7863],
        [0.5911, 0.4596, 0.7293, 0.8482],
        [0.7202, 0.5606, 0.7810, 0

In [None]:
predicted_class = similarities.argmax(dim=1)
predicted_class

tensor([2, 0, 3, 0, 3, 3, 3, 3, 2, 3, 3, 3, 2, 0, 3, 3, 3, 3, 3, 2, 3, 3, 3, 2,
        2, 3, 0, 3, 3, 0, 3, 0])

In [None]:
(predicted_class == labels).float().mean()

tensor(0.5625)

### 15.8.1 Evaluation

In [None]:
import evaluate

metric1 = evaluate.load('precision', average=None)
metric2 = evaluate.load('recall', average=None)
metric3 = evaluate.load('accuracy')

In [None]:
for batch in dataloader:
    labels, sentences = batch['topic'], batch['news']
    tokens = tokenize_batch(sentences)
    embeddings = get_bag_of_embeddings(tokens)

    # predictions = model(embeddings)
    predictions = cos(embeddings.unsqueeze(1), cand_emb.unsqueeze(0))

    pred_class = predictions.argmax(dim=1).tolist()
    labels = labels.tolist()

    metric1.add_batch(references=labels, predictions=pred_class)
    metric2.add_batch(references=labels, predictions=pred_class)
    metric3.add_batch(references=labels, predictions=pred_class)

In [None]:
metric1.compute(average=None), metric2.compute(average=None), metric3.compute()

({'precision': array([0.33205619, 1.        , 0.67253045, 0.43290471])},
 {'recall': array([4.10526316e-01, 5.26315789e-04, 7.84736842e-01, 6.91052632e-01])},
 {'accuracy': 0.47171052631578947})

## 15.9 Chunking Strategies

In [None]:
text = """
ITEM 1A. RISK FACTORS Our operations and financial results are subject to various risks and uncertainties, including those described below, that could adversely affect our business, financial condition, results of operations, cash flows, and the trading price of our common stock. STRATEGIC AND COMPETITIVE RISKS We face intense competition across all markets for our products and services, which may lead to lower revenue or operating margins.    Competition in the technology sector Our competitors range in size from diversified global companies with significant research and development resources to small, specialized firms whose narrower product lines may let them be more effective in deploying technical, marketing, and financial resources. Barriers to entry in many of our businesses are low and many of the areas in which we compete evolve rapidly with changing and disruptive technologies, shifting user needs, and frequent introductions of new products and services. Our ability to remain competitive depends on our success in making innovative products, devices, and services that appeal to businesses and consumers.    Competition among platform-based ecosystems An important element of our business model has been to create platform-based ecosystems on which many participants can build diverse solutions. A well-established ecosystem creates beneficial network effects among users, application developers, and the platform provider that can accelerate growth. Establishing significant scale in the marketplace is necessary to achieve and maintain attractive margins. We face significant competition from firms that provide competing platforms.
"""

### 15.9.1 Fixed-Length

In [None]:
from langchain.text_splitter import RecursiveCharacterTextSplitter

text_splitter = RecursiveCharacterTextSplitter(chunk_size=256, chunk_overlap=20)
chunks = text_splitter.create_documents([text])
chunks[:3]

[Document(page_content='ITEM 1A. RISK FACTORS Our operations and financial results are subject to various risks and uncertainties, including those described below, that could adversely affect our business, financial condition, results of operations, cash flows, and the trading', metadata={}),
 Document(page_content='and the trading price of our common stock. STRATEGIC AND COMPETITIVE RISKS We face intense competition across all markets for our products and services, which may lead to lower revenue or operating margins.    Competition in the technology sector Our', metadata={}),
 Document(page_content='sector Our competitors range in size from diversified global companies with significant research and development resources to small, specialized firms whose narrower product lines may let them be more effective in deploying technical, marketing, and', metadata={})]

### 15.9.2 Content-Aware

In [None]:
import nltk
nltk.download('punkt')
nltk.download('punkt_tab')

[nltk_data] Downloading package punkt to /home/dvgodoy/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


True

In [None]:
from nltk.tokenize import sent_tokenize

chunks = sent_tokenize(text)
chunks[:3]

['\nITEM 1A.',
 'RISK FACTORS Our operations and financial results are subject to various risks and uncertainties, including those described below, that could adversely affect our business, financial condition, results of operations, cash flows, and the trading price of our common stock.',
 'STRATEGIC AND COMPETITIVE RISKS We face intense competition across all markets for our products and services, which may lead to lower revenue or operating margins.']

### 15.9.3 Custom

In [None]:
chunks = text.split('  ')
chunks[:3]

['\nITEM 1A. RISK FACTORS Our operations and financial results are subject to various risks and uncertainties, including those described below, that could adversely affect our business, financial condition, results of operations, cash flows, and the trading price of our common stock. STRATEGIC AND COMPETITIVE RISKS We face intense competition across all markets for our products and services, which may lead to lower revenue or operating margins.',
 '',
 'Competition in the technology sector Our competitors range in size from diversified global companies with significant research and development resources to small, specialized firms whose narrower product lines may let them be more effective in deploying technical, marketing, and financial resources. Barriers to entry in many of our businesses are low and many of the areas in which we compete evolve rapidly with changing and disruptive technologies, shifting user needs, and frequent introductions of new products and services. Our ability