In [None]:
import torch
import torch.nn as nn
import numpy as np
from einops import einsum, rearrange

class Linear(nn.Module):
    def __init__(self, in_features, out_features, device=None, dtype=None):
        super().__init__()
        # Specify mean for the param initialization
        mean, std = 0, np.sqrt(2/(in_features+out_features))
        # Init the params from the normal distribution with said mean and std
        param = torch.normal(mean=mean, std=std, size=(out_features, in_features), dtype=dtype, device=device)
        # Truncate
        nn.init.trunc_normal_(param, a =-3 * std, b = 3 * std)
        # Init the weight via the nn.Parameter
        self.weight = nn.Parameter(data=param)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return einsum(self.weight, x, "d_out d_in, ... d_in -> ... d_out")

In [None]:
class Embedding(nn.Module):
    def __init__(self, num_embeddings, embedding_dim, device=None, dtype=None):
        super().__init__()
        # Specify mean for the param initialization
        mean, std = 0, 1
        # Init the params from the normal distribution with said mean and std
        param = torch.normal(mean=mean, std=std, size=(num_embeddings, embedding_dim), dtype=dtype, device=device)
        # Truncate
        nn.init.trunc_normal_(param, a=-3, b=3)
        # Init the embedding via the nn.Parameter
        self.weight = nn.Parameter(data=param)
    
    def forward(self, token_ids: torch.Tensor) -> torch.Tensor:
        return self.weight[token_ids]


In [None]:
class RMSNorm(nn.Module):
    def __init__(self, d_model: int, eps: float = 1e-5, device=None, dtype=None):
        super().__init__()
        self.d_model = d_model

        param = torch.tensor([1] * d_model, device=device, dtype=dtype)
        self.weight = nn.Parameter(data=param)

        self.eps = eps
        self.device = device
        self.dtype = dtype
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        in_dtype = x.dtype 
        x = x.to(torch.float32)
        rms = np.sqrt(1 / self.d_model * rearrange(x ** 2, "B T C -> B (T C)") + self.eps)
        rmsnorm = (x / rms) * self.weight
        return rmsnorm.to(dtype=in_dtype)
