In [None]:
!pip install torch==1.12.0 transformers==4.21.0 nltk==3.7 numpy==1.21.2 scikit-learn==1.0.2 regex==2022.3.15 stanfordcorenlp==3.9.1.1


In [None]:
import re
import torch
import numpy as np
from transformers import AutoTokenizer, AutoModel
import nltk
from nltk import pos_tag, word_tokenize

from torch.utils.data import Dataset
from transformers import T5Tokenizer

from nltk.corpus import stopwords

from stanfordcorenlp import StanfordCoreNLP
from tqdm import tqdm

# Tải dữ liệu cần thiết cho NLTK
nltk.download('punkt')
nltk.download('averaged_perceptron_tagger')

class MDERank:
    def __init__(self, model_name="bert-base-uncased", pooling="max"):
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.model = AutoModel.from_pretrained(model_name)
        self.pooling = pooling

    def compute_embedding(self, text):
        # Chuẩn hóa đầu vào và lấy output từ mô hình BERT
        inputs = self.tokenizer(text, return_tensors="pt", truncation=True, max_length=512)
        with torch.no_grad():
            outputs = self.model(**inputs)
        # Lấy hidden_states: [batch_size, sequence_length, hidden_size]
        hidden_states = outputs.last_hidden_state[0]  # (seq_len, hidden_size)
        if self.pooling == "max":
            embedding, _ = torch.max(hidden_states, dim=0)
        elif self.pooling == "avg":
            embedding = torch.mean(hidden_states, dim=0)
        else:
            embedding = torch.mean(hidden_states, dim=0)
        return embedding.numpy()

    def extract_candidates(self, text):
        """
        Sử dụng NLTK để tách từ, gán nhãn POS và trích xuất các cụm từ ứng viên
        theo pattern: liên tục các từ có tag bắt đầu bằng JJ (tính từ) hoặc NN (danh từ).
        """
        tokens = word_tokenize(text)
        tagged = pos_tag(tokens)
        candidates = []
        candidate = []
        for word, tag in tagged:
            if tag.startswith("JJ") or tag.startswith("NN"):
                candidate.append(word)
            else:
                if candidate:
                    phrase = " ".join(candidate)
                    candidates.append(phrase)
                    candidate = []
        if candidate:
            phrase = " ".join(candidate)
            candidates.append(phrase)
        # Loại bỏ các cụm từ trùng lặp và có độ dài ít nhất 1 từ
        candidates = list(set([c for c in candidates if len(c.split()) >= 1]))
        return candidates

    def mask_text(self, text, candidate):
        """
        Thay thế các xuất hiện của candidate trong text bằng [MASK] với số lượng token tương ứng.
        """
        candidate_tokens = candidate.split()
        mask_token = " ".join(["[MASK]"] * len(candidate_tokens))
        # Sử dụng regex để thay thế, không phân biệt hoa thường
        pattern = re.compile(re.escape(candidate), re.IGNORECASE)
        masked_text = pattern.sub(mask_token, text)
        return masked_text

    def cosine_similarity(self, vec1, vec2):
        return np.dot(vec1, vec2) / (np.linalg.norm(vec1) * np.linalg.norm(vec2) + 1e-8)

    def rank_keyphrases(self, text):
        """
        Tính toán embedding của văn bản gốc và đối với mỗi ứng viên, tính embedding của văn bản đã mask.
        Sau đó, tính cosine similarity giữa hai embedding này. Ứng viên có similarity thấp hơn (nghĩa là mất thông tin lớn)
        được xem là quan trọng hơn.
        """
        original_embedding = self.compute_embedding(text)
        candidates = self.extract_candidates(text)
        scores = {}
        for candidate in candidates:
            masked_text = self.mask_text(text, candidate)
            masked_embedding = self.compute_embedding(masked_text)
            sim = self.cosine_similarity(original_embedding, masked_embedding)
            scores[candidate] = sim
        # Sắp xếp các ứng viên theo thứ tự tăng dần của similarity
        ranked = sorted(scores.items(), key=lambda x: x[1])
        return ranked


