In [4]:
import torch
from torch import nn
from torch.nn import functional as F
from typing import Tuple

## Converting Image to a Sequence of Patches

In [5]:
class PatchEmbeddings(nn.Module):
    def __init__(
        self, img_size: int = 96, patch_size: int = 16, hidden_dim: int = 512
    ) -> None:
        super().__init__()
        # Store the input image size, the patch size and hidden dimension
        self.img_size = img_size
        self.patch_size = patch_size
        self.hidden_dim = hidden_dim

        # Calculate the total number of patches
        self.num_patches = (self.img_size // self.patch_size) ** 2

        # Create a convolution to extract patch embeddings
        # in_channels=3 asummes a 3-channel image (RGB)
        # outp_channels=hidden_dim sets the number of output channels to match the hidden dimension
        # kernel_size=patch_size and stride=patch_size ensuring each patch is embedded separately
        self.conv = nn.Conv2d(
            in_channels=3,
            out_channels=self.hidden_dim,
            kernel_size=self.patch_size,
            stride=self.patch_size,
        )

    def forward(self, X: torch.Tensor) -> torch.Tensor:
        # Extract patch embeddings from the input image
        # Output shape: (batch_size, hidden_dim, (self.img_size // self.patch_size), (self.img_size // self.patch_size))
        X = self.conv(X)

        # Flatten the spatial dimensions (height and width) of the patch embeddings
        # This step flattens the patch dimensions to a single dimension
        # Output shape: (batch_size, hidden_dim, self.num_patches)
        X = X.flatten(2)

        # Transpose the dimensions to obtain the shape (batch_size, num_patches, hidden_dim)
        # This step brings the num_patches dimension to the second position
        # Output shape: (batch_size, self.num_patches, hidden_dim)
        X = X.transpose(1, 2)

        return X

In [6]:
B, C, H, W = 1, 3, 96, 96  # Batch size, Channels, Height, Width
X = torch.randn(B, C, H, W)

patch_size = 16
hidden_dim = 512

patch_embeddings = PatchEmbeddings(
    img_size=H, patch_size=patch_size, hidden_dim=hidden_dim
)
patches = patch_embeddings(X)
print(f"Shape of image patches: {patches.shape}")

Shape of image patches: torch.Size([1, 36, 512])


In [7]:
num_patches = (H // patch_size) ** 2
assert patches.shape == (B, num_patches, hidden_dim), "Output shape is incorrect"
print("Test passed!")

Test passed!


## Attention Mechanism 
Attention Mechanism across both the vision encoder and language decoder

### The implementation of the Attention Head

In [8]:
class Head(nn.Module):
    def __init__(
        self,
        n_embed: int,
        head_size: int,
        dropout: float = 0.1,
        is_decoder: bool = False,
    ) -> None:
        super().__init__()

        # Linear layer for Key projection
        self.key = nn.Linear(in_features=n_embed, out_features=head_size, bias=False)

        # Linear layer for Query projection
        self.query = nn.Linear(in_features=n_embed, out_features=head_size, bias=False)

        # Linear layer for Value projection
        self.value = nn.Linear(in_features=n_embed, out_features=head_size, bias=False)

        # Dropout layer for regularization to prevent overfitting
        self.dropout = nn.Dropout(p=dropout)

        # Flag indicating wheter the head is used as a decoder
        self.is_decoder = is_decoder

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # Get batch size (B), sequence length (T), and embedding dimension (C) from the input tensor
        B, T, C = x.shape

        # Compute Key, Query, and Value projections
        k = self.key(x)  # Shape: (B, T, head_size)
        q = self.query(x)  # Shape: (B, T, head_size)
        v = self.value(x)  # SHape: (B, T, head_size)

        # Compute attention scores by taking the dot product of Query and Key
        # and scaling by the square root of the embedding dimension
        wei = q @ k.transpose(-2, -1) * (C**-0.5)  # Shape: (B, T, T)

        if self.is_decoder:
            # If this head is used in the decoder, apply causal mask to the attention scores
            # to prevent attention to future positions
            tril = torch.tril(torch.ones(T, T, dtype=torch.bool, device=x.device))
            wei = wei.masked_fill(mask=tril == 0, value=float("-inf"))

        # Apply softmax to the attention scores to obtain attention probabilities
        # Sum of probabilities for each row will be 1
        wei = F.softmax(input=wei, dim=-1)  # Shape: (B, T, T)

        # Apply Dropout to the attention probabilities for regularization
        wei = self.dropout(wei)

        # Perform a weighted aggregation of values using the attention probabilities
        out = wei @ v  # Shape: (B, T, head_size)

        return out

In [9]:
B, T, C = patches.shape  # Batch size, Sequence length, Embedding dimension
head_size = 16  # Size of the attention head

head = Head(n_embed=C, head_size=head_size)
output = head(patches)
print(f"Shape of output tensor: {output.shape}")

Shape of output tensor: torch.Size([1, 36, 16])


In [10]:
assert output.shape == (B, T, head_size), "Output shape is incorrect"
print("Test passed!")

Test passed!


### The implementation of Multihead Attention

In [11]:
class MultiHeadAttention(nn.Module):
    def __init__(
        self,
        n_embed: int,
        num_heads: int,
        dropout: float = 0.1,
        is_decoder: bool = False,
    ) -> None:
        super().__init__()

        # Ensure that the embedding dimension is divisible by the number of heads
        assert n_embed % num_heads == 0, "n_embed must be divisible by num_heads!"

        # Create a ModuleList of attention heads
        self.heads = nn.ModuleList(
            modules=[
                Head(
                    n_embed=n_embed,
                    head_size=n_embed // num_heads,
                    dropout=dropout,
                    is_decoder=is_decoder,
                )
                for _ in range(num_heads)
            ]
        )

        # Linear layer for projecting the concatenated head outputs
        self.proj = nn.Linear(in_features=n_embed, out_features=n_embed)

        # Dropout layer for regularization
        self.dropout = nn.Dropout(p=dropout)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # Apply each attention head to the input tensor
        head_outputs = [
            h(x) for h in self.heads
        ]  # Shape: num_heads * (B, T, head_size)

        # Concatenate the outputs from all heads along the last dimension
        out = torch.cat(tensors=head_outputs, dim=-1)  # Shape: (B, T, m_embed)

        # Apply the projection layer to the concatenated outputs
        out = self.proj(out)  # Shape: (B, T, m_embed)

        # Apply Dropout to the projected outputs for regularization
        out = self.dropout(out)

        return out

In [12]:
num_heads = 2
dropout = 0.1
mha = MultiHeadAttention(n_embed=C, num_heads=num_heads, dropout=dropout)

In [13]:
output = mha(patches)
print(f"Shape of output tensor: {output.shape}")

Shape of output tensor: torch.Size([1, 36, 512])


In [14]:
assert output.shape == (B, T, C), "Output shape is incorrect"
print("Test passed!")

Test passed!


### The Multilayer Perceptron

In [15]:
class MLP(nn.Module):
    def __init__(
        self, n_embed: int, dropout: float = 0.1, is_decoder: bool = False
    ) -> None:
        super().__init__()

        # Define the layers of the MLP
        layers = [
            # First linear layer that expands the input dimension from n_embed to 4 * n_embed
            nn.Linear(in_features=n_embed, out_features=4 * n_embed),
            # Activation function: ReLU if is_decoder is True, else GELU
            nn.ReLU() if is_decoder else nn.GELU(),
            # Second linear layer that projects the intermediate dimension back to n_embed
            nn.Linear(in_features=4 * n_embed, out_features=n_embed),
            # Dropout layer for regularization
            nn.Dropout(p=dropout),
        ]

        # Create the MLP as a sequence of layers
        self.net = nn.Sequential(*layers)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # Pass the input through the MLP layers
        return self.net(x)

In [16]:
dropout = 0.1
mlp = MLP(n_embed=C, dropout=dropout)

In [17]:
output = mlp(output)  # Previous output of the Multihead Attention
print(f"Shape of output tensor: {output.shape}")

Shape of output tensor: torch.Size([1, 36, 512])


In [18]:
assert output.shape == (B, T, C), "Output shape is incorrect"
print("Test passed!")

Test passed!


### Transformer Blocks

In [19]:
class Block(nn.Module):
    def __init__(
        self,
        n_embed: int,
        num_heads: int,
        dropout: float = 0.1,
        is_decoder: bool = False,
    ) -> None:
        super().__init__()

        # Layer normalization for the input to the attention layer
        self.ln1 = nn.LayerNorm(normalized_shape=n_embed)

        # Multi-head attention module
        self.mhattn = MultiHeadAttention(
            n_embed=n_embed, num_heads=num_heads, dropout=dropout, is_decoder=is_decoder
        )

        # Layer normalization for the input to the FFN
        self.ln2 = nn.LayerNorm(normalized_shape=n_embed)

        # Feed-forward neural network (FFN)
        self.ffn = nn.Sequential(
            nn.Linear(in_features=n_embed, out_features=4 * n_embed),
            nn.GELU(),  # Activation function
            nn.Linear(
                in_features=4 * n_embed, out_features=n_embed
            ),  # Projection back to the original dimension
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # Saving the input for residual connection
        original_x = x

        # Apply layer normalization to the input
        x = self.ln1(x)

        # Apply multi-head attention
        mhattn_output = self.mhattn(x)

        # Add the residual connection (original input) to the attention output
        x = original_x + mhattn_output

        # Apply later normalization to the input to the FFN
        x = self.ln2(x)

        # Apply the FFN
        ffn_output = self.ffn(x)

        # Apply the residual connection (input to the FFN) to the FFN output
        x = x + ffn_output

        return x

In [20]:
num_heads = 2
dropout = 0.1
block = Block(n_embed=C, num_heads=num_heads, dropout=dropout)

In [21]:
output = block(patches)
print(f"Shape of output tensor: {output.shape}")

Shape of output tensor: torch.Size([1, 36, 512])


In [22]:
assert output.shape == (B, T, C), "Output shape is incorrect"
print("Test passed!")

Test passed!


## Vision Encoder - Vision Transformer (ViT)

Combining patchification logic and attention block in to ViT

In [23]:
class ViT(nn.Module):
    def __init__(
        self,
        img_size: int,
        patch_size: int,
        num_hiddens: int,
        num_heads: int,
        num_blocks: int,
        emb_dropout: float,
        block_dropout: float,
    ) -> None:
        super().__init__()

        # Patch embedding layer to convert the input image into patches
        self.patch_embedding = PatchEmbeddings(
            img_size=img_size, patch_size=patch_size, hidden_dim=num_hiddens
        )

        # Learnable classification token
        self.cls_token = nn.Parameter(data=torch.zeros(size=(1, 1, num_hiddens)))

        # Calculate the number of patches
        num_patches = (img_size // patch_size) ** 2

        # Learnable position embedding
        self.pos_embedding = nn.Parameter(
            data=torch.randn(size=(1, num_patches + 1, num_hiddens))
        )

        # Dropout layer for the embeddings
        self.dropout = nn.Dropout(p=emb_dropout)

        # Stack of transformer blocks
        self.blocks = nn.ModuleList(
            [
                Block(
                    n_embed=num_hiddens,
                    num_heads=num_heads,
                    dropout=block_dropout,
                    is_decoder=False,
                )
                for _ in range(num_blocks)
            ]
        )

        # Layer normalization for the final representation
        self.layer_norm = nn.LayerNorm(normalized_shape=num_hiddens)

    def forward(self, X: torch.Tensor) -> torch.Tensor:
        # Convert the input image into patch embeddings
        x = self.patch_embedding(X)  # Shape: (B, num_patches, num_hiddens)

        # Expand the classification token to match the batch size
        cls_tokens = self.cls_token.expand(
            x.shape[0], -1, -1
        )  # Shape: (B, 1, num_hiddens)

        # Concatenate the classification token with the patch embeddings
        x = torch.cat(
            tensors=(cls_tokens, x), dim=1
        )  # Shape: (B, num_patches + 1, num_hiddens)

        # Add the position embedding to the patch embeddings
        x += self.pos_embedding  # Shape: (B, num_patches + 1, num_hiddens)

        # Apply dropout to the embeddings
        x = self.dropout(x)  # Shape: (B, num_patches + 1, num_hiddens)

        # Pass the embeddings through the transformer blocks
        for block in self.blocks:
            x = block(x)  # Shape: (B, num_patches + 1, num_hiddens)

        # Apply layer normalization to the `[CLS]` token's final representation
        x = self.layer_norm(x[:, 0])  # Shape: (B, num_hiddens)

        return x

In [24]:
B, C, H, W = 2, 3, 96, 96  # Batch size, Channels, Height, Width
X = torch.randn(B, C, H, W)
vit = ViT(
    img_size=H,
    patch_size=16,
    num_hiddens=64,
    num_heads=2,
    num_blocks=2,
    emb_dropout=0.1,
    block_dropout=0.1,
)

In [25]:
output = vit(X)
print(f"Output shape: {output.shape}")

Output shape: torch.Size([2, 64])


In [26]:
assert output.shape == (B, 64), "Output shape is incorrect"
print("Test passed!")

Test passed!


## Vision-Language Projection Module

Unfortunatelly, we can't directly concatenate ViT output to the text embeddings. <br>
We need to project this from dimensionality of image embeddings from the vision transformer to the dimensionality of text embeddings.

Why MLP for this part? If you want to train VLM with low resources you can do so by keeping both the pretrained vision encoder and language decoder frozen during the VLM training. Therefore, allocating more parameters to the connection module could enhance the overall VLM's ability to generalize and help in the downstream instruction-tuning process.

In [27]:
class MultiModalProjector(nn.Module):
    def __init__(
        self,
        n_embed: int,
        img_embed_dim: int,
        dropout: float = 0.1,
    ) -> None:
        super().__init__()

        # Define the projection network
        self.net = nn.Sequential(
            # Linear layer to expand the image embedding dimension
            nn.Linear(in_features=img_embed_dim, out_features=4 * img_embed_dim),
            # GELU activation function
            nn.GELU(),
            # Linear layer to project the expanded image embeddings to the text embedding dimension
            nn.Linear(in_features=4 * img_embed_dim, out_features=n_embed),
            # Dropout layer for regularization
            nn.Dropout(p=dropout),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # Pass the input through the projection network
        x = self.net(x)  # Shape: (B, img_embed_dim) --> (B, n_embed)
        return x

In [28]:
B, n_embed, img_embed_dim = 2, 64, 128
X = torch.randn(size=(B, img_embed_dim))

projector = MultiModalProjector(
    n_embed=n_embed, img_embed_dim=img_embed_dim, dropout=0.1
)

In [29]:
output = projector(X)
print(f"Output shape: {output.shape}")

Output shape: torch.Size([2, 64])


In [30]:
assert output.shape == (B, n_embed), "Output shape is incorrect"
print("Test passed!")

Test passed!


## Building the Decoder Language Model

Only thing that deviates from origianl implementation is that here projection module is integrated into decoder model class. <br>
In contrary, when using pretrained models with HuggingFace (or any other library), you can directly feed embeddings as input to the model.

In [31]:
class DecoderLanguageModel(nn.Module):
    def __init__(
        self,
        n_embed: int,
        img_embed_dim: int,
        vocab_size: int,
        num_heads: int,
        n_layer: int,
        use_images: bool = False,
    ) -> None:
        super().__init__()

        self.use_images = use_images

        # Token embedding table
        self.token_embedding_table = nn.Embedding(
            num_embeddings=vocab_size, embedding_dim=n_embed
        )

        # Position embedding table
        self.position_embedding_table = nn.Embedding(
            num_embeddings=1000, embedding_dim=n_embed
        )

        if use_images:
            # Image projection layer to align image embeddings with text embeddings
            self.image_projection = MultiModalProjector(
                n_embed=n_embed, img_embed_dim=img_embed_dim
            )

        # Stack of transformer decoder blocks
        self.blocks = nn.Sequential(
            *[
                Block(n_embed=n_embed, num_heads=num_heads, is_decoder=True)
                for _ in range(n_layer)
            ]
        )

        # Final layer normalization
        self.ln_f = nn.LayerNorm(normalized_shape=n_embed)

        # Language modeling head
        self.lm_head = nn.Linear(in_features=n_embed, out_features=vocab_size)

    def forward(
        self,
        idx: torch.Tensor,
        img_embeds: torch.Tensor = None,
        targets: torch.Tensor = None,
    ) -> torch.Tensor:
        # Get token embeddings from the input indices
        tok_emb = self.token_embedding_table(idx)

        if self.use_images:
            # Project and concatenate image embeddings with token embeddings
            img_emb = self.image_projection(img_embeds).unsqueeze(1)
            tok_emb = torch.cat([img_emb, tok_emb], dim=1)

        # Get position embeddings
        pos_emb = self.position_embedding_table(
            torch.arange(tok_emb.size(1), device=idx.device)
        )

        # Add position embeddings to token embeddings
        x = tok_emb + pos_emb

        # Pass through the transformer decoder blocks
        x = self.blocks(x)

        # Apply final layer normalization
        x = self.ln_f(x)

        # Get the logits from the language modeling head
        logits = self.lm_head(x)

        if targets is not None:
            if self.use_images and img_embeds is not None:
                # Prepare targets by concatenating a dummy target for the image embedding
                batch_size = idx.size(0)
                targets = torch.cat(
                    [
                        torch.full(
                            (batch_size, 1), -100, dtype=torch.long, device=idx.device
                        ),
                        targets,
                    ],
                    dim=1,
                )

            # Compute the cross-entropy loss
            loss = F.cross_entropy(
                input=logits.view(-1, logits.size(-1)),
                target=targets.view(-1),
                ignore_index=-100,
            )
            return logits, loss

        return logits

    def generate(
        self, idx: torch.Tensor, img_embeds: torch.Tensor, max_new_tokens: int
    ) -> torch.Tensor:
        # Get the batch size and sequence length
        B, T = idx.shape

        # Initialize the generated sequence with the input indices
        generated = idx
        tok_emb = self.token_embedding_table(idx)

        if self.use_images and img_embeds is not None:
            # Project and concatenate image embeddings with token embeddings
            img_emb = self.image_projection(img_embeds).unsqueeze(1)
            current_output = torch.cat([img_emb, tok_emb], dim=1)
        else:
            current_output = tok_emb

        # Generate new tokens iteratevely
        for i in range(max_new_tokens):
            # Get the current sequence length
            T_current = current_output.shape[1]

            # Get position embeddings for the current sequence length
            current_pos_emb = self.position_embedding_table(
                torch.arange(T_current, device=idx.device)
            ).unsqueeze(0)

            # Add position embeddings to the current output
            current_output += current_pos_emb

            # Pass through the transformer decoder blocks
            for block in self.blocks:
                current_output = block(current_output)

            # Get the logits for the last token
            logits = self.lm_head(current_output[:, -1, :])

            # Apply softmax to get probabilities
            probs = F.softmax(logits, dim=-1)

            # Sample the next token based on the probability
            idx_next = torch.multinomial(input=probs, num_samples=1)

            # Concatenate the generated token to the generated sequence
            generated = torch.cat([generated, idx_next], dim=1)

            # Get the embeddings for the generated token
            idx_next_emb = self.token_embedding_table(idx_next)

            # Concatenate the generated token embeddings to the current output
            current_output = torch.cat([current_output, idx_next_emb], dim=1)

        return generated

In [32]:
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
device

device(type='mps')

Testing

In [33]:
n_embed, img_embed_dim, vocab_size, num_heads, n_layer = 128, 256, 1000, 8, 6
# `n_layer` is used to represent number of decoder transformer blocks and num_blocks for the vision encoder to avoid confusion
model = DecoderLanguageModel(
    n_embed=n_embed,
    img_embed_dim=img_embed_dim,
    vocab_size=vocab_size,
    num_heads=num_heads,
    n_layer=n_layer,
    use_images=True,
)
model.to(device)


# Dummy input
B, T = 10, 50
idx = torch.randint(low=0, high=vocab_size, size=(B, T)).to(device)
image_embeds = torch.randn(B, 256).to(device)  # Assume img_embed_dim is 256

targets = torch.randint(0, vocab_size, (B, T)).to(
    device
)  # Only if you want to compute loss

# Test forward pass
# Check if you need to calculate loss by providing targets
if targets is not None:
    logits, loss = model(idx, image_embeds, targets)
    print(f"Logits shape: {logits.shape}, Loss: {loss}")
else:
    logits = model(idx, image_embeds)  # Call without targets
    print(f"Logits shape: {logits.shape}")

# Test generation
generated = model.generate(idx, image_embeds, max_new_tokens=20)
print(f"Generated sequence shape: {generated.shape}")

Logits shape: torch.Size([10, 51, 1000]), Loss: 7.145832061767578
Generated sequence shape: torch.Size([10, 70])


## Putting everything together: Simple Vision Language Model

In [34]:
class VisionLanguageModel(nn.Module):
    def __init__(
        self,
        n_embed: int,
        img_embed_dim: int,
        vocab_size: int,
        n_layer: int,
        img_size: int,
        patch_size: int,
        num_heads: int,
        num_blocks: int,
        emb_dropout: float,
        block_dropout: float,
    ) -> None:
        super().__init__()

        # Set num_hiddens equal to img_embed_dim
        num_hiddens = img_embed_dim

        # Assert that num_hiddens is divisible by num_heads
        assert num_hiddens % num_heads == 0, ValueError(
            "num_hiddens must be divisible by num_heads!"
        )

        # Initialize the Vision Transformer (ViT) encoder
        self.vision_encoder = ViT(
            img_size=img_size,
            patch_size=patch_size,
            num_hiddens=num_hiddens,
            num_heads=num_heads,
            num_blocks=num_blocks,
            emb_dropout=emb_dropout,
            block_dropout=block_dropout,
        )

        # Initialize the Language Model Decoder (DecoderLanguageModel)
        self.decoder = DecoderLanguageModel(
            n_embed=n_embed,
            img_embed_dim=img_embed_dim,
            vocab_size=vocab_size,
            num_heads=num_heads,
            n_layer=n_layer,
            use_images=True,
        )

    def _check_image_embeddings(self, image_embeds: torch.Tensor) -> None:
        """Chek if image embeddings are valid."""
        if image_embeds.nelement() == 0 or image_embeds.shape[1] == 0:
            raise ValueError(
                "Something is wrong with the ViT model. It's returning an empty tensor or the embedding dimension is empty."
            )

    def forward(
        self, img_array: torch.Tensor, idx: torch.Tensor, targets: torch.Tensor = None
    ) -> torch.Tensor | Tuple[torch.Tensor, torch.Tensor]:
        # Get the image embeddings from the Vision Encoder
        image_embeds = self.vision_encoder(img_array)

        # Check if image embeddings are valid
        self._check_image_embeddings(image_embeds)

        if targets is not None:
            # If targets are provided, compute the logits and loss
            logits, loss = self.decoder(idx, image_embeds, targets)
            return logits, loss
        else:
            # If targets are not provided, compute only the logits
            logits = self.decoder(idx, image_embeds)
            return logits

    def generate(
        self, img_array: torch.Tensor, idx: torch.Tensor, max_new_tokens: int
    ) -> torch.Tensor:
        # Get the image embeddings from the Vision Encoder
        image_embeds = self.vision_encoder(img_array)

        # Check if image embeddings are valid
        self._check_image_embeddings(image_embeds)

        # Generate new tokens using the Language Model Decoder
        generated_tokens = self.decoder.generate(
            idx=idx, img_embeds=image_embeds, max_new_tokens=max_new_tokens
        )
        return generated_tokens

Testing

In [36]:
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
print(f"Running on device: {device}")

Running on device: mps


In [38]:
n_embed, num_hiddens, vocab_size, num_heads, n_layer = 128, 512, 1000, 8, 8
image_embed_dim = num_hiddens
img_size = 96
patch_size = 16
num_blocks = 2

n_layer, block_size, num_hiddens = 8, 32, 512

# Initialize the model
model = VisionLanguageModel(
    n_embed=n_embed,
    img_embed_dim=image_embed_dim,
    vocab_size=vocab_size,
    n_layer=n_layer,
    img_size=img_size,
    patch_size=patch_size,
    num_heads=num_heads,
    num_blocks=num_blocks,
    emb_dropout=0.1,
    block_dropout=0.1,
)
model.to(device)

# Create dummy data with correct dimensions
dummy_img = torch.randn(1, 3, img_size, img_size).to(
    device
)  # Correct shape for image input
dummy_idx = torch.randint(0, vocab_size, (1, block_size)).to(
    device
)  # Correct shape for text input

# Forward pass to initialize all parameters
try:
    output = model(dummy_img, dummy_idx)  # Output for debugging
    print("Output from initialization forward pass:", output)
except RuntimeError as e:
    print(f"Runtime Error during forward pass: {str(e)}")
    print("Check layer configurations and input shapes.")

Output from initialization forward pass: tensor([[[-0.8860, -0.0301,  0.1804,  ...,  0.4420,  0.0458, -1.0360],
         [-0.3543,  0.1468,  0.0166,  ...,  0.8835,  0.1615, -0.4509],
         [-0.3524,  0.2126,  1.6317,  ...,  1.0799,  0.1788, -0.4753],
         ...,
         [ 0.3872,  0.4607,  0.5435,  ..., -0.0981, -0.8271, -0.9599],
         [-0.3231, -0.1148, -0.0585,  ...,  1.1575, -0.6290, -0.6378],
         [ 0.0722, -0.1574, -0.3214,  ...,  0.2728, -0.7626,  0.2860]]],
       device='mps:0', grad_fn=<LinearBackward0>)
