In [124]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [125]:
import sys
sys.path.append('../')
sys.argv=[''] 
del sys

In [126]:
import os
#os.environ['CUDA_LAUNCH_BLOCKING'] = "1"
os.environ['CUDA_DEVICES_VISIBLE']='1'

In [127]:
cd "/root/SymmNeRF/"

/root/SymmNeRF


In [153]:
from opt import config_parser

parser = config_parser()
args = parser.parse_args()

device = 'cuda'
det = args.det
lindisp = args.lindisp
args.distributed = False
args.local_feature_ch = 32
args.local_rank = 0

In [154]:
from datasets import dataset_dict, create_training_dataset
from torch.utils.data import DataLoader 

from model import model_dict  
from model.sample_ray import RaySampler 
from model.render_ray import render_rays

from utils_lab_fct import * 

import setproctitle
import torch 
import numpy as np 
import cv2
from einops import rearrange 
import matplotlib.pyplot as plt

setproctitle.setproctitle('[Gaetan. - SymmNeRF]')

### Build up a DataLoader 

In [155]:
bs = 2

train_dataset, train_sampler = create_training_dataset(args)
train_loader = DataLoader(train_dataset,batch_size = bs , sampler = train_sampler)

it = iter(train_loader)
train_batch = next(it)

[Info] Training dataset: srns_dataset
[Info] Set used: test
SRNsDataset:  /data/datasets/srn_cars/cars_train


#### Model loading. 

In [156]:
model = model_dict['hypernerf_symm_local'](args,ckpts_folder = "")#/root/SymmNeRF/logs/srns_dataset/cars/srns_cars/ckpts")

[Info] No ckpts found, training from scratch...


#### HyperParameters

In [157]:
## Some constants. 
nb_rays = 256
nb_sampled_points_on_rays = 64

# First 2 dim. of the feature map F 
W_F, H_F = 64, 64

#### Create a ray batch

In [158]:
# RaySampler class instantiation.
ray_sampler = RaySampler(train_batch) 

# Get the corresponding source images and poses.
src_imgs = ray_sampler.src_img     # [B, 3, 128,128]
tgt_imgs = ray_sampler.render_imgs # [B,NV,3,128,128]
poses = ray_sampler.render_poses   # [B, NV, 4, 4 ]

# Build up a batch of ray. 
ray_batch = ray_sampler.random_sample(nb_rays,use_bbox = True)

### Get the latent code from the Source Images. 
The .encode() method encodes the RGB source images. Correspond to the f() network in the main paper. 
Each code is a 256d vector. 

In [159]:
z = model.encode(ray_batch)

### Render the rays. 
Considering the rays that were sampled on this batch, render them according to symmetry plane that is defined (points are expressed in the source camera viewpoint). 
Volume rendering is also perfomed as well as a last final step. 

In [160]:
from model.render_ray import sample_along_camera_ray, run_network, raw2outputs
from model.nerf import run_nerf_symm_local
from model.nerf_helpers import *
from model.render_ray import get_symmetric_points_and_directions


rays_o = ray_batch['rays_o']    # [B,256,3]
rays_d = ray_batch['rays_d']    # [B,256,3]
z_near = ray_batch['z_near']
z_far = ray_batch['z_far']

noise = False 

M = torch.tensor([[-1.,0.,0.],
                  [0.,1.,0.],
                  [0.,0.,1.]]).to(device)

F = model.feature_net.latent
print(F.shape)
#### 1. Get the 3D points ray, the viewdir and z sampled dist and symmetrize it. 
pts, viewdirs, z_vals = sample_along_camera_ray(rays_o=rays_o,  # pts : [B,256,64,3] - viewdirs : [B,256,3] - z_vals : [B,256,64]
                                                rays_d=rays_d,
                                                z_near=z_near,
                                                z_far=z_far,
                                                device=device,
                                                N_samples=nb_sampled_points_on_rays,
                                                lindisp=lindisp, 
                                                det=det)

pts_s,viewdirs_s = get_symmetric_points_and_directions(pts,viewdirs,M)

#### 2. Get the NeRF weights according to the latent code z. 
nerf_coarse_layers = model.hypernetwork(z) 

