In [4]:
import re
import os
import json
import numpy as np
from tqdm import tqdm
from elasticsearch import Elasticsearch
from elasticsearch.helpers import bulk

# Create the elastic instance
elastic = Elasticsearch(
    "http://localhost:9200",
    request_timeout=1000000
)

# Successful response!
elastic.info()

ObjectApiResponse({'name': 'ubuntu', 'cluster_name': 'elasticsearch', 'cluster_uuid': 'G7JzeSt3Q5OoTS6zBdAaWQ', 'version': {'number': '8.4.1', 'build_flavor': 'default', 'build_type': 'tar', 'build_hash': '2bd229c8e56650b42e40992322a76e7914258f0c', 'build_date': '2022-08-26T12:11:43.232597118Z', 'build_snapshot': False, 'lucene_version': '9.3.0', 'minimum_wire_compatibility_version': '7.17.0', 'minimum_index_compatibility_version': '7.0.0'}, 'tagline': 'You Know, for Search'})

In [None]:
emb = np.random.rand(768).astype(np.float32)

elastic.search(
    index="wenshu",
    size=2,
    from_=0,
    _source=False,
    fields=["case_name", "content", {"field": "publish_date", "format": "year_month_day"}, "court_name", "case_type"],
    # query={
    #     "combined_fields": {
    #         "query": "测试",
    #         "fields": ["case_name", "content"]
    #     }
    # },
    aggs={
        "agg-court": {
            "terms": {
                "field": "court_name"
            }
        }
    },
    knn={
        "field": "vector",
        "query_vector": (emb / np.linalg.norm(emb)).tolist(),
        "k": 10,
        "num_candidates": 10,
        # "boost": 1
    },
    # highlight={
    #     "fields": {
    #         "content": {
    #             "pre_tags" : ["<strong>"],
    #             "post_tags": ["</strong>"],
    #             "number_of_fragments": 1,
    #         }
    #     }
    # },
    post_filter={
        "bool": {
            "filter": [
                {
                    "terms": {"court_name": ["上海市长宁区人民法院"]}
                }
            ]
        }
    }
)["hits"]

In [None]:
file = "p2-1.filtered"
model = "DPR"

embeddings = np.memmap(
    os.path.join("data/encode", model, "wenshu", file, "text_embeddings.mmp"),
    dtype=np.float32,
    mode="r"
).reshape(-1, 768)

def gendata():
    with open(f"../../../Data/wenshu/{file}.txt", encoding="utf-8") as f:
        for i, line in enumerate(f):
            case = json.loads(line.strip())
            del case["crawl_time"]
            del case["legal_base"]
            del case["tf_content"]
            del case["id"]
            case["vector"] = embeddings[i].tolist()
            yield case

for x in tqdm(gendata(), desc="Indexing", total=embeddings.shape[0]):
    elastic.index(index="wenshu", document=x)

In [None]:
elastic.indices.create(
    index="wenshu",
    settings={
        'analysis': {
            'analyzer': {
                # we must set the default analyzer
                "default": {
                    "type": "smartcn"
                }
            }
        },
        "index.mapping.ignore_malformed": True
    },
    mappings={
        "properties": {
            # field name
            "doc_id": {
                "type": "keyword",
            },
            "court_name": {
                "type": "keyword",
            },
            "court_id": {
                "type": "keyword",
            },
            "court_province": {
                "type": "keyword",
            },
            "court_city": {
                "type": "keyword",
            },
            "court_region": {
                "type": "keyword",
            },
            "court_district": {
                "type": "keyword"
            },
            "pub_prosecution_org": {
                "type": "keyword"
            },
            "case_type": {
                "type": "keyword",
            },
            "cause": {
                "type": "keyword",
            },
            "trial_round": {
                "type": "keyword"
            },
            "content": {
                "type": "text"
            },
            "vector": {
                "type": "dense_vector",
                "dims": 768,
                # enable hnsw
                "index": True,
                # inner product only allows unit-length vector
                "similarity": "dot_product"  
            }
        }
    }
)

In [None]:
elastic.indices.get_mapping(index="wenshu")

In [None]:
# delete index
elastic.indices.delete(index="wenshu")

In [None]:
elastic.indices.get_alias(index="*")

In [None]:
elastic.delete_by_query(
    index="wenshu", 
    query={
        "match_all": {}
    }
)

In [None]:
elastic.search(
    index="wenshu",
    query={
        "match_all": {}
    },
    size=1,
    sort=[
        {
        "_timestamp": {
            "order": "desc"
        }
        }
    ]
)["hits"]

In [5]:
elastic.count(
    index="wenshu",
    query={
        "match_all": {
        }
    }
)

ObjectApiResponse({'count': 26345, '_shards': {'total': 1, 'successful': 1, 'skipped': 0, 'failed': 0}})

In [None]:
skipped = 0
file = "p5"
with open(f"../../../Data/wenshu/{file}.txt", encoding="utf-8") as f, open(f"../../../Data/wenshu/{file}.filtered.txt", "w", encoding="utf-8") as g:
    for i, line in enumerate(tqdm(f)):
        fields = json.loads(line.strip())

        if fields["basics_text"]:
            content = fields["basics_text"].replace("/n", "")
        else:
            content = fields["content"]
            if not content:
                skipped += 1
                continue
            else:
                content = content.replace("\n", "")
                for x in ("指控：", "指控，", "查明：", "查明，", "诉称：", "诉称，", "请求：", "理由：", "认定：", "认定:", "认定，", "查认为，"):
                    idx = content.rfind(x)
                    if idx != -1:
                        break
                if idx == -1:
                    skipped += 1
                    # print(content)
                    # if skipped > 10:
                    #     break
                    continue
                else:
                    content = content[idx:]
        # write filter content to a new field
        fields["tf_content"] = content
        g.write(json.dumps(fields, ensure_ascii=False) + "\n")

skipped