In [1]:
# Data file at https://www.cse.ust.hk/msbd5003/data

from pyspark import SparkContext
import sys

print (sys.version) 


sc = SparkContext()
lines = sc.textFile('adj_noun_pairs.txt', 1)
lines.count()

2.7.12 (default, Nov 19 2016, 06:48:10) 
[GCC 5.4.0 20160609]


3162692

In [3]:
lines.count()

3162692

In [4]:
lines.getNumPartitions()

2

In [5]:
lines.take(5)

[u'early radical',
 u'french revolution',
 u'pejorative way',
 u'violent means',
 u'positive label']

In [6]:
# Converting lines into word pairs. 
# Data is dirty: some lines have more than 2 words, so filter them out.
pairs = lines.map(lambda l: tuple(l.split())).filter(lambda p: len(p)==2)
pairs.cache()

PythonRDD[5] at RDD at PythonRDD.scala:48

In [7]:
pairs.take(5)

[(u'early', u'radical'),
 (u'french', u'revolution'),
 (u'pejorative', u'way'),
 (u'violent', u'means'),
 (u'positive', u'label')]

In [8]:
N = pairs.count()

In [9]:
N

3162674

In [10]:
# Compute the frequency of each pair.
# Ignore pairs that not frequent enough
pair_freqs = pairs.map(lambda p: (p,1)).reduceByKey(lambda f1, f2: f1 + f2) \
                  .filter(lambda pf: pf[1] >= 100)

In [11]:
pair_freqs.take(5)

[((u'graphic', u'novel'), 117),
 ((u'much', u'debate'), 136),
 ((u'other', u'country'), 1857),
 ((u'other', u'book'), 223),
 ((u'first', u'election'), 263)]

In [12]:
# Computing the frequencies of the adjectives and the nouns
a_freqs = pairs.map(lambda p: (p[0],1)).reduceByKey(lambda x,y: x+y)
n_freqs = pairs.map(lambda p: (p[1],1)).reduceByKey(lambda x,y: x+y)

In [13]:
a_freqs.take(5)

[(u'fawn', 2),
 (u'base-paired', 3),
 (u'eicosapentanoic', 1),
 (u'host-cell', 2),
 (u'1,800', 1)]

In [14]:
n_freqs.count()
n_freqs.take(5)

[(u'fawn', 7),
 (u'xylem', 10),
 (u'ntsc-uk', 1),
 (u'buckskin', 1),
 (u'homomorphism', 63)]

In [50]:
# Broadcasting the adjective and noun frequencies. 
'''
a_dict = a_freqs.collectAsMap() #return a dictionary
print [a for i,a in enumerate(a_dict.items()) if i < 6]
print a_dict['violent']
a_dict = sc.parallelize(a_dict)  # become a RDD

import sys
reload(sys)
sys.setdefaultencoding('utf-8')
text_file = open("Output.txt", "w")
for a in a_dict.take(a_dict.count()):
    text_file.write("{0}\n".format(a))
text_file.close()
'''
n_dict = sc.broadcast(n_freqs.collectAsMap()) # dict of (noun, freq)
#print [a for i,a in enumerate(n_dict) if i < 6]
a_dict = sc.broadcast(a_freqs.collectAsMap())# dict of (adj, freq)
#print [a for i,a in enumerate(a_dict) if i < 6]
a_dict.value['violent']

In [22]:
from math import *

# Computing the PMI for a pair.
def pmi_score(pair_freq):
    w1, w2 = pair_freq[0]
    f = pair_freq[1]
    pmi = log(float(f)*N/(a_dict.value[w1]*n_dict.value[w2]), 2)
    return pmi, (w1, w2)

In [23]:
# Computing the PMI for all pairs.
scored_pairs = pair_freqs.map(pmi_score)

In [24]:
# Printing the most strongly associated pairs. 
scored_pairs.top(10)

[(14.41018838546462, (u'magna', u'carta')), (13.071365888694997, (u'polish-lithuanian', u'Commonwealth')), (12.990597616733414, (u'nitrous', u'oxide')), (12.64972604311254, (u'latter-day', u'Saints')), (12.50658937509916, (u'stainless', u'steel')), (12.482331020687814, (u'pave', u'runway')), (12.19140721768055, (u'corporal', u'punishment')), (12.183248694293388, (u'capital', u'punishment')), (12.147015483562537, (u'rush', u'yard')), (12.109945794428935, (u'globular', u'cluster'))]