# TEST DIFFERENT APPROXIMATIONS FOR KL(logit)

In [3]:
%matplotlib inline
import matplotlib.pyplot as plt
import torch
import numpy
from genus.util_ml import Grid_DPP
from genus.util_vis import show_batch
import torch.nn.functional as F

def bernoulli_entropy(logit):
    p = torch.sigmoid(logit)
    one_m_p = torch.sigmoid(-logit)
    log_p = F.logsigmoid(logit)
    log_one_m_p = F.logsigmoid(-logit)
    entropy = - (p * log_p + one_m_p * log_one_m_p)
    return entropy

def logp_bernoulli(c, logit):
    log_p = F.logsigmoid(logit)
    log_one_m_p = F.logsigmoid(-logit)
    log_prob_bernoulli = (c.detach() * log_p + ~c.detach() * log_one_m_p)
    return log_prob_bernoulli

def all_configurations(grid_size_x, grid_size_y):
    all_c = torch.zeros(2**(grid_size_x*grid_size_y),grid_size_x,grid_size_y).flatten(start_dim=1).numpy()
    for n in range(2**(grid_size_x*grid_size_y)):
        tmp = bin(n).lstrip("0b").rjust(grid_size_x*grid_size_y, "0")
        all_c[n] =  numpy.array(list(tmp), dtype=int)
    all_c = torch.from_numpy(all_c).view(-1, grid_size_x,grid_size_y).bool()
    return all_c

# TRAINING

In [24]:
grid_size_x, grid_size_y = 3, 3
c_all = all_configurations(grid_size_x, grid_size_y).detach()

DPP_0 = Grid_DPP(length_scale=10, weight=0.01)
DPP_1 = Grid_DPP(length_scale=10, weight=0.01)
DPP_2 = Grid_DPP(length_scale=10, weight=0.01)
DPP_3 = Grid_DPP(length_scale=10, weight=0.01)
DPP_4 = Grid_DPP(length_scale=10, weight=0.01)


logit_0 = torch.rand((grid_size_x,grid_size_y)).requires_grad_(False)
logit_1 = logit_0.clone().requires_grad_(True)
logit_2 = logit_0.clone().requires_grad_(True)
logit_3 = logit_0.clone().requires_grad_(True)
logit_4 = logit_0.clone().requires_grad_(True)

params1 = [logit_1]
for name, param in DPP_1.named_parameters():
    params1.append(param)
    
optimizer1 = torch.optim.SGD([{'params': params1, 'lr': 1E-2}])


params2 = [logit_2]
for name, param in DPP_2.named_parameters():
    params2.append(param)
    
optimizer2 = torch.optim.SGD([{'params': params2, 'lr': 1E-2}])


params3 = [logit_3]
for name, param in DPP_3.named_parameters():
    params3.append(param)
    
optimizer3 = torch.optim.SGD([{'params': params3, 'lr': 1E-2}])


params4 = [logit_4]
for name, param in DPP_4.named_parameters():
    params4.append(param)
    
optimizer4 = torch.optim.SGD([{'params': params4, 'lr': 1E-2}])

In [30]:
mc_samples = 4
mse = 10.0

