https://github.com/masa-su/pixyz/blob/main/tutorial/Japanese/00-PixyzOverview.ipynb

## 確率分布

In [1]:
from __future__ import print_function
import torch
import torch.utils.data
from torch import nn, optim
from torch.nn import functional as F
from torchvision import datasets, transforms
from tensorboardX import SummaryWriter
from tqdm import tqdm

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [3]:
from pixyz.distributions import Normal, Bernoulli
from pixyz.utils import print_latex

z_dim = 64

#prior
#N(z; 0, 1)
prior = Normal(loc=torch.tensor(0.),
               scale=torch.tensor(1.),
               var=["z"],
               features_shape=[z_dim],
               name="p_{prior}").to(device)
print(prior)
print_latex(prior)

Distribution:
  p_{prior}(z)
Network architecture:
  Normal(
    name=p_{prior}, distribution_name=Normal,
    var=['z'], cond_var=[], input_var=[], features_shape=torch.Size([64])
    (loc): torch.Size([1, 64])
    (scale): torch.Size([1, 64])
  )


<IPython.core.display.Math object>

In [4]:
x_dim = 784

#generative model p(x|z)
class Generator(Bernoulli):
    def __init__(self, x_dim, z_dim):
        super(Generator, self).__init__(var=["x"], cond_var=["z"], name="p")
        self.fc1 = nn.Linear(z_dim, 512)
        self.fc2 = nn.Linear(512, 512)
        self.fc3 = nn.Linear(512, x_dim)
    
    def forward(self, z):
        h = F.relu(self.fc1(z))
        h = F.relu(self.fc2(h))
        return {"probs": torch.sigmoid(self.fc3(h))}

p = Generator(x_dim, z_dim).to(device)
print(p)
print_latex(p)

Distribution:
  p(x|z)
Network architecture:
  Generator(
    name=p, distribution_name=Bernoulli,
    var=['x'], cond_var=['z'], input_var=['z'], features_shape=torch.Size([])
    (fc1): Linear(in_features=64, out_features=512, bias=True)
    (fc2): Linear(in_features=512, out_features=512, bias=True)
    (fc3): Linear(in_features=512, out_features=784, bias=True)
  )


<IPython.core.display.Math object>

In [5]:
#interface model p(z|x)
class Interface(Normal):
    def __init__(self, x_dim, z_dim):
        super(Interface, self).__init__(var=["z"], cond_var=["x"], name="q")
        self.fc1 = nn.Linear(x_dim, 512)
        self.fc2 = nn.Linear(512, 512)
        self.fc3_1 = nn.Linear(512, z_dim)
        self.fc3_2 = nn.Linear(512, z_dim)
    
    def forward(self, x):
        h = F.relu(self.fc1(x))
        h = F.relu(self.fc2(h))
        return {"loc": self.fc3_1(h), "scale": F.softplus(self.fc3_2(h))}

q = Interface(x_dim, z_dim).to(device)
print(q)
print_latex(q)

Distribution:
  q(z|x)
Network architecture:
  Interface(
    name=q, distribution_name=Normal,
    var=['z'], cond_var=['x'], input_var=['x'], features_shape=torch.Size([])
    (fc1): Linear(in_features=784, out_features=512, bias=True)
    (fc2): Linear(in_features=512, out_features=512, bias=True)
    (fc3_1): Linear(in_features=512, out_features=64, bias=True)
    (fc3_2): Linear(in_features=512, out_features=64, bias=True)
  )


<IPython.core.display.Math object>

In [6]:
#サンプリング
#z ~ p(z)
prior_samples = prior.sample(batch_n=1)
print(prior_samples)
print(prior_samples.keys())
print(prior_samples["z"].shape)

{'z': tensor([[-0.8692,  0.3627,  1.6204, -0.2100, -2.3270, -1.7125,  0.9110, -1.4538,
         -0.3629,  1.4393,  2.2257,  0.4443,  0.2845,  1.1481,  0.8815, -0.5648,
         -0.9252, -0.4958, -0.6896,  2.4310,  0.6378, -1.0607,  1.7778, -0.1394,
         -0.3230,  0.9172,  1.7727,  1.2028, -0.3246,  0.1105, -0.2817,  1.6480,
         -0.4157,  0.1428, -0.6948,  1.1176,  0.6853,  0.5598,  1.3857,  1.8068,
          0.4076, -0.4483,  0.1025, -0.8067,  1.6540,  0.1772, -0.0590,  1.1351,
          0.0423,  2.1853,  0.1270,  0.3917,  1.3915,  1.9126,  0.5511, -0.5458,
         -2.0836,  1.5341, -0.6056,  0.1172,  0.4461, -0.2969, -1.5324, -0.4157]],
       device='cuda:0')}
dict_keys(['z'])
torch.Size([1, 64])


In [7]:
#同時分布の定義
#p_(x,z) = p(x|z)p(z)
p_joint = p * prior
print(p_joint)
print_latex(p_joint)