In [None]:
def extract_keyphrases(text, top_k=10):
    """
    Hàm bọc để trích xuất keyphrase từ văn bản.
    Trả về danh sách top_k keyphrase có score thấp nhất (nghĩa là quan trọng nhất).
    """
    mde = MDERank()
    ranked = mde.rank_keyphrases(text.lower())
    top_candidates = [phrase for phrase, score in ranked[:top_k]]
    return top_candidates

In [None]:


def clean_labels(labels):
    clean_labels = {}
    for id in labels:
        label = labels[id]
        clean_label = []
        for kp in label:
            if kp.find(";") != -1:
                left, right = kp.split(";")
                clean_label.append(left)
                clean_label.append(right)
            else:
                clean_label.append(kp)
        clean_labels[id] = clean_label        
    return clean_labels


def get_long_data(file_path="data/nus/nus_test.json"):
    """ Load file.jsonl ."""
    data = {}
    labels = {}
    with codecs.open(file_path, 'r', 'utf-8') as f:
        json_text = f.readlines()
        for i, line in tqdm(enumerate(json_text), desc="Loading Doc ..."):
            try:
                jsonl = json.loads(line)
                keywords = jsonl['keywords'].lower().split(";")
                abstract = jsonl['abstract']
                fulltxt = jsonl['fulltext']
                doc = ' '.join([abstract, fulltxt])
                doc = re.sub('\. ', ' . ', doc)
                doc = re.sub(', ', ' , ', doc)
                doc = doc.replace('\n', ' ')
                data[jsonl['name']] = doc
                labels[jsonl['name']] = keywords
            except:
                raise ValueError
    labels = clean_labels(labels)
    return data,labels


def get_duc2001_data(file_path="data/DUC2001"):
    pattern = re.compile(r'<TEXT>(.*?)</TEXT>', re.S)
    data = {}
    labels = {}
    for dirname, dirnames, filenames in os.walk(file_path):
        for fname in filenames:
            if (fname == "annotations.txt"):
                # left, right = fname.split('.')
                infile = os.path.join(dirname, fname)
                f = open(infile,'rb')
                text = f.read().decode('utf8')
                lines = text.splitlines()
                for line in lines:
                    left, right = line.split("@")
                    d = right.split(";")[:-1]
                    l = left
                    labels[l] = d
                f.close()
            else:
                infile = os.path.join(dirname, fname)
                f = open(infile,'rb')
                text = f.read().decode('utf8')
                text = re.findall(pattern, text)[0]
                data[fname] = text
    labels = clean_labels(labels)
    return data,labels

def get_inspec_data(file_path="data/Inspec"):

    data={}
    labels={}
    for dirname, dirnames, filenames in os.walk(file_path):
        for fname in filenames:
            left, right = fname.split('.')
            if (right == "abstr"):
                infile = os.path.join(dirname, fname)
                f=open(infile)
                text=f.read()
                text = text.replace("%", '')
                data[left]=text
            if (right == "uncontr"):
                infile = os.path.join(dirname, fname)
                f=open(infile)
                text=f.read()
                text = text.replace("\n\t", ' ')
                text=text.replace("\n",' ')
                label=text.split("; ")
                labels[left]=label
    labels = clean_labels(labels)
    return data,labels

def get_semeval2017_data(data_path="data/SemEval2017/docsutf8",labels_path="data/SemEval2017/keys"):

    data={}
    labels={}
    for dirname, dirnames, filenames in os.walk(data_path):
        for fname in filenames:
            left, right = fname.split('.')
            infile = os.path.join(dirname, fname)
            # f = open(infile, 'rb')
            # text = f.read().decode('utf8')
            with codecs.open(infile, "r", "utf-8") as fi:
                text = fi.read()
                text = text.replace("%", '')
            data[left] = text.lower()
            # f.close()
    for dirname, dirnames, filenames in os.walk(labels_path):
        for fname in filenames:
            left, right = fname.split('.')
            infile = os.path.join(dirname, fname)
            f = open(infile, 'rb')
            text = f.read().decode('utf8')
            text = text.strip()
            ls=text.splitlines()
            labels[left] = ls
            f.close()
    labels = clean_labels(labels)
    return data,labels

