In [1]:
"""
An implementation of a Deep Markov Model in Pyro based on reference [1].
This is essentially the DKS variant outlined in the paper. The primary difference
between this implementation and theirs is that in our version any KL divergence terms
in the ELBO are estimated via sampling, while they make use of the analytic formulae.
We also illustrate the use of normalizing flows in the variational distribution (in which
case analytic formulae for the KL divergences are in any case unavailable).
Reference:
[1] Structured Inference Networks for Nonlinear State Space Models [arXiv:1609.09869]
    Rahul G. Krishnan, Uri Shalit, David Sontag
"""

import torch
import torch.nn as nn
import numpy as np
import pyro
from pyro.infer import SVI
from pyro.optim import ClippedAdam
import pyro.distributions as dist
from pyro.util import ng_ones
from pyro.distributions.transformed_distribution import InverseAutoregressiveFlow
from pyro.distributions.transformed_distribution import TransformedDistribution
import six.moves.cPickle as pickle
from os.path import exists
import argparse
import time
from util import get_logger


class Emitter(nn.Module):
    """
    Parameterizes the bernoulli observation likelihood `p(x_t | z_t)`
    """
    def __init__(self, input_dim, z_dim, emission_dim):
        super(Emitter, self).__init__()
        # initialize the three linear transformations used in the neural network
        self.lin_z_to_hidden = nn.Linear(z_dim, emission_dim)
        self.lin_hidden_to_hidden = nn.Linear(emission_dim, emission_dim)
        self.lin_hidden_to_input = nn.Linear(emission_dim, input_dim)
        # initialize the two non-linearities used in the neural network
        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()

    def forward(self, z_t):
        """
        Given the latent z at a particular time step t we return the vector of
        probabilities `ps` that parameterizes the bernoulli distribution `p(x_t|z_t)`
        """
        h1 = self.relu(self.lin_z_to_hidden(z_t))
        h2 = self.relu(self.lin_hidden_to_hidden(h1))
        ps = self.sigmoid(self.lin_hidden_to_input(h2))
        return ps


class GatedTransition(nn.Module):
    """
    Parameterizes the gaussian latent transition probability `p(z_t | z_{t-1})`
    See section 5 in the reference for comparison.
    """
    def __init__(self, z_dim, transition_dim):
        super(GatedTransition, self).__init__()
        # initialize the six linear transformations used in the neural network
        self.lin_gate_z_to_hidden = nn.Linear(z_dim, transition_dim)
        self.lin_gate_hidden_to_z = nn.Linear(transition_dim, z_dim)
        self.lin_proposed_mean_z_to_hidden = nn.Linear(z_dim, transition_dim)
        self.lin_proposed_mean_hidden_to_z = nn.Linear(transition_dim, z_dim)
        self.lin_sig = nn.Linear(z_dim, z_dim)
        self.lin_z_to_mu = nn.Linear(z_dim, z_dim)
        # modify the default initialization of lin_z_to_mu
        # so that it's starts out as the identity function
        self.lin_z_to_mu.weight.data = torch.eye(z_dim)
        self.lin_z_to_mu.bias.data = torch.zeros(z_dim)
        # initialize the three non-linearities used in the neural network
        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()
        self.softplus = nn.Softplus()

    def forward(self, z_t_1):
        """
        Given the latent `z_{t-1}` corresponding to the time step t-1
        we return the mean and sigma vectors that parameterize the
        (diagonal) gaussian distribution `p(z_t | z_{t-1})`
        """
        # compute the gating function and one minus the gating function
        gate_intermediate = self.relu(self.lin_gate_z_to_hidden(z_t_1))
        gate = self.sigmoid(self.lin_gate_hidden_to_z(gate_intermediate))
        one_minus_gate = ng_ones(gate.size()).type_as(gate) - gate
        # compute the 'proposed mean'
        proposed_mean_intermediate = self.relu(self.lin_proposed_mean_z_to_hidden(z_t_1))
        proposed_mean = self.lin_proposed_mean_hidden_to_z(proposed_mean_intermediate)
        # assemble the actual mean used to sample z_t, which mixes a linear transformation
        # of z_{t-1} with the proposed mean modulated by the gating function
        mu = one_minus_gate * self.lin_z_to_mu(z_t_1) + gate * proposed_mean
        # compute the sigma used to sample z_t, using the proposed mean from above as input
        # the softplus ensures that sigma is positive
        sigma = self.softplus(self.lin_sig(self.relu(proposed_mean)))
        # return mu, sigma which can be fed into Normal
        return mu, sigma


