# [VQ-VAE](https://arxiv.org/abs/1711.00937) for audio in PyTorch

This notebook is based on 
https://github.com/zalandoresearch/pytorch-vq-vae

## Introduction

Variational Auto Encoders (VAEs) can be thought of as what all but the last layer of a neural network is doing, namely feature extraction or seperating out the data. Thus given some data we can think of using a neural network for representation generation. 

Recall that the goal of a generative model is to estimate the probability distribution of high dimensional data such as images, videos, audio or even text by learning the underlying structure in the data as well as the dependencies between the different elements of the data. This is very useful since we can then use this representation to generate new data with similar properties. This way we can also learn useful features from the data in an unsupervised fashion.

The VQ-VAE uses a discrete latent representation mostly because many important real-world objects are discrete. For example in images we might have categories like "Cat", "Car", etc. and it might not make sense to interpolate between these categories. Discrete representations are also easier to model since each category has a single value whereas if we had a continous latent space then we will need to normalize this density function and learn the dependencies between the different variables which could be very complex.

### Code

I have followed the code from the TensorFlow implementation by the author which you can find here [vqvae.py](https://github.com/deepmind/sonnet/blob/master/sonnet/python/modules/nets/vqvae.py) and [vqvae_example.ipynb](https://github.com/deepmind/sonnet/blob/master/sonnet/examples/vqvae_example.ipynb). 

Another PyTorch implementation is found at [pytorch-vqvae](https://github.com/ritheshkumar95/pytorch-vqvae).


## Basic Idea

We start by defining a latent embedding space of dimension `[K, D]` where `K` are the number of embeddings and `D` is the dimensionality of each latent embeddng vector $e_i$.

The model will take in batches of waveforms, of size 16126 for our example, and pass it through a ConvNet encoder producing some output, where we make sure the channels are the same as the dimensionality of the latent embedding vectors. To calculate the discrete latent variable we find the nearest embedding vector and output it's index. 

The input to the decoder is the embedding vector corresponding to the index which is passed through the decoder to produce the reconstructed audio. 

Since the nearest neighbour lookup has no real gradient in the backward pass we simply pass the gradients from the decoder to the encoder  unaltered. The intuition is that since the output representation of the encoder and the input to the decoder share the same `D` channel dimensional space, the gradients contain useful information for how the encoder has to change its output to lower the reconstruction loss.

## Loss

The total loss is composed of three components:

1. reconstruction loss which optimizes the decoder and encoder
1. due to the fact that gradients bypass the embedding, we use a dictionary learning algorithm  which uses an $l_2$  error to move the embedding vectors $e_i$ towards the encoder output
1. also since the volume of the embedding space is dimensionless, it can grow arbirtarily if the embeddings $e_i$ do not train as fast as  the encoder parameters, and thus we add a commitment loss to make sure that the encoder commits to an embedding

In [1]:
import os
import subprocess

import math

import matplotlib.pyplot as plt
import numpy as np
import pickle

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim

import librosa
import random
from wavenet_vocoder.wavenet import WaveNet
from wavenet_vocoder.wavenet import receptive_field_size
#from vq import VectorQuantizerEMA

In [2]:
import easydict
args = easydict.EasyDict({
    "batch": 1,
    "epochs": 500,
    "training_data": './2_speaker/vctk_train.txt',
    "test_data": './2_speaker/vctk_test.txt',
#    "training_data": './vctk_train.txt',
#    "test_data": './vctk_test.txt',
#    "out": "result",
#    "resume": False,
    "load": 0,
    "load_mid" : 0,
    "seed": 123456789 })

In [3]:
device = torch.device("cuda")
#torch.cuda.set_device(0)
device

torch.manual_seed(args.seed)
np.random.seed(args.seed)
random.seed(args.seed)

In [4]:
with open(args.training_data, 'r') as f:
    data = f.read()
file = data.splitlines()
speaker_dic = {}
number_of_speakers = 0
for i in range (0, len(file)):
    if (file[i].split('/')[0] in speaker_dic):
        continue
    else :
        speaker_dic[file[i].split('/')[0]] = number_of_speakers
        number_of_speakers+=1
        

In [5]:
#TO DO: check that weight gets updated
class VectorQuantizerEMA(nn.Module):
    """We will also implement a slightly modified version  which will use exponential moving averages
    to update the embedding vectors instead of an auxillary loss.
    This has the advantage that the embedding updates are independent of the choice of optimizer 
    for the encoder, decoder and other parts of the architecture.
    For most experiments the EMA version trains faster than the non-EMA version."""
    def __init__(self, num_embeddings, embedding_dim, commitment_cost, decay, epsilon=1e-5):
        super(VectorQuantizerEMA, self).__init__()
        
        self._embedding_dim = embedding_dim
        self._num_embeddings = num_embeddings
        
        self._embedding = nn.Embedding(self._num_embeddings, self._embedding_dim)
        #self._embedding.weight.data.normal_()
        self._embedding.weight.data.uniform_(-1./512, 1./512)
#        self._embedding.weight.data = torch.Tensor([0])
        #self._embedding.weight.data = torch.Tensor(np.zeros(()))
        self._commitment_cost = commitment_cost
        
        self.register_buffer('_ema_cluster_size', torch.zeros(num_embeddings))
        self._ema_w = nn.Parameter(torch.Tensor(num_embeddings, self._embedding_dim))
        self._ema_w.data.normal_()
        
        self._decay = decay
        self._epsilon = epsilon
    '''
    def forward(self, inputs):
        # convert inputs from BCL -> BLC
        inputs = inputs.permute(0, 2, 1).contiguous()
        input_shape = inputs.shape
        # Flatten input
        flat_input = inputs.view(-1, self._embedding_dim)     #[BL, C]
        if (self._embedding.weight.data == 0).all():
            self._embedding.weight.data = flat_input[-self._num_embeddings:].detach()
        # Calculate distances

        distances = (torch.sum(flat_input**2, dim=1, keepdim=True) 
                    + torch.sum(self._embedding.weight**2, dim=1)
                    - 2 * torch.matmul(flat_input, self._embedding.weight.t())) #[BL, num_embeddings]
        # Encoding
        encoding_indices = torch.argmin(distances, dim=1).unsqueeze(1) #[BL, 1]
        encodings = torch.zeros(encoding_indices.shape[0], self._num_embeddings).to(device)# [BL, num_embeddings]
        encodings.scatter_(1, encoding_indices, 1)
        #print(encodings.shape) [250, 512]
        # Use EMA to update the embedding vectors
        if self.training:
            self._ema_cluster_size = self._ema_cluster_size * self._decay + \
                                     (1 - self._decay) * torch.sum(encodings, 0)
            #print(self._ema_cluster_size.shape) [512]
            n = torch.sum(self._ema_cluster_size)
            self._ema_cluster_size = (
                (self._ema_cluster_size + self._epsilon)
                / (n + self._num_embeddings * self._epsilon) * n)
            
            dw = torch.matmul(encodings.t(), flat_input)
            self._ema_w = nn.Parameter(self._ema_w * self._decay + (1 - self._decay) * dw)
            
            self._embedding.weight = nn.Parameter(self._ema_w / self._ema_cluster_size.unsqueeze(1))
        
        # Quantize and unflatten
        #encodings.shape = [BL, num_embeddings] , weight.shape=[num_embeddings, C]
        quantized = torch.matmul(encodings, self._embedding.weight).view(input_shape)

        
        # Loss
        e_latent_loss = torch.mean((quantized.detach() - inputs)**2)
        q_latent_loss = torch.mean((quantized - inputs.detach())**2)
#        print(q_latent_loss.item(), 0.25 * e_latent_loss.item())
        loss = q_latent_loss + self._commitment_cost * e_latent_loss
        
        quantized = inputs + (quantized - inputs).detach()
        avg_probs = torch.mean(encodings, dim=0)
        perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10)))
        # convert quantized from BLC -> BCL
        return loss, quantized.permute(0, 2, 1).contiguous(), perplexity
    '''
    
    def forward(self, inputs):
        # convert inputs from BCL -> BLC
        inputs = inputs.permute(0, 2, 1).contiguous()
        input_shape = inputs.shape
        # Flatten input
        flat_input = inputs.view(-1, self._embedding_dim)     #[BL, C]
        # Calculate distances
        
        distances = torch.norm(flat_input.unsqueeze(1) - self._embedding.weight, dim=2, p=2)
 #       distances = (torch.sum(flat_input**2, dim=1, keepdim=True) 
 #                   + torch.sum(self._embedding.weight**2, dim=1)
        # Encoding
        encoding_indices = torch.argmin(distances, dim=1).unsqueeze(1) #[BL, 1]
        print(encoding_indices.unsqueeze(1))
        encodings = torch.zeros(encoding_indices.shape[0], self._num_embeddings).to(device)# [BL, num_embeddings]
        encodings.scatter_(1, encoding_indices, 1)
        #print(encodings.shape) [250, 512]

#         # Use EMA to update the embedding vectors
#         if self.training:
#             self._ema_cluster_size = self._ema_cluster_size * self._decay + \
#                                      (1 - self._decay) * torch.sum(encodings, 0)
#             #print(self._ema_cluster_size.shape) [512]
#             n = torch.sum(self._ema_cluster_size)
#             self._ema_cluster_size = (
#                 (self._ema_cluster_size + self._epsilon)
#                 / (n + self._num_embeddings * self._epsilon) * n)
            
#             dw = torch.matmul(encodings.t(), flat_input)
#             self._ema_w = nn.Parameter(self._ema_w * self._decay + (1 - self._decay) * dw)
            
#             self._embedding.weight = nn.Parameter(self._ema_w / self._ema_cluster_size.unsqueeze(1))

        # Quantize and unflatten
        #encodings.shape = [BL, num_embeddings] , weight.shape=[num_embeddings, C]
        quantized = torch.matmul(encodings, self._embedding.weight).view(input_shape)
        # Loss
        e_latent_loss = torch.mean((quantized.detach() - inputs)**2)
        q_latent_loss = torch.mean((quantized - inputs.detach())**2)
#        print(q_latent_loss.item(), 0.25 * e_latent_loss.item())
        loss = q_latent_loss + self._commitment_cost * e_latent_loss
        
        quantized = inputs + (quantized - inputs).detach()
        avg_probs = torch.mean(encodings, dim=0)
        perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10)))
        # same as torch.exp( entropy loss )
        
        # convert quantized from BLC -> BCL
        return loss, quantized.permute(0, 2, 1).contiguous(), perplexity
