Markdown cell


In [9]:
#import statements

import numpy as np
import matplotlib as plt


In [10]:
def entropy(pmf: np.array) -> float:
    """This function returns the entropy of the given pmf 

    The formula used is : -sum p(x) log p(x)

    Example
    --------
    >>> p = np.array([0.5, 0.5])
    >>> entropy(p)
    1.0
    """
    return -np.sum(pmf*np.log2(pmf))

def KL_divergence(p: np.array, q: np.array) -> float:
    """This function returns the KL divergence between the two given pmfs
    
    The formula used is : sum p(x) log p(x)/q(x) 
    
    Example
    -------
    >>> p = np.array([0.5, 0.5])
    >>> q = np.array([0.6, 0.4])
    >>> KL_divergence(p, q)
    0.029446844526784283
    """
    idx = np.argwhere( q == 0.0 )
    q[idx]+=1e-5
    idx = np.argwhere( p == 0.0 )
    p[idx]+=1e-5
    return np.sum(p * np.log2(p/q)) 

def cross_entropy(p: np.array, q:np.array) -> float:
    """This function returns the cross entropy betwen the two given pmfs
    
    The formula used is : H(p) + D(p||q)
    
    Example
    -------
    >>> p = np.array([0.5, 0.5])
    >>> q = np.array([0.6, 0.4])
    >>> cross_entropy(p, q)
    1.0294468445267844
    """
    return entropy(p)+KL_divergence(p, q)

def JS_divergence(p: np.array, q:np.array) -> float:
    """This function return the Jenson Shannon divergence between
    the given two pmfs
    
    The formula used is : D(p||m)+D(q||m), where
    m = (p+q)/2
    
    Example
    -------
    >>> p = np.array([0.5, 0.5])
    >>> q = np.array([0.6, 0.4])
    >>> JS_divergence(p, q)
    0.014598313520947925
    """
    m = (p+q)/2
    return KL_divergence(p, m) + KL_divergence(q, m)



p = np.array([0.5, 0.5])
q = np.array([0.6, 0.4])
print(entropy(p))
print(KL_divergence(p, q))
print(cross_entropy(p, q))


1.0
0.029446844526784283
1.0294468445267844


In [11]:

JS_divergence(p,q)


0.014598313520947925