In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import math

# Positional Encoding
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
      # set a maximun sequence length of 5000
      # d_model is the embedding dimension, for example 128, 512
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        # torch.arrange creates [0,1,2,3...,4999] so 5000 by 1 vector
        # after squeeze is 5000 by 1
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        # div term dimension is d_model/2 vector, create different frequencies for diff positions
        # suppose d_model is 128
        # the arrange would create a row dimension, 1D vector
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        # 5000 by 1 times 64 is 5000 by 64
        # the div_term would be broadcasted as 1 by 64
        # the even columns: 0, 2, 4,...,126
        pe[:, 0::2] = torch.sin(position * div_term)
        # the odd columns: 1, 3, ..., 127
        pe[:, 1::2] = torch.cos(position * div_term)
        # basically the relative postions
        # change to 1, 5000, 128
        pe = pe.unsqueeze(0)
        # saves the pe as part of the model state
        # not a trainable parameter, fronzen positional encoding
        # will be saved/loaded with model checkpoints
        self.register_buffer('pe', pe)

    def forward(self, x):
        # if x is 32, 10, 128, then it is 32, 10, 128 + 1, 10, 128
        # the positional encoding is added to every sequence
        # addition is the original implementation, and also in other paper
        return x + self.pe[:, :x.size(1)]

# Multi-Head Attention
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super().__init__()
        assert d_model % num_heads == 0
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads

        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)

    def scaled_dot_product_attention(self, Q, K, V, mask=None):
        # Q is batch, num_heads, seq_len, d_k
        # K is batch, num_heads, d_k, seq_len
        attn_scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
        # result is batch, num_heads, seq_len, seq_len
        if mask is not None:
          # is provided, mask is batch, 1, seq_len, seq_len
          # sets the attention score to -1e9 so they become ~0 after softmax
            attn_scores = attn_scores.masked_fill(mask == 0, -1e9)
        # each row now sums to 1 , representing attention weights
        attn_probs = torch.softmax(attn_scores, dim=-1)
        # batch, num_heads, seq_len, seq_len
        # batch, num_heads, seq_len, d_k
        # result us batch, num_heads, seq_len, d_k , a weighted combination of all value vectors
        output = torch.matmul(attn_probs, V)
        return output

    def split_heads(self, x):
        batch_size, seq_length, d_model = x.size()
        return x.view(batch_size, seq_length, self.num_heads, self.d_k).transpose(1, 2)

    def combine_heads(self, x):
        batch_size, _, seq_length, d_k = x.size()
        # batch, seq_len, d_model
        return x.transpose(1, 2).contiguous().view(batch_size, seq_length, self.d_model)

    def forward(self, Q, K, V, mask=None):
        Q = self.split_heads(self.W_q(Q))
        # linear layer
        # Q is batch, seq_len, 512, then times 512 by 512
        # output = Q @ W_q.weight.T + W_q.bias
        # final dimension is batch, seq_len, 512
        # then the split_head
        # would return , batch, num_heads, seq_length, d_k
        K = self.split_heads(self.W_k(K))
        V = self.split_heads(self.W_v(V))

        attn_output = self.scaled_dot_product_attention(Q, K, V, mask)
        # batch, num_heads, seq_len, d_k

        # after combine heads is batch, seq_len, d_model
        # the output = combine_heads @ W_o.weight.T + W_o.bias
        output = self.W_o(self.combine_heads(attn_output))
        return output

# Feed Forward Network
class FeedForward(nn.Module):
    def __init__(self, d_model, d_ff):
        super().__init__()
        self.fc1 = nn.Linear(d_model, d_ff)
        self.fc2 = nn.Linear(d_ff, d_model)
        self.relu = nn.ReLU()

    def forward(self, x):
      # the transfomer position-wise feed-forward network with ReLU activation function
        return self.fc2(self.relu(self.fc1(x)))

