Skip to content

Commit

Permalink
Moved MLP and Transformer
Browse files Browse the repository at this point in the history
Summary: Moved the MLP and transformer from nerf to a new file to be reused.

Reviewed By: bottler

Differential Revision: D38828150

fbshipit-source-id: 8ff77b18b3aeeda398d90758a7bcb2482edce66f
  • Loading branch information
Darijan Gudelj authored and facebook-github-bot committed Aug 23, 2022
1 parent edee25a commit 898ba5c
Show file tree
Hide file tree
Showing 2 changed files with 321 additions and 302 deletions.
315 changes: 315 additions & 0 deletions pytorch3d/implicitron/models/implicit_function/decoding_functions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,315 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import logging

from typing import Optional, Tuple

import torch

logger = logging.getLogger(__name__)


class MLPWithInputSkips(torch.nn.Module):
"""
Implements the multi-layer perceptron architecture of the Neural Radiance Field.
As such, `MLPWithInputSkips` is a multi layer perceptron consisting
of a sequence of linear layers with ReLU activations.
Additionally, for a set of predefined layers `input_skips`, the forward pass
appends a skip tensor `z` to the output of the preceding layer.
Note that this follows the architecture described in the Supplementary
Material (Fig. 7) of [1].
References:
[1] Ben Mildenhall and Pratul P. Srinivasan and Matthew Tancik
and Jonathan T. Barron and Ravi Ramamoorthi and Ren Ng:
NeRF: Representing Scenes as Neural Radiance Fields for View
Synthesis, ECCV2020
"""

def _make_affine_layer(self, input_dim, hidden_dim):
l1 = torch.nn.Linear(input_dim, hidden_dim * 2)
l2 = torch.nn.Linear(hidden_dim * 2, hidden_dim * 2)
_xavier_init(l1)
_xavier_init(l2)
return torch.nn.Sequential(l1, torch.nn.ReLU(True), l2)