#    '''

## Encoder & Decoder Architecture

In [6]:
class Encoder(nn.Module):
    """Audio encoder
    The vq-vae paper says that the encoder has 6 strided convolutions with stride 2 and window-size 4.
    The number of channels and a nonlinearity is not specified in the paper. 
    I tried using ReLU, it didn't work.
    Now I try using tanh, hoping that this will keep my encoded values within the neighborhood of 0,
    so they do not drift too far away from encoding vectors.
    """

    c1 = nn.Conv1d(in_channels=256, out_channels = 512, stride=2,kernel_size=4,padding=0)
    c2 = nn.Conv1d(in_channels=512, out_channels = 512, stride=2,kernel_size=4,padding=0)
    c3 = nn.Conv1d(in_channels=512, out_channels = 512, stride=2,kernel_size=4,padding=0)
    c4 = nn.Conv1d(in_channels=512, out_channels = 512, stride=2,kernel_size=4,padding=0)
    c5 = nn.Conv1d(in_channels=512, out_channels = 512, stride=2,kernel_size=4,padding=0)
    c6 = nn.Conv1d(in_channels=512, out_channels = 64, stride=2,kernel_size=4,padding=0)


    def __init__(self, encoding_channels, in_channels=256):
        super(Encoder,self).__init__()
        self._num_layers = 2 * len(encoding_channels)
        self._layers = nn.ModuleList()
        nn.init.xavier_uniform_(self.c1.weight)
        nn.init.xavier_uniform_(self.c2.weight)
        nn.init.xavier_uniform_(self.c3.weight)
        nn.init.xavier_uniform_(self.c4.weight)
        nn.init.xavier_uniform_(self.c5.weight)
        nn.init.xavier_uniform_(self.c6.weight)
        
        self._layers.append(self.c1)
        self._layers.append(nn.Tanh())
        self._layers.append(self.c2)
        self._layers.append(nn.Tanh())
        self._layers.append(self.c3)
        self._layers.append(nn.Tanh())
        self._layers.append(self.c4)
        self._layers.append(nn.Tanh())
        self._layers.append(self.c5)
        self._layers.append(nn.Tanh())
        self._layers.append(self.c6)
        self._layers.append(nn.Tanh())
        
