In [1]:
import torch
import torch.nn as nn
from tqdm import tqdm
import pdb
import matplotlib.pyplot as plt
import sys
sys.path.insert(0, './src/')
from target import NN_bernoulli
from kernels import HMC_our, Reverse_kernel
%matplotlib inline

In [2]:
class dotdict(dict):
    """dot.notation access to dictionary attributes"""
    __getattr__ = dict.get
    __setattr__ = dict.__setitem__
    __delattr__ = dict.__delitem__

In [3]:
# 'Encoder' - simple matrix
class Encoder(nn.Module):
    def __init__(self, L, z_dim, device='cpu'):
        super(Encoder, self).__init__()
        self.L = L
        self.z_dim = z_dim
        self.e = nn.Linear(in_features=self.L, out_features=2*self.L)
        self.mu = nn.Linear(in_features=2*self.L, out_features=self.z_dim)
        self.sigma = nn.Linear(in_features=2*self.L, out_features=self.z_dim)
        self.h = nn.Linear(in_features=2*self.L, out_features=self.z_dim)
    def forward(self, x):
        hid = torch.relu(self.e(x))
        return self.mu(hid), nn.functional.softplus(self.sigma(hid)), self.h(hid)
    
# 'Decoder' - simple matrix, return logits
class Decoder(nn.Module):
    def __init__(self, L, z_dim, device='cpu'):
        super(Decoder, self).__init__()
        self.L = L
        self.z_dim = z_dim
        self.W = nn.Linear(in_features=self.z_dim, out_features=self.L, bias=False)
    def forward(self, z):
        return [self.W(z)]

In [4]:
L = 2
z_dim = 2
N = 10000
device = "cpu"#"cuda:0" if torch.cuda.is_available() else "cpu"

args = dotdict({})
args.K = 7
args.N = 7
args.z_dim = z_dim
args.torchType = torch.float32
args.device = device
args.learnable_reverse = True
args.num_epoches = 500
args.train_batch_size = 100
args.amortize = False
args.gamma = 0.1 ## Stepsize
args.alpha = 0.5  ## For partial momentum refresh
args.train_only_inference_period = 10
args.train_only_inference_cutoff = 5
args.hoffman_idea = False
args.separate_params = False
args.use_barker = True

In [5]:
enc = Encoder(L=L, z_dim=z_dim, device=device).to(device)
dec = Decoder(L=L, z_dim=z_dim, device=device).to(device)
reverse_kernel = Reverse_kernel(args).to(device)
target = NN_bernoulli({}, dec, device).to(device)
transitions = nn.ModuleList([HMC_our(kwargs=args).to(args.device) for _ in range(args['K'])])

std_normal = torch.distributions.Normal(loc=torch.tensor(0., device=device),
                                                scale=torch.tensor(1., device=device))
args.std_normal = std_normal

In [6]:
true_theta = std_normal.sample((z_dim, L)) + 5
print('True decoder matrix')
print(true_theta)
print('-' * 75)
data_logits = std_normal.sample((N, z_dim)) @ true_theta
data = torch.distributions.Bernoulli(logits=data_logits).sample()
print('Generated data example:')
print(data[:10])

True decoder matrix
tensor([[5.1967, 5.6456],
        [4.1954, 6.4900]])
---------------------------------------------------------------------------
Generated data example:
tensor([[1., 1.],
        [0., 0.],
        [1., 1.],
        [1., 1.],
        [0., 0.],
        [1., 1.],
        [0., 0.],
        [1., 1.],
        [0., 0.],
        [1., 1.]])


In [7]:
dataloader = torch.utils.data.DataLoader(data, batch_size=args.train_batch_size, shuffle=True)

if args.separate_params:
    params = list(enc.parameters()) + list(reverse_kernel.parameters()) + list(transitions.parameters())
    optimizer = torch.optim.Adam(params=target.parameters())
    optimizer_inference = torch.optim.Adam(params=params)
else:
    params = list(enc.parameters()) + list(target.parameters()) + list(reverse_kernel.parameters()) + list(transitions.parameters())
    optimizer = torch.optim.Adam(params=params)
    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[100, 200, 300, 400], gamma=0.5) #torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10, eta_min=0, last_epoch=-1)

In [8]:
enc.mu.weight

Parameter containing:
tensor([[ 0.1934,  0.4755,  0.4092,  0.1377],
        [-0.3581, -0.2096,  0.3742,  0.0482]], requires_grad=True)

In [9]:
def compute_loss(z_new, p_new, u, p_old, x, sum_log_alpha, sum_log_jac, sum_log_sigma, mu=None, all_directions=None, h=None):
    if args.learnable_reverse:
        log_r = reverse_kernel(z_fin=z_new.detach(), h=h.detach(), a=all_directions)
        log_m = args.std_normal.log_prob(u).sum(1) + args.std_normal.log_prob(p_old).sum(1) - sum_log_jac - sum_log_sigma + sum_log_alpha
    else:
        log_r = 0 #-args.K * torch_log_2
        log_m = args.std_normal.log_prob(u).sum(1) + args.std_normal.log_prob(p_old).sum(1) - sum_log_jac - sum_log_sigma # + sum_log_alpha
        
    log_p = target.get_logdensity(z=z_new, x=x, args=args) + args.std_normal.log_prob(p_new.sum(1))
    elbo_full = log_p + log_r - log_m
    grad_elbo = torch.mean(elbo_full + elbo_full.detach() * sum_log_alpha)
    return elbo_full, grad_elbo 

In [10]:
best_elbo = -float("inf")
current_tolerance = 0
print_info_ = 10
# with torch.autograd.detect_anomaly():
for ep in tqdm(range(args.num_epoches)): # cycle over epoches
    for b_num, batch_train in enumerate(dataloader): # cycle over batches
#         plt.close()        
#         pdb.set_trace()

        mu, sigma, h = enc(batch_train) # sample mu and sigma from encoder
        u = args.std_normal.sample(mu.shape) # sample random tensor for reparametrization trick
        z = mu + sigma * u # reperametrization trick
        sum_log_sigma = torch.sum(torch.log(sigma), 1)

        p_old = args.std_normal.sample(z.shape)
        cond_vectors = [args.std_normal.sample(p_old.shape) for _ in range(args.K)]

        sum_log_alpha = torch.zeros(mu.shape[0], dtype=args.torchType, device=args.device) # for grad log alpha accumulation
        sum_log_jacobian = torch.zeros(mu.shape[0], dtype=args.torchType, device=args.device) # for log_jacobian accumulation
        p = p_old
        if args.learnable_reverse:
            all_directions = torch.tensor([], device=args.device)
        else:
            all_directions = None
        for k in range(args.K):
            # sample alpha - transition probabilities 
            if args.amortize:
