In [1]:
import pytorch3d

In [1]:
import torch
import torch.nn as nn
from transformers.models.bert.modeling_bert import BertSelfOutput, BertIntermediate, BertOutput

In [2]:
from omegaconf import OmegaConf

In [45]:
config = OmegaConf.load("../configs/model/fastformer.yaml")
print(config)

{'hidden_size_per_step': [80, 40], 'hidden_dropout_prob': 0.2, 'hidden_act': 'gelu', 'num_attention_heads': 8, 'intermediate_size': 80, 'layer_norm_eps': 1e-12, 'initializer_range': 0.02}


In [46]:
config.hidden_size = config.hidden_size_per_step[0]

In [47]:
config

{'hidden_size_per_step': [80, 40], 'hidden_dropout_prob': 0.2, 'hidden_act': 'gelu', 'num_attention_heads': 8, 'intermediate_size': 80, 'layer_norm_eps': 1e-12, 'initializer_range': 0.02, 'hidden_size': 80}

In [4]:
# Fastformer의 핵심인 기존 self-attention의 단순화 버전
# Fastformer 논문에서는 d_model = d(논문) = hidden_size(코드)로 어노테이션 되어 있는 듯
# 나는 인코더 하나씩만 가져와서 쓰는 거니까 이 클래스까지만 쓰면 될 듯?
class FastSelfAttention(nn.Module):
    def __init__(self, config):
        super(FastSelfAttention, self).__init__()

        self.config = config

        # d_model은 h로 나누어 떨어져야 함.
        if config.hidden_size % config.num_attention_heads != 0:
            raise ValueError(
                "The hidden size (%d) is not a multiple of the number of attention "
                "heads (%d)" %
                (config.hidden_size, config.num_attention_heads))
    
        self.attention_head_size = int(config.hidden_size /config.num_attention_heads)      # attention_head_size
        self.num_attention_heads = config.num_attention_heads                               # num_attention_heads = h
        self.all_head_size = self.num_attention_heads * self.attention_head_size            # all_head_size = d_model = d = hidden_size
        self.input_dim = config.hidden_size                                                 # input_dim = d_model = d = hidden_size
        
        # Query Transformation
        # Input >> Query matrix [q_1, q_2, q_3, ..., q_N]
        self.query = nn.Linear(self.input_dim, self.all_head_size)      # (..., N, d) >> (..., N, d)
        self.query_att = nn.Linear(self.all_head_size, self.num_attention_heads)    # (..., N, d) >> (..., N, num_head)

        # Key Transformation
        # Input >> Key Matrix [k_1, k_2, k_3, ..., k_N]
        self.key = nn.Linear(self.input_dim, self.all_head_size)        # (..., N, d) >> (..., N, d)
        self.key_att = nn.Linear(self.all_head_size, self.num_attention_heads)      # (..., N, d) >> (..., N, num_head)
        
        self.transform = nn.Linear(self.all_head_size, self.all_head_size)      # (..., N, d) >> (..., N, d)

        self.softmax = nn.Softmax(dim=-1)
        
        self.apply(self.init_weights)

    def init_weights(self, module):
        if isinstance(module, nn.Linear):
            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
        if isinstance(module, nn.Linear) and module.bias is not None:
            module.bias.data.zero_()
                
    def transpose_for_scores(self, x):
        new_x_shape = x.size()[:-1] + (self.num_attention_heads,
                                       self.attention_head_size)
        x = x.view(*new_x_shape)

        return x.permute(0, 2, 1, 3)
    
    def forward(self, hidden_states):
        # batch_size, seq_len, num_head * head_dim, batch_size, seq_len
        batch_size, seq_len, _ = hidden_states.shape    # (B, N, d)
        
        # Q: [q_1, q_2, q_3, ..., q_N]
        mixed_query_layer = self.query(hidden_states)
        # K: [k_1, k_2, k_3, ..., k_N]
        mixed_key_layer = self.key(hidden_states)

        # (B, num_head, N)
        query_for_score = self.query_att(mixed_query_layer).transpose(1, 2) / self.attention_head_size**0.5     # 각 q벡터에 곱해지는 alpha 계산 후, d**0.5로 나누어 스케일링
        # query_for_score += attention_mask       # add attention mask
        # (B, num_head, 1, N)
        query_weight = self.softmax(query_for_score).unsqueeze(2)       # 계산한 alpha는 query의 weight가 된다.

        # batch_size, num_head, seq_len, head_dim
        query_layer = self.transpose_for_scores(mixed_query_layer)  # 여기서부터 해독........

        # batch_size, num_head, head_dim, 1
        pooled_query = torch.matmul(query_weight, query_layer).transpose(1, 2).view(-1,1,self.num_attention_heads*self.attention_head_size)
        pooled_query_repeat= pooled_query.repeat(1, seq_len,1)
        # batch_size, num_head, seq_len, head_dim

        # batch_size, num_head, seq_len
        mixed_query_key_layer=mixed_key_layer* pooled_query_repeat
        
        query_key_score=(self.key_att(mixed_query_key_layer)/ self.attention_head_size**0.5).transpose(1, 2)
        
        # query_key_score +=attention_mask      # add attention mask

        # batch_size, num_head, 1, seq_len
        query_key_weight = self.softmax(query_key_score).unsqueeze(2)

        key_layer = self.transpose_for_scores(mixed_query_key_layer)
        pooled_key = torch.matmul(query_key_weight, key_layer)

        # query = value
        weighted_value =(pooled_key * query_layer).transpose(1, 2)
        weighted_value = weighted_value.reshape(
            weighted_value.size()[:-2] + (self.num_attention_heads * self.attention_head_size,))
        weighted_value = self.transform(weighted_value) + mixed_query_layer
      
        return weighted_value


