Skip to content

关键字抽取算法 TextRank 之 Scala 实现 #2

@debugger87

Description

@debugger87

关键字提取的方法有很多,本文会介绍一种 Graph-based 关键字抽取算法,并给出 Scala 实现。

PageRank

事实上,TextRank 的基本思想源自 PageRank。为了对网页进行排序,Google 提出了 PageRank 这种基于图的算法。按照 Google 的思想,互联网可以抽象为图这种数据结构。若干网页构成图上的结点,网页之间的链接则构成边。每个网页的排名(rank)可以通过多次迭代计算得到。在一次迭代过程中,每个网页对其邻居(neighbour)贡献(contributes) r/n 。其中 r 为该网页 rank,n 为该网页邻居数。之后再根据每个网页得到的贡献(contributions),更新其 rank。更新方法如下:

{% math-block %}
r_{i}=\frac{\alpha}{N} + (1-\alpha)\sum_{i=1}^{n}{c_i} \qquad (\alpha \in [0,1))
{% endmath-block %}

经过若干次迭代之后,每个网页会得到一个 rank,之后便可根据这个 rank 来对网页排序。

TextRank

在 PageRank 的基础上,TextRank 做了一些修改,但基本方法是一致的。让我们来看看 TextRank 的形式化描述:

给定 G = (V, E) 为有向图,其中 V 是顶点集(单词集合),E 是边集。每次迭代更新各个单词 rank 公式如下:

{% math-block %}
WS_{V_{i}}=(1-d) + d\sum_{V_{j}\in{In(V_{i})}}^{n}\frac{w_{ji}}{\sum_{V_{k}\in{Out(V_{j})}}w_{jk}}WS_{V_{j}} \qquad (d \in [0,1))
{% endmath-block %}

Scala 实现

因为 TextRank 是基于图的算法,第一步需要做的便是将整篇文档中的单词构造成一个 TextGraph。考虑到关键词一般为名词或命名实体,我们除了使用 ansj 做词性标注抽取名词,还使用 stanford 的 NER 库抽取命名实体。根据同现信息构造 TextGraph。

import java.io.StringReader
import java.text.BreakIterator
import java.util.Locale

import akka.event.slf4j.SLF4JLogging
import edu.stanford.nlp.process.{CoreLabelTokenFactory, PTBTokenizer}
import edu.stanford.nlp.tagger.maxent.MaxentTagger
import org.ansj.domain.Term
import org.ansj.splitWord.analysis.ToAnalysis
import org.ansj.util.recognition.NatureRecognition
import org.graphstream.graph.Node
import org.graphstream.graph.implementations.SingleGraph

import scala.collection.mutable
import scala.io.Source

/**
 * Created by yangchaozhong on 12/30/14.
 */
