In [98]:
import json
import pandas as pd
from sentence_transformers import SentenceTransformer
import numpy as np
import ast

emb_model = SentenceTransformer('moka-ai/m3e-base')

data_dict_df = pd.read_csv('data/data_dictionary.csv')
db_description_json = json.load(open('data/database-table/database-with_description.json'))

table_info_dict = {}
for db in db_description_json:
    key = db['table_name_en']
    value = {
        'database_name_zh': db['database_name_zh'],
        'database_name_en': db['database_name_en'],
        'table_name_en': db['table_name_en'],
        'table_name_zh': db['table_name_zh'],
        'table_description' : db['table_description'],
    }
    table_info_dict[key] = value 


from tqdm import tqdm


columns_info = []

for i, row in tqdm(data_dict_df.iterrows(), total=data_dict_df.shape[0]):
    table_name_en = row['table_name']
    column_name_en = row['column_name']
    db_name = table_info_dict[table_name_en]['database_name_en']
    column_description = (row['column_description'] if isinstance(row['column_description'], str) else '')
    annotation = (row['注释'] if isinstance(row['注释'], str) else '')
    table_description = table_info_dict[table_name_en]['table_description']
    table_name_zh = table_info_dict[table_name_en]['table_name_zh']

    columns_info.append({
        'db_name': table_info_dict[table_name_en]['database_name_en'],
        'table_name_en': table_name_en,
        'column_name': column_name_en,
        'column_description': column_description,
        'table_name_zh': table_name_zh,
        'table_name_zh_emb': emb_model.encode(table_name_zh),
        'table_desc_emb': emb_model.encode(table_description),
        'col_emb': emb_model.encode(column_description + ' ' + annotation),
        'all_emb': emb_model.encode(table_name_zh + ': ' + table_description + ' ' + column_description + ' ' + annotation),
    })


100%|██████████| 3489/3489 [03:51<00:00, 15.10it/s]


保存embedding为文件

In [99]:
def convert_embeddings(obj):
    if isinstance(obj, np.ndarray):
        return obj.tolist()
    if isinstance(obj, dict):
        return {key: convert_embeddings(value) for key, value in obj.items()}
    if isinstance(obj, list):
        return [convert_embeddings(item) for item in obj]
    return obj

# Convert dictionaries with embeddings to JSON serializable format
columns_info_serializable = convert_embeddings(columns_info)

with open('data/database-table/columns_emb.json', 'w', encoding='utf-8') as f:
    json.dump(columns_info_serializable, f, ensure_ascii=False, indent=4)

    

读取保存的文件

In [100]:
def parse_embedding(embedding_str):
    """
    将存储在 JSON 中的嵌入字符串转换为 numpy 数组。
    假设嵌入存储为形如 "[0.1, 0.2, ...]" 的字符串。
    """
    return np.array(ast.literal_eval(embedding_str))

# 读取 JSON 文件并转换回字典
def load_json_to_dict(file_path):
    with open(file_path, 'r', encoding='utf-8') as f:
        data = json.load(f)
    
    # 递归解析嵌入数据
    def convert_embeddings(obj):
        if isinstance(obj, list) and all(isinstance(x, (int, float)) for x in obj):
            return np.array(obj)
        if isinstance(obj, dict):
            return {key: convert_embeddings(value) for key, value in obj.items()}
        if isinstance(obj, list):
            return [convert_embeddings(item) for item in obj]
        return obj
    
    return convert_embeddings(data)

# 加载 JSON 文件到字典
columns_info = load_json_to_dict('data/database-table/columns_emb.json')

以下计算：
score = 0.5 * table_description_cos + 0.25 * (column_name_cos + column_annotation_cos)

In [131]:
import numpy as np
from sentence_transformers import util

def cosine_similarity(vec_a, vec_b):
    return np.dot(vec_a, vec_b) / (np.linalg.norm(vec_a) * np.linalg.norm(vec_b))

