In [None]:
from init_notebook import *

In [None]:
class AttentionConvBlock(nn.Module):
    def __init__(
            self,
            in_channels: int,
            out_channels: int,
            kernel_size: int = 3,
            padding: Optional[int] = None,
            activation: Union[None, str, Callable] = None,
            dropout: float = 0.,
            residual: bool = True,
    ):
        super().__init__()
        self._act = activation
        self._residual = residual
        
        if padding is None:
            padding = int(math.floor(kernel_size / 2))
        
        self.q = nn.Conv2d(in_channels, out_channels, kernel_size, padding=padding)
        self.k = nn.Conv2d(in_channels, out_channels, kernel_size, padding=padding)
        self.v = nn.Conv2d(in_channels, out_channels, kernel_size, padding=padding)
        self.act = activation_to_callable(activation)
        if dropout > 0:
            self.dropout = nn.Dropout2d(dropout)

    def extra_repr(self):
        msg = f"residual={self._residual}"
        msg = f"{msg}, activation={repr(self._act)}"
        return msg
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        B, C, H, W = x.shape
        
        q = self.q(x)
        k = self.k(x)
        v = self.v(x)

        attn = q @ k
        attn = F.softmax(attn.view(B, -1, H * W), dim=-1).view(B, -1, H, W)
        if hasattr(self, "dropout"):
            attn = self.dropout(attn)
            
        y = attn @ v

        if self.act is not None:
            y = self.act(y)
    
        return y

    

m = AttentionConvBlock(3, 10, activation="gelu")
print(f"params: {num_module_parameters(m):,}")
inp = torch.rand(1, 3, 16, 16)
outp = m(inp)
print(f"{inp.shape} -> {outp.shape}")
display(m)
display(VF.to_pil_image(outp[0, :3]))

In [None]:
class AttentionConvBlock(nn.Module):
    def __init__(
            self,
            in_channels: int,
            out_channels: int,
            kernel_size: int = 3,
            padding: Optional[int] = None,
            activation: Union[None, str, Callable] = None,
            dropout: float = 0.,
            residual: bool = False,
            norm: Optional[str] = None,
    ):
        super().__init__()
        self._act = activation
        self._residual = residual

        if padding is None:
            padding = int(math.floor(kernel_size / 2))

        self.q = nn.Conv2d(in_channels, out_channels, kernel_size, padding=padding)
        self.k = nn.Conv2d(in_channels, out_channels, kernel_size, padding=padding)
        self.v = nn.Conv2d(in_channels, out_channels, kernel_size, padding=padding)
        self.attn_conv3 = nn.Conv2d(out_channels, out_channels, 3, padding=1)
        self.attn_conv5 = nn.Conv2d(out_channels, out_channels, 5, padding=2)
        self.act = activation_to_callable(activation)
        self.norm = normalization_to_module(norm, out_channels)

        if dropout > 0:
            self.dropout = nn.Dropout2d(dropout)

    def extra_repr(self):
        return f"residual={self._residual}, activation={repr(self._act)}"

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        B, C, H, W = x.shape
        residual = x

        if self.norm is not None:
            x = self.norm(x)

        q = F.relu(self.q(x))
        k = F.relu(self.k(x))
        v = self.v(x)

        attn = q @ k
        if hasattr(self, "dropout"):
            attn = self.dropout(attn)

        attn = attn + self.attn_conv3(attn) + self.attn_conv5(attn)
        
        y = attn @ v

        if self.act is not None:
            y = self.act(y)

        if self._residual:
            y = y + residual

        return y


m = AttentionConvBlock(3, 10, activation="gelu")
print(f"params: {num_module_parameters(m):,}")
inp = torch.rand(1, 3, 16, 16)
outp = m(inp)
print(f"{inp.shape} -> {outp.shape}")
display(m)
display(VF.to_pil_image(outp[0, :3]))

In [None]:
class Conv2dDepth(nn.Module):
    def __init__(
            self,
            in_channels: int,
            out_channels: int,
            *args, **kwargs,
    ):
        super().__init__()

        self.depth_conv = nn.Conv2d(in_channels, out_channels, *args, **kwargs)
        self.point_conv = nn.Conv2d(out_channels, out_channels, 1)
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        y = self.depth_conv(x)
        y = self.point_conv(y)
        
        return y

    

m = Conv2dDepth(3, 10, 3, 1)
print(f"params: {num_module_parameters(m):,}")
inp = torch.rand(1, 3, 16, 16)
outp = m(inp)
print(f"{inp.shape} -> {outp.shape}")
display(m)
display(VF.to_pil_image(outp[0, :3]))

In [None]:
def show(v):
    display((v * 100).to(torch.int))
    
v = torch.rand(1, 2, 3, 3, generator=torch.Generator().manual_seed(23))
show(v)
show(F.softmax(v, dim=-1))
show(F.softmax(v.view(1, 2, 9), dim=-1).view(1, 2, 3, 3))

In [None]:
def clip_module_weights(module: nn.Module, max_magnitude: float):
    with torch.no_grad():
        for param in module.parameters():
            param[:] = param.clamp(-max_magnitude, max_magnitude)
            print(param.max(), param.shape)

clip_module_weights(m, 1.)

In [None]:
size = 28
while size > 4 and (size // 2) % 2 == 0:
    print(size)
    size //= 2
    

In [None]:
from src.models.efficientvit.ops import LiteMLA

m = LiteMLA(16, 32)
print(f"params: {num_module_parameters(m):,}")
inp = torch.rand(1, 16, 16, 16)
outp = m(inp)
print(f"{inp.shape} -> {outp.shape}")
display(m)
display(VF.to_pil_image(outp[0, :3]))