class Combiner(nn.Module):
    """
    Parameterizes `q(z_t | z_{t-1}, x_{t:T})`, which is the basic building block
    of the guide (i.e. the variational distribution). The dependence on `x_{t:T}` is
    through the hidden state of the RNN (see the PyTorch module `rnn` below)
    """
    def __init__(self, z_dim, rnn_dim):
        super(Combiner, self).__init__()
        # initialize the three linear transformations used in the neural network
        self.lin_z_to_hidden = nn.Linear(z_dim, rnn_dim)
        self.lin_hidden_to_mu = nn.Linear(rnn_dim, z_dim)
        self.lin_hidden_to_sigma = nn.Linear(rnn_dim, z_dim)
        # initialize the two non-linearities used in the neural network
        self.tanh = nn.Tanh()
        self.softplus = nn.Softplus()

    def forward(self, z_t_1, h_rnn):
        """
        Given the latent z at at a particular time step t-1 as well as the hidden
        state of the RNN `h(x_{t:T})` we return the mean and sigma vectors that
        parameterize the (diagonal) gaussian distribution `q(z_t | z_{t-1}, x_{t:T})`
        """
        # combine the rnn hidden state with a transformed version of z_t_1
        h_combined = 0.5 * (self.tanh(self.lin_z_to_hidden(z_t_1)) + h_rnn)
        # use the combined hidden state to compute the mean used to sample z_t
        mu = self.lin_hidden_to_mu(h_combined)
        # use the combined hidden state to compute the sigma used to sample z_t
        sigma = self.softplus(self.lin_hidden_to_sigma(h_combined))
        # return mu, sigma which can be fed into Normal
        return mu, sigma


