In [2]:
from dataclasses import dataclass
from transformers import PreTrainedModel, PreTrainedTokenizerFast
from composer.models.huggingface import HuggingFaceModel

from functools import partial
import torch
import torch.nn as nn

In [1]:
@dataclass
class WabiSabiConfig:
    """
    Important ratios:
    - d_model should be a multiple of n_heads
    - d_q, d_k, d_v are all equal to d_model / n_heads
    """

    d_model: int = 2048
    n_heads: int = 16
    n_layers: int = 24
    vocab_size: int = 50368


# Flash attention, multi-query attention
class WabiSabiBlock(nn.Module):
    def __init__(
        self,
        d_model: int,
        n_heads: int,
    ):
        super().__init__()
        self.layer_norm_before_attention = nn.LayerNorm(d_model)


class WabiSabiModel(PreTrainedModel):
    def __init__(self, config: WabiSabiConfig):
        self.config = config
        self.tokens_to_embeddings = nn.Embedding(
            num_embeddings=config.vocab_size, embedding_dim=config.d_model
        )
        self.embeddings_to_logits = nn.Linear(
            in_features=config.d_model, out_features=config.vocab_size
        )

        # https://paperswithcode.com/method/weight-tying
        self.embeddings_to_logits.weight = self.tokens_to_embeddings.weight

        # Turned into low precision layernorm by composer
        # https://docs.mosaicml.com/projects/composer/en/latest/method_cards/low_precision_layernorm.html
        self.layer_norm_final = nn.LayerNorm(config.d_model)

        # Embedding fraction: page 7 of GLM-130B paper https://arxiv.org/abs/2210.02414
        # self.embedding_fraction = config.embedding_fraction

        # initialize parameters
        # notes from MPT/nanoGPT/transformers
        # 1. residual projections (e.g. linear layers that project to d_model) are divided
        # by 1 / sqrt(num_layers)
        # 2. layer norm weights are set to one (PyTorch sets this by default; skip)
        # 3. all others are initialized with normal distribution with mean 0 and std 0.02
        # Note: MPT uses kaiming_normal; I'll go for this as well
        def init_weights(module: nn.Module):
            if isinstance(module, nn.Linear):
                nn.init.kaiming_normal_(module.weight)
                if module._is_residual_projection:
                    with torch.no_grad():
                        module.weight.div_(torch.sqrt(config.n_layers))

            elif isinstance(module, nn.Embedding):
                nn.init.kaiming_normal_(module.weight)

        # disable bias in all modules
        # note for later: if you want to enable bias, should remember to zero out all biases
        # in init_weights
        def disable_bias(module: nn.Module):
            if hasattr(module, "bias") and isinstance(module.bias, nn.Parameter):
                module.register_parameter("bias", None)

        self.apply(init_weights)
        self.apply(disable_bias)

    # TODO: kwargs are for other HuggingFace generate params. Implement if needed.
    def forward(self, input_ids: torch.LongTensor, **kwargs):
        x = self.tokens_to_embeddings(input_ids)

        # MPT doesn't use embedding fraction
        # x = (x * self.embedding_fraction) + (x.detach() * (1 - self.embedding_fraction))
        # blocks

        x = self.layer_norm_final(x)
        x = self.embeddings_to_logits(x)
        return x

SyntaxError: incomplete input (2965947075.py, line 2)

In [8]:
def apply_fn(module: nn.Module):
    print("APPLY")


class TestModule(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = nn.Linear(10, 10)
        self.linear_list = nn.ModuleList([nn.Linear(10, 10) for _ in range(10)])

        print("CHILDREN")
        for ind, child in enumerate(self.children()):
            print(ind, child)

        print("MODULES")
        for ind, module in enumerate(self.modules()):
            print(ind, module)

        self.apply(apply_fn)

    def forward(self, x):
        return x


TestModule()

CHILDREN
0 Linear(in_features=10, out_features=10, bias=True)
1 ModuleList(
  (0-9): 10 x Linear(in_features=10, out_features=10, bias=True)
)
MODULES
0 TestModule(
  (linear): Linear(in_features=10, out_features=10, bias=True)
  (linear_list): ModuleList(
    (0-9): 10 x Linear(in_features=10, out_features=10, bias=True)
  )
)
1 Linear(in_features=10, out_features=10, bias=True)
2 ModuleList(
  (0-9): 10 x Linear(in_features=10, out_features=10, bias=True)
)
3 Linear(in_features=10, out_features=10, bias=True)
4 Linear(in_features=10, out_features=10, bias=True)
5 Linear(in_features=10, out_features=10, bias=True)
6 Linear(in_features=10, out_features=10, bias=True)
7 Linear(in_features=10, out_features=10, bias=True)
8 Linear(in_features=10, out_features=10, bias=True)
9 Linear(in_features=10, out_features=10, bias=True)
10 Linear(in_features=10, out_features=10, bias=True)
11 Linear(in_features=10, out_features=10, bias=True)
12 Linear(in_features=10, out_features=10, bias=True)
APP

TestModule(
  (linear): Linear(in_features=10, out_features=10, bias=True)
  (linear_list): ModuleList(
    (0-9): 10 x Linear(in_features=10, out_features=10, bias=True)
  )
)