# 3. 深層生成モデルの実装例：VAE

In [12]:
from __future__ import print_function
import torch
import torch.utils.data
from torch import nn, optim
from torch.nn import functional as F
from torchvision import datasets, transforms

from tqdm import tqdm

batch_size = 128
epochs = 100
seed = 1
torch.manual_seed(seed)

if torch.cuda.is_available():
    device = "cuda"
else:
    device = "cpu"

## 3.1 Tars.modelsを使わない方法

ここでは，2で説明した確率分布クラスの記述方法を使って， 深層生成モデルの一つであるVAEを実装してみましょう．

VAEでは， 次の生成モデルを考えます．

\begin{equation*}
p_{\theta}(x,z)=p_{\theta}(x|z)p(z)
\end{equation*}

ただし，$x$は観測変数，$z$は潜在変数，$\theta$はパラメータです．

VAEでは，$p_{\theta}(x|z)$は深層ニューラルネットワークで定義されるので，$\theta$はニューラルネットワークのパラメータになります．

VAEの目標は，周辺尤度$p(x)=\int p_\theta(x|z)p(z)dz$を最大化するように生成モデルを学習することです． ただし，これは直接最大化することができません．

そこで，代わりに次のような変分下界$\mathcal{L}(x)$を最大化することで，モデルを学習します．

\begin{equation*}
\log p(x) \geq E_{q_\phi(z|x)}[\log \frac{p_\theta(x,z)}{q_\phi(z|x)}] = \mathcal{L}(x)
\end{equation*}

ただし，$q_\phi(z|x)$は事後分布$p(z|x)$の近似分布で，$\phi$はそのパラメータです．VAEではこの分布も深層ニューラルネットワークでモデル化します（amortized inference）．

それぞれの分布について，$q_\phi(ｚ|x)$は$x$を$z$にエンコードするので**エンコーダ**，$p_\theta(x|z)$は$z$から$x$にデコードするので**デコーダ**と呼ばれます． このことから，このモデルは**variational autoencoders（VAE）**と呼ばれます．

ではこれに従って，Tarsの分布クラスで，各分布を設定しましょう．

In [13]:
from Tars.distributions import NormalModel, BernoulliModel

x_dim = 784
z_dim = 64

# inference model q(z|x)
class Inference(NormalModel):
    def __init__(self):
        super(Inference, self).__init__(cond_var=["x"], var=["z"])

        self.fc1 = nn.Linear(x_dim, 512)
        self.fc2 = nn.Linear(512, 512)
        self.fc31 = nn.Linear(512, z_dim)
        self.fc32 = nn.Linear(512, z_dim)

    def forward(self, x):
        h = F.relu(self.fc1(x))
        h = F.relu(self.fc2(h))
        return self.fc31(h), F.softplus(self.fc32(h))

    
# generative model p(x|z)    
class Generator(BernoulliModel):
    def __init__(self):
        super(Generator, self).__init__(cond_var=["z"], var=["x"])

        self.fc1 = nn.Linear(z_dim, 512)
        self.fc2 = nn.Linear(512, 512)
        self.fc3 = nn.Linear(512, x_dim)

    def forward(self, x):
        h = F.relu(self.fc1(x))
        h = F.relu(self.fc2(h))
        return F.sigmoid(self.fc3(h))
    
q = Inference()    
_p = Generator()
    
# prior model p(z)
loc = torch.tensor(0.).to(device)
scale = torch.tensor(1.).to(device)
prior = NormalModel(loc=loc, scale=scale, var=["z"], dim=z_dim)

ここでは，エンコーダにガウス分布，デコーダにベルヌーイ分布を用いています.

次に，生成モデルの同時分布を設定します．

In [14]:
p = _p * prior

print(p.prob_text)
print(p.prob_factorized_text)

p(x,z)
p(x|z)p(z)


今回は，GPU計算するので，各分布をGPUに載せます．

In [15]:
p.to(device)
q.to(device)

Inference(
  (fc1): Linear(in_features=784, out_features=512, bias=True)
  (fc2): Linear(in_features=512, out_features=512, bias=True)
  (fc31): Linear(in_features=512, out_features=64, bias=True)
  (fc32): Linear(in_features=512, out_features=64, bias=True)
)

次に，目的関数を設定します． VAEの目的関数は，次の式で表される変分下界でした． 