class DMM(nn.Module):
    """
    This PyTorch Module encapsulates the model as well as the
    variational distribution (the guide) for the Deep Markov Model
    """
    def __init__(self, input_dim=88, z_dim=100, emission_dim=100,
                 transition_dim=200, rnn_dim=600, rnn_dropout_rate=0.0,
                 num_iafs=0, iaf_dim=50, use_cuda=False):
        super(DMM, self).__init__()
        # instantiate PyTorch modules used in the model and guide below
        self.emitter = Emitter(input_dim, z_dim, emission_dim)
        self.trans = GatedTransition(z_dim, transition_dim)
        self.combiner = Combiner(z_dim, rnn_dim)
        self.rnn = nn.RNN(input_size=input_dim, hidden_size=rnn_dim, nonlinearity='relu',
                          batch_first=True, bidirectional=False, num_layers=1,
                          dropout=rnn_dropout_rate)

        # if we're using normalizing flows, instantiate those too
        iafs = [InverseAutoregressiveFlow(z_dim, iaf_dim) for _ in range(num_iafs)]
        self.iafs = nn.ModuleList(iafs)

        # define a (trainable) parameters z_0 and z_q_0 that help define the probability
        # distributions p(z_1) and q(z_1)
        # (since for t = 1 there are no previous latents to condition on)
        self.z_0 = nn.Parameter(torch.zeros(z_dim))
        self.z_q_0 = nn.Parameter(torch.zeros(z_dim))
        # define a (trainable) parameter for the initial hidden state of the rnn
        self.h_0 = nn.Parameter(torch.zeros(1, 1, rnn_dim))

        self.use_cuda = use_cuda
        # if on gpu cuda-ize all PyTorch (sub)modules
        if use_cuda:
            self.cuda()

    # the model p(x_{1:T} | z_{1:T}) p(z_{1:T})
    def model(self, mini_batch, mini_batch_reversed, mini_batch_mask,
              mini_batch_seq_lengths, annealing_factor=1.0):

        # this is the number of time steps we need to process in the mini-batch
        T_max = mini_batch.size(1)

        # register all PyTorch (sub)modules with pyro
        # this needs to happen in both the model and guide
        pyro.module("dmm", self)

        # set z_prev = z_0 to setup the recursive conditioning in p(z_t | z_{t-1})
        z_prev = self.z_0.expand(mini_batch.size(0), self.z_0.size(0))

        # sample the latents z and observed x's one time step at a time
        for t in range(1, T_max + 1):
            # the next three lines of code sample z_t ~ p(z_t | z_{t-1})
            # note that (both here and elsewhere) log_pdf_mask takes care of both
            # (i)  KL annealing; and
            # (ii) raggedness in the observed data (i.e. different sequences
            #      in the mini-batch have different lengths)

            # first compute the parameters of the diagonal gaussian distribution p(z_t | z_{t-1})
            z_mu, z_sigma = self.trans(z_prev)
            # then sample z_t according to dist.Normal(z_mu, z_sigma)
            z_t = pyro.sample("z_%d" % t,
                              dist.normal,
                              z_mu,
                              z_sigma,
                              log_pdf_mask=annealing_factor * mini_batch_mask[:, t - 1:t])

            # compute the probabilities that parameterize the bernoulli likelihood
            emission_probs_t = self.emitter(z_t)
            # the next statement instructs pyro to observe x_t according to the
            # bernoulli distribution p(x_t|z_t)
            pyro.observe("obs_x_%d" % t, dist.bernoulli, mini_batch[:, t - 1, :],
                         emission_probs_t,
                         log_pdf_mask=mini_batch_mask[:, t - 1:t])
            # the latent sampled at this time step will be conditioned upon
            # in the next time step so keep track of it
            z_prev = z_t

    # the guide q(z_{1:T} | x_{1:T}) (i.e. the variational distribution)
    def guide(self, mini_batch, mini_batch_reversed, mini_batch_mask,
              mini_batch_seq_lengths, annealing_factor=1.0):

        # this is the number of time steps we need to process in the mini-batch
        T_max = mini_batch.size(1)
        # register all PyTorch (sub)modules with pyro
        pyro.module("dmm", self)

        # if on gpu we need the fully broadcast view of the rnn initial state
        # to be in contiguous gpu memory
        h_0_contig = self.h_0 if not self.use_cuda \
            else self.h_0.expand(1, mini_batch.size(0), self.rnn.hidden_size).contiguous()
        # push the observed x's through the rnn;
        # rnn_output contains the hidden state at each time step
        rnn_output, _ = self.rnn(mini_batch_reversed, h_0_contig)
        # reverse the time-ordering in the hidden state and un-pack it
        rnn_output = poly.pad_and_reverse(rnn_output, mini_batch_seq_lengths)
        # set z_prev = z_q_0 to setup the recursive conditioning in q(z_t |...)
        z_prev = self.z_q_0

        # sample the latents z one time step at a time
        for t in range(1, T_max + 1):
            # the next two lines assemble the distribution q(z_t | z_{t-1}, x_{t:T})
            z_mu, z_sigma = self.combiner(z_prev, rnn_output[:, t - 1, :])
            z_dist = dist.normal

            # if we are using normalizing flows, we apply the sequence of transformations
            # parameterized by self.iafs to the base distribution defined in the previous line
            # to yield a transformed distribution that we use for q(z_t|...)
            if self.iafs.__len__() > 0:
                z_dist = TransformedDistribution(z_dist, self.iafs)
            # sample z_t from the distribution z_dist
            z_t = pyro.sample("z_%d" % t,
                              z_dist,
                              z_mu,
                              z_sigma,
                              log_pdf_mask=annealing_factor * mini_batch_mask[:, t - 1:t])
            # the latent sampled at this time step will be conditioned upon in the next time step
            # so keep track of it
            z_prev = z_t