# Encoder Layer
class EncoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
        super().__init__()
        self.self_attn = MultiHeadAttention(d_model, num_heads)
        self.feed_forward = FeedForward(d_model, d_ff)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask=None):
      # batch, seq_len, d_model
        attn_output = self.self_attn(x, x, x, mask)
        x = self.norm1(x + self.dropout(attn_output))
        ff_output = self.feed_forward(x)
        x = self.norm2(x + self.dropout(ff_output))
        return x

# Decoder Layer
class DecoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
        super().__init__()
        # the first multihead attention with mask
        self.self_attn = MultiHeadAttention(d_model, num_heads)
        # the second multihead attention as in tranformer arc
        self.cross_attn = MultiHeadAttention(d_model, num_heads)
        self.feed_forward = FeedForward(d_model, d_ff)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, enc_output, src_mask=None, tgt_mask=None):
        attn_output = self.self_attn(x, x, x, tgt_mask)
        x = self.norm1(x + self.dropout(attn_output))
        attn_output = self.cross_attn(x, enc_output, enc_output, src_mask)
        x = self.norm2(x + self.dropout(attn_output))
        ff_output = self.feed_forward(x)
        x = self.norm3(x + self.dropout(ff_output))
        return x

# Complete Transformer
class Transformer(nn.Module):
    def __init__(self, src_vocab_size, tgt_vocab_size, d_model=512, num_heads=8,
                 num_layers=6, d_ff=2048, max_seq_length=100, dropout=0.1):
        super().__init__()
        self.encoder_embedding = nn.Embedding(src_vocab_size, d_model)
        self.decoder_embedding = nn.Embedding(tgt_vocab_size, d_model)
        self.positional_encoding = PositionalEncoding(d_model, max_seq_length)

        # ModuleList registers parameters, enables .to(device)
        # enables .parameters(), saves/loads state
        # enables .train() and /eval(), properly sets training/eval mode for
        # all layers
        # same as
        # self.encoder_layers = nn.ModuleList([
        #     EncoderLayer(d_model, num_heads, d_ff, dropout),  # Layer 0
        #     EncoderLayer(d_model, num_heads, d_ff, dropout),  # Layer 1
        #     EncoderLayer(d_model, num_heads, d_ff, dropout),  # Layer 2
        #     EncoderLayer(d_model, num_heads, d_ff, dropout),  # Layer 3
        #     EncoderLayer(d_model, num_heads, d_ff, dropout),  # Layer 4
        #     EncoderLayer(d_model, num_heads, d_ff, dropout),  # Layer 5
        # ])
        # each layer has its own parameters, different transformations

        # ### **1. Hierarchical Feature Learning**

        # Each layer learns increasingly abstract representations:
        # ```
        # Layer 1: Simple patterns (word relationships, basic grammar)
        # Layer 2: Phrases and local context
        # Layer 3: Sentence-level meaning
        # Layer 4: Discourse structure
        # Layer 5: Complex reasoning
        # Layer 6: High-level semantic understanding
        # ```
        ### **2. Increased Representational Power**

        # - **Single layer**: Limited expressiveness
        # - **Multiple layers**: Can learn much more complex functions
        # - Similar to how deep CNNs learn edges → shapes → objects

        # ### **3. Longer-Range Dependencies**

        # Each attention layer has a "receptive field" - stacking allows information
        # to propagate across the entire sequence multiple times:
        # ```
        # Layer 1: Each token attends to all others (1 hop)
        # Layer 2: Tokens can combine information from layer 1 (2 hops)
        # Layer 3: Even more complex relationships (3 hops)
        # ...
        # Layer 6: Very complex, multi-hop reasoning (6 hops)

        self.encoder_layers = nn.ModuleList([
            EncoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(num_layers)
        ])
        self.decoder_layers = nn.ModuleList([
            DecoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(num_layers)
        ])

        self.fc = nn.Linear(d_model, tgt_vocab_size)
        self.dropout = nn.Dropout(dropout)

    def generate_mask(self, src, tgt):
      # != 0 is the boolean mask
      # after two unsqueeze, for example 2,6 would be 2, 1, 1, 6
      # batch, 1, 1, seq_len
        src_mask = (src != 0).unsqueeze(1).unsqueeze(2)
        tgt_mask = (tgt != 0).unsqueeze(1).unsqueeze(3)
        #
        seq_length = tgt.size(1)
        # upper triangular of 1, seq_len, seq_len
        # then 1 - , is the lower left triangualr
        # nopeak = 1 - upper_tri
        # tensor([
        #     [[1., 0., 0., 0., 0., 0.],   ← Position 0 can only see position 0
        #      [1., 1., 0., 0., 0., 0.],   ← Position 1 can see positions 0-1
        #      [1., 1., 1., 0., 0., 0.],   ← Position 2 can see positions 0-2
        #      [1., 1., 1., 1., 0., 0.],   ← Position 3 can see positions 0-3
        #      [1., 1., 1., 1., 1., 0.],   ← Position 4 can see positions 0-4
        #      [1., 1., 1., 1., 1., 1.]]   ← Position 5 can see all positions
        # ])
        # triu, diagonal=0 means the elements on and above the main diagonal are remianed
        # diagonal = 1 means above the diagonal are remained
        nopeak_mask = (1 - torch.triu(torch.ones(1, seq_length, seq_length), diagonal=1)).bool()
        tgt_mask = tgt_mask & nopeak_mask.to(tgt.device)
        return src_mask, tgt_mask

    def forward(self, src, tgt):
        src_mask, tgt_mask = self.generate_mask(src, tgt)
        src_embedded = self.dropout(self.positional_encoding(self.encoder_embedding(src)))
        tgt_embedded = self.dropout(self.positional_encoding(self.decoder_embedding(tgt)))

        enc_output = src_embedded
        for enc_layer in self.encoder_layers:
            enc_output = enc_layer(enc_output, src_mask)

        dec_output = tgt_embedded
        for dec_layer in self.decoder_layers:
            dec_output = dec_layer(dec_output, enc_output, src_mask, tgt_mask)

        output = self.fc(dec_output)
        return output