class FastAttention(nn.Module):
    def __init__(self, config):
        super(FastAttention, self).__init__()

        self.self = FastSelfAttention(config)
        self.output = BertSelfOutput(config)

    def forward(self, input_tensor):
        self_output = self.self(input_tensor)
        attention_output = self.output(self_output, input_tensor)

        return attention_output


class FastformerLayer(nn.Module):
    def __init__(self, config):
        super(FastformerLayer, self).__init__()
        self.attention = FastAttention(config)
        self.intermediate = BertIntermediate(config)
        self.output = BertOutput(config)

    def forward(self, hidden_states):
        attention_output = self.attention(hidden_states)
        intermediate_output = self.intermediate(attention_output)
        layer_output = self.output(intermediate_output, attention_output)

        return layer_output

In [5]:
class NerFormerArchitecture(nn.Module):
    def __init__(self, d_z):
        super(NerFormerArchitecture, self).__init__()

        self.d_z = d_z  # Input feature의 차원

        # input: (N_rays(=Batch), N_s, N_src, d_z)
        self.linear_1 = nn.Linear(d_z, 80, bias=False)
        
        # (N_rays, N_s, N_src, 80)
        self.TE_1 = nn.Sequential(
            TransformerEncoder(along_dim="src", feature_dim=80, num_heads=8),          # Pooling transformer encoder
            TransformerEncoder(along_dim="sample", feature_dim=80, num_heads=8)           # Ray transformer encoder
        )
        self.dim_linear_1 = nn.Linear(80, 40)
        # (N_rays, N_s, N_src, 40)
        self.TE_2 = nn.Sequential(
            TransformerEncoder(along_dim="src", feature_dim=40, num_heads=4),          # Pooling transformer encoder
            TransformerEncoder(along_dim="sample", feature_dim=40, num_heads=4)           # Ray transformer encoder
        )
        self.dim_linear_2 = nn.Linear(40, 20)
        # (N_rays, N_s, N_src, 20)

        self.weight_layer = nn.Sequential(
            nn.Linear(20, 1),
            nn.Softmax(dim=-2)      # 특정 sample에서 각 src들에 대한 값의 합이 1이 되도록 차원을 설정
        )

        # color function head
        # Output shape: (N_s, 3)
        self.c_head = nn.Sequential(
            nn.Linear(20, 20),
            nn.ReLU(),
            nn.Linear(20, 3)
        )

        # opacity function head
        # Output shape: (N_s, 1)
        self.f_head = nn.Sequential(
            nn.Linear(20, 1),
            nn.ReLU()
        )


    def forward(self, input_tensor):
        # input_tensor: (N_rays(=Batch), N_s, N_src, D_z)

        x = self.linear_1(input_tensor)     # (N_rays, N_s, N_src, 80)

        x = self.TE_1(x)                    # (N_rays, N_s, N_src, 80)
        x = self.dim_linear_1(x)              # (N_rays, N_s, N_src, 40)

        x = self.TE_2(x)                    # (N_rays, N_s, N_src, 40)
        x = self.dim_linear_2(x)              # (N_rays, N_s, N_src, 20)
        
        # weighted sum along dim 1
        weight = self.weight_layer(x)       # (N_rays, N_s, N_src, 1)
        per_point_features = torch.sum(weight*x, dim=-2)      # (N_rays, N_s, 20)

        # Color function
        ray_colors = self.c_head(per_point_features) # (N_rays, N_s, 3)
        # Opacity function
        ray_densities = self.f_head(per_point_features) # (N_rays, N_s, 1)

        return ray_densities, ray_colors