\begin{equation*}
\mathcal{L}(x) = E_{q_\phi(z|x)}[\log \frac{p_\theta(x,z)}{q_\phi(z|x)}]
\end{equation*}

この計算は，前回の確率分布クラスの特性を使うことで，以下のように非常にシンプルに書くことができます．

In [16]:
def elbo(x):
    #1. sample from q(z|x) 
    samples = q.sample(x)
    
    #2. caluculate the lower bound (log p(x,z) - log q(z|x))
    lower_bound = p.log_likelihood(samples) - q.log_likelihood(samples)

    loss = -torch.mean(lower_bound)

    return loss

この関数は，入力$x$を受け取って上記の式に従って変分下界を計算し，それを誤差関数として返します．

なおpytorchでは，誤差を最小化するように学習しますので，ここでは変分下界にマイナスを付けたものを誤差関数としています．

また，パラメータや最適化手法を設定します．

In [21]:
q_params = list(q.parameters())
p_params = list(p.parameters())
params = q_params + p_params
print(type(q.parameters()))

optimizer = optim.Adam(params, lr=1e-3)

<class 'generator'>


学習用のデータを読み込みます．今回はMNISTを使います．

In [8]:
batch_size = 128

kwargs = {'num_workers': 1, 'pin_memory': True}
train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=True, download=True,
                   transform=transforms.ToTensor()),
    batch_size=batch_size, shuffle=True, **kwargs)
test_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=False, transform=transforms.ToTensor()),
    batch_size=batch_size, shuffle=True, **kwargs)

最後に，trainとtestの関数を設定します．

In [9]:
def train(epoch):
    train_loss = 0
    for batch_idx, (data, _) in enumerate(tqdm(train_loader)):
        data = data.to(device)
        optimizer.zero_grad()
        
        # calculate the ELBO
        loss = elbo({"x": data.view(-1, 784)})
        
        # backprop
        loss.backward()

        # update params
        optimizer.step()             
        
        
        train_loss += loss
 
    train_loss = train_loss * train_loader.batch_size / len(train_loader.dataset)
    print('Epoch: {} Train loss: {:.4f}'.format(epoch, train_loss))
    return train_loss

In [10]:
def test(epoch):
    p.eval()
    q.eval()    
    
    test_loss = 0
    for i, (data, _) in enumerate(test_loader):
        data = data.to(device)
        with torch.no_grad():
            loss = elbo({"x": data.view(-1, 784)})        
        test_loss += loss

    test_loss = test_loss * test_loader.batch_size / len(test_loader.dataset)
    print('Test loss: {:.4f}'.format(test_loss))
    return test_loss

そして，学習を行います．

In [11]:
for epoch in range(1, epochs + 1):
    train_loss = train(epoch)
    test_loss = test(epoch)

100%|██████████| 469/469 [00:05<00:00, 92.48it/s]

Epoch: 1 Train loss: 183.3633



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 146.0240


100%|██████████| 469/469 [00:04<00:00, 94.64it/s]

Epoch: 2 Train loss: 133.4354



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 125.5314


100%|██████████| 469/469 [00:05<00:00, 91.70it/s]

Epoch: 3 Train loss: 120.6294



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 117.8539


100%|██████████| 469/469 [00:05<00:00, 89.69it/s]


Epoch: 4 Train loss: 114.9133


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 113.0229


100%|██████████| 469/469 [00:05<00:00, 92.74it/s]

Epoch: 5 Train loss: 111.5798



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 111.1281


100%|██████████| 469/469 [00:05<00:00, 92.16it/s]

Epoch: 6 Train loss: 109.5248



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 109.0550


100%|██████████| 469/469 [00:05<00:00, 88.19it/s]

Epoch: 7 Train loss: 107.8700



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 107.9416


100%|██████████| 469/469 [00:05<00:00, 87.00it/s]


Epoch: 8 Train loss: 106.5653


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 107.0079


100%|██████████| 469/469 [00:05<00:00, 88.02it/s]


Epoch: 9 Train loss: 105.6236


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 106.4712


100%|██████████| 469/469 [00:05<00:00, 89.90it/s]

Epoch: 10 Train loss: 104.8221



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 105.9208


100%|██████████| 469/469 [00:05<00:00, 89.04it/s]

