# Numerical opt. of $\alpha_{1:n}$ for probit class

Assume a model with a fairly informative prior:
$$
 \Theta \sim N(\Theta; \mu_{\theta}, \sigma^2_{\Theta}) \\
 B \sim U(B; a_{B}, b_{B})
$$

In [None]:
%load_ext autoreload
%autoreload 2
from functools import partial
from scipy.stats import norm, bernoulli
import numpy as np
import torch
from torch.distributions.normal import Normal
from torch.distributions.uniform import Uniform
import matplotlib.pyplot as plt
import sys
sys.path.append("..")
from src.probit.probit import (
    p,
    h,
    p_y_tilde_seq_given_params,
    entropy_of_avg,
    avg_entropy,
    sample_joint_normal_uniform
)
plt.rcParams['figure.figsize'] = (12,5)

# Examples with few samples

Let $x_{1:n} = [-2, -1/2, 0, 1/2]$.

$$
B \sim U(a_B, b_B)\\
\Theta \sim(\mu_{\Theta}, \sigma^2_{\Theta})
$$

In [None]:
# Parameter prior
a_B = -1
b_B = 1
mu_theta = 1e6
sigma_sq_theta = 0.1
theta = Normal(loc=mu_theta, scale=sigma_sq_theta).sample()
beta = Uniform(a_B, b_B).sample()
num_mc_samples = 10

# Specific few samples example
xs = torch.tensor([-2, -1/2, 0, 1/2])
alphas = torch.tensor([1/2, 1/2, 1, 1/2])
alphas = 1/2 * torch.ones(xs.size())
sampler = partial(sample_joint_normal_uniform, mu_theta, sigma_sq_theta, a_B, b_B, num_mc_samples)

#fig, ax = plt.subplots()
#ax.plot(xs, torch.zeros(xs.size()), 'bX', label="$x_{1:n}$")
#ax.plot([a_B, b_B], torch.zeros((2,)), 'r|', label="$\\beta$-intervall")
#ax.plot(xs, alphas, 'b.', label="$\\alpha_{1:n}$")
#ax.legend()

#p_y_tilde_seq_given_params(torch.tensor([0, 0, 1, 1]), xs, (torch.tensor(1e6), torch.tensor(0.0)), alphas)
#avg_entropy(xs, alphas, sampler)
# print(binary_labels_combination(4))
entropy_of_avg(xs, alphas, sampler)
y_tilde_seq = torch.zeros(xs.size())
y_tilde_seq = torch.tensor([1, 0, 0, 0])
# p_y_tilde_seq_given_params(y_tilde_seq, xs, (1e6, 0.5), alphas)
xs = torch.tensor([-2, -1/2, 0, 1/2])
alphas = torch.tensor([1/2, 1/2, 1, 1/2])
alphas = 0.55*torch.ones(xs.shape)
avg_entropy(xs, alphas, sampler)
entropy_of_avg(xs, alphas, sampler)

## Opt problem
maximise
$$
L = 
\mathcal{H}[\tilde{Y}_{1:n} \vert X_{1:n}]
- \mathbb{E}_{\Theta \sim p(\Theta \vert X_{1:n})} \left[\mathcal{H}[\tilde{Y}_{1:n} \vert \theta, x_{1:n}] \right] \\
 = \mathcal{H}(\int p_{\alpha_{1:n}}(Ỹ_{1:n} = ỹ_{1:n} | x_{1:n}, \theta, \beta) N(\theta; \mu_{\Theta}, \sigma^2_{\theta}) U(\beta; a_B, b_B) d \theta d \beta) \\
 - \int \sum_i  h ((2\alpha_i - 1) \Phi(\theta^{\top}(x_i - \beta)) + 1 - \alpha_i) N(\theta; \mu_{\Theta}, \sigma^2_{\theta}) U(\beta; a_B, b_B) d \theta d \beta,
$$
with respect to $\alpha_i \in [0.5, 1.0]$

In [None]:
from functools import partial
from scipy.stats import norm, bernoulli
import numpy as np
import torch
from torch.distributions.normal import Normal
from torch.distributions.uniform import Uniform
import matplotlib.pyplot as plt
import sys
sys.path.append("..")
from src.probit.probit import (
    entropy_of_avg,
    avg_entropy,
    sample_joint_normal_uniform
)
from src.probit.opt import alpha_mapping


