In [2]:
# Data file at https://www.cse.ust.hk/msbd5003/data

lines = sc.textFile('../data/adj_noun_pairs.txt')

In [3]:
lines.count()

3162701

In [4]:
lines.getNumPartitions()

2

In [5]:
lines.take(5)

[u'{\\rtf1\\ansi\\ansicpg1252\\cocoartf1671\\cocoasubrtf600',
 u'{\\fonttbl\\f0\\fmodern\\fcharset0 Courier;}',
 u'{\\colortbl;\\red255\\green255\\blue255;\\red0\\green0\\blue0;}',
 u'{\\*\\expandedcolortbl;;\\cssrgb\\c0\\c0\\c0;}',
 u'\\paperw11900\\paperh16840\\margl1440\\margr1440\\vieww10800\\viewh8400\\viewkind0']

In [6]:
# Converting lines into word pairs. 
# Data is dirty: some lines have more than 2 words, so filter them out.
pairs = lines.map(lambda l: tuple(l.split())).filter(lambda p: len(p)==2)
pairs.cache()

PythonRDD[4] at RDD at PythonRDD.scala:53

In [7]:
pairs.take(5)

[(u'{\\fonttbl\\f0\\fmodern\\fcharset0', u'Courier;}'),
 (u'early', u'radical\\'),
 (u'french', u'revolution\\'),
 (u'pejorative', u'way\\'),
 (u'violent', u'means\\')]

In [8]:
N = pairs.count()

In [9]:
N

3161992

In [10]:
# Compute the frequency of each pair.
# Ignore pairs that not frequent enough
pair_freqs = pairs.map(lambda p: (p,1)).reduceByKey(lambda f1, f2: f1 + f2) \
                  .filter(lambda pf: pf[1] >= 100)

In [11]:
pair_freqs.take(5)

[((u'gay', u'man\\'), 226),
 ((u'legislative', u'election\\'), 167),
 ((u'other', u'group\\'), 771),
 ((u'non-profit', u'organization\\'), 158),
 ((u'manufacture', u'goods\\'), 104)]

In [12]:
# Computing the frequencies of the adjectives and the nouns
a_freqs = pairs.map(lambda p: (p[0],1)).reduceByKey(lambda x,y: x+y)
n_freqs = pairs.map(lambda p: (p[1],1)).reduceByKey(lambda x,y: x+y)

In [13]:
a_freqs.take(5)

[(u'fawn', 2),
 (u'base-paired', 3),
 (u'eicosapentanoic', 1),
 (u'host-cell', 2),
 (u'1,800', 1)]

In [14]:
n_freqs.count()

105898

In [27]:
# Broadcasting the adjective and noun frequencies. 
a_dict = a_freqs.collectAsMap()
n_dict = n_freqs.collectAsMap()
#a_dict = sc.parallelize(a_dict).map(lambda x: x)
#n_dict = sc.parallelize(n_dict).map(lambda x: x)


#n_dict = sc.broadcast(n_freqs.collectAsMap())
#a_dict = sc.broadcast(a_freqs.collectAsMap())
#a_dict.value['violent']

In [28]:
from math import *

# Computing the PMI for a pair.
def pmi_score(pair_freq):
    w1, w2 = pair_freq[0]
    f = pair_freq[1]
    pmi = log(float(f)*N/(a_dict[w1]*n_dict[w2]), 2)
    return pmi, (w1, w2)

In [29]:
# Computing the PMI for all pairs.
scored_pairs = pair_freqs.map(pmi_score)

In [30]:
# Printing the most strongly associated pairs. 
scored_pairs.top(10)

[(14.409877248711565, (u'magna', u'carta\\')),
 (13.07105475194194, (u'polish-lithuanian', u'Commonwealth\\')),
 (12.990286479980359, (u'nitrous', u'oxide\\')),
 (12.649414906359487, (u'latter-day', u'Saints\\')),
 (12.506278238346107, (u'stainless', u'steel\\')),
 (12.482019883934761, (u'pave', u'runway\\')),
 (12.191096080927498, (u'corporal', u'punishment\\')),
 (12.182937557540333, (u'capital', u'punishment\\')),
 (12.146704346809479, (u'rush', u'yard\\')),
 (12.109634657675876, (u'globular', u'cluster\\'))]