# Consistent Generative Query Networks
### https://deepmind.com/research/publications/consistent-generative-query-networks

#### There are no open MNIST Dice datasets from DeepMind, but you can make them by using this.
##### https://github.com/musyoku/gqn-dataset-renderer

#### You can also train CGQN using GQN datasets like Shepard-Metzler.
##### Datasets: https://github.com/deepmind/gqn-datasets
##### Datasets Translater: https://github.com/l3robot/gqn_datasets_translator

In [1]:
import torch
import torch.utils.data
from torch import nn, optim
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torchvision.utils import save_image

import os
import datetime
import random
import math
from tensorboardX import SummaryWriter

from pixyz.distributions import Normal
from pixyz.losses import KullbackLeibler

from tqdm import tqdm

from shepardmetzler import ShepardMetzler, Scene, transform_viewpoint
from conv_lstm import Conv2dLSTMCell

seed = 1234
torch.manual_seed(seed)

<torch._C.Generator at 0x7fa2ec0f5f90>

In [2]:
class Representation(nn.Module):
    def __init__(self, nf_v, nf_f):
        super(Representation, self).__init__()
        self.conv1 = nn.Conv2d(nf_v+nf_f, 8, kernel_size=2, stride=2)
        self.conv2 = nn.Conv2d(8, 16, kernel_size=2, stride=2)
        self.conv3 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1)
        self.conv4 = nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1)

    def forward(self, v, f):
        # Increase dimensions
        v = v.view(v.size(0), -1, 1, 1)
        v = v.repeat(1, 1, f.size(2), f.size(3))
        
        h = F.relu(self.conv1(torch.cat([v, f], dim=1)))
        h = F.relu(self.conv2(h))
        h = F.relu(self.conv3(h))
        r = self.conv4(h)

        return r

In [3]:
class Posterior(Normal):
    def __init__(self, nf_to_hidden=64, nf_z=3):
        super(Posterior, self).__init__(cond_var=["h_e"], var=["z"])
        self.nf_z = nf_z
        self.conv = nn.Conv2d(nf_to_hidden, 2*nf_z, kernel_size=5, stride=1, padding=2)
        
    def forward(self, h_e):
        mu, logvar = torch.split(self.conv(h_e), self.nf_z, dim=1)
        return {"loc": mu, "scale": F.softplus(logvar)}

In [4]:
class Prior(Normal):
    def __init__(self, nf_to_obs=128, nf_z=3):
        super(Prior, self).__init__(cond_var=["h_d"], var=["z"])
        self.nf_z = nf_z
        self.conv = nn.Conv2d(nf_to_obs, 2*nf_z, kernel_size=5, stride=1, padding=2)
        
    def forward(self, h_d):
        mu, logvar = torch.split(self.conv(h_d), self.nf_z, dim=1)
        return {"loc": mu, "scale": F.softplus(logvar)}

In [5]:
class Renderer(Normal):
    def __init__(self, nf_to_obs=128, nf_v=7, nf_f=3):
        super(Renderer, self).__init__(cond_var=["h_d", "v", "var"], var=["f"])
        self.convt = nn.ConvTranspose2d(nf_to_obs+nf_v, nf_f, kernel_size=2, stride=2)
    def forward(self, h_d, v, var):
        # Increase dimensions
        v = v.view(v.size(0), -1, 1, 1)
        v = v.repeat(1, 1, h_d.size(2), h_d.size(3))
        mu = self.convt(torch.cat([h_d, v], dim=1))
        return {"loc": mu, "scale": math.sqrt(var)}

