In [1]:
import torch
import torch.nn as nn
from dataclasses import dataclass, field
from typing import List, Tuple

import math
import torch.nn.functional as F
import numpy as np

from fairseq.dataclass import ChoiceEnum, FairseqDataclass
from fairseq.modules import (
    Fp32GroupNorm,
    Fp32LayerNorm,
    GradMultiply,
    GumbelVectorQuantizer,
    LayerNorm,
    MultiheadAttention,
    SamePad,
    TransposeLast,
)

from fairseq.modules.transformer_sentence_encoder import init_bert_params

In [2]:
from TransformerSentenceEncoderLayer import TransformerSentenceEncoderLayer, TransformerSentenceEncoderLayerConfig

In [44]:
@dataclass
class StudentTransformerEncoderConfig(FairseqDataclass):
    
    layer_setting: TransformerSentenceEncoderLayerConfig = field(
        default=TransformerSentenceEncoderLayerConfig(),
        metadata={"help": "Default setting of TransformerSentenceEncoderLayerConfig"}
    )
    
    # layer setting after time reduction layer
    # You need to change this inside the class
    '''
    smaller_layer_setting: TransformerSentenceEncoderLayerConfig = field(
        default=TransformerSentenceEncoderLayerConfig(
            encoder_embed_dim = 384,
            encoder_ffn_embed_dim = 1536,
            encoder_attention_heads = 6 
        ),
        metadata={"help": "Time reduction layer of TransformerSentenceEncoderLayerConfig"}
    )
    '''
    encoder_layers: int = field(
        default=6,
        metadata={"help": "num encoder layers in the transformer"}
    )
    
    conv_pos: int = field(
        default=128,
        metadata={"help": "number of filters for convolutional positional embeddings"},
    )
    
    conv_pos_groups: int = field(
        default=16,
        metadata={"help": "number of groups for convolutional positional embedding"},
    )
    
    encoder_layerdrop: float = field(
        default=0.0,
        metadata={"help": "probability of dropping a transformer layer"}
    )
    
    # Time-reduction layer
    able_tr_layer: bool = field(
        default=True,
        metadata={"help": "applying time reduction layer or not"}
    )
    
    type_of_tr_layer: str = field(
        default="fcl", # or conv1d
        metadata={"help": "type of time reduction layer"}
    )
    
    tr_conv1d_kernel_stride: str = field(
        default="(2, 2)",
        metadata={"help": "If tr is conv1d, list of kernel and stride for conv1d"}
    )
    
    tr_fcl_output_factor: int = field(
        default=2,
        metadata={"help": "Factor to reduce time length"}
    )
    
    tr_layer_floor: int = field(
        default=3,
        metadata={"help": "which floor should time reduction layer put in"}
    )

