Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reduce search multithreading overhead #1892

Open
dranikpg opened this issue Sep 19, 2023 · 35 comments
Open

Reduce search multithreading overhead #1892

dranikpg opened this issue Sep 19, 2023 · 35 comments
Assignees

Comments

@dranikpg
Copy link
Contributor

Currently document iteration order is shard-sequential, meaning first all documents are listed from shard 0, then from shard 1, etc...

shard 0      shard 1       shard 2
[0, 1, 2] -> [0, 1, 2, 3] -> [0, 1]

Alternatively, we can list documents as the first documents from all shards, the second document from all shards, etc This will allow us to reduce the number of elements we fetch from each shard

(doc id, shard id)

(0, 0) (0, 1) (0, 2) (1, 0) (1, 1) (1, 2) (2, 0) (2, 1) (3, 1)

Now if we need to fetch two items, we can fetch only one from each shard -> and we have enough.

Reducing the number of fetched items is probabilistic. For example, on 16 shards, we can assume its very unlikely that even a single shard contains more than half of the documents. If it really is the case, we are forced to perform one more hop. The specific weights need to be tuned.

@dranikpg dranikpg self-assigned this Sep 19, 2023
@dranikpg
Copy link
Contributor Author

use std::io::BufRead;
use std::iter;

use rand::seq::SliceRandom;
use rand::thread_rng;

fn run(n: usize, m: usize, total: usize) {
    let mut res = vec![0; m];
    let mut basevec = Vec::new();
    for s in 0..n {
        basevec.extend(iter::repeat(s).take(m));
    }
    for _ in 0..total {
        basevec.shuffle(&mut thread_rng());
        let mut cnts = vec![0; n];
        for e in &basevec[0..m] {
            cnts[*e] += 1;
        }
        res[*cnts.iter().max().unwrap() - 1] += 1;
    }
    let mut prob: f64 = 0.0;
    for (v, c) in res.iter().enumerate() {
        prob += *c as f64 / total as f64;
        println!("{} => {:.5}", v, prob * 100.0);
    }
    println!("=====");
}
fn main() {
    loop {
        let mut buf = String::new();
        std::io::stdin().lock().read_line(&mut buf).ok();
        let v: Vec<_> = buf
            .split(' ')
            .map(|s| s.trim().parse::<usize>().unwrap())
            .collect();
        run(v[0], v[1], v[2]);
    }
}

@dranikpg
Copy link
Contributor Author

dranikpg commented Sep 20, 2023

n shards, top m are requested

x - axis - number of shards, one plot for different m values
blue - 99th percentile, if we pick that number for m we will cover at least 99% of queries
orange - my cheap estimation return min(m//n+ 1 + max(m//4+1, 3), m))

We see its most tight for the bound 10, which also seems to be one of the most common requests

image

@royjacobson
Copy link
Contributor

royjacobson commented Sep 21, 2023

I tried to do back-of-the-envelope calculations for this and failed miserably because order statistics is hard :)

However I did find this and they suggest using

p = 0.01
def bound(shards, num_results):
    return num_results / shards + sqrt(num_results / 2 * (log(shards) - log(p)))

I don't think it's super tight It's not a tight bound at all, but I think it's still interesting.

@dranikpg
Copy link
Contributor Author

@royjacobson if you like math you're welcomed to join me! 🎩 I tried using it (I studied Probability theory for a year!! 😆 ) but it quickly became evident that there is no way to verify your results. And besides writing an experiment is faster and more reliable

@dranikpg
Copy link
Contributor Author

The next optimization is for structured search. Approaching this problem very simply and trying to find a very reliable upper bound

Given an query with LIMIT m on n shards, we will take at most m//n+1 from every shard if all shards have enough hits.

Now, imagine we have found k results on shard x. We can choose to return m//n+1 instead of m if we are 99% sure all shards have at least m//n+1 hits. All we have is this single k (and the shards can't communicate).

Let's be pessimistic and assume our shard x has already the highest number of hits, so all others will be lower. Next we want to find how for a query max k correlates to min k.

For this I've written a simple script

