In [1]:
from __future__ import print_function
import torch
import torch.utils.data
from torch import nn, optim
from torch.nn import functional as F
from torchvision import datasets, transforms
from torchvision.utils import save_image
from tensorboardX import SummaryWriter

from tqdm import tqdm

batch_size = 36
h = 64
w = 64
epochs = 40
seed = 5

context_num = 10 # length of contexts

torch.manual_seed(seed)

if torch.cuda.is_available():
    device = "cuda"
else:
    device = "cpu"

In [2]:
from Tars.distributions import Normal, Bernoulli
from Tars.distributions.divergences import KullbackLeibler
from Tars.models import VAE

In [3]:
xDim=3
vDim=7
rDim=256
hDim=128
zDim=64
L=12
SCALE = 4 # Scale of image generation process

In [4]:
# utility
class Conv2dLSTMCell(nn.Module):
    """
    2d convolutional long short-term memory (LSTM) cell.
    Functionally equivalent to nn.LSTMCell with the
    difference being that nn.Kinear layers are replaced
    by nn.Conv2D layers.

    :param in_channels: number of input channels
    :param out_channels: number of output channels
    :param kernel_size: size of image kernel
    :param stride: length of kernel stride
    :param padding: number of pixels to pad with
    """
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1):
        super(Conv2dLSTMCell, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels

        kwargs = dict(kernel_size=kernel_size, stride=stride, padding=padding)

        self.forget = nn.Conv2d(in_channels, out_channels, **kwargs)
        self.input  = nn.Conv2d(in_channels, out_channels, **kwargs)
        self.output = nn.Conv2d(in_channels, out_channels, **kwargs)
        self.state  = nn.Conv2d(in_channels, out_channels, **kwargs)

    def forward(self, input, states):
        """
        Send input through the cell.

        :param input: input to send through
        :param states: (hidden, cell) pair of internal state
        :return new (hidden, cell) pair
        """
        (hidden, cell) = states

        forget_gate = F.sigmoid(self.forget(input))
        input_gate  = F.sigmoid(self.input(input))
        output_gate = F.sigmoid(self.output(input))
        state_gate  = F.tanh(self.state(input))

        # Update internal cell state
        cell = forget_gate * cell + input_gate * state_gate
        hidden = output_gate * F.tanh(cell)

        return hidden, cell

In [5]:
# Using TowerRepresentation
class Representation(nn.Module):
    def __init__(self, n_channels, v_dim, r_dim=256, pool=True):
        """
        Network that generates a condensed representation
        vector from a joint input of image and viewpoint.

        Employs the tower/pool architecture described in the paper.

        :param n_channels: number of color channels in input image
        :param v_dim: dimensions of the viewpoint vector
        :param r_dim: dimensions of representation
        :param pool: whether to pool representation
        """
        super(Representation, self).__init__()
        # Final representation size
        self.r_dim = k = r_dim
        self.pool = pool

        self.conv1 = nn.Conv2d(n_channels, k, kernel_size=2, stride=2)
        self.conv2 = nn.Conv2d(k, k, kernel_size=2, stride=2)
        self.conv3 = nn.Conv2d(k, k//2, kernel_size=3, stride=1, padding=1)
        self.conv4 = nn.Conv2d(k//2, k, kernel_size=2, stride=2)

        self.conv5 = nn.Conv2d(k + v_dim, k, kernel_size=3, stride=1, padding=1)
        self.conv6 = nn.Conv2d(k + v_dim, k//2, kernel_size=3, stride=1, padding=1)
        self.conv7 = nn.Conv2d(k//2, k, kernel_size=3, stride=1, padding=1)
        self.conv8 = nn.Conv2d(k, k, kernel_size=1, stride=1)

        self.avgpool  = nn.AvgPool2d(k//16)

    def forward(self, x, v):
        """
        Send an (image, viewpoint) pair into the
        network to generate a representation
        :param x: image
        :param v: viewpoint (x, y, z, cos(yaw), sin(yaw), cos(pitch), sin(pitch))
        :return: representation
        """
        # Increase dimensions
        v = v.view(v.size(0), -1, 1, 1)
        v = v.repeat(1, 1, self.r_dim // 16, self.r_dim // 16)

        # First skip-connected conv block
        skip_in  = F.relu(self.conv1(x))
        skip_out = F.relu(self.conv2(skip_in))

        x = F.relu(self.conv3(skip_in))
        x = F.relu(self.conv4(x)) + skip_out

        # Second skip-connected conv block (merged)
        skip_in = torch.cat([x, v], dim=1)
        skip_out  = F.relu(self.conv5(skip_in))

        x = F.relu(self.conv6(skip_in))
        x = F.relu(self.conv7(x)) + skip_out

        r = F.relu(self.conv8(x))

        if self.pool:
            r = self.avgpool(r)

        return r

rep = Representation(xDim, vDim, rDim)

In [18]:
class GeneratorCore(nn.Module):
    def __init__(self, v_dim, r_dim, z_dim=64, h_dim=128):
        super(GeneratorCore, self).__init__()
        self.core = Conv2dLSTMCell(v_dim + r_dim + z_dim, h_dim, kernel_size=5, stride=1, padding=2)
        self.upsample   = nn.ConvTranspose2d(h_dim, h_dim, kernel_size=SCALE, stride=SCALE, padding=0)
        
    def forward(self, h_g, c_g, u):
        h_g, c_g =  self.core(torch.cat([z, v, r], dim=1), [h_g, c_g])
        u = self.upsample(h_g) + u
        return h_g, c_g, u
generator_core = GeneratorCore(vDim, rDim)


class InferenceCore(nn.Module):
    def __init__(self, x_dim, v_dim, r_dim, h_dim=128):
        super(InferenceCore, self).__init__()
        self.core = Conv2dLSTMCell(h_dim + x_dim + v_dim + r_dim, h_dim, kernel_size=5, stride=1, padding=2)
        
    def forward(self, h_g, h_e, c_e):
        h_e, c_e = self.core(torch.cat([h_g, x, v, r], dim=1), [h_e, c_e])
        return h_e, c_e
    
inference_core = InferenceCore(xDim, vDim, rDim, hDim)

In [7]:
# ステップにより変更
sigma = 1

In [8]:
class Generator(Normal):
    def __init__(self):
        super(Generator, self).__init__(conv_var=["z","v_q","r"],var=["x_q"])
        self.eta_g = nn.Conv2d(hDim, xDim, kernel_size=1, stride=1, padding=0)
        
    def forward(self, z, v_q, r):
        mu = F.sigmoid(self.observation_density(u))
        return {"loc":mu, "scale":sigma}

class Prior(Normal):
    def __init__(self):
        super(Prior, self).__init__(conv_var=["v_q","r"],var=["z"])
        self.eta_pi = nn.Conv2d(hDim, 2*zDim, kernel_size=5, stride=1, padding=2)
        
    def forward(self, v_q, r):
        mu, std = torch.split(self.eta_pi(hidden_g), zDim, dim=1)
        return {"loc":mu ,"scale":F.softplus(std)}
    
class Inference(Normal):
    def __init__(self):
        super(Inference, self).__init__(conv_var=["x_q","v_q","r"],var=["z"])
        self.eta_e = nn.Conv2d(hDim, 2*zDim, kernel_size=5, stride=1, padding=2)
        
    def forward(self, x_q, v_q, r):
        mu, std = torch.split(self.eta_e(hidden_e), zDim, dim=1)
        return {"loc":mu, "scale":std}

In [19]:
# reset generater state
hidden_g = torch.zeros((batch_size, hDim, h//SCALE, w//SCALE))
cell_g = torch.zeros((batch_size, hDim, h//SCALE, w//SCALE))
u = torch.zeros((batch_size, hDim, h, w))

# reset inference state
hidden_e = torch.zeros((batch_size, hDim, h//SCALE, w//SCALE))
cell_e = torch.zeros((batch_size, hDim, h//SCALE, w//SCALE))

# autoregressive model
regularizer = []
for _ in range(L):    
    # kl
    pi = Prior().to(device)
    q = Inference().to(device)
    kl = KullbackLeibler(q, pi)
    regularizer.append(kl)
    # update state
    hidden_e, cell_e = inference_core(hidden_g, hidden_e, cell_e)
    hidden_g, cell_g, u = generator_core(hidden_g, cell_g, u)
    
g = Generator().to(device)

p = g * pi


NameError: name 'x' is not defined

In [None]:
kl = KullbackLeibler(q, prior)
# TODO: change optimizer and learning rate
model = VAE(q, g, regularizer=[kl], optimizer=optim.Adam, optimizer_params={"lr":1e-3})

In [None]:
phi = rep(x, v)
_, *phi_dims = phi.size()
phi = phi.view((batch_size, n_views, *phi_dims))
r = torch.sum(phi, dim=1)


In [None]:
def split_context_target(x, v, context_num=100, device="cpu", shuffle_each_example=True):
    x, v = shuffle_dim(x, v, shuffle_each_example=shuffle_each_example)
    x_context = torch.Tensor(x[:context_num]).to(device)
    v_context = torch.Tensor(v[:context_num]).to(device)
    
    x_target = torch.Tensor(x[context_num:]).to(device)
    v_target = torch.Tensor(v[context_num:]).to(device)
    return x_context, v_context, x_target, v_target

In [None]:
def train(epoch):
    train_loss = 0    
    t = tqdm(train_loader)
    for batch_idx, (data, _) in enumerate(t):
        t.set_description('Epoch: {}'.format(epoch))
        
        x, V =
        x, v, x_q, y_q = split_context_target(x, v, context_num=context_num, device=device)
        
        lower_bound, loss = model.train({"x": x, "v": v, "x_q": x_q, "v_q": v_q})
        train_loss += loss
        
        t.set_postfix(loss=loss.item())
        
    train_loss = train_loss * train_loader.batch_size / len(train_loader.dataset)
    print('Epoch: {} Train loss: {:.4f}'.format(epoch, train_loss))
    return train_loss