In [10]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [64]:
import faiss

In [65]:
embed_dim = 64
M = 16

index = faiss.IndexHNSWFlat(embed_dim, M)

In [66]:
index

<faiss.swigfaiss_avx512.IndexHNSWFlat; proxy of <Swig Object of type 'faiss::IndexHNSWFlat *' at 0x7f8765f06a60> >

In [67]:
import numpy as np
d = 64                           # dimension
nb = 100000                      # database size
nq = 10000                       # nb of queries
np.random.seed(1234)             # make reproducible
xb = np.random.random((nb, d)).astype('float32')
xb[:, 0] += np.arange(nb) / 1000.
xq = np.random.random((nq, d)).astype('float32')
xq[:, 0] += np.arange(nq) / 1000.

In [68]:
index.add(xb) 
print(index.ntotal)

100000


In [69]:
k = 4                          # we want to see 4 nearest neighbors
D, I = index.search(xb[:1], k) # sanity check
print(I)
print(D)
D, I = index.search(xq, k)     # actual search
print(I[:5])                   # neighbors of the 5 first queries
print(I[-5:])  
index.get_top_docs(xq, k)

[[  0 363  78 924]]
[[0.        7.207629  7.2511625 7.3218946]]
[[ 381  477  588  329]
 [ 526  911  142   72]
 [ 838  527 1290  425]
 [ 196  184  164  599]
 [ 526  377  425  917]]
[[ 9900  9309  9831 10568]
 [11055 10895 10812 11321]
 [11353 11103 10164  9787]
 [10571 10664 10632  9638]
 [ 9554 10036  9582 10304]]


AttributeError: 'IndexHNSWFlat' object has no attribute 'get_top_docs'

In [44]:
from transformers import AutoModel, AutoTokenizer
import torch
from torch.nn import functional as F

retriever_tokenizer = AutoTokenizer.from_pretrained('mistralai/Mistral-7B-Instruct-v0.2', model_max_length=8192)

retriever_tokenizer.pad_token = retriever_tokenizer.eos_token
print(retriever_tokenizer.model_max_length)
context = "Context Context Context"
prompt = "Write a summary of the following text in bullet points:"
target = "This is a test"

context_inputs = retriever_tokenizer(context, return_tensors="pt", padding=True, truncation=True)['input_ids']
inputs = retriever_tokenizer(prompt, return_tensors="pt", padding=True, truncation=True)['input_ids']
targets = retriever_tokenizer(target, return_tensors="pt", padding=True, truncation=True)['input_ids']
print(context_inputs, inputs)
combined_inputs = [torch.cat((ctx, inp), dim=0) for ctx, inp in zip(context_inputs, inputs)]
print(combined_inputs)
# Prepare attention mask for combined inputs
attention_mask = [torch.ones_like(combined_input) for combined_input in combined_inputs]

# Prepare the targets by shifting the tokens to the left so the model predicts the next token
print("targets", targets)
shifted_targets = [trgt[1:] for trgt in targets]
print("shifted targets", shifted_targets)
# Pad combined inputs and attention masks to the maximum sequence length in the batch
max_len = max([combined_input.shape[0] for combined_input in combined_inputs])
padded_inputs = torch.stack([F.pad(input, (0, max_len - input.shape[0]), value=retriever_tokenizer.pad_token_id) for input in combined_inputs])
print(padded_inputs)
padded_attention_mask = torch.stack([F.pad(mask, (0, max_len - mask.shape[0]), value=0) for mask in attention_mask])

8192
tensor([[    1, 14268, 14268, 14268]]) tensor([[    1, 12018,   264, 14060,   302,   272,  2296,  2245,   297, 17144,
          3569, 28747]])
[tensor([    1, 14268, 14268, 14268,     1, 12018,   264, 14060,   302,   272,
         2296,  2245,   297, 17144,  3569, 28747])]
targets tensor([[   1,  851,  349,  264, 1369]])
shifted targets [tensor([ 851,  349,  264, 1369])]
tensor([[    1, 14268, 14268, 14268,     1, 12018,   264, 14060,   302,   272,
          2296,  2245,   297, 17144,  3569, 28747]])


In [107]:
ctx = torch.ones((4,5,16))
inputs = torch.zeros((4,16))
labels = torch.ones((4,16))

B = ctx.shape[0]  # Extracting batch dimension

# Using list comprehension for a cleaner approach with extracted batch dimension
sources = torch.stack([torch.stack([torch.cat((con, inputs[idx], labels[idx]), dim=0) for con in ctx[idx]]) for idx in range(B)])

ctx.shape, inputs.shape, targets.shape, sources.shape, sources[0][0]

(torch.Size([4, 5, 16]),
 torch.Size([4, 16]),
 torch.Size([1, 3]),
 torch.Size([4, 5, 48]),
 tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1.,
         1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]))

