## Implement Causal Attention

In [None]:
import torch
import torch.nn as nn

class CausalAttention(nn.Module):
    def __init__(self, embed_dim, dropout_rate=0.1):
        super().__init__()
        self.embed_dim = embed_dim
        # Initialize linear layers for query, key, and value transformations
        self.query_proj = nn.Linear(embed_dim, embed_dim)
        self.key_proj = nn.Linear(embed_dim, embed_dim)
        self.value_proj = nn.Linear(embed_dim, embed_dim)
        self.dropout = nn.Dropout(dropout_rate)

    def forward(self, x):
        # Pass x through the query, key, and value linear layers
        query = self.query_proj(x)
        key = self.key_proj(x)
        value = self.value_proj(x)

        # Calculate raw attention scores (dot product between query and transpose of key)
        # Scaling by sqrt(embed_dim) to prevent very large values
        scores = torch.matmul(query, key.transpose(-2, -1)) / (self.embed_dim ** 0.5)

        # Create a causal mask (look-ahead mask)
        seq_len = x.size(1) # Assuming input x shape is (batch_size, seq_len, embed_dim)
        causal_mask = torch.triu(torch.ones(seq_len, seq_len, device=x.device) * float('-inf'), diagonal=1)

        # Apply the causal mask to the attention scores
        scores = scores + causal_mask

        # Apply softmax to get attention probabilities
        attention_probs = torch.softmax(scores, dim=-1)

        # Apply dropout to attention probabilities
        attention_probs = self.dropout(attention_probs)

        # Multiply attention probabilities with the value tensor
        output = torch.matmul(attention_probs, value)

        return output

print("CausalAttention class defined successfully.")

CausalAttention class defined successfully.


## Test Causal Attention


In [None]:
import torch
import torch.nn as nn

# 1. Define an embedding dimension, batch size, and sequence length
embed_dim = 64
batch_size = 2
seq_len = 10

# 2. Create a random input tensor x
x = torch.randn(batch_size, seq_len, embed_dim)

# 3. Instantiate the CausalAttention class
causal_attention_model = CausalAttention(embed_dim)

# 4. Pass the input tensor x through the CausalAttention instance
output = causal_attention_model(x)

# 5. Print the shape of the output tensor
print(f"Input tensor shape: {x.shape}")
print(f"Output tensor shape: {output.shape}")

# Assert that the output shape is as expected
assert output.shape == (batch_size, seq_len, embed_dim), f"Expected output shape ({batch_size}, {seq_len}, {embed_dim}), but got {output.shape}"
print("Test passed: Output shape is correct.")

Input tensor shape: torch.Size([2, 10, 64])
Output tensor shape: torch.Size([2, 10, 64])
Test passed: Output shape is correct.


## Implement Causal Attention in TensorFlow




In [None]:
import tensorflow as tf
from tensorflow.keras import layers

class CausalAttentionTF(layers.Layer):
    def __init__(self, embed_dim, dropout_rate=0.1, **kwargs):
        super().__init__(**kwargs)
        self.embed_dim = embed_dim
        # Initialize dense layers for query, key, and value transformations
        self.query_proj = layers.Dense(embed_dim, use_bias=False)
        self.key_proj = layers.Dense(embed_dim, use_bias=False)
        self.value_proj = layers.Dense(embed_dim, use_bias=False)
        self.dropout = layers.Dropout(dropout_rate)

    def call(self, inputs):
        # Pass inputs through the query, key, and value dense layers
        query = self.query_proj(inputs)
        key = self.key_proj(inputs)
        value = self.value_proj(inputs)

        # Calculate raw attention scores (dot product between query and transpose of key)
        # Scaling by sqrt(embed_dim) to prevent very large values
        scores = tf.matmul(query, key, transpose_b=True) / tf.math.sqrt(tf.cast(self.embed_dim, tf.float32))
        print("Scores",scores)
        # Create a causal mask (look-ahead mask)
        seq_len = tf.shape(inputs)[1] # Assuming input shape is (batch_size, seq_len, embed_dim)
        # Create a lower triangular mask and then invert it to get an upper triangular mask
        # for positions to be masked (set to -inf)
        causal_mask = tf.cast(tf.linalg.band_part(tf.ones((seq_len, seq_len)), -1, 0), tf.float32)
        causal_mask = (1 - causal_mask) * -1e9 # Mask out future positions with a large negative number
        print("Causal Mask",causal_mask)
        # Apply the causal mask to the attention scores
        # The mask needs to be broadcastable to the scores shape (batch_size, seq_len, seq_len)
        scores = scores + causal_mask

        # Apply softmax to get attention probabilities
        attention_probs = tf.nn.softmax(scores, axis=-1)

        # Apply dropout to attention probabilities
        attention_probs = self.dropout(attention_probs)
        print("Attention Probs",attention_probs)
        # Multiply attention probabilities with the value tensor
        output = tf.matmul(attention_probs, value)
        print("Output",output)
        return output

print("CausalAttentionTF class defined successfully.")

CausalAttentionTF class defined successfully.


## Test Causal Attention in TensorFlow



In [None]:
import tensorflow as tf

# 1. Define an embedding dimension, batch size, and sequence length
embed_dim = 3
batch_size = 1
seq_len = 5

# 2. Create a random input tensor x
x = tf.random.normal((batch_size, seq_len, embed_dim))

# 3. Instantiate the CausalAttentionTF class
causal_attention_tf_model = CausalAttentionTF(embed_dim)

# 4. Pass the input tensor x through the CausalAttentionTF instance
output_tf = causal_attention_tf_model(x)

# 5. Print the shape of the output tensor
print(f"Input tensor shape (TF): {x.shape}")
print(f"Output tensor shape (TF): {output_tf.shape}")

# Assert that the output shape is as expected
assert output_tf.shape == (batch_size, seq_len, embed_dim), f"Expected output shape ({batch_size}, {seq_len}, {embed_dim}), but got {output_tf.shape}"
print("Test passed: TensorFlow output shape is correct.")

Scores Tensor("truediv:0", shape=(1, 5, 5), dtype=float32)
Causal Mask Tensor("mul:0", shape=(5, 5), dtype=float32)
Attention Probs Tensor("Softmax:0", shape=(1, 5, 5), dtype=float32)
Output Tensor("MatMul_1:0", shape=(1, 5, 3), dtype=float32)
Scores tf.Tensor(
[[[-1.2593594   1.2487633  -0.3013413  -1.3586303   0.10356779]
  [ 1.5379086  -0.2092953  -0.32554933 -0.10354766 -0.05004847]
  [-0.539802   -0.2679707   0.2952769   0.48303255 -0.00340527]
  [-0.8458488   0.22361538  0.11110591  0.0240486   0.04574059]
  [ 0.21569961 -0.01592045 -0.05388021 -0.02061073 -0.00497768]]], shape=(1, 5, 5), dtype=float32)
Causal Mask tf.Tensor(
[[-0.e+00 -1.e+09 -1.e+09 -1.e+09 -1.e+09]
 [-0.e+00 -0.e+00 -1.e+09 -1.e+09 -1.e+09]
 [-0.e+00 -0.e+00 -0.e+00 -1.e+09 -1.e+09]
 [-0.e+00 -0.e+00 -0.e+00 -0.e+00 -1.e+09]
 [-0.e+00 -0.e+00 -0.e+00 -0.e+00 -0.e+00]], shape=(5, 5), dtype=float32)
Attention Probs tf.Tensor(
[[[1.         0.         0.         0.         0.        ]
  [0.8515998  0.14840023 0. 