def find_top_similar_columns(word, columns_info, emb_model, top_k=20):
    """
    Find the top `top_k` most similar columns to the given word based on cosine similarity.

    Args:
    - word (str): The query word.
    - columns_info (list): list containing column details with embeddings.
    - emb_model (SentenceTransformer): The embedding model.
    - top_k (int): Number of top similar columns to return.

    Returns:
    - List of tuples (column_name, similarity_score, column_description)
    """
    # Encode the input word
    word_embedding = emb_model.encode(word)

    table_similarities = []
    column_similarities = []
    scores1 = []
    scores2 = []

    for col_info in columns_info:
        # Get column name and description
        column_name = col_info['column_name']
        column_description = col_info.get('column_description', '')
        db_name = col_info.get('db_name', '')
        table_name_en = col_info.get('table_name_en', '')
        table_name_zh = col_info.get('table_name_zh', '')


        # Get embeddings (ensure they exist)
        table_name_zh_emb = col_info.get('table_name_zh_emb')
        table_desc_emb = col_info.get('table_desc_emb')
        col_emb = col_info.get('col_emb')
        all_emb = col_info.get('all_emb')

        table_name_similarity = cosine_similarity(word_embedding, table_name_zh_emb) if table_name_zh_emb is not None else 0
        table_desc_similarity = cosine_similarity(word_embedding, table_desc_emb) if table_desc_emb is not None else 0
        col_similarity = cosine_similarity(word_embedding, col_emb) if col_emb is not None else 0
        all_similarity = cosine_similarity(word_embedding, all_emb) if all_emb is not None else 0

        # 四种similarity的计算方法

        table_similarity = 0.5 * table_name_similarity + 0.5 * table_desc_similarity 
        if table_name_en not in [x[2] for x in table_similarities]:
            table_similarities.append((table_similarity, db_name, table_name_en, table_name_zh))

        column_similarities.append((col_similarity, db_name, table_name_en, column_name, column_description))

        score1 = all_similarity
        scores1.append((score1, db_name, table_name_en, column_name, column_description))

        score2 = 0.3 * table_similarity + 0.7 * col_similarity
        scores2.append((score2, db_name, table_name_en, column_name, column_description))


    # Sort by similarity score
    table_similarities = sorted(table_similarities, key=lambda x: x[0], reverse=True)
    column_similarities = sorted(column_similarities, key=lambda x: x[0], reverse=True)
    scores1 = sorted(scores1, key=lambda x: x[0], reverse=True)
    scores2 = sorted(scores2, key=lambda x: x[0], reverse=True)


    # Return top `top_k` results
    return table_similarities[:top_k], column_similarities[:top_k], scores1[:top_k], scores2[:top_k]

entities = ['中文全称', '全称变更', 'A股简称','法人','法律顾问','会计师事务所','董秘', '实控人', '近一个月最高价', '现金流量净额',
            '注册邮箱', '注册地址', '信披网址', '公司电话','硕士及以上学历（硕士+博士）的人员占比', '一级行业', 
            '什么时间上市']


for entity in entities:
    print(f"{entity}: ")
    results = find_top_similar_columns(entity, columns_info, emb_model, top_k=3)
    print("- By Table Only:")
    for result in results[0]:
        print(f'  - {result[1]}.{result[2]} ({result[3]})')
    print("- By Column Only:")
    for result in results[1]:
        print(f'  - {result[1]}.{result[2]}.{result[3]} ({result[4]})')
    print("- By Total Scores 1")
    for result in results[2]:
        print(f'  - {result[1]}.{result[2]}.{result[3]} ({result[4]})')
    print("- By Total Scores 2")
    for result in results[2]:
        print(f'  - {result[1]}.{result[2]}.{result[3]} ({result[4]})')

    print()

中文全称: 
- By Table Only:
  - AStockBasicInfoDB.LC_NameChange (公司名称更改状况)
  - AStockBasicInfoDB.LC_StockArchives (公司概况)
  - PublicFundDB.MF_FundProdName (公募基金产品名称)
- By Column Only:
  - AStockBasicInfoDB.LC_StockArchives.ChiName (中文名称)
  - AStockBasicInfoDB.LC_NameChange.ChiName (中文名称)
  - HKStockDB.HK_StockArchives.ChiName (中文名称)
- By Total Scores 1
  - AStockBasicInfoDB.LC_NameChange.ChiNameAbbr (中文名称缩写)
  - AStockBasicInfoDB.LC_NameChange.ChiName (中文名称)
  - AStockBasicInfoDB.LC_NameChange.EngNameAbbr (英文名称缩写)
- By Total Scores 2
  - AStockBasicInfoDB.LC_NameChange.ChiNameAbbr (中文名称缩写)
  - AStockBasicInfoDB.LC_NameChange.ChiName (中文名称)
  - AStockBasicInfoDB.LC_NameChange.EngNameAbbr (英文名称缩写)

全称变更: 
- By Table Only:
  - AStockBasicInfoDB.LC_NameChange (公司名称更改状况)
  - AStockIndustryDB.LC_ExgIndChange (公司行业变更表)
  - AStockBasicInfoDB.LC_Business (公司经营范围与行业变更)
- By Column Only:
  - AStockBasicInfoDB.LC_Business.ChangeReason (简称变更原因)
  - AStockShareholderDB.LC_ShareStru.ChangeReason (简称变更原因)
