In [1]:
import os
import asyncio
import logging
from typing import List
import yaml
import torch
import gc

# 假设这些导入是可用的，如果不是，您可能需要安装相应的包
from Knowledge_based_async import KnowledgeBase
from langchain_community.embeddings import HuggingFaceBgeEmbeddings

# # 配置日志
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# 加载配置
with open("config.yaml", "r") as config_file:
    config = yaml.safe_load(config_file)

# 初始化嵌入模型
model_kwargs = {"device": config['settings']['device']}
encode_kwargs = {
    "batch_size": config['settings']['batch_size'],
    "normalize_embeddings": config['settings']['normalize_embeddings']
}
embeddings = HuggingFaceBgeEmbeddings(
    model_name=config['paths']['model_dir'],
    model_kwargs=model_kwargs,
    encode_kwargs=encode_kwargs,
)

class State:
    def __init__(self):
        self.current_kb_name = None
        self.kb = None
        self.kb_vectordb = None
        self.unfilter_context = None

state = State()

async def update_vectordb(kb_name: str, file_paths: List[str]):
    KB_DIR = config['paths']['kb_dir']
    kb_dir = os.path.join(KB_DIR, kb_name)
    os.makedirs(kb_dir, exist_ok=True)

    try:
        kb = KnowledgeBase(kb_name, embeddings)
        
        total_files = len(file_paths)
        for i, file_path in enumerate(file_paths, 1):
            try:
                # 更新单个文件
                result = await kb.update_vectordb([file_path])
                logger.info(f"Processed file {i}/{total_files}: {file_path}")
            except Exception as e:
                logger.error(f"Error processing file {file_path}: {str(e)}")
            finally:
                # 每处理完一个文件就清理 GPU 缓存
                clean_gpu_cache()
        
        # 更新全局状态
        await update_global_state(kb_name, kb)
        
        logger.info(f"Knowledge base '{kb_name}' updated successfully")
        return "All files processed"
    
    except Exception as e:
        error_message = str(e)
        logger.error(f"Error occurred while updating knowledge base '{kb_name}': {error_message}")
        raise

async def update_global_state(kb_name, kb):
    state.kb = kb
    state.kb_vectordb = await state.kb.load_vectordb()
    state.current_kb_name = kb_name
    state.unfilter_context = [doc for doc_id, doc in state.kb_vectordb.docstore._dict.items()]
    print("重新选择向量库完成")
    logger.info(f"Knowledge base '{kb_name}' selected successfully")

def clean_gpu_cache():
    if torch.cuda.is_available():
        gc.collect()
        torch.cuda.empty_cache()
        print("GPU cache cleared")

def get_all_files(directory):
    file_paths = []
    for root, dirs, files in os.walk(directory):
        for file in files:
            file_paths.append(os.path.join(root, file))
    return file_paths

async def main():
    kb_name = "test_kb"
    file_path = "/root/autodl-tmp/project_knowledge/Doc_QA/归档/归档"
    
    # 获取目录中的所有文件
    all_files = get_all_files(file_path)
    
    print(f"Found {len(all_files)} files in the directory.")
    
    try:
        result = await update_vectordb(kb_name, all_files)
        print(f"Update result: {result}")
        
        # 打印更新后的向量库信息
        print(f"Current KB: {state.current_kb_name}")
        print(f"Vector DB size: {len(state.unfilter_context)}")
        
    except Exception as e:
        print(f"An error occurred: {str(e)}")

# 在 Jupyter Notebook 中运行异步函数
await main()

  from .autonotebook import tqdm as notebook_tqdm


In [5]:
from Knowledge_based_async import KnowledgeBase
import yaml
# Load configuration
with open("config.yaml", "r") as config_file:
    config = yaml.safe_load(config_file)

# # Configure logging
# logging.basicConfig(level=logging.INFO)
# logger = logging.getLogger(__name__)

# Load embeddings and reranker models
model_kwargs = {"device": config['settings']['device']}
encode_kwargs = {
    "batch_size": config['settings']['batch_size'],
    "normalize_embeddings": config['settings']['normalize_embeddings']
}
embeddings = HuggingFaceBgeEmbeddings(
    model_name=config['paths']['model_dir'],
    model_kwargs=model_kwargs,
    encode_kwargs=encode_kwargs,
)
# 初始化知识库
kb = KnowledgeBase("1828695663060254721", embeddings)
# 定义并运行异步函数来加载向量数据库和文件
# 定义并运行异步函数来加载向量数据库和文件，并查看文档内容
async def load_and_view_documents():
    await kb.load_vectordb_and_files()
    if kb.vectordb and kb.vectordb.docstore._dict:
        print(f"知识库 {kb.kb_name} 中的文档:")
        for doc_id, doc in kb.vectordb.docstore._dict.items():
            print(f"文档ID: {doc_id}")
            print(f"文件名: {doc.metadata.get('file_path', 'Unknown')}")
            print(f"内容预览: {doc.page_content[:200]}...")  # 显示前200个字符
            print("-" * 50)
    else:
        print(f"知识库 {kb.kb_name} 中没有文档或加载失败")

