In [1]:
import torch
import torch.nn as nn
from tqdm import tqdm
import pdb
import matplotlib.pyplot as plt
import sys
sys.path.insert(0, './src/')
from target import NN_bernoulli
%matplotlib inline

In [2]:
class dotdict(dict):
    """dot.notation access to dictionary attributes"""
    __getattr__ = dict.get
    __setattr__ = dict.__setitem__
    __delattr__ = dict.__delitem__

In [3]:
# 'Encoder' - simple matrix
class Encoder(nn.Module):
    def __init__(self, L, z_dim, device='cpu'):
        super(Encoder, self).__init__()
        self.L = L
        self.z_dim = z_dim
        self.W = nn.Linear(in_features=self.L, out_features=self.z_dim)
        self.noise = torch.distributions.Normal(loc=torch.tensor(0., device=device),
                                                scale=torch.tensor(1., device=device))
    def forward(self, x):
        return self.W(x) + self.noise.sample((self.L, self.z_dim))
    
# 'Decoder' - simple matrix, return logits
class Decoder(nn.Module):
    def __init__(self, L, z_dim, device='cpu'):
        super(Decoder, self).__init__()
        self.L = L
        self.z_dim = z_dim
        self.W = nn.Linear(in_features=self.z_dim, out_features=self.L)
    def forward(self, z):
        return [self.W(z)]

In [4]:
L = 10
z_dim = 2
N = 1000
device = "cuda:0" if torch.cuda.is_available() else "cpu"

enc = Encoder(L=L, z_dim=z_dim, device=device).to(device)
dec = Decoder(L=L, z_dim=z_dim, device=device).to(device)
target = NN_bernoulli({}, dec, device)
std_normal = enc.noise

In [5]:
true_theta = std_normal.sample((z_dim, L))
print('True decoder matrix')
print(true_theta)
print('-' * 75)
data_logits = std_normal.sample((N, z_dim)) @ true_theta
data = torch.distributions.Bernoulli(logits=data_logits).sample()
print('Generated data example:')
print(data[:10])

True decoder matrix
tensor([[ 0.1479, -1.3777, -0.7659, -0.0702, -0.2629,  1.0342,  0.2221,  0.5664,
         -0.0814,  0.6970],
        [ 0.5635, -0.8596,  2.3421,  0.8749,  0.0081, -1.2952,  1.9679,  0.5302,
          0.0713, -0.2374]], device='cuda:0')
---------------------------------------------------------------------------
Generated data example:
tensor([[1., 0., 1., 1., 0., 1., 0., 0., 0., 0.],
        [1., 1., 1., 1., 0., 0., 1., 0., 0., 1.],
        [0., 0., 0., 1., 1., 1., 0., 1., 1., 1.],
        [0., 0., 1., 1., 1., 1., 1., 1., 1., 1.],
        [0., 0., 1., 1., 0., 1., 0., 0., 0., 0.],
        [0., 1., 0., 1., 0., 1., 0., 1., 1., 1.],
        [1., 0., 1., 1., 1., 0., 1., 0., 0., 1.],
        [0., 0., 1., 0., 0., 0., 0., 1., 0., 1.],
        [1., 0., 1., 1., 0., 0., 0., 1., 1., 0.],
        [1., 1., 0., 0., 1., 1., 0., 0., 1., 1.]], device='cuda:0')