#         for out_channels in encoding_channels:
#             self._layers.append(nn.Conv1d(in_channels=in_channels,
#                                     out_channels=out_channels,
#                                     stride=2,
#                                     kernel_size=4,
#                                     padding=0, 
#                                         ))
#             self._layers.append(nn.Tanh())
#             in_channels = out_channels
        
    def forward(self, x):
        for i in range(self._num_layers):
            x = self._layers[i](x)
        return x

In [7]:
class Model(nn.Module):
    def __init__(self,
                 encoding_channels,
                 num_embeddings, 
                 embedding_dim,
                 commitment_cost, 
                 layers,
                 stacks,
                 kernel_size,
                 decay=0):
        super(Model, self).__init__()       
        self._encoder = Encoder(encoding_channels=encoding_channels)
        #I tried adding batch normalization here, because:
        #the distribution of encoded values needs to be similar to the distribution of embedding vectors
        #otherwise we'll see "posterior collapse": all values will be assigned to the same embedding vector,
        #and stay that way (because vectors which do not get assigned anything do not get updated).
        #Batch normalization is a way to fix that. But it didn't work: model
        #reproduced voice correctly, but the words were completely wrong.
        #self._batch_norm = nn.BatchNorm1d(1)
        if decay > 0.0:
#             self._vq_vae = EMVectorQuantizerEMA(num_embeddings, embedding_dim, 
#                                               commitment_cost, decay, 100)
            self._vq_vae = VectorQuantizerEMA(num_embeddings, embedding_dim, 
                                               commitment_cost, decay)

        else:
            self._vq_vae = VectorQuantizer(num_embeddings, embedding_dim,
                                           commitment_cost)
        self._decoder = WaveNet(device, out_channels=256, #dimension of ohe mu-quantized signal
                                layers=layers, #like in original WaveNet
                                stacks=stacks,
                                residual_channels=512,
                                gate_channels=512,
                                skip_out_channels=512,
                                kernel_size=kernel_size, 
                                dropout=1 - 0.95,
                                cin_channels=embedding_dim, #local conditioning channels - on encoder output
                                gin_channels=number_of_speakers, #global conditioning channels - on speaker_id
                                n_speakers=number_of_speakers,
                                weight_normalization=False, 
                                upsample_conditional_features=True, 
                                decoding_channels=encoding_channels[::-1],
                                use_speaker_embedding=False
                               )
        self.recon_loss = torch.nn.CrossEntropyLoss()
        self.receptive_field = receptive_field_size(total_layers=layers, num_cycles=stacks, kernel_size=kernel_size)