### 3. Get a set of feature for experimental purpose ( F feature tensor, uv coordinates etc... ) 
ret_features = model.feature_net.index_experimental(pts, 
                                        ray_batch['src_pose'],
                                        ray_batch['intrinsics'], 
                                        ray_batch['image_size'], 
                                        M,
                                        noise)
### 4. Retrieve f ans f_S 
f = ret_features['local_feature']
f_S = ret_features['local_feature_symm']


### 5. Define the NeRF models (one for each image in the batch since they all have a different latent code z). 
nerf_coarse = lambda x: run_nerf_symm_local(x, nerf_layers=nerf_coarse_layers, input_ch=model.input_ch,
                                                input_ch_views=model.input_ch_views,
                                                local_feature_ch=32)


### 6. Do raw forward pass - rgb and alpha are not normalized correctly - Volume rendering was not performed yet. 
raw_coarse = run_network(pts,viewdirs,nerf_coarse,model.embed_fn,model.embeddirs_fn,f)

raw_coarse_S = run_network(pts_s,viewdirs_s,nerf_coarse,model.embed_fn,model.embeddirs_fn,f_S)


### 7. Perform the volume rendering twice (also for the symmetric)
outputs_coarse = raw2outputs(raw_coarse,z_vals,ray_batch['rays_d'],device = device, raw_noise_std = 0.,white_bkgd = True)
outputs_coarse_S = raw2outputs(raw_coarse_S,z_vals,ray_batch['rays_d'],device=device,raw_noise_std = 0.,white_bkgd = True)

torch.Size([2, 32, 64, 64])
torch.Size([2, 16384, 32])


### Change the CNN feature network. 

In [118]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from model.feature_network import _resnet_symm_local, BasicBlock,ResNetSymmLocal

In [66]:
model = nn.Sequential(nn.Conv2d(3, 32, kernel_size=5, stride=2, padding=2,
                               bias=False))
                    # nn.MaxPool2d(kernel_size=3, stride=2, padding=1))


In [114]:
model = _resnet_symm_local('resnet18',latent_dim = 256,block = BasicBlock, layers= [2,2,2,2], pretrained= False, progress= True)



# 128,128,3 --> Input.
# conv1 : 64,64,32
# layer1 : 32,32,64
# layer2 : 16,16,
# 64,64,32
# 32,32,64
# 16,16,96

# --> 64,64, 128+64 = 192. --> 64,64,64




In [116]:
model = ResNetSymmLocalTest(block = BasicBlock, layers = [2,2,2,2])

x = torch.rand(1,3,128,128).to('cpu')

y = model(x) 
print(64+32+16+8)
print(y.shape)

Shape of x after the first conv: torch.Size([1, 8, 64, 64])
Shape of x after l1: torch.Size([1, 16, 64, 64])
Shape of x after l2: torch.Size([1, 32, 32, 32])
Shape of x after l3: torch.Size([1, 64, 16, 16])
Shape of latent before: torch.Size([1, 120, 64, 64])
Shape of latent after: torch.Size([1, 32, 64, 64])
120
torch.Size([1, 256])


In [123]:
model = ResNetSymmLocal(block = BasicBlock, layers = [2,2,2,2], **kwargs)
x = torch.rand(1,3,128,128).to('cpu')
y = model(x) 

NameError: name 'kwargs' is not defined

