In [1]:
import torch
import torch.nn as nn
from tqdm import tqdm
import pdb
import matplotlib.pyplot as plt
import sys
sys.path.insert(0, './src/')
from target import NN_bernoulli
from kernels import HMC_our, Reverse_kernel
%matplotlib inline

In [2]:
class dotdict(dict):
    """dot.notation access to dictionary attributes"""
    __getattr__ = dict.get
    __setattr__ = dict.__setitem__
    __delattr__ = dict.__delitem__

In [3]:
# 'Encoder' - simple matrix
class Encoder(nn.Module):
    def __init__(self, L, z_dim, device='cpu'):
        super(Encoder, self).__init__()
        self.L = L
        self.z_dim = z_dim
        self.mu = nn.Linear(in_features=self.L, out_features=self.z_dim)
        self.sigma = nn.Linear(in_features=self.L, out_features=self.z_dim)
        self.h = nn.Linear(in_features=self.L, out_features=self.z_dim)
    def forward(self, x):
        return self.mu(x), nn.functional.softplus(self.sigma(x)), self.h(x)
    
# 'Decoder' - simple matrix, return logits
class Decoder(nn.Module):
    def __init__(self, L, z_dim, device='cpu'):
        super(Decoder, self).__init__()
        self.L = L
        self.z_dim = z_dim
        self.W = nn.Linear(in_features=self.z_dim, out_features=self.L, bias=False)
    def forward(self, z):
        return [self.W(z)]

In [4]:
L = 2
z_dim = 2
N = 10000
device = "cuda:0" if torch.cuda.is_available() else "cpu"

args = dotdict({})
args.K = 1
args.N = 1
args.z_dim = z_dim
args.torchType = torch.float32
args.device = device
args.learnable_reverse = True
args.num_epoches = 500
args.train_batch_size = 100
args.amortize = False
args.gamma = 0.1 ## Stepsize
args.alpha = 0.5  ## For partial momentum refresh
args.train_only_inference_period = 10
args.train_only_inference_cutoff = 5
args.hoffman_idea = True
args.separate_params = True

In [5]:
enc = Encoder(L=L, z_dim=z_dim, device=device).to(device)
dec = Decoder(L=L, z_dim=z_dim, device=device).to(device)
reverse_kernel = Reverse_kernel(args)
target = NN_bernoulli({}, dec, device).to(device)
transitions = nn.ModuleList([HMC_our(kwargs=args).to(args.device) for _ in range(args['K'])])

std_normal = torch.distributions.Normal(loc=torch.tensor(0., device=device),
                                                scale=torch.tensor(1., device=device))
args.std_normal = std_normal

In [6]:
true_theta = std_normal.sample((z_dim, L))
print('True decoder matrix')
print(true_theta)
print('-' * 75)
data_logits = std_normal.sample((N, z_dim)) @ true_theta
data = torch.distributions.Bernoulli(logits=data_logits).sample()
print('Generated data example:')
print(data[:10])

True decoder matrix
tensor([[ 0.1476,  0.7996],
        [-0.9377, -1.8887]])
---------------------------------------------------------------------------
Generated data example:
tensor([[0., 1.],
        [0., 1.],
        [0., 0.],
        [1., 0.],
        [1., 0.],
        [0., 1.],
        [1., 1.],
        [1., 0.],
        [1., 1.],
        [1., 1.]])


In [7]:
dataloader = torch.utils.data.DataLoader(data, batch_size=args.train_batch_size, shuffle=True)

params = list(enc.parameters()) + list(reverse_kernel.parameters())
optimizer = torch.optim.Adam(params=target.parameters())
optimizer_inference = torch.optim.Adam(params=params)

In [8]:
def compute_loss(z_new, p_new, u, p_old, x, sum_log_alpha, sum_log_jac, sum_log_sigma, mu=None, all_directions=None, h=None):
    if args.learnable_reverse:
        log_r = reverse_kernel(z_fin=z_new.detach(), mu=mu.detach(), a=all_directions)
        log_m = args.std_normal.log_prob(u).sum(1) + args.std_normal.log_prob(p_old).sum(1) - sum_log_jac - sum_log_sigma + sum_log_alpha
    else:
        log_r = 0 #-args.K * torch_log_2
        log_m = args.std_normal.log_prob(u).sum(1) + args.std_normal.log_prob(p_old).sum(1) - sum_log_jac - sum_log_sigma # + sum_log_alpha
        
    log_p = target.get_logdensity(z=z_new, x=x, prior=get_prior, args=args) + args.std_normal.log_prob(p_new.sum(1))
    elbo_full = log_p + log_r - log_m
    grad_elbo = torch.mean(elbo_full + elbo_full.detach() * sum_log_alpha)
    return elbo_full, grad_elbo 

