#### Dataset comes from: https://huggingface.co/ashraq/financial-news-articles
Milestone 2的相关代码，使用Pinecone作为向量数据库，all-mpnet-base-v2为Embedding模型

In [1]:
import numpy as np
import pyarrow.parquet as pq

In [2]:
df = pq.read_table('data/train-1-of-2.parquet').to_pandas()

In [3]:
NUM_TEXTS = 100000
VEC_BATCH_SIZE = 20

In [4]:
from sentence_transformers import SentenceTransformer
model = SentenceTransformer("models/all-mpnet-base-v2")

In [5]:
embs = list()
for i in range(0, NUM_TEXTS, VEC_BATCH_SIZE):
    text_batch = list(df['title'][i:i+VEC_BATCH_SIZE])
    embs.append(model.encode(text_batch))
len(embs), embs[0].shape

(5000, (20, 768))

In [6]:
vecs = np.concatenate(embs, axis=0)
vecs = [{'id':'news-'+str(seq_no), "values":vec} for seq_no, vec in zip(np.arange(len(vecs)), vecs)]
len(vecs), vecs[0]['id']

(100000, 'news-0')

In [7]:
vecs[0]['values'].shape

(768,)

### 使用Pinecone作为向量数据库

In [None]:
from pinecone import Pinecone, PodSpec

pc = Pinecone(api_key="4eb3035e-8ef3-41cb-b108-28b8b7e250b9")
index = pc.Index("rag4fin")
for i in range(79900, len(vecs), 20):
    if i % 1000 == 0:
        print('%s uploaded' % i)
    index.upsert(vecs[i:i+20])

In [None]:
# query = 'Is there any news about Microsoft?'
# query = 'What is the revenue of Microsoft in cloud business?'
query = 'What is the revenue of Microsoft in game?'
query = model.encode([query])[0].tolist()
results = index.query(vector=query, top_k=5, include_values=True)['matches']

for r in results:
    id = r['id']
    idx = int(id.split('-')[1])
    print(r['id'], r['score'], df['title'][idx])

In [None]:
from dateutil import parser

text = "Let's meet on the 5th of July, 2023, at 3 PM."
date = parser.parse(text, fuzzy=True)
date, type(date)

### 直接用Faiss

In [26]:
import os
import faiss

#### 创建向量数据库

In [8]:
split_nums = []

def get_text_splits(min_lines=20):
    num_rows = df.shape[0]
    lines = []
    
    for i in range(num_rows):
        lines += [df['title'][i]]
        lines_ = [line for line in df['text'][i].split('\n') if len(line.strip())>0]
        lines += lines_
        split_nums.append(len(lines_)+1)
        if len(lines) > min_lines:
            yield lines
            lines = []

In [None]:
def load_ckpt(dir):
    

In [23]:
def checkpoint(index, num, dir='index/'):
    filename = '%sindex_%s.idx' % (dir, num)
    faiss.write_index(index, filename)

In [25]:
show_progress = 20
ckpt_index = 100
index = faiss.IndexFlatIP(768)
vecs = np.empty([0, 768], dtype=np.float32)

for i, split in enumerate(get_text_splits()):
    vecs = np.concatenate((vecs, model.encode(split)))

    if i > 0 and i % show_progress == 0:
        print('.', end='')
    if i > 0 and i % ckpt_index == 0:
        index.add(vecs)
        checkpoint(index, vecs.shape[0])
        print(' %s completed' % i)
        break

..... 100 completed


In [21]:
vecs.dtype

dtype('float64')

In [17]:
vecs = np.concatenate(embs, axis=0)
vecs.shape

(100000, 768)

In [37]:
qvec = (df['title'][10]).reshape(-1, 768)
D, I = index.search(qvec, k=4)
print(df['title'][10])
for i in I[0]:
    print(i, df['title'][I[0, 0]])

Call center jobs await deported Salvadorans
41778 Call center jobs await deported Salvadorans
41095 Call center jobs await deported Salvadorans
10 Call center jobs await deported Salvadorans
44281 Call center jobs await deported Salvadorans
