In [1]:
import math
import argparse
DEFAULT_WIDTH_MULTIPLIER = 2.0
DEFAULT_MIN_DEXTRA_LAYERS = 4
DEFAULT_MAX_DEXTRA_LAYERS = 8
DEFAULT_BASE_GROUPS = 16
MIN_ELEMENTS_PER_GROUP = 32
DEFAULT_FFN_RED_FACTOR = 4
DEFAULT_DROPOUT = 0.1
DEFAULT_STD_DROPOUT = 0.1
ADAPTIVE_SCALE_FACTOR = 2

In [2]:

## Import Lib



def base_architecture(args):
    # DeLighT Embedding layer
    args.adaptive_input = getattr(args, "adaptive_input", False)
    args.delight_emb_map_dim = getattr(args, "delight_emb_map_dim", 128)
    args.delight_emb_out_dim = getattr(args, "delight_emb_out_dim", 128)
    # compute the max groups in GLT
    assert args.delight_emb_out_dim % MIN_ELEMENTS_PER_GROUP == 0, 'remainder({}, {}) should be equal to 0'.format(
        args.delight_emb_out_dim, MIN_ELEMENTS_PER_GROUP)
    max_groups = 2 ** math.ceil(math.log(args.delight_emb_out_dim // MIN_ELEMENTS_PER_GROUP, 2))

    args.delight_emb_max_groups = getattr(args, "delight_emb_max_groups", max_groups)
    args.delight_emb_dropout = getattr(args, "delight_emb_dropout", DEFAULT_DROPOUT)
    args.delight_emb_depth = getattr(args, "delight_emb_depth", DEFAULT_MIN_DEXTRA_LAYERS)
    args.delight_emb_width_mult = getattr(args, "delight_emb_width_mult", DEFAULT_WIDTH_MULTIPLIER)

    # Encoder arguments in DeLighT
    args.delight_enc_scaling = getattr(args, "delight_enc_scaling", 'block')
    args.delight_enc_layers = getattr(args, "delight_enc_layers", DEFAULT_MAX_DEXTRA_LAYERS)
    args.delight_enc_min_depth = getattr(args, "delight_enc_min_depth", DEFAULT_MIN_DEXTRA_LAYERS)
    args.delight_enc_max_depth = getattr(args, "delight_enc_max_depth", DEFAULT_MAX_DEXTRA_LAYERS)
    args.delight_enc_width_mult = getattr(args, "delight_enc_width_mult", DEFAULT_WIDTH_MULTIPLIER)
    args.delight_enc_ffn_red = getattr(args, "delight_enc_ffn_red", DEFAULT_FFN_RED_FACTOR)
    args.delight_enc_max_groups = getattr(args, "delight_enc_max_groups", max_groups)

    # Decoder arguments in DeLighT
    args.delight_dec_scaling = getattr(args, "delight_dec_scaling", 'block')
    args.delight_dec_layers = getattr(args, "delight_dec_layers", DEFAULT_MAX_DEXTRA_LAYERS)
    args.delight_dec_min_depth = getattr(args, "delight_dec_min_depth", DEFAULT_MIN_DEXTRA_LAYERS)
    args.delight_dec_max_depth = getattr(args, "delight_dec_max_depth", DEFAULT_MAX_DEXTRA_LAYERS)
    args.delight_dec_width_mult = getattr(args, "delight_dec_width_mult", DEFAULT_WIDTH_MULTIPLIER)
    args.delight_dec_ffn_red = getattr(args, "delight_dec_ffn_red", DEFAULT_FFN_RED_FACTOR)
    args.delight_dec_max_groups = getattr(args, "delight_dec_max_groups", max_groups)

    ## Others
    args.no_glt_shuffle = getattr(args, "no_glt_shuffle", False)
    args.glt_shuffle = not args.no_glt_shuffle
    args.define_iclr = getattr(args, "define_iclr", False)
    args.delight_dropout = getattr(args, "delight_dropout", DEFAULT_DROPOUT)

    # normalization and activation layers
    args.norm_type = getattr(args, "norm_type", 'ln')
    args.act_type = getattr(args, "act_type", 'swish')

    # ADAPTIVE INPUT AND OUTPUT PARAMS
    args.adaptive_softmax_cutoff = getattr(args, "adaptive_softmax_cutoff", None)
    args.adaptive_softmax_dropout = getattr(args, "adaptive_softmax_dropout", 0)
    args.adaptive_softmax_factor = getattr(args, 'adaptive_softmax_factor', 4)
    args.tie_adaptive_weights = getattr(args, 'tie_adaptive_weights', False)
    args.tie_adaptive_proj = getattr(args, 'tie_adaptive_proj', False)

    # Print  stats
    args.print_stats = getattr(args, "print_stats", False)
    args.src_len_ps = getattr(args, "src_len_ps", 20)
    args.tgt_len_ps = getattr(args, "tgt_len_ps", 20)

    # DROPOUTS
    args.attention_dropout = getattr(args, "attention_dropout", DEFAULT_DROPOUT)
    args.activation_dropout = getattr(args, "activation_dropout", 0.0)
    args.dropout = getattr(args, "dropout", DEFAULT_DROPOUT)
    args.delight_dropout = getattr(args, "delight_dropout", 0.0)
    args.pe_dropout = getattr(args, "pe_dropout", DEFAULT_DROPOUT)
    args.ffn_dropout = getattr(args, "ffn_dropout", DEFAULT_DROPOUT)

    # Other parameters
    args.encoder_normalize_before = getattr(args, "encoder_normalize_before", False)
    args.decoder_normalize_before = getattr(args, "decoder_normalize_before", False)

    args.share_decoder_input_output_embed = getattr(
        args, "share_decoder_input_output_embed", False
    )
    args.share_all_embeddings = getattr(args, "share_all_embeddings", False)
    args.no_token_positional_embeddings = getattr(
        args, "no_token_positional_embeddings", False
    )
    args.encoder_learned_pos = getattr(args, 'encoder_learned_pos', False)
    args.decoder_learned_pos = getattr(args, 'decoder_learned_pos', False)

    args.no_scale_embedding = getattr(args, "no_scale_embedding", False)

args = argparse.Namespace()
base_architecture(args)

"""

# print(args)
from fake_bert.models.dexTraUnit import DExTraUnit

DExTraUnit(
    in_features=args.delight_emb_map_dim,
    in_proj_features=args.delight_emb_out_dim // 2,
    out_features=args.delight_emb_out_dim,
    width_multiplier=args.delight_emb_width_mult,
    dextra_depth=args.delight_emb_depth,
    dextra_dropout=args.delight_dropout,
    max_glt_groups=args.delight_emb_max_groups,
    act_type=args.act_type,
    norm_type=args.norm_type,
    use_bias=True,
    is_iclr_version=args.define_iclr,
    glt_shuffle=args.glt_shuffle,
)
"""
# self.dextra_layer = DExTraUnit(
#             in_features=self.input_features,
#             in_proj_features=self.embedding_dim // 2,
#             out_features=self.embedding_dim,
#             width_multiplier=args.delight_emb_width_mult,
#             dextra_depth=args.delight_emb_depth,
#             dextra_dropout=args.delight_dropout,
#             max_glt_groups=args.delight_emb_max_groups,
#             act_type=args.act_type,
#             norm_type=args.norm_type,
#             use_bias=use_bias,
#             is_iclr_version=args.define_iclr,
#             glt_shuffle=args.glt_shuffle,
#         )

'\n\n# print(args)\nfrom fake_bert.models.dexTraUnit import DExTraUnit\n\nDExTraUnit(\n    in_features=args.delight_emb_map_dim,\n    in_proj_features=args.delight_emb_out_dim // 2,\n    out_features=args.delight_emb_out_dim,\n    width_multiplier=args.delight_emb_width_mult,\n    dextra_depth=args.delight_emb_depth,\n    dextra_dropout=args.delight_dropout,\n    max_glt_groups=args.delight_emb_max_groups,\n    act_type=args.act_type,\n    norm_type=args.norm_type,\n    use_bias=True,\n    is_iclr_version=args.define_iclr,\n    glt_shuffle=args.glt_shuffle,\n)\n'

## Test Embbedding

In [3]:
from fake_bert.models.dextra_emb import DExTraEmb
from fake_bert.models.nn_functions import get_embedding_layer
from fake_bert.datasets.dataset import RandomGenerator
from torch.utils.data import DataLoader

padding_idx = 3
num_embeddings = 150
# delight_emb_map_dim = 128

map_layer = get_embedding_layer(num_embeddings=num_embeddings,
                embedding_dim=args.delight_emb_map_dim,
                padding_idx=padding_idx)
emb = DExTraEmb(args, map_layer=map_layer)

dataset = RandomGenerator()
dataloader = DataLoader(dataset, batch_size=10, shuffle=False)
data_sample = next(iter(dataloader))
x = data_sample[0]

out = emb(x)
out.shape


torch.Size([10, 50, 128])

## Test Transformers

In [13]:
from typing import Callable, Optional, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from fake_bert.models.nn_functions import get_embedding_layer
from fake_bert.datasets.dataset import RandomGenerator
from torch.utils.data import DataLoader
from fake_bert.models.dextra_emb import DExTraEmb
from torch import Tensor
from fake_bert.models.dextra_unit import DExTraUnit
model = nn.TransformerEncoderLayer(128, 1, 64)

padding_idx = 3
num_embeddings = 150
# delight_emb_map_dim = 128

class DeLightEncoderLayer(nn.TransformerEncoderLayer):
    """Some Information about DeLightEncoderLayer"""
    def __init__(self, d_model: int, nhead: int, dim_feedforward: int = 2048, dropout: float = 0.1, activation: Union[str, Callable[[Tensor], Tensor]] = F.relu, layer_norm_eps: float = 0.00001, batch_first: bool = False, norm_first: bool = False, device=None, dtype=None) -> None:
        super().__init__(d_model, nhead, dim_feedforward, dropout, activation, layer_norm_eps, batch_first, norm_first, device, dtype)

        self.dextra_layer = DExTraUnit(
            in_features=args.delight_emb_map_dim,
            in_proj_features=args.delight_emb_out_dim // 2,
            out_features=args.delight_emb_out_dim,
            width_multiplier=args.delight_emb_width_mult,
            dextra_depth=args.delight_emb_depth,
            dextra_dropout=args.delight_dropout,
            max_glt_groups=args.delight_emb_max_groups,
            act_type=args.act_type,
            norm_type=args.norm_type,
            use_bias=True,
            is_iclr_version=args.define_iclr,
            glt_shuffle=args.glt_shuffle,
        )

    def forward_dextra(self, x):
        x = self.dextra_layer(x)
        return x
    
    def forward(self, src: Tensor, src_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None) -> Tensor:
        # print(src.shape)
        # print(src.dtype)
        src = self.forward_dextra(src)
        # print(src.shape)
        # print(src.dtype)
        return super().forward(src, src_mask, src_key_padding_mask)


In [14]:
class MyModel(nn.TransformerEncoderLayer):
    def __init__(self, d_model: int, nhead: int, dim_feedforward: int = 2048, dropout: float = 0.1, activation: Union[str, Callable[[Tensor], Tensor]] = ..., layer_norm_eps: float = 0.00001, batch_first: bool = False, norm_first: bool = False, device=None, dtype=None) -> None:
        super().__init__(d_model, nhead, dim_feedforward, dropout, activation, layer_norm_eps, batch_first, norm_first, device, dtype)

In [15]:
model = nn.TransformerEncoderLayer(d_model=128, nhead=1, dim_feedforward=64)
model(out).shape

torch.Size([10, 50, 128])


torch.Size([10, 50, 128])

In [16]:
model = DeLightEncoderLayer(d_model=128, nhead=1, dim_feedforward=64)
model(out).shape

torch.Size([10, 50, 128])


torch.Size([10, 50, 128])

In [None]:
model

TransformerEncoderLayer(
  (self_attn): MultiheadAttention(
    (out_proj): NonDynamicallyQuantizableLinear(in_features=128, out_features=128, bias=True)
  )
  (linear1): Linear(in_features=128, out_features=64, bias=True)
  (dropout): Dropout(p=0.1, inplace=False)
  (linear2): Linear(in_features=64, out_features=128, bias=True)
  (norm1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
  (norm2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
  (dropout1): Dropout(p=0.1, inplace=False)
  (dropout2): Dropout(p=0.1, inplace=False)
)