Names: Jorge Mazariegos & Cameron Knopp

In [38]:
# imports statements
import time
import string
import itertools
import operator
import math
import matplotlib.pyplot as plt
import numpy as np
import torch
import nltk
from scipy.stats import iqr
from statistics import median
from collections import defaultdict, OrderedDict, Counter
from bs4 import BeautifulSoup
#from gensim.models import KeyedVectors
from torch.utils.data import Dataset, DataLoader
from nltk import word_tokenize

nltk.download('stopwords')
from nltk.corpus import stopwords
set(stopwords.words('english'))

%matplotlib inline
plt.style.use('seaborn-paper')

[nltk_data] Downloading package stopwords to
[nltk_data]     /Users/camknopp/nltk_data...
[nltk_data]   Package stopwords is already up-to-date!


In [39]:
# preprocess should take in the datasets (.xml) and prepare them to be used
def preprocess(data):
    """
    Args:
        data (list(str)):
    Returns: a list of tokens and a list of tokenized sentences

    """
    #######################################################
    # initialize variables to use in preprocess
    #######################################################
    puns = []
    tokens = []
    stop = stopwords.words('english')
    
    #######################################################
    # Given a sentence, tokenize it and append it to a list
    #######################################################
    for sentence in data:
        puns.append(word_tokenize(sentence.lower())) # creates the list of all sentences
        
    #######################################################
    # Every sentence is tokenized, but let's grab each
    # individual word to make a vocab out of.
    #######################################################
    for sentence in puns:
        for word in sentence:
            if(word.isalpha()): # filter out punctuation
                tokens.append(word)
    #######################################################
    # Remove stop words from tokens
    #######################################################
    tokens_with_stop_words = tokens
    tokens = [token for token in tokens_with_stop_words if token not in stop]

    return tokens, puns

In [58]:
def data_process(file):

    # DATA PROCESSING #
    #######################################################
    # Open the dataset/'s we will be using and process the
    # text within to be used by our code.
    #######################################################
    #f = open('datasets/data/test/subtask1-heterographic-test.xml', 'r', encoding = 'utf8')
    
    f = open(file, 'r', encoding = 'utf8')
    data = f.read()

    #######################################################
    # Using Beautiful Soup we can easily extract the puns
    # from the given datasets.
    #######################################################
    soup = BeautifulSoup(data, 'xml')
    ids = soup.find_all('text')
    words = soup.find_all('word')
    #######################################################
    # Create a list of all sentences within the dataset to hand
    # over to our preprocess function
    #######################################################
    wurd = ""
    sentence = ""
    sentences = []
    pun_list = []
    for i in range(len(ids)):
        for line in ids[i]:
            for word in line:
                if(word != '\n' or word == '\''):
                    if(word.isalpha()): # If not punctuation
                        wurd = word
                        if(sentence == ""): # If the start of the sentence
                            sentence = sentence + wurd
                        else: # If not the start of the sentence
                            sentence = sentence + " " + wurd
                    else: # If punctuation we don't want to put a space between the character and it.
                        wurd = word
                        sentence = sentence + wurd
                    wurd = "" # clear the current word
        sentences.append(sentence) # append the created string sentence to our list.
        sentence = ""
    #######################################################
    # Create a list of tokens to make a vocabulary of and
    # create a list of sentences to create make word pairs
    # from.
    #######################################################
    
    # return token, pun_list
    return preprocess(sentences)
    

In [59]:
class Vocabulary:
    def __init__(self, special_tokens=None):
        self.w2idx = {}
        self.idx2w = {}
        self.w2cnt = defaultdict(int)
        self.special_tokens = special_tokens
        if self.special_tokens is not None:
            self.add_tokens(special_tokens)

    def add_tokens(self, tokens):
        for token in tokens:
            self.add_token(token)
            self.w2cnt[token] += 1

    def add_token(self, token):
        if token not in self.w2idx:
            cur_len = len(self)
            self.w2idx[token] = cur_len
            self.idx2w[cur_len] = token

    def prune(self, min_cnt=2):
        to_remove = set([token for token in self.w2idx if self.w2cnt[token] < min_cnt])
        if self.special_tokens is not None:
            to_remove = to_remove.difference(set(self.special_tokens))
        
        for token in to_remove:
            self.w2cnt.pop(token)
            
        self.w2idx = {token: idx for idx, token in enumerate(self.w2cnt.keys())}
        self.idx2w = {idx: token for token, idx in self.w2idx.items()}
    
    def __contains__(self, item):
        return item in self.w2idx
    
    def __getitem__(self, item):
        if isinstance(item, str):
            return self.w2idx[item]
        elif isinstance(item , int):
            return self.idx2w[item]
        else:
            raise TypeError("Supported indices are int and str")
    
    def __len__(self):
        return(len(self.w2idx))

In [60]:
#######################################################
# Using skipgrams we can create the wordpairs described
# in the N-Hance research paper.
#######################################################

