# Pix2Voxel论文的notebook版本

## 配置文件

In [None]:
from easydict import EasyDict as eDict

__C = eDict()
cfg = __C

#数据集设置
__C.DATA_CONFIG = eDict()
__C.DATA_CONFIG.PIX3D = eDict()
__C.DATA_CONFIG.PIX3D.TAXONOMY_FILE_PATH = './DataSets/Pix3D.json'
__C.DATA_CONFIG.PIX3D.ANNOTATION_PATH = './DataSets/PiX3D/pix3d.json'
__C.DATA_CONFIG.PIX3D.RENDERING_PATH = './DataSets/PiX3D/img/%s/%s.%s'
__C.DATA_CONFIG.PIX3D.VOXEL_PATH = './DataSets/PiX3D/model/%s/%s/%s.binvox'

__C.DATASET = eDict()
__C.DATASET.MEAN = [0.5, 0.5, 0.5]
__C.DATASET.STD = [0.5, 0.5, 0.5]
__C.DATASET.TRAIN_DATASET = 'Pix3D'
__C.DATASET.TEST_DATASET = 'Pix3D'

# 文件路径配置
__C.DIR = eDict()
__C.DIR.OUT_PATH = './output'

# 神经网络配置
__C.NETWORK = eDict()
__C.NETWORK.LEAKY_VALUE = .2
__C.NETWORK.CONV_USE_BIAS = False
__C.NETWORK.USE_REFINER = True
__C.NETWORK.USE_MERGER = True

## 数据集的预处理

In [None]:
import cv2
import numpy as np
import os
import random
import torch

class Compose(object):
    def __init__(self, transforms):
        self.transforms = transforms
        
    def __call__(self, rendering_images, bounding_box=None):
        for t in self.transforms:
            if t.__class__.__name__ == 'RandomCrop' or t.__class__.__name__ == 'CenterCrop':
                rendering_images = t(rendering_images, bounding_box)
            else:
                rendering_images = t(rendering_images)


## 神经网络模块的定义，暂时不使用refiner

In [None]:
import torchvision.models as models

class Encoder(torch.nn.Module):
    def __init__(self, config1):
        super().__init__()
        self.cfg = config1

        # 网络结构定义
        vgg16_bn = models.vgg16_bn(pretrained=True)
        self.vgg = torch.nn.Sequential(*list(vgg16_bn.features.children()))[:27]
        self.layer1 = torch.nn.Sequential(
            torch.nn.Conv2d(512, 512, kernel_size=3),
            torch.nn.BatchNorm2d(512),
            torch.nn.ELU()
        )
        self.layer2 = torch.nn.Sequential(
            torch.nn.Conv2d(512, 512, kernel_size=3),
            torch.nn.BatchNorm2d(512),
            torch.nn.ELU(),
            torch.nn.MaxPool2d(kernel_size=3)
        )
        self.layer3 = torch.nn.Sequential(
            torch.nn.Conv2d(512, 256, kernel_size=1),
            torch.nn.BatchNorm2d(256),
            torch.nn.ELU()
        )

        for param in vgg16_bn.parameters():
            param.requires_grad = False

    def forward(self, rendering_images):
        rendering_images = rendering_images.permute(1, 0, 2, 3, 4).contiguous()
        rendering_images = torch.split(rendering_images, 1 ,dim=0)
        image_features = []

        for img in rendering_images:
            features = self.vgg(img.squeeze(dim=0))
            features = self.layer1(features)
            features = self.layer2(features)
            features = self.layer3(features)
            image_features.append(features)
            
        image_features = torch.stack(image_features).permute(1, 0, 2, 3, 4).contiguous()
        print(image_features.size())
        return image_features
    

## 解码器构建

In [None]:
class Decoder(torch.nn.Module):
    def __init__(self, config2):
        super().__init__()
        self.config2 = config2
        
        self.layer1 = torch.nn.Sequential(
                torch.nn.ConvTranspose3d(2048, 512, kernel_size=4, stride=2, bias=cfg.NETWORK.TCONV_USE_BIAS, padding=1),
                torch.nn.BatchNorm3d(512),
                torch.nn.ReLU()
            )
        self.layer2 = torch.nn.Sequential(
            torch.nn.ConvTranspose3d(512, 128, kernel_size=4, stride=2, bias=cfg.NETWORK.TCONV_USE_BIAS, padding=1),
            torch.nn.BatchNorm3d(128),
            torch.nn.ReLU()
        )
        self.layer3 = torch.nn.Sequential(
            torch.nn.ConvTranspose3d(128, 32, kernel_size=4, stride=2, bias=cfg.NETWORK.TCONV_USE_BIAS, padding=1),
            torch.nn.BatchNorm3d(32),
            torch.nn.ReLU()
        )
        self.layer4 = torch.nn.Sequential(
            torch.nn.ConvTranspose3d(32, 8, kernel_size=4, stride=2, bias=cfg.NETWORK.TCONV_USE_BIAS, padding=1),
            torch.nn.BatchNorm3d(8),
            torch.nn.ReLU()
        )
        self.layer5 = torch.nn.Sequential(
            torch.nn.ConvTranspose3d(8, 1, kernel_size=1, bias=cfg.NETWORK.TCONV_USE_BIAS),
            torch.nn.Sigmoid()
        )

    def forward(self, image_features):
        image_features = image_features.permute(1, 0, 2, 3, 4).contiguous()
        image_features = torch.split(image_features, 1, dim=0)
        gen_volumes = []
        raw_features = []

        for features in image_features:
            gen_volume = features.view(-1, 2048, 2, 2, 2)
            # print(gen_volume.size())   # torch.Size([batch_size, 2048, 2, 2, 2])
            gen_volume = self.layer1(gen_volume)
            # print(gen_volume.size())   # torch.Size([batch_size, 512, 4, 4, 4])
            gen_volume = self.layer2(gen_volume)
            # print(gen_volume.size())   # torch.Size([batch_size, 128, 8, 8, 8])
            gen_volume = self.layer3(gen_volume)
            # print(gen_volume.size())   # torch.Size([batch_size, 32, 16, 16, 16])
            gen_volume = self.layer4(gen_volume)
            raw_feature = gen_volume
            # print(gen_volume.size())   # torch.Size([batch_size, 8, 32, 32, 32])
            gen_volume = self.layer5(gen_volume)
            # print(gen_volume.size())   # torch.Size([batch_size, 1, 32, 32, 32])
            raw_feature = torch.cat((raw_feature, gen_volume), dim=1)
            # print(raw_feature.size())  # torch.Size([batch_size, 9, 32, 32, 32])

            gen_volumes.append(torch.squeeze(gen_volume, dim=1))
            raw_features.append(raw_feature)

        gen_volumes = torch.stack(gen_volumes).permute(1, 0, 2, 3, 4).contiguous()
        raw_features = torch.stack(raw_features).permute(1, 0, 2, 3, 4, 5).contiguous()
        # print(gen_volumes.size())      # torch.Size([batch_size, n_views, 32, 32, 32])
        # print(raw_features.size())     # torch.Size([batch_size, n_views, 9, 32, 32, 32])
        return raw_features, gen_volumes

##融合模块

In [None]:
class Merger(torch.nn.Module):