In [16]:
import torch
from typing import Optional, Dict, Tuple, List, Any

#### Question 1: Efficient Attention Mask Creation
You're implementing a decoder-only transformer and need to create causal attention masks efficiently.

Implement a function that creates a causal attention mask for a batch of sequences:

In [5]:
def create_causal_mask(seq_length: int, batch_size: int) -> torch.Tensor:
    """
    Create a causal attention mask for decoder-only transformers.
    
    Args:
        seq_length: Length of the sequence
        batch_size: Number of sequences in the batch
    
    Returns:
        A tensor of shape (batch_size, seq_length, seq_length) where:
        - mask[i, j, k] = 0 if position j can attend to position k
        - mask[i, j, k] = -inf if position j cannot attend to position k (k > j)
    
    Example:
        For seq_length=3, one slice should look like:
        [[0, -inf, -inf],
         [0,    0, -inf],
         [0,    0,    0]]
    """
    mask = torch.tril(
        torch.ones(seq_length, seq_length), diagonal=1
    )
    mask = mask.masked_fill(mask == 1, value=float('-inf'))

    # do not expand. PyTorch implicitly broadcasts during attention computation
    mask = mask.unsqueeze(0)

    return mask

#### Question 2: KV Cache Management
You're implementing an inference engine for a decoder-only transformer. To avoid recomputing key and value tensors for previous tokens, you need to implement a KV cache.

Implement a function that updates the KV cache during autoregressive generation:

In [8]:
def update_kv_cache(
    cache: Optional[Dict[int, Dict[str, torch.Tensor]]],
    new_keys: torch.Tensor,
    new_values: torch.Tensor,
    layer_idx: int
) -> Dict[int, Dict[str, torch.Tensor]]:
    """
    Update KV cache with new key and value tensors for a specific layer.
    
    Args:
        cache: Existing cache or None. Structure:
               {
                   0: {'keys': tensor, 'values': tensor},
                   1: {'keys': tensor, 'values': tensor},
                   ...
               }
        new_keys: shape (batch_size, 1, num_heads, head_dim)
        new_values: shape (batch_size, 1, num_heads, head_dim)
        layer_idx: Which transformer layer (0-indexed)
    
    Returns:
        Updated cache dictionary
    """
    if cache is None:
        cache = {}
    
    if layer_idx not in cache:
        cache[layer_idx] = {'keys': new_keys, 'values': new_values}
    else:
        cache[layer_idx]['keys'] = torch.cat(
            [cache[layer_idx]['keys'], new_keys], dim=1
        )
        cache[layer_idx]['values'] = torch.cat(
            [cache[layer_idx]['values'], new_values], dim=1
        )
    
    return cache

#### Question 3: Weight Quantization
You're building an inference engine and need to implement INT8 quantization to reduce model size and speed up inference.

Implement a function that quantizes FP32 weights to INT8 using symmetric quantization:

