# Model

In [None]:
import sys
import torch
import torch.nn as nn
from torch.autograd import Variable
from torch.nn import functional as F
from datasets import N_ATTRS


class MVAE(nn.Module):
    def __init__(self,
                 n_latents,
                 image_encoder,
                 image_decoder,
                 text_encoder,
                 text_decoder):
        super(MVAE, self).__init__()
        self.image_encoder = image_encoder
        self.image_decoder = image_decoder
        self.text_encoder = text_encoder
        self.text_decoder = text_decoder
        self.product_of_experts = ProductOfExperts()
        self.n_latents = n_latents

    def reparametrize(self, mu, logvar):
        if self.training:
            std = logvar.mul(0.5).exp_()
            eps = Variable(std.data.new(std.size()).normal_())
            return eps.mul(std).add_(mu)
        else:  # return mean during inference
            return mu

    def forward(self, img_emb=None, text_emb=None):
        mu, logvar = self.forward_encoder(img_emb, text_emb)
        # reparametrization trick to sample
        z = self.reparametrize(mu, logvar)
        # reconstruct inputs based on that gaussian
        image_recon, text_recon = self.forward_decoder(z)
        return image_recon, text_recon, mu, logvar

    def forward_encoder(self, img_emb=None, text_emb=None):
        if img_emb is not None:
            batch_size = img_emb.size(0)
        else:
            batch_size = text_emb.size(0)

        use_cuda = next(self.parameters()).is_cuda  # check if CUDA
        mu, logvar = prior_expert((1, batch_size, self.n_latents),
                                  use_cuda=use_cuda)
        if image is not None:
            image_mu, image_logvar = self.image_encoder(img_emb)
            mu = torch.cat((mu, image_mu.unsqueeze(0)), dim=0)
            logvar = torch.cat((logvar, image_logvar.unsqueeze(0)), dim=0)

        if text_emb is not None:
            text_mu, text_logvar = self.text_encoder(text_emb)
            mu = torch.cat((mu, text_mu.unsqueeze(0)), dim=0)
            logvar = torch.cat((logvar, text_logvar.unsqueeze(0)), dim=0)

        # product of experts to combine gaussians
        mu, logvar = self.product_of_experts(mu, logvar)
        return mu, logvar

    def forward_decoder(self, z):
        image_recon = self.image_decoder(z)
        text_recon = self.text_decoder(z)
        return image_recon, text_recon


class Swish(nn.Module):
    """https://arxiv.org/abs/1710.05941"""

    def forward(self, x):
        return x * F.sigmoid(x)


class Encoder(nn.Module):
    """Parametrizes q(z|x)."""

    def __init__(self,
                 in_dims: int,
                 hidden_dims: int,
                 out_dim: int,
                 last_activation: Type[nn.Module]
                 ):
        super(ImageEncoder, self).__init__()

        self.encode = nn.Sequential(
            nn.Linear(in_dims, hidden_dims[0]),
            nn.Swish(inplace=True),
            *[
                layer
                for idx in range(len(hidden_dims) - 1)
                for layer in (nn.Linear(hidden_dims[idx], hidden_dims[idx + 1]), nn.Swish(inplace=True))
            ],
            nn.Dropout(p=0.1),
            nn.Linear(hidden_dims[-1], out_dim * 2),
            last_activation(),
        )

    def forward(self, x):
        out_dim = self.out_dim
        x = self.encode(x)
        return x[:, :out_dim], x[:, out_dim:]


class Decoder(nn.Module):
    """Parametrizes p(x|z)."""

    def __init__(self,
                 in_dims: int,
                 hidden_dims: int,
                 out_dim: int,
                 last_activation: Type[nn.Module]
                 ):
        super(ImageDecoder, self).__init__()
        self.decode = nn.Sequential(
            nn.Linear(in_dims, hidden_dims[0]),
            nn.Swish(inplace=True),
            *[
                layer
                for idx in range(len(hidden_dims) - 1)
                for layer in (nn.Linear(hidden_dims[idx], hidden_dims[idx + 1]), nn.Swish(inplace=True))
            ],
            nn.Linear(hidden_dims[-1], out_dim * 2),
            last_activation(),
        )

    def forward(self, z):
        # the input will be a vector of size |n_latents|
        z = self.decode(z)
        # returns reconstructed image/text embedding
        return z  # NOTE: no sigmoid here. See train.py


