-
Notifications
You must be signed in to change notification settings - Fork 45
/
Copy pathsearch_sharded.py
120 lines (104 loc) · 3.79 KB
/
search_sharded.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import argparse
import json
import sys
import time
import requests
from datasets import load_dataset
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--input_dataset", type=str, default="HuggingFaceTB/bisac_expanded_final"
)
parser.add_argument("--n_pages", type=int, default=2000)
parser.add_argument(
"--output_dataset",
type=str,
default="HuggingFaceTB/bisac_boosted_new_index_2000",
)
parser.add_argument("--shard", type=int, required=True)
parser.add_argument("--num_shards", type=int, required=True)
return parser.parse_args()
# wait until the server is up
while True:
try:
requests.post(
"http://127.0.0.1:9308/search",
data='{"index": "fineweb", "query": {"match": {"content": "ping"}}}',
)
break
except requests.exceptions.ConnectionError:
time.sleep(10)
pass
args = get_args()
data = load_dataset(
args.input_dataset, split="train", cache_dir="/scratch/cosmo/.cache"
)
data = data.filter(lambda x, i: i % args.num_shards == args.shard, with_indices=True)
data = data.select_columns(["top_category", "subcategory", "subtopic"])
def run_query(query, n_pages):
while True:
try:
max_pages = 4_000
response = requests.post(
"http://127.0.0.1:9308/search",
data=json.dumps(
{
"index": "fineweb",
"size": n_pages,
"query": query,
"max_matches": max_pages,
}
),
timeout=1000,
)
if response.status_code != 200:
print(response.text, file=sys.stderr)
time.sleep(5)
continue
else:
hits = response.json()["hits"]["hits"]
return hits
except requests.exceptions.ConnectionError as e:
print(e, file=sys.stderr)
time.sleep(5)
continue
def search_topic(sample):
top_category = sample["top_category"][0].strip()
subcategory = sample["subcategory"][0].strip()
subtopic = sample["subtopic"][0].strip()
for c in ["!", '"', "$", "'", "(", ")", "/", "<", "@", "\\", "^", "|", "~"]:
top_category = top_category.replace(c, " ")
subcategory = subcategory.replace(c, " ")
subtopic = subtopic.replace(c, " ")
# boosting the IDF score of subtopic tokens
boosted_subtopic = " ".join([w + "^2" for w in subtopic.split()])
match_query = " ".join([top_category, subcategory, subtopic])
boosted_query = " ".join([top_category, subcategory, boosted_subtopic])
boosted_hits = run_query({"query_string": boosted_query}, args.n_pages)
print(f"Boosted hits: {len(boosted_hits)} for {boosted_query}", file=sys.stderr)
if len(boosted_hits) < args.n_pages:
match_hits = run_query(
{"match": {"content": match_query}}, args.n_pages + len(boosted_hits)
)
print(f"Match hits: {len(match_hits)} for {match_query}", file=sys.stderr)
else:
match_hits = []
hit_ids = set()
hits = []
for hit in boosted_hits + match_hits:
if hit["_id"] not in hit_ids:
hits.append(hit)
hit_ids.add(hit["_id"])
hits = hits[: args.n_pages]
results = {
"top_category": sample["top_category"] * len(hits),
"subcategory": sample["subcategory"] * len(hits),
"subtopic": sample["subtopic"] * len(hits),
"topic_hits": hits,
"num_hits": [len(hits)] * len(hits),
}
return results
data = data.map(search_topic, batched=True, batch_size=1, num_proc=2)
data.push_to_hub(
f"{args.output_dataset}_{args.shard}", private=True, max_shard_size="4096MB"
)