In [11]:
def quantize_weights_int8(
    weights: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Quantize FP32 weights to INT8 using symmetric quantization.
    
    Symmetric quantization maps the range [-max_abs, max_abs] to [-127, 127]
    
    Args:
        weights: FP32 tensor of any shape
    
    Returns:
        quantized_weights: INT8 tensor (same shape as input)
        scale: FP32 scalar tensor used for dequantization
        
    The relationship is: weights ≈ quantized_weights * scale
    
    Example:
        weights = torch.tensor([[-2.0, 1.5], [3.0, -1.0]])
        quantized, scale = quantize_weights_int8(weights)
        # scale = 3.0 / 127 ≈ 0.0236
        # quantized ≈ [[-85, 64], [127, -42]]
        
        # Dequantize: quantized * scale ≈ original weights
    """
    max_val = weights.abs().max()

    scale = max_val / 127
    if scale == 0.0:
        scale = 1.0

    quantized_weights = torch.clamp(
        torch.round(weights / scale),
        -127, 127
    ).to(torch.int8)

    return quantized_weights, scale

In [None]:
def quantize_weights_int8_per_channel(
    weights: torch.Tensor,  # Shape: (out_features, in_features)
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Per-channel quantization: each row gets its own scale.
    
    Returns:
        quantized_weights: INT8 tensor, shape (out_features, in_features)
        scales: FP32 tensor, shape (out_features,) - one scale per output channel
    """
    quantized_channels = []
    channel_scales = []

    q_max = 127

    for idx in range(weights.shape[0]):
        channel_data = weights[idx,:]

        max_val = channel_data.abs().max()
        scale = max_val / q_max

        if scale == 0.0:
            scale = 1.0

        quantized_channel_data = torch.clamp(
            torch.round(channel_data / scale), 
            - q_max, q_max
        ).to(torch.int8)

        quantized_channels.append(quantized_channel_data)
        channel_scales.append(scale)

    quantized_channels = torch.stack(quantized_channels, dim=0)
    channel_scales = torch.stack(channel_scales, dim=0)

    return quantized_channels, channel_scales

#### Question 4: Batched Attention with Variable Sequence Lengths
You're implementing batched inference where different sequences in the batch have different lengths. You need to compute attention while properly masking padded positions.

Implement a function that computes scaled dot-product attention for a batch with variable-length sequences:

In [None]:
def batched_attention_with_padding(
    query: torch.Tensor,      # (batch_size, seq_len, d_model)
    key: torch.Tensor,        # (batch_size, seq_len, d_model)
    value: torch.Tensor,      # (batch_size, seq_len, d_model)
    seq_lengths: List[int],   # Actual length of each sequence (without padding)
) -> torch.Tensor:
    """
    Compute scaled dot-product attention with padding mask.
    
    Args:
        query: Query tensor (batch_size, seq_len, d_model)
        key: Key tensor (batch_size, seq_len, d_model)
        value: Value tensor (batch_size, seq_len, d_model)
        seq_lengths: List of actual sequence lengths for each batch element
                     e.g., [5, 3, 7] means batch has 3 sequences with lengths 5, 3, 7
                     
    Returns:
        output: Attention output (batch_size, seq_len, d_model)
        
    Note: 
    - Padded positions should not contribute to attention (use -inf in attention scores)
    - Use scaled dot-product: softmax(Q @ K^T / sqrt(d_model)) @ V
    
    Example:
        batch_size=2, seq_len=4, d_model=8
        seq_lengths = [3, 2]  # First sequence has 3 valid tokens, second has 2
        
        Attention mask should be:
        [[1, 1, 1, 0],   # First sequence: positions 0,1,2 are valid
         [1, 1, 0, 0]]   # Second sequence: positions 0,1 are valid
    """
    # Your implementation here
    pass

#### Question 5: Tensor Parallelism - Column-wise Weight Split
You're implementing tensor parallelism to distribute a large linear layer across multiple GPUs. In tensor parallelism, we split weight matrices across devices.

Implement a function that performs a column-parallel linear layer:

In [None]:
def column_parallel_linear(
    input: torch.Tensor,          # (batch_size, seq_len, hidden_size)
    weight: torch.Tensor,         # (output_size, hidden_size) - FULL weight on this rank
    rank: int,                    # Current GPU rank (0, 1, 2, ...)
    world_size: int,              # Total number of GPUs
    bias: Optional[torch.Tensor] = None,  # (output_size,) - FULL bias
) -> torch.Tensor:
    """
    Perform column-parallel linear layer where weight matrix is split by columns.
    
    In column parallelism:
    - Input is the SAME on all GPUs (replicated)
    - Weight is SPLIT by columns across GPUs
    - Each GPU computes a partial output
    - Outputs are CONCATENATED (via all-gather)
    
    Example with world_size=2, output_size=8:
        GPU 0 gets weight[:, 0:4]  -> computes output[:, :, 0:4]
        GPU 1 gets weight[:, 4:8]  -> computes output[:, :, 4:8]
        
        Then all-gather to get full output of size (batch, seq_len, 8)
    
    Args:
        input: Input tensor (same on all GPUs)
        weight: FULL weight matrix on this rank - you need to slice it
        rank: Which GPU this is (0-indexed)
        world_size: Total number of GPUs
        bias: Optional full bias vector
        
    Returns:
        output: Full output tensor (batch_size, seq_len, output_size)
        
    Note: For this exercise, assume you have access to:
        - torch.distributed.all_gather() for gathering tensors
        - Ignore actual distributed setup, focus on the sharding logic
    """
    output_features = weight.shape[0]
    output_features_per_rank = output_features // world_size

    rank_weight = weight[rank * output_features_per_rank: (rank + 1) * output_features_per_rank, :] 
    rank_bias = None
    if bias is not None:
        rank_bias = bias[rank * output_features_per_rank: (rank + 1) * output_features_per_rank]

    rank_output = input @ rank_weight.T
    if rank_bias is not None:
        rank_output += rank_bias

    tensor_list = [torch.zeros_like(rank_output) for _ in range(world_size)]
    torch.distributed.all_gather(tensor_list, rank_output)

    output = torch.cat(tensor_list, dim=-1)
    return output


#### Question 7: Token Sampling Strategies
You're building an inference engine and need to implement different sampling strategies for text generation.

Implement three sampling functions: greedy, top-k, and top-p (nucleus) sampling:

In [13]:
def greedy_sample(logits: torch.Tensor) -> torch.Tensor:
    """
    Greedy sampling - select token with highest probability.
    
    Args:
        logits: Logits tensor of shape (batch_size, vocab_size)
        
    Returns:
        next_tokens: Token IDs of shape (batch_size,)
    """
    probs = torch.softmax(logits, dim=-1)
    pred = torch.argmax(probs, dim=-1)

    return pred


def top_k_sample(
    logits: torch.Tensor,
    k: int,
    temperature: float = 1.0
) -> torch.Tensor:
    """
    Top-K sampling - sample from the k most likely tokens.
    
    Args:
        logits: Logits tensor of shape (batch_size, vocab_size)
        k: Number of top tokens to consider
        temperature: Softmax temperature (higher = more random)
        
    Returns:
        next_tokens: Token IDs of shape (batch_size,)
        
    Steps:
    1. Apply temperature scaling: logits / temperature
    2. Find top-k logits and their indices
    3. Set all other logits to -inf
    4. Apply softmax to get probabilities
    5. Sample from the distribution
    """
    logits = logits / temperature

    topk_values, topk_indices = torch.topk(logits, k=k, dim=-1)

    filtered_logits = torch.full_like(logits, fill_value=float('-inf'))
    filtered_logits.scatter_(dim=-1, index=topk_indices, src=topk_values)

    probs = torch.softmax(filtered_logits, dim=-1)
    pred = torch.multinomial(probs, num_samples=1)

    return pred.unsqueeze(-1)


def top_p_sample(
    logits: torch.Tensor,
    p: float,
    temperature: float = 1.0
) -> torch.Tensor:
    """
    Top-P (nucleus) sampling - sample from smallest set of tokens whose
    cumulative probability exceeds p.
    
    Args:
        logits: Logits tensor of shape (batch_size, vocab_size)
        p: Cumulative probability threshold (e.g., 0.9)
        temperature: Softmax temperature
        
    Returns:
        next_tokens: Token IDs of shape (batch_size,)
        
    Steps:
    1. Apply temperature scaling
    2. Convert to probabilities with softmax
    3. Sort probabilities in descending order
    4. Compute cumulative probabilities
    5. Find cutoff where cumsum > p
    6. Mask out tokens beyond cutoff
    7. Renormalize and sample
    """
    logits = logits / temperature

    sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
    sorted_probs = torch.softmax(sorted_logits, dim=-1)

    probs_cumsum = torch.cumsum(sorted_probs, dim=-1)
    mask = probs_cumsum > p
    mask[...,0] = False
    sorted_logits[mask] = float('-inf')

    filtered_logits = torch.full_like(logits, fill_value=float('-inf'))
    filtered_logits.scatter_(dim=-1, index=sorted_indices, src=sorted_logits)

    probs = torch.softmax(filtered_logits, dim=-1)
    token_ids = torch.multinomial(probs, num_samples=1)

    return token_ids.unsqueeze(-1)

#### Question 8: Gradient Accumulation
You're training a large model but GPU memory is limited. You need to implement gradient accumulation to simulate a larger batch size.

Implement a training step with gradient accumulation:

In [14]:
from torch.utils.data import DataLoader

In [None]:
def train_step_with_grad_accumulation(
    model: torch.nn.Module,
    optimizer: torch.optim.Optimizer,
    data_loader: DataLoader,
    accumulation_steps: int,
    device: str = 'cuda'
) -> float:
    """
    Perform one training epoch with gradient accumulation.
    
    Gradient accumulation allows training with effective batch size of:
        actual_batch_size * accumulation_steps
    
    Args:
        model: The neural network model
        optimizer: Optimizer (e.g., Adam)
        data_loader: DataLoader providing batches
        accumulation_steps: Number of batches to accumulate before optimizer step
        device: Device to run on
        
    Returns:
        average_loss: Average loss over the epoch
        
    Key points:
    - Gradients accumulate across multiple forward/backward passes
    - Optimizer step happens every accumulation_steps batches
    - Need to scale loss appropriately
    - Zero gradients at the right time
    
    Example:
        If batch_size=8 and accumulation_steps=4:
        - Effective batch size = 32
        - Do 4 forward passes with batch_size=8
        - Accumulate gradients from all 4
        - Then do optimizer.step()
    """
    model.train()
    criterion = torch.nn.CrossEntropyLoss()

    total_loss = 0.0

    for idx, (batch, y) in enumerate(DataLoader):
        batch = batch.to(device)
        y = y.to(device)

        output = model(batch).to(device)
        loss = criterion(output, y) / accumulation_steps
        loss.backward()

        total_loss += loss.item()

        if (idx + 1) % accumulation_steps == 0:
            optimizer.step()
            optimizer.zero_grad()

    if (idx + 1) % accumulation_steps != 0:
        optimizer.step()
        optimizer.zero_grad()
    
    return total_loss / len(data_loader)

#### Question 9: Mixed Precision Training (FP16/BF16)
You're implementing mixed precision training to reduce memory usage and speed up training. This involves using FP16/BF16 for forward/backward passes while keeping FP32 master weights.

Implement a training step with mixed precision and gradient scaling:

In [15]:
def train_step_mixed_precision(
    model: torch.nn.Module,
    optimizer: torch.optim.Optimizer,
    inputs: torch.Tensor,
    labels: torch.Tensor,
    scaler: torch.cuda.amp.GradScaler,
) -> float:
    """
    Perform one training step with automatic mixed precision (AMP).
    
    Mixed precision training:
    - Forward pass in FP16/BF16 (faster, less memory)
    - Loss in FP16/BF16
    - Gradients scaled up to prevent underflow
    - Optimizer step with FP32 master weights
    
    Args:
        model: Neural network model
        optimizer: Optimizer
        inputs: Input tensor (batch_size, ...)
        labels: Label tensor (batch_size,)
        scaler: GradScaler for scaling gradients
        use_amp: Whether to use automatic mixed precision
        
    Returns:
        loss_value: Scalar loss value
        
    Key concepts:
    - torch.cuda.amp.autocast() for automatic dtype casting
    - GradScaler to prevent gradient underflow in FP16
    - Gradient scaling: scale up before backward, unscale before optimizer step
    
    Why gradient scaling?
    - FP16 range: ~6e-5 to 65504
    - Small gradients (e.g., 1e-7) underflow to 0 in FP16
    - Solution: Scale gradients by large factor (e.g., 2^16) during backward
    - Unscale before optimizer step
    """
    criterion = torch.nn.CrossEntropyLoss()

    with torch.amp.autocast(device_type=torch.float16):
        outputs = model(inputs)
        loss = criterion(outputs, labels)
    
    scaler.scale(loss)
    scaler.step(optimizer)
    scaler.update()

    return loss.item()

#### Question 10: Continuous Batching for Inference
You're building an inference server that needs to handle multiple requests efficiently. Traditional batching waits for a fixed batch to fill up, but continuous batching dynamically adds/removes sequences as they complete.

Implement a simplified continuous batching scheduler:

In [None]:
class ContinuousBatchScheduler:
    """
    Continuous batching scheduler for LLM inference.
    
    Key idea:
    - Don't wait for all sequences to finish
    - As soon as one sequence completes, add a new request
    - Maximizes GPU utilization
    
    Example timeline:
        Time 0: [Req1, Req2, Req3, Req4] (batch=4, all start)
        Time 5: Req2 finishes (generated EOS token)
        Time 5: [Req1, Req5, Req3, Req4] (Req5 added immediately)
        Time 8: Req3 finishes
        Time 8: [Req1, Req5, Req6, Req4] (Req6 added immediately)
    """
    
    def __init__(self, max_batch_size: int, max_seq_length: int):
        """
        Args:
            max_batch_size: Maximum number of sequences in a batch
            max_seq_length: Maximum sequence length
        """
        self.max_batch_size = max_batch_size
        self.max_seq_length = max_seq_length
        self.active_requests = {}  # request_id -> request_data
        self.waiting_queue = []    # Queue of pending requests
    
    def add_request(self, request_id: str, prompt_tokens: List[int]):
        """
        Add a new request to the queue.
        
        Args:
            request_id: Unique identifier for this request
            prompt_tokens: Input token IDs
        """
        # Your implementation here
        request_data = {
            'tokens': prompt_tokens.copy(),
            'length': len(prompt_tokens)
        }

        if len(self.active_requests) >= self.max_batch_size:
            self.waiting_queue.append(request_data)
        else:
            self.active_requests[request_id] = request_data
    
    def get_batch(self) -> Dict[str, Any]:
        """
        Get the current batch for inference.
        
        Returns:
            Dictionary containing:
            - 'request_ids': List of active request IDs
            - 'input_ids': Tensor of shape (batch_size, seq_len) - padded
            - 'attention_mask': Tensor of shape (batch_size, seq_len)
            - 'sequence_lengths': List of current lengths for each sequence
            
        Note: Sequences will have different lengths, so padding is needed.
        """
        if not self.active_requests:
            return {
                'request_ids': [],
                'input_ids': torch.tensor([]),
                'attention_mask': torch.tensor([]),
                'sequence_lengths': []
            }
        
        request_ids = list(self.active_requests.keys())
        all_tokens = [self.active_requests[rid]['tokens'] for rid in request_ids]
        sequence_lengths = [self.active_requests[rid]['length'] for rid in request_ids]

        max_length = max(sequence_lengths)
        batch_size = len(request_ids)

        input_ids = torch.zeros((batch_size, max_length))
        attention_mask = torch.zeros(batch_size, max_length, dtype=torch.long)
        
        for i, tokens in enumerate(all_tokens):
            seq_len = len(tokens)
            input_ids[i, :seq_len] = torch.tensor(tokens)
            attention_mask[i, :seq_len] = 1
        
        return {
            'request_ids': request_ids,
            'input_ids': input_ids,
            'attention_mask': attention_mask,
            'sequence_lengths': sequence_lengths
        }
    
    def update_batch(self, request_id: str, new_token: int, is_finished: bool):
        """
        Update a request with newly generated token.
        
        Args:
            request_id: Which request to update
            new_token: Newly generated token ID
            is_finished: Whether this sequence is complete (EOS or max_length)
            
        If is_finished=True, remove from active_requests and try to add
        a new request from waiting_queue to maintain batch size.
        """
        if not is_finished:
            self.active_requests[request_id]['tokens'].append(new_token)
            self.active_requests[request_id]['length'] += 1

            if self.active_requests[request_id]['length'] > self.max_seq_length:
                self.update_batch(request_id, None, is_finished)
        else:
            del self.active_requests[request_id]
            
            if self.waiting_queue and len(self.active_requests) < self.max_batch_size:
                new_req_id, prompt_tokens = self.waiting_queue.pop(0)
                self.active_requests[new_req_id] = {
                    'tokens': prompt_tokens.copy(),
                    'length': len(prompt_tokens)
                }
    
    def can_add_more_requests(self) -> bool:
        """
        Check if we can add more requests to the current batch.
        
        Returns:
            True if batch is not full and waiting queue has requests
        """
        # Your implementation here
        return len(self.active_requests) < self.max_batch_size