In [2]:
#     parser = argparse.ArgumentParser(description="parse args")
#     parser.add_argument('-n', '--num-epochs', type=int, default=5000)
#     parser.add_argument('-lr', '--learning-rate', type=float, default=0.0004)
#     parser.add_argument('-b1', '--beta1', type=float, default=0.96)
#     parser.add_argument('-b2', '--beta2', type=float, default=0.999)
#     parser.add_argument('-cn', '--clip-norm', type=float, default=20.0)
#     parser.add_argument('-lrd', '--lr-decay', type=float, default=0.99996)
#     parser.add_argument('-wd', '--weight-decay', type=float, default=0.6)
#     parser.add_argument('-mbs', '--mini-batch-size', type=int, default=20)
#     parser.add_argument('-ae', '--annealing-epochs', type=int, default=1000)
#     parser.add_argument('-maf', '--minimum-annealing-factor', type=float, default=0.1)
#     parser.add_argument('-rdr', '--rnn-dropout-rate', type=float, default=0.1)
#     parser.add_argument('-iafs', '--num-iafs', type=int, default=0)
#     parser.add_argument('-id', '--iaf-dim', type=int, default=100)
#     parser.add_argument('-cf', '--checkpoint-freq', type=int, default=0)
#     parser.add_argument('-lopt', '--load-opt', type=str, default='')
#     parser.add_argument('-lmod', '--load-model', type=str, default='')
#     parser.add_argument('-sopt', '--save-opt', type=str, default='')
#     parser.add_argument('-smod', '--save-model', type=str, default='')
#     parser.add_argument('--cuda', action='store_true')
#     parser.add_argument('-l', '--log', type=str, default='dmm.log')

In [3]:
import torch
import torch.nn as nn
import numpy as np
from observations import jsb_chorales
from os.path import join, exists
import six.moves.cPickle as pickle


# this function processes the raw data; in particular it unsparsifies it
def process_data(base_path, filename, T_max=160, min_note=21, note_range=88):
    output = join(base_path, filename)
    if exists(output):
        return

    print("processing raw polyphonic music data...")
    data = jsb_chorales(base_path)
    processed_dataset = {}
    for split, data_split in zip(['train', 'test', 'valid'], data):
        processed_dataset[split] = {}
        n_seqs = len(data_split)
        processed_dataset[split]['sequence_lengths'] = np.zeros((n_seqs), dtype=np.int32)
        processed_dataset[split]['sequences'] = np.zeros((n_seqs, T_max, note_range))
        for seq in range(n_seqs):
            seq_length = len(data_split[seq])
            processed_dataset[split]['sequence_lengths'][seq] = seq_length
            for t in range(seq_length):
                note_slice = np.array(list(data_split[seq][t])) - min_note
                slice_length = len(note_slice)
                if slice_length > 0:
                    processed_dataset[split]['sequences'][seq, t, note_slice] = np.ones((slice_length))
    pickle.dump(processed_dataset, open(output, "wb"))
    print("dumped processed data to %s" % output)


# this logic will be initiated upon import
base_path = './data'
process_data(base_path, "jsb_processed.pkl")


# this function takes a numpy mini-batch and reverses each sequence
# (w.r.t. the temporal axis, i.e. axis=1)
def reverse_sequences_numpy(mini_batch, seq_lengths):
    reversed_mini_batch = mini_batch.copy()
    for b in range(mini_batch.shape[0]):
        T = seq_lengths[b]
        reversed_mini_batch[b, 0:T, :] = mini_batch[b, (T - 1)::-1, :]
    return reversed_mini_batch


# this function takes a torch mini-batch and reverses each sequence
# (w.r.t. the temporal axis, i.e. axis=1)
# in contrast to `reverse_sequences_numpy`, this function plays
# nice with torch autograd
def reverse_sequences_torch(mini_batch, seq_lengths):
    reversed_mini_batch = mini_batch.new_zeros(mini_batch.size())
    for b in range(mini_batch.size(0)):
        T = seq_lengths[b]
        time_slice = np.arange(T - 1, -1, -1)
        time_slice = torch.cuda.LongTensor(time_slice) if 'cuda' in mini_batch.data.type() \
            else torch.LongTensor(time_slice)
        reversed_sequence = torch.index_select(mini_batch[b, :, :], 0, time_slice)
        reversed_mini_batch[b, 0:T, :] = reversed_sequence
    return reversed_mini_batch


# this function takes the hidden state as output by the PyTorch rnn and
# unpacks it it; it also reverses each sequence temporally
def pad_and_reverse(rnn_output, seq_lengths):
    rnn_output, _ = nn.utils.rnn.pad_packed_sequence(rnn_output, batch_first=True)
    reversed_output = reverse_sequences_torch(rnn_output, seq_lengths)
    return reversed_output


# this function returns a 0/1 mask that can be used to mask out a mini-batch
# composed of sequences of length `seq_lengths`
def get_mini_batch_mask(mini_batch, seq_lengths):
    mask = np.zeros(mini_batch.shape[0:2])
    for b in range(mini_batch.shape[0]):
        mask[b, 0:seq_lengths[b]] = np.ones(seq_lengths[b])
    return mask