#                     pdb.set_trace()
                z, p, log_jac, current_log_alphas, directions, _ = transitions.make_transition(q_old=z, x=batch_train,
                                                    p_old=p, k=cond_vectors[k], target_distr=target, args=args)
            else:
                z, p, log_jac, current_log_alphas, directions, _ = transitions[k].make_transition(q_old=z, x=batch_train,
                                                                    p_old=p, k=cond_vectors[k], target_distr=target, args=args) # sample a_i -- directions
            if ep  % print_info_ == 0 and b_num % (100 * print_info_) == 0:
                print('On batch number {}/{} and on k = {} we have for  0: {} and for +1: {}'.format(b_num + 1,
                                                                        data.shape[0] // args['train_batch_size'],
                                                                           k + 1,
                                                    (directions==0.).to(float).mean(),
                                                                    (directions==1.).to(float).mean()))
                if args.amortize:
                    print('Stepsize {}'.format(np.exp(transitions.gamma.cpu().detach().item())))
                    print('Autoregression coeff {}'.format(torch.sigmoid(transitions.alpha_logit).cpu().detach().item()))
            if args.learnable_reverse:
                all_directions = torch.cat([all_directions, directions.view(-1, 1)], dim=1)
            # Accumulate alphas
            sum_log_alpha = sum_log_alpha + current_log_alphas
            sum_log_jacobian = sum_log_jacobian + log_jac  # refresh log jacobian
        ##############################################
        if args.hoffman_idea:
            if args.learnable_reverse:
                log_r = reverse_kernel(z_fin=z.detach(), h=h.detach(), a=all_directions)
                log_m = args.std_normal.log_prob(u).sum(1) + args.std_normal.log_prob(p_old).sum(1) - sum_log_jacobian - sum_log_sigma + sum_log_alpha
            else:
                log_r = 0 #-args.K * torch_log_2
                log_m = args.std_normal.log_prob(u).sum(1) + args.std_normal.log_prob(p_old).sum(1) - sum_log_jacobian - sum_log_sigma # + sum_log_alpha
            log_p = target.get_logdensity(z=z, x=batch_train, args=args) + args.std_normal.log_prob(p.sum(1))
            elbo_full = log_p + log_r - log_m
    #                 pdb.set_trace()
            ### Gradient of the first objective:
#             target.eval()
            obj_1 = torch.mean(elbo_full + elbo_full.detach() * sum_log_alpha)
            (-obj_1).backward(retain_graph=True)
            optimizer_inference.step()
            optimizer_inference.zero_grad()
            optimizer.zero_grad() 

            ### Gradient of the second objective:
#             target.train()
            log_p = target.get_logdensity(z=z.detach(), x=batch_train, args=args) + args.std_normal.log_prob(p.detach()).sum(1)
            elbo_full = log_p # - log_m
            obj_2 = torch.mean(elbo_full + elbo_full.detach() * sum_log_alpha)
            (-obj_2).backward()
            optimizer.step()
            optimizer_inference.zero_grad()
            optimizer.zero_grad()
            ###########################################################
        else:
            elbo_full, grad_elbo = compute_loss(z_new=z, p_new=p, u=u, p_old=p_old, x=batch_train, sum_log_alpha=sum_log_alpha,
                                                sum_log_jac=sum_log_jacobian, sum_log_sigma=sum_log_sigma, mu=mu,
                                                all_directions=all_directions, h=h)
            (-grad_elbo).backward()

                
        if args.separate_params: # if we separate params of inference part and generation part
            optimizer_inference.step() # we always perform step for inference part
            if ep % args.train_only_inference_period > args.train_only_inference_cutoff: # but sometimes for gen
                optimizer.step()
            optimizer.zero_grad()
            optimizer_inference.zero_grad()
        else:
            optimizer.step()
            optimizer.zero_grad()
            scheduler.step()
            
        if ep  % print_info_ == 0 and b_num % (100 * print_info_) == 0:
            if args.hoffman_idea:
                print('obj_1:', obj_1)
                print('obj_2:', obj_2)
            else:
                print('elbo:', elbo_full.mean().cpu().detach().item())

  0%|          | 0/500 [00:00<?, ?it/s]

On batch number 1/100 and on k = 1 we have for  0: 0.57 and for +1: 0.43
On batch number 1/100 and on k = 2 we have for  0: 0.55 and for +1: 0.45
On batch number 1/100 and on k = 3 we have for  0: 0.53 and for +1: 0.47
On batch number 1/100 and on k = 4 we have for  0: 0.54 and for +1: 0.46
On batch number 1/100 and on k = 5 we have for  0: 0.58 and for +1: 0.42
On batch number 1/100 and on k = 6 we have for  0: 0.55 and for +1: 0.45
On batch number 1/100 and on k = 7 we have for  0: 0.52 and for +1: 0.48
elbo: -10.820435523986816


  2%|▏         | 10/500 [00:31<25:24,  3.11s/it]

On batch number 1/100 and on k = 1 we have for  0: 0.47 and for +1: 0.53
On batch number 1/100 and on k = 2 we have for  0: 0.44 and for +1: 0.56
On batch number 1/100 and on k = 3 we have for  0: 0.46 and for +1: 0.54
On batch number 1/100 and on k = 4 we have for  0: 0.45 and for +1: 0.55
On batch number 1/100 and on k = 5 we have for  0: 0.51 and for +1: 0.49
On batch number 1/100 and on k = 6 we have for  0: 0.47 and for +1: 0.53
On batch number 1/100 and on k = 7 we have for  0: 0.49 and for +1: 0.51
elbo: -8.738831520080566


  4%|▍         | 20/500 [01:02<24:56,  3.12s/it]

On batch number 1/100 and on k = 1 we have for  0: 0.57 and for +1: 0.43
On batch number 1/100 and on k = 2 we have for  0: 0.51 and for +1: 0.49
On batch number 1/100 and on k = 3 we have for  0: 0.4 and for +1: 0.6
On batch number 1/100 and on k = 4 we have for  0: 0.51 and for +1: 0.49
On batch number 1/100 and on k = 5 we have for  0: 0.48 and for +1: 0.52
On batch number 1/100 and on k = 6 we have for  0: 0.47 and for +1: 0.53
On batch number 1/100 and on k = 7 we have for  0: 0.46 and for +1: 0.54
elbo: -8.241945266723633


  6%|▌         | 30/500 [01:33<24:28,  3.12s/it]

On batch number 1/100 and on k = 1 we have for  0: 0.45 and for +1: 0.55
On batch number 1/100 and on k = 2 we have for  0: 0.64 and for +1: 0.36
On batch number 1/100 and on k = 3 we have for  0: 0.45 and for +1: 0.55
On batch number 1/100 and on k = 4 we have for  0: 0.44 and for +1: 0.56
On batch number 1/100 and on k = 5 we have for  0: 0.47 and for +1: 0.53
On batch number 1/100 and on k = 6 we have for  0: 0.49 and for +1: 0.51
On batch number 1/100 and on k = 7 we have for  0: 0.5 and for +1: 0.5
elbo: -7.957027435302734


  8%|▊         | 40/500 [02:04<24:00,  3.13s/it]