#        self.mean = None
#        self.std = None
    def forward(self, x):
        audio, target, speaker_id = x
        assert len(audio.shape) == 3 # B x C x L 
        assert audio.shape[1] == 256
        z = self._encoder(audio)
        #normalize output - subtract mean, divide by standard deviation
        #without this, perplexity goes to 1 almost instantly
#         if self.mean is None:
#             self.mean = z.mean().detach()
#         if self.std is None:
#              self.std = z.std().detach()
#        z = z - self.mean
#        z = z / self.std
        vq_loss, quantized, perplexity = self._vq_vae(z)
#        assert z.shape == quantized.shape
#        print("audio.shape", audio.shape)
#        print("quantized.shape", quantized.shape)
        x_recon = self._decoder(audio, quantized, speaker_id, softmax=False)
        x_recon = x_recon[:, :, self.receptive_field:-1]
        recon_loss_value = self.recon_loss(x_recon, target[:, 1:])
        loss = recon_loss_value + vq_loss
        
        return loss, recon_loss_value, x_recon, perplexity

# Train

In [8]:
num_training_updates = 39818
#vector quantizer parameters:
embedding_dim = 64 #dimension of each vector
encoding_channels = [512,512,512,512,512,embedding_dim]
num_embeddings = 512 #number of vectors
commitment_cost = 0.25

#wavenet parameters:
kernel_size=2
total_layers=30
num_cycles=3


decay = 0.99
#decay = 0

learning_rate = 1e-3
batch_size=1

In [9]:
receptive_field = receptive_field_size(total_layers=total_layers, num_cycles=num_cycles, kernel_size=kernel_size)
print(receptive_field)

