In [1]:
# PyTorch
import torch
import torch.nn as nn
import torch.nn.functional as F
 
# PyTorch画像用
import torchvision
import torchvision.transforms as transforms
 
# 画像表示用
import matplotlib.pyplot as plt

In [2]:
batch_size = 128
 
# データセットの取得
train_dataset = torchvision.datasets.MNIST(
    root='./data',
    train=True,
    transform=transforms.ToTensor(),
    download=True,
)
# DataLoaderの作成
train_loader = torch.utils.data.DataLoader(
    dataset=train_dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=0
)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz


  0%|          | 0/9912422 [00:00<?, ?it/s]

Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz


  0%|          | 0/28881 [00:00<?, ?it/s]

Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz


  0%|          | 0/1648877 [00:00<?, ?it/s]

Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz


  0%|          | 0/4542 [00:00<?, ?it/s]

Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw



In [None]:
image_size = 28 * 28
h_dim = 32
z_dim = 16

In [15]:
class Encoder(nn.Module):
    def __init__(self, input_dim, hidden_dim, latent_dim):
        super(Encoder, self).__init__()
        self.fc = nn.Linear(input_diam, hidden_dim)
        self.fc_mu = nn.Linear(hidden_dim, latent_dim)
        self.fc_var = nn.Linear(hidden_dim, latent_dim)

    def forward(self, x):
        # ニューラルネットワークで事後分布の平均・分散を計算する
        h = torch.relu(self.fc(x))
        mu = self.fc_mu(h) # μ
        log_var = self.fc_var(h) # log σ^2

        # 潜在変数を求める
        ## 標準正規乱数を振る
        eps = torch.randn_like(torch.exp(log_var))
        ## 潜在変数の計算 μ + σ・ε
        z = mu + torch.exp(log_var / 2) * eps

In [16]:
class Decoder(nn.Module):
    def __init__(self, input_dim, hidden_dim, latent_dim):
        super(Decoder, self).__init__()
        self.fc = nn.Linear(latent_dim, hidden_dim)
        self.fc_output = nn.Linear(hidden_dim, input_dim)

    def forward(self, z):
        h = torch.relu(self.fc(z))
        output = torch.sigmoid(self.fc_output(h))
        return output

In [30]:
class VAE(nn.Module):
    def __init__(self, input_dim, hidden_dim, latent_dim):
        super(VAE, self).__init__()
        self.encoder = Encoder(input_dim, hidden_dim, latent_dim)
        self.decoder = Decoder(input_dim, hidden_dim, latent_dim)

    def forward(self, x):
        mu, log_var, z = self.encoder(x) # エンコード
        x_decoded = self.decoder(z) # デコード
        return x_decoded, mu, log_var, z 

In [31]:
def loss_function(label, predict, mu, log_var):
    reconstruction_loss = F.binary_cross_entropy(predict, label, reduction='sum')
    kl_loss = - 0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
    vae_loss = reconstruction_loss + kl_loss
    return vae_loss, reconstruction_loss, kl_loss

In [32]:
# デバイスの割り当て
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

cpu


In [33]:
image_size = 28 * 28
h_dim = 32
z_dim = 16
num_epochs = 10
learning_rate = 1e-3
 
model = VAE(image_size, h_dim, z_dim).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

In [34]:
losses = []
model.train()
for epoch in range(num_epochs):
  train_loss = 0
  for i, (x, labels) in enumerate(train_loader):
    # 予測
    x = x.to(device).view(-1, image_size).to(torch.float32)
    x_recon, mu, log_var, z = model(x)
    # 損失関数の計算
    loss, recon_loss, kl_loss = loss_function(x, x_recon, mu, log_var)
     
    # パラメータの更新
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
     
    # 損失の表示
    if (i+1) % 10 == 0:
      print(f'Epoch: {epoch+1}, loss: {loss: 0.4f}, reconstruct loss: {recon_loss: 0.4f}, KL loss: {kl_loss: 0.4f}')
    losses.append(loss)

TypeError: cannot unpack non-iterable NoneType object