In [111]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [112]:
import os
from pathlib import Path
import sys
from typing import Optional, List, Dict, Any, Tuple
if '..' not in sys.path: sys.path.append('..')

from dataclasses import dataclass
from datasets import load_dataset
from datasets.arrow_dataset import Dataset
import numpy as np
from matplotlib import pyplot as plt
from pydantic_yaml import parse_yaml_file_as
import torch
from torch import nn, dist
import torch.nn.functional as F
from transformers import GPT2Tokenizer, AddedToken, PreTrainedTokenizer, AutoTokenizer, BatchEncoding

from mllm.data.wiki.dswiki import WikiDsLoader
from mllm.data.utils import RandomInputTokenizer, RandomInputTokenizerV2, TokensSubset, TokensSubsetV2, tokens_subsets_to_tensors, tokens_subsets_v2_to_tensors
from mllm.exp.args import ENCDEC_GRAPH_BERT_MODEL_CFG_FNAME
from mllm.model.encdec_ranker_hg import EncdecGraphBert
from mllm.config.model import EncdecGraphBertCfg


In [25]:
# DATA_PATH = Path(os.path.expandvars('$HOME')) / 'data'
DATA_PATH = Path('Q:/data')
# WIKI_DS_NAME = '20200501.en'
WIKI_DS_NAME = '20220301.en'

TRAIN_ENCDEC_GRAPH_BERT_PATH = DATA_PATH / 'train_mllm_encdec_graph_bert'
encdec_subdir = ''

encdec_train_path = TRAIN_ENCDEC_GRAPH_BERT_PATH / encdec_subdir
encdec_snapshot_fpath = encdec_train_path / 'best.pth'
encdec_graph_model_cfg_fpath = encdec_train_path / ENCDEC_GRAPH_BERT_MODEL_CFG_FNAME

device_name = 'cpu'
# device_name = 'cuda'

device = torch.device(device_name)
print(device)

cpu


## Wikipeida dataset loading

In [26]:
# dss = load_dataset('wikipedia', WIKI_DS_NAME, beam_runner='DirectRunner', cache_dir=str(DATA_PATH))
dss = load_dataset('wikipedia', WIKI_DS_NAME, cache_dir=str(DATA_PATH), trust_remote_code=True)
ds: Dataset = dss['train']
n_docs = len(ds)
print(f'Wikipedia {WIKI_DS_NAME} docs: {n_docs}')

Loading dataset shards:   0%|          | 0/41 [00:00<?, ?it/s]

Wikipedia 20220301.en docs: 6458670


## Tokenizer tests

In [27]:
tkz = AutoTokenizer.from_pretrained('bert-base-uncased')
tkz