On batch number 1/100 and on k = 1 we have for  0: 0.55 and for +1: 0.45
On batch number 1/100 and on k = 2 we have for  0: 0.45 and for +1: 0.55
On batch number 1/100 and on k = 3 we have for  0: 0.57 and for +1: 0.43
On batch number 1/100 and on k = 4 we have for  0: 0.51 and for +1: 0.49
On batch number 1/100 and on k = 5 we have for  0: 0.51 and for +1: 0.49
On batch number 1/100 and on k = 6 we have for  0: 0.45 and for +1: 0.55
On batch number 1/100 and on k = 7 we have for  0: 0.57 and for +1: 0.43
elbo: -7.481565952301025


 10%|█         | 50/500 [02:35<23:30,  3.13s/it]

On batch number 1/100 and on k = 1 we have for  0: 0.58 and for +1: 0.42
On batch number 1/100 and on k = 2 we have for  0: 0.45 and for +1: 0.55
On batch number 1/100 and on k = 3 we have for  0: 0.48 and for +1: 0.52
On batch number 1/100 and on k = 4 we have for  0: 0.63 and for +1: 0.37
On batch number 1/100 and on k = 5 we have for  0: 0.48 and for +1: 0.52
On batch number 1/100 and on k = 6 we have for  0: 0.61 and for +1: 0.39
On batch number 1/100 and on k = 7 we have for  0: 0.51 and for +1: 0.49
elbo: -7.405515193939209


 12%|█▏        | 60/500 [03:06<22:36,  3.08s/it]

On batch number 1/100 and on k = 1 we have for  0: 0.57 and for +1: 0.43
On batch number 1/100 and on k = 2 we have for  0: 0.56 and for +1: 0.44
On batch number 1/100 and on k = 3 we have for  0: 0.58 and for +1: 0.42
On batch number 1/100 and on k = 4 we have for  0: 0.53 and for +1: 0.47
On batch number 1/100 and on k = 5 we have for  0: 0.46 and for +1: 0.54
On batch number 1/100 and on k = 6 we have for  0: 0.49 and for +1: 0.51
On batch number 1/100 and on k = 7 we have for  0: 0.52 and for +1: 0.48
elbo: -6.96462345123291


 14%|█▍        | 70/500 [03:37<22:02,  3.08s/it]

On batch number 1/100 and on k = 1 we have for  0: 0.51 and for +1: 0.49
On batch number 1/100 and on k = 2 we have for  0: 0.55 and for +1: 0.45
On batch number 1/100 and on k = 3 we have for  0: 0.38 and for +1: 0.62
On batch number 1/100 and on k = 4 we have for  0: 0.48 and for +1: 0.52
On batch number 1/100 and on k = 5 we have for  0: 0.48 and for +1: 0.52
On batch number 1/100 and on k = 6 we have for  0: 0.51 and for +1: 0.49
On batch number 1/100 and on k = 7 we have for  0: 0.52 and for +1: 0.48
elbo: -6.5585856437683105


 16%|█▌        | 80/500 [04:09<21:54,  3.13s/it]

On batch number 1/100 and on k = 1 we have for  0: 0.44 and for +1: 0.56
On batch number 1/100 and on k = 2 we have for  0: 0.45 and for +1: 0.55
On batch number 1/100 and on k = 3 we have for  0: 0.48 and for +1: 0.52
On batch number 1/100 and on k = 4 we have for  0: 0.55 and for +1: 0.45
On batch number 1/100 and on k = 5 we have for  0: 0.47 and for +1: 0.53
On batch number 1/100 and on k = 6 we have for  0: 0.55 and for +1: 0.45
On batch number 1/100 and on k = 7 we have for  0: 0.46 and for +1: 0.54
elbo: -6.372210502624512


 18%|█▊        | 90/500 [04:40<20:59,  3.07s/it]

On batch number 1/100 and on k = 1 we have for  0: 0.45 and for +1: 0.55
On batch number 1/100 and on k = 2 we have for  0: 0.57 and for +1: 0.43
On batch number 1/100 and on k = 3 we have for  0: 0.55 and for +1: 0.45
On batch number 1/100 and on k = 4 we have for  0: 0.43 and for +1: 0.57
On batch number 1/100 and on k = 5 we have for  0: 0.53 and for +1: 0.47
On batch number 1/100 and on k = 6 we have for  0: 0.38 and for +1: 0.62
On batch number 1/100 and on k = 7 we have for  0: 0.43 and for +1: 0.57
elbo: -5.772932529449463


 20%|██        | 100/500 [05:11<20:49,  3.12s/it]

On batch number 1/100 and on k = 1 we have for  0: 0.43 and for +1: 0.57
On batch number 1/100 and on k = 2 we have for  0: 0.5 and for +1: 0.5
On batch number 1/100 and on k = 3 we have for  0: 0.57 and for +1: 0.43
On batch number 1/100 and on k = 4 we have for  0: 0.5 and for +1: 0.5
On batch number 1/100 and on k = 5 we have for  0: 0.47 and for +1: 0.53
On batch number 1/100 and on k = 6 we have for  0: 0.47 and for +1: 0.53
On batch number 1/100 and on k = 7 we have for  0: 0.52 and for +1: 0.48
elbo: -5.633616924285889


 22%|██▏       | 110/500 [05:42<20:19,  3.13s/it]

On batch number 1/100 and on k = 1 we have for  0: 0.44 and for +1: 0.56
On batch number 1/100 and on k = 2 we have for  0: 0.51 and for +1: 0.49
On batch number 1/100 and on k = 3 we have for  0: 0.53 and for +1: 0.47
On batch number 1/100 and on k = 4 we have for  0: 0.47 and for +1: 0.53
On batch number 1/100 and on k = 5 we have for  0: 0.42 and for +1: 0.58
On batch number 1/100 and on k = 6 we have for  0: 0.5 and for +1: 0.5
On batch number 1/100 and on k = 7 we have for  0: 0.53 and for +1: 0.47
elbo: -5.4576873779296875


 24%|██▍       | 120/500 [06:13<19:48,  3.13s/it]

On batch number 1/100 and on k = 1 we have for  0: 0.46 and for +1: 0.54
On batch number 1/100 and on k = 2 we have for  0: 0.56 and for +1: 0.44
On batch number 1/100 and on k = 3 we have for  0: 0.48 and for +1: 0.52
On batch number 1/100 and on k = 4 we have for  0: 0.42 and for +1: 0.58
On batch number 1/100 and on k = 5 we have for  0: 0.53 and for +1: 0.47
On batch number 1/100 and on k = 6 we have for  0: 0.51 and for +1: 0.49
On batch number 1/100 and on k = 7 we have for  0: 0.42 and for +1: 0.58
elbo: -5.159465789794922


 26%|██▌       | 130/500 [06:44<19:12,  3.12s/it]

