In [15]:
from collections import defaultdict
import numpy as np

import nltk
from nltk.tokenize import word_tokenize

In [3]:
ngram_size = 4
test_sentence = "This is a test sentence."
ngrams = nltk.ngrams(test_sentence.split(), ngram_size)

In [26]:
def tokenize(sentence, method="nltk", cased=True):
    
    sentence = sentence if cased else sentence.lower()
    
    if method == "simple":
        tokens = sentence.split()
    elif method == "nltk":
        tokens = nltk.word_tokenize(sentence)
    return tokens


def compute_sentence_ngram_overlap(sent1, sent2, n):
    """ Returns precision and recall of n-grams wrt sent1 """
    ngrams1 = list(nltk.ngrams(sent1, n)) # zip by default
    ngrams2 = list(nltk.ngrams(sent2, n))
    
    precision1 = 0.
    for ngram in ngrams1:
        precision1 += int(ngram in ngrams2)
    precision1 /= len(ngrams1)
        
    recall1 = 0.
    for ngram in ngrams2:
        recall1 += int(ngram in ngrams1)
    recall1 /= len(ngrams2)
    return {
            'precision': precision1,
            'recall': recall1,
           }
    
    
def compute_corpus_ngram_overlap(corp1, corp2, n=1, cased=True):
    
    assert len(corp1) == len(corp2), "Different number of sentences in corpora"
    corp1 = [tokenize(s, cased=cased) for s in corp1]
    corp2 = [tokenize(s, cased=cased) for s in corp2]
    
    overlaps = defaultdict(list)
    for sent1, sent2 in zip(corp1, corp2):
        overlap = compute_sentence_ngram_overlap(sent1, sent2, n)
        for k, v in overlap.items():
            overlaps[k].append(v)
            
    overlaps = {k: np.array(v) for k, v in overlaps.items()}
        
    return overlaps

In [31]:
corp1 = ["this is a sentence.", "lol we are hacking"]
corp2 = ["This is a sentence.", "yoooo are we hacking?"]
n = 1
cased = True

results = compute_corpus_ngram_overlap(corp1, corp2, n=n, cased=cased)
for k, v in results.items():
    print(f'{k}: mean {np.mean(v):.3f}, std {np.std(v):.3f}')
    print(f'\tmax {np.max(v):.3f}, min {np.min(v):.3f}')

precision: mean 0.775, std 0.025
	max 0.800, min 0.750
recall: mean 0.700, std 0.100
	max 0.800, min 0.600
