In [7]:
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModel
import torch
import torch.nn.functional as F
from pathlib import Path
from math import ceil
import numpy as np
from einops import rearrange
import faiss
from autofaiss import build_index

dataset = load_dataset("wikipedia", "20200501.en")

Reusing dataset wikipedia (/home/dashiell/.cache/huggingface/datasets/wikipedia/20200501.en/1.0.0/009f923d9b6dd00c00c8cdc7f408f2b47f45dd4f5fb7982a21f9448f4afbe475)


  0%|          | 0/1 [00:00<?, ?it/s]

In [8]:
len(dataset['train'])

6078422

In [4]:
doc_text = dataset[934590]
chunk_size = 64
seq_len = 2048
pad_id = 0
print(doc_text)

{'title': 'Nickel (disambiguation)', 'text': 'Nickel is a chemical element.\n\nNickel may also refer to:\n\nPeople\n Nickel (surname)\n Nickel Ashmeade (born 1990), Jamaican athlete\n Nickel Chand (born 1995), Fijian footballer\n Nickel Hoffmann (1536–1592), German stonemason\n Nickel Leung, Hong Kong educator\n\nCoins and tokens\n Nickel (Canadian coin), a five cent coin introduced in 1922\n Nickel (United States coin), a five cent coin introduced in 1866\n Half dime, a U.S. five cent coin produced in various years in the range 1792–1873 (sometimes called a "nickel" due to its face value)\n Three-cent nickel, a U.S. coin (1865–1889)\n Indian Head cent, a U.S. coin (1859–1864) nicknamed the "nickel"\n\nGames and sports\n Nickel defense, a defense formation in American and Canadian football\n Nickel Trophy, awarded to the winner of the football game between North Dakota State University and the University of North Dakota\n\nOther uses\n Nickel, a shade of gray\n Nickel Theatre, St. John

In [6]:

MODEL = None
TOKENIZER = None

def exists(val):
    return val is not None

def range_chunked(max_value, *, batch_size):
    counter = 0
    while counter < max_value:
        curr = counter + batch_size
        curr = min(curr, max_value)
        yield slice(counter, curr)
        counter = curr

# indexing helper functions

def faiss_read_index(path):
    return faiss.read_index(str(path), faiss.IO_FLAG_MMAP | faiss.IO_FLAG_READ_ONLY)

#Mean Pooling - Take attention mask into account for correct averaging
def mean_pooling(model_output, attention_mask):
    token_embeddings = model_output[0] #First element of model_output contains all token embeddings
    input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
    return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)


def get_tokenizer():
    global TOKENIZER
    if not exists(TOKENIZER):
        TOKENIZER = AutoTokenizer.from_pretrained('sentence-transformers/all-mpnet-base-v2')
    return TOKENIZER

def get_bert():
    global MODEL
    if not exists(MODEL):
        MODEL = AutoModel.from_pretrained('sentence-transformers/all-mpnet-base-v2')
        if torch.cuda.is_available():
            MODEL = MODEL.cuda()
    return MODEL

def tokenize(texts, add_special_tokens = True):
    if not isinstance(texts, (list, tuple)):
        texts = [texts]

    tokenizer = get_tokenizer()

    encoding = tokenizer.batch_encode_plus(
        texts,
        add_special_tokens = add_special_tokens,
        padding = True,
        return_tensors = 'pt'
    )

    token_ids = encoding.input_ids
    return token_ids


In [22]:



assert (seq_len % chunk_size) == 0, 'sequence length must be divisible by chunk size'

ids = tokenize(doc_text)
print(ids)
print(ids.shape)

