In [11]:
from transformers import AutoTokenizer, AutoModelForTokenClassification, pipeline
from underthesea import sent_tokenize, text_normalize
from collections import Counter, defaultdict
from common.writer import Writer
import pandas as pd
import csv

class NEREntityCounter:
    def __init__(self, model_name="NlpHUST/ner-vietnamese-electra-base"):
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.model = AutoModelForTokenClassification.from_pretrained(model_name)
        self.nlp = pipeline("ner", model=self.model, tokenizer=self.tokenizer)

    def aggregate_entities(self, ner_results):
        entities = []
        current_entity = ""
        current_type = None

        for token in ner_results:
            word = token["word"]
            entity_type = token["entity"]

            if entity_type.startswith("B-"):
                if current_entity:
                    entities.append({"text": current_entity.strip(), "type": current_type})
                current_entity = word
                current_type = entity_type[2:]

            elif entity_type.startswith("I-") and current_type == entity_type[2:]:
                if word.startswith("##"):
                    current_entity += word[2:]
                else:
                    current_entity += " " + word
            else:
                if current_entity:
                    entities.append({"text": current_entity.strip(), "type": current_type})
                    current_entity = ""
                    current_type = None

        if current_entity:
            entities.append({"text": current_entity.strip(), "type": current_type})

        return entities

    def count_entities_in_text(self, text):
        text = text_normalize(text)
        segs = sent_tokenize(text)
        ner_results_batch = self.nlp(segs)

        all_entities = []
        for ner_results in ner_results_batch:
            entities = self.aggregate_entities(ner_results)
            all_entities.extend(entities)

        counter = defaultdict(Counter)
        for entity in all_entities:
            normalized_text = entity["text"].strip()
            entity_type = entity["type"]
            counter[entity_type][normalized_text] += 1

        return counter

    def count_entities_in_dataframe(self, df, content_column="content"):
        total_counter = defaultdict(Counter)

        for idx, row in df.iterrows():
            text = row.get(content_column, "")
            if pd.isna(text) or not text.strip():
                continue
            entity_counter = self.count_entities_in_text(text)
            for entity_type, type_counter in entity_counter.items():
                total_counter[entity_type].update(type_counter)

        # Chuyển từ defaultdict(Counter) → DataFrame
        rows = []
        for entity_type, type_counter in total_counter.items():
            for entity_text, count in type_counter.items():
                rows.append({
                    "entity_text": entity_text,
                    "entity_type": entity_type,
                    "count": count
                })

        df_entities = pd.DataFrame(rows)
        return df_entities

    def export_to_csv(self, entity_counter, output_path="entity_counts.csv"):
        with open(output_path, "w", encoding="utf-8", newline="") as f:
            writer = csv.writer(f)
            writer.writerow(["entity_text", "entity_type", "count"])
            for entity_type, type_counter in entity_counter.items():
                for entity_text, count in type_counter.items():
                    writer.writerow([entity_type, entity_text, count])


In [12]:
writer = Writer(filepath="articles.csv")
df = writer.read_as_dataframe()
print(df)

Empty DataFrame
Columns: [crawled_time, published_time, title, content, author, url]
Index: []


In [13]:
ner_counter = NEREntityCounter()
df_result = ner_counter.count_entities_in_dataframe(df.head(10))

df_result_sorted = df_result.sort_values(by="count", ascending=False)

print(df_result_sorted)

df_result_sorted.head(50).to_csv("top_50_entities.csv", sep="|", index=False, encoding="utf-8")


Device set to use mps:0


KeyError: 'count'