In [55]:
class StudentTransformerEncoder(nn.Module):
    def __init__(self,
                cfg: StudentTransformerEncoderConfig
                ):
        
        super().__init__()
        
        args = cfg.layer_setting
        
        self.dropout = args.dropout
        self.embedding_dim = args.encoder_embed_dim

        self.pos_conv = nn.Conv1d(
            self.embedding_dim,
            self.embedding_dim,
            kernel_size=cfg.conv_pos,
            padding=cfg.conv_pos // 2,
            groups=cfg.conv_pos_groups,
        )
        dropout = 0
        std = math.sqrt((4 * (1.0 - dropout)) / (cfg.conv_pos * self.embedding_dim))
        nn.init.normal_(self.pos_conv.weight, mean=0, std=std)
        nn.init.constant_(self.pos_conv.bias, 0)    
        
        self.pos_conv = nn.utils.weight_norm(self.pos_conv, name="weight", dim=2)
        self.pos_conv = nn.Sequential(self.pos_conv, SamePad(cfg.conv_pos), nn.GELU())

        self.tr_fcl_output_factor = None
        if not cfg.able_tr_layer:
            tr_layer = None  
        else:
            if cfg.type_of_tr_layer == 'fcl':
                self.tr_fcl_output_factor = cfg.tr_fcl_output_factor
                # Input length will be verified first.
                tr_layer = nn.Linear(
                    self.embedding_dim * self.tr_fcl_output_factor,
                    self.embedding_dim
                )
                nn.init.xavier_uniform_(tr_layer.weight)
                
            elif cfg.type_of_tr_layer == 'conv1d':
                (kernel, stride) = eval(cfg.tr_conv1d_kernel_stride)
                tr_layer = nn.Conv1d(
                    self.embedding_dim,
                    self.embedding_dim,
                    kernel_size=kernel,
                    stride=stride
                )
            else:
                print ("Wrong type of time reduction layer.")             
        self.tr_layer = tr_layer
        
        if not cfg.able_tr_layer:
            self.layers = nn.ModuleList(
                [
                    TransformerSentenceEncoderLayer(cfg.layer_setting)
                    for _ in range(cfg.encoder_layers)
                ]
            )
        else:
            # To do: # get list of {tr_layer, and where to put in}
            # And put in literally using nn.module list
            '''
            self.layers = nn.ModuleList(
                    [
                        TransformerSentenceEncoderLayer(cfg.layer_setting)
                        for _ in range(cfg.encoder_layers)
                    ].insert(cfg.tr_layer_floor, self.tr_layer)
            )
            
            '''
            self.layers = nn.Sequential(
                nn.ModuleList(
                    [
                        TransformerSentenceEncoderLayer(cfg.layer_setting)
                        for _ in range(cfg.tr_layer_floor)
                    ],
                ),
                nn.ModuleList(
                    [
                        self.tr_layer
                    ]
                     ),
                nn.ModuleList(
                    [
                        TransformerSentenceEncoderLayer(cfg.layer_setting)
                        for _ in range(cfg.encoder_layers - cfg.tr_layer_floor)
                    ],                                
                )
            
            )

        self.layer_norm_first = args.layer_norm_first
        self.layer_norm = LayerNorm(self.embedding_dim)
        self.layerdrop = cfg.encoder_layerdrop

        self.apply(init_bert_params)

    def forward(self, x, padding_mask=None, layer=None):
        x, layer_results = self.extract_features(x, padding_mask, layer)

        if self.layer_norm_first and layer is None:
            x = self.layer_norm(x)

        return x, layer_results

    def extract_features(self, x, padding_mask=None, tgt_layer=None):

        if padding_mask is not None:
            x = index_put(x, padding_mask, 0)

        x_conv = self.pos_conv(x.transpose(1, 2))
        x_conv = x_conv.transpose(1, 2)
        x = x + x_conv

        if not self.layer_norm_first:
            x = self.layer_norm(x)

        x = F.dropout(x, p=self.dropout, training=self.training)

        # B x T x C -> T x B x C
        x = x.transpose(0, 1)

        layer_results = []
        r = None
        
        if self.tr_layer is None:
            for j, layer in enumerate(self.layers):
                dropout_probability = np.random.random()
                if not self.training or (dropout_probability > self.layerdrop):
                    x, z = layer(x, self_attn_padding_mask=padding_mask, need_weights=False)
                    if tgt_layer is not None:
                        layer_results.append((x, z))
                if j == tgt_layer:
                    r = x
                    break                
        else:
            # To do
            """
            for j, layer in enumerate(self.layers):
                if isinstance(layer, TransformerEncoderLayer):
                    dropout_probability = np.random.random()
                    if not self.training or (dropout_probability > self.layerdrop):
                        x, z = layer(x, self_attn_padding_mask=padding_mask, need_weights=False)
                        if tgt_layer is not None:
                            layer_results.append((x, z))
                    if i == tgt_layer:
                        r = x
                        break
                elif isinstance(layer, torch.nn.Conv1d): 
                    x = x.permute(1, 2, 0).contiguous()
                    x = layer(x)
                    x = x.permute(2, 0, 1).contiguous()
                elif isinstance(layer, torch.nn.Linear):
                    # T x B x C
                    x = self.concat_channelwise(x)
                    x = layer(x) 
            """
            
            for i, layer_block in enumerate(self.layers):
                # I write this code in this way intentionally
                # TransformerEnocder             
                if i == 0:
                    for j, layer in enumerate(layer_block):
                        dropout_probability = np.random.random()
                        if not self.training or (dropout_probability > self.layerdrop):
                            x, z = layer(x, self_attn_padding_mask=padding_mask, need_weights=False)
                            if tgt_layer is not None:
                                layer_results.append((x, z))
                        if i == tgt_layer:
                            r = x
                            break
                # Time Reduction
                elif i == 1:   
                    for j, layer in enumerate(layer_block):
                        if isinstance(layer, torch.nn.Conv1d): 
                            x = x.permute(1, 2, 0).contiguous()
                            x = layer(x)
                            x = x.permute(2, 0, 1).contiguous()
                        elif isinstance(layer, torch.nn.Linear):
                            # T x B x C
                            x = self.concat_channelwise(x)
                            x = layer(x)
                # TransformerEncoder
                elif i == 2:
                    for j, layer in enumerate(layer_block):
                        dropout_probability = np.random.random()
                        if not self.training or (dropout_probability > self.layerdrop):
                            x, z = layer(x, self_attn_padding_mask=padding_mask, need_weights=False)
                            if tgt_layer is not None:
                                layer_results.append((x, z))
                        if i == tgt_layer:
                            r = x
                            break
        
        if r is not None:
            x = r

        # T x B x C -> B x T x C
        x = x.transpose(0, 1)

        return x, layer_results    
    
    def max_positions(self):
        """Maximum output length supported by the encoder."""
        return self.args.max_positions

    def upgrade_state_dict_named(self, state_dict, name):
        """Upgrade a (possibly old) state dic\t for new versions of fairseq."""
        return state_dict
    
    def concat_channelwise(self, x):
        # x is shaped T x B x C
        time_length, batch, channel = x.size()
        how_many_pad = self.tr_fcl_output_factor - time_length % self.tr_fcl_output_factor 
        if how_many_pad != 0:
            zero_pad = torch.zeros([how_many_pad, batch, channel]).cuda()
            x = torch.cat([x, zero_pad], dim = 0)
        time_length += how_many_pad

        result = torch.tensor([]).cuda()
        
        j = 0
        while (j < self.tr_fcl_output_factor):
            # (T / factor) X B x C
            tensor_to_concat = x[j::self.tr_fcl_output_factor,:,:]
            result = torch.cat([result, tensor_to_concat], dim = 2)
            j += 1
        # (T / factor) X B X (C * factor)
        return result         