tensor([[    0, 15523,  2007,  1041,  5076,  5787,  1016, 15523,  2093,  2040,
          6527,  2004,  1028,  2115, 15523,  1010, 11992,  1011, 15523,  6687,
          4172,  9652,  1010,  2145,  2905,  1011,  1014, 17855,  8262, 15523,
          9216,  2098,  1010,  2145,  2790,  1011,  1014, 11468,  2323,  4366,
         15523, 24441,  1010, 16714,  2579,  1520, 18918,  2479,  1011,  1014,
          2450,  2966,  9339,  2243, 15523, 26041,  1014,  4295,  4294, 11494,
          7828,  2002, 19208,  2019, 15523,  1010,  3014,  9230,  1011,  1014,
          1041,  2278,  9362,  9230,  3111,  2003,  4802, 15523,  1010,  2146,
          2167,  9230,  1011,  1014,  1041,  2278,  9362,  9230,  3111,  2003,
          7651,  2435, 27215,  1014,  1041,  1061,  1016,  1059,  1016,  2278,
          9362,  9230,  2554,  2003,  2540,  2090,  2003,  2000,  2850, 13418,
          1520,  7616,  1010,  2827,  2174,  1041,  1004, 15523,  1004,  2353,
          2004,  2053,  2231,  3647,  1011,  2097,  

In [23]:
ids = rearrange(ids, '1 ... -> ...')

print(ids)
ids.shape

tensor([    0, 15523,  2007,  1041,  5076,  5787,  1016, 15523,  2093,  2040,
         6527,  2004,  1028,  2115, 15523,  1010, 11992,  1011, 15523,  6687,
         4172,  9652,  1010,  2145,  2905,  1011,  1014, 17855,  8262, 15523,
         9216,  2098,  1010,  2145,  2790,  1011,  1014, 11468,  2323,  4366,
        15523, 24441,  1010, 16714,  2579,  1520, 18918,  2479,  1011,  1014,
         2450,  2966,  9339,  2243, 15523, 26041,  1014,  4295,  4294, 11494,
         7828,  2002, 19208,  2019, 15523,  1010,  3014,  9230,  1011,  1014,
         1041,  2278,  9362,  9230,  3111,  2003,  4802, 15523,  1010,  2146,
         2167,  9230,  1011,  1014,  1041,  2278,  9362,  9230,  3111,  2003,
         7651,  2435, 27215,  1014,  1041,  1061,  1016,  1059,  1016,  2278,
         9362,  9230,  2554,  2003,  2540,  2090,  2003,  2000,  2850, 13418,
         1520,  7616,  1010,  2827,  2174,  1041,  1004, 15523,  1004,  2353,
         2004,  2053,  2231,  3647,  1011,  2097,  1015,  9362, 

torch.Size([270])

In [24]:
text_len = ids.shape[-1]

# pad to multiple of chunk size with an extra token

padding = chunk_size - ((text_len - 1) % chunk_size)
ids = F.pad(ids, (0, padding))
ids

tensor([    0, 15523,  2007,  1041,  5076,  5787,  1016, 15523,  2093,  2040,
         6527,  2004,  1028,  2115, 15523,  1010, 11992,  1011, 15523,  6687,
         4172,  9652,  1010,  2145,  2905,  1011,  1014, 17855,  8262, 15523,
         9216,  2098,  1010,  2145,  2790,  1011,  1014, 11468,  2323,  4366,
        15523, 24441,  1010, 16714,  2579,  1520, 18918,  2479,  1011,  1014,
         2450,  2966,  9339,  2243, 15523, 26041,  1014,  4295,  4294, 11494,
         7828,  2002, 19208,  2019, 15523,  1010,  3014,  9230,  1011,  1014,
         1041,  2278,  9362,  9230,  3111,  2003,  4802, 15523,  1010,  2146,
         2167,  9230,  1011,  1014,  1041,  2278,  9362,  9230,  3111,  2003,
         7651,  2435, 27215,  1014,  1041,  1061,  1016,  1059,  1016,  2278,
         9362,  9230,  2554,  2003,  2540,  2090,  2003,  2000,  2850, 13418,
         1520,  7616,  1010,  2827,  2174,  1041,  1004, 15523,  1004,  2353,
         2004,  2053,  2231,  3647,  1011,  2097,  1015,  9362, 

In [25]:
ids, last_token = ids[:-1], ids[-1:]
ids = rearrange(ids, '(n c) -> n c', c = chunk_size)
print(ids)
print(last_token)

tensor([[    0, 15523,  2007,  1041,  5076,  5787,  1016, 15523,  2093,  2040,
          6527,  2004,  1028,  2115, 15523,  1010, 11992,  1011, 15523,  6687,
          4172,  9652,  1010,  2145,  2905,  1011,  1014, 17855,  8262, 15523,
          9216,  2098,  1010,  2145,  2790,  1011,  1014, 11468,  2323,  4366,
         15523, 24441,  1010, 16714,  2579,  1520, 18918,  2479,  1011,  1014,
          2450,  2966,  9339,  2243, 15523, 26041,  1014,  4295,  4294, 11494,
          7828,  2002, 19208,  2019],
        [15523,  1010,  3014,  9230,  1011,  1014,  1041,  2278,  9362,  9230,
          3111,  2003,  4802, 15523,  1010,  2146,  2167,  9230,  1011,  1014,
          1041,  2278,  9362,  9230,  3111,  2003,  7651,  2435, 27215,  1014,
          1041,  1061,  1016,  1059,  1016,  2278,  9362,  9230,  2554,  2003,
          2540,  2090,  2003,  2000,  2850, 13418,  1520,  7616,  1010,  2827,
          2174,  1041,  1004, 15523,  1004,  2353,  2004,  2053,  2231,  3647,
          1011

In [29]:
last_token_per_chunk = ids[1:, 0]
last_token_per_chunk


tensor([15523, 15523,  2000, 21563])

In [31]:
all_last_tokens = torch.cat((last_token_per_chunk, last_token), dim = 0)

all_last_tokens = rearrange(all_last_tokens, 'n -> n 1')
all_last_tokens

tensor([[15523],
        [15523],
        [ 2000],
        [21563],
        [    0]])

In [33]:
# append all last tokens to ids for (num_chunks, chunk_size + 1)
chunks_with_extra_token = torch.cat((ids, all_last_tokens), dim = -1)
chunks_with_extra_token.shape

torch.Size([5, 65])

In [35]:
total_chunks = ids.shape[0]
num_chunks_per_seq = seq_len // chunk_size
num_chunks_per_seq

32

In [7]:



MODEL = None
TOKENIZER = None

#Mean Pooling - Take attention mask into account for correct averaging
def mean_pooling(model_output, attention_mask):
    token_embeddings = model_output[0] #First element of model_output contains all token embeddings
    input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
    return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)


def get_tokenizer():
    global TOKENIZER
    if not exists(TOKENIZER):
        TOKENIZER = AutoTokenizer.from_pretrained('sentence-transformers/all-mpnet-base-v2')
    return TOKENIZER

def get_bert():
    global MODEL
    if not exists(MODEL):
        MODEL = AutoModel.from_pretrained('sentence-transformers/all-mpnet-base-v2')
        if torch.cuda.is_available():
            MODEL = MODEL.cuda()
    return MODEL

# tokenize

def tokenize(texts, add_special_tokens = True):
    if not isinstance(texts, (list, tuple)):
        texts = [texts]

    tokenizer = get_tokenizer()

    encoding = tokenizer.batch_encode_plus(
        texts,
        add_special_tokens = add_special_tokens,
        padding = True,
        return_tensors = 'pt'
    )

    token_ids = encoding.input_ids
    return token_ids

# text to chunks

def doc_text_to_chunks_and_seq_indices(
    *,
    doc_text,
    chunk_size = 64,
    seq_len = 2048,
    pad_id = 0
):
    assert (seq_len % chunk_size) == 0, 'sequence length must be divisible by chunk size'

    ids = tokenize(doc_text)
    ids = rearrange(ids, '1 ... -> ...')

    text_len = ids.shape[-1]

    # pad to multiple of chunk size with an extra token

    padding = chunk_size - ((text_len - 1) % chunk_size)
    ids = F.pad(ids, (0, padding))

    # split out very last token

    ids, last_token = ids[:-1], ids[-1:]
    ids = rearrange(ids, '(n c) -> n c', c = chunk_size)

    # first tokens of chunk [2:] and on will become the last token of chunk [1:]

    last_token_per_chunk = ids[1:, 0]
    all_last_tokens = torch.cat((last_token_per_chunk, last_token), dim = 0)
    all_last_tokens = rearrange(all_last_tokens, 'n -> n 1') # Transpose sorta? [x, y, z] -> [[x], [y], [z]]

    # append all last tokens to ids for (num_chunks, chunk_size + 1)

    chunks_with_extra_token = torch.cat((ids, all_last_tokens), dim = -1)

    # calculate chunk indices starting at 0, spaced number of chunks of seq len apart

    total_chunks = ids.shape[0]
    num_chunks_per_seq = seq_len // chunk_size
    seq = torch.arange(0, total_chunks, num_chunks_per_seq)
    return chunks_with_extra_token, seq



In [21]:
32 * 64

2048

In [8]:
from tqdm.notebook import tqdm

def text_dataset_to_chunks_(
    *,
    dataset,
    chunks_memmap_path,
    seqs_memmap_path,
    doc_ids_memmap_path,
    chunk_size = 64,
    seq_len = 2048,
    glob = '**/*.txt',
    max_chunks = 1_00_000_000,
    max_seqs = 1_00_000
):

    total_chunks = 0
    total_docs = 0
    total_seqs = 0

    chunks_shape = (max_chunks, chunk_size + 1)
    seqs_shape = (max_seqs,)
    doc_ids_shape = (max_chunks,)
    seq_stack = []
    chunk_stack = []

    with (
        # huge on-disk numpy array of shape (VERY_BIG, 65), writes a set of length-65 rows which made up of token IDs
        memmap(chunks_memmap_path, shape = chunks_shape, dtype = np.int32, mode = 'w+') as chunks_memmap,
        # 
        memmap(seqs_memmap_path, shape = seqs_shape, dtype = np.int32, mode = 'w+') as seqs_memmap,
        memmap(doc_ids_memmap_path, shape = doc_ids_shape, dtype = np.int32, mode = 'w+') as doc_ids_memmap
    ):
        for text in tqdm(dataset):
            chunks, seq = doc_text_to_chunks_and_seq_indices(
                doc_text = text['text'],
                chunk_size = chunk_size,
                seq_len = seq_len
            )
            """
            chunk_stack.append(chunks)
            seq_stack.append(seq)
            if len(chunk_stack) > 50:
                chunk_stack.pop(0)
                seq_stack.pop(0)
            """
            doc_chunk_len = chunks.shape[0]
            doc_seq_len = seq.shape[0]
            chunks_memmap[total_chunks:(total_chunks + doc_chunk_len)] = chunks.numpy()
            try:
                seqs_memmap[total_seqs:(total_seqs + doc_seq_len)] = seq.numpy() + total_chunks
            except ValueError as ve:
                for chunk_hist, seq_hist in zip(chunk_stack, seq_stack):
                    print(seq_hist, chunk_hist)
                raise ve
            doc_ids_memmap[total_chunks:(total_chunks + doc_chunk_len)] = np.full((doc_chunk_len,), total_docs)
            last_seq = seq
            total_chunks += doc_chunk_len
            total_seqs += doc_seq_len
            total_docs += 1

    return dict(
        chunks = total_chunks,
        docs = total_docs,
        seqs = total_seqs
    )

In [9]:
@torch.no_grad()
def bert_embed(
    token_ids,
    return_cls_repr = False,
    eps = 1e-8,
    pad_id = 0.
):
    model = get_bert()
    mask = token_ids != pad_id

    if torch.cuda.is_available():
        token_ids = token_ids.cuda()

    outputs = model(
        input_ids = token_ids,
        attention_mask = mask,
        output_hidden_states = True
    )

    hidden_state = outputs.hidden_states[-1]

    if return_cls_repr:
        return hidden_state[:, 0]               # return [cls] as representation

    if not exists(mask):
        return hidden_state.mean(dim = 1)

    mask = mask[:, 1:]                          # mean all tokens excluding [cls], accounting for length
    mask = rearrange(mask, 'b n -> b n 1')

    numer = (hidden_state[:, 1:] * mask).sum(dim = 1)
    denom = mask.sum(dim = 1)
    masked_mean =  numer / (denom + eps)
    return masked_mean




In [11]:
from retro_pytorch.utils import memmap, reset_folder_

chunk_size = 64
knn = 2
seq_len = 2048
chunks_memmap_path = './train.chunks.dat'
seqs_memmap_path = './train.seq.dat'
doc_ids_memmap_path = './train.doc_ids.dat'
max_chunks = 1_000_000_000
max_seqs = 100_000_000
knn_extra_neighbors = 100
processed_stats_json_path = './processed-stats.json'
faiss_index_filename = 'knn.index'


# constants

SOS_ID = 101
EOS_ID = 102
BERT_MODEL_DIM = 768
BERT_VOCAB_SIZE = 28996

TMP_PATH = Path('./.tmp')
INDEX_FOLDER_PATH = TMP_PATH / '.index'
EMBEDDING_TMP_SUBFOLDER = 'embeddings'

# helper functions

def exists(val):
    return val is not None

def range_chunked(max_value, *, batch_size):
    counter = 0
    while counter < max_value:
        curr = counter + batch_size
        curr = min(curr, max_value)
        yield slice(counter, curr)
        counter = curr

# indexing helper functions

def faiss_read_index(path):
    return faiss.read_index(str(path), faiss.IO_FLAG_MMAP | faiss.IO_FLAG_READ_ONLY)


stats = text_dataset_to_chunks_(
    dataset = dataset, 
    chunks_memmap_path = chunks_memmap_path,
    seqs_memmap_path = seqs_memmap_path,
    doc_ids_memmap_path = doc_ids_memmap_path,
    chunk_size = chunk_size,
    seq_len = seq_len,
    max_chunks = max_chunks,
    max_seqs = max_seqs
)


num_chunks = stats['chunks']
num_seqs = stats['seqs']

# calculate knn memmap path and get the faiss index
# todo - make sure if faiss_index_filename is found, do not reprocess unless flag is given

knn_memmap_path, faiss_index = chunks_to_precalculated_knn_(
    num_chunks = num_chunks,
    chunk_size = chunk_size,
    chunk_memmap_path = chunks_memmap_path,
    doc_ids_memmap_path = doc_ids_memmap_path,
    num_nearest_neighbors = knn,
    num_extra_neighbors = knn_extra_neighbors,
    index_file = faiss_index_filename,
    force_reprocess = force_reprocess,
)


  0%|          | 0/6078422 [00:00<?, ?it/s]

Token indices sequence length is longer than the specified maximum sequence length for this model (730 > 512). Running this sequence through the model will result in indexing errors

KeyboardInterrupt

