# Neural scene representation and rendering
### https://deepmind.com/blog/neural-scene-representation-and-rendering

Datasets: https://github.com/deepmind/gqn-datasets

Datasets Translater: https://github.com/l3robot/gqn_datasets_translator

In [None]:
import torch
import torch.utils.data
from torch import nn, optim
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torchvision.utils import save_image

import os
import datetime
import random
import math
from tensorboardX import SummaryWriter

from pixyz.distributions import Normal
from pixyz.losses import NLL, KullbackLeibler

from tqdm import tqdm

from shepardmetzler import ShepardMetzler, Scene, transform_viewpoint
from conv_lstm import Conv2dLSTMCell

seed = 1234
torch.manual_seed(seed)

In [None]:
# tower or pool
class Representation(nn.Module):
    def __init__(self, n_channels, v_dim, r_dim=256, pool=True):
        super(Representation, self).__init__()
        # dimention of r
        self.r_dim = k = r_dim
        # pool: True, tower: False:
        self.pool = pool

        self.conv1 = nn.Conv2d(n_channels, k, kernel_size=2, stride=2)
        self.conv2 = nn.Conv2d(k, k, kernel_size=2, stride=2)
        self.conv3 = nn.Conv2d(k, k//2, kernel_size=3, stride=1, padding=1)
        self.conv4 = nn.Conv2d(k//2, k, kernel_size=2, stride=2)

        self.conv5 = nn.Conv2d(k + v_dim, k, kernel_size=3, stride=1, padding=1)
        self.conv6 = nn.Conv2d(k + v_dim, k//2, kernel_size=3, stride=1, padding=1)
        self.conv7 = nn.Conv2d(k//2, k, kernel_size=3, stride=1, padding=1)
        self.conv8 = nn.Conv2d(k, k, kernel_size=1, stride=1)

        self.avgpool  = nn.AvgPool2d(k//16)

    def forward(self, x, v):
        # Increase dimensions
        v = v.view(v.size(0), -1, 1, 1)
        v = v.repeat(1, 1, self.r_dim // 16, self.r_dim // 16)

        # First skip-connected conv block
        skip_in  = F.relu(self.conv1(x))
        skip_out = F.relu(self.conv2(skip_in))

        x = F.relu(self.conv3(skip_in))
        x = F.relu(self.conv4(x)) + skip_out

        # Second skip-connected conv block (merged)
        skip_in = torch.cat([x, v], dim=1)
        skip_out  = F.relu(self.conv5(skip_in))

        x = F.relu(self.conv6(skip_in))
        x = F.relu(self.conv7(x)) + skip_out

        r = F.relu(self.conv8(x))

        if self.pool:
            r = self.avgpool(r)

        return r

In [None]:
class InferenceCore(nn.Module):
    def __init__(self, x_dim, v_dim, r_dim, h_dim):
        super(InferenceCore, self).__init__()
        self.core = Conv2dLSTMCell(2*h_dim + x_dim + v_dim + r_dim, h_dim, kernel_size=5, stride=1, padding=2)
        
    def forward(self, x, v, r, h_g, h_e, c_e, u):
        h_e, c_e = self.core(torch.cat([h_g, u, x, v, r], dim=1), [h_e, c_e])
        return h_e, c_e
    
class GeneratorCore(nn.Module):
    def __init__(self, v_dim, r_dim, z_dim, h_dim, SCALE):
        super(GeneratorCore, self).__init__()
        self.core = Conv2dLSTMCell(v_dim + r_dim + z_dim, h_dim, kernel_size=5, stride=1, padding=2)
        self.upsample = nn.ConvTranspose2d(h_dim, h_dim, kernel_size=SCALE, stride=SCALE, padding=0)
        
    def forward(self, z, v, r, h_g, c_g, u):
        h_g, c_g =  self.core(torch.cat([z, v, r], dim=1), [h_g, c_g])
        u = self.upsample(h_g) + u
        return h_g, c_g, u

In [None]:
class Inference(Normal):
    def __init__(self, z_dim, h_dim):
        super(Inference, self).__init__(cond_var=["h_i"],var=["z"])
        self.z_dim = z_dim
        self.eta_e = nn.Conv2d(h_dim, 2*z_dim, kernel_size=5, stride=1, padding=2)
        
    def forward(self, h_i):
        mu, logvar = torch.split(self.eta_e(h_i), self.z_dim, dim=1)
        std = F.softplus(logvar)
        return {"loc":mu, "scale":std}
    
class Prior(Normal):
    def __init__(self, z_dim, h_dim):
        super(Prior, self).__init__(cond_var=["h_g"],var=["z"])
        self.z_dim = z_dim
        self.eta_pi = nn.Conv2d(h_dim, 2*z_dim, kernel_size=5, stride=1, padding=2)

    def forward(self, h_g):
        mu, logvar = torch.split(self.eta_pi(h_g), self.z_dim, dim=1)
        std = F.softplus(logvar)
        return {"loc":mu ,"scale":std}
    
class Generator(Normal):
    def __init__(self, x_dim, h_dim):
        super(Generator, self).__init__(cond_var=["u", "sigma"],var=["x_q"])
        self.eta_g = nn.Conv2d(h_dim, x_dim, kernel_size=1, stride=1, padding=0)
        
    def forward(self, u, sigma):
        mu = self.eta_g(u)
        return {"loc":mu, "scale":sigma}

In [None]:
class GQN(nn.Module):
    def __init__(self, x_dim, v_dim, r_dim, h_dim, z_dim, L=12, SCALE=4):
        super(GQN, self).__init__()
        self.L = L
        self.h_dim = h_dim
        self.SCALE = SCALE

        self.phi = Representation(x_dim, v_dim, r_dim)
        self.generator_core = GeneratorCore(v_dim, r_dim, z_dim, h_dim, self.SCALE)
        self.inference_core = InferenceCore(x_dim, v_dim, r_dim, h_dim)

        self.upsample   = nn.ConvTranspose2d(h_dim, h_dim, kernel_size=SCALE, stride=SCALE, padding=0)
        self.downsample_x = nn.Conv2d(x_dim, x_dim, kernel_size=SCALE, stride=SCALE, padding=0)
        self.downsample_u = nn.Conv2d(h_dim, h_dim, kernel_size=SCALE, stride=SCALE, padding=0)

        # distribution
        self.pi = Prior(z_dim, h_dim)
        self.q = Inference(z_dim, h_dim)
        self.g = Generator(x_dim, h_dim)

    def forward(self, x, v, v_q, x_q, sigma):
        batch_size, n_views, _, h, w = x.size()
        
        # merge batch and view dimensions.
        _, _, *x_dims = x.size()
        _, _, *v_dims = v.size()

        x = x.view((-1, *x_dims))
        v = v.view((-1, *v_dims))

        # representation generated from input images and corresponding viewpoints
        r = self.phi(x, v)

        # seperate batch and view dimensions
        _, *r_dims = r.size()
        r = r.view((batch_size, n_views, *r_dims))

        # sum over view representations
        r = torch.sum(r, dim=1)

        _, _, h, w = x.size()

        # increase dimensions
        v_q = v_q.view(batch_size, -1, 1, 1).repeat(1, 1, h//self.SCALE, w//self.SCALE)
        
        if r.size(2) != h//self.SCALE:
            r = r.repeat(1, 1, h//self.SCALE, w//self.SCALE)

        # reset hidden state
        hidden_g = x_q.new_zeros((batch_size, self.h_dim, h//self.SCALE, w//self.SCALE))
        hidden_i = x_q.new_zeros((batch_size, self.h_dim, h//self.SCALE, w//self.SCALE))

        # reset cell state
        cell_g = x_q.new_zeros((batch_size, self.h_dim, h//self.SCALE, w//self.SCALE))
        cell_i = x_q.new_zeros((batch_size, self.h_dim, h//self.SCALE, w//self.SCALE))
        
        # reset u state
        u = x.new_zeros((batch_size, self.h_dim, h, w))
        
        _x_q = self.downsample_x(x_q)

        kls = 0
        for _ in range(self.L):
            z = self.q.sample({"h_i": hidden_i}, reparam=True)["z"]
            # kl divergence between posterior and prior
            kl = KullbackLeibler(self.q, self.pi)
            kl_tensor = kl.estimate({"h_i":hidden_i, "h_g":hidden_g})
            kls += kl_tensor
            # update state
            _u = self.downsample_u(u)
            hidden_i, cell_i = self.inference_core(_x_q, v_q, r, hidden_g, hidden_i, cell_i, _u)
            hidden_g, cell_g, u = self.generator_core(z, v_q, r, hidden_g, cell_g, u)
            
        # sample reconstruction
        x_q_rec = torch.clamp(self.g.sample_mean({"u": u, "sigma":sigma}), 0, 1)
        # negative log-likelihood
        nll = NLL(self.g)
        nll_tensor = nll.estimate({"u":u, "sigma":sigma, "x_q": x_q})

        return nll_tensor, kls, x_q_rec
    
    def generate(self, x, v, v_q):
        batch_size, n_views, _, h, w = x.size()
        
        # merge batch and view dimensions.
        _, _, *x_dims = x.size()
        _, _, *v_dims = v.size()

        x = x.contiguous().view((-1, *x_dims))
        v = v.contiguous().view((-1, *v_dims))

        # representation generated from input images and corresponding viewpoints
        r = self.phi(x, v)

        # seperate batch and view dimensions
        _, *r_dims = r.size()
        r = r.view((batch_size, n_views, *r_dims))

        # sum over view representations
        r = torch.sum(r, dim=1)

        # increase dimensions
        v_q = v_q.view(batch_size, -1, 1, 1).repeat(1, 1, h//self.SCALE, w//self.SCALE)
        
        if r.size(2) != h//self.SCALE:
            r = r.repeat(1, 1, h//self.SCALE, w//self.SCALE)

        # reset hidden state
        hidden_g = x.new_zeros((batch_size, self.h_dim, h//self.SCALE, w//self.SCALE))

        # reset cell state
        cell_g = x.new_zeros((batch_size, self.h_dim, h//self.SCALE, w//self.SCALE))
        
        # reset r state
        u = x.new_zeros((batch_size, self.h_dim, h, w))
        
        for _ in range(self.L):
            z = self.pi.sample({"h_g": hidden_g})["z"]
            # update state
            hidden_g, cell_g, u = self.generator_core(z, v_q, r, hidden_g, cell_g, u)
            
        x_q_hat = torch.clamp(self.g.sample_mean({"u": u, "sigma":sigma}), 0, 1)

        return x_q_hat

In [None]:
def arrange_data(x_data, v_data, seed=None):
    random.seed(seed)
    batch_size, m, *_ = x_data.size()

    # sample random number of views
    n_views = random.randint(2, m-1)

    indices = torch.randperm(m)
    representation_idx, query_idx = indices[:n_views], indices[n_views]

    x, v = x_data[:, representation_idx], v_data[:, representation_idx]
    x_q, v_q = x_data[:, query_idx], v_data[:, query_idx]
    
    return x, v, x_q, v_q

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"

# dataset translater
# https://github.com/l3robot/gqn_datasets_translator

# dataset directory
train_data_dir = '/workspace/dataset/shepard_metzler_7_parts-torch/train'
test_data_dir = '/workspace/dataset/shepard_metzler_7_parts-torch/test'

# number of workers to load data
num_workers = 0

# log
log_interval_num = 100
save_interval_num = 100000
dir_name = str(datetime.datetime.now())
log_dir = '/workspace/logs/'+ dir_name
os.mkdir(log_dir)
os.mkdir(log_dir+'/models')
os.mkdir(log_dir+'/runs')

# tensorboardX
writer = SummaryWriter(log_dir=log_dir+'/runs')

batch_size = 36
gradient_steps = 2*(10**6)

train_dataset = ShepardMetzler(root_dir=train_data_dir, target_transform=transform_viewpoint)
test_dataset = ShepardMetzler(root_dir=test_data_dir, target_transform=transform_viewpoint)

# model settings
xDim=3
vDim=7
rDim=256
hDim=128
zDim=64
L=12
SCALE = 4 # scale of image generation process

# model
gqn=GQN(xDim,vDim,rDim,hDim,zDim, L, SCALE).to(device)
gqn = nn.DataParallel(gqn)

# learning rate
mu_i, mu_f = 5e-4, 5e-5

# pixel variance
sigma_i, sigma_f = 2.0, 0.7

# initial value
mu, sigma = mu_i, sigma_i

optimizer = torch.optim.Adam(gqn.parameters(), lr=mu, betas=(0.9, 0.999))
kwargs = {'num_workers':num_workers, 'pin_memory': True} if torch.cuda.is_available() else {}

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, **kwargs)
test_loader = DataLoader(test_dataset, batch_size=36, shuffle=True, **kwargs)
    
x_data_test, v_data_test = next(iter(test_loader))

# number of gradient steps
s = 0
while True:
    for x_data, v_data in tqdm(train_loader):
        x_data = x_data.to(device)
        v_data = v_data.to(device)
        x, v, x_q, v_q = arrange_data(x_data, v_data)
        nll, kl, x_q_rec = gqn(x, v, v_q, x_q, sigma)
        nll = nll.mean()
        kl = kl.mean()
        loss = nll + kl
        loss.backward()
        
        optimizer.step()
        optimizer.zero_grad()
        
        writer.add_scalar('train_nll', nll, s)
        writer.add_scalar('train_kl', kl, s)
        writer.add_scalar('train_loss', loss, s)
        
        s += 1
        
        with torch.no_grad():
            # write logs to tensorboard
            if s % log_interval_num == 0 or s == 1:
                writer.add_image('train_ground_truth', x_q[:8], s)
                writer.add_image('train_reconstruction', x_q_rec[:8], s)
                
                x_data_test = x_data_test.to(device)
                v_data_test = v_data_test.to(device)
                
                x_test, v_test, x_q_test, v_q_test = arrange_data(x_data_test, v_data_test, seed=0)
                nll_test, kl_test, x_q_rec_test = gqn(x_test, v_test, v_q_test, x_q_test, sigma)
                x_q_hat_test = gqn.module.generate(x_test, v_test, v_q_test)
                
                nll_test = nll_test.mean()
                kl_test = kl_test.mean()
                loss_test = nll_test + kl_test
                
                writer.add_scalar('test_nll', nll_test, s)
                writer.add_scalar('test_kl', kl_test, s)
                writer.add_scalar('test_loss', loss_test, s)
                writer.add_image('test_ground_truth', x_q_test[:8], s)
                writer.add_image('test_reconstruction', x_q_rec_test[:8], s)
                writer.add_image('test_generation', x_q_hat_test[:8], s)
                
            if s % save_interval_num == 0:
                torch.save(gqn.state_dict(), log_dir + "/models/model-{}.pt".format(s))
                
            if s >= gradient_steps:
                break

            # Anneal learning rate
            mu = max(mu_f + (mu_i - mu_f)*(1 - s/(1.6 * 10**6)), mu_f)
            for group in optimizer.param_groups:
                group["lr"] = mu
            # Anneal pixel variance
            sigma = max(sigma_f + (sigma_i - sigma_f)*(1 - s/(2 * 10**5)), sigma_f)
        
    if s >= gradient_steps:
        torch.save(gqn.state_dict(), log_dir + "/models/model-final.pt")
        break
writer.close()