3070


## Load data

In [10]:
model = Model(num_embeddings=num_embeddings,
              encoding_channels=encoding_channels,
              embedding_dim=embedding_dim, 
              commitment_cost=commitment_cost, 
              layers=total_layers,
              stacks=num_cycles,
              kernel_size=kernel_size,
              decay=decay).to(device)

In [11]:
optimizer = optim.Adam(model.parameters(), lr=1, amsgrad=False)

scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer,
                                              lr_lambda=lambda epoch: 1e-3 if epoch == 0 else  (optimizer.param_groups[0]['lr'] - (1e-3 - 1e-6)/500) if epoch <= 500 else optimizer.param_groups[0]['lr'])

In [12]:
class TrainingSet(Dataset):
    # VCTK-Corpus Training data set

    def __init__(self, num_speakers,
                 receptive_field,
                 segment_length=16126,
                 chunk_size=1000,
                 classes=256):
        
        self.x_list = self.read_files(args.training_data)
        self.classes = 256
        self.segment_length = segment_length
        self.chunk_size = chunk_size
        self.classes = classes
        self.receptive_field = receptive_field
        self.cached_pt = 0
        self.num_speakers = num_speakers

    def read_files(self, filename):
        print("training data from " + args.training_data)
        with open(filename) as file:
            files = file.readlines()
        return [f.strip() for f in files]

    def __getitem__(self, index):
        try:
            audio, sr = librosa.load('./VCTK/wav48/'+self.x_list[index])
        except Exception as e:
            print(e, audiofile)
        if sr != 22050:
            raise ValueError("{} SR of {} not equal to 22050".format(sr, audiofile))
            
        audio = librosa.util.normalize(audio) #divide max(abs(audio))
        audio = self.quantize_data(audio, self.classes)
            
        while audio.shape[0] < self.segment_length:
            index += 1
            audio, speaker_id = librosa.load('./VCTK/wav48/'+self.x_list[index])
            
        max_audio_start = audio.shape[0] - self.segment_length
        audio_start = random.randint(0, max_audio_start)
        audio = audio[audio_start:audio_start+self.segment_length]
        
                #divide into input and target
        audio = torch.from_numpy(audio)
        ohe_audio = torch.FloatTensor(self.classes, self.segment_length).zero_()
        ohe_audio.scatter_(0, audio.unsqueeze(0), 1.)
        target = audio[self.receptive_field:]
            
        speaker_index = speaker_dic[self.x_list[index].split('/')[0]]
        speaker_id = torch.from_numpy(np.array(speaker_index)).unsqueeze(0).unsqueeze(0)
        ohe_speaker = torch.FloatTensor(self.num_speakers, 1).zero_()
        ohe_speaker.scatter_(0, speaker_id, 1.)
        
        return ohe_audio, target, ohe_speaker
    
    def __len__(self):
        return len(self.x_list)
    
    def quantize_data(self, data, classes):
        mu_x = self.mu_law_encode(data, classes)
        bins = np.linspace(-1, 1, classes)
        quantized = np.digitize(mu_x, bins) - 1
        return quantized

    def mu_law_encode(self, data, mu):
        mu_x = np.sign(data) * np.log(1 + mu * np.abs(data)) / np.log(mu + 1)
        return mu_x

