# Testing DAG generation

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from tqdm.notebook import tqdm
import os

from estimators import to_z, to_b, reinforce, relax, ExactGradientEstimator
from distributions import PlackettLuce
from critics import REBARCritic, RELAXCritic
from utils import make_permutation_matrix

## Plackett-Luce distribution

In [2]:
theta = torch.FloatTensor([10,10,-10,-10])
dist = PlackettLuce(theta)

In [3]:
dist.sample(num_samples=10)

tensor([[0, 1, 2, 3],
        [0, 1, 3, 2],
        [1, 0, 3, 2],
        [1, 0, 2, 3],
        [0, 1, 2, 3],
        [0, 1, 3, 2],
        [0, 1, 3, 2],
        [1, 0, 3, 2],
        [0, 1, 2, 3],
        [1, 0, 3, 2]])

In [4]:
dist.log_prob(samples=torch.LongTensor([[0,1,2,3], [1,0,3,2], [3,2,1,0]]))

tensor([ -1.3863,  -1.3863, -42.0794])

### Gradient estimator

In [92]:
np.random.seed(42)
torch.manual_seed(42)

NUM_VARS = 6
DETERMINISTIC_TARGET = True

if DETERMINISTIC_TARGET:
    target = np.random.permutation(NUM_VARS)
    P_target = torch.zeros(NUM_VARS, NUM_VARS, dtype=torch.float32)
    P_target[torch.arange(NUM_VARS), target] = 1.0

    print("Target", target)
    print("P_target", P_target)

    loss_func = lambda P: torch.norm(P - P_target, p=2)
else:
    P_target = torch.rand(NUM_VARS, NUM_VARS)
    P_target = P_target / P_target.sum(dim=-1, keepdims=True)
    P_target = torch.distributions.utils.clamp_probs(P_target)
    print("Target probs", P_target)
    P_target = P_target.log()
    
    loss_func = lambda P: -(P * P_target).sum() # NLL Loss
    
    min_loss = 10000
    targets = None
    for _ in range(5000):
        t = np.random.permutation(NUM_VARS)
        P = torch.zeros(NUM_VARS, NUM_VARS, dtype=torch.float32)
        P[torch.arange(NUM_VARS), t] = 1.0
        loss = loss_func(P).item()
        if loss < min_loss:
            min_loss, targets = loss, t
    
    print("Minimum loss found:", min_loss)
    print("Targets:", targets)

Target [0 1 5 2 4 3]
P_target tensor([[1., 0., 0., 0., 0., 0.],
        [0., 1., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 1.],
        [0., 0., 1., 0., 0., 0.],
        [0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 1., 0., 0.]])


In [107]:
log_theta = nn.Parameter(torch.zeros(NUM_VARS) - torch.arange(NUM_VARS, dtype=torch.float32) / NUM_VARS + 0.5)
critic = RELAXCritic(loss_func, NUM_VARS, 32)
# critic = REBARCritic(loss_func)

In [108]:
optim = torch.optim.Adam(list(critic.parameters()) + [log_theta], lr=0.01)

In [113]:
for i in tqdm(range(5000)):
    optim.zero_grad()
    u = torch.distributions.utils.clamp_probs(torch.rand_like(log_theta)) # * 0.0 + 0.5
    v = torch.distributions.utils.clamp_probs(torch.rand_like(log_theta)) # * 0.0 + 0.5
    z = to_z(log_theta, u)
    b = to_b(z)
    f_b = loss_func(make_permutation_matrix(b))
    d_log_theta = relax(fb=f_b, b=b, logits=log_theta, z=z, c=critic, v=v)
    (d_log_theta ** 2).sum().backward()
    log_theta.backward(d_log_theta)
    optim.step()
    print("Loss", f_b.detach().item())

HBox(children=(FloatProgress(value=0.0, max=5000.0), HTML(value='')))

Loss 2.4494898319244385
Loss 2.4494898319244385
Loss 2.0
Loss 2.8284270763397217
Loss 3.464101552963257
Loss 3.464101552963257
Loss 3.1622776985168457
Loss 3.464101552963257
Loss 2.8284270763397217
Loss 3.464101552963257
Loss 3.464101552963257
Loss 3.464101552963257
Loss 3.1622776985168457
Loss 2.8284270763397217
Loss 3.464101552963257
Loss 3.464101552963257
Loss 2.4494898319244385
Loss 2.4494898319244385
Loss 2.4494898319244385
Loss 2.8284270763397217
Loss 3.1622776985168457
Loss 2.0
Loss 2.8284270763397217
Loss 3.1622776985168457
Loss 2.4494898319244385
Loss 2.0
Loss 3.464101552963257
Loss 2.0
Loss 3.1622776985168457
Loss 2.8284270763397217
Loss 2.4494898319244385
Loss 3.1622776985168457
Loss 3.464101552963257
Loss 2.8284270763397217
Loss 3.464101552963257
Loss 2.8284270763397217
Loss 2.8284270763397217
Loss 2.4494898319244385
Loss 2.8284270763397217
Loss 3.464101552963257
Loss 2.0
Loss 3.464101552963257
Loss 2.8284270763397217
Loss 0.0
Loss 3.1622776985168457
Loss 2.8284270763397217

