In [277]:
import numpy as np
import matplotlib.pyplot as plt
import os

import torch
from torch import nn, optim
from torch.nn import functional as F
from torch.utils.data import DataLoader

import torchvision
from torchvision import datasets, transforms
from torchvision.utils import make_grid, save_image

print('PyTorch version:', torch.__version__)
print('torchvision version:', torchvision.__version__)
print('Is GPU available:', torch.cuda.is_available())

PyTorch version: 0.4.1
torchvision version: 0.2.1
Is GPU available: False


In [278]:
# デバイスとハイパーパラメータ

# デバイスの準備
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('device:', device)

# バッチサイズの指定
batchsize = 128

# 回すエポック数の指定
n_epochs = 100

# 学習率の指定
learning_rate = 0.0005

# 乱数シードの指定
seed = 1
torch.manual_seed(seed)

if torch.cuda.is_available():
    torch.cuda.manual_seed(seed)
    
# VAEの潜在空間の次元数の指定
embed_dim = 10

# 画像を保存するディレクトリ名の指定
output_dir = './VAE_' + str(embed_dim)
if not os.path.exists(output_dir):
    os.mkdir(output_dir)

device: cpu


In [279]:
# データセットの準備

# Tensorにして、-1～1の範囲に正規化
tf = transforms.Compose([transforms.ToTensor(),
                         transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

# 今回はMNISTを使用
mnist_train = datasets.MNIST(root = '../data', train = True, transform = tf, download = True)
mnist_test = datasets.MNIST(root = '../data', train = False, transform = tf, download = False)

# データローダーを容易
mnist_train_loader = DataLoader(mnist_train, batch_size = batchsize, shuffle = True, num_workers = 4)
mnist_test_loader = DataLoader(mnist_test, batch_size = batchsize, shuffle = False, num_workers = 4)

print('The number of training data:', len(mnist_train))
print('The number of test data:', len(mnist_test))

The number of training data: 60000
The number of test data: 10000


In [280]:
# VAEのencoderを定義
class Encoder(nn.Module):
    def __init__(self, n_out):
        super(Encoder, self).__init__()
        self.cv1 = nn.Conv2d(1,  32, kernel_size = 5, stride = 2, padding = 2)
        self.cv2 = nn.Conv2d(32, 64, kernel_size = 4, stride = 2, padding = 1)
        self.fc3 = nn.Linear(64*7*7, 256)
        self.fc4_mean = nn.Linear(256, n_out)
        self.fc4_logvar = nn.Linear(256, n_out)
        
        self.bn1 = nn.BatchNorm2d(32)
        self.bn2 = nn.BatchNorm2d(64)
        self.bn3 = nn.BatchNorm1d(256)
        
    def forward(self, x):
        h = self.cv1(x)
        h = self.bn1(h)
        h = F.leaky_relu(h)
        
        h = self.cv2(h)
        h = self.bn2(h)
        h = F.leaky_relu(h)
        
        h = h.view(h.size(0), -1)
        
        h = self.fc3(h)
        h = self.bn3(h)
        h = F.leaky_relu(h)
        
        out_mean = self.fc4_mean(h)
        out_logvar = self.fc4_logvar(h)
        return out_mean, out_logvar

In [281]:
# VAEのdecoderを定義
class Decoder(nn.Module):
    def __init__(self, n_in):
        super(Decoder, self).__init__()
        self.fc1 = nn.Linear(n_in, 256)
        self.fc2 = nn.Linear(256, 64*7*7)
        self.tc3 = nn.ConvTranspose2d(64, 32, kernel_size = 4, stride = 2, padding = 1)
        self.tc4 = nn.ConvTranspose2d(32, 1, kernel_size =  4, stride = 2, padding = 1)
        
        self.bn1 = nn.BatchNorm1d(256)
        self.bn2 = nn.BatchNorm1d(64*7*7)
        self.bn3 = nn.BatchNorm2d(32)
        
    def forward(self, x):
        h = self.fc1(x)
        h = self.bn1(h)
        h = F.leaky_relu(h)

        h = self.fc2(h)
        h = self.bn2(h)
        h = F.leaky_relu(h)
        
        h = h.view(h.size(0), 64, 7, 7)
        
        h = self.tc3(h)
        h = self.bn3(h)
        h = F.leaky_relu(h)
        
        out = self.tc4(h)
        out = torch.tanh(out)
        return out

In [282]:
# encoderとdecoderをまとめ、reparametrization trick等を合わせてまとめてVAEを構成
class VAE(nn.Module):
    def __init__(self, n_emb):
        super(VAE, self).__init__()
        self.n_emb = n_emb
        self.encoder = Encoder(n_emb)
        self.decoder = Decoder(n_emb)
        
    def embed(self, x):
        out_mean, out_logvar = self.encoder(x)
        return out_mean, out_logvar
    
    def sample(self, n_sample):
        z = torch.randn(n_sample. self.n_emb)
        out = self.decoder(z)
        return out
    
    def forward(self, x):
        out_mean, out_logvar = self.embed(x)
        eps = torch.randn(out_mean.size())
        z = (0.5 * out_logvar).exp() * eps + out_mean
        out = self.decoder(z)
        return out, out_mean, out_logvar

In [283]:
# ネットワークを実体化、オプティマイザを定義
net = VAE(embed_dim)
net = net.to(device) # CPU/GPUにモデルを送信

# オプティマイザは取り敢えずAdam, 学習率は上で指定、weight_decayを適当にかける
optimizer = optim.Adam(net.parameters(), lr = learning_rate, weight_decay = 0.0005)

# モデルのtarinableな(勾配を要求する)パラメータの数をカウントする（.numel()で要素数の合計がわかる）
num_trainable_params = sum(p.numel() for p in net.parameters() if p.requires_grad)

# モデルの構造、オプティマイザの表示
print('The number of trainable parameters:', num_trainable_params)
print('\nModel:\n', net)
print('\nOptimizer:\n', optimizer)

The number of trainable parameters: 1691509

Model:
 VAE(
  (encoder): Encoder(
    (cv1): Conv2d(1, 32, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2))
    (cv2): Conv2d(32, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (fc3): Linear(in_features=3136, out_features=256, bias=True)
    (fc4_mean): Linear(in_features=256, out_features=10, bias=True)
    (fc4_logvar): Linear(in_features=256, out_features=10, bias=True)
    (bn1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (bn3): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (decoder): Decoder(
    (fc1): Linear(in_features=10, out_features=256, bias=True)
    (fc2): Linear(in_features=256, out_features=3136, bias=True)
    (tc3): ConvTranspose2d(64, 32, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (tc4): ConvTranspose2d(32, 1, kernel_size=(4, 

In [284]:
# VAEのロス関数を定義
def loss_VAE(recon_x, x, out_mean, out_logvar):
    recon_err = F.mse_loss(recon_x, x)
    kldiv_err = -0.5 * torch.sum(1 + out_logvar - out_mean.pow(2) - out_logvar.exp())
    final_err = recon_err + kldiv_err
    
    return final_err, recon_err, kldiv_err

In [285]:
# １エポック分の学習を行う関数
def train(train_loader):
    net.train() # モデルを学習モードにする
    running_final_loss = 0
    running_recon_loss = 0
    running_kldiv_loss = 0
    
    for inputs, _ in train_loader:
        inputs =  inputs.to(device) # モデルとTensorを同じインタフェースでCPU/GPU転送できる 
        recon_inputs, out_mean, out_logvar = net(inputs)
        final_loss, recon_loss, kldiv_loss = loss_VAE(recon_inputs, inputs, out_mean, out_logvar)
        
        optimizer.zero_grad()
        final_loss.backward()
        optimizer.step()
        
        running_final_loss += final_loss.item() # .item()でスカラ値を単要素Tensorから取り出す
        running_recon_loss += recon_loss.item()
        running_kldiv_loss += kldiv_loss.item()
        
    final_loss = running_final_loss / len(train_loader.dataset)
    recon_loss = running_recon_loss / len(train_loader.dataset)
    kldiv_loss = running_kldiv_loss / len(train_loader.dataset)

    return final_loss, recon_loss, kldiv_loss # 戻り値は訓練誤差

In [289]:
# 1ポック分のテスト（実質バリデーションだが）を行う関数
def test(test_loader, epoch):
    net.eval() # モデルを学習モードにする
    running_final_loss = 0
    running_recon_loss = 0
    running_kldiv_loss = 0
    
    with torch.no_grad():
        for inputs, _ in test_loader:
            inputs =  inputs.to(device) # モデルとTensorを同じインタフェースでCPU/GPU転送できる 
            recon_inputs, out_mean, out_logvar = net(inputs)
            final_loss, recon_loss, kldiv_loss = loss_VAE(recon_inputs, inputs, out_mean, out_logvar)
        
            running_final_loss += final_loss.item() # .item()でスカラ値を単要素Tensorから取り出す
            running_recon_loss += recon_loss.item()
            running_kldiv_loss += kldiv_loss.item()
        
    final_loss = running_final_loss / len(test_loader.dataset)
    recon_loss = running_recon_loss / len(test_loader.dataset)
    kldiv_loss = running_kldiv_loss / len(test_loader.dataset)
    
    # 10エポックごとに最後のミニバッチの生成画像を保存する
    if epoch % 5 == 0:
        n_save_image = 8        
        comparison = torch.cat([inputs[:n_save_image], recon_inputs[:n_save_image]])
        save_image(comparison.data.cpu(),'{}/reconstruction_{}.png'.format(output_dir, epoch), nrow=n_save_image)
        
    return final_loss, recon_loss, kldiv_loss # 戻り値はテスト（バリデーション）誤差

In [291]:
# 学習の実行と、モデルの保存（学習ログは.npyで、モデル状態は.pthで　←　モデル状態の保存は他にもいろいろある？）
train_loss_list = [[],[],[]]
test_loss_list = [[],[],[]]
for epoch in range(n_epochs):
    train_final_loss, train_recon_loss, train_kldiv_loss = train(mnist_train_loader)
    test_final_loss, test_recon_loss, test_kldiv_loss  = test(mnist_test_loader, epoch)
    
    train_loss_list[0].append(train_final_loss)
    train_loss_list[1].append(train_recon_loss)
    train_loss_list[2].append(train_kldiv_loss)

    test_loss_list[0].append(test_final_loss)
    test_loss_list[1].append(test_recon_loss)
    test_loss_list[2].append(test_kldiv_loss)

    print('epoch[%d/%d] train_loss:%1.4f test_loss:%1.4f' % \
                                (epoch+1, n_epochs, train_final_loss, test_final_loss))


np.save(output_dir + 'train_loss_list.npy', np.array(train_loss_list))
np.save(output_dir + 'validation_loss_list.npy', np.array(test_loss_list))

torch.save(net.state_dict(), output_dir + 'VAE.pth')

epoch[1/5] train_loss:0.0043 test_loss:0.0043


KeyboardInterrupt: 