## Scratch pad to try and understand https://github.com/facebookresearch/segment-anything/tree/main/segment_anything

In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from typing import Type

In [18]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

if device.type == "cuda":
    print(f'Device name {torch.cuda.get_device_name(0)}')


Using device: cuda
Device name Tesla T4


In [2]:
class MLPBlock(nn.Module):
    def __init__(
        self,
        embedding_dim: int,
        mlp_dim: int,
        activation: Type[nn.Module] = nn.GELU
    ):
        super().__init__()
        self.linear_1 = nn.Linear(embedding_dim, mlp_dim)
        self.linear_2 = nn.Linear(mlp_dim, embedding_dim)
        self.activation = activation()
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.linear_2(self.activation(self.linear_1(x)))
        

In [3]:
class LayerNorm2d(nn.Module):
    def __init__(self, num_channels: int, eps: float = 1e-6):
        super().__init__()
        self.weight = nn.Parameter(torch.ones(num_channels))
        self.bias = nn.Parameter(torch.zeros(num_channels))
        self.eps = eps
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        u = x.mean(1, keepdim=True)
        s = (x - u).pow(2).mean(1, keepdim=True)
        x = (x - u) / torch.sqrt(s + self.eps)
        x = self.weight[:, None, None] * x + self.bias[:, None, None]
        return x