# Timing PLAID Searching

In [None]:
import timeit
import torch
import random
import wandb
from datasets import load_dataset
from ettcl.modeling import ColBERTModel, ColBERTTokenizer
from ettcl.encoding import ColBERTEncoder
from ettcl.searching.colbert_searcher import ColBERTSearcher, _SearcherSettings
from ettcl.logging import configure_logger
import colbert.search.index_storage as index_storage

configure_logger("INFO")

model_path = "../training/imdb/bert-base-uncased/2023-06-30T09:30:28.027860/checkpoint-7500"
index_path = "../training/imdb/bert-base-uncased/2023-06-30T09:30:28.027860/checkpoint-7500/index"

In [None]:
dataset = load_dataset("imdb", split="train")
dataset

In [None]:
model = ColBERTModel.from_pretrained(model_path)
tokenizer = ColBERTTokenizer.from_pretrained(model_path)
encoder = ColBERTEncoder(model, tokenizer)
searcher = ColBERTSearcher(index_path, encoder)

In [None]:
n = 2_000

encoder.cuda()
Q = encoder.encode_queries(dataset.select(range(n))["text"], to_cpu=False)
encoder.cpu()

torch.cuda.empty_cache()
print("Memory:", torch.cuda.memory_allocated() / 1e9)

Q[0].shape, Q[0].device

In [None]:
setup = '''
import random
random.seed(12345)
'''

def search(searcher, args, Q, k):
    idx = random.randint(0, 2000)
    searcher.dense_search(Q[idx], k=128, args=args)

def profile(searcher, args, Q, k):
    torch.cuda.empty_cache()
    torch.cuda.reset_peak_memory_stats()

    idle_memory = torch.cuda.memory_allocated()

    run = wandb.init(
        project="performance-analysis",
        config={"k": k, **args.__dict__, "idle_memory": idle_memory},
        save_code=True,
    )


    r, n = 5, 200
    timer = timeit.Timer(
        "search(searcher, args, Q, k)",
        setup=setup,
        globals={"search": search, "searcher": searcher, "args": args, "Q": Q, "k": k}
    )

    time = min(timer.repeat(r, n)) / n
    print(time)

    print(f"{time * 1000:.3f} ms")
    memory = torch.cuda.max_memory_allocated()
    print(f"Memory: {memory / 1e9:.3f} GB ({idle_memory / 1e9:.3f} GB idle)")

    run.log({
        "execution_time": time,
        "max_memory": memory,
    })
    run.finish()

## Searching on GPU, approx. index operations on CPU/GPU

In [None]:
Q = [q.cuda() for q in Q]
searcher.ranker = index_storage.IndexScorer(index_path)

In [None]:
args = _SearcherSettings(
    ncells=1,
    centroid_score_threshold=0.8,
    plaid_num_elem_batch=3e8,
    skip_plaid_stage_3=False,
    plaid_stage_2_3_cpu=False,
)

profile(searcher, args, Q, k=256)

In [None]:
args = _SearcherSettings(
    ncells=1,
    centroid_score_threshold=0.8,
    plaid_num_elem_batch=3e8,
    skip_plaid_stage_3=True,
    plaid_stage_2_3_cpu=False,
)

profile(searcher, args, Q, k=256)

In [None]:
args = _SearcherSettings(
    ncells=1,
    plaid_num_elem_batch=3e8,
    skip_plaid_stage_3=False,
    plaid_stage_2_3_cpu=False,
)

profile(searcher, args, Q, k=256)

In [None]:
args = _SearcherSettings(
    ncells=1,
    plaid_num_elem_batch=3e8,
    skip_plaid_stage_3=True,
    plaid_stage_2_3_cpu=False,
)

profile(searcher, args, Q, k=256)

In [None]:
args = _SearcherSettings(
    ncells=1,
    centroid_score_threshold=0.8,
    plaid_num_elem_batch=3e9,
    skip_plaid_stage_3=False,
    plaid_stage_2_3_cpu=False,
)

