# 02 - Advanced Tensor Operations for Text Processing

**Duration:** 1-2 hours | **Difficulty:** Beginner

## 🎯 Learning Objectives
- Master tensor indexing, slicing, and reshaping for NLP
- Understand broadcasting in text processing contexts
- Learn batch processing patterns for variable-length sequences
- Implement memory-efficient operations

## 📚 Contents
1. Advanced Indexing and Masking
2. Tensor Reshaping for NLP
3. Broadcasting Patterns
4. Batch Processing
5. Memory Efficiency
6. Exercise: Attention Mechanism

In [None]:
import torch
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt

torch.manual_seed(42)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

## 1. Advanced Indexing and Masking

Essential for attention mechanisms and sequence processing.

In [None]:
# Create sample batch of tokenized sentences
batch_size, seq_length, vocab_size = 3, 6, 10
token_ids = torch.randint(0, vocab_size, (batch_size, seq_length))
print(f"Token IDs shape: {token_ids.shape}")
print(f"Token IDs:\n{token_ids}")

# Advanced indexing examples
print("\n=== Indexing Examples ===")
print(f"First sentence: {token_ids[0]}")
print(f"First token of each sentence: {token_ids[:, 0]}")
print(f"Last 3 tokens: {token_ids[:, -3:]}")

# Boolean masking
mask = token_ids > 5
high_tokens = token_ids[mask]
print(f"\nTokens > 5: {high_tokens}")

# Create causal attention mask
causal_mask = torch.tril(torch.ones(seq_length, seq_length))
print(f"\nCausal mask (prevents future attention):\n{causal_mask}")

In [None]:
# Fancy indexing for sequence processing
batch_size, seq_len, hidden_dim = 2, 5, 4
hidden_states = torch.randn(batch_size, seq_len, hidden_dim)

# Extract last valid hidden state for variable length sequences
seq_lengths = torch.tensor([3, 5])  # Actual lengths
batch_indices = torch.arange(batch_size)
last_indices = seq_lengths - 1

last_hidden = hidden_states[batch_indices, last_indices]
print(f"Last hidden states shape: {last_hidden.shape}")
print(f"Last hidden states:\n{last_hidden}")

# Gather specific positions
positions = torch.tensor([[0, 2], [1, 4]])
gathered = torch.gather(hidden_states, 1, 
                       positions.unsqueeze(-1).expand(-1, -1, hidden_dim))
print(f"\nGathered states shape: {gathered.shape}")

## 2. Tensor Reshaping for NLP

Understanding shapes for different neural network architectures.

In [None]:
# Common reshaping patterns
batch_size, seq_len, hidden_dim = 2, 4, 8
x = torch.randn(batch_size, seq_len, hidden_dim)
print(f"Original shape: {x.shape}")

# Flatten for fully connected layer
x_flat = x.view(batch_size, -1)
print(f"Flattened: {x_flat.shape}")

# Multi-head attention reshaping
num_heads = 2
head_dim = hidden_dim // num_heads
x_multihead = x.view(batch_size, seq_len, num_heads, head_dim)
print(f"Multi-head: {x_multihead.shape}")

# Transpose for attention computation
x_transposed = x_multihead.transpose(1, 2)  # (B, H, S, D)
print(f"Transposed: {x_transposed.shape}")

# Views vs copies
print(f"\nOriginal data pointer: {x.data_ptr()}")
print(f"View data pointer: {x_flat.data_ptr()}")
print(f"Same memory? {x.data_ptr() == x_flat.data_ptr()}")

## 3. Broadcasting Patterns

Essential for efficient operations between tensors of different shapes.

In [None]:
# NLP broadcasting examples
batch_size, seq_len, vocab_size = 3, 5, 8
logits = torch.randn(batch_size, seq_len, vocab_size)

# Add vocabulary bias (broadcasts across batch and sequence)
vocab_bias = torch.randn(vocab_size)
biased_logits = logits + vocab_bias
print(f"Logits: {logits.shape} + Bias: {vocab_bias.shape} = {biased_logits.shape}")

# Position embeddings (broadcasts across batch)
pos_embeddings = torch.randn(seq_len, vocab_size)
logits_with_pos = logits + pos_embeddings
print(f"With position: {logits.shape} + {pos_embeddings.shape} = {logits_with_pos.shape}")

# Attention mask application
attention_mask = torch.randint(0, 2, (batch_size, seq_len))
masked_logits = logits.masked_fill(attention_mask.unsqueeze(-1) == 0, -float('inf'))
print(f"Masked logits shape: {masked_logits.shape}")

## 4. Batch Processing

Handling variable-length sequences efficiently.

In [None]:
# Variable length sequences
sequences = [
    [1, 2, 3],
    [4, 5, 6, 7, 8],
    [9, 10],
    [11, 12, 13, 14]
]

lengths = torch.tensor([len(seq) for seq in sequences])
max_length = lengths.max().item()
print(f"Lengths: {lengths}, Max: {max_length}")

# Pad sequences
pad_token = 0
padded = []
for seq in sequences:
    padded.append(seq + [pad_token] * (max_length - len(seq)))

batch_tensor = torch.tensor(padded)
print(f"\nPadded batch:\n{batch_tensor}")