# Toy Dataset: Simple number reversal task
def generate_toy_data(num_samples=1000, seq_length=5):
    """Generate sequences where target is the reverse of source"""
    data = []
    for _ in range(num_samples):
        # Generate random sequence (vocab: 1-9, 0 is padding)
        seq = torch.randint(1, 10, (seq_length,))
        # Target is reversed sequence with BOS (10) and EOS (11) tokens
        tgt_input = torch.cat([torch.tensor([10]), seq.flip(0)])
        tgt_output = torch.cat([seq.flip(0), torch.tensor([11])])
        data.append((seq, tgt_input, tgt_output))
    return data

# Training function
def train_model():
    # Hyperparameters
    src_vocab_size = 12  # 0-9 digits + BOS + EOS
    tgt_vocab_size = 12
    d_model = 128
    num_heads = 4
    num_layers = 2
    d_ff = 512
    max_seq_length = 20
    batch_size = 32
    num_epochs = 50

    # Generate toy data
    train_data = generate_toy_data(800, seq_length=5)
    test_data = generate_toy_data(200, seq_length=5)

    # Initialize model
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = Transformer(src_vocab_size, tgt_vocab_size, d_model, num_heads,
                       num_layers, d_ff, max_seq_length).to(device)

    criterion = nn.CrossEntropyLoss(ignore_index=0)
    optimizer = optim.Adam(model.parameters(), lr=0.0001, betas=(0.9, 0.98), eps=1e-9)

    print(f"Training on {device}")
    print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")

    # Training loop
    model.train()
    for epoch in range(num_epochs):
        total_loss = 0
        for i in range(0, len(train_data), batch_size):
            batch = train_data[i:i+batch_size]
            src_batch = torch.stack([item[0] for item in batch]).to(device)
            tgt_input_batch = torch.stack([item[1] for item in batch]).to(device)
            tgt_output_batch = torch.stack([item[2] for item in batch]).to(device)

            optimizer.zero_grad()
            output = model(src_batch, tgt_input_batch)
            loss = criterion(output.reshape(-1, tgt_vocab_size), tgt_output_batch.reshape(-1))
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

        avg_loss = total_loss / (len(train_data) // batch_size)
        if (epoch + 1) % 10 == 0:
            print(f"Epoch {epoch+1}/{num_epochs}, Loss: {avg_loss:.4f}")

    # Test the model
    model.eval()
    correct = 0
    with torch.no_grad():
        for src, tgt_input, tgt_output in test_data[:10]:
            src = src.unsqueeze(0).to(device)
            tgt_input = tgt_input.unsqueeze(0).to(device)

            output = model(src, tgt_input)
            predicted = output.argmax(dim=-1)

            print(f"Input:     {src.squeeze().cpu().tolist()}")
            print(f"Target:    {tgt_output.tolist()}")
            print(f"Predicted: {predicted.squeeze().cpu().tolist()}")
            print()

            if torch.equal(predicted.squeeze()[:-1], tgt_output[:-1].to(device)):
                correct += 1

    print(f"Accuracy on 10 test samples: {correct}/10")

if __name__ == "__main__":
    train_model()

Training on cpu
Model parameters: 930,316
Epoch 10/50, Loss: 0.2702
Epoch 20/50, Loss: 0.0633
Epoch 30/50, Loss: 0.0454
Epoch 40/50, Loss: 0.0230
Epoch 50/50, Loss: 0.0204
Input:     [3, 3, 6, 6, 7]
Target:    [7, 6, 6, 3, 3, 11]
Predicted: [7, 6, 6, 3, 3, 11]

Input:     [1, 6, 7, 2, 8]
Target:    [8, 2, 7, 6, 1, 11]
Predicted: [8, 2, 7, 6, 1, 11]

Input:     [4, 6, 6, 8, 4]
Target:    [4, 8, 6, 6, 4, 11]
Predicted: [4, 8, 6, 6, 4, 11]

Input:     [1, 6, 9, 1, 4]
Target:    [4, 1, 9, 6, 1, 11]
Predicted: [4, 1, 9, 6, 1, 11]

Input:     [8, 9, 7, 8, 5]
Target:    [5, 8, 7, 9, 8, 11]
Predicted: [5, 8, 7, 9, 8, 11]

Input:     [2, 2, 8, 1, 8]
Target:    [8, 1, 8, 2, 2, 11]
Predicted: [8, 1, 8, 2, 2, 11]

Input:     [8, 9, 3, 2, 4]
Target:    [4, 2, 3, 9, 8, 11]
Predicted: [4, 2, 3, 9, 8, 11]

Input:     [8, 3, 6, 1, 8]
Target:    [8, 1, 6, 3, 8, 11]
Predicted: [8, 1, 6, 3, 8, 11]

Input:     [7, 3, 4, 1, 2]
Target:    [2, 1, 4, 3, 7, 11]
Predicted: [2, 1, 4, 3, 7, 11]

Input:     [4, 9, 

In [None]:
nopeak = 1 - upper_tri
# tensor([
#     [[1., 0., 0., 0., 0., 0.],   ← Position 0 can only see position 0
#      [1., 1., 0., 0., 0., 0.],   ← Position 1 can see positions 0-1
#      [1., 1., 1., 0., 0., 0.],   ← Position 2 can see positions 0-2
#      [1., 1., 1., 1., 0., 0.],   ← Position 3 can see positions 0-3
#      [1., 1., 1., 1., 1., 0.],   ← Position 4 can see positions 0-4
#      [1., 1., 1., 1., 1., 1.]]   ← Position 5 can see all positions
# ])


In [None]:
import torch
import numpy as np

print("=" * 80)
print("COMPREHENSIVE GUIDE TO MATRIX MULTIPLICATION AND OPERATIONS")
print("=" * 80)

# Create sample tensors for demonstration
A = torch.tensor([[1, 2, 3],
                  [4, 5, 6]])  # Shape: [2, 3]

B = torch.tensor([[7, 8],
                  [9, 10],
                  [11, 12]])  # Shape: [3, 2]

C = torch.tensor([[2, 3],
                  [4, 5]])  # Shape: [2, 2]

v = torch.tensor([1, 2, 3])  # Shape: [3]

print("\nSample Tensors:")
print(f"A (2×3):\n{A}")
print(f"\nB (3×2):\n{B}")
print(f"\nC (2×2):\n{C}")
print(f"\nv (3):\n{v}")

# ============================================================================
# 1. ELEMENT-WISE MULTIPLICATION (Hadamard Product)
# ============================================================================
print("\n" + "=" * 80)
print("1. ELEMENT-WISE MULTIPLICATION: * operator")
print("=" * 80)

print("\nRequirement: Shapes must be EXACTLY the same (or broadcastable)")
print("\nC * C (element-wise):")
result = C * C
print(f"[[2, 3],   *  [[2, 3],   =  [[2*2, 3*3],   =  {result.tolist()}")
print(f" [4, 5]]       [4, 5]]       [4*4, 5*5]]")

print("\nBroadcasting example:")
scalar = torch.tensor([2, 3])  # Shape: [2]
print(f"C * [2, 3] (broadcasts to each row):")
result = C * scalar
print(f"Result:\n{result}")
print(f"Explanation: [[2,3], * [2,3] = [[2*2, 3*3], = [[4, 9],")
print(f"              [4,5]]            [4*2, 5*3]]    [8, 15]]")

print("\n⚠️  ERROR if shapes don't match:")
try:
    result = A * B  # [2,3] * [3,2] - doesn't work!
    print(result)
except RuntimeError as e:
    print(f"A * B fails: {e}")

# ============================================================================
# 2. MATRIX MULTIPLICATION: @ operator and torch.matmul()
# ============================================================================
print("\n" + "=" * 80)
print("2. MATRIX MULTIPLICATION: @ operator or torch.matmul()")
print("=" * 80)

print("\nRequirement: Inner dimensions must match: (m×n) @ (n×p) = (m×p)")
print("\nA @ B:")
print(f"Shape: [2, 3] @ [3, 2] = [2, 2]")
result = A @ B
print(f"Result:\n{result}")

print("\nDetailed calculation:")
print("[[1, 2, 3],   [[7,  8 ],   [[1*7+2*9+3*11,  1*8+2*10+3*12],")
print(" [4, 5, 6]] @  [9,  10],  = [4*7+5*9+6*11,  4*8+5*10+6*12]]")
print("               [11, 12]]")
print(f"             = [[58, 64],")
print(f"                [139, 154]]")

print("\ntorch.matmul(A, B) gives same result:")
result2 = torch.matmul(A, B)
print(f"Result:\n{result2}")

print("\n⚠️  ERROR if inner dimensions don't match:")
try:
    result = A @ A  # [2,3] @ [2,3] - doesn't work!
except RuntimeError as e:
    print(f"A @ A fails: {e}")


COMPREHENSIVE GUIDE TO MATRIX MULTIPLICATION AND OPERATIONS

Sample Tensors:
A (2×3):
tensor([[1, 2, 3],
        [4, 5, 6]])

B (3×2):
tensor([[ 7,  8],
        [ 9, 10],
        [11, 12]])

C (2×2):
tensor([[2, 3],
        [4, 5]])

v (3):
tensor([1, 2, 3])

1. ELEMENT-WISE MULTIPLICATION: * operator

Requirement: Shapes must be EXACTLY the same (or broadcastable)

C * C (element-wise):
[[2, 3],   *  [[2, 3],   =  [[2*2, 3*3],   =  [[4, 9], [16, 25]]
 [4, 5]]       [4, 5]]       [4*4, 5*5]]

Broadcasting example:
C * [2, 3] (broadcasts to each row):
Result:
tensor([[ 4,  9],
        [ 8, 15]])
Explanation: [[2,3], * [2,3] = [[2*2, 3*3], = [[4, 9],
              [4,5]]            [4*2, 5*3]]    [8, 15]]

⚠️  ERROR if shapes don't match:
A * B fails: The size of tensor a (3) must match the size of tensor b (2) at non-singleton dimension 1

2. MATRIX MULTIPLICATION: @ operator or torch.matmul()

Requirement: Inner dimensions must match: (m×n) @ (n×p) = (m×p)

A @ B:
Shape: [2, 3] @ [3, 2

In [None]:
# ============================================================================
# 3. torch.mm() - Matrix-Matrix multiplication (2D only)
# ============================================================================
print("\n" + "=" * 80)
print("3. torch.mm() - Strictly 2D Matrix Multiplication")
print("=" * 80)

print("\nRequirement: Both inputs must be 2D matrices")
print("\ntorch.mm(A, B):")
result = torch.mm(A, B)
print(f"Result:\n{result}")

print("\n⚠️  ERROR with non-2D tensors:")
try:
    result = torch.mm(A, v)  # v is 1D
except RuntimeError as e:
    print(f"torch.mm(A, v) fails: {e}")

# ============================================================================
# 4. torch.mv() - Matrix-Vector multiplication
# ============================================================================
print("\n" + "=" * 80)
print("4. torch.mv() - Matrix-Vector Multiplication")
print("=" * 80)

print("\nRequirement: matrix (2D) × vector (1D)")
print("\ntorch.mv(A, v):")
result = torch.mv(A, v)
print(f"Shape: [2, 3] × [3] = [2]")
print(f"Result: {result}")

print("\nDetailed calculation:")
print("[[1, 2, 3],   [1]     [1*1 + 2*2 + 3*3]   [14]")
print(" [4, 5, 6]] × [2]  =  [4*1 + 5*2 + 6*3] = [32]")
print("              [3]")

# ============================================================================
# 5. torch.bmm() - Batch Matrix Multiplication
# ============================================================================
print("\n" + "=" * 80)
print("5. torch.bmm() - Batch Matrix Multiplication")
print("=" * 80)

print("\nRequirement: Both inputs 3D with same batch size: (b×m×n) @ (b×n×p) = (b×m×p)")

batch_A = torch.randn(10, 3, 4)  # batch=10, 3×4 matrices
batch_B = torch.randn(10, 4, 5)  # batch=10, 4×5 matrices

result = torch.bmm(batch_A, batch_B)
print(f"Shape: [10, 3, 4] @ [10, 4, 5] = {list(result.shape)}")
print("Performs 10 separate matrix multiplications (one per batch)")

# ============================================================================
# 6. Broadcasting with @ and torch.matmul()
# ============================================================================
print("\n" + "=" * 80)
print("6. BROADCASTING with @ and torch.matmul()")
print("=" * 80)

print("\ntorch.matmul() supports broadcasting for batch dimensions!")

# Example 1: Broadcasting batch dimension
A_batched = torch.randn(5, 2, 3)  # 5 batches of 2×3 matrices
B_single = torch.randn(3, 4)      # Single 3×4 matrix

result = torch.matmul(A_batched, B_single)
print(f"\nExample 1 - Broadcast single matrix to all batches:")
print(f"[5, 2, 3] @ [3, 4] = {list(result.shape)}")
print("The [3,4] matrix broadcasts to all 5 batches")

# Example 2: Different batch sizes
A_batch = torch.randn(10, 1, 2, 3)  # batch shape [10, 1]
B_batch = torch.randn(1, 5, 3, 4)   # batch shape [1, 5]

result = torch.matmul(A_batch, B_batch)
print(f"\nExample 2 - Broadcast both batch dimensions:")
print(f"[10, 1, 2, 3] @ [1, 5, 3, 4] = {list(result.shape)}")
print("Batch dims [10,1] and [1,5] broadcast to [10,5]")

# ============================================================================
# 7. torch.dot() - Dot product (1D only)
# ============================================================================
print("\n" + "=" * 80)
print("7. torch.dot() - Dot Product (1D vectors only)")
print("=" * 80)

u = torch.tensor([1, 2, 3])
v = torch.tensor([4, 5, 6])

result = torch.dot(u, v)
print(f"u: {u}")
print(f"v: {v}")
print(f"torch.dot(u, v) = {result}")
print("Calculation: 1*4 + 2*5 + 3*6 = 32")

print("\n⚠️  ERROR with 2D tensors:")
try:
    result = torch.dot(C, C)
except RuntimeError as e:
    print(f"torch.dot(C, C) fails: {e}")

# ============================================================================
# 8. Einstein Summation: torch.einsum()
# ============================================================================
print("\n" + "=" * 80)
print("8. torch.einsum() - Einstein Summation (Most Flexible!)")
print("=" * 80)

print("\nVery powerful notation for complex operations")

# Matrix multiplication
result = torch.einsum('ij,jk->ik', A, B)
print(f"\nMatrix mult: 'ij,jk->ik'")
print(f"A @ B:\n{result}")

# Batch matrix multiplication
batch_A = torch.randn(10, 3, 4)
batch_B = torch.randn(10, 4, 5)
result = torch.einsum('bij,bjk->bik', batch_A, batch_B)
print(f"\nBatch matrix mult: 'bij,bjk->bik'")
print(f"Shape: {list(result.shape)}")

# Transpose
result = torch.einsum('ij->ji', A)
print(f"\nTranspose: 'ij->ji'")
print(f"Result:\n{result}")

# Diagonal sum (trace)
result = torch.einsum('ii->', C)
print(f"\nTrace (diagonal sum): 'ii->'")
print(f"Result: {result}")

# Element-wise multiplication then sum
result = torch.einsum('ij,ij->', C, C)
print(f"\nElement-wise mult then sum: 'ij,ij->'")
print(f"Result: {result}")

# ============================================================================
# 9. COMPARISON TABLE
# ============================================================================
print("\n" + "=" * 80)
print("9. QUICK REFERENCE TABLE")
print("=" * 80)

table = """
┌─────────────────┬──────────────────┬─────────────────────────────────┐
│ Operation       │ Syntax           │ Requirements                    │
├─────────────────┼──────────────────┼─────────────────────────────────┤
│ Element-wise    │ A * B            │ Same shape (or broadcastable)   │
│ multiplication  │                  │                                 │
├─────────────────┼──────────────────┼─────────────────────────────────┤
│ Matrix mult     │ A @ B            │ (m×n) @ (n×p) = (m×p)          │
│                 │ torch.matmul()   │ Supports broadcasting           │
├─────────────────┼──────────────────┼─────────────────────────────────┤
│ Matrix mult     │ torch.mm(A, B)   │ Strictly 2D matrices            │
│ (2D only)       │                  │ No broadcasting                 │
├─────────────────┼──────────────────┼─────────────────────────────────┤
│ Matrix-vector   │ torch.mv(A, v)   │ A: 2D matrix, v: 1D vector     │
├─────────────────┼──────────────────┼─────────────────────────────────┤
│ Batch matrix    │ torch.bmm(A, B)  │ Both 3D, same batch size       │
├─────────────────┼──────────────────┼─────────────────────────────────┤
│ Dot product     │ torch.dot(u, v)  │ Both 1D vectors, same length   │
├─────────────────┼──────────────────┼─────────────────────────────────┤
│ Einstein sum    │ torch.einsum()   │ Flexible notation for any op   │
└─────────────────┴──────────────────┴─────────────────────────────────┘
"""
print(table)

# ============================================================================
# 10. COMMON PITFALLS AND SOLUTIONS
# ============================================================================
print("\n" + "=" * 80)
print("10. COMMON PITFALLS")
print("=" * 80)

print("\n❌ PITFALL 1: Using * instead of @ for matrix multiplication")
print("Solution: Use @ or torch.matmul() for matrix multiplication")

print("\n❌ PITFALL 2: Dimension mismatch")
print("Solution: Check shapes with .shape before operations")
print(f"Example: A.shape = {A.shape}, B.shape = {B.shape}")

print("\n❌ PITFALL 3: 1D vector as matrix row/column")
v = torch.tensor([1, 2, 3])
print(f"v.shape = {v.shape} (1D)")
print("To make column: v.unsqueeze(1) →", v.unsqueeze(1).shape)
print("To make row: v.unsqueeze(0) →", v.unsqueeze(0).shape)

print("\n❌ PITFALL 4: Using torch.mm() with batches")
print("Solution: Use torch.bmm() or torch.matmul() for batched operations")

# ============================================================================
# 11. PRACTICAL EXAMPLES FROM TRANSFORMERS
# ============================================================================
print("\n" + "=" * 80)
print("11. TRANSFORMER EXAMPLES")
print("=" * 80)

batch_size, seq_len, d_model = 32, 10, 128
num_heads, d_k = 8, 16

# Example: Multi-head attention
Q = torch.randn(batch_size, num_heads, seq_len, d_k)
K = torch.randn(batch_size, num_heads, seq_len, d_k)
V = torch.randn(batch_size, num_heads, seq_len, d_k)

print("\nAttention scores: Q @ K^T")
print(f"Q shape: {list(Q.shape)}")
print(f"K.transpose(-2, -1) shape: {list(K.transpose(-2, -1).shape)}")

attn_scores = torch.matmul(Q, K.transpose(-2, -1))
print(f"Result shape: {list(attn_scores.shape)}")
print("Performs batch matrix mult: [32,8,10,16] @ [32,8,16,10] = [32,8,10,10]")

print("\nAttention output: attn_probs @ V")
attn_probs = torch.softmax(attn_scores, dim=-1)
output = torch.matmul(attn_probs, V)
print(f"Result shape: {list(output.shape)}")
print("[32,8,10,10] @ [32,8,10,16] = [32,8,10,16]")

print("\n" + "=" * 80)
print("END OF GUIDE")
print("=" * 80)


3. torch.mm() - Strictly 2D Matrix Multiplication

Requirement: Both inputs must be 2D matrices

torch.mm(A, B):
Result:
tensor([[ 58,  64],
        [139, 154]])

⚠️  ERROR with non-2D tensors:
torch.mm(A, v) fails: mat2 must be a matrix

4. torch.mv() - Matrix-Vector Multiplication

Requirement: matrix (2D) × vector (1D)

torch.mv(A, v):
Shape: [2, 3] × [3] = [2]
Result: tensor([14, 32])

Detailed calculation:
[[1, 2, 3],   [1]     [1*1 + 2*2 + 3*3]   [14]
 [4, 5, 6]] × [2]  =  [4*1 + 5*2 + 6*3] = [32]
              [3]

5. torch.bmm() - Batch Matrix Multiplication

Requirement: Both inputs 3D with same batch size: (b×m×n) @ (b×n×p) = (b×m×p)
Shape: [10, 3, 4] @ [10, 4, 5] = [10, 3, 5]
Performs 10 separate matrix multiplications (one per batch)

6. BROADCASTING with @ and torch.matmul()

torch.matmul() supports broadcasting for batch dimensions!

Example 1 - Broadcast single matrix to all batches:
[5, 2, 3] @ [3, 4] = [5, 2, 4]
The [3,4] matrix broadcasts to all 5 batches

Example 2 