# Spark assignment 2: Collocations

As for the second part of the assignment, your task is to extract collocations: that is word combinations that occur together. For example, “high school” or “roman empire”.

To find collocations, you will use NPMI (normalized pointwise mutual information) metric.

PMI of two words, a & b, is defined as “PMI(a, b) = ln (P(ab) / (P(a) * P(b))”, where P(ab) is the probability of two words coming one after the other, and P(a) and P(b) are probabilities of words a & b respectively.

You will estimate probabilities with occurrence counts, that is “P(a) = # of occurrences of word a / total number of words”, and “P(ab) = # of occurrences of words ‘a b’ / total number of word pairs”.

To build an intuition behind the definition, see Reading material.

Therefore, rare combinations of coupled words have large PMI.

NPMI is computed as “NPMI(a, b) = PMI(a, b) / -ln P(ab)”. This normalizes the quantity to be within the range [-1; 1].

You task is a bit more complicated now:

- Extract all the words, as in the previous task.
- Filter out stopwords using the dictionary (/datasets/stop_words_en.txt ) (do not forget to convert words to the lowercase!)
- Compute all bigrams (that is, pairs of consequent words)
- Leave only bigrams with at least 500 occurrences
- Compute NPMI for every bigram (note: when computing probabilities, you need unpruned counts!)
- Sort word pairs by NPMI in the descending order
- Print top 39 word pairs, with words delimited by the underscore “_”

For example,

    roman_empire
    south_africa

The part of the result on the sample dataset:
    
    ...
    references_reading
    notes_references
    award_best
    north_america
    new_zealand
    ...
 
**Hint**: if you did everything right, “roman_empire” and “south_africa” are going to be in the result.

If you want to deploy the environment on your own machine, please use [bigdatateam/spark-course1](https://hub.docker.com/r/bigdatateam/spark-course1/) Docker container.

In [1]:
#! /usr/bin/env python

from pyspark import SparkConf, SparkContext
sc = SparkContext(conf=SparkConf().setAppName("MyApp").setMaster("local[2]"))

import re
import math

In [2]:
stop_file = "/datasets/stop_words_en.txt"
wiki_file = "/data/wiki/en_articles_part/articles-part"
pair_thresh = 500

with open(stop_file, "r") as f:
    stop_words = f.read().splitlines()
    
stop_words_bcast = sc.broadcast(stop_words)

def parse_article(line):
    try:
        article_id, text = line.rstrip().split('\t', 1)
        text = re.sub("^\W+|\W+$", "", text, flags=re.UNICODE)
        words = re.split("\W*\s+\W*", text, flags=re.UNICODE)
        return words
    except ValueError as e:
        return []
    
def lower(words):
    return [word.lower() for word in words]

def filter_stop(words):
    return [word for word in words if word not in stop_words_bcast.value]

def pairs(words):
    out = []
    for w1, w2 in zip(words, words[1:]):
        out.append((w1.lower() + "_" + w2.lower(), 1))
    return out

wiki = (sc.textFile(wiki_file, 16)
         .map(parse_article)  
         .map(lower)
         .map(filter_stop)
        ).cache()

In [3]:
words = (wiki.flatMap(lambda wds : [(word, 1) for word in wds])
         .reduceByKey(lambda x,y: x+y)
        ).cache()

words_total = words.map(lambda value: value[1]).sum()
words_total = sc.broadcast(words_total)

words_count_map = words.collectAsMap()
words_count_map = sc.broadcast(words_count_map)

pairs = (wiki.flatMap(pairs)
         .reduceByKey(lambda x,y : x+y)
        ).cache()

pairs_total = pairs.map(lambda value: value[1]).sum()
pairs_total = sc.broadcast(pairs_total)

In [4]:
def npmi(value):
    pair, count = value
    w1, w2 = pair.split("_")
    w1_count = words_count_map.value[w1]
    w2_count = words_count_map.value[w2]
    
    pair_prob = float(count) / pairs_total.value
    w1_prob = float(w1_count) / words_total.value
    w2_prob = float(w2_count) / words_total.value
    
    pmi = math.log(pair_prob / (w1_prob * w2_prob))
    npmi = pmi / (-1 * math.log(pair_prob))
    return (pair, npmi)

npmi = (pairs
        .filter(lambda value: value[1] > pair_thresh)
        .map(lambda value: npmi(value))
        .sortBy(lambda value: value[1], ascending=False)
       ).cache()    

In [5]:
for pair, value in npmi.take(39):
    print(pair)

los_angeles
external_links
united_states
prime_minister
san_francisco
et_al
new_york
supreme_court
19th_century
20th_century
references_external
soviet_union
air_force
baseball_player
university_press
roman_catholic
united_kingdom
references_reading
notes_references
award_best
north_america
new_zealand
civil_war
catholic_church
world_war
war_ii
south_africa
took_place
roman_empire
united_nations
american_singer-songwriter
high_school
american_actor
american_actress
american_baseball
york_city
american_football
years_later
north_american