# Create attention mask
attention_mask = torch.zeros_like(batch_tensor)
for i, length in enumerate(lengths):
    attention_mask[i, :length] = 1
print(f"\nAttention mask:\n{attention_mask}")

# Efficient operations with masking
embedding_dim = 4
embeddings = torch.randn(batch_tensor.shape + (embedding_dim,))
masked_embeddings = embeddings * attention_mask.unsqueeze(-1).float()

# Compute sequence means (ignoring padding)
sequence_sums = masked_embeddings.sum(dim=1)
sequence_means = sequence_sums / lengths.unsqueeze(-1).float()
print(f"\nSequence means shape: {sequence_means.shape}")

## 5. Memory Efficiency

Techniques for reducing memory usage in large models.

In [None]:
# In-place operations
x = torch.randn(100, 100)
original_ptr = x.data_ptr()

# Regular operation creates new tensor
y = x + 1.0
print(f"Regular operation creates new tensor: {y.data_ptr() != original_ptr}")

# In-place operation modifies original
x.add_(1.0)
print(f"In-place operation reuses memory: {x.data_ptr() == original_ptr}")

# Memory-efficient attention (chunked)
def chunked_attention(Q, K, V, chunk_size=50):
    """Compute attention in chunks to save memory."""
    batch_size, seq_len, head_dim = Q.shape
    output = torch.zeros_like(Q)
    
    for i in range(0, seq_len, chunk_size):
        end_i = min(i + chunk_size, seq_len)
        q_chunk = Q[:, i:end_i, :]
        
        scores = torch.matmul(q_chunk, K.transpose(-2, -1))
        weights = F.softmax(scores, dim=-1)
        output[:, i:end_i, :] = torch.matmul(weights, V)
    
    return output

# Test chunked attention
Q = torch.randn(2, 100, 64)
K = torch.randn(2, 100, 64)
V = torch.randn(2, 100, 64)

output = chunked_attention(Q, K, V)
print(f"\nChunked attention output shape: {output.shape}")

## 6. Exercise: Implement Attention Mechanism

Put it all together to implement a complete attention mechanism.

In [None]:
class SimpleAttention(torch.nn.Module):
    """Simple scaled dot-product attention."""
    
    def __init__(self, hidden_dim, num_heads=1):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.num_heads = num_heads
        self.head_dim = hidden_dim // num_heads
        
        self.q_proj = torch.nn.Linear(hidden_dim, hidden_dim)
        self.k_proj = torch.nn.Linear(hidden_dim, hidden_dim)
        self.v_proj = torch.nn.Linear(hidden_dim, hidden_dim)
        self.out_proj = torch.nn.Linear(hidden_dim, hidden_dim)
    
    def forward(self, x, mask=None):
        batch_size, seq_len, _ = x.shape
        
        # Project to Q, K, V
        Q = self.q_proj(x)
        K = self.k_proj(x)
        V = self.v_proj(x)
        
        # Reshape for multi-head attention
        Q = Q.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        K = K.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        V = V.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        
        # Compute attention scores
        scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.head_dim ** 0.5)
        
        # Apply mask if provided
        if mask is not None:
            scores = scores.masked_fill(mask == 0, -float('inf'))
        
        # Apply softmax
        attention_weights = F.softmax(scores, dim=-1)
        
        # Apply attention to values
        output = torch.matmul(attention_weights, V)
        
        # Reshape and project output
        output = output.transpose(1, 2).contiguous().view(
            batch_size, seq_len, self.hidden_dim)
        output = self.out_proj(output)
        
        return output, attention_weights

# Test the attention mechanism
hidden_dim = 64
attention = SimpleAttention(hidden_dim, num_heads=4)

# Create sample input
batch_size, seq_len = 2, 8
x = torch.randn(batch_size, seq_len, hidden_dim)

# Create causal mask
causal_mask = torch.tril(torch.ones(seq_len, seq_len)).unsqueeze(0).unsqueeze(0)

# Forward pass
output, weights = attention(x, mask=causal_mask)

print(f"Input shape: {x.shape}")
print(f"Output shape: {output.shape}")
print(f"Attention weights shape: {weights.shape}")

# Visualize attention weights for first head, first batch
plt.figure(figsize=(8, 6))
plt.imshow(weights[0, 0].detach().numpy(), cmap='Blues')
plt.title('Attention Weights (First Head, First Batch)')
plt.xlabel('Key Position')
plt.ylabel('Query Position')
plt.colorbar()
plt.show()

print("\n✅ Attention mechanism implemented successfully!")

## 🎉 Congratulations!

You've mastered advanced tensor operations for text processing:

✅ **Advanced Indexing**: Boolean masking and fancy indexing  
✅ **Tensor Reshaping**: Views, transposes, and multi-head patterns  
✅ **Broadcasting**: Efficient operations across different tensor shapes  
✅ **Batch Processing**: Variable-length sequence handling  
✅ **Memory Efficiency**: In-place operations and chunked computation  
✅ **Complete Example**: Implemented attention mechanism from scratch  

## 🚀 Next Steps

In the next notebook, we'll apply these tensor operations to real text preprocessing:
- Text cleaning and normalization
- Tokenization strategies
- Vocabulary building

**Ready to process text?** Continue to [`03_text_preprocessing.ipynb`](03_text_preprocessing.ipynb)!