In [6]:
import torch
import torch.nn as nn

In [11]:


class RMSNorm(nn.Module):
    def __init__(self, normalized_shape, eps=1e-8):
        super().__init__()
        if isinstance(normalized_shape, int):
            normalized_shape = (normalized_shape,)
        self.normalized_shape = normalized_shape
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(self.normalized_shape))

    def forward(self, x):
        rms = torch.sqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)
        x_normed = x / rms
        return x_normed * self.weight

In [12]:


batch_size = 4
feature_dim = 16

x = torch.randn(batch_size, feature_dim)

custom_rms = RMSNorm(feature_dim)
builtin_rms = nn.RMSNorm(feature_dim)

with torch.no_grad():
    builtin_rms.weight.copy_(custom_rms.weight)

out_custom = custom_rms(x)
out_builtin = builtin_rms(x)

print("отклонение:", (out_custom - out_builtin).abs().max())
#отличный результат

отклонение: tensor(1.1921e-07, grad_fn=<MaxBackward1>)


In [13]:


class ExpPlusCos(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, y):
        exp_x = torch.exp(x)
        cos_y = torch.cos(y)
        ctx.save_for_backward(exp_x, y)
        return exp_x + cos_y

    @staticmethod
    def backward(ctx, grad_output):
        exp_x, y = ctx.saved_tensors
        grad_x = grad_output * exp_x
        grad_y = grad_output * (-torch.sin(y))
        return grad_x, grad_y

In [14]:
x = torch.randn(3, 3, requires_grad=True)
y = torch.randn(3, 3, requires_grad=True)

# custom realisation
custom_out = ExpPlusCos.apply(x, y)
custom_out.sum().backward()
grad_x_custom = x.grad.clone()
grad_y_custom = y.grad.clone()

x.grad = None
y.grad = None

# standart realisation
standard_out = torch.exp(x) + torch.cos(y)
standard_out.sum().backward()
grad_x_standard = x.grad
grad_y_standard = y.grad

print("Max abs diff X grad:", (grad_x_custom - grad_x_standard).abs().max().item())
print("Max abs diff Y grad:", (grad_y_custom - grad_y_standard).abs().max().item())
#разницы нету - все отлично

Max abs diff X grad: 0.0
Max abs diff Y grad: 0.0