Epoch: 11 Train loss: 104.1912



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 105.4300


100%|██████████| 469/469 [00:05<00:00, 92.33it/s]

Epoch: 12 Train loss: 103.6142



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 104.7093


100%|██████████| 469/469 [00:05<00:00, 89.80it/s]


Epoch: 13 Train loss: 103.1702


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 104.6743


100%|██████████| 469/469 [00:05<00:00, 87.09it/s]

Epoch: 14 Train loss: 102.6833



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 104.0688


100%|██████████| 469/469 [00:05<00:00, 84.15it/s]

Epoch: 15 Train loss: 102.3161



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 103.9421


100%|██████████| 469/469 [00:05<00:00, 91.10it/s]

Epoch: 16 Train loss: 101.9974



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 104.0246


100%|██████████| 469/469 [00:05<00:00, 88.15it/s]


Epoch: 17 Train loss: 101.7260


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 103.2529


100%|██████████| 469/469 [00:05<00:00, 87.34it/s]

Epoch: 18 Train loss: 101.4162



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 103.0945


100%|██████████| 469/469 [00:05<00:00, 82.13it/s]

Epoch: 19 Train loss: 101.0964



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 102.8995


100%|██████████| 469/469 [00:05<00:00, 90.80it/s]

Epoch: 20 Train loss: 100.8562



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 102.8658


100%|██████████| 469/469 [00:05<00:00, 88.83it/s]


Epoch: 21 Train loss: 100.6708


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 103.1198


100%|██████████| 469/469 [00:05<00:00, 81.66it/s]

Epoch: 22 Train loss: 100.4945



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 102.6019


100%|██████████| 469/469 [00:05<00:00, 88.91it/s]


Epoch: 23 Train loss: 100.2737


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 102.5762


100%|██████████| 469/469 [00:05<00:00, 86.31it/s]

Epoch: 24 Train loss: 100.0818



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 102.3602


100%|██████████| 469/469 [00:05<00:00, 90.39it/s]

Epoch: 25 Train loss: 99.9354



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 102.3842


100%|██████████| 469/469 [00:05<00:00, 89.45it/s]


Epoch: 26 Train loss: 99.7782


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 101.9701


100%|██████████| 469/469 [00:05<00:00, 84.76it/s]

Epoch: 27 Train loss: 99.6358



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 101.8709


100%|██████████| 469/469 [00:05<00:00, 90.16it/s]

Epoch: 28 Train loss: 99.4667



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 102.2739


100%|██████████| 469/469 [00:05<00:00, 85.55it/s]


Epoch: 29 Train loss: 99.2601


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 101.8162


100%|██████████| 469/469 [00:05<00:00, 88.40it/s]


Epoch: 30 Train loss: 99.1618


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 101.8504


100%|██████████| 469/469 [00:05<00:00, 88.88it/s]

Epoch: 31 Train loss: 99.0199



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 101.8173


100%|██████████| 469/469 [00:05<00:00, 89.29it/s]


Epoch: 32 Train loss: 98.8815


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 101.6085


100%|██████████| 469/469 [00:05<00:00, 90.06it/s]

Epoch: 33 Train loss: 98.7860



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 101.8379


100%|██████████| 469/469 [00:05<00:00, 88.41it/s]


Epoch: 34 Train loss: 98.7053


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 101.4982


100%|██████████| 469/469 [00:05<00:00, 87.22it/s]

Epoch: 35 Train loss: 98.5123



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 101.2992


100%|██████████| 469/469 [00:05<00:00, 90.37it/s]

Epoch: 36 Train loss: 98.4653



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 101.3093


100%|██████████| 469/469 [00:05<00:00, 87.17it/s]

Epoch: 37 Train loss: 98.3306



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 101.1088


100%|██████████| 469/469 [00:05<00:00, 88.89it/s]


Epoch: 38 Train loss: 98.2318


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 101.3879


100%|██████████| 469/469 [00:05<00:00, 87.46it/s]

Epoch: 39 Train loss: 98.1755



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 101.3685


100%|██████████| 469/469 [00:05<00:00, 89.27it/s]

Epoch: 40 Train loss: 98.0797



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 101.4969


100%|██████████| 469/469 [00:05<00:00, 88.12it/s]


Epoch: 41 Train loss: 98.0319


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 101.3861