class TextGraph(val graphName: String,
                val doc: String) extends SLF4JLogging {
  import TextGraph._

  val graph = new SingleGraph(graphName)
  constructTextGraph

  private def constructTextGraph = {
    val bi = BreakIterator.getSentenceInstance(Locale.CHINESE)
    bi.setText(doc)
    var lastIndex = bi.first()
    while (lastIndex != BreakIterator.DONE) {
      val firstIndex = lastIndex
      lastIndex = bi.next()

      if (lastIndex != BreakIterator.DONE &&
          Character.isLetterOrDigit(doc.charAt(firstIndex))) {
        val sentence = doc.substring(firstIndex, lastIndex)
        var wordSet: mutable.HashSet[String] = mutable.HashSet.empty
        if (containsChinese(sentence)) {
          wordSet = chinesWordSet(sentence)
        } else {
          wordSet = englishWordSet(sentence)
        }

        val wordList = wordSet.toList
        wordList foreach {
          word => if (graph.getNode(word) == null) graph.addNode(word)
        }

        wordList.combinations(2).toList foreach {
          words =>
            if (graph.getEdge(s"${words(0)}-${words(1)}") == null &&
                graph.getEdge(s"${words(1)}-${words(0)}") == null) {
              graph.addEdge(s"${words(0)}-${words(1)}", words(0), words(1))
              None
            }
        }
      }
    }

    graph.getNodeSet.toArray.map(_.asInstanceOf[Node]).foreach {
      node =>
        log.info(s"${node.getId}:${node.getDegree}")
    }
  }

  private def chinesWordSet(sentence: String) = {
    val terms = ToAnalysis.paser(sentence)
    new NatureRecognition(terms).recogntion()
    val wordSet = new mutable.HashSet[String]()
    terms.toArray.foreach {
      term =>
        val word = term.asInstanceOf[Term].getName
        val nature = term.asInstanceOf[Term].getNatrue.natureStr
        if (!(nature == "null") && word.length >= 2) {
          val reg = "^[ne]".r
          if (reg.findFirstMatchIn(nature).isDefined && !stopwords.contains(word))
            wordSet.add(word.toLowerCase)
        }
    }

    wordSet
  }

  private def englishWordSet(sentence: String) = {
    val ptbt = new PTBTokenizer(new StringReader(sentence), new CoreLabelTokenFactory, "")
    val wordSet = new mutable.HashSet[String]()
    while (ptbt.hasNext) {
      val label = ptbt.next()
      val tagged = tagger.tagString(label.word())
      val start = tagged.lastIndexOf("_") + 1
      val reg = "^[N]".r
      if (reg.findFirstMatchIn(tagged.substring(start)).isDefined &&
          !stopwords.contains(label.word().toLowerCase()) &&
          label.word().toLowerCase.length >= 3)
        wordSet.add(label.word().toLowerCase)
    }

    wordSet
  }

  private def containsChinese(doc: String) = {
    doc.count {
      word =>
        val ub = Character.UnicodeBlock.of(word)
        if (ub == Character.UnicodeBlock.CJK_UNIFIED_IDEOGRAPHS ||
            ub == Character.UnicodeBlock.CJK_COMPATIBILITY_IDEOGRAPHS ||
            ub == Character.UnicodeBlock.CJK_UNIFIED_IDEOGRAPHS_EXTENSION_A ||
            ub == Character.UnicodeBlock.CJK_UNIFIED_IDEOGRAPHS_EXTENSION_B ||
            ub == Character.UnicodeBlock.CJK_SYMBOLS_AND_PUNCTUATION ||
            ub == Character.UnicodeBlock.HALFWIDTH_AND_FULLWIDTH_FORMS ||
            ub == Character.UnicodeBlock.GENERAL_PUNCTUATION) {
          true
        } else false
    } > 0
  }
}

object TextGraph {
  val tagger = new MaxentTagger("taggers/left3words-wsj-0-18.tagger")
  val stopwords = Source.fromURL(getClass.getResource("/stopwords/stopwords-en.txt")).
    getLines().toSet
}

有了 TextGraph,采用 TextRank 算法多次迭代即可抽取文档关键词。这里的实现将公式中的 d 设置为了 0.85, 每条边的权值统一设为 1.0。

import org.graphstream.graph.{Edge, Node}

import scala.collection.mutable

/**
 * Created by yangchaozhong on 12/30/14.
 */
object KeywordExtractor {

  def extractKeywords(doc: String) = {
    val graph = new TextGraph("keywords", doc).graph
    val nodes = graph.getNodeSet.toArray.map(_.asInstanceOf[Node])
    val scoreMap = new mutable.HashMap[String, Float]()

    // Initialization
    nodes.foreach(node => scoreMap.put(node.getId, 1.0f))

    // Iteration
    (1 to 500).foreach {
      i =>
        nodes.foreach {
          node =>
            val edges = node.getEdgeSet.toArray.map(_.asInstanceOf[Edge])
            var score = 1.0f - 0.85f
            edges.foreach {
              edge =>
                val node0 = edge.getNode0.asInstanceOf[Node]
                val node1 = edge.getNode1.asInstanceOf[Node]
                val tempNode = if (node0.getId.equals(node.getId)) node1 else node0
                score += 0.85f * (1.0f * scoreMap(tempNode.getId) / tempNode.getDegree)
            }
            scoreMap.put(node.getId, score)
        }
    }

    scoreMap.toList.sortWith(_._2 > _._2).slice(0, 20).map(_._1)
  }
}

项目地址: https://github.com/debugger87/textrank

Metadata

Metadata

Assignees

No one assigned

    Labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions