# Lexicap

This notebook aims to index [Lex Fridman's podcasts](https://www.youtube.com/playlist?list=PLrAXtmErZgOdP_8GztsuKi9nrraNbKKp4) 
transcriptions for Question-Answering using [Andrej Karathy's](https://twitter.com/karpathy) transcriptions produced with [OpenAI's whisper](https://github.com/openai/whisper/blob/main/model-card.md) 👉️ [Lexicap](https://karpathy.ai/lexicap/).

At the moment this code relies on a private package from MeliorAI, namely [distributed-faiss](https://github.com/MeliorAI/distributed-faiss) for distributed indexing using [FAISS](https://github.com/facebookresearch/faiss) as the underlying index. Every other package is openly available. 

In [71]:
!pip install -qq \
    distributed_faiss \
    nltk \
    tabulate \
    sentence_transformers

## 📚️ Data 

In [11]:
import nltk


try:
    nltk.data.find("tokenizers/punkt")
except LookupError:
    nltk.download("punkt")
            

DATA_DIR = "./data"

In [182]:
import os
import glob
import re
import webvtt

from tabulate import tabulate
from tqdm.auto import tqdm


def gather_transcripts(data_dir:str = DATA_DIR, load_small:bool = False):
    mask = "*_small.vtt" if load_small else "*_large.vtt"
    return sorted(
        glob.glob(os.path.join(data_dir, "vtt", mask))
    )


def gather_episode_data(data_dir:str = DATA_DIR):
    ep_data = {}
    with open(os.path.join(data_dir, "episode_names.txt")) as f:
        for line in f.readlines():
            line = line.strip()
            ep_num = re.findall("(?<=#)\d+", line)[0]
            line = re.sub(f"^{ep_num} ", "", line)
            guest_and_title = line.split(" | ")[0]
            ep_data[ep_num] = {
                "guest": guest_and_title.split(": ")[0], 
                "title": ": ".join(guest_and_title.split(": ")[1:])
            }
    
    return ep_data

    
def sent_split(text:str):
    """Divides each document section into several chunks based on
    a NLTK **sentence** tokenizer.
    """
    text_chunks = nltk.tokenize.sent_tokenize(text)
    text_indices = list(map(
        lambda m: (m.start(), m.end()) if m else (None, None),
        [re.search(re.escape(text_chunk), text) for text_chunk in text_chunks]
    ))
    
    return list(zip(text_chunks, text_indices))


def print_datapoint(datapoint, lim:int = -1):
    keys = ["ep_num", "guest", "title"]
    print(tabulate([[datapoint[k] for k in keys]], 
                   headers=keys,
                  tablefmt="grid"))
    print(tabulate(map(lambda x: (x[0], str(x[1])), datapoint["texts"][:lim]), 
                   headers=["text", "indices"],
                   tablefmt="grid",
                   maxcolwidths=55))
    
    
def build_dataset():
    script_files = gather_transcripts()
    ep_data = gather_episode_data()
    print(f"Total episodes transcripts: {len(script_files)}")

    # Compile a dataset of transcripts with its episode information
    dataset = []
    data_iter = tqdm(script_files)
    sent_id = 0
    for sfile in data_iter:
        fname = os.path.basename(sfile)
        data_iter.set_description(fname)

        # Episode number
        ep_num = str(int(re.findall("\d{1,3}", fname)[0]))  # trim leading 0's
        ep_info = ep_data[ep_num]

        # Read transcript and split in sentences
        text = " ".join([caption.text for caption in webvtt.read(sfile)])
        text_chunks = sent_split(text)
        nt = len(text_chunks)

        ep_info["text_ids"] = list(range(sent_id, sent_id + nt))
        ep_info["texts"] = text_chunks
        ep_info["ep_num"] = ep_num
        
        dataset.append(ep_info)
        sent_id += nt
        
    return dataset

In [185]:
dataset = build_dataset()

# print one of the elements of the dataset
print_datapoint(dataset[1], lim=10)

Total episodes transcripts: 113


  0%|          | 0/113 [00:00<?, ?it/s]

+----------+------------------+---------+
|   ep_num | guest            | title   |
|        6 | Guido van Rossum | Python  |
+----------+------------------+---------+
+---------------------------------------------------------+--------------+
| text                                                    | indices      |
| The following is a conversation with Guido van Rossum,  | (0, 353)     |
| creator of Python, one of the most popular  programming |              |
| languages in the world, used in almost any application  |              |
| that involves computers  from web back end development  |              |
| to psychology, neuroscience, computer vision, robotics, |              |
| deep  learning, natural language processing, and almost |              |
| any subfield of AI.                                     |              |
+---------------------------------------------------------+--------------+
| This conversation is part of  MIT course on artificial  | (354, 470)   |
| gener

## 🧪 Semantic Encoding

In [124]:
from sentence_transformers import SentenceTransformer

# Can also be used from 🤗-Transformers:
# https://huggingface.co/sebastian-hofstaetter/distilbert-dot-tas_b-b256-msmarco
embedder = SentenceTransformer('msmarco-distilbert-base-tas-b')

In [None]:
data_iter = tqdm(dataset)

for dpoint in data_iter:
    data_iter.set_description(f"{dpoint['ep_num']}: {dpoint['title']}")
    dpoint["embeddings"] = embedder.encode(dpoint['texts'])

  0%|          | 0/113 [00:00<?, ?it/s]

In [None]:
dataset[0]["embeddings"].shape

## 📦️ Indexing: Distributed FAISS

First we need to run the FAISS server:

```python
python dfaiss_server.py \
    --log-dir ./logs \
    --partition 1 \
    --discovery-config dfaiss.discovery \
    --num-servers 1 \
    --num-servers-per-node 1 \
    --mem-gb 4 \
    --load-index
```

In [66]:
import json


from distributed_faiss.client import IndexClient
from distributed_faiss.index_cfg import IndexCfg


def _validate_metric(self):
    if self.metric not in VALID_METRICS:
        logger.error(
            f"{self.metric} is not a valid Metric. "
            f"Try to choose between {VALID_METRICS}"
        )
        raise ValueError(
            f"{self.metric} is not a valid Metric. "
            f"Try to choose between {VALID_METRICS}"
        )
        
        
def init_client(index_id:str, index_cfg_file:str, discovery_file:str):
    with open(index_cfg_file, "r") as f:
        cfg = json.load(f)
        idx_cfg = IndexCfg(**cfg)

    index_client = IndexClient(discovery_file)

    if not index_client.index_loaded(index_id):
        idx_loaded = index_client.load_index(index_id, idx_cfg)
        if idx_loaded is False:
            print(f"Index {index_id} hasn't been loaded. Creating new index...")
            index_client.create_index(index_id, idx_cfg)
            
    return index_client

In [176]:
VALID_INDEX_TYPES = ["Flat", "IVF"]
VALID_METRICS = [
    "l2",
    "dot",
]

# configuration and discovery files
discovery_file = "./dfaiss.discovery"
index_cfg_file = "./idx_cfg.json"
index_id = "local"

In [177]:
# Create the dfaiss client
index_client = init_client(index_id, index_cfg_file, discovery_file)

data_iter = tqdm(dataset)


for dpoint in data_iter:
    data_iter.set_description(f"{dpoint['ep_num']}: {dpoint['title']}")
    embeddings = dpoint["embeddings"]
    n_vec = embeddings.shape[0]
    meta = {dpoint[k] for k in ["ep_num", "guest", "title"]}
    index_client.add_index_data(
        index_id, 
        embeddings, 
        [{"id": tid, **meta} for tid in dpoint["text_ids"]]
    )
    chunk_id += n_vec
    
index_client.sync_train(index_id)

connecting jose-N501VW 12032 AddressFamily.AF_INET
connecting jose-N501VW 12033 AddressFamily.AF_INET
Index local hasn't been loaded. Creating new index...


  0%|          | 0/113 [00:00<?, ?it/s]

In [178]:
index_client.add_buffer_to_index(index_id)
index_client.save_index(index_id)

print(f"num servers: {index_client.get_num_servers()}")
print(f"Index states: {index_client.get_all_states(index_id)}")

num servers: 2
Index states: [<IndexState.TRAINED: 4>, <IndexState.TRAINED: 4>]


In [179]:
index_client.get_ntotal(index_id)

153933

## 🔍️ QA Query

In [181]:
index_loaded = any(index_client.all_index_loaded(index_id))
print(f"Index loaded: {index_loaded}")

index_client.search(
    embedder.encode(["Who's Guido van Rossum?"]),
    top_k=10,
    index_id=index_id,
    return_embeddings=False,
    filter_dict={},
)

Index loaded: True


([array([-109.61733 , -106.72429 , -106.32414 , -105.41661 , -104.84072 ,
         -104.7509  , -104.35072 , -104.229034, -103.04799 , -102.74254 ],
        dtype=float32)],
 [[{'id': 774},
   {'id': 782},
   {'id': 766},
   {'id': 770},
   {'id': 759},
   {'id': 2339},
   {'id': 763},
   {'id': 755},
   {'id': 751},
   {'id': 745}]])