In [1]:
%cd /workdir
from elasticsearch import Elasticsearch, helpers
import yaml
import pandas as pd
import numpy as np
import eland as ed
from glom import glom
from bertopic import BERTopic

/workdir


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
with open("config.yaml", "r") as stream:
    config = yaml.safe_load(stream)

In [3]:
allow_pos = set(config["allow_pos"])
index = config["index"]
topic_model = BERTopic.load("./topics_model_persistent")



In [4]:
es = Elasticsearch(
    "http://elasticsearch:9200",
    verify_certs=False,
    basic_auth=("elastic", "123456"),
)

In [5]:
df = ed.DataFrame(
    es,
    index,
    columns=["title_token", "context_token", "context_tag", "context_vector", "date"],
)

In [6]:
query_word = "台積電"

In [28]:
# body = {
#     "bool": {
#         "should": [
#             {"term": {"title_token.keyword": query_word}},
#             {"term": {"context_token.keyword": query_word}},
#         ],
#         "must": [
#             {
#                 "range": {
#                     "date": {
#                         "gte": "2024-10-14",
#                         "lte": "2024-11-14",
#                         "time_zone": "+08:00",
#                         "format": "yyyy-MM-dd",
#                     },
#                 }
#             },
#         ],
#     }
# }
body = {
    "query": {
        "bool": {
            "filter": [
                {
                    "bool": {
                        "should": [
                            {"term": {"title_token.keyword": query_word}},
                            {"term": {"context_token.keyword": query_word}},
                        ],
                    }
                },
                {
                    "range": {
                        "date": {
                            "gte": "2024-10-14",
                            "lte": "2024-11-14",
                            "time_zone": "+08:00",
                            "format": "yyyy-MM-dd",
                        },
                    }
                },
            ]
        }
    }
}

In [None]:
def tag_filter_by_pos(tag_list, pos_list):
    if not tag_list:
        return ""
    cleaned_tag_list = [tag for tag, pos in zip(tag_list, pos_list) if pos in allow_pos]
    return " ".join(cleaned_tag_list)


def query_vec(query_word):
    body = {
        # "_source": ["title", "title_vector", "context", "context_vector", "date"],
        "_source": ["context_token", "context_tag", "context_vector", "date"],
        "query": {
            "multi_match": {
                "query": query_word,
                "fields": ["title", "context"],
                "minimum_should_match": "50%",
            },
        },
    }
    res = list(
        helpers.scan(
            es,
            query=body,
            index=index,
        )
    )
    assert len(res) != 0
    vec_df = pd.DataFrame.from_dict(glom(res, "*._source"))
    vec_df = vec_df[
        vec_df["context_token"].apply(lambda x: x is not None)
        & vec_df["context_vector"].apply(lambda x: isinstance(x, list))
    ]
    vec_df["text"] = vec_df.apply(
        lambda x: tag_filter_by_pos(x["context_token"], x["context_tag"]), axis=1
    )
    return vec_df


def fit_topic_model(vec_df):
    docs = vec_df["text"].tolist()
    vecs = np.array(vec_df["context_vector"].tolist())
    timestamps = vec_df["date"].tolist()
    topic_model.fit(docs, vecs)
    topics_over_time = topic_model.topics_over_time(docs, timestamps)
    return topics_over_time


def gen_dtm(query_word):
    vec_df = query_vec(query_word)
    topics_over_time_df = fit_topic_model(vec_df)
    return topic_model.visualize_topics_over_time(topics_over_time_df)

In [None]:
df = query_vec("美國")

In [None]:
vec_df = df

In [None]:
docs = vec_df["text"].tolist()
vecs = np.array(vec_df["context_vector"].tolist())
timestamps = vec_df["date"].tolist()
# topic_model.fit(docs, vecs)
# topic_model.transform()

In [None]:
topics, _ = topic_model.transform(docs, vecs)

In [None]:
topic_model.topics_over_time(
    docs=docs,
    topics=topics.tolist(),
    timestamps=timestamps,
    nr_bins=20,
    datetime_format="%Y-%m-%d %H:%M:%S.%f",
)