In [18]:
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from patchify import patchify

# Check if GPU is available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Define a simple Vision Transformer model
class VisionTransformer(nn.Module):
    def __init__(self, img_size=256, patch_size=16, num_classes=10, dim=768, depth=6, heads=8, mlp_dim=768, device=None):
        super(VisionTransformer, self).__init__()
        self.device = device   ################################################################################# ADD A DEVICE PARAMETER HERE!!!!!!!!
        
        self.patch_size = patch_size
        assert img_size % patch_size == 0, 'Image size must be divisible by patch size'
        num_patches = (img_size // patch_size) ** 2    #256
        self.patch_dim = 3 * patch_size * patch_size  #768
        self.num_patches = num_patches   #256

        self.embedding = nn.Linear(self.patch_dim, dim)
        self.position_embeddings = nn.Parameter(torch.randn(1, num_patches, dim))
        self.transformer = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(dim, heads, mlp_dim, device = self.device),
            num_layers=depth
        )
        self.to_cls_token = nn.Identity()
        self.mlp_head = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, num_classes)
        )

    def forward(self, x):
        # Convert image to patches
        # patches = x.unfold(3, self.patch_dim, self.patch_dim).unfold(3, self.patch_dim, self.patch_dim)
        # patches = patches.contiguous().view(x.size(0), self.num_patches, -1)
        
        x = x.to("cpu")
        print (f'x is: {x.shape}')
        image_np = x.permute(1, 2, 0).numpy()  # Convert to (H, W, C)  ## 256, 256, 3
        print (image_np.shape)
        patches = patchify(image_np, (self.patch_size, self.patch_size, image_np.shape[2]), step=self.patch_size)  #
        print (patches.shape)
        patches = patches.reshape(-1, self.patch_size, self.patch_size, image_np.shape[2])  # Flatten patches
        print (patches.shape)
        patches = torch.tensor(patches).permute(0, 3, 1, 2)  # Convert to (num_patches, C, patch_size, patch_size)
        print (patches.shape)
        patches = patches.reshape(-1, self.num_patches, self.patch_dim)  # Flatten patches
        print (patches.shape)
        patches = patches.to(device)
        print (patches.shape)
               
        
        
        # Linear embedding
        x = self.embedding(patches)
        # Add position embeddings
        x += self.position_embeddings
        # Transformer encoding
        print (f'tranformer input is: {x.device}')

        x = self.transformer(x).to(self.device)
        print (f'tranformer output is: {x.device}')
        # Classification token
        x = self.to_cls_token(x.mean(dim=1))
        # MLP head
        return self.mlp_head(x)

# Model parameters
img_size = 256
patch_size = 16
num_classes = 10
dim = 768
depth = 6
heads = 8
mlp_dim = 768

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Instantiate the model and move it to the GPU
model = VisionTransformer(img_size, patch_size, num_classes, dim, depth, heads, mlp_dim, device=device)

# Print the model architecture
print(model)
print (device)
print (model.device)


VisionTransformer(
  (embedding): Linear(in_features=768, out_features=768, bias=True)
  (transformer): TransformerEncoder(
    (layers): ModuleList(
      (0-5): 6 x TransformerEncoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
        )
        (linear1): Linear(in_features=768, out_features=768, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
        (linear2): Linear(in_features=768, out_features=768, bias=True)
        (norm1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (norm2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (dropout1): Dropout(p=0.1, inplace=False)
        (dropout2): Dropout(p=0.1, inplace=False)
      )
    )
  )
  (to_cls_token): Identity()
  (mlp_head): Sequential(
    (0): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
    (1): Linear(in_features=768, out_features=10, bias=True)
  )
)
cuda
cuda




In [19]:
# Create dummy input data (batch of images)
# batch_size = 8
img_size = 256
patch_size = 16
num_classes = 10
dim = 768
depth = 2
heads = 8
mlp_dim = 768

model = VisionTransformer(img_size, patch_size, num_classes, dim, depth, heads, mlp_dim).to(device)

T = torch.randn(3, img_size, img_size).to(device)     

# Forward pass
model.eval()  # Set the model to evaluation mode
with torch.no_grad():  # Disable gradient calculation
    output = model(T)
    print("Output shape:", output.shape)



x is: torch.Size([3, 256, 256])
(256, 256, 3)
(16, 16, 1, 16, 16, 3)
(256, 16, 16, 3)
torch.Size([256, 3, 16, 16])
torch.Size([1, 256, 768])
torch.Size([1, 256, 768])
tranformer input is: cuda:0
tranformer output is: cuda:0
Output shape: torch.Size([1, 10])


In [145]:
!pip install torch-summary

Collecting torch-summary
  Downloading torch_summary-1.4.5-py3-none-any.whl.metadata (18 kB)
Downloading torch_summary-1.4.5-py3-none-any.whl (16 kB)
Installing collected packages: torch-summary
Successfully installed torch-summary-1.4.5


In [174]:
!pip install torchinfo

Collecting torchinfo
  Downloading torchinfo-1.8.0-py3-none-any.whl.metadata (21 kB)
Downloading torchinfo-1.8.0-py3-none-any.whl (23 kB)
Installing collected packages: torchinfo
Successfully installed torchinfo-1.8.0


In [15]:
from torchinfo import summary
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
img_size = 256
patch_size = 16
num_classes = 10
dim = 768
depth = 6
heads = 8
mlp_dim = 768

model = VisionTransformer(img_size, patch_size, num_classes, dim, depth, heads, mlp_dim).to(device)

model.to(device)

input_size = 3, 256, 256
summary(model, input_size=input_size)

x is: torch.Size([3, 256, 256])
(256, 256, 3)
(16, 16, 1, 16, 16, 3)
(256, 16, 16, 3)
torch.Size([256, 3, 16, 16])
torch.Size([1, 256, 768])
torch.Size([1, 256, 768])


Layer (type:depth-idx)                        Output Shape              Param #
VisionTransformer                             [1, 10]                   196,608
├─Linear: 1-1                                 [1, 256, 768]             590,592
├─TransformerEncoder: 1-2                     [1, 256, 768]             --
│    └─ModuleList: 2-1                        --                        --
│    │    └─TransformerEncoderLayer: 3-1      [1, 256, 768]             3,546,624
│    │    └─TransformerEncoderLayer: 3-2      [1, 256, 768]             3,546,624
│    │    └─TransformerEncoderLayer: 3-3      [1, 256, 768]             3,546,624
│    │    └─TransformerEncoderLayer: 3-4      [1, 256, 768]             3,546,624
│    │    └─TransformerEncoderLayer: 3-5      [1, 256, 768]             3,546,624
│    │    └─TransformerEncoderLayer: 3-6      [1, 256, 768]             3,546,624
├─Identity: 1-3                               [1, 768]                  --
├─Sequential: 1-4                          