# 在 Jupyter Notebook 中运行异步函数
await load_and_view_documents()


知识库 1828695663060254721 的向量数据库和已上传文件加载成功
知识库 1828695663060254721 中的文档:
文档ID: 76342932-5dcc-4b46-a421-d0775c93ee1a
文件名: Unknown
内容预览: Document...
--------------------------------------------------
文档ID: b38e45ef-5579-4d21-b5e7-d302b9e18223
文件名: /root/autodl-tmp/project_knowledge/Doc_QA/Knowledge_based/1828695663060254721/uploads/1828696355741171714.md
内容预览: 你要讲道理吗
我不想讲道理...
--------------------------------------------------
文档ID: c785d4b5-beb3-4e38-bc0e-2b2b43db5670
文件名: /root/autodl-tmp/project_knowledge/Doc_QA/Knowledge_based/1828695663060254721/uploads/1828696355741171714.md
内容预览: 你真的讲道理吗
我真的讲道理...
--------------------------------------------------
文档ID: 21a7573c-96ac-45ba-bb96-0f723bee0706
文件名: /root/autodl-tmp/project_knowledge/Doc_QA/Knowledge_based/1828695663060254721/uploads/1828696355741171714.md
内容预览: 讲道理的视频
降到力...
--------------------------------------------------
文档ID: b174792e-d712-4a7a-91ef-9d07eba5586b
文件名: /root/autodl-tmp/project_knowledge/Doc_QA/Knowledge_based/1828695

In [1]:
import os
from pathlib import Path

def classify_files(files):
    file_groups = {
        'docx': [], 'doc': [], 'pdf': [], 'md': [], 'txt': [],
        'pptx': [], 'html': [], 'xlsx': [], 'csv': [], 'jpg': [], 'png': []
    }

    supported_extensions = set(file_groups.keys())

    for file in files:
        # 使用 os.path.splitext() 来正确处理文件扩展名
        _, ext = os.path.splitext(file)
        ext = ext.lower().lstrip('.')
        
        # 特殊处理 .doc 和 .docx
        if ext == 'doc' or ext == 'docx':
            if file.lower().endswith('.doc'):
                ext = 'doc'
            elif file.lower().endswith('.docx'):
                ext = 'docx'
        
        if ext in supported_extensions:
            file_groups[ext].append(file)
        # 可选：记录不支持的文件类型
        # else:
        #     print(f"Unsupported file type: {file}")

    return file_groups

# 使用示例
files = [
    '/path/to/file1.docx', 
    '/path/to/file2.pdf', 
    '/path/to/file3.jpg', 
    '/path/to/file4.doc',
    '/path/to/file5.unknown'
]
grouped_files = classify_files(files)

# 打印结果
for ext, file_list in grouped_files.items():
    if file_list:  # 只打印非空的组
        print(f"{ext}: {file_list}")

docx: ['/path/to/file1.docx']
doc: ['/path/to/file4.doc']
pdf: ['/path/to/file2.pdf']
jpg: ['/path/to/file3.jpg']


In [2]:
import requests

url = "https://huggingface.github.io/text-embeddings-inference/embed"  # 使用实际运行服务的地址
data = {
    "inputs": ["Hello world", "Test sentence"]
}

response = requests.post(url, json=data)

if response.status_code == 200:
    print(response.json())
else:
    print(f"Error: {response.status_code}")
    print(response.text)

KeyboardInterrupt: 

In [2]:
from langchain_community.embeddings import HuggingFaceBgeEmbeddings
import yaml
# Load configuration
with open("config.yaml", "r") as config_file:
    config = yaml.safe_load(config_file)
# Load embeddings and reranker models
model_kwargs = {"device": config['settings']['device']}
encode_kwargs = {
    "batch_size": config['settings']['batch_size'],
    "normalize_embeddings": config['settings']['normalize_embeddings']
}    
embeddings = HuggingFaceBgeEmbeddings(
    model_name=config['paths']['model_dir'],
    model_kwargs=model_kwargs,
    encode_kwargs=encode_kwargs,
)



  from tqdm.autonotebook import tqdm, trange


In [16]:
import numpy as np
result = embeddings.embed_query('The quick brown fox jumps over the lazy dog.')
print(f"Type of result: {type(result)}")
if isinstance(result, list):
    print(f"Length of result: {len(result)}")
    print(f"First few values: {result[:5]}")  # 打印前5个值作为示例
else:
    print(f"Result is not a list. Value: {result}")

# 转换为numpy数组以查看维度
try:
    np_result = np.array(result)
    print(f"\nNumpy array shape: {np_result.shape}")
    print(f"Numpy array dtype: {np_result.dtype}")
except:
    print("\nCould not convert result to numpy array.")

# 如果结果是列表，检查是否都是浮点数
if isinstance(result, list):
    all_float = all(isinstance(x, float) for x in result)
    print(f"\nAll elements are floats: {all_float}")