On batch number 1/100 and on k = 1 we have for  0: 0.56 and for +1: 0.44
On batch number 1/100 and on k = 2 we have for  0: 0.52 and for +1: 0.48
On batch number 1/100 and on k = 3 we have for  0: 0.55 and for +1: 0.45
On batch number 1/100 and on k = 4 we have for  0: 0.48 and for +1: 0.52
On batch number 1/100 and on k = 5 we have for  0: 0.52 and for +1: 0.48
On batch number 1/100 and on k = 6 we have for  0: 0.45 and for +1: 0.55
On batch number 1/100 and on k = 7 we have for  0: 0.51 and for +1: 0.49
elbo: -4.856247901916504


 28%|██▊       | 140/500 [07:15<18:37,  3.10s/it]

On batch number 1/100 and on k = 1 we have for  0: 0.53 and for +1: 0.47
On batch number 1/100 and on k = 2 we have for  0: 0.43 and for +1: 0.57
On batch number 1/100 and on k = 3 we have for  0: 0.54 and for +1: 0.46
On batch number 1/100 and on k = 4 we have for  0: 0.52 and for +1: 0.48
On batch number 1/100 and on k = 5 we have for  0: 0.6 and for +1: 0.4
On batch number 1/100 and on k = 6 we have for  0: 0.57 and for +1: 0.43
On batch number 1/100 and on k = 7 we have for  0: 0.5 and for +1: 0.5
elbo: -4.5913496017456055


 30%|███       | 150/500 [07:46<18:04,  3.10s/it]

On batch number 1/100 and on k = 1 we have for  0: 0.51 and for +1: 0.49
On batch number 1/100 and on k = 2 we have for  0: 0.54 and for +1: 0.46
On batch number 1/100 and on k = 3 we have for  0: 0.42 and for +1: 0.58
On batch number 1/100 and on k = 4 we have for  0: 0.51 and for +1: 0.49
On batch number 1/100 and on k = 5 we have for  0: 0.53 and for +1: 0.47
On batch number 1/100 and on k = 6 we have for  0: 0.5 and for +1: 0.5
On batch number 1/100 and on k = 7 we have for  0: 0.44 and for +1: 0.56
elbo: -4.8007659912109375


 32%|███▏      | 160/500 [08:17<17:40,  3.12s/it]

On batch number 1/100 and on k = 1 we have for  0: 0.42 and for +1: 0.58
On batch number 1/100 and on k = 2 we have for  0: 0.51 and for +1: 0.49
On batch number 1/100 and on k = 3 we have for  0: 0.5 and for +1: 0.5
On batch number 1/100 and on k = 4 we have for  0: 0.45 and for +1: 0.55
On batch number 1/100 and on k = 5 we have for  0: 0.49 and for +1: 0.51
On batch number 1/100 and on k = 6 we have for  0: 0.42 and for +1: 0.58
On batch number 1/100 and on k = 7 we have for  0: 0.51 and for +1: 0.49
elbo: -4.1873087882995605


 34%|███▍      | 170/500 [08:49<17:12,  3.13s/it]

On batch number 1/100 and on k = 1 we have for  0: 0.42 and for +1: 0.58
On batch number 1/100 and on k = 2 we have for  0: 0.44 and for +1: 0.56
On batch number 1/100 and on k = 3 we have for  0: 0.46 and for +1: 0.54
On batch number 1/100 and on k = 4 we have for  0: 0.53 and for +1: 0.47
On batch number 1/100 and on k = 5 we have for  0: 0.46 and for +1: 0.54
On batch number 1/100 and on k = 6 we have for  0: 0.57 and for +1: 0.43
On batch number 1/100 and on k = 7 we have for  0: 0.44 and for +1: 0.56
elbo: -4.171233177185059


 36%|███▌      | 180/500 [09:20<16:32,  3.10s/it]

On batch number 1/100 and on k = 1 we have for  0: 0.54 and for +1: 0.46
On batch number 1/100 and on k = 2 we have for  0: 0.49 and for +1: 0.51
On batch number 1/100 and on k = 3 we have for  0: 0.53 and for +1: 0.47
On batch number 1/100 and on k = 4 we have for  0: 0.52 and for +1: 0.48
On batch number 1/100 and on k = 5 we have for  0: 0.47 and for +1: 0.53
On batch number 1/100 and on k = 6 we have for  0: 0.51 and for +1: 0.49
On batch number 1/100 and on k = 7 we have for  0: 0.49 and for +1: 0.51
elbo: -3.8555548191070557


 38%|███▊      | 190/500 [09:51<16:10,  3.13s/it]

On batch number 1/100 and on k = 1 we have for  0: 0.56 and for +1: 0.44
On batch number 1/100 and on k = 2 we have for  0: 0.64 and for +1: 0.36
On batch number 1/100 and on k = 3 we have for  0: 0.43 and for +1: 0.57
On batch number 1/100 and on k = 4 we have for  0: 0.48 and for +1: 0.52
On batch number 1/100 and on k = 5 we have for  0: 0.5 and for +1: 0.5
On batch number 1/100 and on k = 6 we have for  0: 0.5 and for +1: 0.5
On batch number 1/100 and on k = 7 we have for  0: 0.51 and for +1: 0.49
elbo: -3.7172937393188477


 40%|████      | 200/500 [10:22<15:37,  3.13s/it]

On batch number 1/100 and on k = 1 we have for  0: 0.58 and for +1: 0.42
On batch number 1/100 and on k = 2 we have for  0: 0.54 and for +1: 0.46
On batch number 1/100 and on k = 3 we have for  0: 0.58 and for +1: 0.42
On batch number 1/100 and on k = 4 we have for  0: 0.42 and for +1: 0.58
On batch number 1/100 and on k = 5 we have for  0: 0.4 and for +1: 0.6
On batch number 1/100 and on k = 6 we have for  0: 0.46 and for +1: 0.54
On batch number 1/100 and on k = 7 we have for  0: 0.49 and for +1: 0.51
elbo: -3.755617380142212


 42%|████▏     | 210/500 [10:53<14:57,  3.10s/it]

