In [1]:
import torch
from torch.nn import NLLLoss

def nansum(x):
  # Define nansum, as pytorch does not offer it inbuilt.
  return x[~torch.isnan(x)].sum()

def self_information(p):
  return -torch.log2(torch.tensor(p)).item()
  self_information(1 / 64)

In [2]:
def entropy(p):
  entropy = - p * torch.log2(p)
  # Operator `nansum` will sum up the non-nan number
  out = nansum(entropy)
  return out

entropy(torch.tensor([0.1, 0.5, 0.1, 0.3]))

tensor(1.6855)

In [3]:
def joint_entropy(p_xy):
  joint_ent = -p_xy * torch.log2(p_xy)
  # Operator `nansum` will sum up the non-nan number
  out = nansum(joint_ent)
  return out

joint_entropy(torch.tensor([[0.1, 0.5], [0.1, 0.3]]))

tensor(1.6855)

In [4]:
def conditional_entropy(p_xy, p_x):
  p_y_given_x = p_xy/p_x
  cond_ent = -p_xy * torch.log2(p_y_given_x)
  # Operator `nansum` will sum up the non-nan number
  out = nansum(cond_ent)
  return out
conditional_entropy(torch.tensor([[0.1, 0.5], [0.2, 0.3]]),
                    torch.tensor([0.2, 0.8]))


tensor(0.8635)

In [5]:
def mutual_information(p_xy, p_x, p_y):
  p = p_xy / (p_x * p_y)
  mutual = p_xy * torch.log2(p)
  # Operator `nansum` will sum up the non-nan number
  out = nansum(mutual)
  return out
mutual_information(torch.tensor([[0.1, 0.5], [0.1, 0.3]]),
torch.tensor([0.2, 0.8]), torch.tensor([[0.75, 0.25]]))


tensor(0.7195)

In [6]:
def kl_divergence(p, q):
  kl = p * torch.log2(p / q)
  out = nansum(kl)
  return out.abs().item()

In [7]:
torch.manual_seed(1)

tensor_len = 10000

p = torch.normal(0, 1, (tensor_len, ))
q1 = torch.normal(-1, 1, (tensor_len, ))
q2 = torch.normal(1, 1, (tensor_len, ))

p = torch.sort(p)[0]
q1 = torch.sort(q1)[0]
q2 = torch.sort(q2)[0]

In [8]:
kl_pq1 = kl_divergence(p, q1)
kl_pq2 = kl_divergence(p, q2)

similar_percentage = abs(kl_pq1 - kl_pq2) / ((kl_pq1 + kl_pq2) / 2) * 100

kl_pq1, kl_pq2, similar_percentage

(8582.0341796875, 8828.3095703125, 2.8290698237936858)

In [13]:
torch.log2(torch.tensor(1/44))

tensor(-5.4594)

In [15]:
torch.exp(torch.tensor(53.8))

tensor(2.3176e+23)