In [9]:
best_elbo = -float("inf")
current_tolerance = 0
print_info_ = 10
# with torch.autograd.detect_anomaly():
for ep in tqdm(range(args.num_epoches)): # cycle over epoches
    for b_num, batch_train in enumerate(dataloader): # cycle over batches
        plt.close()        
        pdb.set_trace()

        mu, sigma, h = enc(batch_train) # sample mu and sigma from encoder
        u = args.std_normal.sample(mu.shape) # sample random tensor for reparametrization trick
        z = mu + sigma * u # reperametrization trick
        sum_log_sigma = torch.sum(torch.log(sigma), 1)

        p_old = args.std_normal.sample(z.shape)
        cond_vectors = [args.std_normal.sample(p_old.shape) for _ in range(args.K)]

        sum_log_alpha = torch.zeros(mu.shape[0], dtype=args.torchType, device=args.device) # for grad log alpha accumulation
        sum_log_jacobian = torch.zeros(mu.shape[0], dtype=args.torchType, device=args.device) # for log_jacobian accumulation
        p = p_old
        if args.learnable_reverse:
            all_directions = torch.tensor([], device=args.device)
        else:
            all_directions = None
        for k in range(args.K):
            # sample alpha - transition probabilities 
            if args.amortize:
#                     pdb.set_trace()
                z, p, log_jac, current_log_alphas, directions, _ = transitions.make_transition(q_old=z, x=batch_train,
                                                    p_old=p, k=cond_vectors[k], target_distr=target, args=args)
            else:
                z, p, log_jac, current_log_alphas, directions, _ = transitions[k].make_transition(q_old=z, x=batch_train,
                                                                    p_old=p, k=cond_vectors[k], target_distr=target, args=args) # sample a_i -- directions
            if ep  % print_info_ == 0 and b_num % (100 * print_info_) == 0:
                print('On batch number {}/{} and on k = {} we have for  0: {} and for +1: {}'.format(b_num + 1,
                                                                        data.shape[0] // args['train_batch_size'],
                                                                           k + 1,
                                                    (directions==0.).to(float).mean(),
                                                                    (directions==1.).to(float).mean()))
                if args.amortize:
                    print('Stepsize {}'.format(np.exp(transitions.gamma.cpu().detach().item())))
                    print('Autoregression coeff {}'.format(torch.sigmoid(transitions.alpha_logit).cpu().detach().item()))
            if args.learnable_reverse:
                all_directions = torch.cat([all_directions, directions.view(-1, 1)], dim=1)
            # Accumulate alphas
            sum_log_alpha = sum_log_alpha + current_log_alphas
            sum_log_jacobian = sum_log_jacobian + log_jac  # refresh log jacobian
        ##############################################
        if args.hoffman_idea:
            if args.learnable_reverse:
                log_r = reverse_kernel(z_fin=z.detach(), h=h.detach(), a=all_directions)
                log_m = args.std_normal.log_prob(u).sum(1) + args.std_normal.log_prob(p_old).sum(1) - sum_log_jacobian - sum_log_sigma + sum_log_alpha
            else:
                log_r = 0 #-args.K * torch_log_2
                log_m = args.std_normal.log_prob(u).sum(1) + args.std_normal.log_prob(p_old).sum(1) - sum_log_jacobian - sum_log_sigma # + sum_log_alpha
            log_p = target.get_logdensity(z=z, x=batch_train, args=args) + args.std_normal.log_prob(p.sum(1))
            elbo_full = log_p + log_r - log_m
    #                 pdb.set_trace()
            ### Gradient of the first objective:
            target.eval()
            obj_1 = torch.mean(elbo_full + elbo_full.detach() * sum_log_alpha)
            (-obj_1).backward(retain_graph=True)
            optimizer_inference.step()
            optimizer_inference.zero_grad()
            optimizer.zero_grad() 

            ### Gradient of the second objective:
            target.train()
            log_p = target.get_logdensity(z=z.detach(), x=batch_train, args=args) + args.std_normal.log_prob(p.detach()).sum(1)
            elbo_full = log_p # - log_m
            obj_2 = torch.mean(elbo_full + elbo_full.detach() * sum_log_alpha)
            (-obj_2).backward()
            optimizer.step()
            optimizer_inference.zero_grad()
            optimizer.zero_grad()
            ###########################################################
        else:
            elbo_full, grad_elbo = compute_loss(z_new=z, p_new=p, u=u, p_old=p_old, x=batch_train, sum_log_alpha=sum_log_alpha,
                                                sum_log_jac=sum_log_jacobian, sum_log_sigma=sum_log_sigma, mu=mu,
                                                all_directions=all_directions, args=args, h=h)
            (-grad_elbo).backward()

                
        if args.separate_params: # if we separate params of inference part and generation part
            optimizer_inference.step() # we always perform step for inference part
            if ep % args.train_only_inference_period > args.train_only_inference_cutoff: # but sometimes for gen
                optimizer.step()
            optimizer.zero_grad()
            optimizer_inference.zero_grad()
        else:
            optimizer.step()
            optimizer.zero_grad()
            
        if ep  % print_info_ == 0 and b_num % (100 * print_info_) == 0:
            print('obj_1:', obj_1)
            print('obj_2:', obj_2)

  0%|          | 0/500 [00:00<?, ?it/s]