def _apply_affine_layer(self, layer, x, z):
mu_log_std = layer(z)
mu, log_std = mu_log_std.split(mu_log_std.shape[-1] // 2, dim=-1)
std = torch.nn.functional.softplus(log_std)
return (x - mu) * std

def __init__(
self,
n_layers: int = 8,
input_dim: int = 39,
output_dim: int = 256,
skip_dim: int = 39,
hidden_dim: int = 256,
input_skips: Tuple[int, ...] = (5,),
skip_affine_trans: bool = False,
no_last_relu=False,
):
"""
Args:
n_layers: The number of linear layers of the MLP.
input_dim: The number of channels of the input tensor.
output_dim: The number of channels of the output.
skip_dim: The number of channels of the tensor `z` appended when
evaluating the skip layers.
hidden_dim: The number of hidden units of the MLP.
input_skips: The list of layer indices at which we append the skip
tensor `z`.
"""
super().__init__()
layers = []
skip_affine_layers = []
for layeri in range(n_layers):
dimin = hidden_dim if layeri > 0 else input_dim
dimout = hidden_dim if layeri + 1 < n_layers else output_dim

if layeri > 0 and layeri in input_skips:
if skip_affine_trans:
skip_affine_layers.append(
self._make_affine_layer(skip_dim, hidden_dim)
)
else:
dimin = hidden_dim + skip_dim

linear = torch.nn.Linear(dimin, dimout)
_xavier_init(linear)
layers.append(
torch.nn.Sequential(linear, torch.nn.ReLU(True))
if not no_last_relu or layeri + 1 < n_layers
else linear
)
self.mlp = torch.nn.ModuleList(layers)
if skip_affine_trans:
self.skip_affines = torch.nn.ModuleList(skip_affine_layers)
self._input_skips = set(input_skips)
self._skip_affine_trans = skip_affine_trans

def forward(self, x: torch.Tensor, z: Optional[torch.Tensor] = None):
"""
Args:
x: The input tensor of shape `(..., input_dim)`.
z: The input skip tensor of shape `(..., skip_dim)` which is appended
to layers whose indices are specified by `input_skips`.
Returns:
y: The output tensor of shape `(..., output_dim)`.
"""
y = x
if z is None:
# if the skip tensor is None, we use `x` instead.
z = x
skipi = 0
for li, layer in enumerate(self.mlp):
if li in self._input_skips:
if self._skip_affine_trans:
y = self._apply_affine_layer(self.skip_affines[skipi], y, z)
else:
y = torch.cat((y, z), dim=-1)
skipi += 1
y = layer(y)
return y


class TransformerWithInputSkips(torch.nn.Module):
def __init__(
self,
n_layers: int = 8,
input_dim: int = 39,
output_dim: int = 256,
skip_dim: int = 39,
hidden_dim: int = 64,
input_skips: Tuple[int, ...] = (5,),
dim_down_factor: float = 1,
):
"""
Args:
n_layers: The number of linear layers of the MLP.
input_dim: The number of channels of the input tensor.
output_dim: The number of channels of the output.
skip_dim: The number of channels of the tensor `z` appended when
evaluating the skip layers.
hidden_dim: The number of hidden units of the MLP.
input_skips: The list of layer indices at which we append the skip
tensor `z`.
"""
super().__init__()

self.first = torch.nn.Linear(input_dim, hidden_dim)
_xavier_init(self.first)

self.skip_linear = torch.nn.ModuleList()

layers_pool, layers_ray = [], []
dimout = 0
for layeri in range(n_layers):
dimin = int(round(hidden_dim / (dim_down_factor**layeri)))
dimout = int(round(hidden_dim / (dim_down_factor ** (layeri + 1))))
logger.info(f"Tr: {dimin} -> {dimout}")
for _i, l in enumerate((layers_pool, layers_ray)):
l.append(
TransformerEncoderLayer(
d_model=[dimin, dimout][_i],
nhead=4,
dim_feedforward=hidden_dim,
dropout=0.0,
d_model_out=dimout,
)
)

if layeri in input_skips:
self.skip_linear.append(torch.nn.Linear(input_dim, dimin))

self.last = torch.nn.Linear(dimout, output_dim)
_xavier_init(self.last)

# pyre-fixme[8]: Attribute has type `Tuple[ModuleList, ModuleList]`; used as
# `ModuleList`.
self.layers_pool, self.layers_ray = (
torch.nn.ModuleList(layers_pool),
torch.nn.ModuleList(layers_ray),
)
self._input_skips = set(input_skips)

def forward(
self,
x: torch.Tensor,
z: Optional[torch.Tensor] = None,
):
"""
Args:
x: The input tensor of shape
`(minibatch, n_pooled_feats, ..., n_ray_pts, input_dim)`.
z: The input skip tensor of shape
`(minibatch, n_pooled_feats, ..., n_ray_pts, skip_dim)`
which is appended to layers whose indices are specified by `input_skips`.
Returns:
y: The output tensor of shape
`(minibatch, 1, ..., n_ray_pts, input_dim)`.
"""

if z is None:
# if the skip tensor is None, we use `x` instead.
z = x

y = self.first(x)

B, n_pool, n_rays, n_pts, dim = y.shape

# y_p in n_pool, n_pts, B x n_rays x dim
y_p = y.permute(1, 3, 0, 2, 4)

skipi = 0
dimh = dim
for li, (layer_pool, layer_ray) in enumerate(
zip(self.layers_pool, self.layers_ray)
):
y_pool_attn = y_p.reshape(n_pool, n_pts * B * n_rays, dimh)
if li in self._input_skips:
z_skip = self.skip_linear[skipi](z)
y_pool_attn = y_pool_attn + z_skip.permute(1, 3, 0, 2, 4).reshape(
n_pool, n_pts * B * n_rays, dimh
)
skipi += 1
# n_pool x B*n_rays*n_pts x dim
y_pool_attn, pool_attn = layer_pool(y_pool_attn, src_key_padding_mask=None)
dimh = y_pool_attn.shape[-1]

y_ray_attn = (
y_pool_attn.view(n_pool, n_pts, B * n_rays, dimh)
.permute(1, 0, 2, 3)
.reshape(n_pts, n_pool * B * n_rays, dimh)
)
# n_pts x n_pool*B*n_rays x dim
y_ray_attn, ray_attn = layer_ray(
y_ray_attn,
src_key_padding_mask=None,
)

y_p = y_ray_attn.view(n_pts, n_pool, B * n_rays, dimh).permute(1, 0, 2, 3)

y = y_p.view(n_pool, n_pts, B, n_rays, dimh).permute(2, 0, 3, 1, 4)

W = torch.softmax(y[..., :1], dim=1)
y = (y * W).sum(dim=1)
y = self.last(y)

return y


class TransformerEncoderLayer(torch.nn.Module):
r"""TransformerEncoderLayer is made up of self-attn and feedforward network.
This standard encoder layer is based on the paper "Attention Is All You Need".
Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez,
Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you need. In Advances in
Neural Information Processing Systems, pages 6000-6010. Users may modify or implement
in a different way during application.
Args:
d_model: the number of expected features in the input (required).
nhead: the number of heads in the multiheadattention models (required).
dim_feedforward: the dimension of the feedforward network model (default=2048).
dropout: the dropout value (default=0.1).
activation: the activation function of intermediate layer, relu or gelu (default=relu).
Examples::
>>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8)
>>> src = torch.rand(10, 32, 512)
>>> out = encoder_layer(src)
"""

def __init__(
self, d_model, nhead, dim_feedforward=2048, dropout=0.1, d_model_out=-1
):
super(TransformerEncoderLayer, self).__init__()
self.self_attn = torch.nn.MultiheadAttention(d_model, nhead, dropout=dropout)
# Implementation of Feedforward model
self.linear1 = torch.nn.Linear(d_model, dim_feedforward)
self.dropout = torch.nn.Dropout(dropout)
d_model_out = d_model if d_model_out <= 0 else d_model_out
self.linear2 = torch.nn.Linear(dim_feedforward, d_model_out)
self.norm1 = torch.nn.LayerNorm(d_model)
self.norm2 = torch.nn.LayerNorm(d_model_out)
self.dropout1 = torch.nn.Dropout(dropout)
self.dropout2 = torch.nn.Dropout(dropout)

self.activation = torch.nn.functional.relu

def forward(self, src, src_mask=None, src_key_padding_mask=None):
r"""Pass the input through the encoder layer.
Args:
src: the sequence to the encoder layer (required).
src_mask: the mask for the src sequence (optional).
src_key_padding_mask: the mask for the src keys per batch (optional).
Shape:
see the docs in Transformer class.
"""
src2, attn = self.self_attn(
src, src, src, attn_mask=src_mask, key_padding_mask=src_key_padding_mask
)
src = src + self.dropout1(src2)
src = self.norm1(src)
src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
d_out = src2.shape[-1]
src = src[..., :d_out] + self.dropout2(src2)[..., :d_out]
src = self.norm2(src)
return src, attn


def _xavier_init(linear) -> None:
"""
Performs the Xavier weight initialization of the linear layer `linear`.
"""
torch.nn.init.xavier_uniform_(linear.weight.data)
Loading

0 comments on commit 898ba5c

Please sign in to comment.