On batch number 1/100 and on k = 1 we have for  0: 0.49 and for +1: 0.51
On batch number 1/100 and on k = 2 we have for  0: 0.47 and for +1: 0.53
On batch number 1/100 and on k = 3 we have for  0: 0.52 and for +1: 0.48
On batch number 1/100 and on k = 4 we have for  0: 0.54 and for +1: 0.46
On batch number 1/100 and on k = 5 we have for  0: 0.5 and for +1: 0.5
On batch number 1/100 and on k = 6 we have for  0: 0.55 and for +1: 0.45
On batch number 1/100 and on k = 7 we have for  0: 0.58 and for +1: 0.42
elbo: -3.260326623916626


 44%|████▍     | 220/500 [11:25<14:46,  3.17s/it]

On batch number 1/100 and on k = 1 we have for  0: 0.54 and for +1: 0.46
On batch number 1/100 and on k = 2 we have for  0: 0.46 and for +1: 0.54
On batch number 1/100 and on k = 3 we have for  0: 0.45 and for +1: 0.55
On batch number 1/100 and on k = 4 we have for  0: 0.58 and for +1: 0.42
On batch number 1/100 and on k = 5 we have for  0: 0.57 and for +1: 0.43
On batch number 1/100 and on k = 6 we have for  0: 0.45 and for +1: 0.55
On batch number 1/100 and on k = 7 we have for  0: 0.42 and for +1: 0.58
elbo: -3.2211477756500244


 46%|████▌     | 230/500 [11:56<13:58,  3.10s/it]

On batch number 1/100 and on k = 1 we have for  0: 0.45 and for +1: 0.55
On batch number 1/100 and on k = 2 we have for  0: 0.49 and for +1: 0.51
On batch number 1/100 and on k = 3 we have for  0: 0.46 and for +1: 0.54
On batch number 1/100 and on k = 4 we have for  0: 0.48 and for +1: 0.52
On batch number 1/100 and on k = 5 we have for  0: 0.42 and for +1: 0.58
On batch number 1/100 and on k = 6 we have for  0: 0.53 and for +1: 0.47
On batch number 1/100 and on k = 7 we have for  0: 0.61 and for +1: 0.39
elbo: -3.108839511871338


 48%|████▊     | 240/500 [12:27<13:21,  3.08s/it]

On batch number 1/100 and on k = 1 we have for  0: 0.42 and for +1: 0.58
On batch number 1/100 and on k = 2 we have for  0: 0.45 and for +1: 0.55
On batch number 1/100 and on k = 3 we have for  0: 0.54 and for +1: 0.46
On batch number 1/100 and on k = 4 we have for  0: 0.55 and for +1: 0.45
On batch number 1/100 and on k = 5 we have for  0: 0.36 and for +1: 0.64
On batch number 1/100 and on k = 6 we have for  0: 0.47 and for +1: 0.53
On batch number 1/100 and on k = 7 we have for  0: 0.56 and for +1: 0.44
elbo: -2.7954483032226562


 50%|█████     | 250/500 [12:58<12:56,  3.11s/it]

On batch number 1/100 and on k = 1 we have for  0: 0.46 and for +1: 0.54
On batch number 1/100 and on k = 2 we have for  0: 0.56 and for +1: 0.44
On batch number 1/100 and on k = 3 we have for  0: 0.55 and for +1: 0.45
On batch number 1/100 and on k = 4 we have for  0: 0.44 and for +1: 0.56
On batch number 1/100 and on k = 5 we have for  0: 0.5 and for +1: 0.5
On batch number 1/100 and on k = 6 we have for  0: 0.51 and for +1: 0.49
On batch number 1/100 and on k = 7 we have for  0: 0.53 and for +1: 0.47
elbo: -2.660506248474121


 52%|█████▏    | 260/500 [13:29<12:31,  3.13s/it]

On batch number 1/100 and on k = 1 we have for  0: 0.56 and for +1: 0.44
On batch number 1/100 and on k = 2 we have for  0: 0.46 and for +1: 0.54
On batch number 1/100 and on k = 3 we have for  0: 0.47 and for +1: 0.53
On batch number 1/100 and on k = 4 we have for  0: 0.52 and for +1: 0.48
On batch number 1/100 and on k = 5 we have for  0: 0.45 and for +1: 0.55
On batch number 1/100 and on k = 6 we have for  0: 0.37 and for +1: 0.63
On batch number 1/100 and on k = 7 we have for  0: 0.46 and for +1: 0.54
elbo: -2.913423538208008


 54%|█████▍    | 270/500 [14:00<11:59,  3.13s/it]

On batch number 1/100 and on k = 1 we have for  0: 0.52 and for +1: 0.48
On batch number 1/100 and on k = 2 we have for  0: 0.52 and for +1: 0.48
On batch number 1/100 and on k = 3 we have for  0: 0.52 and for +1: 0.48
On batch number 1/100 and on k = 4 we have for  0: 0.46 and for +1: 0.54
On batch number 1/100 and on k = 5 we have for  0: 0.48 and for +1: 0.52
On batch number 1/100 and on k = 6 we have for  0: 0.57 and for +1: 0.43
On batch number 1/100 and on k = 7 we have for  0: 0.55 and for +1: 0.45
elbo: -2.3134162425994873


 56%|█████▌    | 280/500 [14:31<11:29,  3.13s/it]

On batch number 1/100 and on k = 1 we have for  0: 0.41 and for +1: 0.59
On batch number 1/100 and on k = 2 we have for  0: 0.59 and for +1: 0.41
On batch number 1/100 and on k = 3 we have for  0: 0.41 and for +1: 0.59
On batch number 1/100 and on k = 4 we have for  0: 0.56 and for +1: 0.44
On batch number 1/100 and on k = 5 we have for  0: 0.5 and for +1: 0.5
On batch number 1/100 and on k = 6 we have for  0: 0.55 and for +1: 0.45
On batch number 1/100 and on k = 7 we have for  0: 0.49 and for +1: 0.51
elbo: -2.4382660388946533


 58%|█████▊    | 290/500 [15:02<10:50,  3.10s/it]

On batch number 1/100 and on k = 1 we have for  0: 0.57 and for +1: 0.43
On batch number 1/100 and on k = 2 we have for  0: 0.5 and for +1: 0.5
On batch number 1/100 and on k = 3 we have for  0: 0.49 and for +1: 0.51
On batch number 1/100 and on k = 4 we have for  0: 0.43 and for +1: 0.57
On batch number 1/100 and on k = 5 we have for  0: 0.53 and for +1: 0.47
On batch number 1/100 and on k = 6 we have for  0: 0.49 and for +1: 0.51
On batch number 1/100 and on k = 7 we have for  0: 0.55 and for +1: 0.45
elbo: -2.383314609527588


 60%|██████    | 300/500 [15:33<10:19,  3.10s/it]

