-
Notifications
You must be signed in to change notification settings - Fork 1
/
load_reference_data.py
108 lines (89 loc) · 2.9 KB
/
load_reference_data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import os
import csv
import openai
import numpy as np
from dotenv import load_dotenv
from redisvl.index import SearchIndex
from typing import List, Dict
import ast
import time
import argparse
from rich import print
# Load secrets
load_dotenv()
openai.api_key = os.getenv("OPENAI_API_KEY")
OPENAI_EMBEDDING_MODEL = "text-embedding-ada-002"
def get_embedding(doc):
response = openai.Embedding.create(
input=doc, model=OPENAI_EMBEDDING_MODEL, encoding_format="float"
)
embedding = response["data"][0]["embedding"]
return embedding
def get_data(filepath) -> List[Dict]:
print(f"Using the input file {filepath} to generate embeddings...")
records = []
with open(filepath, "r") as f:
reader = csv.DictReader(f)
for d in reader:
records.append(
{
"Narration": d["Narration"],
"Category": d["Category"],
"Narration_Embedding": get_embedding(d["Narration"]),
}
)
return records
def save_embeddings(filepath, dict_list):
fieldnames = dict_list[0].keys()
with open(filepath, "w", newline="") as csvfile:
writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
writer.writeheader()
for row in dict_list:
writer.writerow(row)
def create_and_load_index(records):
index = SearchIndex.from_yaml("index.yaml")
index.connect("redis://localhost:6379")
index.create(overwrite=True)
index.load(records)
def get_data_from_file(filepath):
print(f"Using the input file {filepath} to load embeddings...")
records = []
with open(filepath, "r") as f:
reader = csv.reader(f)
for i,row in enumerate(reader):
if i == 0:
continue
embedding = ast.literal_eval(row[2])
records.append(
{
"Narration": row[0],
"Category": row[1],
"Narration_Embedding": np.array(
embedding, dtype=np.float32
).tobytes(),
}
)
return records
def main():
# Command Line arguments
argparser = argparse.ArgumentParser()
argparser.add_argument(
"-regen",
action="store_true",
help="Flush the target and reload all. Use very carefully, usually -a should suffice",
)
args = argparser.parse_args()
if args.regen:
records = get_data(filepath="data/labelled.csv")
save_embeddings(
filepath="data/embeddings.csv",
dict_list=records,
)
start_load_redis = time.time()
records_with_embeddings = get_data_from_file(filepath="data/embeddings.csv")
create_and_load_index(records_with_embeddings)
print(
f"Vector Database Loaded! ( {round(time.time() - start_load_redis,2)} seconds )"
)
if __name__ == "__main__":
main()