# (N_s, N_src, D_z) -> (N_s, N_src, D_z)
class TransformerEncoder(nn.Module):
    def __init__(self, along_dim, feature_dim, num_heads):
        super(TransformerEncoder, self).__init__()

        self.along_dim = along_dim
        # Multi-head attention along dim
        # num_heads = 8 (Transformer 논문에서의 세팅)
        self.multi_head_att = nn.MultiheadAttention(embed_dim=feature_dim, num_heads=num_heads)
        self.Q_weights = nn.Linear(feature_dim, feature_dim)
        self.K_weights = nn.Linear(feature_dim, feature_dim)
        self.V_weights = nn.Linear(feature_dim, feature_dim)
        
        self.dropout_1 = nn.Dropout(0.1)
        self.dropout_2 = nn.Dropout(0.1)

        self.layer_norm_1 = nn.LayerNorm(feature_dim)
        self.layer_norm_2 = nn.LayerNorm(feature_dim)

        self.two_layer_MLP = nn.Sequential(
            nn.Linear(feature_dim, feature_dim),
            nn.ReLU(),
            nn.Linear(feature_dim, feature_dim)
        )
        

    def forward(self, input_tensor):
        # input_tensor = Z

        # Multi Head Att = MHA(Z, dim=dim)

        # MultiHead(Q,K,V)
        # Q: (sequence length, batch, embedding)
        # K: (sequence length, batch, embedding)
        # V: (sequence length, batch, embedding)

        # Pooling transformer enc
        if self.along_dim == "src":
            # 배치로 들어오는 각 샘플들에 대해, N_src개 소스뷰 시퀀스를 입력으로 줌.
            # (Seq_len, Batch, Features) = (N_src, N_rays*N_s, D_z)
            input_tensor = input_tensor.permute(2, 0, 1, 3)
            shape = input_tensor.shape

            # Pooling transformer의 Batch에 해당하는
            # `N_rays` 차원과 `N_s` 차원을 합쳐준다.
            input_tensor = input_tensor.reshape(shape[0], shape[1]*shape[2], shape[3])

        # Ray transformer enc
        else:
            # 배치로 들어오는 각 소스뷰에 대해, N_s개 샘플 시퀀스를 입력으로 줌.
            # (Seq_len, Batch, Features) = (N_s, N_rays*N_src, D_Z) 
            input_tensor = input_tensor.permute(1, 0, 2, 3)
            shape = input_tensor.shape

            # Ray transformer의 Batch에 해당하는
            # `N_rays` 차원과 `N_src` 차원을 합쳐준다.
            input_tensor = input_tensor.reshape(shape[0], shape[1]*shape[2], shape[3])
        
        query = self.Q_weights(input_tensor)
        key = self.K_weights(input_tensor)
        value = self.V_weights(input_tensor)

        x, _ = self.multi_head_att(query, key, value)
        # Sub-layer MLP
        x_skip = self.layer_norm_1(input_tensor + self.dropout_1(x))    # Skip + LayerNorm  = Z'
        x = self.two_layer_MLP(x_skip)                                  # Two-Layer MLP = MLP(Z')
        x = self.layer_norm_2(x_skip + self.dropout_2(x))               # Skip + LayerNorm = TE^dim(Z)

        x = x.reshape(shape[0], shape[1], shape[2], shape[3])   # N_rays 차원을 분리
        if self.along_dim == "src":
            x = x.permute(1, 2, 0, 3)           # 원래 차원 순서인 (N_rays, N_s, N_src, D_z)로 변환
        else:
            x = x.permute(1, 0, 2, 3)           # 원래 차원 순서인 (N_rays, N_s, N_src, D_z)로 변환

        return x        # shape: (N_rays, N_s, N_src, c_out)