In [56]:
from torch.utils.data import Dataset
import torch
import transformers
from typing import Dict, Sequence
import logging
import json
from dataclasses import dataclass
import copy

IGNORE_INDEX = -100

@dataclass
class DataCollatorForSupervisedDataset(object):
    """Collate examples for supervised fine-tuning."""

    tokenizer: transformers.PreTrainedTokenizer

    def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
        input_ids, labels = tuple([instance[key] for instance in instances] for key in ("input_ids", "labels"))
        input_ids = torch.nn.utils.rnn.pad_sequence(
            input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id
        )
        labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX)
        return dict(
            input_ids=input_ids,
            labels=labels,
            attention_mask=input_ids.ne(self.tokenizer.pad_token_id),
        )

def _tokenize_fn(strings: Sequence[str], tokenizer: transformers.PreTrainedTokenizer) -> Dict:
    """Tokenize a list of strings."""
    tokenized_list = [
        tokenizer(
            text,
            return_tensors="pt",
            padding="longest",
            max_length=tokenizer.model_max_length,
            truncation=True,
        )
        for text in strings
    ]
    input_ids = labels = [tokenized.input_ids[0] for tokenized in tokenized_list]
    input_ids_lens = labels_lens = [
        tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item() for tokenized in tokenized_list
    ]
    return dict(
        input_ids=input_ids,
        labels=labels,
        input_ids_lens=input_ids_lens,
        labels_lens=labels_lens,
    )

INST_START = "[INST]"
INST_END = "[/INST]"

def preprocess(
    sources: Sequence[str],
    targets: Sequence[str],
    tokenizer: transformers.PreTrainedTokenizer,
) -> Dict:
    """Preprocess the data by tokenizing."""
    examples = [f"{INST_START} {s} {INST_END} {t}" for s, t in zip(sources, targets)]
    examples_tokenized, sources_tokenized = [_tokenize_fn(strings, tokenizer) for strings in (examples, sources)]
    input_ids = examples_tokenized["input_ids"]
    labels = copy.deepcopy(input_ids)
    for label, source_len in zip(labels, sources_tokenized["input_ids_lens"]):
        label[:source_len] = IGNORE_INDEX
    return dict(input_ids=input_ids, labels=labels)

class SupervisedDataset(Dataset):
    """Dataset for supervised fine-tuning."""

    def __init__(self, data_path: str, tokenizer: transformers.PreTrainedTokenizer):
        super(SupervisedDataset, self).__init__()
        logging.warning("Loading data...")
        list_data_dict = self._jload(data_path)
        
        logging.warning("Formatting inputs...")
        sources = [
            example.get("user", "") for example in list_data_dict
        ]
        print(sources[0])
        targets = [f"{example['assistant']}{tokenizer.eos_token}" for example in list_data_dict]
        print(targets[0])

        logging.warning("Tokenizing inputs... This may take some time...")
        data_dict = preprocess(sources, targets, tokenizer)

        self.input_ids = data_dict["input_ids"]
        self.labels = data_dict["labels"]

    def _jload(self, path):
        f = open(path, 'r')
        j = json.load(f)
        f.close()
        return j

    def __len__(self):
        return len(self.input_ids)

    def __getitem__(self, i) -> Dict[str, torch.Tensor]:
        return dict(input_ids=self.input_ids[i], labels=self.labels[i])

def make_supervised_data_module(tokenizer: transformers.PreTrainedTokenizer, data_path) -> Dict:
    """Make dataset and collator for supervised fine-tuning."""
    train_dataset = SupervisedDataset(tokenizer=tokenizer, data_path=data_path)
    data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer)
    return dict(train_dataset=train_dataset, eval_dataset=None, data_collator=data_collator)

In [57]:
data_module = make_supervised_data_module(
    tokenizer=retriever_tokenizer,
    data_path="rag/data/generated/processed/all_processed_data_cleaned_no_markers_split.json")

from torch.utils.data import DataLoader
dataloader = DataLoader(data_module['train_dataset'], batch_size=16, collate_fn=data_module['data_collator'])



Who are the Fremen in the Dune universe by Frank Herbert?
The Fremen are a group of desert-dwelling people on the planet Arrakis in the Dune universe. They are skilled fighters, adept at surviving the harsh conditions of the desert known as the "Deep Desert." The Fremen have their own unique culture and customs, including the use of stillsuits to reclaim moisture and their reverence for the sandworms that produce the valuable spice melange. They play a significant role in the political and religious conflicts depicted in the Dune series by Frank Herbert.
</s>


