# Chapter 15: Word Embeddings and Text Classification

In [3]:
!pip install transformers evaluate portalocker chromadb langchain

## 15.4 AG News Dataset

![](https://raw.githubusercontent.com/dvgodoy/assets/main/PyTorchInPractice/images/ch0/data_step1.png)

In [6]:
!wget https://raw.githubusercontent.com/mhjabreel/CharCnn_Keras/master/data/ag_news_csv/train.csv
!wget https://raw.githubusercontent.com/mhjabreel/CharCnn_Keras/master/data/ag_news_csv/test.csv
!wget https://raw.githubusercontent.com/mhjabreel/CharCnn_Keras/master/data/ag_news_csv/classes.txt

### 15.4.1 Data Cleaning

![](https://raw.githubusercontent.com/dvgodoy/assets/main/PyTorchInPractice/images/ch0/data_step2.png)

In [7]:
import numpy as np

chr_codes = np.array([
     36,   151,    38,  8220,   147,   148,   146,   225,   133,    39,  8221,  8212,   232,   149,   145,   233,
  64257,  8217,   163,   160,    91,    93,  8211,  8482,   234,    37,  8364,   153,   195,   169
])
chr_subst = {f' #{c};':chr(c) for c in chr_codes}
chr_subst.update({' amp;': '&', ' quot;': "'", ' hellip;': '...', ' nbsp;': ' ', '&lt;': '', '&gt;': '',
                  '&lt;em&gt;': '', '&lt;/em&gt;': '', '&lt;strong&gt;': '', '&lt;/strong&gt;': ''})

In [8]:
def replace_chars(sent):
    to_replace = [c for c in list(chr_subst.keys()) if c in sent]
    for c in to_replace:
        sent = sent.replace(c, chr_subst[c])
    return sent

def preproc_description(desc):
    desc = desc.replace('\\', ' ').strip()
    return replace_chars(desc)

### 15.4.2 DataPipes

![](https://raw.githubusercontent.com/dvgodoy/assets/main/PyTorchInPractice/images/ch0/data_step4.png)

In [9]:
from torchdata.datapipes.iter import FileLister
from torch.utils.data import DataLoader

def create_raw_datapipe(fname):
    datapipe = FileLister(root='.')
    datapipe = datapipe.filter(filter_fn=lambda v: v.endswith(fname))
    datapipe = datapipe.open_files(mode='rt', encoding="utf-8")
    datapipe = datapipe.parse_csv(delimiter=",", skip_lines=0)
    datapipe = datapipe.map(lambda row: (int(row[0])-1, preproc_description(row[2])))
    return datapipe

In [10]:
datapipes = {}
datapipes['train'] = create_raw_datapipe('train.csv')
datapipes['test'] = create_raw_datapipe('test.csv')

In [11]:
batch = next(iter(DataLoader(dataset=datapipes['train'], batch_size=4)))
labels, descriptions = batch
labels, descriptions

(tensor([2, 2, 2, 2]),
 ("Reuters - Short-sellers, Wall Street's dwindling band of ultra-cynics, are seeing green again.",
  'Reuters - Private investment firm Carlyle Group, which has a reputation for making well-timed and occasionally controversial plays in the defense industry, has quietly placed its bets on another part of the market.',
  'Reuters - Soaring crude prices plus worries about the economy and the outlook for earnings are expected to hang over the stock market next week during the depth of the summer doldrums.',
  'Reuters - Authorities have halted oil export flows from the main pipeline in southern Iraq after intelligence showed a rebel militia could strike infrastructure, an oil official said on Saturday.'))

## 15.5 Tokenization

![](https://raw.githubusercontent.com/dvgodoy/assets/main/PyTorchInPractice/images/ch0/data_step3.png)

In [12]:
from torchtext.data import get_tokenizer
tokenizer = get_tokenizer("basic_english")

tokens = tokenizer(descriptions[0])
descriptions[0], tokens

("Reuters - Short-sellers, Wall Street's dwindling band of ultra-cynics, are seeing green again.",
 ['reuters',
  '-',
  'short-sellers',
  ',',
  'wall',
  'street',
  "'",
  's',
  'dwindling',
  'band',
  'of',
  'ultra-cynics',
  ',',
  'are',
  'seeing',
  'green',
  'again',
  '.'])

### 15.5.1 Vocabulary

In [13]:
from torchtext.vocab import build_vocab_from_iterator

def yield_tokens(datapipe):
    for label, description in datapipe:
        yield tokenizer(description)

vocab = build_vocab_from_iterator(yield_tokens(datapipes['train']))
vocab

Vocab()

In [14]:
len(vocab)

78147

In [15]:
vocab.lookup_indices(['reuters', 'press', 'washington', 'knicks', 'sox', 'raccoon', 'duck'])

[28, 389, 116, 3836, 356, 43281, 14731]

In [16]:
vocab['reuters']

28

### 15.5.2 Transform

In [17]:
from torchtext.transforms import VocabTransform

vocab_transform = VocabTransform(vocab)
vocab_transform

VocabTransform(
  (vocab): Vocab()
)

In [18]:
tokens = tokenizer(descriptions[0])
token_ids = vocab_transform(tokens)
tokens, token_ids

(['reuters',
  '-',
  'short-sellers',
  ',',
  'wall',
  'street',
  "'",
  's',
  'dwindling',
  'band',
  'of',
  'ultra-cynics',
  ',',
  'are',
  'seeing',
  'green',
  'again',
  '.'],
 [28,
  11,
  44045,
  2,
  409,
  323,
  8,
  9,
  10941,
  3065,
  5,
  45321,
  2,
  36,
  3599,
  807,
  412,
  0])

In [19]:
vocab_transform(['anteater'])

RuntimeError: ignored

In [20]:
vocab.set_default_index(-1)

In [21]:
vocab_transform(['anteater', 'zzzzz'])

[-1, -1]

### 15.5.3 Special Tokens

In [22]:
vocab = build_vocab_from_iterator(yield_tokens(datapipes['train']), specials=['<unk>', '<pad>', '<sep>', '<cls>'])
vocab_transform = VocabTransform(vocab)

#### 15.5.3.1 `<UNK>`: Unknown Token

In [23]:
vocab['<unk>']

0

In [24]:
vocab.set_default_index(vocab['<unk>'])

In [25]:
vocab['anteater']

0

#### 15.5.3.2 `<PAD>`: Padding Token

In [26]:
tokens = [tokenizer(desc) for desc in descriptions]
token_ids = vocab_transform(tokens)
[len(t) for t in token_ids]

[18, 36, 33, 32]

In [27]:
vocab['<pad>']

1

In [28]:
from torchtext.transforms import ToTensor

padded_token_ids = ToTensor(padding_value=vocab['<pad>'])(token_ids)
padded_token_ids

tensor([[   32,    15, 44049,     6,   413,   327,    12,    13, 10945,  3069,
             9, 45325,     6,    40,  3603,   811,   416,     4,     1,     1,
             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
             1,     1,     1,     1,     1,     1],
        [   32,    15,   858,   736,   331, 19149,    95,     6,    86,    28,
             7,  3880,    16,   506, 45752,    11, 14956,  1279,  2718,    10,
             5,   527,   220,     6,    28,  3452,  2059,    22,  8228,    14,
           204,   301,     9,     5,   125,     4],
        [   32,    15,  2124,   455,   105,  1671,  1454,    67,     5,   356,
            11,     5,   997,    16,   298,    40,   215,     8,  6362,    41,
             5,   302,   125,   106,    81,   178,     5,  7835,     9,     5,
          1046, 15411,     4,     1,     1,     1],
        [   32,    15,   674,    37,  5100,    83,  3735,  7986,    31,     5,
           743,  3144,    10,   466,    96,    36,  14

#### 15.5.3.3 `<SEP>`: Separation Token

In [29]:
vocab['<sep>']

2

In [30]:
from torchtext.transforms import AddToken

added_token_ids = AddToken(token=vocab['<sep>'], begin=False)(token_ids)
ToTensor(padding_value=vocab['<pad>'])(added_token_ids)

tensor([[   32,    15, 44049,     6,   413,   327,    12,    13, 10945,  3069,
             9, 45325,     6,    40,  3603,   811,   416,     4,     2,     1,
             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
             1,     1,     1,     1,     1,     1,     1],
        [   32,    15,   858,   736,   331, 19149,    95,     6,    86,    28,
             7,  3880,    16,   506, 45752,    11, 14956,  1279,  2718,    10,
             5,   527,   220,     6,    28,  3452,  2059,    22,  8228,    14,
           204,   301,     9,     5,   125,     4,     2],
        [   32,    15,  2124,   455,   105,  1671,  1454,    67,     5,   356,
            11,     5,   997,    16,   298,    40,   215,     8,  6362,    41,
             5,   302,   125,   106,    81,   178,     5,  7835,     9,     5,
          1046, 15411,     4,     2,     1,     1,     1],
        [   32,    15,   674,    37,  5100,    83,  3735,  7986,    31,     5,
           743,  3144,    10,   4

#### 15.5.3.4 `<CLS>`: Classification Token

In [31]:
added_token_ids = AddToken(token=vocab['<cls>'], begin=True)(added_token_ids)
ToTensor(padding_value=vocab['<pad>'])(added_token_ids)

tensor([[    3,    32,    15, 44049,     6,   413,   327,    12,    13, 10945,
          3069,     9, 45325,     6,    40,  3603,   811,   416,     4,     2,
             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
             1,     1,     1,     1,     1,     1,     1,     1],
        [    3,    32,    15,   858,   736,   331, 19149,    95,     6,    86,
            28,     7,  3880,    16,   506, 45752,    11, 14956,  1279,  2718,
            10,     5,   527,   220,     6,    28,  3452,  2059,    22,  8228,
            14,   204,   301,     9,     5,   125,     4,     2],
        [    3,    32,    15,  2124,   455,   105,  1671,  1454,    67,     5,
           356,    11,     5,   997,    16,   298,    40,   215,     8,  6362,
            41,     5,   302,   125,   106,    81,   178,     5,  7835,     9,
             5,  1046, 15411,     4,     2,     1,     1,     1],
        [    3,    32,    15,   674,    37,  5100,    83,  3735,  7986,    31,
            

### 15.5.4 Truncation

In [32]:
from torchtext.transforms import Truncate

truncated_token_ids = Truncate(max_seq_len=254)(token_ids)

### 15.5.5 Sequential

In [33]:
get_tokenizer('basic_english')

<function torchtext.data.utils._basic_english_normalize(line)>

In [34]:
from torchtext.transforms import Sequential as TextSequential

transform_fn = TextSequential(vocab_transform,
                              Truncate(max_seq_len=254),
                              AddToken(token=vocab['<cls>'], begin=True),
                              AddToken(token=vocab['<sep>'], begin=False))
transform_fn

Sequential(
  (0): VocabTransform(
    (vocab): Vocab()
  )
  (1): Truncate()
  (2): AddToken()
  (3): AddToken()
)

In [35]:
tokens = [tokenizer(d) for d in descriptions]
ToTensor(padding_value=vocab['<pad>'])(transform_fn(tokens))

tensor([[    3,    32,    15, 44049,     6,   413,   327,    12,    13, 10945,
          3069,     9, 45325,     6,    40,  3603,   811,   416,     4,     2,
             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
             1,     1,     1,     1,     1,     1,     1,     1],
        [    3,    32,    15,   858,   736,   331, 19149,    95,     6,    86,
            28,     7,  3880,    16,   506, 45752,    11, 14956,  1279,  2718,
            10,     5,   527,   220,     6,    28,  3452,  2059,    22,  8228,
            14,   204,   301,     9,     5,   125,     4,     2],
        [    3,    32,    15,  2124,   455,   105,  1671,  1454,    67,     5,
           356,    11,     5,   997,    16,   298,    40,   215,     8,  6362,
            41,     5,   302,   125,   106,    81,   178,     5,  7835,     9,
             5,  1046, 15411,     4,     2,     1,     1,     1],
        [    3,    32,    15,   674,    37,  5100,    83,  3735,  7986,    31,
            

### 15.5.6 Tokenizers

#### 15.5.6.1 BERTTokenizer

In [36]:
from torchtext.transforms import BERTTokenizer

VOCAB_FILE = "https://huggingface.co/bert-base-uncased/resolve/main/vocab.txt"

tokenizer = BERTTokenizer(vocab_path=VOCAB_FILE, do_lower_case=True, return_tokens=True)

100%|██████████| 232k/232k [00:00<00:00, 17.7MB/s]


In [37]:
tokenizer(descriptions[0])

['reuters',
 '-',
 'short',
 '-',
 'sellers',
 ',',
 'wall',
 'street',
 "'",
 's',
 'd',
 '##wind',
 '##ling',
 'band',
 'of',
 'ultra',
 '-',
 'cy',
 '##nic',
 '##s',
 ',',
 'are',
 'seeing',
 'green',
 'again',
 '.']

In [38]:
import requests

resp = requests.get(VOCAB_FILE)
vocab_txt = resp.content
vocab_list = vocab_txt.decode().split('\n')
len(vocab_list)

30523

#### 15.5.6.2 GPT2BPETokenizer

In [39]:
from torchtext.transforms import GPT2BPETokenizer

# https://github.com/facebookresearch/fairseq/blob/8deb43af8c54d6840e5ba6e057acf715c4491f9c/fairseq/data/encoders/gpt2_bpe.py#L15
VOCAB_FILE = "https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/vocab.bpe"
ENCODER_FILE = "https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/encoder.json"
tokenizer = GPT2BPETokenizer(vocab_bpe_path=VOCAB_FILE, encoder_json_path=ENCODER_FILE)

tokenizer(descriptions[0])

1.04MB [00:00, 50.0MB/s]
456kB [00:00, 22.5MB/s]


['12637',
 '532',
 '10073',
 '12',
 '7255',
 '364',
 '11',
 '5007',
 '3530',
 '338',
 '45215',
 '4097',
 '286',
 '14764',
 '12',
 '948',
 '77',
 '873',
 '11',
 '389',
 '4379',
 '4077',
 '757',
 '13']

In [40]:
import torchtext

roberta_base = torchtext.models.ROBERTA_BASE_ENCODER
roberta_transform_fn = roberta_base.transform()
roberta_transform_fn

100%|██████████| 1.04M/1.04M [00:00<00:00, 3.35MB/s]
100%|██████████| 456k/456k [00:00<00:00, 1.83MB/s]
Downloading: "https://download.pytorch.org/models/text/roberta.vocab.pt" to /root/.cache/torch/hub/checkpoints/roberta.vocab.pt
100%|██████████| 726k/726k [00:00<00:00, 19.3MB/s]


Sequential(
  (0): GPT2BPETokenizer()
  (1): VocabTransform(
    (vocab): Vocab()
  )
  (2): Truncate()
  (3): AddToken()
  (4): AddToken()
)

#### 15.5.6.3 CLIPTokenizer

In [41]:
from torchtext.transforms import CLIPTokenizer

MERGES_FILE = "http://download.pytorch.org/models/text/clip_merges.bpe"
ENCODER_FILE = "http://download.pytorch.org/models/text/clip_encoder.json"

tokenizer = CLIPTokenizer(merges_path=MERGES_FILE, encoder_json_path=ENCODER_FILE)

tokenizer(descriptions[0])

100%|██████████| 525k/525k [00:00<00:00, 1.67MB/s]
100%|██████████| 862k/862k [00:00<00:00, 2.63MB/s]


['15569',
 '268',
 '3005',
 '268',
 '16562',
 '267',
 '2569',
 '2012',
 '568',
 '67',
 '6812',
 '1358',
 '1963',
 '539',
 '8118',
 '268',
 '14324',
 '1324',
 '267',
 '631',
 '3214',
 '1901',
 '1495',
 '269']

## 15.6 Embeddings

In [42]:
import torch.nn as nn

emb_dims = 50
embeddings = nn.Embedding(len(vocab), emb_dims)
embeddings

Embedding(78151, 50)

In [43]:
import torch

idx = torch.as_tensor([vocab['reuters']])
idx, embeddings(idx)

(tensor([32]),
 tensor([[-4.4208e-01, -5.5386e-01,  2.6806e-01, -1.4883e+00,  1.1500e+00,
           6.3705e-02,  2.2030e+00, -6.9488e-01, -1.3200e+00, -4.8496e-01,
          -2.3973e+00,  1.5396e+00, -4.2997e-01,  3.4996e-01,  1.4332e+00,
          -7.5786e-01, -7.7373e-01, -6.1642e-01,  1.7552e+00,  7.9116e-01,
           1.1501e+00, -4.9157e-01, -1.0165e+00,  3.8830e-02, -8.3259e-04,
           3.6193e-01,  1.9510e+00,  1.3320e-01, -1.1226e+00,  7.0744e-01,
          -9.0923e-01, -4.0862e-01, -6.0121e-01,  3.4464e-01,  3.8744e+00,
           1.1284e+00,  1.5116e-01, -3.0192e-02,  1.9442e+00,  9.8362e-01,
           4.3002e-01,  1.7529e-01, -1.0816e+00, -1.2829e+00, -1.2346e+00,
           1.4199e+00,  1.4938e+00, -7.9703e-02,  5.7981e-01, -1.0756e-01]],
        grad_fn=<EmbeddingBackward0>))

In [44]:
embeddings.weight[idx]

tensor([[-4.4208e-01, -5.5386e-01,  2.6806e-01, -1.4883e+00,  1.1500e+00,
          6.3705e-02,  2.2030e+00, -6.9488e-01, -1.3200e+00, -4.8496e-01,
         -2.3973e+00,  1.5396e+00, -4.2997e-01,  3.4996e-01,  1.4332e+00,
         -7.5786e-01, -7.7373e-01, -6.1642e-01,  1.7552e+00,  7.9116e-01,
          1.1501e+00, -4.9157e-01, -1.0165e+00,  3.8830e-02, -8.3259e-04,
          3.6193e-01,  1.9510e+00,  1.3320e-01, -1.1226e+00,  7.0744e-01,
         -9.0923e-01, -4.0862e-01, -6.0121e-01,  3.4464e-01,  3.8744e+00,
          1.1284e+00,  1.5116e-01, -3.0192e-02,  1.9442e+00,  9.8362e-01,
          4.3002e-01,  1.7529e-01, -1.0816e+00, -1.2829e+00, -1.2346e+00,
          1.4199e+00,  1.4938e+00, -7.9703e-02,  5.7981e-01, -1.0756e-01]],
       grad_fn=<IndexBackward0>)

### 15.6.2 Embedding Arithmetic

![](https://raw.githubusercontent.com/dvgodoy/assets/main/PyTorchInPractice/images/ch13/embed_arithmetic.png)

![](https://raw.githubusercontent.com/dvgodoy/assets/main/PyTorchInPractice/images/ch13/synthetic_queen.png)

### 15.6.3 Global Vectors (GloVe)

In [45]:
from torchtext.vocab import GloVe

GloVe.url

{'42B': 'http://nlp.stanford.edu/data/glove.42B.300d.zip',
 '840B': 'http://nlp.stanford.edu/data/glove.840B.300d.zip',
 'twitter.27B': 'http://nlp.stanford.edu/data/glove.twitter.27B.zip',
 '6B': 'http://nlp.stanford.edu/data/glove.6B.zip'}

In [46]:
import os

new_locations = {key: os.path.join('https://huggingface.co/stanfordnlp/glove/resolve/main',
                                   os.path.split(GloVe.url[key])[-1]) for key in GloVe.url.keys()}
new_locations

{'42B': 'https://huggingface.co/stanfordnlp/glove/resolve/main/glove.42B.300d.zip',
 '840B': 'https://huggingface.co/stanfordnlp/glove/resolve/main/glove.840B.300d.zip',
 'twitter.27B': 'https://huggingface.co/stanfordnlp/glove/resolve/main/glove.twitter.27B.zip',
 '6B': 'https://huggingface.co/stanfordnlp/glove/resolve/main/glove.6B.zip'}

In [47]:
GloVe.url = new_locations

In [48]:
vec = GloVe(name='6B', dim=50)

.vector_cache/glove.6B.zip: 862MB [00:15, 56.2MB/s]                           
100%|█████████▉| 400000/400001 [00:12<00:00, 32234.93it/s]


In [49]:
vec.vectors, vec.vectors.shape

(tensor([[ 0.4180,  0.2497, -0.4124,  ..., -0.1841, -0.1151, -0.7858],
         [ 0.0134,  0.2368, -0.1690,  ..., -0.5666,  0.0447,  0.3039],
         [ 0.1516,  0.3018, -0.1676,  ..., -0.3565,  0.0164,  0.1022],
         ...,
         [-0.7590, -0.4743,  0.4737,  ...,  0.7895, -0.0141,  0.6448],
         [ 0.0726, -0.5139,  0.4728,  ..., -0.1891, -0.5902,  0.5556],
         [ 0.0726, -0.5139,  0.4728,  ..., -0.1891, -0.5902,  0.5556]]),
 torch.Size([400001, 50]))

In [50]:
vec['reuters']

tensor([-0.1374, -0.2549,  1.8853,  0.1476,  0.6386, -0.6768, -1.1622, -0.2153,
         0.2598, -0.5288,  0.6668, -0.7675, -0.5273,  0.0666,  0.0766,  0.3274,
        -0.8025, -0.4955, -0.3739,  0.1126,  1.1671,  1.1508,  0.6180,  0.0795,
         0.1269, -0.0724, -1.2037, -0.2462, -0.7708,  0.7670,  1.2745, -0.1290,
         0.9989, -0.2673, -0.5754, -1.0151, -0.1428, -0.4382,  0.7658, -0.0088,
         1.2848,  0.0031,  0.1186, -0.3882, -0.2352, -0.9209, -0.5164,  1.5083,
         0.3646,  0.5991])

In [51]:
vec.stoi['reuters'], vec.itos[10851]

(10851, 'reuters')

In [52]:
vec['anteater']

tensor([ 1.3244, -0.3380, -0.7163,  2.3814,  0.2372,  1.2824,  0.4650, -0.2310,
         0.0327, -0.5052,  0.0911,  0.6865,  0.5613,  0.6078, -0.2925, -0.3512,
        -0.5515,  1.4059, -0.3041, -0.4593, -1.1025, -0.4290, -0.4530,  0.0071,
        -0.2118,  0.4694,  0.3863,  0.9646, -0.8679, -0.4496, -0.2790, -0.7240,
         1.3138,  0.8487, -0.9294, -0.2259, -0.7488, -0.8090,  0.1210, -0.5639,
         0.0885, -0.5298, -0.2664,  1.6615,  1.0241, -0.8384, -0.0942,  0.6270,
        -0.0036,  0.4311])

In [53]:
vec['zzzzzzz']

tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0.])

In [54]:
embedded_tokens = vec.get_vecs_by_tokens(tokens[0])
embedded_tokens.shape

torch.Size([18, 50])

## 15.7 Vector Databases

### 15.7.1 ChromaDB

In [93]:
import chromadb

client = chromadb.PersistentClient(path="./agnews_db")

In [94]:
collection = client.create_collection("agnews_collection")

![](https://raw.githubusercontent.com/dvgodoy/assets/main/PyTorchInPractice/images/ch13/vector_db.png)

In [None]:
from torchtext.data import get_tokenizer

def tokenize_batch(sentences, tokenizer=None):
    if tokenizer is None:
        tokenizer = get_tokenizer('basic_english')
    
    return [tokenizer(s) for s in sentences]

def get_bag_of_embeddings(tokens, vec):
    embeddings = torch.cat([vec.get_vecs_by_tokens(s).mean(axis=0).unsqueeze(0) for s in tokens], dim=0)
    return embeddings

In [95]:
from torch.utils.data import DataLoader

batch_size = 32
unshuffled_dl = DataLoader(dataset=datapipes['train'], batch_size=batch_size, shuffle=False)

for i, batch in enumerate(unshuffled_dl):
    labels, sentences = batch
    tokens = tokenize_batch(sentences)
    embeddings = get_bag_of_embeddings(tokens, vec)
    ids = [f'{i:06}' for i in np.arange(i*batch_size, i*batch_size+len(sentences))]

    collection.add(embeddings=embeddings.tolist(),
                   documents=sentences,
                   metadatas=[{'label': v} for v in labels.tolist()],
                   ids=ids)

    if i == 300: # roughly 10k docs
        break

In [96]:
collection.count()

9632

### 15.7.2 Similarity Search

In [110]:
query_sentence = 'The company running the Japanese nuclear plant hit by a fatal accident is to close its reactors for safety checks.'
query_tokens = tokenize_batch([query_sentence])
query_embeddings = get_bag_of_embeddings(query_tokens, vec)[0]

query_embeddings

tensor([ 3.8694e-01,  1.0883e-01,  2.5127e-01,  1.8260e-01,  2.0508e-02,
         3.5670e-01, -5.0396e-01, -6.0610e-02,  2.2243e-01, -1.2400e-01,
         2.0367e-01, -6.9109e-02, -3.3572e-01,  3.0528e-02,  2.7079e-01,
         1.7575e-01, -7.7128e-02,  3.1292e-01, -5.4482e-01, -2.6725e-01,
         3.4033e-01, -8.0611e-02, -1.1263e-01, -2.7624e-01,  6.8221e-02,
        -1.7065e+00, -6.2789e-02,  1.1811e-01,  2.8549e-01, -3.8258e-04,
         3.1455e+00, -1.7113e-01, -1.7587e-01, -2.1735e-01,  2.1206e-01,
        -3.5065e-02,  2.0723e-01,  9.7491e-02,  9.0475e-02,  5.8104e-03,
        -2.5775e-01, -1.1576e-01,  2.1288e-01, -9.4534e-02,  1.2578e-01,
         7.7879e-02, -2.1271e-01,  2.6167e-01,  3.1803e-02, -1.1789e-01])

![](https://raw.githubusercontent.com/dvgodoy/assets/main/PyTorchInPractice/images/ch13/query_db.png)

In [111]:
query_embeddings = query_embeddings.tolist()
collection.query(query_embeddings=query_embeddings, n_results=5)

{'ids': [['000030', '001046', '004715', '002464', '006426']],
 'embeddings': None,
 'documents': [['The company running the Japanese nuclear plant hit by a fatal accident is to close its reactors for safety checks.',
   'AP - The operator of a nuclear power plant where a long-neglected cooling pipe burst and killed four workers last week said Monday that four other pipes at its reactors also went unchecked for years.',
   'TOKYO The operators of a Japanese nuclear plant say there was no evidence of danger at the plant before a deadly explosion this month.',
   'Reuters - No more Japanese nuclear reactors need to be closed for inspections, electric power companies said on Wednesday after submitting reports ordered by the government following a reactor accident that killed four workers last week.',
   'TEHRAN (Reuters) - Iran on Sunday announced a further  substantial delay in the long overdue project to complete its  first nuclear power plant, part of a program which Washington  says co

In [114]:
query_sentence = 'asian stock market'
query_tokens = tokenize_batch([query_sentence])
query_embeddings = get_bag_of_embeddings(query_tokens, vec)[0]
query_embeddings = query_embeddings.tolist()

In [115]:
collection.query(query_embeddings=query_embeddings, n_results=5)

{'ids': [['007389', '006925', '007014', '006829', '008197']],
 'embeddings': None,
 'documents': [['Asian stocks rose after oil prices fell from a record on Friday, easing concern higher energy costs will damp consumer spending and corporate profits.',
   'Asian stocks advanced after oil prices fell from a record Friday in New York, easing concern higher energy costs will damp consumer spending and corporate profits.',
   'Japanese stocks rose after oil prices fell from a record in New York on Friday, easing concern higher energy costs will damp consumer spending and corporate profits.',
   'Japanese stocks may rise after oil prices fell from a record in New York, easing concern higher energy costs will damp consumer spending and corporate profits.',
   "Reuters - Japan's Nikkei average was flat by late morning trade on Tuesday as falls in oil-related stocks offset gains in some exporters including auto makers after a retreat in oil prices eased fears about the global economy."]],
 'me

## 15.8 Zero-Shot Text Classification

![](https://raw.githubusercontent.com/dvgodoy/assets/main/PyTorchInPractice/images/ch0/model_step5.png)

In [160]:
cand_labels = ["world", "sports", "business", "science and technology"]

cand_emb = torch.vstack([vec.get_vecs_by_tokens(tokens).mean(axis=0) for tokens in tokenize_batch(cand_labels)])
cand_emb.shape

torch.Size([4, 50])

In [161]:
cos = nn.CosineSimilarity(dim=2)

cos(cand_emb.unsqueeze(1), cand_emb.unsqueeze(0))

tensor([[1.0000, 0.6529, 0.6136, 0.6678],
        [0.6529, 1.0000, 0.6410, 0.6171],
        [0.6136, 0.6410, 1.0000, 0.8069],
        [0.6678, 0.6171, 0.8069, 1.0000]])

In [None]:
batch_size = 32
dataloader = DataLoader(dataset=datapipes['test'], batch_size=batch_size, shuffle=False)

labels, sentences = next(iter(dataloader))
tokens = tokenize_batch(sentences)
embeddings = get_bag_of_embeddings(tokens, vec)
similarities = cos(embeddings.unsqueeze(1), cand_emb.unsqueeze(0))
similarities

In [163]:
predicted_class = similarities.argmax(dim=1)
predicted_class

tensor([2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 0, 3, 3, 3, 3, 3, 2, 3, 3, 3, 3,
        2, 3, 0, 3, 3, 3, 3, 0])

In [164]:
(predicted_class == labels).float().mean()

tensor(0.7188)

### 15.8.1 Evaluation

In [165]:
import evaluate

metric1 = evaluate.load('precision', average=None)
metric2 = evaluate.load('recall', average=None)
metric3 = evaluate.load('accuracy')

In [166]:
for batch in dataloader:
    labels, sentences = batch
    tokens = tokenize_batch(sentences)
    embeddings = get_bag_of_embeddings(tokens, vec)

    # predictions = model(embeddings)
    predictions = cos(embeddings.unsqueeze(1), cand_emb.unsqueeze(0))

    pred_class = predictions.argmax(dim=1).tolist()
    labels = labels.tolist()

    metric1.add_batch(references=labels, predictions=pred_class)
    metric2.add_batch(references=labels, predictions=pred_class)
    metric3.add_batch(references=labels, predictions=pred_class)

In [167]:
metric1.compute(average=None), metric2.compute(average=None), metric3.compute()

  _warn_prf(average, modifier, msg_start, len(result))


({'precision': array([0.29004107, 0.        , 0.74021131, 0.38535741])},
 {'recall': array([0.29736842, 0.        , 0.62684211, 0.82      ])},
 {'accuracy': 0.43605263157894736})

## 15.9 Chunking Strategies

In [1]:
text = """
ITEM 1A. RISK FACTORS Our operations and financial results are subject to various risks and uncertainties, including those described below, that could adversely affect our business, financial condition, results of operations, cash flows, and the trading price of our common stock. STRATEGIC AND COMPETITIVE RISKS We face intense competition across all markets for our products and services, which may lead to lower revenue or operating margins.    Competition in the technology sector Our competitors range in size from diversified global companies with significant research and development resources to small, specialized firms whose narrower product lines may let them be more effective in deploying technical, marketing, and financial resources. Barriers to entry in many of our businesses are low and many of the areas in which we compete evolve rapidly with changing and disruptive technologies, shifting user needs, and frequent introductions of new products and services. Our ability to remain competitive depends on our success in making innovative products, devices, and services that appeal to businesses and consumers.    Competition among platform-based ecosystems An important element of our business model has been to create platform-based ecosystems on which many participants can build diverse solutions. A well-established ecosystem creates beneficial network effects among users, application developers, and the platform provider that can accelerate growth. Establishing significant scale in the marketplace is necessary to achieve and maintain attractive margins. We face significant competition from firms that provide competing platforms.
"""

### 15.9.1 Fixed-Length

In [6]:
from langchain.text_splitter import RecursiveCharacterTextSplitter

text_splitter = RecursiveCharacterTextSplitter(chunk_size=256, chunk_overlap=20)
chunks = text_splitter.create_documents([text])
chunks[:3]

[Document(page_content='ITEM 1A. RISK FACTORS Our operations and financial results are subject to various risks and uncertainties, including those described below, that could adversely affect our business, financial condition, results of operations, cash flows, and the trading', metadata={}),
 Document(page_content='and the trading price of our common stock. STRATEGIC AND COMPETITIVE RISKS We face intense competition across all markets for our products and services, which may lead to lower revenue or operating margins.  Competition in the technology sector Our', metadata={}),
 Document(page_content='sector Our competitors range in size from diversified global companies with significant research and development resources to small, specialized firms whose narrower product lines may let them be more effective in deploying technical, marketing, and', metadata={})]

### 15.9.2 Content-Aware

In [7]:
import nltk
nltk.download('punkt')

[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


True

In [8]:
from nltk.tokenize import sent_tokenize

chunks = sent_tokenize(text)
chunks[:3]

['ITEM 1A.',
 'RISK FACTORS Our operations and financial results are subject to various risks and uncertainties, including those described below, that could adversely affect our business, financial condition, results of operations, cash flows, and the trading price of our common stock.',
 'STRATEGIC AND COMPETITIVE RISKS We face intense competition across all markets for our products and services, which may lead to lower revenue or operating margins.']

### 15.9.3 Custom

In [9]:
chunks = text.split('  ')
chunks[:3]

['ITEM 1A. RISK FACTORS Our operations and financial results are subject to various risks and uncertainties, including those described below, that could adversely affect our business, financial condition, results of operations, cash flows, and the trading price of our common stock. STRATEGIC AND COMPETITIVE RISKS We face intense competition across all markets for our products and services, which may lead to lower revenue or operating margins.',
 'Competition in the technology sector Our competitors range in size from diversified global companies with significant research and development resources to small, specialized firms whose narrower product lines may let them be more effective in deploying technical, marketing, and financial resources. Barriers to entry in many of our businesses are low and many of the areas in which we compete evolve rapidly with changing and disruptive technologies, shifting user needs, and frequent introductions of new products and services. Our ability to rem