class SkipGramDataset(Dataset):
    def __init__(self, data, vocab, skip_window=3):
        super().__init__()

        #######################################################
        # Unlike before, data will be a list of strings handed
        # all at once.
        #######################################################
        self.vocab = vocab
        self.data = data
        # set skip_window to the length of the longest sentence in the data set
        self.skip_window =  max(data, key=len)
        self.pairs = self._generate_pairs(data, skip_window)
        
    #######################################################
    #
    #######################################################
    def _generate_pairs(self, data, skip_window):
        """
        Args: input data (a list of lists of words for each sentence (i.e, each list of words is a sentence))
        Returns: a list of lists. Each list will contain the word pairs for a given sentence in the input dataset
        """
        pairs = [[]]  # list of word pairs for each sentence
        curr_sentence_pairs = [] # list of word pairs for current sentence
        pruned_pairs = []
        

        for sent in data: 
            for i in range(len(sent)):
                for j in range(-skip_window, skip_window + 1):
                    context_idx = i + j
                    if j == 0 or context_idx < 0 or context_idx >= len(sent):
                        continue
                    if sent[i] not in self.vocab or sent[context_idx] not in self.vocab:
                        continue
                        
                    # only add in this sentence if the reverse does not already exist in the list
                    if (sent[context_idx], sent[i]) not in curr_sentence_pairs:
                        curr_sentence_pairs.append((sent[i], sent[context_idx]))
                    
            pairs.append(curr_sentence_pairs.copy()) # need to append a copy so that it is not cleared with we call clear() in the next line
            curr_sentence_pairs.clear()
                    
        return pairs
    
    #######################################################
    #
    #######################################################
    def __getitem__(self, idx):
        """
        Args:
            returns word_pairs for the sentence at the idx
        Returns:

        """
        pair = self.pairs[idx]

        #pair = [self.vocab[t] for t in pair]
        #pair = [self.vocab.__getitem__(t) for t in pair]
        return pair
    
    #######################################################
    #
    #######################################################
    def __len__(self):
        """
        Returns
        """
        return len(self.pairs)

In [63]:
def detect_puns(file, heterographic):
    """
    create word_pairs for sentences in given file
    calculate pmi scores for all given word_pairs
    calculate the interquartile range for the pmi scores of word_pairs in each sentence
    find the median value of the interquartile ranges across all sentences in the given dataset
    for each sentence, if the highest pmi score - second highest pmi score > median interquartile range ...
    (cont.) then that means that that sentence contains a pun
    """
    
    # homographic pun 5 would be referred to as hom5 in the final list (this is based on the N-Hance system's guidelines)
    if heterographic:
        prefix = "het"
    else:
        prefix = "hom"
    
    # Tokenize dataset and Create a Vocabulary using the tokens
    tokens, pun_list = data_process(file)
    voc = Vocabulary()
    voc.add_tokens(tokens)
    
    # create skipgram model using vocab and puns
    skipgram = SkipGramDataset(pun_list, voc, skip_window=2)
    
    # create a Counter object to get counts of individual words
    all_sentences = list(itertools.chain.from_iterable(skipgram.data.copy())) 
    word_counts = Counter(all_sentences)
    total_words = len(all_sentences)
    
    # get list of lists of word_pairs for each sentence
    word_pairs = skipgram.pairs.copy()
    word_pairs = [[(a,b) for (a,b) in sent] for sent in word_pairs] 
    print(word_pairs)
        
    # create Counter object to get counts for each word_pair
    all_word_pairs= list(itertools.chain.from_iterable(word_pairs.copy())) # join all sentences together
    all_word_pairs = [(a,b) for (a,b) in all_word_pairs] 
    total_word_pairs = len(all_word_pairs)
    word_pair_count = Counter(all_word_pairs)
    
    # create a list of dictionaries for each sentence { word_pair : pmi_score }
    pmi_scores = list(dict())
    current_pmi = 0
    current_dict = {}
    
    # now we will calculate the PMI score for each word_pair
    # the formula for PMI score is: log[p(x,y) / (p(x)*p(y))]
    
    for i in range(skipgram.__len__()):
        
        # for each sentence, find pmi score for each individual word_pair
        for w_pair in word_pairs[i]:
            numerator = word_pair_count[w_pair] / total_word_pairs
            denominator = (word_counts[w_pair[0]] / total_words) * (word_counts[w_pair[1]] / total_words)
            current_pmi =  numerator / denominator
            current_pmi = math.log(current_pmi)
        
            current_dict.update({w_pair : current_pmi}) # add bigram's pmi score to dictionary at index i (the current sentence)
        
        pmi_scores.append(current_dict.copy())
        current_dict.clear()
        
    
    # now we sort the dictionary entries from highest->lowest based on value (PMI score)
    ordered_pmi_scores = list(OrderedDict())
    
    for i in range(len(pmi_scores)):
        current_dict = pmi_scores[i]
        # convert to dictionary ordered by value (which is the pmi score in this case)
        current_ordered_dict = OrderedDict(sorted(current_dict.items(), key=lambda x: x[1], reverse=True))
        ordered_pmi_scores.append(current_ordered_dict.copy())
        current_ordered_dict.clear()
    
    # now we need to find the interquartile range for each dictionary in the list using iqr from scipy.stats
    iqr_values = []
    
    for dictionary in ordered_pmi_scores:
        iqr_values.append(iqr(list(dictionary.values())))
    
    
    # now we take the median of these iqr values and take that as our iqr value of the current dataset
    median_iqr = median(iqr_values)
    
    # create a list which will contain True (yes, this sentence contains a pun) or False (no, this sentence does not contain a pun)
    # ... at each index
    contains_pun = []
     
    for i in range(len(ordered_pmi_scores)):
        curr_dict = list(ordered_pmi_scores[i].items())
        
        if len(curr_dict) > 1:
            # if the difference between the highest pmi score and second highest pmi score (cont.)
            #... is greater than the median iqr, then the sentence contains a pun
            if float(curr_dict[0][1] - curr_dict[1][1]) > median_iqr:
                contains_pun.append(prefix + str(i) + " 1" )
                
            else:
                contains_pun.append(prefix + str(i) + " 0" )
        else:
            contains_pun.append(prefix + str(i) + " 0" )
        

    return contains_pun