for epoch in range(10000):
    
    # Exact
    prob_1 = F.sigmoid(logit_1)
    c_before_nms = (torch.rand_like(logit_1) < prob_1)
    score = c_before_nms + prob_1
    c_after_nms = (score == torch.max(score)) * c_before_nms
    
    entropy_1 = bernoulli_entropy(logit_1).sum()
    logp_c_all = DPP_1.log_prob(c_all).detach()
    q_c_all = logp_bernoulli(c_all, logit_1).sum(dim=(-1,-2)).exp()
    logp_dpp_after = DPP_1.log_prob(c_after_nms).mean()
    
    loss1 = -entropy_1 - logp_dpp_after -(q_c_all * logp_c_all).sum() - mse * F.sigmoid(logit_1[0,0])
    optimizer1.zero_grad()
    loss1.backward()
    optimizer1.step() 
    
    # reinforce
    prob_2 = F.sigmoid(logit_2)
    c_before_nms = (torch.rand_like(logit_2.expand(mc_samples,-1,-1)) < prob_2)
    score = c_before_nms + prob_2
    c_after_nms = (score == torch.max(score)) * c_before_nms
    
    entropy_2 = bernoulli_entropy(logit_2).sum()
    logp_dpp_before = DPP_2.log_prob(c_before_nms)
    logp_dpp_after = DPP_2.log_prob(c_after_nms)
    logp_ber_before = logp_bernoulli(c_before_nms, logit_2).sum(dim=(-1,-2))
    d_tmp = (logp_dpp_before - logp_dpp_before.mean()).abs().mean().detach()
    reinforce_2 = logp_ber_before * (logp_dpp_before - logp_dpp_before.mean()).detach()
    
    loss2 = - entropy_2 - logp_dpp_after.mean() - reinforce_2.mean() - mse * F.sigmoid(logit_2[0,0])
    optimizer2.zero_grad()
    loss2.backward()
    optimizer2.step() 
    
    
    # reinforce + importance sampling
    prob_3 = F.sigmoid(logit_3)
    c_before_nms = (torch.rand_like(logit_3) < prob_3)
    score = c_before_nms + prob_3
    c_after_nms = (score == torch.max(score)) * c_before_nms
    
    logit_importance = logit_3.clamp(min=-3.5, max=3.5).detach()
    prob_importance = F.sigmoid(logit_importance)
    c_importance = (torch.rand_like(logit_importance.expand(mc_samples,-1,-1)) < prob_importance)
    
    entropy_3 = bernoulli_entropy(logit_3).sum()
    logp_dpp_after = DPP_3.log_prob(c_after_nms)
    logp_importance = DPP_3.log_prob(c_importance)
    logq_importance_star = logp_bernoulli(c_importance, logit_3).sum(dim=(-1,-2))
    logq_importance = logp_bernoulli(c_importance, logit_importance.detach()).sum(dim=(-1,-2))
    importance_weights = (logq_importance_star - logq_importance).exp().detach()
    reinforce_3 = logq_importance_star * (importance_weights*logp_importance - (importance_weights*logp_importance).mean()).detach()
        
    loss3 = - entropy_3 - logp_dpp_after.mean() - reinforce_3.mean() - mse * F.sigmoid(logit_3[0,0])
    optimizer3.zero_grad()
    loss3.backward()
    optimizer3.step() 

    
    if epoch % 500 == 0:
        print(epoch,entropy_1.item(),entropy_2.item(),entropy_3.item(),d_tmp,importance_weights.mean())
    

0 0.9330554008483887 1.4890918731689453 1.5636526346206665 tensor(0.) tensor(1.0249)
500 0.9256253838539124 1.5320041179656982 1.5609056949615479 tensor(0.) tensor(0.9681)
1000 0.9195586442947388 1.528131127357483 1.6069483757019043 tensor(0.) tensor(1.0247)
1500 0.9141694903373718 1.5548080205917358 1.6184022426605225 tensor(1.5625) tensor(1.0255)
2000 0.9091646671295166 1.4935671091079712 1.6164069175720215 tensor(2.6004) tensor(1.0244)
2500 0.9035624265670776 1.4640452861785889 1.5342923402786255 tensor(1.7359) tensor(1.0243)
3000 0.8974437713623047 1.4200860261917114 1.5832723379135132 tensor(1.0515) tensor(1.0001)
3500 0.891935408115387 1.5104097127914429 1.628333330154419 tensor(0.) tensor(1.0217)
4000 0.8872072696685791 1.5085636377334595 1.5596504211425781 tensor(1.2257) tensor(1.0232)
4500 0.8825173377990723 1.5245722532272339 1.6188315153121948 tensor(1.3097) tensor(1.0260)
5000 0.8782775402069092 1.5575309991836548 1.5757675170898438 tensor(2.6220) tensor(1.0263)
5500 0.8745

In [26]:
print("INITIAL PROB")
print("nav ->",F.sigmoid(logit_0).sum().item())
print(F.sigmoid(logit_0))

print("FINAL PROB_1")
print("nav ->",F.sigmoid(logit_1).sum().item())
print(F.sigmoid(logit_1))

print("FINAL PROB_2")
print("nav ->",F.sigmoid(logit_2).sum().item())
print(F.sigmoid(logit_2))