In [13]:
class TestSet(Dataset):
    # VCTK-Corpus Test data set


    def __init__(self, num_speakers,
                 receptive_field,
                 segment_length=16126,
                 chunk_size=1000,
                 classes=256):
        
        
        self.x_list = self.read_files(args.test_data)
        self.classes = 256
        self.segment_length = segment_length
        self.chunk_size = chunk_size
        self.classes = classes
        self.receptive_field = receptive_field
        self.cached_pt = 0
        self.num_speakers = num_speakers


    def read_files(self, filename):
        print("training data from " + args.test_data)
        with open(filename) as file:
            files = file.readlines()
        return [f.strip() for f in files]

    def __getitem__(self, index):
        try:
            audio, sr = librosa.load('./VCTK/wav48/'+self.x_list[index])
        except Exception as e:
            print(e, audiofile)
        if sr != 22050:
            raise ValueError("{} SR of {} not equal to 22050".format(sr, audiofile))
        
        audio = librosa.util.normalize(audio) #divide max(abs(audio))
        audio = self.quantize_data(audio, self.classes)
            
        while audio.shape[0] < self.segment_length:
            index += 1
            audio, speaker_id = librosa.load('./VCTK/wav48/'+self.x_list[index])
            
        max_audio_start = audio.shape[0] - self.segment_length
        audio_start = random.randint(0, max_audio_start)
        audio = audio[audio_start:audio_start+self.segment_length]
        
                #divide into input and target
        audio = torch.from_numpy(audio)
        ohe_audio = torch.FloatTensor(self.classes, self.segment_length).zero_()
        ohe_audio.scatter_(0, audio.unsqueeze(0), 1.)
        target = audio[self.receptive_field:]
            
        speaker_index = speaker_dic[self.x_list[index].split('/')[0]]
        speaker_id = torch.from_numpy(np.array(speaker_index)).unsqueeze(0).unsqueeze(0)
        ohe_speaker = torch.FloatTensor(self.num_speakers, 1).zero_()
        ohe_speaker.scatter_(0, speaker_id, 1.)
        
        return ohe_audio, target, ohe_speaker

    def __len__(self):
        return len(self.x_list)
        
    def quantize_data(self, data, classes):
        mu_x = self.mu_law_encode(data, classes)
        bins = np.linspace(-1, 1, classes)
        quantized = np.digitize(mu_x, bins) - 1
        return quantized

    def mu_law_encode(self, data, mu):
        mu_x = np.sign(data) * np.log(1 + mu * np.abs(data)) / np.log(mu + 1)
        return mu_x

In [14]:
trainset = TrainingSet(number_of_speakers, receptive_field=receptive_field)
testset = TestSet(number_of_speakers, receptive_field=receptive_field)


training_loader = DataLoader(dataset = trainset,
                           batch_size=batch_size,
                           shuffle=True, 
                           num_workers=1)


validation_loader = DataLoader(dataset = testset,
                           batch_size=batch_size,
                           shuffle=True, 
                           num_workers=1)

training data from ./2_speaker/vctk_train.txt
training data from ./2_speaker/vctk_test.txt


In [15]:
train_res_recon_error = []
train_res_perplexity = []

In [16]:
def train():
    model.train()
    global train_res_recon_error
    global train_res_perplexity
    train_total_loss = []
    train_recon_error = []
    train_perplexity = []
    # with open("errors", "rb") as file:
    #     train_res_recon_error, train_res_perplexity = pickle.load(file)
# num_epochs = 1
# for epoch in range(num_epochs):
    iterator = iter(training_loader)
#     datas0 = []
#     datas1 = []
#     datas2 = []
    for i, data_train in enumerate(iterator):
        data_train = [data_train[0].to(device),
                     data_train[1].to(device),
                     data_train[2].to(device)
                     ]

#         datas0.append(data_train[0])
#         datas1.append(data_train[1])
#         datas2.append(data_train[2])
#         if (i+1) % batch_size == 0:
#             data = [torch.cat(datas0).to(device),
#                    torch.cat(datas1).to(device),
#                    torch.cat(datas2).to(device)]
        optimizer.zero_grad()
        loss, recon_error, data_recon, perplexity = model(data_train)
        loss.backward()
        optimizer.step()
        train_total_loss.append(loss.item())
        train_recon_error.append(recon_error.item())
        train_perplexity.append(perplexity.item())

        if (i+1) % (10 * batch_size) == 0:
            print('%d iterations' % (i+1))
            print('recon_error: %.3f' % np.mean(train_recon_error[-100:]))
            print('perplexity: %.3f' % np.mean(train_perplexity[-100:]))
            print()
    train_res_recon_error.extend(train_recon_error)
    train_res_perplexity.extend(train_perplexity)
    return np.mean(train_total_loss), np.mean(train_res_recon_error)