In [6]:
# (batch=ray=1, seq_len=samples=64, sources=3, d_z=160)
input_tensor = torch.randn(1, 32, 3, 160).to("cuda")

In [7]:
nerformer = NerFormerArchitecture(d_z=160).to("cuda")

In [8]:
nerformer_output = nerformer(input_tensor)

print(nerformer_output[0].shape, nerformer_output[1].shape)

torch.Size([1, 32, 1]) torch.Size([1, 32, 3])


In [11]:
config.hidden_size

80

In [40]:
class FastNerFormerArchitecture(nn.Module):
    def __init__(self, d_z, config):
        super(FastNerFormerArchitecture, self).__init__()

        self.d_z = d_z  # Input feature의 차원

        # input: (N_rays(=Batch), N_s, N_src, d_z)
        self.linear_1 = nn.Linear(d_z, 80, bias=False)
        
        # (N_rays, N_s, N_src, 80)
        config.hidden_size = 80
        config.intermediate_size = 80
        self.TE_1 = nn.Sequential(
            FastformerEncoder(along_dim="src", config=config, step=0),          # Pooling transformer encoder
            FastformerEncoder(along_dim="sample", config=config, step=0)           # Ray transformer encoder
        )
        self.TE_1_2 = nn.Sequential(
            FastformerEncoder(along_dim="src", config=config),          # Pooling transformer encoder
            FastformerEncoder(along_dim="sample", config=config)           # Ray transformer encoder
        )
        self.dim_linear_1 = nn.Linear(80, 40)
        # (N_rays, N_s, N_src, 40)
        config.hidden_size = 40
        config.intermediate_size = 40
        self.TE_2 = nn.Sequential(
            FastformerEncoder(along_dim="src", config=config, step=1),          # Pooling transformer encoder
            FastformerEncoder(along_dim="sample", config=config, step=1)           # Ray transformer encoder
        )
        self.TE_2_2 = nn.Sequential(
            FastformerEncoder(along_dim="src", config=config),          # Pooling transformer encoder
            FastformerEncoder(along_dim="sample", config=config)           # Ray transformer encoder
        )
        self.dim_linear_2 = nn.Linear(40, 20)
        # (N_rays, N_s, N_src, 20)

        self.weight_layer = nn.Sequential(
            nn.Linear(20, 1),
            nn.Softmax(dim=-2)      # 특정 sample에서 각 src들에 대한 값의 합이 1이 되도록 차원을 설정
        )

        # color function head
        # Output shape: (N_s, 3)
        self.c_head = nn.Sequential(
            nn.Linear(20, 20),
            nn.ReLU(),
            nn.Linear(20, 3)
        )

        # opacity function head
        # Output shape: (N_s, 1)
        self.f_head = nn.Sequential(
            nn.Linear(20, 1),
            nn.ReLU()
        )


    def forward(self, input_tensor):
        # input_tensor: (N_rays(=Batch), N_s, N_src, D_z)

        x = self.linear_1(input_tensor)     # (N_rays, N_s, N_src, 80)

        x = self.TE_1(x)                    # (N_rays, N_s, N_src, 80)
        x = self.TE_1_2(x)
        x = self.dim_linear_1(x)              # (N_rays, N_s, N_src, 40)

        x = self.TE_2(x)                    # (N_rays, N_s, N_src, 40)
        x = self.TE_2_2(x)
        x = self.dim_linear_2(x)              # (N_rays, N_s, N_src, 20)
        
        # weighted sum along dim 1
        weight = self.weight_layer(x)       # (N_rays, N_s, N_src, 1)
        per_point_features = torch.sum(weight*x, dim=-2)      # (N_rays, N_s, 20)

        # Color function
        ray_colors = self.c_head(per_point_features) # (N_rays, N_s, 3)
        # Opacity function
        ray_densities = self.f_head(per_point_features) # (N_rays, N_s, 1)

        return ray_densities, ray_colors