print("FINAL PROB_3")
print("nav ->",F.sigmoid(logit_3).sum().item())
print(F.sigmoid(logit_3))

# print("FINAL PROB_4")
# print("nav ->",F.sigmoid(logit_4).sum().item())
# print(F.sigmoid(logit_4))
# 
# print("FINAL DPP_1")
# print(DPP_1.fingerprint[:2])
# final_dpp_sample_1 = DPP_1.sample(size=logit_1.expand(10000,-1,-1).size())
# nav_final_1 = final_dpp_sample_1.sum(dim=(-1,-2)).float().mean()
# print("nav ->",nav_final_1)
# print(final_dpp_sample_1.float().mean(dim=0))
# 
# print("FINAL DPP_2")
# print(DPP_2.fingerprint[:2])
# final_dpp_sample_2 = DPP_2.sample(size=logit_2.expand(10000,-1,-1).size())
# nav_final_2 = final_dpp_sample_2.sum(dim=(-1,-2)).float().mean()
# print("nav ->",nav_final_2)
# print(final_dpp_sample_2.float().mean(dim=0))
# 
# print("FINAL DPP_3")
# print(DPP_3.fingerprint[:2])
# final_dpp_sample_3 = DPP_3.sample(size=logit_3.expand(10000,-1,-1).size())
# nav_final_3 = final_dpp_sample_3.sum(dim=(-1,-2)).float().mean()
# print("nav ->",nav_final_3)
# print(final_dpp_sample_3.float().mean(dim=0))
# 
# print("FINAL DPP_4")
# print(DPP_4.fingerprint[:2])
# final_dpp_sample_4 = DPP_4.sample(size=logit_4.expand(10000,-1,-1).size())
# nav_final_4 = final_dpp_sample_4.sum(dim=(-1,-2)).float().mean()
# print("nav ->",nav_final_4)
# print(final_dpp_sample_4.float().mean(dim=0))

INITIAL PROB
nav -> 5.736042499542236
tensor([[0.5315, 0.7233, 0.5104],
        [0.6619, 0.6944, 0.6782],
        [0.7123, 0.5970, 0.6271]])
FINAL PROB_1
nav -> 1.1985440254211426
tensor([[0.9975, 0.0104, 0.0282],
        [0.0104, 0.0141, 0.0314],
        [0.0283, 0.0313, 0.0470]], grad_fn=<SigmoidBackward>)
FINAL PROB_2
nav -> 1.3730865716934204
tensor([[0.9974, 0.0232, 0.0518],
        [0.0217, 0.0185, 0.0548],
        [0.0589, 0.0648, 0.0821]], grad_fn=<SigmoidBackward>)
FINAL PROB_3
nav -> 1.3659470081329346
tensor([[0.9885, 0.0237, 0.0553],
        [0.0286, 0.0348, 0.0582],
        [0.0531, 0.0597, 0.0639]], grad_fn=<SigmoidBackward>)


In [27]:
print(logit_1)
print(logit_2)
print(logit_3)

tensor([[ 5.9691, -4.5568, -3.5381],
        [-4.5598, -4.2485, -3.4295],
        [-3.5352, -3.4310, -3.0103]], requires_grad=True)
tensor([[ 5.9356, -3.7393, -2.9065],
        [-3.8066, -3.9717, -2.8483],
        [-2.7721, -2.6701, -2.4147]], requires_grad=True)
tensor([[ 4.4498, -3.7165, -2.8374],
        [-3.5242, -3.3233, -2.7838],
        [-2.8806, -2.7561, -2.6839]], requires_grad=True)


In [29]:
print(logit_1.grad)
print(logit_2.grad)
print(logit_3.grad)

tensor([[-0.0087,  0.0048,  0.0019],
        [ 0.0047,  0.0037,  0.0017],
        [ 0.0020,  0.0017,  0.0013]])
tensor([[-0.0107, -0.0848, -0.1427],
        [-0.0809, -0.0721,  0.2563],
        [-0.1534,  0.2421, -0.1817]])
tensor([[-0.0634, -0.0861, -0.1482],
        [-0.0980, -0.1115, -0.1525],
        [-0.1448, -0.1547,  0.3806]])