fn run(n: u32) {
    let total = 300_000;
    let mut out = HashMap::<u32, Vec<u32>>::new();

    for _ in 0..total {
        let m: usize = thread_rng().gen::<usize>() % 2000usize + 1;
        let mut bv = vec![0u32; m];
        for v in &mut bv {
            *v = thread_rng().gen::<u32>() % n;
        }
        let mut cnts = vec![0u32; n as usize];
        for v in bv {
            cnts[v as usize] += 1;
        }
        //let vavg = cnts.iter().sum::<u32>() / cnts.len() as u32;
        let vmax = *cnts.iter().max().unwrap();
        let vmin = *cnts.iter().min().unwrap();

        if vmax > 100 {
            continue;
        }

        out.entry(vmax).or_insert_with(|| vec![]).push(vmin);
    }

    let mut ps: Vec<(u32, u32)> = vec![];
    for (k, mut v) in out {
        v.sort();
        let i = (v.len() as f32 * 0.01f32).floor();
        ps.push((k, v[i as usize]));
    }
    ps.sort();

    let vs: Vec<u32> = ps.iter().map(|(k, v)| *v).collect();
    println!("{:?}", vs);
}

@dranikpg
Copy link
Contributor Author

The blue line is again the 99th percentile. It means that if we got x results on some shard, we can assume that all shards will have at least y results with 99% prob.

I experimentally deduced a bound again: k*math.log2(k)//(12+n/10)-min(k, 5) (orange)

image

@dranikpg
Copy link
Contributor Author

I don't think it's super tight It's not a tight bound at all, but I think it's still interesting.

True. Its not so much about the tightest formula now, but about what assumptions we can generally make and what techniques we can use. And what logical pitfalls there are

For example, optimization (1) can only be applied after check (2). Because our formulas for (1) only hold if every shard really has at least m values

@romange
Copy link
Collaborator

romange commented Sep 22, 2023

  1. I am curious, why do you choose writing scripts in rust? Do you find it more convenient?
  2. You may relax the requirement of non-communication. You can make k atomic (i.e. found so far) and improve your estimation with each shard execution. Assuming that shards do not run exactly at the same time you may improve your recall. The disadvantage (and big I would say) is that it becomes non-deterministic.

@romange
Copy link
Collaborator

romange commented Sep 22, 2023

Your approach is very good! Did you think of it yourself or you read a paper? Seems like something that worthy publishing in academic literature.

@dranikpg
Copy link
Contributor Author

I am curious, why do you choose writing scripts in rust? Do you find it more convenient?

It would take ages in Python, etc. So I can only write them in GO/C++/Rust and I like miss Rust the most 😭

You may relax the requirement of non-communication. You can make k atomic

That is true and I thought about it, but the first shard to finish has to make the decision on its own. If we improve the estimation for shards k... (k reliable datapoints) and onward it won't help much because our first shard will keep ruining everything

Your approach is very good! Did you think of it yourself or you read a paper? Seems like something that worthy publishing in academic literature.

Thanks 😄 I didn't, I tried searching for quite some time but didn't find much. There certainly must be something on that topic on arxiv. Anyways I was just thinking about what really obvious optimizations we can apply

@romange
Copy link
Collaborator

romange commented Sep 22, 2023

  1. I did not understand the code (and not because I do not understand rust) - i understand each line but hardly understand the intent. Could you add comments?

  2. You simplify the problem to either return m//n + 1 or m. It may be possible to generalise it to m//n + \delta, where \delta >= 1 and you can tune it.

@dranikpg
Copy link
Contributor Author

I want those scripts somewhere in our repos! Besides the algorithmic side, they can be helpful to learn (copy paste) when we need to generate graphs.

I can re-write them. The graphs are a one-liner from python. Let me finish my research first 🤣

You simplify the problem to either return m//n + 1 or m. It may be possible to generalize it to m//n + \delta, where \delta >= 1 and you can tune it.

To some extent. If the bound is m//n or m//n-1 if m%n = 0 (one less than we might need) we can take additional n elements to make it a bound again for m//n+1 (not sure it it works for \delta) - that is in the worst case if we're the only savior shard

I will think about this, it requires a more complicated experiment

@romange
Copy link
Collaborator

romange commented Sep 22, 2023

I do not understand your last comment. Suppose we want to receive m=160 via n=10 shards.
Lowest bound you can fetch is 16 and highest obviously 160. In your formalisation you ask a question "can i set k to either 16 or 160 based on number of total results i got on the shard", if i understand correctly.

@dranikpg
Copy link
Contributor Author

dranikpg commented Sep 22, 2023

I do not understand your last comment. Suppose we want to receive m=160 via n=10 shards.
Lowest bound you can fetch is 16 and highest obviously 160. In your formalisation you ask a question "can i set k to either 16 or 160 based on number of total results i got on the shard", if i understand correctly.

You forgot that k is at most k :)