100%|██████████| 469/469 [00:05<00:00, 89.24it/s]


Epoch: 42 Train loss: 97.9281


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 101.3036


100%|██████████| 469/469 [00:05<00:00, 87.48it/s]

Epoch: 43 Train loss: 97.8774



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 101.1539


100%|██████████| 469/469 [00:05<00:00, 86.93it/s]

Epoch: 44 Train loss: 97.7513



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 101.2182


100%|██████████| 469/469 [00:05<00:00, 90.58it/s]

Epoch: 45 Train loss: 97.7026



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 100.8849


100%|██████████| 469/469 [00:05<00:00, 87.09it/s]

Epoch: 46 Train loss: 97.5989



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 101.1147


100%|██████████| 469/469 [00:05<00:00, 82.32it/s]

Epoch: 47 Train loss: 97.5350



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 100.8053


100%|██████████| 469/469 [00:05<00:00, 89.33it/s]


Epoch: 48 Train loss: 97.4056


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 100.6892


100%|██████████| 469/469 [00:05<00:00, 89.96it/s]


Epoch: 49 Train loss: 97.3909


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 100.8094


100%|██████████| 469/469 [00:05<00:00, 87.67it/s]

Epoch: 50 Train loss: 97.3662



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 100.7903


100%|██████████| 469/469 [00:05<00:00, 84.82it/s]

Epoch: 51 Train loss: 97.2863



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 100.7027


100%|██████████| 469/469 [00:05<00:00, 85.29it/s]


Epoch: 52 Train loss: 97.1847


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 100.7475


100%|██████████| 469/469 [00:05<00:00, 87.93it/s]

Epoch: 53 Train loss: 97.1609



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 100.8758


100%|██████████| 469/469 [00:05<00:00, 89.19it/s]


Epoch: 54 Train loss: 97.0746


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 100.7802


100%|██████████| 469/469 [00:05<00:00, 87.24it/s]

Epoch: 55 Train loss: 97.0767



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 100.8502


100%|██████████| 469/469 [00:05<00:00, 88.99it/s]

Epoch: 56 Train loss: 96.9856



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 100.4351


100%|██████████| 469/469 [00:05<00:00, 86.81it/s]

Epoch: 57 Train loss: 96.9419



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 100.5765


100%|██████████| 469/469 [00:05<00:00, 89.82it/s]


Epoch: 58 Train loss: 96.8993


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 100.8626


100%|██████████| 469/469 [00:05<00:00, 86.54it/s]

Epoch: 59 Train loss: 96.8453



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 100.7003


100%|██████████| 469/469 [00:05<00:00, 90.37it/s]

Epoch: 60 Train loss: 96.7747



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 100.6985


100%|██████████| 469/469 [00:05<00:00, 87.42it/s]

Epoch: 61 Train loss: 96.7370



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 100.9532


100%|██████████| 469/469 [00:05<00:00, 85.39it/s]


Epoch: 62 Train loss: 96.6957


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 100.8166


100%|██████████| 469/469 [00:05<00:00, 90.10it/s]

Epoch: 63 Train loss: 96.6418



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 100.5635


100%|██████████| 469/469 [00:05<00:00, 78.74it/s]

Epoch: 64 Train loss: 96.6572



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 100.8069


100%|██████████| 469/469 [00:05<00:00, 86.50it/s]


Epoch: 65 Train loss: 96.5272


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 100.6438


100%|██████████| 469/469 [00:05<00:00, 88.17it/s]

Epoch: 66 Train loss: 96.5249



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 100.8428


100%|██████████| 469/469 [00:05<00:00, 88.52it/s]


Epoch: 67 Train loss: 96.4745


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 100.6602


100%|██████████| 469/469 [00:05<00:00, 85.28it/s]


Epoch: 68 Train loss: 96.4263


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 100.6775


100%|██████████| 469/469 [00:05<00:00, 81.97it/s]


Epoch: 69 Train loss: 96.4235


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 100.3792


100%|██████████| 469/469 [00:05<00:00, 90.50it/s]

Epoch: 70 Train loss: 96.3189



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 100.7504


100%|██████████| 469/469 [00:05<00:00, 88.38it/s]


Epoch: 71 Train loss: 96.3343


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 100.5617


