# PubMed bio-artical classifier in PySpark



Create the spark session

In [1]:
from pyspark.sql import SparkSession

spark = SparkSession.builder.getOrCreate()

Filter to remove empty lines and file structure

In [2]:
import re

begin_re = re.compile("<doc.*>$")
end_re = re.compile("^</doc>$")

def is_text(line):
    line = line.strip()
    if not line or begin_re.match(line) or end_re.match(line):
        return False
    
    return True

In [3]:
from pyspark import keyword_only
from pyspark.ml import Transformer
from pyspark.ml.param.shared import HasInputCol, HasOutputCol
from pyspark.sql.functions import udf
from pyspark.sql.types import ArrayType, StringType

class PunctuationStripper(Transformer, HasInputCol, HasOutputCol):
    
    @keyword_only
    def __init__(self, inputCol=None, outputCol=None):
        super(PunctuationStripper, self).__init__()
        kwargs = self.__init__._input_kwargs
        self.setParams(**kwargs)

    @keyword_only
    def setParams(self, inputCol=None, outputCol=None):
        kwargs = self.setParams._input_kwargs
        return self._set(**kwargs)

    def _transform(self, dataset):
        punct_re = re.compile(r'[^\w\s]', re.UNICODE)
        
        def strip(s):
            return punct_re.sub('', s)
        
        out_col = self.getOutputCol()
        in_col = dataset[self.getInputCol()]
        mapper = udf(strip, StringType())
        
        return dataset.withColumn(out_col, mapper(in_col))

load PubMed and non-PubMed articles in to dataframe

In [4]:
from pyspark.sql.types import StructType, StructField, DoubleType, StringType
schema = StructType([
    StructField("fullText", StringType(), True), StructField("category", DoubleType(), False)
])

def load_article(category_name, category_id):
    text_file = spark.sparkContext.textFile("{}/*".format(category_name))
    return text_file.filter(is_text).map(lambda l: (l, float(category_id))).toDF(schema)

The loading and preprocessing should be turned into Spark tasks.

In [5]:
bio_articles = load_article("bio", 0)
other_articles = load_article("other", 1)

In [6]:
bio_articles.show()

+--------------------+--------+
|            fullText|category|
+--------------------+--------+
|Association of ma...|     0.0|
|Pregnancy, delive...|     0.0|
|Spermicidal activ...|     0.0|
|Laboratory tests ...|     0.0|
|The diagnosis of ...|     0.0|
|Preliminary resul...|     0.0|
|Cesarean hysterec...|     0.0|
|This paper review...|     0.0|
|               Skin.|     0.0|
|A reevaluation of...|     0.0|
|Puerperal tubal s...|     0.0|
|A study of 1830 p...|     0.0|
|Cesarean section-...|     0.0|
|Experience with c...|     0.0|
|Adrenalectomy for...|     0.0|
|Problems involved...|     0.0|
|The early diagnos...|     0.0|
|Coarctation of th...|     0.0|
|Haemorrhagic gang...|     0.0|
|The epidemiologic...|     0.0|
+--------------------+--------+
only showing top 20 rows



In [7]:
df = bio_articles.unionAll(other_articles)

The following is a workaround for pyspark not finding numpy, [taken from the GitHub issue](https://github.com/jupyter/docker-stacks/issues/109).

In [8]:
import os, sys

os.environ['PYTHONPATH'] = ':'.join(sys.path)

In [9]:
from pyspark.ml.feature import CountVectorizer, Tokenizer
from pyspark.ml.feature import HashingTF, IDF
from pyspark.ml.classification import NaiveBayes
from pyspark.ml import Pipeline

punctuation_stripper = PunctuationStripper(inputCol="fullText", outputCol="strippedText")
tokenizer = Tokenizer(inputCol="strippedText", outputCol="words")
# CountVectorizer and HashingTF both can be used to get term frequency vectors
# cv = CountVectorizer(inputCol="words", outputCol="rawFeatures")


TF -> IDF -> NaiveBayes

In [10]:
hashingTF = HashingTF(inputCol="words", outputCol="rawFeatures")
idf = IDF(inputCol="rawFeatures", outputCol="features")

nb = NaiveBayes(featuresCol="features", labelCol="category", modelType="multinomial")

pipeline = Pipeline(stages=[punctuation_stripper, tokenizer, hashingTF, idf, nb])

In [11]:
model = pipeline.fit(df)

Try classifying a few basic sentences.

In [12]:
tf = spark.createDataFrame([("Bactibilia has several consequences to human health", ),
                            ("Assessing the bile microbiology of patients with biliopancreatic diseases in order to identify bacteria and their possible infectious complications", ),
                            ("Thirty bile samples from patients at mean age ≈57.7 years, mostly female (n=18), were assessed. ", ),
                            ("Julius Caesar was a Roman general", ),
                            ("big data analysis is great", ),
                            ("do you know snow crash", ),
                           ], ["fullText"])
tf = model.transform(tf)
tf.select(tf['fullText'], tf['prediction']).show()

+--------------------+----------+
|            fullText|prediction|
+--------------------+----------+
|Bactibilia has se...|       0.0|
|Assessing the bil...|       0.0|
|Thirty bile sampl...|       0.0|
|Julius Caesar was...|       1.0|
|big data analysis...|       1.0|
|do you know snow ...|       1.0|
+--------------------+----------+

