In [8]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import findspark
findspark.init()

import re
from pyspark.sql import *
from pyspark.sql.functions import *
from pyspark.sql.types import DateType, ArrayType, DoubleType

from pyspark.ml.clustering import LDA, LDAModel

spark = SparkSession.builder.getOrCreate()
sc = spark.sparkContext

In [9]:
result_lda = spark.read.parquet('result_lda.parquet')
result_lda.show()

+--------+--------------------+
|  textID|   topicDistribution|
+--------+--------------------+
|  671240|[6.58155862613258...|
| 3781244|[9.40955320799835...|
|  761241|[0.00146956887930...|
| 1921241|[7.87280666829929...|
| 2241240|[0.00141175238583...|
| 4001241|[6.90012310329306...|
|14091243|[7.49673752511427...|
|  821241|[0.99567634446045...|
|  441241|[0.24788906994921...|
| 1171241|[4.98378122588521...|
| 2251241|[3.46593827223224...|
| 4451244|[7.20856067245151...|
|  711243|[0.19224600963891...|
| 2221244|[5.73554304172097...|
|  311241|[3.87982960614948...|
| 1191240|[3.00254584416535...|
| 1041241|[0.14659772641234...|
|13961244|[0.99420762787174...|
|14231242|[0.21207572174111...|
|14761242|[1.90451862605562...|
+--------+--------------------+
only showing top 20 rows



In [16]:
sources_rdd = sc.textFile('../sample_data/now-samples-sources.txt').zipWithIndex().filter(lambda r: r[1] > 2).keys().map(lambda r: r.split('\t'))

#create schema and change data type for date
sources_schema = sources_rdd.map(lambda r: Row(textID=int(r[0]),nwords=int(r[1]),date=r[2],country=r[3],website=r[4],url=r[5],title=r[6],))
sources = spark.createDataFrame(sources_schema)

In [51]:
#udf to change list of topic distribution into multiple columns
def to_array(col):
    def to_array_(v):
        return v.toArray().tolist()
    return udf(to_array_, ArrayType(DoubleType()))(col)

distribution = (result_lda.withColumn("topic", to_array(col("topicDistribution"))).select(["textID"] + [col("topic")[i] for i in range(10)]))
country = sources.drop("date","nwords","title","url","website")
country = country.select(col("country"), col("textID").alias("c_textID"))
country_dist = distribution.join(country, distribution.textID == country.c_textID).drop("c_textID")
avg_countryTopics = country_dist.sort("textID").groupby("country").mean().drop("avg(textID)")

In [52]:
oldColumns = avg_countryTopics.schema.names[-10:]
newColumns = ['topic0', 'topic1', 'topic2', 'topic3', 'topic4', 'topic5', 'topic6', 'topic7', 'topic8', 'topic9']

In [56]:
for i in range(len(oldColumns)):
    avg_countryTopics = avg_countryTopics.withColumnRenamed(oldColumns[i], newColumns[i])

In [58]:
avg_dateTopics.printSchema()

root
 |-- year(date): integer (nullable = true)
 |-- month(date): integer (nullable = true)
 |-- avg(topic[0]): double (nullable = true)
 |-- avg(topic[1]): double (nullable = true)
 |-- avg(topic[2]): double (nullable = true)
 |-- avg(topic[3]): double (nullable = true)
 |-- avg(topic[4]): double (nullable = true)
 |-- avg(topic[5]): double (nullable = true)
 |-- avg(topic[6]): double (nullable = true)
 |-- avg(topic[7]): double (nullable = true)
 |-- avg(topic[8]): double (nullable = true)
 |-- avg(topic[9]): double (nullable = true)



In [64]:
dates = sources.drop("country","nwords","title","url","website")
dates = dates.select(col("date"), col("textID").alias("d_textID"))
date_dist = distribution.join(dates, distribution.textID == dates.d_textID).drop("d_textID")
avg_dateTopics = date_dist.sort("textID").groupBy(year("date"),month("date")).mean().drop('avg(textID)')

In [65]:
avg_dateTopics.show()

+----------+-----------+-------------------+-----------------+-------------------+-------------------+-------------------+-------------------+-------------------+-------------------+-----------------+-------------------+
|year(date)|month(date)|      avg(topic[0])|    avg(topic[1])|      avg(topic[2])|      avg(topic[3])|      avg(topic[4])|      avg(topic[5])|      avg(topic[6])|      avg(topic[7])|    avg(topic[8])|      avg(topic[9])|
+----------+-----------+-------------------+-----------------+-------------------+-------------------+-------------------+-------------------+-------------------+-------------------+-----------------+-------------------+
|      null|       null|0.07705842133774261|0.104873906575239|0.11313226969777704|0.05024047759621304|0.13021743878510075|0.15288021279302708|0.09880137527229328|0.07381243261961451|0.082581498775322|0.11640196654767081|
+----------+-----------+-------------------+-----------------+-------------------+-------------------+--------------

In [63]:
spark.read.parquet('avg_dateTopics.parquet').show()

+----+-----+-------------------+-----------------+-------------------+-------------------+-------------------+-------------------+-------------------+-------------------+-----------------+-------------------+
|year|month|             topic0|           topic1|             topic2|             topic3|             topic4|             topic5|             topic6|             topic7|           topic8|             topic9|
+----+-----+-------------------+-----------------+-------------------+-------------------+-------------------+-------------------+-------------------+-------------------+-----------------+-------------------+
|null| null|0.07705842133774261|0.104873906575239|0.11313226969777704|0.05024047759621304|0.13021743878510075|0.15288021279302708|0.09880137527229328|0.07381243261961451|0.082581498775322|0.11640196654767081|
+----+-----+-------------------+-----------------+-------------------+-------------------+-------------------+-------------------+-------------------+--------------

In [18]:
lda_model = LDA(k=10, maxIter=10, optimizeDocConcentration=True).setFeaturesCol('non_norm_features').fit(tfidf)

In [20]:
test = lda_model.transform(tfidf).drop('non_norm_features')

In [21]:
test.write.parquet('result_lda.parquet')

In [27]:
#read the text file and remove the first three rows (zip trick)
wlp_rdd = sc.textFile('../*-*-*.txt').zipWithIndex().filter(lambda r: r[1] > 2).keys()

In [28]:
#we split the elements separated by tabs
lines = wlp_rdd.map(lambda r: r.split('\t'))

#identify the columns
wlp_schema = lines.map(lambda r: Row(textID=int(r[0]),idseq=int(r[1]),word=r[2],lemma=r[3],pos=r[4]))
wlp = spark.createDataFrame(wlp_schema)
wlp.show()

+----------+-----------+-------+-------+-----------+
|     idseq|      lemma|    pos| textID|       word|
+----------+-----------+-------+-------+-----------+
|2654351732|   official|    nn2|1787313|  officials|
|2654351733|       have|    vh0|1787313|       have|
|2654351734|     accuse|    vvn|1787313|    accused|
|2654351735|       mine|vvg_nn1|1787313|     mining|
|2654351736|       firm|    nn2|1787313|      firms|
|2654351737|         of|     io|1787313|         of|
|2654351738|      react|    vvg|1787313|   reacting|
|2654351739|       like|     ii|1787313|       like|
|2654351740|    spoiled|    jj@|1787313|    spoiled|
|2654351741|      child|    nn2|1787313|   children|
|2654351742|           |      ,|1787313|          ,|
|2654351743|        but|    ccb|1787313|        but|
|2654351744|        the|     at|1787313|        the|
|2654351745|  tanzanian|     jj|1787313|  Tanzanian|
|2654351746| government|    nn1|1787313| government|
|2654351747|         's|     ge|1787313|      

In [29]:
wlp.groupBy('textID').count().show()

+-------+-----+
| textID|count|
+-------+-----+
|1787820|  169|
|1787313|  846|
|1787819|  386|
+-------+-----+



In [30]:
pos_remove = ['.',',',"\'",'\"','null']
wlp_nopos = wlp.filter(~wlp['pos'].isin(pos_remove)).filter(~wlp['pos'].startswith('m')).filter(~wlp['pos'].startswith('f')).drop('idseq','pos','word')

In [31]:
#np.save('our_stopwords',stopwords)
stopwords = sc.textFile('../our_stopwords.txt').collect()
print('Number of stopwords: ', len(stopwords))

Number of stopwords:  5639


In [32]:
#filter out stopwords and looking at the frequency of words without them
wlp_nostop = wlp_nopos.filter(~wlp['lemma'].isin(stopwords))
lemma_freq = wlp_nostop.groupBy('lemma').count().sort('count', ascending=False)
lemma_freq.show()

+-----------+-----+
|      lemma|count|
+-----------+-----+
|       mine|   10|
|   tanzania|    9|
|     mining|    8|
| government|    6|
|     cookie|    6|
|        fee|    6|
|    royalty|    6|
|     public|    6|
|legislation|    6|
|   industry|    6|
|     device|    6|
| investment|    6|
|      astro|    5|
|     change|    5|
|   investor|    5|
|     sector|    5|
|        use|    5|
|  tanzanian|    5|
|       high|    4|
|   minister|    4|
+-----------+-----+
only showing top 20 rows



In [33]:
#calculate percentiles and filtering out the lemmas above and below them
[bottom,top] = lemma_freq.approxQuantile('count', [0.8,0.99], 0.01)
lemma_tokeep = lemma_freq.filter(lemma_freq['count']>bottom).filter(lemma_freq['count']<top)
c = lemma_tokeep.count()
print('Number of lemmas left: %d'%c)
print('Percentage of lemmas left: %.2f'%(c/lemma_freq.count()*100))

Number of lemmas left: 46
Percentage of lemmas left: 12.81


In [34]:
#perform sql query and inner join
wlp_nostop.registerTempTable('wlp_nostop')
lemma_tokeep.registerTempTable('lemma_tokeep')

query = """
SELECT wlp_nostop.lemma, wlp_nostop.textID
FROM wlp_nostop
INNER JOIN lemma_tokeep ON wlp_nostop.lemma = lemma_tokeep.lemma
"""

wlp_kept = spark.sql(query)
wlp_bytext = wlp_kept.groupBy('textID').agg(collect_list('lemma'))\
                    .sort('textID')\
                    .withColumnRenamed('collect_list(lemma)','lemma_list')
wlp_bytext.show()

+-------+--------------------+
| textID|          lemma_list|
+-------+--------------------+
|1787313|[subject, subject...|
|1787819|[subject, set, re...|
|1787820|[astro, astro, as...|
+-------+--------------------+