In [62]:
loader = iter(dataloader)
batch = next(loader)
print(batch['input_ids'].shape, retriever_tokenizer.batch_decode(sequences=batch['input_ids'])[0])
batch = next(loader)
print(batch['input_ids'][0], batch['labels'][0], batch['attention_mask'][0])
batch = next(loader)
print(batch['input_ids'].shape)

torch.Size([16, 255]) <s> [INST] Who are the Fremen in the Dune universe by Frank Herbert? [/INST] The Fremen are a group of desert-dwelling people on the planet Arrakis in the Dune universe. They are skilled fighters, adept at surviving the harsh conditions of the desert known as the "Deep Desert." The Fremen have their own unique culture and customs, including the use of stillsuits to reclaim moisture and their reverence for the sandworms that produce the valuable spice melange. They play a significant role in the political and religious conflicts depicted in the Dune series by Frank Herbert.
</s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></

In [34]:
tokenizer = Tokenizer(
        model_path="_model/tokenizer.model"
    )
tokenizer.decode([2])

''

In [2]:
import torch
import torch.nn.functional as F
import torch.nn as nn
import numpy as np
from typing import List
from transformers import AdamW
from torch.utils.data import DataLoader, DistributedSampler
from mistral.model import Transformer as Mistral
from mistral.tokenizer import Tokenizer
import transformers
from transformers import AutoTokenizer, AutoModel
import loralib as lora
from pathlib import Path
from retriever.index import init_index, get_top_docs
from retriever.nomic import mean_pooling
from dataset import init_dataset

generator_tokenizer = AutoTokenizer.from_pretrained('mistralai/Mistral-7B-Instruct-v0.2', model_max_length=8192)
retriever_tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased', model_max_length=8192)

generator_tokenizer.pad_token = generator_tokenizer.eos_token

print(generator_tokenizer.pad_token_id, generator_tokenizer.bos_token_id)

2 1


In [3]:
matryoshka_dim = 768
index_path = "index/dune.index"
index = init_index(matryoshka_dim, index_path)

Initializing index...


In [4]:
from utils import load_documents
documents_path = "data/chunks"
documents = load_documents(documents_path)
len(documents)

749

In [5]:
doc_encoder = AutoModel.from_pretrained('nomic-ai/nomic-embed-text-v1.5', trust_remote_code=True, safe_serialization=True, rotary_scaling_factor=2)

<All keys matched successfully>


In [14]:
from retriever.index import build_index

build_index(index, documents, doc_encoder, retriever_tokenizer, index_path=index_path)

Building index...
Adding document 749/749 to index...
Index building complete.


In [6]:
from utils import setup_data
    
dataloader, sampler, train_dataset = setup_data(generator_tokenizer, retriever_tokenizer, "data/dune_mistral_instruct.jsonl", 32, True)

In [11]:
top_k = 5

def retrieve(retriever, batch, documents):
    input_ids, labels, mask = batch['input_ids'], batch['labels'], batch['mask']
    retriever_inputs, retriever_attn_mask = batch['retriever_tokens'], batch['retriever_attn_mask']
    B = input_ids.shape[0]

    # embed
    embeded_inputs = retriever(input_ids=retriever_inputs, attention_mask=retriever_attn_mask)

    embeddings = mean_pooling(embeded_inputs, retriever_attn_mask)

    if matryoshka_dim is not None:
        embeddings = F.layer_norm(embeddings, normalized_shape=(embeddings.shape[1],))
        embeddings = embeddings[:, :matryoshka_dim]
    embeddings_batched = F.normalize(embeddings, p=2, dim=1)

    # retrieve()
    I = []
    vectors_batched = []
    for embeddings in embeddings_batched:
        ids, retrieved_doc_embeds = get_top_docs(index, embeddings, top_k)
        I.extend(ids)
        vectors_batched.extend(retrieved_doc_embeds)
    I = np.array(I)
    vectors_batched = np.array(vectors_batched)
    # get embbeddings from index by I

    retrieved_doc_embeds = torch.tensor(vectors_batched)

    # I = (batch_size, top_k), top_k dimension is the document ids
    # assume dataset.get_document(idx) returns tokenized document context ids
    # return context_ids tensor over batched I, context_ids = (batch_size, top_k, max_length)
    docs = []
    for indicies in I:
        docs.append([documents[idx] for idx in indicies])

    context_input_ids, context_attention_mask, context_labels = process_docs(docs, input_ids, labels, mask, top_k)
    
    # https://github.com/huggingface/transformers/blob/66ce9593fdb8e340df546ddd0774eb444f17a12c/src/transformers/models/rag/modeling_rag.py#L644
    doc_scores = torch.bmm(
        embeddings_batched.unsqueeze(1),
        retrieved_doc_embeds.transpose(1, 2)
    ).squeeze(1)

    return context_input_ids, context_attention_mask, context_labels, doc_scores

retriever = doc_encoder