100%|██████████| 469/469 [00:05<00:00, 87.41it/s]

Epoch: 72 Train loss: 96.2387



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 100.4829


100%|██████████| 469/469 [00:05<00:00, 81.07it/s]

Epoch: 73 Train loss: 96.2754



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 100.3839


100%|██████████| 469/469 [00:05<00:00, 86.00it/s]


Epoch: 74 Train loss: 96.2269


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 100.7312


100%|██████████| 469/469 [00:05<00:00, 86.98it/s]


Epoch: 75 Train loss: 96.1733


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 100.5839


100%|██████████| 469/469 [00:05<00:00, 85.43it/s]


Epoch: 76 Train loss: 96.1219


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 100.5242


100%|██████████| 469/469 [00:05<00:00, 86.25it/s]


Epoch: 77 Train loss: 96.1104


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 100.7335


100%|██████████| 469/469 [00:05<00:00, 88.07it/s]

Epoch: 78 Train loss: 96.0831



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 100.6883


100%|██████████| 469/469 [00:05<00:00, 86.68it/s]

Epoch: 79 Train loss: 96.0533



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 100.6282


100%|██████████| 469/469 [00:05<00:00, 85.61it/s]

Epoch: 80 Train loss: 96.0074



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 100.5492


100%|██████████| 469/469 [00:05<00:00, 85.96it/s]

Epoch: 81 Train loss: 95.9857



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 100.1944


100%|██████████| 469/469 [00:05<00:00, 86.85it/s]


Epoch: 82 Train loss: 95.9357


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 100.5168


100%|██████████| 469/469 [00:05<00:00, 88.06it/s]


Epoch: 83 Train loss: 95.8966


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 100.2586


100%|██████████| 469/469 [00:05<00:00, 83.52it/s]

Epoch: 84 Train loss: 95.9017



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 100.5553


100%|██████████| 469/469 [00:05<00:00, 82.77it/s]

Epoch: 85 Train loss: 95.9224



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 100.4574


100%|██████████| 469/469 [00:05<00:00, 81.11it/s]


Epoch: 86 Train loss: 95.8194


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 100.4493


100%|██████████| 469/469 [00:05<00:00, 88.20it/s]

Epoch: 87 Train loss: 95.8588



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 100.5864


100%|██████████| 469/469 [00:05<00:00, 85.28it/s]


Epoch: 88 Train loss: 95.7388


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 100.4172


100%|██████████| 469/469 [00:05<00:00, 87.65it/s]

Epoch: 89 Train loss: 95.7378



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 100.6841


100%|██████████| 469/469 [00:05<00:00, 82.46it/s]


Epoch: 90 Train loss: 95.7522


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 100.7997


100%|██████████| 469/469 [00:05<00:00, 87.91it/s]

Epoch: 91 Train loss: 95.7350



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 100.5321


100%|██████████| 469/469 [00:05<00:00, 85.95it/s]


Epoch: 92 Train loss: 95.6541


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 100.4617


100%|██████████| 469/469 [00:05<00:00, 87.43it/s]

Epoch: 93 Train loss: 95.6215



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 100.3324


100%|██████████| 469/469 [00:05<00:00, 86.34it/s]

Epoch: 94 Train loss: 95.6763



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 100.5712


100%|██████████| 469/469 [00:05<00:00, 91.19it/s]

Epoch: 95 Train loss: 95.5992



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 100.5765


100%|██████████| 469/469 [00:05<00:00, 84.90it/s]

Epoch: 96 Train loss: 95.5718



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 100.2883


100%|██████████| 469/469 [00:05<00:00, 87.73it/s]

Epoch: 97 Train loss: 95.5808



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 100.3554


100%|██████████| 469/469 [00:05<00:00, 88.49it/s]

Epoch: 98 Train loss: 95.5843



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 100.5142


100%|██████████| 469/469 [00:05<00:00, 88.26it/s]

Epoch: 99 Train loss: 95.5165



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 100.8268


100%|██████████| 469/469 [00:05<00:00, 89.84it/s]

Epoch: 100 Train loss: 95.4821





Test loss: 100.3799


テスト誤差がどんどん小さくなっているのが確認できると思います． これは，変分下界が大きくなっていることを示しており，VAEの学習がうまく進んでいることがわかります．