BertTokenizerFast(name_or_path='bert-base-uncased', vocab_size=30522, model_max_length=512, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'unk_token': '[UNK]', 'sep_token': '[SEP]', 'pad_token': '[PAD]', 'cls_token': '[CLS]', 'mask_token': '[MASK]'}, clean_up_tokenization_spaces=False, added_tokens_decoder={
	0: AddedToken("[PAD]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	100: AddedToken("[UNK]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	101: AddedToken("[CLS]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	102: AddedToken("[SEP]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	103: AddedToken("[MASK]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
}
)

In [28]:
inds = [0, 10, 20, 30, 40]
ttls = [ds[i]['title'] for i in inds]
txts = [ds[i]['text'] for i in inds]
batch: BatchEncoding = tkz(txts, padding=False, truncation=False)
print(f'CLS: {tkz.cls_token}, {tkz.cls_token_id}. SEP: {tkz.sep_token}, {tkz.sep_token_id}. PAD: {tkz.pad_token}. {tkz.pad_token_id}.')
for i in range(len(inds)):
    print(f'[{inds[i]}] {ttls[i]}')
    print(txts[i][:200].replace('\n', '\\n'))
    tok_ids = batch['input_ids'][i]
    assert tok_ids[0] == tkz.cls_token_id and tok_ids[-1] == tkz.sep_token_id
    print(len(tok_ids), tok_ids[:10], tok_ids[-10:])
    print('---')


Token indices sequence length is longer than the specified maximum sequence length for this model (8351 > 512). Running this sequence through the model will result in indexing errors


CLS: [CLS], 101. SEP: [SEP], 102. PAD: [PAD]. 0.
[0] Anarchism
Anarchism is a political philosophy and movement that is sceptical of authority and rejects all involuntary, coercive forms of hierarchy. Anarchism calls for the abolition of the state, which it holds
8351 [101, 9617, 11140, 2964, 2003, 1037, 2576, 4695, 1998, 2929] [8780, 21615, 2591, 8106, 14649, 2521, 1011, 2187, 4331, 102]
---
[10] Academy Awards
The Academy Awards, popularly known as the Oscars, are awards for artistic and technical merit in the film industry. They are regarded by many as the most prestigious and significant awards in the ent
9434 [101, 1996, 2914, 2982, 1010, 16071, 2124, 2004, 1996, 7436] [3349, 5365, 2381, 1998, 3226, 2137, 2444, 2547, 3065, 102]
---
[20] Anthropology
Anthropology is the scientific study of humanity, concerned with human behavior, human biology, cultures, societies, and linguistics, in both the present and past, including past human species. Social
9831 [101, 12795, 2003, 1996, 4045

In [29]:
batch2 = tkz(ttls, padding=True, truncation=True, max_length=4)
print(type(batch2['input_ids']))
batch2['input_ids']

<class 'list'>


[[101, 9617, 11140, 102],
 [101, 2914, 2982, 102],
 [101, 12795, 102, 0],
 [101, 16951, 15396, 102],
 [101, 15262, 102, 0]]

In [30]:
for toks in batch2['input_ids']:
    print(tkz.convert_ids_to_tokens(toks))

['[CLS]', 'ana', '##rch', '[SEP]']
['[CLS]', 'academy', 'awards', '[SEP]']
['[CLS]', 'anthropology', '[SEP]', '[PAD]']
['[CLS]', 'austro', '##asia', '[SEP]']
['[CLS]', 'ada', '[SEP]', '[PAD]']


## Random tokenzied substring

In [72]:
def print_batch(batch: List[TokensSubset], tkz: PreTrainedTokenizer):
    for i, toks_sub in enumerate(batch):
        toks_src_str = ' '.join(tkz.convert_ids_to_tokens(toks_sub.toks_src[:200]))
        toks_inp_str = ' '.join(tkz.convert_ids_to_tokens(toks_sub.toks_inp))
        toks_inp2_str = ' '.join(tkz.convert_ids_to_tokens(toks_sub.toks_src[toks_sub.inp_beg_ind:toks_sub.inp_end_ind]))
        print(f'[{i}] src ({len(toks_sub.toks_src)}): {toks_src_str}')
        print(f'     inp ({len(toks_sub.toks_inp)}): {toks_inp_str}')
        print(f'     inp2({len(toks_inp2_str)}): {toks_inp2_str}')
        if toks_sub.toks_cite is not None:
            toks_sub_str = ' '.join(tkz.convert_ids_to_tokens(toks_sub.toks_src[toks_sub.cite_beg_ind:toks_sub.cite_end_ind]))
            print(f'     cite ({len(toks_sub_str)}): {toks_sub_str}')
        print('---')

def print_batch_v2(batch: List[TokensSubsetV2], tkz: PreTrainedTokenizer):
    for i, toks_sub in enumerate(batch):
        toks_src_str = ' '.join(tkz.convert_ids_to_tokens(toks_sub.toks_src[:200]))
        toks_inp_str = ' '.join(tkz.convert_ids_to_tokens(toks_sub.toks_inp))
        toks_inp2_str = ' '.join(tkz.convert_ids_to_tokens(toks_sub.toks_src[toks_sub.inp_beg_ind:toks_sub.inp_end_ind]))
        print(f'[{i}] src ({len(toks_sub.toks_src)}): {toks_src_str}')
        print(f'    inp ({len(toks_sub.toks_inp)}): {toks_inp_str}')
        print(f'    inp2({len(toks_inp2_str)}): {toks_inp2_str}')

        toks_sub_str = ' '.join(tkz.convert_ids_to_tokens(toks_sub.toks_src[toks_sub.cite_beg_ind:toks_sub.cite_end_ind]))
        cite_beg = tkz.convert_ids_to_tokens(toks_sub.toks_cite_beg)
        cite_end = tkz.convert_ids_to_tokens(toks_sub.toks_cite_end)
        cite_beg_str, cite_end_str = ' '.join(cite_beg), ' '.join(cite_end)
        print(f'    special tokens begin, end: {toks_sub.toks_cite_beg} = "{cite_beg_str}", {toks_sub.toks_cite_end} = "{cite_end_str}"')
        print(f'    cite ({len(toks_sub_str)}): {toks_sub_str}')
        print(f'    prompt ({len(toks_sub.toks_prompt)}): ' + ' '.join(tkz.convert_ids_to_tokens(toks_sub.toks_prompt)))
        print('---')


In [31]:
tag_beg, tag_end = tkz(['[TAG_BEG]', '[TAG_END]'], add_special_tokens=False).input_ids
print(f'TAG_BEG: {tag_beg}, TAG_END: {tag_end}')
tag_beg, tag_end = tkz(['<cite>', '</cite>'], add_special_tokens=False).input_ids
print(f'TAG_BEG: {tag_beg}, TAG_END: {tag_end}')

TAG_BEG: [1031, 6415, 1035, 11693, 1033], TAG_END: [1031, 6415, 1035, 2203, 1033]
TAG_BEG: [1026, 21893, 1028], TAG_END: [1026, 1013, 21893, 1028]


In [33]:
rand_tkz = RandomInputTokenizer(tkz, max_len=20)
batch = rand_tkz(txts, n_items_to_cite=2)
print_batch(batch, tkz)


[0] src (8349): ana ##rch ##ism is a political philosophy and movement that is sc ##ept ##ical of authority and rejects all involuntary , coe ##rc ##ive forms of hierarchy . ana ##rch ##ism calls for the abolition of the state , which it holds to be unnecessary , und ##es ##ira ##ble , and harmful . as a historically left - wing movement , placed on the far ##thest left of the political spectrum , it is usually described alongside communal ##ism and libertarian marxism as the libertarian wing ( libertarian socialism ) of the socialist movement , and has a strong historical association with anti - capitalism and socialism . humans lived in societies without formal hi ##era ##rch ##ies long before the establishment of formal states , realms , or empires . with the rise of organised hierarchical bodies , sc ##ept ##ici ##sm toward authority also rose . although traces of anarchist thought are found throughout history , modern ana ##rch ##ism emerged from the enlightenment . during the lat

In [34]:
rand_tkz = RandomInputTokenizer(tkz, max_len=12)
batch_ttl = rand_tkz(ttls, n_items_to_cite=3)
print_batch(batch_ttl, tkz)

inp_toks_t, att_mask_t = tokens_subsets_to_tensors(batch_ttl, pad_token_id=tkz.pad_token_id, device=device)
inp_toks_t, att_mask_t

[0] src (3): ana ##rch ##ism
     inp (5): [CLS] ana ##rch ##ism [SEP]
     inp2(15): ana ##rch ##ism
---
[1] src (2): academy awards
     inp (11): [CLS] < cite > academy < / cite > awards [SEP]
     inp2(14): academy awards
     cite (7): academy
---
[2] src (1): anthropology
     inp (3): [CLS] anthropology [SEP]
     inp2(12): anthropology
---
[3] src (4): austro ##asia ##tic languages
     inp (12): [CLS] ##asia < cite > ##tic < / cite > languages [SEP]
     inp2(22): ##asia ##tic languages
     cite (5): ##tic
---
[4] src (1): ada
     inp (10): [CLS] < cite > ada < / cite > [SEP]
     inp2(3): ada
     cite (3): ada
---


(tensor([[  101,  9617, 11140,  2964,   102,     0,     0,     0,     0,     0,
              0,     0],
         [  101,  1026, 21893,  1028,  2914,  1026,  1013, 21893,  1028,  2982,
            102,     0],
         [  101, 12795,   102,     0,     0,     0,     0,     0,     0,     0,
              0,     0],
         [  101, 15396,  1026, 21893,  1028,  4588,  1026,  1013, 21893,  1028,
           4155,   102],
         [  101,  1026, 21893,  1028, 15262,  1026,  1013, 21893,  1028,   102,
              0,     0]]),
 tensor([[1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0],
         [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0],
         [1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
         [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
         [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0]]))

## Random tokenized substring v2

In [36]:
print(len(tkz))
tkz

30522


BertTokenizerFast(name_or_path='bert-base-uncased', vocab_size=30522, model_max_length=512, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'unk_token': '[UNK]', 'sep_token': '[SEP]', 'pad_token': '[PAD]', 'cls_token': '[CLS]', 'mask_token': '[MASK]'}, clean_up_tokenization_spaces=False, added_tokens_decoder={
	0: AddedToken("[PAD]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	100: AddedToken("[UNK]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	101: AddedToken("[CLS]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	102: AddedToken("[SEP]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	103: AddedToken("[MASK]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
}
)

In [44]:
i = 990
inds = list(range(i, i + 30))
' '.join(tkz.convert_ids_to_tokens(inds))

'[unused985] [unused986] [unused987] [unused988] [unused989] [unused990] [unused991] [unused992] [unused993] ! " # $ % & \' ( ) * + , - . / 0 1 2 3 4 5'

In [74]:
rand_tkz = RandomInputTokenizerV2(tkz, max_len=30, n_random_toks=3)
batch_ttl = rand_tkz(ttls)
print_batch_v2(batch_ttl, tkz)

[0] src (3): ana ##rch ##ism
    inp (11): [CLS] ##vance ##ₖ remarried ana ##rch ##ism ##ets email originated [SEP]
    inp2(15): ana ##rch ##ism
    special tokens begin, end: [21789, 30094, 19316] = "##vance ##ₖ remarried", [8454, 10373, 7940] = "##ets email originated"
    cite (15): ana ##rch ##ism
    prompt (36): [CLS] cite tag begin : " # # vance # # ₖ remarried " . cite tag end : " # # et ##s email originated " . produce output text between these tags . [SEP]
---
[1] src (2): academy awards
    inp (10): [CLS] mac ##nt tides academy awards technically ##kon contemporaries [SEP]
    inp2(14): academy awards
    special tokens begin, end: [6097, 3372, 22487] = "mac ##nt tides", [10892, 19648, 16682] = "technically ##kon contemporaries"
    cite (14): academy awards
    prompt (34): [CLS] cite tag begin : " mac # # nt tides " . cite tag end : " technically # # ko ##n contemporaries " . produce output text between these tags . [SEP]
---
[2] src (1): anthropology
    inp (9): [CLS] 

In [76]:
inp_toks_t, att_mask_t, edge_inds_t = tokens_subsets_v2_to_tensors(batch_ttl, tkz=tkz, device=device)
inp_toks_t, att_mask_t, edge_inds_t

(tensor([[  101, 21789, 30094, 19316,  9617, 11140,  2964,  8454, 10373,  7940,
            102,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0],
         [  101,  6097,  3372, 22487,  2914,  2982, 10892, 19648, 16682,   102,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0],
         [  101, 10807, 25900, 13224, 12795,  6868, 13054, 25052,   102,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0],
         [  101, 28777, 27113, 15847, 16951, 15396,  4588,  4155,  8191, 15477,
           8670,   102,     0,     0,    

## Inference

In [None]:
model_cfg = parse_yaml_file_as(EncdecGraphBertCfg, encdec_graph_model_cfg_fpath)
tkz = AutoTokenizer.from_pretrained(model_cfg.enc_bert.pretrained_model_name)
print(model_cfg)
print(tkz)

enc_bert=EncBertCfg(inp_len=128, d_model=768, pad_token_id=0, pretrained_model_name='bert-base-uncased', tokenizer_name='bert-base-uncased', emb_type=<BertEmbType.Cls: 'cls'>, emb2_tok_name='') dec_pyr=DecPyrCfg(d_model=768, n_heads=12, d_k=64, d_v=64, d_inner=3072, inp_len=128, step=2, n_layers=7, dropout_rate=0.0, n_vocab=30522, n_similar_layers=1, enhance_type=<HgEnhanceType.MatmulBeginBias: 'mmbb'>, temperature=0.0)
BertTokenizerFast(name_or_path='bert-base-uncased', vocab_size=30522, model_max_length=512, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'unk_token': '[UNK]', 'sep_token': '[SEP]', 'pad_token': '[PAD]', 'cls_token': '[CLS]', 'mask_token': '[MASK]'}, clean_up_tokenization_spaces=False, added_tokens_decoder={
	0: AddedToken("[PAD]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	100: AddedToken("[UNK]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	101: AddedToken("[CLS]",

In [None]:
chkpt = torch.load(encdec_snapshot_fpath, map_location=device)
model = EncdecGraphBert(model_cfg).to(device)
strict = True
# strict = False
model.load_state_dict(chkpt['model'], strict=strict)
del chkpt
model.eval()
None

In [85]:
inp_len = model_cfg.enc_bert.inp_len
print('inp_len:', inp_len)

def get_batch_tokens(doc_inds: list[int], randomize: bool = False, mask_toks_len_max: int = 5) -> torch.Tensor:
    docs_toks = np.full((len(doc_inds), inp_len), tkz.pad_token_id)
    for i, doc_ind in enumerate(doc_inds):
        doc = ds[doc_ind]
        title, text = doc['title'], doc['text']
        doc_txt = f'{title} {text}'
        doc_txt = text
        doc_txt = doc_txt.lower()
        doc_toks: list[int] | np.ndarray = tkz(doc_txt)['input_ids']
        n_toks = len(doc_toks)
        if n_toks > inp_len:
            i_off = np.random.randint(1, n_toks - inp_len + 1) if randomize else 1
            doc_toks = np.concatenate([doc_toks[:1], doc_toks[i_off:i_off + inp_len - 2], doc_toks[-1:]])
            # print(doc_toks)

        mask_toks_len = np.random.randint(1, mask_toks_len_max + 1)
        mask_toks_off = np.random.randint(len(doc_toks) - mask_toks_len + 1)
        masked_toks = doc_toks[mask_toks_off:mask_toks_off + mask_toks_len]
        masked_substr = tkz.decode(masked_toks)
        print(f'{doc_ind:03d}. {masked_substr}')
        doc_toks[mask_toks_off:mask_toks_off + mask_toks_len] = tkz.mask_token_id

        docs_toks[i, :len(doc_toks)] = doc_toks
    docs_toks_t = torch.from_numpy(docs_toks).to(device)
    return docs_toks_t


inp_len: 128


In [86]:
doc_inds = np.arange(5)
# doc_inds += 5
doc_inds = [x.item() for x in doc_inds]
for doc_ind in doc_inds:
    doc = ds[doc_ind]
    title, text = doc['title'], doc['text'].replace('\n', '\\n')
    print(f'{doc_ind:03d} "{title}" {text[:400]}')

000 "Anarchism" Anarchism is a political philosophy and movement that is sceptical of authority and rejects all involuntary, coercive forms of hierarchy. Anarchism calls for the abolition of the state, which it holds to be unnecessary, undesirable, and harmful. As a historically left-wing movement, placed on the farthest left of the political spectrum, it is usually described alongside communalism and libertarian
001 "Autism" Autism is a neurodevelopmental disorder characterized by difficulties with social interaction and communication, and by restricted and repetitive behavior. Parents often notice signs during the first three years of their child's life. These signs often develop gradually, though some autistic children experience regression in their communication and social skills after reaching developmental milest
002 "Albedo" Albedo (; ) is the measure of the diffuse reflection of solar radiation out of the total solar radiation and measured on a scale from 0, corresponding to a 

In [92]:
docs_toks_in = get_batch_tokens(doc_inds)
print(docs_toks_in[0])
print(docs_toks_in.shape)
logits_pred = model(docs_toks_in, docs_toks_in != tkz.pad_token_id)
probs_pred = torch.softmax(logits_pred, dim=-1)
# probs_pred = torch.sigmoid(logits_pred)
print(probs_pred.shape)
docs_toks_out = torch.argmax(probs_pred, dim=-1)
print(docs_toks_out.shape)

000. formal hierarch
001. first
002. albedo (
003. written
004. yellowhammer state,
tensor([  101,  9617, 11140,  2964,  2003,  1037,  2576,  4695,  1998,  2929,
         2008,  2003,  8040, 23606,  7476,  1997,  3691,  1998, 19164,  2035,
        26097,  1010, 24873, 11890,  3512,  3596,  1997, 12571,  1012,  9617,
        11140,  2964,  4455,  2005,  1996, 15766,  1997,  1996,  2110,  1010,
         2029,  2009,  4324,  2000,  2022, 14203,  1010,  6151,  2229,  7895,
         3468,  1010,  1998, 17631,  1012,  2004,  1037,  7145,  2187,  1011,
         3358,  2929,  1010,  2872,  2006,  1996,  2521, 20515,  2187,  1997,
         1996,  2576,  8674,  1010,  2009,  2003,  2788,  2649,  4077, 15029,
         2964,  1998, 19297, 27255,  2004,  1996, 19297,  3358,  1006, 19297,
        14649,  1007,  1997,  1996,  6102,  2929,  1010,  1998,  2038,  1037,
         2844,  3439,  2523,  2007,  3424,  1011, 16498,  1998, 14649,  1012,
         4286,  2973,  1999,  8384,  2302,   103,   103,  

In [93]:
for i, doc_ind in enumerate(doc_inds):
    s = tkz.decode(docs_toks_in[i])
    s = s.replace('\n', '\\n')
    print(f'{doc_ind:03d} {s}')

000 [CLS] anarchism is a political philosophy and movement that is sceptical of authority and rejects all involuntary, coercive forms of hierarchy. anarchism calls for the abolition of the state, which it holds to be unnecessary, undesirable, and harmful. as a historically left - wing movement, placed on the farthest left of the political spectrum, it is usually described alongside communalism and libertarian marxism as the libertarian wing ( libertarian socialism ) of the socialist movement, and has a strong historical association with anti - capitalism and socialism. humans lived in societies without [MASK] [MASK] [MASK] [MASK]ies long before the establishment of formal states [SEP]
001 [CLS] autism is a neurodevelopmental disorder characterized by difficulties with social interaction and communication, and by restricted and repetitive behavior. parents often notice signs during the [MASK] three years of their child ' s life. these signs often develop gradually, though some autistic 

In [94]:
for i, doc_ind in enumerate(doc_inds):
    s = tkz.decode(docs_toks_out[i])
    s = s.replace('\n', '\\n')
    print(f'{doc_ind:03d} {s}')

000 the, andrel [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD]
001 past two [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD]

## Encoder embedding evaluation

In [16]:
def get_tokens(txts: list[str]) -> torch.Tensor:
    batch_toks = np.full((len(txts), inp_len), tkz.pad_token_id)
    for i, txt in enumerate(txts):
        toks: list[int] = tkz(txt)['input_ids']
        n_toks = len(toks)
        if n_toks > inp_len:
            i_off = np.random.randint(n_toks - inp_len + 1)
            toks = toks[i_off:i_off + inp_len]
        batch_toks[i, :len(toks)] = toks
    batch_toks_t = torch.from_numpy(batch_toks).to(device)
    return batch_toks_t

model.eval()
None

In [17]:
txts = [
    '"Orana Australia Ltd" Orana Australia Ltd is a not-for-profit organisation that provides a diverse range of training and support services to over 650 people with disabilities and their families in South Australia.\n\nHistory\nThe Mentally Retarded Children’s Society of SA Inc. was established in 1950 by a group of parent',
    'Australia',
    'Orana Australia Ltd',
    'Hello Kitty',
]
batch_toks = get_tokens(txts)
embs = model.enc_bert(batch_toks, batch_toks != tkz.pad_token_id)
# embs = embs.detach().cpu().numpy()
embs = embs.detach().cpu()
print(embs.shape)

torch.Size([4, 768])


In [18]:
for i in range(1, len(embs)):
    cos_dist = F.cosine_similarity(embs[0:1], embs[i:i + 1])
    norm_dist = torch.norm(embs[0] - embs[i])
    print(txts[i], cos_dist.numpy(), norm_dist)

Australia [0.1348315] tensor(9.7678)
Orana Australia Ltd [0.14778207] tensor(10.4330)
Hello Kitty [0.08750954] tensor(10.7561)


## Answering questions

In [40]:
ctx = 'Context3. Like other American research universities, Northwestern was transformed by World War II. Franklyn B. Snyder led the university from 1939 to 1949, when nearly 50,000 military officers and personnel were trained on the Evanston and Chicago campuses. After the war, surging enrollments under the G.I. Bill drove drastic expansion of both campuses. In 1948 prominent anthropologist Melville J. '
q = 'Question: Between 1939 and 1949, how many military officers and personnel were trained on the Evanston and Chicago campuses? Answer: [MASK] [MASK] [MASK] [MASK]'
s = f'{ctx} {q}'
toks = tkz(s).input_ids
len(toks)

109

In [41]:
toks_in = torch.tensor(toks).to(device).unsqueeze(0)
logits_pred = model(toks_in, toks_in != tkz.pad_token_id)
print(logits_pred.shape)
probs_pred = torch.softmax(logits_pred[0], dim=-1)
toks_out = torch.argmax(probs_pred, dim=-1)
toks_out

torch.Size([1, 128, 30522])


tensor([  101,  6123,  2509,  1012,  2066,  2060,  2137,  2470,  5534,  1010,
         7855,  2001,  8590,  2011,  2088,  2162,  2462,  1012, 19597,  2078,
         1038,  1012, 17840,  2419,  1996,  2118,  2013,  3912,  2000,  4085,
         1010,  2043,  3053,  2753,  1010,  2199,  2510,  3738,  1998,  5073,
         2020,  4738,  2006,  1996,  6473,  2669,  1998,  3190, 13696,  1012,
         2044,  1996,  2162,  1010,  7505,  4726, 10316,  2015,  2104,  1996,
         1043,  1012,  1045,  1012,  3021,  6303, 20851,  4935,  1997,  2119,
        13696,  1012,  1999,  3882,  4069, 21571, 20154,  1046,  1012,  3160,
         1024,  2090,  3912,  1998,  4085,  1010,  2129,  2116,  2510,  3738,
         1998,  5073,  2020,  4738,  2006,  1996,  6473,  2669,  1998,  3190,
        13696,  1029,  3437,  1024,  1000,  1024,  2137,  2015,   102,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0])

In [42]:
tkz.decode(toks_out)

'[CLS] context3. like other american research universities, northwestern was transformed by world war ii. franklyn b. snyder led the university from 1939 to 1949, when nearly 50, 000 military officers and personnel were trained on the evanston and chicago campuses. after the war, surging enrollments under the g. i. bill demanded ample expansion of both campuses. in 1948 prominent anthropologist melville j. question : between 1939 and 1949, how many military officers and personnel were trained on the evanston and chicago campuses? answer : " : americans [SEP] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD]'

In [34]:
probs_pred[:10]

tensor([1.0000e+00, 1.1139e-12, 6.2658e-13, 1.4770e-12, 1.5230e-12, 1.4352e-12,
        1.0220e-12, 1.1165e-12, 1.0452e-12, 1.1000e-12],
       grad_fn=<SliceBackward0>)

In [31]:
probs_pred.shape

torch.Size([5, 128, 30522])

In [19]:
t = torch.rand(2, 3, 4)
t

tensor([[[0.6792, 0.0907, 0.6140, 0.9981],
         [0.2409, 0.8605, 0.0702, 0.2925],
         [0.1105, 0.1167, 0.4092, 0.3401]],

        [[0.3772, 0.2479, 0.2412, 0.0046],
         [0.8316, 0.8350, 0.4125, 0.5759],
         [0.0096, 0.6448, 0.9033, 0.0755]]])

In [22]:
t1 = torch.softmax(t, dim=2)
print(t1)
t1.shape

tensor([[[0.2586, 0.1435, 0.2422, 0.3557],
         [0.2103, 0.3908, 0.1773, 0.2215],
         [0.2168, 0.2182, 0.2923, 0.2728]],

        [[0.2906, 0.2554, 0.2537, 0.2002],
         [0.2911, 0.2921, 0.1914, 0.2254],
         [0.1563, 0.2949, 0.3819, 0.1669]]])


torch.Size([2, 3, 4])

In [27]:
t2 = torch.randint(0, 4, (2, 3, 1))
t2

tensor([[[1],
         [1],
         [3]],

        [[2],
         [2],
         [1]]])

In [28]:
torch.gather(t1, dim=2, index=t2)

tensor([[[0.1435],
         [0.3908],
         [0.2728]],

        [[0.2537],
         [0.1914],
         [0.2949]]])