In [1]:
from math import exp  # exp(x) gives e^x

In [2]:
def grouper(seq, n):
    """Get all n-grams from a sequence

    An n-gram is a contiguous sub-sequence within `seq` of length `n`. This
    function extracts them (in order) from `seq`.

    Parameters
    ----------
    seq : sequence
        A sequence of token ids or words representing a transcription.
    n : int
        The size of sub-sequence to extract.

    Returns
    -------
    ngrams : list
    """
    ngrams = []

    for i in range(len(seq) - n + 1):
        ngrams.append(seq[i:i + n])

    return ngrams

In [3]:
seq = ["hello", "how", "are", "you", "today"]

In [4]:
grouper(seq, 2)

[['hello', 'how'], ['how', 'are'], ['are', 'you'], ['you', 'today']]

In [5]:
grouper(seq, 3)

[['hello', 'how', 'are'], ['how', 'are', 'you'], ['are', 'you', 'today']]

In [6]:
seq2 = [0, 1, 2, 3, 4, 5]

In [7]:
grouper(seq2, 3)

[[0, 1, 2], [1, 2, 3], [2, 3, 4], [3, 4, 5]]

In [17]:
def n_gram_precision(reference, candidate, n):
    """Compute the precision for a given order of n-gram

    Parameters
    ----------
    reference : sequence
        The reference transcription. A sequence of token ids or words.
    candidate : sequence
        The candidate transcription. A sequence of token ids or words
        (whichever is used by `reference`)
    n : int
        The order of n-gram precision to calculate

    Returns
    -------
    p_n : float
        The n-gram precision. In the case that the candidate has length 0,
        `p_n` is 0.
    """
    total = len(grouper(candidate, n))
    count = 0

    for ngram in grouper(candidate, n):
        if ngram in grouper(reference, n):
            count += 1

    return count / total

In [18]:
ref = ["how", "are", "you", "doing", "today"]

In [22]:
n_gram_precision(ref, seq, 4)

0.0

In [23]:
def brevity_penalty(reference, candidate):
    """Calculate the brevity penalty between a reference and candidate

    Parameters
    ----------
    reference : sequence
        The reference transcription. A sequence of token ids or words.
    candidate : sequence
        The candidate transcription. A sequence of token ids or words
        (whichever is used by `reference`)

    Returns
    -------
    BP : float
        The brevity penalty. In the case that the candidate transcription is
        of 0 length, `BP` is 0.
    """
    c = len(candidate)
    r = len(reference)
    brevity = r / c if c != 0 else 0
    if brevity == 0:
        return 0
    BP = 1 if brevity < 1 else exp(1 - brevity)
    
    return BP

In [33]:
s1 = "It is a guide to action that ensures that the military will forever heed Party commands"
s2 = "It is the guiding principle which guarantees the military forces always being under command of the Party"
s3 = "It is the practical guide for the army always to heed the directions of the party"

In [29]:
def seq_gen(sent):
    return sent.split()

In [35]:
ref1 = seq_gen(s1)
ref2 = seq_gen(s2)
ref3 = seq_gen(s3)

In [46]:
q1 = "It is a guide to action which ensures that the military always obeys the commands of the party"
q2 = "It is to insure the troops forever hearing the activity guidebook that party direct"

In [47]:
cand1 = seq_gen(q1)
cand2 = seq_gen(q2)

In [48]:
brevity_penalty(ref2, cand1)

1

In [50]:
brevity_penalty(ref3, cand2)

0.8668778997501817