class ProductOfExperts(nn.Module):
    """Return parameters for product of independent experts.
    See https://arxiv.org/pdf/1410.7827.pdf for equations.

    @param mu: M x D for M experts
    @param logvar: M x D for M experts
    """

    def forward(self, mu, logvar, eps=1e-8):
        var = torch.exp(logvar) + eps
        # precision of i-th Gaussian expert at point x
        T = 1. / var
        pd_mu = torch.sum(mu * T, dim=0) / torch.sum(T, dim=0)
        pd_var = 1. / torch.sum(T, dim=0)
        pd_logvar = torch.log(pd_var)
        return pd_mu, pd_logvar


def prior_expert(size, use_cuda=False):
    """Universal prior expert. Here we use a spherical
    Gaussian: N(0, 1).

    @param size: integer
                 dimensionality of Gaussian
    @param use_cuda: boolean [default: False]
                     cast CUDA on variables
    """
    mu = Variable(torch.zeros(size))
    logvar = Variable(torch.log(torch.ones(size)))
    if use_cuda:
        mu, logvar = mu.cuda(), logvar.cuda()
    return mu, logvar


# TRain/Test loop

In [None]:
import os
import sys
import shutil
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.autograd import Variable
from torchvision import transforms

from model import MVAE


def elbo_loss(recon_img_emb, image_emb, recon_text_emb, text_emb, mu, logvar,
              lambda_image=1.0, lambda_text=1.0, annealing_factor=1):
    """Bimodal ELBO loss function.

    @param recon_image: torch.Tensor
                        reconstructed image
    @param image: torch.Tensor
                  input image
    @param recon_attrs: torch.Tensor
                        reconstructed attribute probabilities
    @param attrs: torch.Tensor
                  input attributes
    @param mu: torch.Tensor
               mean of latent distribution
    @param logvar: torch.Tensor
                   log-variance of latent distribution
    @param lambda_image: float [default: 1.0]
                         weight for image BCE
    @param lambda_attrs: float [default: 1.0]
                       weight for attribute BCE
    @param annealing_factor: integer [default: 1]
                             multiplier for KL divergence term
    @return ELBO: torch.Tensor
                  evidence lower bound
    """
    image_bce, attrs_bce = 0, 0  # default params

    if recon_img_emb is not None and image_emb is not None:
        image_bce = torch.sum(binary_cross_entropy_with_logits(recon_img_emb, image_emb))

    if recon_text_emb is not None and text_emb is not None:
        text_bce = torch.sum(binary_cross_entropy_with_logits(recon_text_emb, text_emb))

    # Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014
    # https://arxiv.org/abs/1312.6114
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp(), dim=1)
    ELBO = torch.mean(lambda_image * image_bce + lambda_text * text_bce + annealing_factor * KLD)
    return ELBO


def binary_cross_entropy_with_logits(input, target):
    """Sigmoid Activation + Binary Cross Entropy

    @param input: torch.Tensor (size N)
    @param target: torch.Tensor (size N)
    @return loss: torch.Tensor (size N)
    """
    if not (target.size() == input.size()):
        raise ValueError("Target size ({}) must be the same as input size ({})".format(
            target.size(), input.size()))

    return (torch.clamp(input, 0) - input * target
            + torch.log(1 + torch.exp(-torch.abs(input))))


class AverageMeter(object):
    """Computes and stores the average and current value"""

    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count


def save_checkpoint(state, is_best, folder='./', filename='checkpoint.pth.tar'):
    if not os.path.isdir(folder):
        os.mkdir(folder)
    torch.save(state, os.path.join(folder, filename))
    if is_best:
        shutil.copyfile(os.path.join(folder, filename),
                        os.path.join(folder, 'model_best.pth.tar'))


def load_checkpoint(file_path, use_cuda=False):
    checkpoint = torch.load(file_path) if use_cuda else \
        torch.load(file_path, map_location=lambda storage, location: storage)
    model = MVAE(checkpoint['n_latents'])
    model.load_state_dict(checkpoint['state_dict'])
    return model