In [35]:
def process_docs(
    docs: List[List[str]],
    input_ids: torch.Tensor,
    labels: torch.Tensor,
    masks: torch.Tensor,
    n_docs: int
):
    
    # decode masked tokens from labels, do first in batch
    masked_tokens = [token for token, mask in zip(labels[0], masks[0]) if mask]
    print(generator_tokenizer.decode(masked_tokens))
    

    # custom batch encode
    context_inputs = {
        'input_ids': [],
        'masks': [],
        'labels': []
    }
    for i in range(n_docs):
        for j in range(len(docs)):

            
            
            doc_tokens = generator_tokenizer.encode(f"<s> [CONTEXT] {docs[j][i]} [/CONTEXT]\n", add_special_tokens=False)
            doc_tokens = torch.tensor(doc_tokens)
            doc_mask = torch.tensor([False] * (len(doc_tokens)))

            # assert input_ids and labels are the same length
            assert input_ids[i].shape[0] == labels[i].shape[0] == masks[i].shape[0]
            # assert doc_tokens.shape == doc_mask.shape

            # concat to front of input_ids and labels and masks, this wont work!
            # we need to just tokenize everything here, batches are just the text pairs
            # [1,2,3,4,5]
            # [2,3,4,5,pad]
            # [F,F,T,T,F]
            context_tokens = torch.cat((doc_tokens, input_ids[i]))
            context_mask = torch.cat((doc_mask, torch.tensor([False]), masks[i]))

            # assert sequence length is the same
            # assert context_tokens.shape[0] == context_mask.shape[0]

            context_inputs['input_ids'].append(context_tokens[:-1])
            context_inputs['masks'].append(context_mask[1:])
            context_inputs['labels'].append(context_tokens[1:])
        
    # pad sequences with eos token
    for key in context_inputs:
        if key == 'labels':
            context_inputs[key] = torch.nn.utils.rnn.pad_sequence(context_inputs[key], batch_first=True, padding_value=generator_tokenizer.pad_token_id)
        else:
            context_inputs[key] = torch.nn.utils.rnn.pad_sequence(context_inputs[key], batch_first=True, padding_value=False)

    return context_inputs['input_ids'], context_inputs['masks'], context_inputs['labels']
    
epochs_run = 0
steps_per_epoch = len(dataloader)
num_epochs = 10

for epoch in range(epochs_run, num_epochs):

    # sampler.set_epoch(epoch)
    for step, batch in enumerate(dataloader):
        input_ids, labels, mask = batch['input_ids'], batch['labels'], batch['mask']
        retriever_tokens, retriever_attn_mask = batch['retriever_tokens'], batch['retriever_attn_mask']

        # retrieve
        context_input_ids, context_masks, context_labels, doc_scores = retrieve(retriever, batch, documents)
        break
    break


The Fremen are a group of desert-dwelling people on the planet Arrakis in the Dune universe. They are skilled fighters, adept at surviving the harsh conditions of the desert known as the "Deep Desert." The Fremen have their own unique culture and customs, including the use of stillsuits to reclaim moisture and their reverence for the sandworms that produce the valuable spice melange. They play a significant role in the political and religious conflicts depicted in the Dune series by Frank Herbert.
</s>


In [36]:
print(context_input_ids[5][0:100])
print(context_labels[5][0:100])
# print(generator_tokenizer.decode(context_input_ids[5]))

# decode masked tokens only, from context_masks
masked_tokens = [token for token, mask in zip(context_labels[5], context_masks[5]) if mask]
print(generator_tokenizer.decode(masked_tokens))


tensor([    1, 28705,   733, 27181, 28793,  9706, 28725,   272,   908,   628,
        28742, 28713,  8852,  4142,  2622, 12298, 28747, 28705,   345,  5364,
          302,   592, 28687,   369,   400,   659,   750,  4644,   286,  1835,
         3358,   304,   315,   541, 11016,   480,   397,   356,   713, 28723,
        28705,   650,   622,   842,   842, 15192,    13,  2428,  5970, 10832,
          390,   272,  5344,  2061,  2327,   680,   281,  3692,   356,   516,
         2105, 28723,    13, 28739, 28737,   511,   459,  1038,  2823,  6912,
          611, 28705,  2354,  8852, 21028,   272,  7470,  6639,   595,  3085,
          395,  1656, 28723, 28705,   345,  1976,   873,   478,   622,   511,
          813,  1489,   562,   315,  6557,   369,   478, 18619,   395, 15068])
tensor([28705,   733, 27181, 28793,  9706, 28725,   272,   908,   628, 28742,
        28713,  8852,  4142,  2622, 12298, 28747, 28705,   345,  5364,   302,
          592, 28687,   369,   400,   659,   750,  4644,   286,