In [1]:
import sys
import os
sys.path.append("..")

In [2]:
from models.vq_vae import VQ_VAE
from utils.e4_bvp_datasets import BVPDataset
from utils.e4_bvp_datasets import BVPDatasetValid

In [3]:
import numpy as np
import torch
import torchvision
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset, Subset
from torch.autograd import Variable
import torch.nn.functional as F 
from torchvision.utils import save_image
from torch.utils.tensorboard import SummaryWriter
%matplotlib inline
from matplotlib import pyplot as plt
import datetime

In [4]:
num_hiddens = 32
num_residual_hiddens = 32
num_residual_layers = 2

# embedding_dim = 1 #RJS最初
embedding_dim = 8
num_embeddings = 128

commitment_cost = 0.25

decay = 0.99

In [5]:
batch_size = 32
num_training_updates = 100
num_epochs = 500
learning_rate = 1e-3

# Load Dataset

In [6]:
dataset = BVPDataset(transform=transforms.ToTensor())
valid_dataset = BVPDatasetValid(transform=transforms.ToTensor())

In [7]:
train_loader = DataLoader(dataset, shuffle=True, batch_size=batch_size)
valid_loader = DataLoader(valid_dataset, shuffle=False, batch_size=batch_size)
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# Train

In [8]:
model = VQ_VAE(num_hiddens, num_residual_layers, num_residual_hiddens,
              num_embeddings, embedding_dim, 
              commitment_cost, decay).to(device)

In [9]:
%%time
model.train()
train_res_recon_error = []
train_res_perplexity = []
recon_error_s = 0
res_perplexity = 0
data_variance = 1
good_id = 0


optimizer = optim.Adam(model.parameters(), lr=learning_rate, amsgrad=False)

f_name = "bvp_vq_vae-128-8dim"
dt_now = datetime.datetime.now()
now = "{}-{}-{}".format(dt_now.year, dt_now.month, dt_now.day)
name_dir = "pth/{}".format(now)
name = "pth/{}/{}.pth".format(now, f_name)
if not os.path.isdir(name_dir):
    os.makedirs(name_dir)
log_dir = "logs/{}/{}".format(now, f_name)
if not os.path.isdir(log_dir):
    os.makedirs(log_dir)
writer = SummaryWriter(log_dir=log_dir)


for epoch in range(num_epochs):
    model.train()
    recon_error_s = 0
    res_perplexity = 0
    for data in train_loader:
        data = data.to(device, dtype=torch.float)
        data = data.view(data.size(0), 1, -1)
        optimizer.zero_grad()

        vq_loss, data_recon, perplexity = model(data)
        recon_error = F.mse_loss(data_recon, data) / data_variance        
        loss = recon_error + vq_loss
        
        loss.backward()
        optimizer.step()
        
        recon_error_s += recon_error.item()
        res_perplexity += perplexity.item()
        
#         recon_error_s += recon_error.item()
#         res_perplexity += perplexity.item()
        
#         # マルチGPU用
#         loss = loss.mean()
#         #loss.mean().backward()
#         loss.backward()
#         optimizer.step()
#         perplexity = perplexity.mean()
#         recon_error_s += recon_error.item()
#         res_perplexity += perplexity.item()
    writer.add_scalar("train-loss", recon_error_s/len(train_loader), epoch)
    writer.add_scalar("train-perplexity", res_perplexity/len(train_loader), epoch)
    
    recon_error_s = 0
    res_perplexity = 0
    #検証データで評価
    model.eval()
    with torch.no_grad():
        for data in valid_loader:
            data = data.to(device, dtype=torch.float)
            data = data.view(data.size(0), 1, -1)
            vq_loss, data_recon, perplexity = model(data)
            recon_error = F.mse_loss(data_recon, data) / data_variance
            recon_error_s += recon_error.item()
            res_perplexity += perplexity.item()

    train_res_recon_error.append(recon_error_s/len(valid_loader))
    train_res_perplexity.append(res_perplexity/len(valid_loader))
    writer.add_scalar("valid-loss", recon_error_s/len(valid_loader), epoch)
    writer.add_scalar("valid-perplexity", res_perplexity/len(valid_loader), epoch)
    if recon_error_s/len(train_loader) <= train_res_recon_error[good_id]:
        torch.save(model.state_dict(), name)
        good_id = epoch

    
writer.close()

KeyboardInterrupt: 