In [1]:
import torch
from mobilevit import MobileViT
dims=[144,192,240]
channels=[16,32,64,64,96,96,128,128,160,160,640]
model=MobileViT(channels=channels,dims=dims,num_classes=13).to(torch.float16).cuda()
A = torch.rand(1,3,100,224,224,dtype=torch.float16,device='cuda')
model(A)


Lk:16


: 

In [6]:
#count the number of parameters in the model
total_params = sum(p.numel() for p in model.parameters())
print("Total number of parameters in the model:", total_params)

Total number of parameters in the model: 7771152


In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from flash_attention_triton import attention as flash_attention_kernel
import torch.optim as optim


# Define the model classes
class FlashAttentionBlock(nn.Module):
    def __init__(self, d_model, n_heads, flash_attention_kernel):
        super(FlashAttentionBlock, self).__init__()
        self.d_model = d_model
        self.n_heads = n_heads
        self.head_dim = d_model // n_heads

        assert self.head_dim * n_heads == d_model, "d_model must be divisible by n_heads"

        self.q_linear = nn.Linear(d_model, d_model)
        self.k_linear = nn.Linear(d_model, d_model)
        self.v_linear = nn.Linear(d_model, d_model)
        self.out_linear = nn.Linear(d_model, d_model)

        self.scale = self.head_dim ** -0.5
        self.flash_attention_kernel = flash_attention_kernel

    def forward(self, x):
        B, N, D = x.shape

        # Linear transformations
        Q = self.q_linear(x).view(B, N, self.n_heads, self.head_dim).transpose(1, 2)
        K = self.k_linear(x).view(B, N, self.n_heads, self.head_dim).transpose(1, 2)
        V = self.v_linear(x).view(B, N, self.n_heads, self.head_dim).transpose(1, 2)

        # FlashAttention using Triton kernel
        attn_output = self.flash_attention_kernel(Q, K, V,False,self.scale)

        # Concatenate heads and put through final linear layer
        attn_output = attn_output.transpose(1, 2).contiguous().view(B, N, D)
        output = self.out_linear(attn_output)

        return output

class ToyModel(nn.Module):
    def __init__(self, d_model, n_heads, flash_attention_kernel):
        super(ToyModel, self).__init__()
        self.attention_block = FlashAttentionBlock(d_model, n_heads, flash_attention_kernel)
        self.relu = nn.ReLU()
        self.output_layer = nn.Linear(d_model, 1)  # Output layer for regression

    def forward(self, x):
        x = self.attention_block(x)
        x = self.relu(x)
        x = x.mean(dim=1)  # Global average pooling across the sequence length
        x = x.view(x.size(0), -1)  # Ensure the correct shape for the output layer
        x = self.output_layer(x)
        return x

# Instantiate the model
d_model = 64
n_heads = 4
model = ToyModel(d_model, n_heads, flash_attention_kernel).to(torch.float16).cuda()

# Generate some dummy data for training
torch.manual_seed(42)
batch_size = 16
seq_length = 256 
num_epochs = 5

X_train = torch.randn(batch_size, seq_length, d_model,dtype=torch.float16).cuda()
y_train = torch.randn(batch_size, 1,dtype=torch.float16).cuda()

# Define the loss function and optimizer
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Training loop
for epoch in range(num_epochs):
    model.train()

    # Forward pass
    outputs = model(X_train)
    loss = criterion(outputs, y_train)

    # Backward pass and optimization
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}')

# Test the model with dummy input
model.eval()
with torch.no_grad():
    test_input = torch.randn(2, seq_length, d_model).to('cuda').to(torch.float16)  # Batch size of 2, sequence length of 10, model dimension of 64
    test_output = model(test_input)
    print("Test output:", test_output)


Epoch [1/5], Loss: 0.3569
Epoch [2/5], Loss: nan
Epoch [3/5], Loss: nan
Epoch [4/5], Loss: nan
Epoch [5/5], Loss: nan
Test output: tensor([[nan],
        [nan]], device='cuda:0', dtype=torch.float16)