On batch number 1/100 and on k = 1 we have for  0: 0.46 and for +1: 0.54
On batch number 1/100 and on k = 2 we have for  0: 0.57 and for +1: 0.43
On batch number 1/100 and on k = 3 we have for  0: 0.53 and for +1: 0.47
On batch number 1/100 and on k = 4 we have for  0: 0.44 and for +1: 0.56
On batch number 1/100 and on k = 5 we have for  0: 0.52 and for +1: 0.48
On batch number 1/100 and on k = 6 we have for  0: 0.54 and for +1: 0.46
On batch number 1/100 and on k = 7 we have for  0: 0.6 and for +1: 0.4
elbo: -2.3239591121673584


 62%|██████▏   | 310/500 [16:04<09:43,  3.07s/it]

On batch number 1/100 and on k = 1 we have for  0: 0.44 and for +1: 0.56
On batch number 1/100 and on k = 2 we have for  0: 0.51 and for +1: 0.49
On batch number 1/100 and on k = 3 we have for  0: 0.52 and for +1: 0.48
On batch number 1/100 and on k = 4 we have for  0: 0.5 and for +1: 0.5
On batch number 1/100 and on k = 5 we have for  0: 0.51 and for +1: 0.49
On batch number 1/100 and on k = 6 we have for  0: 0.53 and for +1: 0.47
On batch number 1/100 and on k = 7 we have for  0: 0.54 and for +1: 0.46
elbo: -2.313925266265869


 64%|██████▍   | 320/500 [16:35<09:21,  3.12s/it]

On batch number 1/100 and on k = 1 we have for  0: 0.51 and for +1: 0.49
On batch number 1/100 and on k = 2 we have for  0: 0.5 and for +1: 0.5
On batch number 1/100 and on k = 3 we have for  0: 0.47 and for +1: 0.53
On batch number 1/100 and on k = 4 we have for  0: 0.37 and for +1: 0.63
On batch number 1/100 and on k = 5 we have for  0: 0.54 and for +1: 0.46
On batch number 1/100 and on k = 6 we have for  0: 0.57 and for +1: 0.43
On batch number 1/100 and on k = 7 we have for  0: 0.43 and for +1: 0.57
elbo: -2.3885867595672607


 66%|██████▌   | 330/500 [17:06<08:52,  3.13s/it]

On batch number 1/100 and on k = 1 we have for  0: 0.47 and for +1: 0.53
On batch number 1/100 and on k = 2 we have for  0: 0.59 and for +1: 0.41
On batch number 1/100 and on k = 3 we have for  0: 0.5 and for +1: 0.5
On batch number 1/100 and on k = 4 we have for  0: 0.42 and for +1: 0.58
On batch number 1/100 and on k = 5 we have for  0: 0.57 and for +1: 0.43
On batch number 1/100 and on k = 6 we have for  0: 0.56 and for +1: 0.44
On batch number 1/100 and on k = 7 we have for  0: 0.55 and for +1: 0.45
elbo: -1.7861998081207275


 68%|██████▊   | 340/500 [17:37<08:09,  3.06s/it]

On batch number 1/100 and on k = 1 we have for  0: 0.5 and for +1: 0.5
On batch number 1/100 and on k = 2 we have for  0: 0.64 and for +1: 0.36
On batch number 1/100 and on k = 3 we have for  0: 0.53 and for +1: 0.47
On batch number 1/100 and on k = 4 we have for  0: 0.39 and for +1: 0.61
On batch number 1/100 and on k = 5 we have for  0: 0.49 and for +1: 0.51
On batch number 1/100 and on k = 6 we have for  0: 0.44 and for +1: 0.56
On batch number 1/100 and on k = 7 we have for  0: 0.49 and for +1: 0.51
elbo: -2.0529532432556152


 70%|███████   | 350/500 [18:08<07:47,  3.12s/it]

On batch number 1/100 and on k = 1 we have for  0: 0.5 and for +1: 0.5
On batch number 1/100 and on k = 2 we have for  0: 0.38 and for +1: 0.62
On batch number 1/100 and on k = 3 we have for  0: 0.49 and for +1: 0.51
On batch number 1/100 and on k = 4 we have for  0: 0.44 and for +1: 0.56
On batch number 1/100 and on k = 5 we have for  0: 0.55 and for +1: 0.45
On batch number 1/100 and on k = 6 we have for  0: 0.5 and for +1: 0.5
On batch number 1/100 and on k = 7 we have for  0: 0.41 and for +1: 0.59
elbo: -1.9460071325302124


 72%|███████▏  | 360/500 [18:39<07:15,  3.11s/it]

On batch number 1/100 and on k = 1 we have for  0: 0.48 and for +1: 0.52
On batch number 1/100 and on k = 2 we have for  0: 0.58 and for +1: 0.42
On batch number 1/100 and on k = 3 we have for  0: 0.52 and for +1: 0.48
On batch number 1/100 and on k = 4 we have for  0: 0.52 and for +1: 0.48
On batch number 1/100 and on k = 5 we have for  0: 0.48 and for +1: 0.52
On batch number 1/100 and on k = 6 we have for  0: 0.6 and for +1: 0.4
On batch number 1/100 and on k = 7 we have for  0: 0.46 and for +1: 0.54
elbo: -2.039198875427246


 74%|███████▍  | 370/500 [19:10<06:42,  3.09s/it]

On batch number 1/100 and on k = 1 we have for  0: 0.51 and for +1: 0.49
On batch number 1/100 and on k = 2 we have for  0: 0.54 and for +1: 0.46
On batch number 1/100 and on k = 3 we have for  0: 0.55 and for +1: 0.45
On batch number 1/100 and on k = 4 we have for  0: 0.44 and for +1: 0.56
On batch number 1/100 and on k = 5 we have for  0: 0.42 and for +1: 0.58
On batch number 1/100 and on k = 6 we have for  0: 0.52 and for +1: 0.48
On batch number 1/100 and on k = 7 we have for  0: 0.39 and for +1: 0.61
elbo: -1.7087551355361938


 76%|███████▌  | 380/500 [19:41<06:12,  3.10s/it]

On batch number 1/100 and on k = 1 we have for  0: 0.48 and for +1: 0.52
On batch number 1/100 and on k = 2 we have for  0: 0.5 and for +1: 0.5
On batch number 1/100 and on k = 3 we have for  0: 0.53 and for +1: 0.47
On batch number 1/100 and on k = 4 we have for  0: 0.51 and for +1: 0.49
On batch number 1/100 and on k = 5 we have for  0: 0.45 and for +1: 0.55
On batch number 1/100 and on k = 6 we have for  0: 0.54 and for +1: 0.46
On batch number 1/100 and on k = 7 we have for  0: 0.54 and for +1: 0.46
elbo: -1.4224680662155151


 78%|███████▊  | 390/500 [20:12<05:42,  3.11s/it]

