# Lexicap

This notebook aims to index [Lex Fridman's podcasts](https://www.youtube.com/playlist?list=PLrAXtmErZgOdP_8GztsuKi9nrraNbKKp4) 
transcriptions for Question-Answering using [Andrej Karathy's](https://twitter.com/karpathy) transcriptions produced with [OpenAI's whisper](https://github.com/openai/whisper/blob/main/model-card.md) 👉️ i.e.: [Lexicap](https://karpathy.ai/lexicap/).

At the moment this code relies on a private package from MeliorAI, namely [distributed-faiss](https://github.com/MeliorAI/distributed-faiss) for distributed indexing using [FAISS](https://github.com/facebookresearch/faiss) as the underlying index. Every other package is openly available. 

In [None]:
!pip install -U pip
!pip install -qq \
    'distributed_faiss~=0.1.0' \
    nltk \
    tabulate \
    sentence_transformers \
    webvtt-py

## 📚️ Data 

In [1]:
import nltk


try:
    nltk.data.find("tokenizers/punkt")
except LookupError:
    nltk.download("punkt")
            

DATA_DIR = "./data"

## 🔧 Helpers

### 🎙️ VTT utils

In [43]:
import os
import glob
import re
import webvtt

from tabulate import tabulate
from tqdm.auto import tqdm
from typing import *


def gather_transcripts(data_dir:str = DATA_DIR, load_small:bool = False):
    mask = "*_small.vtt" if load_small else "*_large.vtt"
    return sorted(
        glob.glob(os.path.join(data_dir, "vtt", mask))
    )


def gather_episode_data(data_dir:str = DATA_DIR):
    ep_data = {}
    with open(os.path.join(data_dir, "episode_names.txt")) as f:
        for line in f.readlines():
            line = line.strip()
            ep_num = re.findall("(?<=#)\d+", line)[0]
            line = re.sub(f"^{ep_num} ", "", line)
            guest_and_title = line.split(" | ")[0]
            ep_data[ep_num] = {
                "guest": guest_and_title.split(": ")[0], 
                "title": ": ".join(guest_and_title.split(": ")[1:])
            }
    
    return ep_data

    
def sent_split(text:str) -> Dict[str, Union[str, int]]:
    """Divides each document section into several chunks based on
    a NLTK **sentence** tokenizer.
    """
    text_chunks = nltk.tokenize.sent_tokenize(text)
    text_indices = list(map(
        lambda m: (m.start(), m.end()) if m else (None, None),
        [re.search(re.escape(text_chunk), text) for text_chunk in text_chunks]
    ))
    
    return [
        {"text": text, "start": indices[0], "end": indices[1]} 
        for text, indices in zip(text_chunks, text_indices)
    ]


def print_datapoint(datapoint, lim:int = -1):
    keys = ["ep_num", "guest", "title"]
    print(tabulate([[datapoint[k] for k in keys]], 
                   headers=keys,
                  tablefmt="grid"))
    
    readout_list = list(map(lambda x: 
                            (x['text'], str(x['start']), str(x['end'])), 
                            datapoint["text_chunks"][:lim])
                       )
    print(tabulate(readout_list, 
                   headers=["text", "start", "end"],
                   tablefmt="grid",
                   maxcolwidths=55))
    
    
def build_dataset(split_by_time:bool = False):
    
    def split_with_timestamp(sfile:str):
        chunks = []
        buffer = []
        start = None
        for caption in webvtt.read(sfile):
            text = caption.text
            if not re.search("\w[\.\?\!]$$", text):
                buffer.append(text)
                if start is None:
                    start = caption.start
            else:
                chunks.append({
                    "start": start or caption.start,
                    "end": caption.end,
                    "text": (" ".join(buffer) + text).strip()
                })
                buffer = []
                start = None

        return chunks
    
    def split_with_text_indices(sfile:str):
        text = " ".join([caption.text for caption in webvtt.read(sfile)])
        return sent_split(text)
        
        
    script_files = gather_transcripts()
    ep_data = gather_episode_data()
    print(f"Total episodes transcripts: {len(script_files)}")

    # Compile a dataset of transcripts with its episode information
    dataset = []
    data_iter = tqdm(script_files)
    sent_id = 0
    for sfile in data_iter:
        fname = os.path.basename(sfile)
        data_iter.set_description(fname)

        # Episode number
        ep_num = str(int(re.findall("\d{1,3}", fname)[0]))  # trim leading 0's
        ep_info = ep_data[ep_num]

        # Read transcript and split in sentences
        if split_by_time:
            chunks = split_with_timestamp(sfile)
        else:
            chunks = split_with_text_indices(sfile)
            
        nt = len(chunks)

        ep_info.update({
            "text_ids": list(range(sent_id, sent_id + nt)),
            "text_chunks": chunks,
            "ep_num": ep_num
        })
        
        dataset.append(ep_info)
        sent_id += nt
        
    return dataset

### ⏱️ Timestamp utils

In [63]:
import copy
import datetime as dt


def cluster_in_time(datapoint, group_secs:int = 60):
    
    def parse_ts(ts:str):
        return dt.datetime.strptime(ts, "%H:%M:%S.%f")


    def delta_secs(t1:str, t2:str):
        return (parse_ts(t2) - parse_ts(t1)).seconds
    
    def add_group():
        groups.append({
            "start": start['start'],
            "end": end['end'],
            "text": " ".join([b['text'] for b in buffer])
        })
        
    chunks = datapoint['text_chunks']
    groups = []
    ini = 0
    fin = 1
    start = chunks[ini]
    end = chunks[fin]
    buffer = [start]
    while fin < len(chunks):
        end = chunks[fin]
        
        if delta_secs(t1=start['start'], t2=end['end']) < group_secs:
            buffer.append(end)
            
        else:
            add_group()
            ini = fin
            start = chunks[ini]
            buffer = [end]
            
        fin += 1
        
    if len(buffer) > 0:
        add_group()
        
    return groups


## 👀 Dataset visualization

In [64]:
dataset = build_dataset(split_by_time=True)

# print one of the elements of the dataset
print_datapoint(dataset[0], lim=5)

# print the same data point with text grouped in ~60 second chunks
dat = copy.deepcopy(dataset[0])
dat.update({"text_chunks": cluster_in_time(dat)})
print_datapoint(dat, lim=15)

Total episodes transcripts: 319


  0%|          | 0/319 [00:00<?, ?it/s]

+----------+-------------+----------+
|   ep_num | guest       | title    |
|        1 | Max Tegmark | Life 3.0 |
+----------+-------------+----------+
+--------------------------------------------------------+--------------+--------------+
| text                                                   | start        | end          |
| As part of MIT course 6S099, Artificial General        | 00:00:00.000 | 00:00:06.600 |
| Intelligence, I've gotten the chance to sit down with  |              |              |
| Max Tegmark.                                           |              |              |
+--------------------------------------------------------+--------------+--------------+
| He is a professor here at MIT.                         | 00:00:06.600 | 00:00:08.680 |
+--------------------------------------------------------+--------------+--------------+
| He's a physicist, spent a large part of his career     | 00:00:08.680 | 00:00:16.960 |
| studying the mysteries of our cosmological un

## 🧪 Semantic Encoding

In [4]:
from sentence_transformers import SentenceTransformer

# Can also be used from 🤗-Transformers:
# https://huggingface.co/sebastian-hofstaetter/distilbert-dot-tas_b-b256-msmarco
embedder = SentenceTransformer('msmarco-distilbert-base-tas-b')

In [5]:
data_iter = tqdm(dataset)

for dpoint in data_iter:
    data_iter.set_description(f"{dpoint['ep_num']}: {dpoint['title']}")
    dpoint["embeddings"] = embedder.encode(dpoint['texts'])

  0%|          | 0/113 [00:00<?, ?it/s]

In [6]:
dataset[0]["embeddings"].shape

(737, 768)

## 📦️ Indexing: Distributed FAISS

First we need to run the FAISS server (in a separate terminal):

```python
python dfaiss_server.py \
    --log-dir ./logs \
    --partition 1 \
    --discovery-config dfaiss.discovery \
    --num-servers 1 \
    --num-servers-per-node 1 \
    --mem-gb 4 \
    --load-index
```

In [7]:
import json


from distributed_faiss.client import IndexClient
from distributed_faiss.index_cfg import IndexCfg


def _validate_metric(self):
    if self.metric not in VALID_METRICS:
        logger.error(
            f"{self.metric} is not a valid Metric. "
            f"Try to choose between {VALID_METRICS}"
        )
        raise ValueError(
            f"{self.metric} is not a valid Metric. "
            f"Try to choose between {VALID_METRICS}"
        )
        
        
def init_client(index_id:str, index_cfg_file:str, discovery_file:str):
    with open(index_cfg_file, "r") as f:
        cfg = json.load(f)
        idx_cfg = IndexCfg(**cfg)

    index_client = IndexClient(discovery_file)

    if not index_client.index_loaded(index_id):
        idx_loaded = index_client.load_index(index_id, idx_cfg)
        if idx_loaded is False:
            print(f"Index {index_id} hasn't been loaded. Creating new index...")
            index_client.create_index(index_id, idx_cfg)
            
    return index_client

In [8]:
VALID_INDEX_TYPES = ["Flat", "IVF"]
VALID_METRICS = [
    "l2",
    "dot",
]

# configuration and discovery files
discovery_file = "./dfaiss.discovery"
index_cfg_file = "./idx_cfg.json"
index_id = "local"

In [10]:
# Create the dfaiss client
index_client = init_client(index_id, index_cfg_file, discovery_file)

data_iter = tqdm(dataset)

for dpoint in data_iter:
    data_iter.set_description(f"{dpoint['ep_num']}: {dpoint['title']}")
    embeddings = dpoint["embeddings"]
    n_vec = embeddings.shape[0]
    meta = {k: dpoint[k] for k in ["ep_num", "guest", "title"]}
    index_client.add_index_data(
        index_id, 
        embeddings, 
        [{"id": tid, **meta} for tid in dpoint["text_ids"]]
    )
    
index_client.sync_train(index_id)
index_client.add_buffer_to_index(index_id)
index_client.save_index(index_id)

connecting jose-N501VW 12032 AddressFamily.AF_INET


  0%|          | 0/113 [00:00<?, ?it/s]

In [219]:
index_client = init_client(index_id, index_cfg_file, discovery_file)

connecting jose-N501VW 12032 AddressFamily.AF_INET
connecting jose-N501VW 12033 AddressFamily.AF_INET


In [11]:
print(f"num servers: {index_client.get_num_servers()}")
for sn, state in enumerate(index_client.get_all_states(index_id)):
    print(f"Server {sn+1}: {state}")

num servers: 1
Server 1: IndexState.ADD


In [12]:
index_client.get_ntotal(index_id)

76984

## 🔍️ QA Query

In [13]:
index_loaded = any(index_client.all_index_loaded(index_id))
if not index_loaded:
    index_client.load_index(index_id)
    
index_loaded = any(index_client.all_index_loaded(index_id))
print(f"Index loaded: {index_loaded}")

scores, indices = index_client.search(
    embedder.encode(["Which boooks where written by Max Tegmark?"]),
    top_k=10,
    index_id=index_id,
    return_embeddings=False,
    filter_dict={},
)
print(indices)
print(scores)

Index loaded: True
[[{'id': 96511}, {'id': 25577}, {'id': 96600}, {'id': 56700}, {'id': 64434}, {'id': 0}, {'id': 5391}, {'id': 1808}, {'id': 149383}, {'id': 31468}]]
[array([-99.35968 , -98.620026, -97.52474 , -97.056656, -96.73923 ,
       -96.72657 , -95.80153 , -95.267555, -93.477715, -93.45641 ],
      dtype=float32)]


In [14]:
ep_to_dp_idx = {
    dpoint["ep_num"]: i 
    for i, dpoint in enumerate(dataset)
}

tid_to_ep = {
    tid: dpoint["ep_num"]
    for dpoint in dataset
    for tid in dpoint["text_ids"]
}


def tid_to_episode(tid:int):
    return dataset[ep_to_dp_idx[tid_to_ep[tid]]]


def tid_to_text(tid:int):
    epi = tid_to_episode(tid)
    idx = epi["text_ids"].index(tid)
    return epi["texts"][idx]

In [15]:
for ind in indices[0]:
    tid = ind["id"]
    print(f"{tid} --> {tid_to_text(tid)}")

96511 --> ('And I guess I follow Max Tegmark here.', (154099, 154137))
25577 --> ("And by the way, I'm borrowing from Max Tegmark  for some of these metaphors, the physicist.", (93069, 93160))
96600 --> ("I mean, talking about people like Lee Small  and Alan Guth, Max Tegmark, okay, we're really smart.", (163510, 163608))
56700 --> ("Tegmark, I view as a philosopher  who is somehow taking credit for Platonism,  which I don't see any reason for fighting with Max  because I like Max, but if it ever comes time,  I'm putting a post it note that I'm not positive  the mathematical universe hypothesis  is really anything new.", (125870, 126160))
64434 --> ('Like you have like the Max Tegmark,  young version of Max Tegmark,  who knows how to play the role of boring and fitting in.', (80770, 80894))
0 --> (" As part of MIT course 6S099, Artificial General Intelligence,  I've gotten the chance to sit down with Max Tegmark.", (0, 116))
5391 --> ('Actually, a lot of physicists, Max Tegmark,  peopl