Type of result: <class 'list'>
Length of result: 1024
First few values: [-0.004752224311232567, -0.03886924311518669, 0.025367779657244682, -0.006612122058868408, 0.025436915457248688]

Numpy array shape: (1024,)
Numpy array dtype: float64

All elements are floats: True


In [17]:
import numpy as np

# 测试文本
test_texts = ["The quick brown fox jumps over the lazy dog." for _ in range(6)]

# 获取嵌入向量
result = embeddings.embed_documents(test_texts)

# 打印结果
print(f"Number of embeddings: {len(result)}")
print(f"Shape of first embedding: {np.array(result[0]).shape}")
print(f"Shape of all embeddings: {np.array(result).shape}")

# 打印每个嵌入向量的一些信息
for i, embedding in enumerate(result):
    print(f"\nEmbedding {i+1}:")
    print(f"  Length: {len(embedding)}")
    print(f"  First few values: {embedding[:5]}")

# 转换为numpy数组以查看维度
np_result = np.array(result)
print(f"\nFull numpy array shape: {np_result.shape}")
print(f"Full numpy array dtype: {np_result.dtype}")

Number of embeddings: 6
Shape of first embedding: (1024,)
Shape of all embeddings: (6, 1024)

Embedding 1:
  Length: 1024
  First few values: [0.015553509816527367, -0.02696125954389572, -0.014358056709170341, -0.008205627091228962, 0.017061004415154457]

Embedding 2:
  Length: 1024
  First few values: [0.015553509816527367, -0.02696125954389572, -0.014358056709170341, -0.008205627091228962, 0.017061004415154457]

Embedding 3:
  Length: 1024
  First few values: [0.015553509816527367, -0.02696125954389572, -0.014358056709170341, -0.008205627091228962, 0.017061004415154457]

Embedding 4:
  Length: 1024
  First few values: [0.015553509816527367, -0.02696125954389572, -0.014358056709170341, -0.008205627091228962, 0.017061004415154457]

Embedding 5:
  Length: 1024
  First few values: [0.015553509816527367, -0.02696125954389572, -0.014358056709170341, -0.008205627091228962, 0.017061004415154457]

Embedding 6:
  Length: 1024
  First few values: [0.015553509816527367, -0.02696125954389572, -0.

In [8]:
import faiss
from langchain_community.docstore.in_memory import InMemoryDocstore
from langchain_community.vectorstores import FAISS

index = faiss.IndexFlatL2(len(embeddings.embed_query("hello world")))

vector_store = FAISS(
    embedding_function=embeddings,
    index=index,
    docstore=InMemoryDocstore(),
    index_to_docstore_id={},
)

In [9]:
from uuid import uuid4

from langchain_core.documents import Document

document_1 = Document(
    page_content="I had chocalate chip pancakes and scrambled eggs for breakfast this morning.",
    metadata={"source": "tweet"},
)

document_2 = Document(
    page_content="The weather forecast for tomorrow is cloudy and overcast, with a high of 62 degrees.",
    metadata={"source": "news"},
)

document_3 = Document(
    page_content="Building an exciting new project with LangChain - come check it out!",
    metadata={"source": "tweet"},
)

document_4 = Document(
    page_content="Robbers broke into the city bank and stole $1 million in cash.",
    metadata={"source": "news"},
)

document_5 = Document(
    page_content="Wow! That was an amazing movie. I can't wait to see it again.",
    metadata={"source": "tweet"},
)

document_6 = Document(
    page_content="Is the new iPhone worth the price? Read this review to find out.",
    metadata={"source": "website"},
)

document_7 = Document(
    page_content="The top 10 soccer players in the world right now.",
    metadata={"source": "website"},
)

document_8 = Document(
    page_content="LangGraph is the best framework for building stateful, agentic applications!",
    metadata={"source": "tweet"},
)

document_9 = Document(
    page_content="The stock market is down 500 points today due to fears of a recession.",
    metadata={"source": "news"},
)

document_10 = Document(
    page_content="I have a bad feeling I am going to get deleted :(",
    metadata={"source": "tweet"},
)

documents = [
    document_1,
    document_2,
    document_3,
    document_4,
    document_5,
    document_6,
    document_7,
    document_8,
    document_9,
    document_10,
]
uuids = [str(uuid4()) for _ in range(len(documents))]

vector_store.add_documents(documents=documents, ids=uuids)

['70d3ee2c-8992-429c-b1bd-8b3414ccc68c',
 '6c4e73a9-cc1a-42d1-89a9-f775a3fdb952',
 '4dc71f95-4bc8-4b3d-b436-ecac3fd58253',
 '3ab1bbbf-0fd3-469d-8421-039b9d503cf0',
 '50b45280-c992-4cad-937f-f29f7cb4ecf8',
 'fce4a656-c0b8-48ab-9a2d-48ee3eae33c3',
 'bdb223b2-3e3e-4808-a67d-7372eb11d21e',
 '66531eee-fcdf-47f9-868f-d012adadf3d3',
 '5c35c31a-eaea-4d98-8c9c-f5b347d92f33',
 'ef28004c-f749-44fe-80cc-e0c3dee5eb05']