In [2]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
from pathlib import Path
import sys
from typing import Optional
if '..' not in sys.path: sys.path.append('..')

from datasets import load_dataset
import numpy as np
from matplotlib import pyplot as plt
from pydantic_yaml import parse_yaml_file_as
import torch
from torch import nn
from transformers import BertModel, BertTokenizerFast


In [3]:
model = BertModel.from_pretrained("bert-base-uncased", torch_dtype=torch.float32, attn_implementation="sdpa")
model.eval()

BertModel(
  (embeddings): BertEmbeddings(
    (word_embeddings): Embedding(30522, 768, padding_idx=0)
    (position_embeddings): Embedding(512, 768)
    (token_type_embeddings): Embedding(2, 768)
    (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): BertEncoder(
    (layer): ModuleList(
      (0-11): 12 x BertLayer(
        (attention): BertAttention(
          (self): BertSdpaSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False

In [4]:
param_size = 0
for param in model.parameters():
    param_size += param.nelement() * param.element_size()
buffer_size = 0
for buffer in model.buffers():
    buffer_size += buffer.nelement() * buffer.element_size()

size_all_mb = (param_size + buffer_size) / 1024**2
print('model size: {:.3f}MB'.format(size_all_mb))

model size: 417.649MB


In [5]:
tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")

In [6]:
tokenizer

BertTokenizerFast(name_or_path='bert-base-uncased', vocab_size=30522, model_max_length=512, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'unk_token': '[UNK]', 'sep_token': '[SEP]', 'pad_token': '[PAD]', 'cls_token': '[CLS]', 'mask_token': '[MASK]'}, clean_up_tokenization_spaces=True),  added_tokens_decoder={
	0: AddedToken("[PAD]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	100: AddedToken("[UNK]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	101: AddedToken("[CLS]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	102: AddedToken("[SEP]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	103: AddedToken("[MASK]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
}

## Wikipedia dataset

In [7]:
DATA_PATH = Path(os.path.expandvars('$HOME')) / 'data'


In [24]:
txt1 = 'Moscow calling'
txt2 = "pekin calling"
txt2 = "Let's go"
txt3 = 'calling moscow'
def get_emb(txt: str):
    toks = tokenizer(txt)
    inp = torch.tensor(toks['input_ids'])
    out = model(inp.unsqueeze(0))
    # key = 'last_hidden_state'
    # return out[key].squeeze()[0]
    key = 'pooler_output'
    return out[key].squeeze()

emb1 = get_emb(txt1)
emb2 = get_emb(txt2)
emb3 = get_emb(txt3)

def dist(x, y):
    # print(x.shape)
    # return torch.sum(x * y) / torch.norm(x) / torch.norm(y)
    d = x - y
    return torch.sqrt(torch.sum(d * d))

print(dist(emb1, emb2))
print(dist(emb1, emb3))
print(dist(emb2, emb3))


tensor(10.4493, grad_fn=<SqrtBackward0>)
tensor(4.7394, grad_fn=<SqrtBackward0>)
tensor(14.1465, grad_fn=<SqrtBackward0>)


RuntimeError: a Tensor with 768 elements cannot be converted to Scalar

In [27]:
inp = torch.tensor(toks['input_ids'])
out = model(inp.unsqueeze(0))
print(out)

BaseModelOutputWithPoolingAndCrossAttentions(last_hidden_state=tensor([[[ 2.9144e-03,  3.9209e-01,  4.2664e-02,  ..., -1.9043e-01,
           1.9202e-01,  6.1035e-01],
         [ 6.1328e-01,  1.9446e-01,  1.4746e-01,  ...,  1.5125e-01,
           1.1514e+00,  8.3618e-03],
         [-2.0581e-01,  4.9561e-01,  1.1243e-01,  ..., -2.5879e-01,
           4.8877e-01,  4.0283e-01],
         ...,
         [-3.5425e-01,  2.0157e-02,  1.2979e+00,  ...,  6.1719e-01,
           5.4834e-01,  5.4688e-01],
         [-3.6888e-03,  7.3389e-01,  2.8931e-01,  ...,  7.4414e-01,
           1.6663e-01,  1.0357e-03],
         [ 3.7109e-01,  5.2393e-01, -4.7607e-01,  ...,  4.4525e-02,
          -4.9097e-01, -3.6255e-01]]], dtype=torch.float16,
       grad_fn=<NativeLayerNormBackward0>), pooler_output=tensor([[-0.7397, -0.4351, -0.8726,  0.6460,  0.6240, -0.1565,  0.7266,  0.2549,
         -0.5791, -1.0000, -0.3198,  0.9307,  0.9688,  0.4231,  0.8350, -0.7075,
         -0.2744, -0.5811,  0.3804,  0.1144,  0.55

In [29]:
lhs, pout = out['last_hidden_state'], out['pooler_output']
lhs.shape, pout.shape

(torch.Size([1, 13, 768]), torch.Size([1, 768]))

In [31]:
print(lhs[0, 0, :10])
print(pout[0, :10])

tensor([ 0.0029,  0.3921,  0.0427, -0.1011, -0.3064, -0.0775,  0.3308,  0.2949,
        -0.3630,  0.1326], dtype=torch.float16, grad_fn=<SliceBackward0>)
tensor([-0.7397, -0.4351, -0.8726,  0.6460,  0.6240, -0.1565,  0.7266,  0.2549,
        -0.5791, -1.0000], dtype=torch.float16, grad_fn=<SliceBackward0>)