In [6]:
class CGQN(nn.Module):
    def __init__(self, nf_v=7, nf_f=3, nf_r=32, nt=4, nf_to_hidden=64, nf_enc=128, nf_to_obs=128, nf_dec=64, nf_z=3):
        super(CGQN, self).__init__()
        self.nf_f = nf_f
        self.nf_to_hidden = nf_to_hidden
        self.nf_to_obs = nf_to_obs
        self.nt = nt
        
        self.m_theta = Representation(nf_v, nf_f)
        
        self.encoder = nn.Conv2d(nf_f, nf_enc, kernel_size=2, stride=2)
        self.decoder = nn.Conv2d(nf_f, nf_dec, kernel_size=2, stride=2)
        
        self.m_gamma = Renderer(nf_to_obs, nf_v, nf_f)
        

        # Outputs parameters of distributions
        self.posterior = Posterior(nf_to_hidden, nf_z)
        self.prior = Prior(nf_to_obs, nf_z)

        # Recurrent encoder/decoder models
        self.lstm_enc = Conv2dLSTMCell(nf_enc+nf_r+nf_to_obs, nf_to_hidden, kernel_size=5, stride=1, padding=2)
        self.lstm_dec = Conv2dLSTMCell(nf_z+nf_dec+nf_r, nf_to_obs, kernel_size=5, stride=1, padding=2)
        
        self.upsample   = nn.ConvTranspose2d(nf_r, nf_r, kernel_size=2, stride=2, padding=0)

    def forward(self, v, f, v_prime, f_T, var):
        batch_size, m, _, h, w = f.size()
        # num of target
        k = f_T.size(1)
        # merge batch and view dimensions.
        _, _, *v_dims = v.size()
        _, _, *f_dims = f.size()

        v = v.view((-1, *v_dims))
        f = f.view((-1, *f_dims))
        
        v_prime = v_prime.view((-1, *v_dims))
        f_T = f_T.view((-1, *f_dims))
        
        r = self.m_theta(v, f)
        r_T = self.m_theta(v_prime, f_T)
        
        # seperate batch and view dimensions
        _, *r_dims = r.size()
        r = r.view((batch_size, m, *r_dims))
        r_T = r_T.view((batch_size, k, *r_dims))

        # sum over view representations
        r = torch.sum(r, dim=1)
        r_T = torch.sum(r_T, dim=1)
        
        # expand dimensions
        r = r.repeat(1, k, 1, 1)
        r = r.view((-1, *r_dims))
        r_T = r_T.repeat(1, k, 1, 1)
        r_T = r_T.view((-1, *r_dims))

        # hidden states
        h_e = f.new_zeros((batch_size*k, self.nf_to_hidden, h//2, w//2))
        h_d = f.new_zeros((batch_size*k, self.nf_to_obs, h//2, w//2))

        # cell states
        c_e = f.new_zeros((batch_size*k, self.nf_to_hidden, h//2, w//2))
        c_d = f.new_zeros((batch_size*k, self.nf_to_obs, h//2, w//2))
        
        canvas = f.new_zeros((batch_size*k, self.nf_f, h, w))

        r = self.upsample(r)
        r_T = self.upsample(r_T)
        enc_input = self.encoder(f_T)
        
        kl = 0
        for _ in range(self.nt):
            
            # update encoder LSTM states
            h_e, c_e = self.lstm_enc(torch.cat([enc_input, r_T, h_d], dim=1), [h_e, c_e])

            # sample from posterior
            z = self.posterior.sample({"h_e": h_e}, reparam=True)["z"]
            
            # kl divergence between posterior and prior
            _kl = KullbackLeibler(self.posterior, self.prior).mean()
            _kl = _kl.estimate({"h_e": h_e, "h_d": h_d})
            kl += _kl

            dec_input = self.decoder(canvas)

            # update decoder LSTM states
            h_d, c_d = self.lstm_dec(torch.cat([z, dec_input, r], dim=1), [h_d, c_d])

            # refine representation
            canvas = self.m_gamma.sample_mean({"h_d": h_d, "v": v_prime, "var": var})
            
        f_R = torch.clamp(canvas.view((batch_size, k, *f_dims)), 0, 1)
        f_R_noise = self.m_gamma.sample({"h_d": h_d, "v": v_prime, "var": var}, reparam=True)["f"]
        MSE = nn.MSELoss()
        mse = MSE(f_T, f_R_noise)

        return f_R, mse, kl
    
    def sample(self, v, f, v_prime):
        batch_size, m, _, h, w = f.size()
        
        # num of target
        k = v_prime.size(1)
        
        # merge batch and view dimensions.
        _, _, *v_dims = v.size()
        _, _, *f_dims = f.size()

        v = v.view((-1, *v_dims))
        f = f.view((-1, *f_dims))
        
        v_prime = v_prime.view((-1, *v_dims))
        
        r = self.m_theta(v, f)
        
        # seperate batch and view dimensions
        _, *r_dims = r.size()
        r = r.view((batch_size, m, *r_dims))

        # sum over view representations
        r = torch.sum(r, dim=1)
        
        # expand dimensions
        r = r.repeat(1, k, 1, 1)
        r = r.view((-1, *r_dims))

        # hidden states
        h_d = f.new_zeros((batch_size*k, self.nf_to_obs, h//2, w//2))
        # cell states
        c_d = f.new_zeros((batch_size*k, self.nf_to_obs, h//2, w//2))

        canvas = f.new_zeros((batch_size*k, self.nf_f, h, w))
        r = self.upsample(r)

        for _ in range(self.nt):
            # sample from prior
            z = self.prior.sample({"h_d": h_d})["z"]
            
            dec_input = self.decoder(canvas)
            # update decoder LSTM states
            h_d, c_d = self.lstm_dec(torch.cat([z, dec_input, r], dim=1), [h_d, c_d])

            canvas = self.m_gamma.sample_mean({"h_d": h_d, "v": v_prime, "var": 0})
            
        f_R = torch.clamp(canvas.view((batch_size, k, *f_dims)), 0, 1)

        return f_R

In [7]:
def arrange_data(v_data, f_data, seed=None):
    random.seed(seed)
    batch_size, n, *_ = f_data.size()

    # Sample random number of views
    m = random.randint(1, n-2)
#     k = random.randint(m+1, n)
    k = m+2

    indices = torch.randperm(n)
    input_idx, target_idx = indices[:m], indices[m:k]

    v, f = v_data[:, input_idx], f_data[:, input_idx]
    v_prime, f_T = v_data[:, target_idx], f_data[:, target_idx]
    
    return v, f, v_prime, f_T

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"

# args
train_data_dir = '/workspace/dataset/shepard_metzler_7_parts-torch/train'
test_data_dir = '/workspace/dataset/shepard_metzler_7_parts-torch/test'

# number of workers to load data
num_workers = 0

# for logging
log_interval_num = 100
save_interval_num = 100000
dir_name = str(datetime.datetime.now())
log_dir = '/workspace/logs/'+ dir_name
os.mkdir(log_dir)
os.mkdir(log_dir+'/models')
os.mkdir(log_dir+'/runs')

# tensorboardX
writer = SummaryWriter(log_dir=log_dir+'/runs')

batch_size = 36
gradient_steps = 2*(10**6)

train_dataset = ShepardMetzler(root_dir=train_data_dir, target_transform=transform_viewpoint)
test_dataset = ShepardMetzler(root_dir=test_data_dir, target_transform=transform_viewpoint)


# hyperparameters for traveling salesman dataset
# nf_v=7
# nf_f=3
# nf_r=32
# nt=4
# nf_to_hidden=64
# nf_enc=128
# nf_to_obs=128
# nf_dec=64
# nf_z=3
# alpha, beta = 2.0, 0.5

# var = alpha

# hyperparameters for MNIST Cube 3D scene reconstruction task
nf_v=7
nf_f=3
nf_r=32
nt=6
nf_to_hidden=128
nf_enc=128
nf_to_obs=128
nf_dec=128
nf_z=3

var = 2.0

hyperparam = (nf_v, nf_f, nf_r, nt, nf_to_hidden, nf_enc, nf_to_obs, nf_dec, nf_z)

# model
model = CGQN(*hyperparam).to(device)
model = nn.DataParallel(model, device_ids=[0, 1])

optimizer = torch.optim.Adam(model.parameters())
kwargs = {'num_workers':num_workers, 'pin_memory': True} if torch.cuda.is_available() else {}

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, **kwargs)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True, **kwargs)
    
f_data_test, v_data_test = next(iter(test_loader))

# number of gradient steps
s = 0
while True:
    for f_data, v_data in tqdm(train_loader):
        f_data = f_data.to(device)
        v_data = v_data.to(device)
        v, f, v_prime, f_T = arrange_data(v_data, f_data)
        f_R, mse, kl = model(v, f, v_prime, f_T, var)
        mse = mse.mean()
        kl = kl.mean()
        loss = mse / var + kl
        loss.backward()
        
        optimizer.step()
        optimizer.zero_grad()
        
        writer.add_scalar('train_mse', mse, s)
        writer.add_scalar('train_kl', kl, s)
        writer.add_scalar('train_loss', loss, s)
        
        s += 1
        
        with torch.no_grad():
            # keep a checkpoint every n steps
            if s % log_interval_num == 0 or s == 1:
                writer.add_image('train_ground_truth', f_T[0], s)
                writer.add_image('train_reconstruction', f_R[0], s)
                
                f_data_test = f_data_test.to(device)
                v_data_test = v_data_test.to(device)
                
                v_test, f_test, v_prime_test, f_T_test = arrange_data(v_data_test, f_data_test, seed=0)
                f_R_test, mse_test, kl_test = model(v_test, f_test, v_prime_test, f_T_test, var)
                f_gen_test = model.module.sample(v_test, f_test, v_prime_test)
                
                mse_test = mse_test.mean()
                kl_test = kl_test.mean()
                loss_test = mse_test / var + kl_test
                
                writer.add_scalar('test_mse', mse_test, s)
                writer.add_scalar('test_kl', kl_test, s)
                writer.add_scalar('test_loss', loss_test, s)
                writer.add_image('test_ground_truth', f_T_test[0], s)
                writer.add_image('test_reconstruction', f_R_test[0], s)
                writer.add_image('test_generation', f_gen_test[0], s)
                
            if s % save_interval_num == 0:
                torch.save(model.state_dict(), log_dir + "/models/model-{}.pt".format(s))
                
            if s >= gradient_steps:
                break

            # pixel variance for traveling salesman dataset
            # var = max(alpha - (alpha - beta)*(s/10**5), beta)
            
            # pixel variance for MNIST Cube 3D scene reconstruction task
            if s >= 100000 and s < 150000:
                var = 0.2
            elif s >= 150000 and s < 200000:
                var = 0.4
            elif s >= 200000:
                var = 0.9
        
    if s >= gradient_steps:
        torch.save(model.state_dict(), log_dir + "/models/model-final.pt")
        break
writer.close()

  0%|          | 50/22476 [01:03<7:11:57,  1.16s/it]