On batch number 1/100 and on k = 1 we have for  0: 0.48 and for +1: 0.52
On batch number 1/100 and on k = 2 we have for  0: 0.55 and for +1: 0.45
On batch number 1/100 and on k = 3 we have for  0: 0.53 and for +1: 0.47
On batch number 1/100 and on k = 4 we have for  0: 0.54 and for +1: 0.46
On batch number 1/100 and on k = 5 we have for  0: 0.49 and for +1: 0.51
On batch number 1/100 and on k = 6 we have for  0: 0.53 and for +1: 0.47
On batch number 1/100 and on k = 7 we have for  0: 0.51 and for +1: 0.49
elbo: -1.7510042190551758


 80%|████████  | 400/500 [20:44<05:11,  3.11s/it]

On batch number 1/100 and on k = 1 we have for  0: 0.41 and for +1: 0.59
On batch number 1/100 and on k = 2 we have for  0: 0.4 and for +1: 0.6
On batch number 1/100 and on k = 3 we have for  0: 0.51 and for +1: 0.49
On batch number 1/100 and on k = 4 we have for  0: 0.41 and for +1: 0.59
On batch number 1/100 and on k = 5 we have for  0: 0.47 and for +1: 0.53
On batch number 1/100 and on k = 6 we have for  0: 0.52 and for +1: 0.48
On batch number 1/100 and on k = 7 we have for  0: 0.53 and for +1: 0.47
elbo: -1.5231746435165405


 82%|████████▏ | 410/500 [21:15<04:38,  3.10s/it]

On batch number 1/100 and on k = 1 we have for  0: 0.45 and for +1: 0.55
On batch number 1/100 and on k = 2 we have for  0: 0.47 and for +1: 0.53
On batch number 1/100 and on k = 3 we have for  0: 0.45 and for +1: 0.55
On batch number 1/100 and on k = 4 we have for  0: 0.51 and for +1: 0.49
On batch number 1/100 and on k = 5 we have for  0: 0.46 and for +1: 0.54
On batch number 1/100 and on k = 6 we have for  0: 0.53 and for +1: 0.47
On batch number 1/100 and on k = 7 we have for  0: 0.46 and for +1: 0.54
elbo: -1.3134524822235107


 84%|████████▍ | 420/500 [21:46<04:07,  3.09s/it]

On batch number 1/100 and on k = 1 we have for  0: 0.52 and for +1: 0.48
On batch number 1/100 and on k = 2 we have for  0: 0.48 and for +1: 0.52
On batch number 1/100 and on k = 3 we have for  0: 0.52 and for +1: 0.48
On batch number 1/100 and on k = 4 we have for  0: 0.5 and for +1: 0.5
On batch number 1/100 and on k = 5 we have for  0: 0.49 and for +1: 0.51
On batch number 1/100 and on k = 6 we have for  0: 0.5 and for +1: 0.5
On batch number 1/100 and on k = 7 we have for  0: 0.49 and for +1: 0.51
elbo: -1.212195634841919


 86%|████████▌ | 430/500 [22:17<03:40,  3.16s/it]

On batch number 1/100 and on k = 1 we have for  0: 0.53 and for +1: 0.47
On batch number 1/100 and on k = 2 we have for  0: 0.56 and for +1: 0.44
On batch number 1/100 and on k = 3 we have for  0: 0.61 and for +1: 0.39
On batch number 1/100 and on k = 4 we have for  0: 0.64 and for +1: 0.36
On batch number 1/100 and on k = 5 we have for  0: 0.53 and for +1: 0.47
On batch number 1/100 and on k = 6 we have for  0: 0.58 and for +1: 0.42
On batch number 1/100 and on k = 7 we have for  0: 0.42 and for +1: 0.58
elbo: -1.4033719301223755


 88%|████████▊ | 440/500 [22:48<03:08,  3.13s/it]

On batch number 1/100 and on k = 1 we have for  0: 0.44 and for +1: 0.56
On batch number 1/100 and on k = 2 we have for  0: 0.45 and for +1: 0.55
On batch number 1/100 and on k = 3 we have for  0: 0.59 and for +1: 0.41
On batch number 1/100 and on k = 4 we have for  0: 0.51 and for +1: 0.49
On batch number 1/100 and on k = 5 we have for  0: 0.48 and for +1: 0.52
On batch number 1/100 and on k = 6 we have for  0: 0.54 and for +1: 0.46
On batch number 1/100 and on k = 7 we have for  0: 0.52 and for +1: 0.48
elbo: -1.3303240537643433


 90%|█████████ | 450/500 [23:19<02:35,  3.10s/it]

On batch number 1/100 and on k = 1 we have for  0: 0.49 and for +1: 0.51
On batch number 1/100 and on k = 2 we have for  0: 0.46 and for +1: 0.54
On batch number 1/100 and on k = 3 we have for  0: 0.52 and for +1: 0.48
On batch number 1/100 and on k = 4 we have for  0: 0.48 and for +1: 0.52
On batch number 1/100 and on k = 5 we have for  0: 0.58 and for +1: 0.42
On batch number 1/100 and on k = 6 we have for  0: 0.51 and for +1: 0.49
On batch number 1/100 and on k = 7 we have for  0: 0.5 and for +1: 0.5
elbo: -1.1245832443237305


 92%|█████████▏| 460/500 [23:50<02:02,  3.07s/it]

On batch number 1/100 and on k = 1 we have for  0: 0.49 and for +1: 0.51
On batch number 1/100 and on k = 2 we have for  0: 0.47 and for +1: 0.53
On batch number 1/100 and on k = 3 we have for  0: 0.6 and for +1: 0.4
On batch number 1/100 and on k = 4 we have for  0: 0.53 and for +1: 0.47
On batch number 1/100 and on k = 5 we have for  0: 0.52 and for +1: 0.48
On batch number 1/100 and on k = 6 we have for  0: 0.64 and for +1: 0.36
On batch number 1/100 and on k = 7 we have for  0: 0.46 and for +1: 0.54
elbo: -1.3000516891479492


 94%|█████████▍| 470/500 [24:21<01:32,  3.08s/it]

On batch number 1/100 and on k = 1 we have for  0: 0.58 and for +1: 0.42
On batch number 1/100 and on k = 2 we have for  0: 0.52 and for +1: 0.48
On batch number 1/100 and on k = 3 we have for  0: 0.45 and for +1: 0.55
On batch number 1/100 and on k = 4 we have for  0: 0.51 and for +1: 0.49
On batch number 1/100 and on k = 5 we have for  0: 0.52 and for +1: 0.48
On batch number 1/100 and on k = 6 we have for  0: 0.56 and for +1: 0.44
On batch number 1/100 and on k = 7 we have for  0: 0.49 and for +1: 0.51
elbo: -1.100080132484436


 96%|█████████▌| 480/500 [24:51<01:01,  3.07s/it]