In [17]:
def validation():
    model.eval()
    with torch.no_grad():
        test_total_loss = []
        test_res_recon_error = []
        # with open("errors", "rb") as file:
        #     train_res_recon_error, train_res_perplexity = pickle.load(file)
    # num_epochs = 1
    # for epoch in range(num_epochs):
        iterator = iter(validation_loader)
    #     datas0 = []
    #     datas1 = []
    #     datas2 = []
        for i, data_test in enumerate(iterator):
            data_test = [data_test[0].to(device),
                         data_test[1].to(device),
                         data_test[2].to(device)]
            
            loss, recon_error, data_recon, perplexity = model(data_test)

            test_total_loss.append(loss.item())
            test_res_recon_error.append(recon_error.item())

            if (i+1) % (10 * batch_size) == 0:
                print('%d iterations' % (i+1))
                print('recon_error: %.3f' % np.mean(test_res_recon_error[-100:]))
                print()
    return np.mean(test_total_loss), np.mean(test_res_recon_error)

In [18]:
epochs = args.epochs
training_total_loss_per_epochs = []
training_reconstruction_errors_per_epochs = []
validation_total_loss_per_epochs = []
validation_reconstruction_errors_per_epochs = []

lrs = []

if (args.load != 0):
    model.load_state_dict(torch.load("model_epoch"+str(args.load)))
    optimizer.load_state_dict(torch.load("optim_epoch"+str(args.load)))
    training_total_loss_per_epochs = np.load('training_total_loss_per_epochs'+str(args.load)+'.npy').tolist()
    training_reconstruction_errors_per_epochs = np.load('training_reconstruction_errors_per_epochs'+str(args.load)+'.npy').tolist()
    validation_total_loss_per_epochs = np.load('validation_total_loss_per_epochs'+str(args.load)+'.npy').tolist()
    validation_reconstruction_errors_per_epochs = np.load('validation_reconstruction_errors_per_epochs'+str(args.load)+'.npy').tolist()
    lrs = np.load('lrs.npy')
    
    
if (args.load_mid != 0 and args.load == 0):
    model.load_state_dict(torch.load("model_epoch"+str(args.load)))
    optimizer.load_state_dict(torch.load("optim_epoch"+str(args.load)))


for i in range(1, epochs+1):
    print(str(i)+" epochs ==> training")
    total_loss, reconstruction_loss = train()
    training_total_loss_per_epochs.append(total_loss)
    training_reconstruction_errors_per_epochs.append(reconstruction_loss)
    
    print(str(i)+" epochs ==> validation")
    total_loss, reconstruction_loss = validation()
    validation_total_loss_per_epochs.append(total_loss)
    validation_reconstruction_errors_per_epochs.append(reconstruction_loss)

    
    if (i % 5 == 0):
        torch.save(model.state_dict(), "model_epoch"+str(i+args.load))
        torch.save(optimizer.state_dict(), "optim_epoch"+str(i+args.load))
        
    for param_group in optimizer.param_groups:
        lr = param_group['lr']
    lrs.append(lr)
    np.save('lrs.npy', lrs)
    np.save('training_total_loss_per_epochs'+str(args.epochs + args.load), np.array(training_total_loss_per_epochs))
    np.save('training_reconstruction_errors_per_epochs'+str(args.epochs + args.load), np.array(training_reconstruction_errors_per_epochs))
    np.save('validation_total_loss_per_epochs'+str(args.epochs + args.load), np.array(validation_total_loss_per_epochs))
    np.save('validation_reconstruction_errors_per_epochs'+str(args.epochs + args.load), np.array(validation_reconstruction_errors_per_epochs))
    scheduler.step()

1 epochs ==> training
tensor([[[430]],

        [[153]],

        [[ 80]],

        [[484]],

        [[201]],

        [[349]],

        [[342]],

        [[ 20]],

        [[342]],

        [[186]],

        [[189]],

        [[189]],

        [[253]],

        [[342]],

        [[113]],

        [[405]],

        [[346]],

        [[ 63]],

        [[342]],

        [[239]],

        [[ 20]],

        [[308]],

        [[342]],

        [[321]],

        [[275]],

        [[436]],

        [[135]],

        [[342]],

        [[447]],

        [[466]],

        [[168]],

        [[405]],

        [[342]],

        [[130]],

        [[349]],

        [[342]],

        [[396]],

        [[264]],

        [[147]],

        [[466]],

        [[436]],

        [[370]],

        [[342]],

        [[239]],

        [[105]],

        [[138]],

        [[115]],

        [[351]],

        [[264]],

        [[480]],

        [[139]],

        [[447]],

        [[335]],

        [[ 20]],

      

KeyboardInterrupt: 