In [65]:
contains_pun_heterographic = detect_puns('datasets/data/test/subtask1-heterographic-test.xml', True)
contains_pun_homographic = detect_puns('datasets/data/test/subtask1-homographic-test.xml', False)


[[], [('tom', 'alleged')], [('chinese', 'laborer'), ('laborer', 'said'), ('said', 'tom'), ('said', 'coolly'), ('tom', 'coolly')], [('baby', 'oil'), ('come', 'squeezing'), ('squeezing', 'dead'), ('squeezing', 'babies'), ('dead', 'babies')], [('like', 'hard'), ('hard', 'day')], [('evil', 'wildebeests'), ('evil', 'bad'), ('wildebeests', 'bad'), ('wildebeests', 'gnus'), ('bad', 'gnus')], [], [('busy', 'barber'), ('barber', 'quite'), ('quite', 'harried')], [('name', 'avery'), ('raise', 'birds')], [('two', 'construction'), ('two', 'workers'), ('construction', 'workers'), ('stairing', 'contest')], [('horses', 'friesian')], [('heel', 'said'), ('said', 'tom'), ('said', 'archly'), ('tom', 'archly')], [('old', 'electricians'), ('old', 'never'), ('electricians', 'never'), ('electricians', 'die'), ('never', 'die')], [('yesterday', 'accidentally'), ('accidentally', 'swallowed'), ('swallowed', 'food'), ('food', 'coloring'), ('doctor', 'says'), ('feel', 'like'), ('dyed', 'little'), ('little', 'inside'



In [49]:
import torch.nn.functional as F

class SkipGramModel(torch.nn.Module):
    def __init__(self, vocab_size, embedding_dim):
        super().__init__()
        """
        Args:
            vocab_size (int): vocabulary size
            embedding_dim (int): the dimension of word embeddings
        """
        ### INSERT YOUR CODE BELOW ###
        #self.embedding = torch.nn.Embedding(vocab_size, embedding_dim)
        #self.linear = torch.nn.Linear(1, vocab_size)
        self.embedding = torch.nn.Embedding(vocab_size, embedding_dim)
        self.linear = torch.nn.Linear(embedding_dim, vocab_size)
        
        ### INSERT YOUR CODE ABOVE ###

    def forward(self, inputs):
        """
        Perform the forward pass of the skip-gram model.
        
        Args:
            inputs (torch.LongTensor): input tensor containing batches of word ids [Bx1]
        Returns:
            outputs (torch.FloatTensor): output tensor with unnormalized probabilities over the vocabulary [BxV]
        """
        ### INSERT YOUR CODE BELOW ###
        embeds = self.embedding(inputs)
        #embeds = self.embedding(inputs)
        outputs = self.linear(embeds)
        outputs=outputs
        #output = F.log_softmax(self.linear(embeds), dim=1)
        ### INSERT YOUR CODE ABOVE ###
        return outputs
    
    def save_embeddings(self, voc, path):
        """
        Save the embedding matrix to a specified path.
        
        Args:
            voc (Vocabulary): the Vocabulary object for id-to-token mapping
            path (str): the location of the target file
        """
        ### INSERT YOUR CODE BELOW ###
        embeds = self.embedding.weight.data.cpu().numpy()
        f = open(path, 'w')
        f.write(str(vocab_size) + ' ' + str(embedding_dim) + '\n')
        
        for idx in range(len(embeds)):
            word = voc.idx2w[idx]
            embedding = ' '.join(map(str,embeds[idx]))
            f.write(word + ' '+ embedding + '\n')
        ### INSERT YOUR CODE ABOVE ###
        print("Successfuly saved to {}".format(path))