<a href="https://colab.research.google.com/github/jonberliner/mivae/blob/master/mivae.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install pytorch-ignite tqdm


Collecting pytorch-ignite
[?25l  Downloading https://files.pythonhosted.org/packages/c0/8e/08569347023611e40e62a14162024ca6238d42cb528b2302f84d662a2033/pytorch_ignite-0.4.1-py2.py3-none-any.whl (166kB)
[K     |██                              | 10kB 18.3MB/s eta 0:00:01[K     |████                            | 20kB 6.5MB/s eta 0:00:01[K     |██████                          | 30kB 8.1MB/s eta 0:00:01[K     |███████▉                        | 40kB 8.8MB/s eta 0:00:01[K     |█████████▉                      | 51kB 7.3MB/s eta 0:00:01[K     |███████████▉                    | 61kB 8.3MB/s eta 0:00:01[K     |█████████████▊                  | 71kB 8.6MB/s eta 0:00:01[K     |███████████████▊                | 81kB 8.8MB/s eta 0:00:01[K     |█████████████████▊              | 92kB 8.2MB/s eta 0:00:01[K     |███████████████████▋            | 102kB 8.5MB/s eta 0:00:01[K     |█████████████████████▋          | 112kB 8.5MB/s eta 0:00:01[K     |███████████████████████▋        | 12

In [10]:
import os
import sys
from typing import Optional, List

import torch
from torch import nn
from torch.nn import functional as F
from torch import distributions as D
from torch.utils.data import Dataset, DataLoader
from torchvision.models import squeezenet1_1, resnet18
from torchvision.datasets import FashionMNIST
from torchvision import transforms
import numpy


In [19]:
### define statics
device = 'cuda' if torch.cuda.is_available() else 'cpu'

INPUT_SIZE = 784
DIM_X = 3
SIZE_X = 28
BACKBONE_OUTPUT_SIZE = 256 * 2
NUM_Z_PARTITIONS = 2
Z1_SIZE = 13
Z2_SIZE = 17
DIM_Z = Z1_SIZE + Z2_SIZE

In [12]:
### define encoder (choice of backbone is arbitrary)

class Encoder(nn.Module):
    def __init__(self,
            backbone: nn.Module,
            dim_x: int,
            dim_z: int,
            backbone_output_dim: int) -> None:
        super().__init__()
        self.backbone = backbone
        self.dim_x = dim_x
        self.dim_z = dim_z
        self.backbone_output_dim = backbone_output_dim

        self.readout = nn.Linear(self.backbone_output_dim, self.dim_z)

    def forward(self, 
                inputs: torch.Tensor,
                return_backbone_outputs: Optional[bool]=False) -> torch.Tensor:
        bb_outputs = self.backbone(inputs)
        outputs = bb_outputs
        while len(outputs.shape) > 2:
            outputs = torch.mean(outputs, dim=-1)
        outputs = self.readout(outputs)
        if return_backbone_outputs:
            return maxes, bb_outputs
        else:
            return outputs


In [13]:
### define a decoder (choice of architecture is arbitrary)

class Decoder(nn.Module):
    def __init__(self,
            dim_z: int,
            dim_x: int,
            x_size: int=28) -> None:
        super().__init__()

        self.readin = nn.Conv2d(dim_z, 64, kernel_size=(1, 1), stride=(1, 1))
        self.resnet = resnet18(pretrained=False)
        self.readout = nn.Conv2d(512, dim_x, kernel_size=(1, 1), stride=(1, 1))
        self.model = nn.Sequential(*[
            self.readin,
            self.resnet.bn1,
            self.resnet.relu,
            self.resnet.layer1,
            self.resnet.layer2,
            self.resnet.layer3,
            self.resnet.layer4,
            self.readout])

        self.dim_z = dim_z
        self.dim_x = dim_x
        self.x_size = x_size

    def forward(self, zs: torch.Tensor) -> torch.Tensor:
        inputs = torch.reshape(zs, (-1, self.dim_z, 1, 1))\
                      .expand(-1, self.dim_z, self.x_size*8, self.x_size*8)
        outputs = self.model(inputs)
        assert outputs.shape[2] == self.x_size
        return outputs
            

In [None]:
### prep dataset and data transforms

# Data augmentation and normalization for training
# Just normalization for validation
data_transforms = {
    'train': transforms.Compose([
        transforms.RandomResizedCrop(INPUT_SIZE),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        lambda x: x.expand(3, -1, -1),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'val': transforms.Compose([
        transforms.Resize(INPUT_SIZE),
        transforms.CenterCrop(INPUT_SIZE),
        transforms.ToTensor(),
        lambda x: x.expand(3, -1, -1),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
}

dataset = FashionMNIST('./FashionMNIST', download=True, transform=data_transforms['train'])
dataloader = DataLoader(dataset)

In [15]:
### init encoders and decoder
squeezenet1 = squeezenet1_1(pretrained=True).features
squeezenet2 = squeezenet1_1(pretrained=True).features

encoders = {}
for ii in range(len(NUM_Z_PARTITIONS)):
    encoders.a
# Z1_SIZE-way normal distr
encoder1 = Encoder(squeezenet1, INPUT_SIZE, Z1_SIZE * 2, BACKBONE_OUTPUT_SIZE)
# Z2_SIZE-way normal distr
encoder2 = Encoder(squeezenet2, INPUT_SIZE, Z2_SIZE * 2, BACKBONE_OUTPUT_SIZE)

decoder = Decoder(DIM_Z, DIM_X, SIZE_X)

encoder1 = encoder1.to(device)
encoder2 = encoder2.to(device)

decoder = decoder.to(device)

In [22]:
### forward pass through VAE

# set priors (can be any distr can call rsample on)
p_x = D.Normal(
            loc=torch.tensor([0.485, 0.456, 0.406]),
            scale=torch.tensor[0.229, 0.224, 0.225]))

p_z1 = D.Normal(
            loc=torch.zeros(Z1_SIZE), 
            scale=torch.ones(Z1_SIZE))
 
p_z2 = D.Normal(
            loc=torch.zeros(Z2_SIZE),
            scale=torch.ones(Z2_SIZE))

inputs = dataset[0][0].unsqueeze(0)

for xx, yy in data_loader:
    inputs = xx.to(device)

    # first pass of inference for VAE loss
    inputs = inputs.to(device)
    batch_size = inputs.shape[0]

    inferred11 = encoder1(inputs)
    inferred21 = encoder2(inputs)

    # draw from p(z|x) for all siblings z1,...,zn that constitute z
    p_z1_given_x_1 = D.Normal(
            loc=inferred11[:, :Z1_SIZE], 
            scale=F.softplus(inferred11[:, Z1_SIZE:]) + 1e-4)

    p_z2_given_x_1 = D.Normal(
            loc=inferred21[:, :Z2_SIZE], 
            scale=F.softplus(inferred21[:, Z2_SIZE:]) + 1e-4)

    z1_given_x_1 = p_z1_given_x_1.rsample()
    z2_given_x_1 = p_z2_given_x_1.rsample()

    # combine z1,...,zn into agg'd z for generative model
    z_given_x_1 = torch.cat([z1_given_x_1, z2_given_x_1], dim=1)
    p_recon_x_given_z = decoder(z_given_x_1)

    # calc standard vae loss
    vae_loss_z1 = D.kl_divergence(p_z1_given_x, p_z1)
    vae_loss_z2 = D.kl_divergence(p_z2_given_x, p_z2)
    vae_loss_x = D.kl_divergence(p_recon_x_given_z, p_x)

    
    # second pass inference for MIVAE Loss

    # choose which zi will be prior and which posterior
    if rng.rand() > 0.5:
        # z1 draws from p(z1|x), z2 from p(z2)
        post = 1
        z1 = z1_given_x_1.rsample()
        z2 = p_z2.rsample(size=batch_size)
    else:
        post = 2
        z1 = p_z1.rsample(size=batch_size)
        z2 = p_z2_given_x.rsample()

    # generate synthetic sample for mutual info loss
    z_given_x_2 = torch.cat([z1, z2])
    
    # get distr for x | z_given_x_2
    p_x2_given_z_logits = decoder(z_given_x_2)
    p_x2_given_z = D.Normal(
        loc=p_x2_given_z_logits[:, :SIZE_X], 
        scale=F.softplus(p_x2_given_z_logits[:, SIZE_X:]) + 1e-4)
    # draw synthetic x2
    x2_given_z = p_x2_given_z.rsample()

    # infer from synthetic sample
    inferred12 = encoder1(x2)        
    inferred22 = encoder2(x2)

    p_z1_given_x2 = D.Normal(
        loc=inferred12[:, :Z2_SIZE], 
        scale=F.softplus(inferred12[:, Z2_SIZE:]) + 1e-4)

    p_z2_given_x2 = D.Normal(
        loc=inferred22[:, :Z2_SIZE], 
        scale=F.softplus(inferred22[:, Z2_SIZE:]) + 1e-4)
    
    # calc mutual info loss
    if post == 1:
        mi_loss_z1 = D.kl_divergence(p_z1_given_x2, p_z1_given_x)
        mi_loss_z2 = D.kl_divergence(p_z2_given_x2, p_z2)
    elif post == 2:
        mi_loss_z1 = D.kl_divergence(p_z1_given_x2, p_z1)
        mi_loss_z2 = D.kl_divergence(p_z2_given_x2, p_z2_given_x)

    mi_loss_x = D.kl_divergence(p_x2_given_z, p_x)

    # clip losses for stability
    vae_loss_z1 = vae_loss_z1.clamp(max=100.)
    vae_loss_z2 = vae_loss_z2.clamp(max=100.)
    vae_loss_x = vae_loss_x.clamp(max=100.)

    mi_loss_z1 = mi_loss_z1.clamp(max=100.)
    mi_loss_z2 = mi_loss_z2.clamp(max=100.)
    mi_loss_x = mi_loss_x.clamp(max=100.)

    # add together losses
    vae_loss = vae_loss_z1 + vae_loss_z2 + vae_loss_x
    mi_loss = mi_loss_z1 + mi_loss_z2 + mi_loss_x

ValueError: ignored