# (N_s, N_src, D_z) -> (N_s, N_src, D_z)
class FastformerEncoder(nn.Module):
    def __init__(self, along_dim, config, step):
        super(FastformerEncoder, self).__init__()

        self.along_dim = along_dim
        # Multi-head attention along dim
        # num_heads = 8 (Transformer 논문에서의 세팅)
        self.fastformer_layer = FastformerLayer(config=config)
        

    def forward(self, input_tensor):
        # Pooling transformer enc
        if self.along_dim == "src":
            # 배치로 들어오는 각 샘플들에 대해, N_src개 소스뷰 시퀀스를 입력으로 줌.
            # (Batch, Seq_len, Features) = (N_rays*N_s, N_src, D_z)
            shape = input_tensor.shape

            # Pooling transformer의 Batch에 해당하는
            # `N_rays` 차원과 `N_s` 차원을 합쳐준다.
            input_tensor = input_tensor.reshape(shape[0]*shape[1], shape[2], shape[3])

        # Ray transformer enc
        else:
            # 배치로 들어오는 각 소스뷰에 대해, N_s개 샘플 시퀀스를 입력으로 줌.
            # (Batch, Seq_len, Features) = (N_rays*N_src, N_s, D_Z) 
            input_tensor = input_tensor.permute(0, 2, 1, 3)
            shape = input_tensor.shape

            # Ray transformer의 Batch에 해당하는
            # `N_rays` 차원과 `N_src` 차원을 합쳐준다.
            input_tensor = input_tensor.reshape(shape[0]*shape[1], shape[2], shape[3])

        x = self.fastformer_layer(input_tensor)

        x = x.reshape(shape[0], shape[1], shape[2], shape[3])   # N_rays 차원을 분리
        
        # 원래 차원 순서인 (N_rays, N_s, N_src, D_z)로 변환
        if self.along_dim == "sample":
            x = x.permute(0, 2, 1, 3)

        return x        # shape: (N_rays, N_s, N_src, c_out)

In [20]:
fastformer_encoder = FastformerEncoder(config).to("cuda")

In [16]:
tmp_tensor = torch.randn(1, 32, 80).to("cuda")

In [21]:
tmp_output_tensor = fastformer_encoder(tmp_tensor)

In [22]:
tmp_output_tensor.shape

torch.Size([1, 32, 80])

In [41]:
nerformer = NerFormerArchitecture(d_z=160).to("cuda")
fastnerformer = FastNerFormerArchitecture(d_z=160, config=config).to("cuda")

In [28]:
input_tensor.shape

torch.Size([1, 32, 3, 160])

In [31]:
fastnerformer_output = fastnerformer(input_tensor)

torch.Size([32, 3, 80])
torch.Size([3, 32, 80])
torch.Size([32, 3, 40])
torch.Size([3, 32, 40])


In [32]:
import time

In [33]:
time.time()

1645460890.3129923

In [38]:
test_tensor = torch.randn(800, 32, 3, 160).to("cuda")

In [42]:
start = time.time()
nerformer_output = nerformer(test_tensor)
print("nerformer inference : ", time.time() - start)

start = time.time()
fastnerformer_output = fastnerformer(test_tensor)
print("fastnerformer inference : ", time.time() - start)

nerformer inference :  0.019857168197631836
fastnerformer inference :  0.00493621826171875


In [43]:
0.02 / 0.005

4.0