In [1]:
import torch
from torch import tensor

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
from math import log

In [3]:
# manual methods
listmean = lambda A,B : [(a+b)/2 for a,b in zip(A,B)]

kl = lambda P,Q : sum(p * (log(p/q) if p!=0 and q!=0 else 0) for p,q in zip(P,Q))

In [4]:
#torch methods

from torch.nn import functional as F
from torch.nn.functional import kl_div

kl_torch = lambda p,q : kl_div(tensor([q]).log(), tensor([p]), reduction='batchmean')

In [5]:
q1 = [.05, .2, .4, .35]
q2 = [.3, .3, .3, .1]
y = [0,0,1,0]

In [6]:
def test_kl(p,q):
    print(f'via torch: {kl_torch(p,q)}')
    print(f'via my code: {kl(p,q)}')

In [7]:
test_kl(q1,y)

via torch: inf
via my code: -0.366516292749662


In [8]:
test_kl(y,q1)

via torch: 0.9162907004356384
via my code: 0.9162907318741551


In [9]:
# JS

In [10]:
def js(p1,p2):
    m = listmean(p1,p2)
    return .5*(kl(p1,m) + kl(p2,m))

In [11]:
def js_div(p_1, p_2, reduction = 'batchmean', log_target = False, pi = .5):
    m = pi * p_1 + (1-pi) * p_2; #print(f'    m of {p_1} and {p_2} is {m}')
    pi_kl_input_m = pi * F.kl_div(m.log(), p_1, reduction=reduction, log_target=log_target); #print(f'    kl({p_1}||{m}) times {pi} is {pi_kl_input_m}')
    onepi_kl_target_m = (1-pi) * F.kl_div(m.log(), p_2, reduction=reduction, log_target=log_target); #print(f'    kl({p_2}||{m}) times {1-pi} is {onepi_kl_target_m}')
    return pi_kl_input_m + onepi_kl_target_m

In [12]:
js_torch = lambda p,q : js_div(tensor([p]), tensor([q]))

In [13]:
def test_js(p,q):
    print(f'via torch: {js_torch(p,q)}')
    print(f'via my code: {js(p,q)}')

In [14]:
test_js(q1,q2)

via torch: 0.09492215514183044
via my code: 0.09492217667402233


In [15]:
test_js(q2,q1)

via torch: 0.09492215514183044
via my code: 0.09492217667402233


In [16]:
# GJS

In [17]:
gjs = lambda p1,p2,y : js(y, [(p+q)/2 for p,q in zip(p1,p2)]) + js(p1,p2)/2

In [25]:
from ceUtils import GJSDivLoss

gjs_torch = lambda p1,p2,y : GJSDivLoss()(tensor([p1]), tensor([p2]), tensor([y]))

In [26]:
def test_gjs(q1,q2,y):
    print(f'via torch: {gjs_torch(q1,q2,y)}')
    print(f'via my code: {gjs(q1,q2,y)}')

In [27]:
test_gjs(q1,q2,y)

(p1 + p2)/2 = tensor([[0.1750, 0.2500, 0.3500, 0.2250]])
1-pi = 0.5
    m of tensor([[0, 0, 1, 0]]) and tensor([[0.1750, 0.2500, 0.3500, 0.2250]]) is tensor([[0.0875, 0.1250, 0.6750, 0.1125]])
    kl(tensor([[0, 0, 1, 0]])||tensor([[0.0875, 0.1250, 0.6750, 0.1125]])) times 0.5 is 0.19652128219604492
    kl(tensor([[0.1750, 0.2500, 0.3500, 0.2250]])||tensor([[0.0875, 0.1250, 0.6750, 0.1125]])) times 0.5 is 0.11033640801906586
js_pi(y, (p1+p2)/2) = 0.306857705116272
    m of tensor([[0.0500, 0.2000, 0.4000, 0.3500]]) and tensor([[0.3000, 0.3000, 0.3000, 0.1000]]) is tensor([[0.1750, 0.2500, 0.3500, 0.2250]])
    kl(tensor([[0.0500, 0.2000, 0.4000, 0.3500]])||tensor([[0.1750, 0.2500, 0.3500, 0.2250]])) times 0.5 is 0.05039357393980026
    kl(tensor([[0.3000, 0.3000, 0.3000, 0.1000]])||tensor([[0.1750, 0.2500, 0.3500, 0.2250]])) times 0.5 is 0.04452858120203018
js_.5(p1, p2) * (1-pi) = 0.04746107757091522
via torch: 0.3543187975883484
via my code: 0.35431879720570963


In [24]:
kl_torch([0,0,1,0],[.0875,.125,.675,.1125])

tensor(0.0491)