By Suchin Gururangan and Swabha Swayamdipta, 2018; Modified by Sam Bowman, 2019; MIT License

Used to produce the PPMI statistics shown in this paper:
https://www.aclweb.org/anthology/N18-2017/

# PPMI Distributions

In [None]:
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import re
from collections import Counter
import string
from IPython.core.debugger import Tracer
import numpy as np
from nltk import PorterStemmer
from nltk.tokenize import word_tokenize
%matplotlib inline

## Load Data

In [None]:
# A few warnings are expected here due to format mismatches. These should not impact the target calculations.
mnli = pd.read_csv('/Users/srbowman/glue_data/MNLI/train.tsv',
                   sep = '\t',
                   error_bad_lines=False).dropna(subset=['sentence2'])
mnli = mnli[mnli['genre'] == 'government']  # Using only the 'government' genre section of MNLI here
mnli = mnli[0:8500]  # Trimming this to match the size of the other datasets under study.
mnli[0:10]

## Functions

In [None]:
def tokenize(string):
    """
    tokenize string from binary parse
    """

    tokens = word_tokenize(string.lower())
    return tokens

def pmi(df):
    """
    PPMI
    Args:
        df -> dataframe (mnli or snli)
    Returns:
        pmis -> pmi data
    """
    alpha = 10.0
    # get joint (w,c) counts
    Px = []
    Py = []
    Pxy = []
    for idx, row in df.iterrows():
        for word in tokenize(row.sentence2):
            Pxy.append((word, row.gold_label))
            Py.append(row.gold_label)
            # Two-class alternative to the above:
            # Pxy.append((word, "entailment" if row.gold_label == "entailment" else "contradiction"))
            # Py.append("entailment" if row.gold_label == "entailment" else "contradiction")
            Px.append(word)

    Pxy = Counter(Pxy)
    Px = Counter(Px)
    Py = Counter(Py)
    Py = {y: c + len(Px)*alpha for y, c in Py.items()}
    Px = {x: c + 3*alpha for x, c in Px.items()}
    total_word_count = sum(Px.values())

    # get c counts
    pmis = []
    
    seen_pairs = {}
    total_Py = 3
    for word in Px.keys():
        for cls in ['contradiction', 'entailment', 'neutral']:  # Remove 'neutral' for two-class version.
            if seen_pairs.get((word, cls)) is None:
                # artificial smoothing
                word_count = float(Px[word])
                cls_count = float(Py[cls])
                word_cls_count = Pxy.get((word, cls), 0.0) + alpha
                z = (float(word_cls_count) * total_word_count / (float(word_count) * float(cls_count)))
                # ppmi
                pmi = max([np.log(z), 0.0])
                output = {'word': word,
                          'class': cls,
                          'pmi': pmi,
                          'count': word_count - 3*alpha,
                          'wc_count': Pxy.get((word, cls), 0.0),}
                seen_pairs[(word, cls)] = 1
                pmis.append(output)
    pmis = pd.DataFrame(pmis)
    return pmis

## Run PPMI

In [None]:
mnli_pmi = pmi(mnli)

## Top 50 PPMIs across NLI Dataset

In [None]:
sns.set_style("whitegrid", {'axes.grid' : False})
mnli_pmi.sort_values(by='pmi', ascending=False).head(n=50)

## MNLI class PPMIs

In [None]:
sns.set_style("whitegrid", {'axes.grid' : False})
ax = (mnli_pmi.loc[mnli_pmi['class'] == 'contradiction']
              .sort_values(by='pmi', ascending=False)
              .head(n=20)
              .plot(kind='barh', x='word', y='pmi', title='top 20 PPMI(word, contradiction) MNLI', legend=False))


In [None]:
sns.set_style("whitegrid", {'axes.grid' : False})
ax = (mnli_pmi.loc[mnli_pmi['class'] == 'entailment']
              .sort_values(by='pmi', ascending=False)
              .head(n=20)
              .plot(kind='barh', x='word', y='pmi', title='top 20 PPMI(word, entailment) MNLI', legend=False))


In [None]:
sns.set_style("whitegrid", {'axes.grid' : False})
ax = (mnli_pmi.loc[mnli_pmi['class'] == 'neutral']
              .sort_values(by='pmi', ascending=False)
              .head(n=20)
              .plot(kind='barh', x='word', y='pmi', title='top 20 PPMI(word, neutral) MNLI', legend=False))