[27] > [33;01m<ipython-input-9-8352655101b0>[00m([36;01m10[00m)<module>()
-> mu, sigma, h = enc(batch_train) # sample mu and sigma from encoder


(Pdb++)  n


[27] > [33;01m<ipython-input-9-8352655101b0>[00m([36;01m11[00m)<module>()
-> u = args.std_normal.sample(mu.shape) # sample random tensor for reparametrization trick


(Pdb++)  


[27] > [33;01m<ipython-input-9-8352655101b0>[00m([36;01m12[00m)<module>()
-> z = mu + sigma * u # reperametrization trick


(Pdb++)  


[27] > [33;01m<ipython-input-9-8352655101b0>[00m([36;01m13[00m)<module>()
-> sum_log_sigma = torch.sum(torch.log(sigma), 1)


(Pdb++)  print(h)


tensor([[ 0.1809, -0.0150],
        [ 0.0026,  0.2459],
        [ 0.0026,  0.2459],
        [ 0.1809, -0.0150],
        [ 0.6491, -0.2483],
        [ 0.1809, -0.0150],
        [ 0.0026,  0.2459],
        [ 0.4707,  0.0126],
        [ 0.4707,  0.0126],
        [ 0.6491, -0.2483],
        [ 0.1809, -0.0150],
        [ 0.4707,  0.0126],
        [ 0.6491, -0.2483],
        [ 0.6491, -0.2483],
        [ 0.0026,  0.2459],
        [ 0.1809, -0.0150],
        [ 0.6491, -0.2483],
        [ 0.0026,  0.2459],
        [ 0.4707,  0.0126],
        [ 0.0026,  0.2459],
        [ 0.0026,  0.2459],
        [ 0.6491, -0.2483],
        [ 0.4707,  0.0126],
        [ 0.1809, -0.0150],
        [ 0.0026,  0.2459],
        [ 0.4707,  0.0126],
        [ 0.6491, -0.2483],
        [ 0.1809, -0.0150],
        [ 0.6491, -0.2483],
        [ 0.4707,  0.0126],
        [ 0.6491, -0.2483],
        [ 0.1809, -0.0150],
        [ 0.6491, -0.2483],
        [ 0.0026,  0.2459],
        [ 0.0026,  0.2459],
        [ 0.0026,  0

(Pdb++)  n


[27] > [33;01m<ipython-input-9-8352655101b0>[00m([36;01m15[00m)<module>()
-> p_old = args.std_normal.sample(z.shape)


(Pdb++)  


[27] > [33;01m<ipython-input-9-8352655101b0>[00m([36;01m16[00m)<module>()
-> cond_vectors = [args.std_normal.sample(p_old.shape) for _ in range(args.K)]


(Pdb++)  


[27] > [33;01m<ipython-input-9-8352655101b0>[00m([36;01m18[00m)<module>()
-> sum_log_alpha = torch.zeros(mu.shape[0], dtype=args.torchType, device=args.device) # for grad log alpha accumulation


(Pdb++)  


[27] > [33;01m<ipython-input-9-8352655101b0>[00m([36;01m19[00m)<module>()
-> sum_log_jacobian = torch.zeros(mu.shape[0], dtype=args.torchType, device=args.device) # for log_jacobian accumulation


(Pdb++)  


[27] > [33;01m<ipython-input-9-8352655101b0>[00m([36;01m20[00m)<module>()
-> p = p_old


(Pdb++)  


[27] > [33;01m<ipython-input-9-8352655101b0>[00m([36;01m21[00m)<module>()
-> if args.learnable_reverse:


(Pdb++)  


[27] > [33;01m<ipython-input-9-8352655101b0>[00m([36;01m22[00m)<module>()
-> all_directions = torch.tensor([], device=args.device)


(Pdb++)  l


 [34;01m17[39;49;00m  	
 [34;01m18[39;49;00m  	        sum_log_alpha = torch.zeros(mu.shape[[34;01m0[39;49;00m], dtype=args.torchType, device=args.device) [30;01m# for grad log alpha accumulation[39;49;00m
 [34;01m19[39;49;00m  	        sum_log_jacobian = torch.zeros(mu.shape[[34;01m0[39;49;00m], dtype=args.torchType, device=args.device) [30;01m# for log_jacobian accumulation[39;49;00m
 [34;01m20[39;49;00m  	        p = p_old
 [34;01m21[39;49;00m  	        [34;01mif[39;49;00m args.learnable_reverse:
 [34;01m22[39;49;00m  ->	            all_directions = torch.tensor([], device=args.device)
 [34;01m23[39;49;00m  	        [34;01melse[39;49;00m:
 [34;01m24[39;49;00m  	            all_directions = [36;01mNone[39;49;00m
 [34;01m25[39;49;00m  	        [34;01mfor[39;49;00m k [35;01min[39;49;00m [36;01mrange[39;49;00m(args.K):
 [34;01m26[39;49;00m  	            [30;01m# sample alpha - transition probabilities[39;49;00m
 [34;01m27[39;49;00m  	         

(Pdb++)  n


[27] > [33;01m<ipython-input-9-8352655101b0>[00m([36;01m25[00m)<module>()
-> for k in range(args.K):


(Pdb++)  


[27] > [33;01m<ipython-input-9-8352655101b0>[00m([36;01m27[00m)<module>()
-> if args.amortize:


(Pdb++)  


[27] > [33;01m<ipython-input-9-8352655101b0>[00m([36;01m32[00m)<module>()
-> z, p, log_jac, current_log_alphas, directions, _ = transitions[k].make_transition(q_old=z, x=batch_train,


(Pdb++)  


[27] > [33;01m<ipython-input-9-8352655101b0>[00m([36;01m33[00m)<module>()
-> p_old=p, k=cond_vectors[k], target_distr=target, args=args) # sample a_i -- directions


(Pdb++)  


[27] > [33;01m<ipython-input-9-8352655101b0>[00m([36;01m34[00m)<module>()
-> if ep  % print_info_ == 0 and b_num % (100 * print_info_) == 0:


(Pdb++)  


[27] > [33;01m<ipython-input-9-8352655101b0>[00m([36;01m35[00m)<module>()
-> print('On batch number {}/{} and on k = {} we have for  0: {} and for +1: {}'.format(b_num + 1,


(Pdb++)  


[27] > [33;01m<ipython-input-9-8352655101b0>[00m([36;01m36[00m)<module>()
-> data.shape[0] // args['train_batch_size'],


(Pdb++)  


[27] > [33;01m<ipython-input-9-8352655101b0>[00m([36;01m37[00m)<module>()
-> k + 1,


(Pdb++)  


[27] > [33;01m<ipython-input-9-8352655101b0>[00m([36;01m38[00m)<module>()
-> (directions==0.).to(float).mean(),


(Pdb++)  l


 [34;01m28[39;49;00m  	[30;01m#                     pdb.set_trace()[39;49;00m
 [34;01m29[39;49;00m  	                z, p, log_jac, current_log_alphas, directions, _ = transitions.make_transition(q_old=z, x=batch_train,
 [34;01m30[39;49;00m  	                                                    p_old=p, k=cond_vectors[k], target_distr=target, args=args)
 [34;01m31[39;49;00m  	            [34;01melse[39;49;00m:
 [34;01m32[39;49;00m  	                z, p, log_jac, current_log_alphas, directions, _ = transitions[k].make_transition(q_old=z, x=batch_train,
 [34;01m33[39;49;00m  	                                                                    p_old=p, k=cond_vectors[k], target_distr=target, args=args) [30;01m# sample a_i -- directions[39;49;00m
 [34;01m34[39;49;00m  	            [34;01mif[39;49;00m ep  % print_info_ == [34;01m0[39;49;00m [35;01mand[39;49;00m b_num % ([34;01m100[39;49;00m * print_info_) == [34;01m0[39;49;00m:
 [34;01m35[39;49;00m  	        

(Pdb++)  n


[27] > [33;01m<ipython-input-9-8352655101b0>[00m([36;01m39[00m)<module>()
-> (directions==1.).to(float).mean()))


(Pdb++)  


On batch number 1/100 and on k = 1 we have for  0: 0.0 and for +1: 1.0
[27] > [33;01m<ipython-input-9-8352655101b0>[00m([36;01m40[00m)<module>()
-> if args.amortize:


(Pdb++)  


[27] > [33;01m<ipython-input-9-8352655101b0>[00m([36;01m43[00m)<module>()
-> if args.learnable_reverse:


(Pdb++)  n


[27] > [33;01m<ipython-input-9-8352655101b0>[00m([36;01m44[00m)<module>()
-> all_directions = torch.cat([all_directions, directions.view(-1, 1)], dim=1)


(Pdb++)  q


BdbQuit: 

In [9]:
target.decoder.W.weight.T

tensor([[-0.1764,  0.1606],
        [ 0.8860, -0.6262]], grad_fn=<PermuteBackward>)

In [10]:
target.decoder.W.bias

In [11]:
print(true_theta)

tensor([[ 0.1634, -1.1773],
        [ 0.3227, -1.2502]])