On batch number 1/100 and on k = 1 we have for  0: 0.43 and for +1: 0.57
On batch number 1/100 and on k = 2 we have for  0: 0.6 and for +1: 0.4
On batch number 1/100 and on k = 3 we have for  0: 0.55 and for +1: 0.45
On batch number 1/100 and on k = 4 we have for  0: 0.59 and for +1: 0.41
On batch number 1/100 and on k = 5 we have for  0: 0.53 and for +1: 0.47
On batch number 1/100 and on k = 6 we have for  0: 0.51 and for +1: 0.49
On batch number 1/100 and on k = 7 we have for  0: 0.58 and for +1: 0.42
elbo: -1.0111397504806519


 98%|█████████▊| 490/500 [25:23<00:31,  3.11s/it]

On batch number 1/100 and on k = 1 we have for  0: 0.48 and for +1: 0.52
On batch number 1/100 and on k = 2 we have for  0: 0.57 and for +1: 0.43
On batch number 1/100 and on k = 3 we have for  0: 0.5 and for +1: 0.5
On batch number 1/100 and on k = 4 we have for  0: 0.47 and for +1: 0.53
On batch number 1/100 and on k = 5 we have for  0: 0.53 and for +1: 0.47
On batch number 1/100 and on k = 6 we have for  0: 0.52 and for +1: 0.48
On batch number 1/100 and on k = 7 we have for  0: 0.49 and for +1: 0.51
elbo: -1.1851496696472168


100%|██████████| 500/500 [25:54<00:00,  3.11s/it]


In [11]:
target.decoder.W.weight.T

tensor([[-0.6599, -0.6826],
        [ 0.4753,  0.5219]], grad_fn=<PermuteBackward>)

In [12]:
target.decoder.W.bias

In [13]:
print(true_theta)

tensor([[5.1967, 5.6456],
        [4.1954, 6.4900]])


In [14]:
enc.mu.weight

Parameter containing:
tensor([[ 0.1218,  0.7920,  0.4092, -0.1279],
        [-0.3447, -0.4989,  0.3742,  0.1400]], requires_grad=True)

## VAE

In [15]:
# 'Encoder' - simple matrix
class Encoder_simple(nn.Module):
    def __init__(self, L, z_dim, device='cpu'):
        super(Encoder_simple, self).__init__()
        self.L = L
        self.z_dim = z_dim
        self.e = nn.Linear(in_features=self.L, out_features=2*self.L)
        self.mu = nn.Linear(in_features=2*self.L, out_features=self.z_dim)
        self.sigma = nn.Linear(in_features=2*self.L, out_features=self.z_dim)
    def forward(self, x):
        hid = torch.relu(self.e(x))
        return self.mu(hid), nn.functional.softplus(self.sigma(hid))

In [16]:
enc = Encoder_simple(L=L, z_dim=z_dim, device=device).to(device)
dec = Decoder(L=L, z_dim=z_dim, device=device).to(device)
target = NN_bernoulli({}, dec, device).to(device)

std_normal = torch.distributions.Normal(loc=torch.tensor(0., device=device),
                                                scale=torch.tensor(1., device=device))
args.std_normal = std_normal

In [17]:
params = list(enc.parameters()) + list(target.parameters())
optimizer = torch.optim.Adam(params=params)
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[100, 200, 300, 400], gamma=0.5) #torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10, eta_min=0, last_epoch=-1)

In [18]:
best_elbo = -float("inf")
current_tolerance = 0
print_info_ = 10
# with torch.autograd.detect_anomaly():
for ep in tqdm(range(args.num_epoches)): # cycle over epoches
    for b_num, batch_train in enumerate(dataloader): # cycle over batches
#         pdb.set_trace()
        mu, sigma = enc(batch_train) # sample mu and sigma from encoder
        u = args.std_normal.sample(mu.shape) # sample random tensor for reparametrization trick
        z = mu + sigma * u # reperametrization trick
        sum_log_sigma = torch.sum(torch.log(sigma), 1)
        
        log_q = -sum_log_sigma + std_normal.log_prob(u).sum(1)
        log_p = target.get_logdensity(x=batch_train, z=z)
        elbo = torch.mean(log_p - log_q)
        (-elbo).backward()
        optimizer.step()
        optimizer.zero_grad()
        scheduler.step()
    if ep % 50 == 0:
        print('elbo:', elbo.cpu().detach().item())

  0%|          | 2/500 [00:00<01:19,  6.27it/s]

elbo: -1.609104871749878


 10%|█         | 52/500 [00:06<00:54,  8.22it/s]

elbo: -1.403254747390747


 20%|██        | 102/500 [00:12<00:48,  8.22it/s]

elbo: -1.3826212882995605


 30%|███       | 152/500 [00:18<00:42,  8.25it/s]

elbo: -1.3923649787902832


 40%|████      | 202/500 [00:24<00:36,  8.27it/s]

elbo: -1.3815406560897827


 50%|█████     | 252/500 [00:30<00:29,  8.28it/s]

elbo: -1.3918979167938232


 60%|██████    | 302/500 [00:36<00:23,  8.26it/s]

elbo: -1.3917657136917114


 70%|███████   | 352/500 [00:42<00:17,  8.28it/s]

elbo: -1.3886233568191528


 80%|████████  | 402/500 [00:48<00:11,  8.20it/s]

elbo: -1.3873263597488403


 90%|█████████ | 452/500 [00:54<00:05,  8.20it/s]

elbo: -1.3879351615905762


100%|██████████| 500/500 [01:00<00:00,  8.23it/s]


In [19]:
target.decoder.W.weight.T

tensor([[ 0.0047, -0.0053],
        [-0.0002,  0.0041]], grad_fn=<PermuteBackward>)

In [20]:
print(true_theta)

tensor([[5.1967, 5.6456],
        [4.1954, 6.4900]])


In [21]:
best_elbo = -float("inf")
current_tolerance = 0
print_info_ = 10
# with torch.autograd.detect_anomaly():

for b_num, batch_train in enumerate(dataloader): # cycle over batches
#         pdb.set_trace()
    mu, sigma = enc(batch_train) # sample mu and sigma from encoder
    u = args.std_normal.sample(mu.shape) # sample random tensor for reparametrization trick
    z = mu + sigma * u # reperametrization trick
    sum_log_sigma = torch.sum(torch.log(sigma), 1)
    break

In [29]:
batch_train[:10]

tensor([[1., 1.],
        [0., 0.],
        [0., 1.],
        [0., 0.],
        [0., 0.],
        [1., 1.],
        [0., 0.],
        [0., 0.],
        [0., 0.],
        [1., 1.]])

In [31]:
torch.distributions.Bernoulli(logits=target.decoder(z)[0][:10]).sample()

tensor([[0., 0.],
        [0., 0.],
        [0., 1.],
        [0., 0.],
        [0., 1.],
        [1., 0.],
        [0., 1.],
        [0., 0.],
        [1., 1.],
        [1., 0.]])