# this function prepares a mini-batch for training or evaluation.
# it returns a mini-batch in forward temporal order (`mini_batch`) as
# well as a mini-batch in reverse temporal order (`mini_batch_reversed`).
# it also deals with the fact that packed sequences (which are what what we
# feed to the PyTorch rnn) need to be sorted by sequence length.
def get_mini_batch(mini_batch_indices, sequences, seq_lengths, cuda=False):
    # get the sequence lengths of the mini-batch
    seq_lengths = seq_lengths[mini_batch_indices]
    # sort the sequence lengths
    sorted_seq_length_indices = np.argsort(seq_lengths)[::-1]
    sorted_seq_lengths = seq_lengths[sorted_seq_length_indices]
    sorted_mini_batch_indices = mini_batch_indices[sorted_seq_length_indices]

    # compute the length of the longest sequence in the mini-batch
    T_max = np.max(seq_lengths)
    # this is the sorted mini-batch
    mini_batch = sequences[sorted_mini_batch_indices, 0:T_max, :]
    # this is the sorted mini-batch in reverse temporal order
    mini_batch_reversed = reverse_sequences_numpy(mini_batch, sorted_seq_lengths)
    # get mask for mini-batch
    mini_batch_mask = get_mini_batch_mask(mini_batch, sorted_seq_lengths)

    # wrap in PyTorch Tensors, using default tensor type
    mini_batch = torch.Tensor(mini_batch).type(torch.Tensor)
    mini_batch_reversed = torch.Tensor(mini_batch_reversed).type(torch.Tensor)
    mini_batch_mask = torch.Tensor(mini_batch_mask).type(torch.Tensor)

    # cuda() here because need to cuda() before packing
    if cuda:
        mini_batch = mini_batch.cuda()
        mini_batch_mask = mini_batch_mask.cuda()
        mini_batch_reversed = mini_batch_reversed.cuda()

    # do sequence packing
    mini_batch_reversed = nn.utils.rnn.pack_padded_sequence(mini_batch_reversed,
                                                            sorted_seq_lengths,
                                                            batch_first=True)

    return mini_batch, mini_batch_reversed, mini_batch_mask, sorted_seq_lengths

In [4]:
jsb_file_loc = "./data/jsb_processed.pkl"
# ingest training/validation/test data from disk
data = pickle.load(open(jsb_file_loc, "rb"))
training_seq_lengths = data['train']['sequence_lengths']
training_data_sequences = data['train']['sequences']
test_seq_lengths = data['test']['sequence_lengths']
test_data_sequences = data['test']['sequences']
val_seq_lengths = data['valid']['sequence_lengths']
val_data_sequences = data['valid']['sequences']
N_train_data = len(training_seq_lengths)
N_train_time_slices = np.sum(training_seq_lengths)
N_mini_batches = int(N_train_data / 20 +
                     int(N_train_data % 20 > 0))


# how often we do validation/test evaluation during training
val_test_frequency = 50
# the number of samples we use to do the evaluation
n_eval_samples = 1

# package repeated copies of val/test data for faster evaluation
# (i.e. set us up for vectorization)
def rep(x):
    y = np.repeat(x, n_eval_samples, axis=0)
    return y

# get the validation/test data ready for the dmm: pack into sequences, etc.
val_seq_lengths = rep(val_seq_lengths)
test_seq_lengths = rep(test_seq_lengths)
val_batch, val_batch_reversed, val_batch_mask, val_seq_lengths = get_mini_batch(np.arange(n_eval_samples * val_data_sequences.shape[0]), rep(val_data_sequences), val_seq_lengths)
test_batch, test_batch_reversed, test_batch_mask, test_seq_lengths = get_mini_batch(
    np.arange(n_eval_samples * test_data_sequences.shape[0]), rep(test_data_sequences),
    test_seq_lengths)

# instantiate the dmm
dmm = DMM(rnn_dropout_rate=0.1, num_iafs=0,
          iaf_dim=100)

# setup optimizer
adam_params = {"lr": 0.01, "betas": (0.96, 0.996),
               "clip_norm": 10, "lrd": 0.99996,
               "weight_decay": 0}
adam = ClippedAdam(adam_params)

# setup inference algorithm
svi = SVI(dmm.model, dmm.guide, adam, "ELBO", trace_graph=False)

