In [1]:
import torch
import torch.nn as nn

In [3]:
class FeedForward_Gemma(nn.Module):
    def __init__(self, cfg):
        super().__init__()

        self.fc1 = nn.Linear(cfg['emb_dim'], cfg['hidden_dim'], dtype=cfg['dtype'], bias=False)
        self.fc2 = nn.Linear(cfg['emb_dim'], cfg['hidden_dim'], dtype=cfg['dtype'], bias=False)
        self.fc3 = nn.Linear(cfg['hidden_dim'], cfg['emb_dim'], dtype=cfg['dtype'], bias=False)
    
    def forward(self, x):
        x_fc1 = self.fc1(x)
        x_fc2 = self.fc2(x)

        # GELU instead of SiLU
        x = nn.functional.gelu(x_fc1, approximate='tanh') * x_fc2
        return self.fc3(x)

In [4]:
class RMSNorm(nn.Module):
    def __init__(self, emb_dim, eps=1e-6, bias=False):
        super().__init__()

        self.eps = eps
        # initialize scale to 0 instead of 1
        self.scale = nn.Parameter(torch.zeros(emb_dim))
        self.shift = nn.Parameter(torch.zeros(emb_dim)) if bias else None
    
    def forward(self, x):
        input_dtype = x.dtype
        
        x_f = x.float()
        # RMSNorm
        var = x_f.pow(2).mean(dim=-1, keepdim=True)
        x_norm = x_f * torch.rsqrt(var + self.eps)
        
        # scale by (1+w) instead of by w
        out = x_norm * (1.0 + self.scale.float())
        if self.shift is not None:
            out = out + self.shift.float()
        
        return out.to(input_dtype)