In [56]:
config = StudentTransformerEncoderConfig(
    #/able_tr_layer = False,
    type_of_tr_layer = "fcl",
    tr_fcl_output_factor = 2
    )

config

StudentTransformerEncoderConfig(_name=None, layer_setting=TransformerSentenceEncoderLayerConfig(_name=None, encoder_embed_dim=768, encoder_ffn_embed_dim=3072, encoder_attention_heads=12, dropout=0.1, attention_dropout=0.1, activation_dropout=0.0, activation_fn='gelu', layer_norm_first=False), encoder_layers=6, conv_pos=128, conv_pos_groups=16, encoder_layerdrop=0.0, able_tr_layer=True, type_of_tr_layer='fcl', tr_conv1d_kernel_stride='(2, 2)', tr_fcl_output_factor=2, tr_layer_floor=3)

In [57]:
model = StudentTransformerEncoder(config)

model

StudentTransformerEncoder(
  (pos_conv): Sequential(
    (0): Conv1d(768, 768, kernel_size=(128,), stride=(1,), padding=(64,), groups=16)
    (1): SamePad()
    (2): GELU()
  )
  (tr_layer): Linear(in_features=1536, out_features=768, bias=True)
  (layers): ModuleList()
  (layer_norm): FusedLayerNorm(torch.Size([768]), eps=1e-05, elementwise_affine=True)
)

In [58]:
from torchinfo import summary

summary(model, (1, 333, 768), col_names = ["input_size", "output_size", "num_params"], depth = 5)

Layer (type:depth-idx)                   Input Shape               Output Shape              Param #
StudentTransformerEncoder                --                        --                        --
├─ModuleList: 1-1                        --                        --                        --
├─Sequential: 1-2                        [1, 768, 333]             [1, 768, 333]             --
│    └─Conv1d: 2-1                       [1, 768, 333]             [1, 768, 334]             4,719,488
│    └─SamePad: 2-2                      [1, 768, 334]             [1, 768, 333]             --
│    └─GELU: 2-3                         [1, 768, 333]             [1, 768, 333]             --
├─FusedLayerNorm: 1-3                    [1, 333, 768]             [1, 333, 768]             1,536
Total params: 4,721,024
Trainable params: 4,721,024
Non-trainable params: 0
Total mult-adds (G): 3.62
Input size (MB): 1.02
Forward/backward pass size (MB): 4.10
Params size (MB): 18.88
Estimated Total Size (MB): 24.0