profile(searcher, args, Q, k=256)

In [None]:
args = _SearcherSettings(
    ncells=1,
    centroid_score_threshold=0.8,
    plaid_num_elem_batch=3e9,
    skip_plaid_stage_3=True,
    plaid_stage_2_3_cpu=False,
)

profile(searcher, args, Q, k=256)

In [None]:
args = _SearcherSettings(
    ncells=1,
    plaid_num_elem_batch=3e9,
    skip_plaid_stage_3=False,
    plaid_stage_2_3_cpu=False,
)

profile(searcher, args, Q, k=256)

In [None]:
args = _SearcherSettings(
    ncells=1,
    plaid_num_elem_batch=3e9,
    skip_plaid_stage_3=True,
    plaid_stage_2_3_cpu=False,
)

profile(searcher, args, Q, k=256)

In [None]:
args = _SearcherSettings(
    ncells=1,
    centroid_score_threshold=0.8,
    plaid_stage_2_3_cpu=True,
)

profile(searcher, args, Q, k=256)

In [None]:
args = _SearcherSettings(
    ncells=1,
    plaid_stage_2_3_cpu=True,
)

profile(searcher, args, Q, k=256)

## Searching entirely on CPU

In [None]:
Q = [q.cpu() for q in Q]
searcher.ranker = index_storage.IndexScorer(index_path, use_gpu=False)

In [None]:
args = _SearcherSettings(
    gpus=0,
    ncells=1,
    centroid_score_threshold=0.8,
)

profile(searcher, args, Q, k=256)

## Benchmark: Searching whole dataset

In [None]:
import torch
import wandb
from datasets import load_dataset
from ettcl.modeling import ColBERTModel, ColBERTTokenizer
from ettcl.encoding import ColBERTEncoder
from ettcl.searching.colbert_searcher import ColBERTSearcher, ColBERTSearcherConfig
from ettcl.utils.multiprocessing import run_multiprocessed
from ettcl.utils import catchtime

model_path = "../training/imdb/bert-base-uncased/2023-06-30T09:30:28.027860/checkpoint-7500"
index_path = "../training/imdb/bert-base-uncased/2023-06-30T09:30:28.027860/checkpoint-7500/index"

In [None]:
dataset = load_dataset("imdb", split="train")
dataset

In [None]:
model = ColBERTModel.from_pretrained(model_path)
tokenizer = ColBERTTokenizer.from_pretrained(model_path)
encoder = ColBERTEncoder(model, tokenizer)

In [None]:
def search(dataset, config, num_proc, k):
    searcher = ColBERTSearcher(index_path, encoder, config)
    dataset = dataset.map(
        run_multiprocessed(searcher.search),
        input_columns="text",
        fn_kwargs={"k": k},
        batched=True,
        batch_size=1_000,
        with_rank=True,
        num_proc=num_proc,
    )

def profile(dataset, config, num_proc, k):
    torch.cuda.empty_cache()
    torch.cuda.reset_peak_memory_stats()
    idle_memory = torch.cuda.memory_allocated()

    run = wandb.init(
        project="performance-analysis",
        config={"k": k, "num_proc": num_proc, "idle_memory": idle_memory, **config.__dict__},
        save_code=True,
    )

    with catchtime() as time:
        search(dataset, config, num_proc, k)

    run.log({"execution_time": time})
    run.finish()

In [None]:
config = ColBERTSearcherConfig(plaid_stage_2_3_cpu=True)
profile(dataset, config, num_proc=4, k=256)

In [None]:
config = ColBERTSearcherConfig(plaid_stage_2_3_cpu=True)
profile(dataset, config, num_proc=2, k=256)

In [None]:
config = ColBERTSearcherConfig(plaid_stage_2_3_cpu=False)
profile(dataset, config, num_proc=4, k=256)

In [None]:
config = ColBERTSearcherConfig(plaid_stage_2_3_cpu=False)
profile(dataset, config, num_proc=2, k=256)