In [1]:
import pyspark
import pyspark.sql
from pyspark.sql import *
from pyspark.sql.functions import *
import json
import urllib
import argparse
%matplotlib inline

conf = pyspark.SparkConf().setMaster("local[*]").setAll([
                                   ('spark.driver.memory','240g'),
                                   ('spark.driver.maxResultSize', '32G'),
                                   ('spark.local.dir', '/scratch/tmp/'),
                                   ('spark.yarn.stagingDir', '/scratch/tmp/'),
                                   ('spark.sql.warehouse.dir', '/scratch/tmp/')
                                  ])

# create the session
spark = SparkSession.builder.config(conf=conf).getOrCreate()
# create the context
sc = spark.sparkContext

In [2]:
spark

In [3]:
site = 'enwiki'

In [4]:
text = spark.read.json("/dlabdata1/piccardi/WikiPDA/datasets/enwiki/plain_text/*/*.bz2")
text

DataFrame[id: string, text: string, title: string, url: string]

In [5]:
training_ids = spark.read.parquet("datasets/{}/heldout_documents.parquet".format(site))\
            .where("heldout = FALSE")

training_ids

DataFrame[qid: string, heldout: boolean]

In [6]:
anchors_info_qid = spark.read.parquet("datasets/{}/anchors_info_qid.parquet".format(site))
anchors_info_qid

DataFrame[page_id: bigint, qid: string, title: string, anchor: string, link: string, destination: string, destination_qid: string]

In [7]:
wikidata_ids = anchors_info_qid.select("page_id", "qid").distinct()
wikidata_ids

DataFrame[page_id: bigint, qid: string]

In [8]:
training_page_ids = wikidata_ids.join(training_ids, wikidata_ids.qid==training_ids.qid).select("page_id")
training_page_ids

DataFrame[page_id: bigint]

In [9]:
text_training = text.alias("t").join(training_page_ids, text.id==training_page_ids.page_id).select("t.*")
text_training

DataFrame[id: string, text: string, title: string, url: string]

In [10]:
# nltk.download('all')

In [11]:
import nltk
stopwordList = nltk.corpus.stopwords.words('english')
stopwordList = set(stopwordList+['see', 'also', 'references', 'category', 'external', 'links'])

In [12]:
import re
nan = re.compile("^\d+px$|^\d+x\d+px$|^\d+$|^[0-9abcdef]{6}$")

tokens = spark.createDataFrame(
        text_training.rdd.map(lambda r: 
                              Row(tokens=[w for w in re.split('\W+', re.sub(r'<[^>]+>', '', r.text)) 
                                          if len(w)>1 and not nan.match(w) 
       and w.lower() not in stopwordList]))
)

In [13]:
from pyspark.ml.feature import CountVectorizer
from pyspark.ml.clustering import LDA, LocalLDAModel

wordsVector = CountVectorizer(inputCol="tokens", outputCol="features", minDF=10)
transformer = wordsVector.fit(tokens)
result = transformer.transform(tokens)

result.cache()

DataFrame[tokens: array<string>, features: vector]

In [14]:
transformer.write().overwrite().save("models/TextEng/{}/LDA_transformer.model".format(site))

In [15]:
result.select("features").write.mode("overwrite").parquet("models/TextEng/{}/traning_set.parquet".format(site))

In [16]:
result.count()

5559197

In [17]:
exit()

In [18]:
# from pyspark.ml.clustering import LDA

# for topics_count in [20, 25, 30, 35, 40, 50, 70, 90, 110, 130, 150, 170, 190, 210]:
#     lda = LDA(k=topics_count, seed=42, maxIter=10)
#     model = lda.fit(result)
#     model.save("models/OnlyText/LDA_model_{}.model".format(topics_count))