In [113]:
class ResNetSymmLocalTest(nn.Module):

    def __init__(self, block, layers, num_classes=1000, zero_init_residual=False, groups=1,
                 width_per_group=64, replace_stride_with_dilation=None, norm_layer=None, index_interp='bilinear',
                 index_padding='border', upsample_interp='bilinear', feature_scale=1.0, use_first_pool=False):
        super().__init__()
        # feature_scale factor to scale all latent by. Useful (<1) if image is extremely large, to fit in memory.
        self.feature_scale = feature_scale
        self.use_first_pool = use_first_pool
        self.index_interp = index_interp
        self.index_padding = index_padding
        self.upsample_interp = upsample_interp

        self.register_buffer("latent", torch.empty(1, 1, 1, 1), persistent=False)
        self.register_buffer("latent_scaling", torch.empty(2, dtype=torch.float32), persistent=False)

        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        self._norm_layer = norm_layer

        self.inplanes = 8
        self.dilation = 1
        if replace_stride_with_dilation is None:
            # each element in the tuple indicates if we should replace
            # the 2x2 stride with a dilated convolution instead
            replace_stride_with_dilation = [False, False, False]
        if len(replace_stride_with_dilation) != 3:
            raise ValueError("replace_stride_with_dilation should be None "
                             "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
        self.groups = groups
        self.base_width = width_per_group
        self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3,
                               bias=False)
        self.bn1 = norm_layer(self.inplanes)
        self.relu = nn.ReLU(inplace=True)
        
        # Last 
        self.conv_last = nn.Conv2d(120, 32, kernel_size=3, stride=1, padding=1,
                               bias=False)
        self.bn_last = norm_layer(32)
    
        
        
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        
        
        self.layer1 = self._make_layer(block, 16, layers[0])
        self.layer2 = self._make_layer(block, 32, layers[1], stride=2,
                                       dilate=replace_stride_with_dilation[0])
        self.layer3 = self._make_layer(block, 64, layers[2], stride=2,
                                       dilate=replace_stride_with_dilation[1])
        self.layer4 = self._make_layer(block, 128, layers[3], stride=2,
                                       dilate=replace_stride_with_dilation[2])
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        #self.fc = nn.Linear(512 * block.expansion, num_classes)
        self.fc = nn.Linear(128,256)
        
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

        # Zero-initialize the last BN in each residual branch,
        # so that the residual branch starts with zeros, and each residual block behaves like an identity.
        # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
        if zero_init_residual:
            for m in self.modules():
                if isinstance(m, Bottleneck):
                    nn.init.constant_(m.bn3.weight, 0)
                elif isinstance(m, BasicBlock):
                    nn.init.constant_(m.bn2.weight, 0)

    def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
        norm_layer = self._norm_layer
        downsample = None
        previous_dilation = self.dilation
        if dilate:
            self.dilation *= stride
            stride = 1
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                conv1x1(self.inplanes, planes * block.expansion, stride),
                norm_layer(planes * block.expansion),
            )

        layers = []
        layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
                            self.base_width, previous_dilation, norm_layer))
    
        self.inplanes = planes * block.expansion
        for _ in range(1, blocks):
            layers.append(block(self.inplanes, planes, groups=self.groups,
                                base_width=self.base_width, dilation=self.dilation,
                                norm_layer=norm_layer))

        return nn.Sequential(*layers)

    def _forward_impl(self, x):
        if self.feature_scale != 1.0:
            x = F.interpolate(x, scale_factor=self.feature_scale,
                              mode='bilinear' if self.feature_scale > 1.0 else 'area',
                              align_corners=True if self.feature_scale > 1.0 else False,
                              recompute_scale_factor=True)

        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        print(f'Shape of x after the first conv: {x.shape}')
        latents = [x]
        if self.use_first_pool:
            x = self.maxpool(x)
        x = self.layer1(x)
        print(f'Shape of x after l1: {x.shape}')
        latents.append(x)

        x = self.layer2(x)
        print(f'Shape of x after l2: {x.shape}')
        latents.append(x)

        x = self.layer3(x)
        print(f"Shape of x after l3: {x.shape}")
        latents.append(x)
        
        x = self.layer4(x)

      
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)
        
        align_corners = False if self.index_interp == 'nearest' else True
        latent_sz = latents[0].shape[-2:]
        
        for i in range(len(latents)):
            latents[i] = F.interpolate(latents[i], latent_sz,
                                       mode=self.upsample_interp,
                                       align_corners=align_corners)
            
        latents = torch.cat(latents,dim=1)
        latents = self.conv_last(latents)
        latents = self.bn_last(latents)
        
        self.latent = self.relu(latents)
        
        #self.latent = torch.cat(latents, dim=1)
        print(f'Shape of latent after: {self.latent.shape}')
        self.latent_scaling[0] = self.latent.shape[-1]
        self.latent_scaling[1] = self.latent.shape[-2]
        self.latent_scaling = self.latent_scaling / (self.latent_scaling - 1) * 2.0

        return x

    def forward(self, x):
        return self._forward_impl(x)