今回は，エンコーダ（ガウス分布）とデコーダ（ベルヌーイ分布）で構成されるシンプルなモデルとしてVAEを考えましたが，Tarsでは，同時分布をモデル化するAPIを採用することで，**同じ実装方法で任意の形の生成モデルを学習することができます**．

また，それぞれの分布クラスは学習中や学習後でも自由にサンプリングできます．したがって，例えば学習後に生成モデル_pからサンプリングすることで，MNIST画像を生成することができます． 

こちらについてはTars/exampleにいくつか例を用意しましたので，実行してみてください．

## 3.2 Tars.modelsを使う方法

3.1では下界を実装しましたが，TarsではTars.models以下にいくつかの学習モデルを用意しています．

変分推論の場合は，任意の生成モデルについて，Tars.models.VIを使うことで簡単に学習できます．

In [16]:
from Tars.models import VI
model = VI(p, q, optimizer=optim.Adam, optimizer_params={"lr":1e-3})

In [17]:
def train(epoch):
    train_loss = 0
    for batch_idx, (data, _) in enumerate(tqdm(train_loader)):
        data = data.to(device)
        lower_bound, loss = model.train({"x": data.view(-1, 784)})
        train_loss += loss
 
    train_loss = train_loss * train_loader.batch_size / len(train_loader.dataset)
    print('Epoch: {} Train loss: {:.4f}'.format(epoch, train_loss))
    return train_loss

In [18]:
def test(epoch):
    test_loss = 0
    for i, (data, _) in enumerate(test_loader):
        data = data.to(device)
        lower_bound, loss = model.test({"x": data.view(-1, 784)})
        test_loss += loss

    test_loss = test_loss * test_loader.batch_size / len(test_loader.dataset)
    print('Test loss: {:.4f}'.format(test_loss))
    return test_loss

In [None]:
for epoch in range(1, epochs + 1):
    train_loss = train(epoch)
    test_loss = test(epoch)

100%|██████████| 469/469 [00:04<00:00, 94.55it/s]


Epoch: 1 Train loss: 103.3703


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 104.1252


100%|██████████| 469/469 [00:05<00:00, 86.72it/s]


Epoch: 2 Train loss: 102.8292


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 104.0812


100%|██████████| 469/469 [00:05<00:00, 81.84it/s]

Epoch: 3 Train loss: 102.3836



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 103.8067


100%|██████████| 469/469 [00:05<00:00, 91.34it/s]


Epoch: 4 Train loss: 102.0076


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 103.9207


100%|██████████| 469/469 [00:04<00:00, 95.09it/s]

Epoch: 5 Train loss: 101.6657



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 103.3255


100%|██████████| 469/469 [00:05<00:00, 87.32it/s]


Epoch: 6 Train loss: 101.3492


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 103.1533


100%|██████████| 469/469 [00:05<00:00, 86.60it/s]

Epoch: 7 Train loss: 101.1122



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 102.9838


100%|██████████| 469/469 [00:05<00:00, 87.60it/s]

Epoch: 8 Train loss: 100.8648



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 102.6977


100%|██████████| 469/469 [00:05<00:00, 87.39it/s]


Epoch: 9 Train loss: 100.6251


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 102.4485


100%|██████████| 469/469 [00:05<00:00, 89.45it/s]


Epoch: 10 Train loss: 100.3216


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 102.2215


100%|██████████| 469/469 [00:05<00:00, 78.34it/s]

Epoch: 11 Train loss: 100.0790



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 102.2996


100%|██████████| 469/469 [00:05<00:00, 88.11it/s]

Epoch: 12 Train loss: 99.8513



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 101.9408


100%|██████████| 469/469 [00:05<00:00, 93.18it/s]

Epoch: 13 Train loss: 99.6467



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 102.1512


100%|██████████| 469/469 [00:05<00:00, 83.65it/s]


Epoch: 14 Train loss: 99.4583


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 102.1782


100%|██████████| 469/469 [00:05<00:00, 79.74it/s]


Epoch: 15 Train loss: 99.3177


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 101.8646


100%|██████████| 469/469 [00:05<00:00, 87.89it/s]

Epoch: 16 Train loss: 99.1559



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 101.9316


 31%|███       | 144/469 [00:01<00:03, 92.54it/s]

その他の分布やモデルの例については，Tars/exampleをみてください．