def get_short_data(file_path="data/krapivin/kravipin_test.json"):
    """ Load file.jsonl ."""
    data = {}
    labels = {}
    with codecs.open(file_path, 'r', 'utf-8') as f:
        json_text = f.readlines()
        for i, line in tqdm(enumerate(json_text), desc="Loading Doc ..."):
            try:
                jsonl = json.loads(line)
                keywords = jsonl['keywords'].lower().split(";")
                abstract = jsonl['abstract']
                doc =abstract
                doc = re.sub('\. ', ' . ', doc)
                doc = re.sub(', ', ' , ', doc)
                doc = doc.replace('\n', ' ')
                doc = doc.replace('\t', ' ')
                data[i] = doc
                labels[i] = keywords
            except:
                raise ValueError
    labels = clean_labels(labels)
    return data,labels

def get_krapivin_data(file_path="data/krapivin/krapivin_test.json"):
    return get_short_data(file_path)

def get_nus_data(file_path="data/nus/nus_test.json"):
    return get_long_data(file_path)

def get_semeval2010_data(file_path="data/SemEval2010/semeval_test.json"):
    return get_short_data(file_path)

def get_dataset_data(dataset_name):
    if dataset_name == "duc2001":
        return get_duc2001_data()
    elif dataset_name == "inspec":
        return get_inspec_data()
    elif dataset_name == "krapivin":
        return get_krapivin_data()
    elif dataset_name == "nus":
        return get_nus_data()
    elif dataset_name == "semeval2010":
        return get_semeval2010_data()
    elif dataset_name == "sameval2017":
        return get_semeval2017_data()
    
def calculate_f1(predicted, ground_truth, k)->float:
    """
    Calculate precision, recall, and F1@K.
    
    Parameters:
      predicted (list): List of predicted keyphrases.
      ground_truth (list): List of ground truth keyphrases.
      k (int): The cutoff for evaluation.
    
    Returns:
      tuple: precision, recall, and F1 score.
    """
    predicted_top_k = predicted[:k]
    common = set(predicted_top_k) & set(ground_truth)
    precision = len(common) * 1.0 / k if k > 0 else 0
    recall = len(common) * 1.0 / len(ground_truth) if ground_truth else 0
    f1 = 0
    if precision + recall > 0:
       f1 = 200.0 * precision * recall / (precision + recall)
    # return precision, recall, f1
    return f1


def print_to_json(data_name, k, score):
    """
    Print the evaluation results to a JSON file.
    
    Parameters:
      data_name (str): The name of the dataset.
      k (int): The cutoff for evaluation.
      score (list): The list of evaluation results.
    """
    average_score = sum(score) / len(score) if score else 0
    result = {
        "dataset": data_name,
        "top_k": k,
        "average_score": average_score,}
    with open(f"results/{data_name}_{k}.json", "w") as outfile:
        json.dump(result, outfile)


if __name__ == "__main__":
    # dataset = ['duc2001', 'inspec', 'krapivin', 'nus', 'semeval2010', 'sameval2017']
    dataset = ['krapivin']
    

    for data_name in dataset:
        data,labels = get_dataset_data(data_name) 
        score5 = []
        score10 = []
        score15 = []
        for id in data:
            keyphrases = extract_keyphrases(data[id], top_k=15)
            
            score5.append(calculate_f1(keyphrases, labels[id], 5))
            score10.append(calculate_f1(keyphrases, labels[id], 10))
            score15.append(calculate_f1(keyphrases, labels[id], 15))
        print_to_json(data_name, 5, score5)
        print_to_json(data_name, 10, score10)
        print_to_json(data_name, 15, score15)