# now we're going to define some functions we need to form the main training loop

# saves the model and optimizer states to disk
def save_checkpoint():
    log("saving model to %s..." % args.save_model)
    torch.save(dmm.state_dict(), args.save_model)
    log("saving optimizer states to %s..." % args.save_opt)
    adam.save(args.save_opt)
    log("done saving model and optimizer checkpoints to disk.")

# loads the model and optimizer states from disk
def load_checkpoint():
    assert exists(args.load_opt) and exists(args.load_model), \
        "--load-model and/or --load-opt misspecified"
    log("loading model from %s..." % args.load_model)
    dmm.load_state_dict(torch.load(args.load_model))
    log("loading optimizer states from %s..." % args.load_opt)
    adam.load(args.load_opt)
    log("done loading model and optimizer states.")

# prepare a mini-batch and take a gradient step to minimize -elbo
def process_minibatch(epoch, which_mini_batch, shuffled_indices):
    if 1000 > 0 and epoch < 1000:
        # compute the KL annealing factor approriate for the current mini-batch in the current epoch
        min_af = 0.1
        annealing_factor = min_af + (1.0 - min_af) * \
            (float(which_mini_batch + epoch * N_mini_batches + 1) /
             float(1000 * N_mini_batches))
    else:
        # by default the KL annealing factor is unity
        annealing_factor = 1.0

    # compute which sequences in the training set we should grab
    mini_batch_start = (which_mini_batch * 20)
    mini_batch_end = np.min([(which_mini_batch + 1) * 1000, N_train_data])
    mini_batch_indices = shuffled_indices[mini_batch_start:mini_batch_end]
    # grab a fully prepped mini-batch using the helper function in the data loader
    mini_batch, mini_batch_reversed, mini_batch_mask, mini_batch_seq_lengths \
        = get_mini_batch(mini_batch_indices, training_data_sequences,
                              training_seq_lengths)
    
    # do an actual gradient step
    loss = svi.step(mini_batch, mini_batch_reversed, mini_batch_mask,
                     mini_batch_seq_lengths, annealing_factor)
#     loss = elbo.step(mini_batch, mini_batch_reversed, mini_batch_mask,
#                      mini_batch_seq_lengths, annealing_factor)
    # keep track of the training loss
    return loss

# helper function for doing evaluation
def do_evaluation():
    # put the RNN into evaluation mode (i.e. turn off drop-out if applicable)
    dmm.rnn.eval()

    # compute the validation and test loss n_samples many times
    val_nll = elbo.evaluate_loss(val_batch, val_batch_reversed, val_batch_mask,
                                 val_seq_lengths) / np.sum(val_seq_lengths)
    test_nll = elbo.evaluate_loss(test_batch, test_batch_reversed, test_batch_mask,
                                  test_seq_lengths) / np.sum(test_seq_lengths)

    # put the RNN back into training mode (i.e. turn on drop-out if applicable)
    dmm.rnn.train()
    return val_nll, test_nll


In [5]:
# if checkpoint files provided, load model and optimizer states from disk before we start training

#################
# TRAINING LOOP #
#################
times = [time.time()]
for epoch in range(5000):
    # accumulator for our estimate of the negative log likelihood (or rather -elbo) for this epoch
    epoch_nll = 0.0
    # prepare mini-batch subsampling indices for this epoch
    shuffled_indices = np.arange(N_train_data)
    np.random.shuffle(shuffled_indices)
    
    # process each mini-batch; this is where we take gradient steps
    for which_mini_batch in range(N_mini_batches):
        epoch_nll += process_minibatch(epoch, which_mini_batch, shuffled_indices)

    # report training diagnostics
    times.append(time.time())
    epoch_time = times[-1] - times[-2]
    log("[training epoch %04d]  %.4f \t\t\t\t(dt = %.3f sec)" %
        (epoch, epoch_nll / N_train_time_slices, epoch_time))

    # do evaluation on test and validation data and report results
    if val_test_frequency > 0 and epoch > 0 and epoch % val_test_frequency == 0:
        val_nll, test_nll = do_evaluation()
        log("[val/test epoch %04d]  %.4f  %.4f" % (epoch, val_nll, test_nll))


RuntimeError: cannot call .data on a torch.Tensor: did you intend to use autograd.Variable?