Distribution:
  p(x,z) = p(x|z)p_{prior}(z)
Network architecture:
  p_{prior}(z):
  Normal(
    name=p_{prior}, distribution_name=Normal,
    var=['z'], cond_var=[], input_var=[], features_shape=torch.Size([64])
    (loc): torch.Size([1, 64])
    (scale): torch.Size([1, 64])
  )
  p(x|z):
  Generator(
    name=p, distribution_name=Bernoulli,
    var=['x'], cond_var=['z'], input_var=['z'], features_shape=torch.Size([])
    (fc1): Linear(in_features=64, out_features=512, bias=True)
    (fc2): Linear(in_features=512, out_features=512, bias=True)
    (fc3): Linear(in_features=512, out_features=784, bias=True)
  )


<IPython.core.display.Math object>

In [8]:
#同時分布からサンプリング
#x,z ~ p(x,z)
p_joint_samples = p_joint.sample(batch_n=1)
print(p_joint_samples)
print(p_joint_samples.keys())
print(p_joint_samples["x"].shape)
print(p_joint_samples["z"].shape)

{'z': tensor([[-1.0338, -1.6246, -1.2903, -0.2704,  1.8970,  0.8257, -0.6334,  0.5154,
          0.7881,  2.0331, -1.1795, -0.6127,  1.2353,  1.7585, -0.6222,  0.6799,
          1.0940, -0.0681, -1.5002, -1.0008, -0.5492,  0.1305,  1.9852, -0.0512,
          1.5764, -0.0717,  0.1816,  1.7187,  2.0722, -0.7183,  0.5219,  0.8634,
          0.4418,  0.5128, -0.9097,  0.0071, -0.6311,  0.4591,  0.8982,  0.5703,
         -1.2073,  0.9992,  0.4002, -0.4339,  0.4324,  0.4580,  0.8162, -0.1648,
         -0.6515,  0.3465, -0.3718, -1.0434, -0.6909, -0.3027, -1.0341,  0.4750,
         -1.7513, -0.5857, -0.5222, -0.1897, -0.0980, -0.8155,  1.6731,  0.3113]],
       device='cuda:0'), 'x': tensor([[0., 0., 1., 1., 0., 1., 1., 1., 1., 1., 1., 1., 0., 1., 0., 1., 0., 0.,
         1., 0., 0., 1., 1., 1., 1., 1., 0., 0., 1., 1., 0., 1., 1., 1., 1., 0.,
         0., 1., 0., 0., 1., 1., 0., 1., 0., 0., 0., 0., 0., 1., 1., 0., 0., 0.,
         0., 0., 1., 1., 1., 1., 0., 1., 0., 1., 1., 0., 0., 1., 0., 1.

## 目的関数の設計

In [9]:
from pixyz.losses import KullbackLeibler

kl = KullbackLeibler(q, prior)
print_latex(kl)

<IPython.core.display.Math object>

In [10]:
reconst = -p.log_prob().expectation(q)
print_latex(reconst)

<IPython.core.display.Math object>

In [11]:
vae_loss = (kl + reconst).mean()
print_latex(vae_loss)

<IPython.core.display.Math object>

In [12]:
dummy_x = torch.randn([4, 784]).to(device)
vae_loss.eval({"x": dummy_x})

tensor(552.0492, device='cuda:0', grad_fn=<MeanBackward0>)

## モデル，訓練

In [13]:
from pixyz.models import Model
model = Model(loss=vae_loss, distributions=[p, q],
             optimizer=optim.AdamW, optimizer_params={"lr": 1e-3})
print(model)
print_latex(model)

Distributions (for training):
  p(x|z), q(z|x)
Loss function:
  mean \left(D_{KL} \left[q(z|x)||p_{prior}(z) \right] - \mathbb{E}_{q(z|x)} \left[\log p(x|z) \right] \right)
Optimizer:
  AdamW (
  Parameter Group 0
      amsgrad: False
      betas: (0.9, 0.999)
      capturable: False
      eps: 1e-08
      foreach: None
      lr: 0.001
      maximize: False
      weight_decay: 0.01
  )


<IPython.core.display.Math object>

In [14]:
dummy_x = torch.randn([10, 784])
def train_dummy(epoch, x):
    x = x.to(device)
    loss = model.train({"x": x})
    print('Epoch: {} Train Loss: {:4f}'.format(epoch, loss))

In [34]:
for epoch in range(10):
    train_dummy(epoch, dummy_x)

Epoch: 0 Train Loss: -4642.977539
Epoch: 1 Train Loss: -4646.271973
Epoch: 2 Train Loss: -4649.668945
Epoch: 3 Train Loss: -4636.995117
Epoch: 4 Train Loss: -4655.151855
Epoch: 5 Train Loss: -4656.193848
Epoch: 6 Train Loss: -4656.672852
Epoch: 7 Train Loss: -4656.166992
Epoch: 8 Train Loss: -4660.292480
Epoch: 9 Train Loss: -4662.324219
