In [1]:
from typing import Literal

from pydantic import BaseModel
import torch
from torch import Tensor, nn

In [2]:
dtype = torch.float32

In [3]:
from typing import TYPE_CHECKING


class Config(BaseModel):
    dim_model: int


class Block(nn.Module):
    def __init__(self, config: Config) -> None:
        super().__init__()
        self.in_proj = nn.Linear(config.dim_model, config.dim_model, bias=False)
        self.init_weights()

    def init_weights(self) -> None:
        with torch.no_grad():
            self.in_proj.weight.data.fill_(0.0)

    def forward(self, x: Tensor) -> Tensor:
        return self.in_proj(x)

    if TYPE_CHECKING:
        def __call__(self, x: Tensor) -> Tensor: ...

In [4]:
X = Tensor([1, 2])
Y = Tensor([10, 20])

config = Config(dim_model=2)
block = Block(config)

print(block.in_proj.weight)

Z = block(X)
print(block.in_proj.weight)

loss = (Z - Y).sum()

print(block.in_proj.weight.data.grad)
print(loss)

loss.backward()

print(block.in_proj.weight.data.grad)

Parameter containing:
tensor([[0., 0.],
        [0., 0.]], requires_grad=True)
Parameter containing:
tensor([[0., 0.],
        [0., 0.]], requires_grad=True)
None
tensor(-30., grad_fn=<SumBackward0>)
None
