In [1]:
from ttt import TTTMLP, TTTConfig

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from ttt import TTTConfig

class PatchEmbedding(nn.Module):
    def __init__(self, in_channels, patch_size, emb_size, img_size):
        super(PatchEmbedding, self).__init__()
        self.patch_size = patch_size
        self.projection = nn.Conv2d(in_channels, emb_size, kernel_size=patch_size, stride=patch_size)
        self.cls_token = nn.Parameter(torch.randn(1, 1, emb_size))
        self.positions = nn.Parameter(torch.randn((img_size // patch_size) ** 2 + 1, emb_size))

    def forward(self, x):
        B, C, H, W = x.shape
        x = self.projection(x)  # [B, emb_size, H/P, W/P]
        x = x.flatten(2)  # [B, emb_size, N]
        x = x.transpose(1, 2)  # [B, N, emb_size]
        cls_tokens = self.cls_token.expand(B, -1, -1)  # [B, 1, emb_size]
        x = torch.cat((cls_tokens, x), dim=1)  # [B, N+1, emb_size]
        x = x + self.positions  # [B, N+1, emb_size]
        return x

class MultiHeadSelfAttention(nn.Module):
    def __init__(self, emb_size, num_heads):
        super(MultiHeadSelfAttention, self).__init__()
        self.emb_size = emb_size
        self.num_heads = num_heads
        self.qkv = nn.Linear(emb_size, emb_size * 3)
        self.projection = nn.Linear(emb_size, emb_size)

    def forward(self, x):
        B, N, _ = x.shape
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.emb_size // self.num_heads)
        q, k, v = qkv.permute(2, 0, 3, 1, 4)
        qk = (q @ k.transpose(-2, -1)) * (self.emb_size ** -0.5)
        attn = qk.softmax(dim=-1)
        out = (attn @ v).transpose(1, 2).reshape(B, N, self.emb_size)
        return self.projection(out)


class TransformerEncoderLayer(nn.Module):
    def __init__(self, emb_size, num_heads, ttt_config, dropout, layer_idx):
        super(TransformerEncoderLayer, self).__init__()
        self.norm1 = nn.LayerNorm(emb_size)
        self.attn = MultiHeadSelfAttention(emb_size, num_heads)
        self.norm2 = nn.LayerNorm(emb_size)
        self.tttmlp = TTTMLP(ttt_config, layer_idx=layer_idx)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        x = x + self.dropout(self.attn(self.norm1(x)))
        x = x + self.dropout(self.tttmlp(self.norm2(x), position_ids=torch.arange(x.shape[1]).unsqueeze(0).repeat(x.shape[0], 1).to(x.device)))
        return x


class ViT(nn.Module):
    def __init__(self, img_size, patch_size, in_channels, num_classes, emb_size, depth, num_heads, ttt_config, dropout):
        super(ViT, self).__init__()
        self.patch_embed = PatchEmbedding(in_channels, patch_size, emb_size, img_size)
        self.transformer = nn.Sequential(*[
            TransformerEncoderLayer(emb_size, num_heads, ttt_config, dropout, layer_idx=idx) for idx in range(depth)
        ])
        self.norm = nn.LayerNorm(emb_size)
        self.fc = nn.Linear(emb_size, num_classes)

    def forward(self, x):
        x = self.patch_embed(x)
        x = self.transformer(x)
        x = self.norm(x[:, 0])
        x = self.fc(x)
        return x

    

# Parameters
in_channels = 3
num_classes = 1000
emb_size = 768
num_heads = 6
num_layers = 6
patch_size = 16
img_size = 224
ff_hidden_size = 3072
dropout = 0.1
mini_batch_size = 4
batch_size = 2
dropout = 0.1
ttt_config = TTTConfig(hidden_size=768)

# Create the model
model = ViT(img_size=img_size, patch_size=patch_size, in_channels=in_channels, num_classes=num_classes, emb_size=emb_size, depth=num_layers, num_heads=num_heads, ttt_config=ttt_config, dropout=dropout).cuda()


# Example forward pass
dummy_input = torch.randn(batch_size, 3, 224, 224).cuda()
output = model(dummy_input)
print(output.shape)  # Should be [1, 10]


torch.Size([2, 1000])


In [3]:
total_size = sum(p.numel() * p.element_size() for p in model.parameters())

In [4]:
total_size * 1e-6

123.74224

In [36]:
tttconfig = TTTConfig()

In [37]:
ttt_mlp = TTTMLP(tttconfig)

Instantiating TTTMLP without passing a `layer_idx` is not recommended and will lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` when creating this class.