Loss 2.0
Loss 0.0
Loss 0.0
Loss 0.0
Loss 0.0
Loss 0.0
Loss 0.0
Loss 0.0
Loss 0.0
Loss 2.4494898319244385
Loss 0.0
Loss 2.0
Loss 0.0
Loss 0.0
Loss 0.0
Loss 2.8284270763397217
Loss 0.0
Loss 0.0
Loss 0.0
Loss 0.0
Loss 0.0
Loss 2.0
Loss 2.0
Loss 0.0
Loss 0.0
Loss 0.0
Loss 0.0
Loss 0.0
Loss 0.0
Loss 2.0
Loss 0.0
Loss 0.0
Loss 0.0
Loss 0.0
Loss 0.0
Loss 2.0
Loss 0.0
Loss 2.0
Loss 0.0
Loss 0.0
Loss 2.0
Loss 0.0
Loss 0.0
Loss 0.0
Loss 0.0
Loss 2.0
Loss 2.0
Loss 0.0
Loss 0.0
Loss 0.0
Loss 2.4494898319244385
Loss 2.0
Loss 0.0
Loss 0.0
Loss 0.0
Loss 2.0
Loss 2.0
Loss 0.0
Loss 0.0
Loss 2.0
Loss 0.0
Loss 0.0
Loss 0.0
Loss 2.0
Loss 0.0
Loss 0.0
Loss 0.0
Loss 0.0
Loss 0.0
Loss 0.0
Loss 0.0
Loss 2.0
Loss 0.0
Loss 2.0
Loss 0.0
Loss 2.0
Loss 0.0
Loss 0.0
Loss 2.0
Loss 0.0
Loss 2.0
Loss 2.0
Loss 0.0
Loss 2.0
Loss 0.0
Loss 2.0
Loss 2.0
Loss 0.0
Loss 0.0
Loss 2.0
Loss 0.0
Loss 0.0
Loss 2.0
Loss 2.0
Loss 0.0
Loss 2.0
Loss 0.0
Loss 2.0
Loss 2.0
Loss 2.0
Loss 0.0
Loss 0.0
Loss 0.0
Loss 0.0
Loss 2.0
Loss 0.0
L

Loss 0.0
Loss 0.0
Loss 0.0
Loss 0.0
Loss 0.0
Loss 0.0
Loss 0.0
Loss 0.0
Loss 0.0
Loss 0.0
Loss 0.0
Loss 0.0
Loss 0.0
Loss 0.0
Loss 0.0
Loss 0.0
Loss 0.0
Loss 0.0
Loss 0.0
Loss 0.0
Loss 0.0
Loss 0.0
Loss 0.0
Loss 2.0
Loss 2.0
Loss 0.0
Loss 0.0
Loss 0.0
Loss 0.0
Loss 2.0
Loss 0.0
Loss 0.0
Loss 0.0
Loss 0.0
Loss 0.0
Loss 0.0
Loss 0.0
Loss 0.0
Loss 0.0
Loss 0.0
Loss 0.0
Loss 0.0
Loss 0.0
Loss 0.0
Loss 0.0
Loss 0.0
Loss 2.0
Loss 2.0
Loss 0.0
Loss 0.0
Loss 0.0
Loss 0.0
Loss 0.0
Loss 0.0
Loss 0.0
Loss 2.0
Loss 0.0
Loss 0.0
Loss 0.0
Loss 0.0
Loss 0.0
Loss 0.0
Loss 0.0
Loss 0.0



KeyboardInterrupt: 

In [110]:
torch.argsort(-log_theta)

tensor([0, 1, 2, 5, 4, 3])

In [111]:
log_theta

Parameter containing:
tensor([ 0.6473,  0.4222,  0.1056, -0.3461, -0.2975,  0.0303],
       requires_grad=True)

In [112]:
log_theta.grad

tensor([-1.6084e-06, -1.3169e-06, -1.3430e-06,  3.4524e-06,  8.3074e-07,
        -2.0955e-08])

### Sinkhorn test

In [3]:
def sinkhorn(X, L, temp=1.):
    """ Sinkhorn operator """
    if L==0:
        S = torch.exp(X/temp)
    else:
        S = sinkhorn(X, L-1, temp)
        S = S / torch.sum(S, axis=1, keepdims=True) # row normalize
        S = S / torch.sum(S, axis=0, keepdims=True) # column normalize
    return S

In [13]:
S = sinkhorn(X=torch.randn(4, 4), L=20, temp=1.0)
print(S)
print(S.argmax(dim=-1))

tensor([[0.2767, 0.2532, 0.1119, 0.3582],
        [0.4534, 0.1689, 0.3331, 0.0446],
        [0.2371, 0.1726, 0.0997, 0.4906],
        [0.0327, 0.4053, 0.4554, 0.1066]])
tensor([3, 0, 3, 2])