def train(epoch, lambda_img_emb, lambda_text_emb):
    model.train()
    train_loss_meter = AverageMeter()

    # NOTE: is_paired is 1 if the example is paired
    for batch_idx, (img_emb, text_emb) in enumerate(train_loader):
        annealing_factor = 1.0

        if args.cuda:
            img_emb = img_emb.cuda()
            text_emb = text_emb.cuda()
        img_emb = Variable(img_emb)
        text_emb = Variable(text_emb)
        batch_size = len(image_emb)

        # refresh the optimizer
        optimizer.zero_grad()

        train_loss = 0  # accumulate train loss here so we don't store a lot of things.

        # compute ELBO using all data (``complete")
        recon_img_emb, recon_text_emb, mu, logvar = model(img_emb, text_emb)
        train_loss += elbo_loss(recon=[recon_img_emb] + [recon_text_emb],
                                data=[img_emb] + [text_emb],
                                mu=mu,
                                logvar=logvar,
                                lambda_img_emb=lambda_img_emb,
                                lambda_attrs=lambda_text_emb,
                                annealing_factor=annealing_factor)

        # compute ELBO using only img_emb data
        recon_img_emb, _, mu, logvar = model(img_emb=img_emb)
        train_loss += elbo_loss(recon=[recon_img_emb],
                                data=[img_emb],
                                mu=mu,
                                logvar=logvar,
                                lambda_img_emb=lambda_img_emb,
                                lambda_attrs=lambda_text_emb,
                                annealing_factor=annealing_factor)

        # compute ELBO using only text data
        _, recon_text_emb, mu, logvar = model(text_emb=text_emb)
        train_loss += elbo_loss(recon=[recon_text_emb],
                                data=[text_emb],
                                mu=mu,
                                logvar=logvar,
                                lambda_img_emb=lambda_img_emb,
                                lambda_attrs=lambda_text_emb,
                                annealing_factor=annealing_factor)

        # compute and take gradient step
        train_loss.backward()
        optimizer.step()

        if batch_idx % args.log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}\tAnnealing-Factor: {:.3f}'.format(
                epoch, batch_idx * len(image_emb), len(train_loader.dataset),
                       100. * batch_idx / len(train_loader), train_loss_meter.avg, annealing_factor))

    print('====> Epoch: {}\tLoss: {:.4f}'.format(epoch, train_loss_meter.avg))


def test(epoch, lambda_img_emb, lambda_text_emb):
    model.eval()
    test_loss_meter = AverageMeter()

    pbar = tqdm(total=len(test_loader))
    for batch_idx, (img_emb, txt_emb) in enumerate(test_loader):
        if args.cuda:
            img_emb = img_emb.cuda()
            attrs = attrs.cuda()

        img_emb = Variable(img_emb, volatile=True)
        attrs = Variable(attrs, volatile=True)
        batch_size = len(img_emb)

        test_loss = 0  # accumulate train loss here so we don't store a lot of things.

        # compute ELBO using all data (``complete")
        recon_img_emb, recon_text_emb, mu, logvar = model(img_emb, text_emb)
        test_loss += elbo_loss(recon=[recon_img_emb] + [recon_text_emb],
                               data=[img_emb] + [text_emb],
                               mu=mu,
                               logvar=logvar,
                               lambda_img_emb=lambda_img_emb,
                               lambda_attrs=lambda_text_emb,
                               annealing_factor=annealing_factor)

        # compute ELBO using only img_emb data
        recon_img_emb, _, mu, logvar = model(img_emb=img_emb)
        test_loss += elbo_loss(recon=[recon_img_emb],
                               data=[img_emb],
                               mu=mu,
                               logvar=logvar,
                               lambda_img_emb=lambda_img_emb,
                               lambda_attrs=lambda_text_emb,
                               annealing_factor=annealing_factor)

        # compute ELBO using only text data
        _, recon_text_emb, mu, logvar = model(text_emb=text_emb)
        test_loss += elbo_loss(recon=[recon_text_emb],
                               data=[text_emb],
                               mu=mu,
                               logvar=logvar,
                               lambda_img_emb=lambda_img_emb,
                               lambda_attrs=lambda_text_emb,
                               annealing_factor=annealing_factor)
        test_loss_meter.update(test_loss.data[0], batch_size)
        pbar.update()

    pbar.close()
    print('====> Test Loss: {:.4f}'.format(test_loss_meter.avg))
    return test_loss_meter.avg

In [None]:
best_loss = sys.maxint
for epoch in range(1, args.epochs + 1):
    train(epoch)
    loss      = test(epoch)
    is_best   = loss < best_loss
    best_loss = min(loss, best_loss)
    # save the best model and current model
    save_checkpoint({
        'state_dict': model.state_dict(),
        'best_loss': best_loss,
        'n_latents': args.n_latents,
        'optimizer' : optimizer.state_dict(),
    }, is_best, folder='./trained_models')