In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch
import torch.nn as nn

### Data

In [3]:
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.utils.data.sampler import SubsetRandomSampler
from torch.utils.data import Dataset, TensorDataset, DataLoader, ConcatDataset


transform = transforms.Compose([transforms.ToTensor()])

data_path = '../Datasets/mnist/'
dataset = datasets.MNIST(data_path, train=True, transform=transform, download=True)
test_dataset = datasets.MNIST(data_path, train=False, transform=transform, download=True)

train_loader = DataLoader(dataset, batch_size=100, num_workers=4)

### Model

In [13]:
from cebm.network.network import _cnn, build_decoder

channels = [32,32,64,64]
num_neurons = [128]
kernels = [3,4,4,4]
strides = [1,2,2,2]
paddings = [3,1,1,1]
activation = 'SiLU'

p, _ = _cnn(28, 28, 1, channels=channels, num_neurons=num_neurons, kernels=kernels, strides=strides, paddings=paddings, activation=activation)
p = nn.Sequential(p,nn.Linear(128, 1))

In [14]:
p

Sequential(
  (0): Sequential(
    (0): Sequential(
      (0): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=(3, 3))
      (1): SiLU(inplace=True)
      (2): Conv2d(32, 32, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
      (3): SiLU(inplace=True)
      (4): Conv2d(32, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
      (5): SiLU(inplace=True)
      (6): Conv2d(64, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
      (7): SiLU(inplace=True)
    )
    (1): Sequential(
      (0): Flatten(start_dim=1, end_dim=-1)
      (1): Linear(in_features=1024, out_features=128, bias=True)
      (2): SiLU(inplace=True)
    )
  )
  (1): Linear(in_features=128, out_features=1, bias=True)
)

In [15]:
p(torch.rand(10,1,28,28)).shape

torch.Size([10, 1])

In [16]:
q = build_decoder(p[0]) 
q = q[1:] + [nn.Sigmoid()]
q = nn.Sequential(*q)

In [17]:
q

Sequential(
  (0): Linear(in_features=128, out_features=1024, bias=True)
  (1): SiLU(inplace=True)
  (2): Reshape()
  (3): ConvTranspose2d(64, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
  (4): SiLU(inplace=True)
  (5): ConvTranspose2d(64, 32, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
  (6): SiLU(inplace=True)
  (7): ConvTranspose2d(32, 32, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
  (8): SiLU(inplace=True)
  (9): ConvTranspose2d(32, 1, kernel_size=(3, 3), stride=(1, 1), padding=(3, 3))
  (10): Sigmoid()
)

In [18]:
logsigma_qx = nn.Parameter((torch.ones(1, ) * .01).log())
logsigma_xiz = nn.Parameter((torch.ones(noise_dim,) * init_post_logsigma).log())

### Optimizers

In [None]:
optim_logsigma_xiz = torch.optim.Adam([logsigma_xiz], lr=1e-3)
optim_p = torch.optim.Adam(p.parameters(), lr=1e-4)

### Objective

In [None]:
import torch.distributions as D

def epoch(x):
    
    B = x.shape[0]
    S = 20
    
    # sample z ~ q(z)
    z0 = torch.randn(B, 128)
    
    # sample x ~ q(x|z0)
    x_q_z0 = q(z0)
    
    # xi(z|z0)
    xi_z_z0 = D.Normal(z0, logsigma_xiz.exp())
    
    # sample z ~ xi(z|z0)
    z = xi_z_z0.rsample() # (B, 128)
    
    # get decoder means
    x_q_z = q(z)
    
    # distributions in q 
    q_x_z = D.Normal(x_q_z, sigma_qx.exp())
    q_z = D.Normal(0, 1)
    
    
    # elbo for logsigma_xiz
    breakpoint()
    elbo = q_x_z.log_prob(x_q_z0) + q_z.log_prob(z) + xi_z_z0.entropy()
    logsigma_loss = -elbo.mean()
    
    optim_logsigma_xiz.zero_grad()
    logsigma_loss.backward()
    optim_logsigma_xiz.step()
    
    # p loss
    energy_x_data = p(x)
    energy_x_q = p(x_q_z0)
    breakpoint()
    grad_energy_xd = torch.autograd.grad(energy_x_data.sum(), x, create_graph=True)[0].flatten(start_dim=1).norm(2, 1)
    
    gamma = 0.1
    p_loss = energy_x_data - energy_x_q + gamma * grad_energy_xd.pow(2)
    
    optim_p.zero_grad()
    p_loss.backward()
    optim_p.step()
    
    
    # sample z ~ xi(z|z0)
    z = xi_z_z0.rsample(sample_size=(S,)) # (S, B, 128)
    
    # get decoder means
    x_q_z = q(z)
    
    # distributions in q 
    q_x_z = D.Normal(x_q_z, sigma_qx.exp())
    
    log_w = (q_x_z.log_prob(x_q_z0) + q_z.log_prob(z) - xi_z_z0.log_prob(z)).sum(-1)
    log_w = log_w - log_w.logsumexp(0) # (S, B)
    
    # grad_x log_qx
    grad_x_log_qx = (x_q_z0[None] - x_q_z) / (sigma_qx ** 2) # (S, B, 1, 28, 28)
    grad_x_log_qx = (log_w.exp()[:,:,None] *  grad_x_log_qx).sum(0)
    
    
    
    
    
    
    
    

In [1]:
import torch