a_B = -1
b_B = 1
mu_theta = 1e6
sigma_sq_theta = 0.1
theta = Normal(loc=mu_theta, scale=sigma_sq_theta).sample()
beta = Uniform(a_B, b_B).sample()
num_mc_samples = 1000

xs = torch.tensor([-2, -1/2, 0, 1/2])
# weights = torch.distributions.Uniform(-10, 10).sample(xs.size)
weights = torch.ones(xs.size(), requires_grad=True)
alphas = alpha_mapping(weights)
sampling_fn = partial(sample_joint_normal_uniform, mu_theta, sigma_sq_theta, a_B, b_B, num_mc_samples)
L = entropy_of_avg(xs, alphas, sampling_fn) - avg_entropy(xs, alphas, sampling_fn)
torch.autograd.grad(L, alphas)

In [None]:
import torch
from src.probit.opt import alpha_mapping, inverse_alpha_mapping

alpha = torch.distributions.Uniform(1 / 2, 1).sample()
mapped_alpha = alpha_mapping(inverse_alpha_mapping(alpha))
a = torch.allclose(alpha, mapped_alpha)
alpha, mapped_alpha

#inverse_alpha_mapping(alpha)
alpha = 0.62
1 - 1 / (2 * (alpha - 1 / 2))

In [None]:
%load_ext autoreload
%autoreload 2
from functools import partial
from scipy.stats import norm, bernoulli
import numpy as np
import torch
from torch.distributions.normal import Normal
from torch.distributions.uniform import Uniform
import matplotlib.pyplot as plt
import sys
sys.path.append("..")
from src.probit.probit import (
    entropy_of_avg,
    avg_entropy,
    sample_joint_normal_uniform
)
from src.probit.opt import MiOpt, linear_cost, alpha_mapping

def training_loop(model, optimizer, num_epochs):
    "Training loop for torch model."
    mis = []
    for i in range(1, num_epochs + 1):
        mi = model.compute_mi()
        opt_objective = - mi + model.constraint()
        opt_objective.backward()
        optimizer.step()
        params = next(model.parameters())
        grads = [p.grad for p in model.parameters()]
        print("grads: ", grads)
        # print("grad elemt", params.grad)
        optimizer.zero_grad()
        mis.append(mi.item()) 
        print(f"Epoch: {i}, MI: {mi.item()}, alpha sum: {alpha_mapping(model.opt_param).sum()}")
    return mis

a_B = -1
b_B = 1
mu_theta = 1e6
sigma_sq_theta = 0.1
theta = Normal(loc=mu_theta, scale=sigma_sq_theta).sample()
beta = Uniform(a_B, b_B).sample()
num_mc_samples = 100
xs = torch.tensor([-2, -1/2, 0, 1/2])
alphas = torch.tensor([1/2, 1/2, 1, 1/2])
alphas = 0.55*torch.ones(xs.shape)
# alphas = torch.tensor([0.55, 0.55, 0.95, 0.55])

sampling_fn = partial(sample_joint_normal_uniform, mu_theta, sigma_sq_theta, a_B, b_B, num_mc_samples)
m = MiOpt(xs, alphas, cost_fn=linear_cost, budget = 3, sampling_fn=sampling_fn)
# Instantiate optimizer
opt = torch.optim.SGD(m.parameters(), lr=10)
losses = training_loop(m, opt, num_epochs = 50)

In [None]:
m.alphas()

In [None]:
# Number of epochs, batch size, number of training data and learning rate

num_epochs = 20 
lr = 0.01

zs = 0.8*torch.ones(xs.size(), requires_grad=True)
for epoch in range(num_epochs):
    costs = []
    alphas = alpha_mapping(zs)
    L_1 = entropy_of_avg(xs, alphas, sampler)
    L_2 = avg_entropy(xs, alphas, sampler)
   
    L = - L_2
    L.backward()
    a = alphas.sum()
    print("alphas", a, a.grad)
    print("L_1", L_1, L_1.grad)
    print("L_2", L_2, L_2.grad)
    print(L.grad)
    costs.append(L.item())
    with torch.no_grad():
        alphas += alphas.grad * lr
        alphas.grad.zero_()
            
    # Computing and printing the average loss in the current epoch
    print('Epoch: {} Cost: {}'.format(epoch, L.item()))