So for n=16 and m=160, we can be 99% sure that each shard has 16 entries if we found ~40 or more entries (that's the blue line). Hence we care only about values < 40

Lets say we found 30 entries, then with 99% each shard has at least ~8 entries, which takes us to a bound of 128 with 99%. Hence shrinking 30 -> down is already risky. Yet it is also extremely unlikely that every other shard lies on the lower bound, but my model just doesn't take that into account

Instead of calculating expected 99th percentile min k the model should calculate expected 99th percentile sum - max k (or better avg k, i actually have that commented out because I thought of it), so we can estimate how much we actually overshoot

Regarding max variance for knn from experiment (1), I'm not sure if the values technically hold for the sum bound. We expect there every shard to have at least m, not their sum to me m. Because if their sum is m, then of course the max number of appearances of a single shard can be higher (but not much because on average they have the same number of results of around ~m/n)

@dranikpg
Copy link
Contributor Author

dranikpg commented Sep 22, 2023

That is the 99th percentile bound for avg k, 16 would be reached with k= ~30

But I don't yet fully understand what conclusions we can make from avg k . It surely implies that with prob 99% we sum up to S = avg k * n so we overshoot by S - m. Does it mean we can leave out (S - m) / n for each shard?

But how does leaving out on all shards a random value affect the sum in the end?

Edit: i.e. doesn't the sum estimation imply we will have such outliers that will balance it. We need to simulate that one

image

@romange
Copy link
Collaborator

romange commented Sep 22, 2023

I am sorry, I am like 10 steps behind you 😄 So for n=16 and m=160, we can be 99% sure that each shard has 16 entries if we found ~40 or more entries - how did you derive this conclusion? I guess it's from the graph but I am not sure how you built your graph.

@romange
Copy link
Collaborator

romange commented Sep 22, 2023

I do not know if this helps or not but I asked this question on math exchange:

https://math.stackexchange.com/questions/4773766/estimate-total-balls-based-on-the-ball-count-in-a-bin

@dranikpg
Copy link
Contributor Author

dranikpg commented Sep 22, 2023

how did you derive this conclusion?

This is my second script

Pure experiment repeated a few million times. First, randomize the total number of results (namely m), then randomly assign each element (0...m) to one of the shards (0..n). Calculate min k, avg k, max k and store the pair (max k -> min k).

At the end, group by max k (first element of the pair) and sort by min k in each group, so we get a map<int, vector<int>>. Then take the lowest 1% for each max k

P.S. There is a low limit on m (10k), which makes the experiment pessimistic. Also, I assume that m is evenly distributed over the range 0-10k, which is obviously not true. I assume in most cases there will be either very few data (specific queries) or very much data (broad queries). I tried generating in the range 0-200 and my lower bound still holds, but its much more tight

@romange
Copy link
Collaborator

romange commented Sep 23, 2023

You are doing the simulation of balls into bins problem:
https://en.wikipedia.org/wiki/Balls_into_bins_problem#Random_allocation

@romange
Copy link
Collaborator

romange commented Sep 23, 2023

Attaching (what I think) is a relevant article.

The ‘Balls Into Bins’ Process and Its Poisson Approximation _ by Dr. Robert Kübler _ Cantor’s Paradise.pdf

Specifically for m balls, n, s.t. $m \gg n ln(n)^3$,

$$ L \approx \frac{m}{n} + \sqrt{\frac{2m \ln n}{n}} $$

Now, you correctly stated that upon sampling a bin (shard) we should be pessimistic, hence we can assume this shard has the highest number of results, therefore for a sample of K results

$$\displaylines{ K \ge \frac{m}{n} + \sqrt{\frac{2m \ln n}{n}} \\\ \text{or} \\\ K^2n - 2Km + m^2 \ge 2m \ln n \\\ \\ \text{and finally,} \\\ m^2 + (2K - 2\ln n)m - K^2n \ge 0 }$$

where K, n are constants. I am not sure I understand what's the meaning of both roots in the quadratic equation but the first root appears to be negative, hence not relevant. We can estimate m as the second root.

@dranikpg
Copy link
Contributor Author

dranikpg commented Sep 23, 2023

Orange is the formula, blue is my experiment. You can see the bound is much more strict (but its actually of the same slope - so both the formula and experiment reaffirm each other)

image

I've also mainly googled about the balls into bins problem and the resulting binomial distribution. The problem is that most formulas are limits - with infinitely large input parameters the odds will be infinitely small. We neither have infinitely large numbers, neither need we the odds to be infinitely small, we can risk 1% in exchange for shaving off much more elements. Most of them were probably not meant to be used on small numbers

Either way we have found a decent approach. I'll start implementing it, the formula and numbers can be tuned later

@royjacobson
Copy link
Contributor

m/n is the variance of a poisson distribution. I interpret ln n is some correction for the change from one bin to maximum over all the bins. And you'd take a square of the variance to get the same scale as the expected value. I guess we can hand wave and multiply the root by the amount of standard deviations we want to be safe (apprx. 2.5-3 for 99 percentiles)

@dranikpg
Copy link
Contributor Author

dranikpg commented Sep 23, 2023

I can't decide on the following issue: how to re-fetch documents? My first and main goal is to cut pointless serialization time.

The most straightforward approach would be finding all search results and serializing only the first ones until the probabilistic bound. Since we already got the ids of the following ones, it'd be pity to not store them as sell. With that information, we don't actually need to repeat the query in an unlucky case, we just to fetch the remaining documents.

However its not clear how to perform the second hop atomically. Hint: its not possible because we don't know ahead if its the last hop. We can fetch documents outside of the transactional world. However as a result of this, the search command becomes non atomic. It could return theoretically incorrect results

Alternatively we need to repeat the whole query which is not cool. With a fast re-fetch approach we could even decide to lower the prob bound to ~90-95

@royjacobson
Copy link
Contributor

If it's a rare enough scenario, do we really need to fetch the remaining results?

@dranikpg
Copy link
Contributor Author

Interviewer: "I heard you were extremely quick at math"
Me: "yes, as a matter of fact I am"
Interviewer: "Whats 14x27"
Me: "49"
Interviewer: "that's not even close"
Me: "Yeah, but it was fast"


If it's a rare enough scenario, do we really need to fetch the remaining results?

It's up to 1% likely - that's not zero. I really assume we should adhere to correctness. I think in many search systems there is an agreement that returning less items than requested indicates that there are no more - I remember always checking this myself when working with sql databases.


I have one more idea💡Lets introduce write epochs to indices. If in-between the query and the follow up no writes happened, it is safe to return more documents from the same result set. Otherwise we need to perform the full query again

@chakaz
Copy link
Collaborator

chakaz commented Sep 24, 2023

+1 to not returning fewer results than requested unless there aren't enough

@romange
Copy link
Collaborator

romange commented Sep 24, 2023

@dranikpg if you feel strongly about returning all the results (and I agree with your arguments) then lets pay in latency. We should be fine that 1% queries will inhibit twice latency + CPU cost.

@dranikpg
Copy link
Contributor Author

Okay, I can keep it as a to-do with the write epochs

@dranikpg
Copy link
Contributor Author

dranikpg commented Sep 24, 2023

Damn, I still need to return all ids with scores to verify knn selected the most optimal ones, I can't blindly chop them off 😵

@romange
Copy link
Collaborator

romange commented Sep 24, 2023

Ah, in your analysis you focused on recall, but you have also have topK requirement. So now the question is how much to return between [K/N, K] from each shard so that with high probability you will be able to select top K, right?

@dranikpg
Copy link
Contributor Author

Yes, because without a top-k query there is no variance - I always select m/n documents from each shard (if there are enough on that shard), with top-k queries I sort them beforehand

@romange
Copy link
Collaborator

romange commented Sep 24, 2023

so maybe K/sqrt(N) ? Let's not be tight, lets choose some bound that works.

@dranikpg
Copy link
Contributor Author

My preliminary results with my prototype

DF 8 threads, 300k entries.
Queries: select ALL, limits: 5, 10, 20, 50, 100
Benchmark: memtier with t=6 pipeline=5

Resulting RPS:

New: 78k 72k 62k 43k 28k
Old: 66k 53k 39k 22k 12k

The difference grows with larger limit values. For limit 5, 10 the gap is suspiciously big 🤔

@dranikpg
Copy link
Contributor Author

so maybe K/sqrt(N) ? Let's not be tight, lets choose some bound that works.

I didn't touch yet KNN optimization. I need to benchmark it first and see how much sense it makes

  • FLAT runs in N LOG M time (n total, m selected), so even cutting m in half makes it not much faster
  • HNSW needs to do a full descent regardless of the number of points, but fetching 20 and 50 must be a big difference for it

@dranikpg
Copy link
Contributor Author

Profile for query with limit 10

Before, we see that serialization takes around 3s, the ratio is 3:5

image

Afterwards, its only 1s, the ratio is 1:7

image

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants