# Week 13 Practice: A simple RNN

Write RNN from scratch. Finish the forward network.

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [10]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class SingleHeadSelfAttention(nn.Module):
    def __init__(self, embed_dim):
        """
        Simple single-head self-attention implementation
        
        Args:
            embed_dim (int): Dimension of input embeddings
        """
        super().__init__()
        
        # Linear projections for Q, K, V
        self.q_proj = nn.Linear(embed_dim, embed_dim)
        self.k_proj = nn.Linear(embed_dim, embed_dim)
        self.v_proj = nn.Linear(embed_dim, embed_dim)
        
        # Scaling factor
        self.scaling = embed_dim ** -0.5
    
    def forward(self, x):
        """
        Args:
            x: Input tensor of shape (batch_size, seq_len, embed_dim)
        Returns:
            Output tensor of shape (batch_size, seq_len, embed_dim)
        """
        # Project inputs to Q, K, V
        q = self.q_proj(x)  # (batch_size, seq_len, embed_dim)
        k = self.k_proj(x)  # (batch_size, seq_len, embed_dim)
        v = self.v_proj(x)  # (batch_size, seq_len, embed_dim)
        
        # Compute attention scores
        # (batch_size, seq_len, seq_len) = (batch_size, seq_len, embed_dim) @ (batch_size, embed_dim, seq_len)
        attention_scores = torch.bmm(q, k.transpose(1, 2)) * self.scaling
        
        # Apply softmax
        attention_weights = F.softmax(attention_scores, dim=-1)
        
        # Compute output
        # (batch_size, seq_len, embed_dim) = (batch_size, seq_len, seq_len) @ (batch_size, seq_len, embed_dim)
        output = torch.bmm(attention_weights, v)
        
        return output

# Example usage
def example_usage():
    # Parameters
    batch_size = 2
    seq_len = 4
    embed_dim = 8
    
    # Create model
    attention = SingleHeadSelfAttention(embed_dim)
    
    # Create random input
    x = torch.randn(batch_size, seq_len, embed_dim)
    
    # Forward pass
    output = attention(x)
    print(f"Input shape: {x.shape}")
    print(x)
    print(f"Output shape: {output.shape}")
    print(output)
    
    return output

if __name__ == "__main__":
    example_usage()

Input shape: torch.Size([2, 4, 8])
tensor([[[-0.5433, -0.5050,  0.4409, -0.7617, -1.2403, -0.0746, -0.6283,
          -1.0576],
         [-0.5391,  0.3608,  1.1843,  0.3850,  0.6423,  0.3401,  0.0428,
          -1.4827],
         [ 1.4827,  0.4520,  0.1143,  0.5077,  0.2364,  0.7920,  0.6191,
           1.1689],
         [-1.3448,  0.3344,  2.0104,  0.7628,  2.5714, -0.3169, -0.8335,
           0.3047]],

        [[-1.4552, -1.5145, -1.2081,  0.5568, -0.4187,  1.3677, -0.0801,
          -0.7354],
         [-2.1575, -0.6595, -1.5990,  1.4023, -0.0898, -1.4643,  0.7791,
           0.4537],
         [-0.0755,  1.0204,  1.1901,  0.6441,  0.1853, -1.2675,  0.2914,
           0.1710],
         [ 0.0329, -0.1213, -0.7920, -0.3056, -1.0159,  0.2477,  0.2326,
          -0.5507]]])
Output shape: torch.Size([2, 4, 8])
tensor([[[-0.1493, -0.3819,  0.1333, -0.2392,  0.0547, -0.4591,  0.2525,
          -0.6258],
         [-0.1610, -0.3904,  0.1121, -0.2588,  0.0392, -0.4388,  0.2552,
          -0.60