# Softmax Layer Exercise (with Batch Support & Numerical Stability)

**Goal:** Implement a numerically stable softmax function that works for both a single vector of logits and a batch of logits.

## 🎯 Learning Objectives

By completing this exercise, you will:  

- **Understand the role of softmax** in neural networks and why it is used to convert raw scores (logits) into probabilities.  
- **Implement the softmax function** in Python using NumPy, both for single vectors (1D arrays) and for batches of data (2D arrays).  
- **Apply numerical stability techniques** by subtracting the maximum logit before exponentiation to prevent overflow errors.  
- **Verify correctness** of your implementation with simple unit tests that check properties like “outputs sum to 1” and “softmax is invariant to constant shifts.”  
- **Practice debugging and testing skills** — interpreting failed tests and fixing implementation mistakes.  
- Build confidence in working with **array operations, broadcasting, and vectorization** in NumPy.  


## Background

In classification tasks, a neural network outputs raw scores called **logits**.  
These values are not probabilities — they can be any real numbers.  

The **softmax function** transforms logits into probabilities:

$$
\text{softmax}(z)_i = \frac{e^{z_i}}{\sum_{j=1}^{C} e^{z_j}}
$$

Key properties:
- All outputs are **between 0 and 1**.
- The outputs **sum to 1**, so they form a valid probability distribution.
- The largest logit corresponds to the class with the highest probability.

👉 **Numerical stability trick**:  
If logits are large, exponentials can overflow.  
To prevent this, we subtract the maximum value before exponentiation:

$$
\text{softmax}(z)_i = \frac{e^{\,z_i - \max(z)}}{\sum_{j=1}^{C} e^{\,z_j - \max(z)}}
$$

This subtraction does **not** change the result — it only keeps numbers manageable for the computer.



## Your Task

Implement the function `softmax(x)`.

Requirements:
- If `x` is a **1D NumPy array** of shape `(C,)`, return a vector of shape `(C,)`.
- If `x` is a **2D NumPy array** of shape `(N, C)`, return a matrix of shape `(N, C)`.  
  Each row should be a probability distribution that sums to 1.
- Use the **numerical stability trick** (subtract the max per vector or per row).
- **Do not modify the input array in place.**

Hints:
- Use `x.ndim` to check if the input is 1D or 2D.
- Use `keepdims=True` when subtracting or summing across rows, so broadcasting works correctly.
- Test your function on both **vector** (1D) and **batch** (2D) cases.



In [None]:
import numpy as np

def softmax(x: np.ndarray) -> np.ndarray:
    """Compute numerically stable softmax"""
    # your code here
    raise NotImplementedError("Implement the softmax function")


 ## Unit tests and Validation

In [None]:
import numpy as np   # <-- add this line

vec = np.array([2.0, 1.0, 0.0])
out = softmax(vec)
assert out.shape == (3,)
assert abs(out.sum() - 1) < 1e-9
assert np.all(out >= 0), "Probabilities must be non-negative"
print("Vector test passed")

batch = np.array([[1, 2, 3],
                  [0, 0, 0]])
outb = softmax(batch)
assert outb.shape == batch.shape
assert np.allclose(outb.sum(axis=1), np.ones(2))
print("Batch test passed")

print("All tests passed ✅")


### Try it yourself

Run the cell below to validate your implementation. It checks:
- Vector case: shape, sum to 1, and reasonable values.
- Batch case: per-row probabilities sum to 1.
- (Optional) You can add more tests of your own.

If a test fails, read the error message, fix your implementation, and re-run.


In [None]:
# Try it (feel free to change numbers)
logits = np.array([[2.0, 1.0, 0.0],
                   [5.0, -1.0, 1.0]])
probs = softmax(logits)
probs, probs.sum(axis=1)


## Instructor's Solution

> This cell would be hidden from learners in a real course.


In [None]:
def softmax(x: np.ndarray) -> np.ndarray:
    import numpy as np
    if x.ndim == 1:
        shifted = x - np.max(x)
        exps = np.exp(shifted)
        return exps/np.sum(exps)
    elif x.ndim == 2:
        shifted = x - np.max(x, axis=1, keepdims=True)
        exps = np.exp(shifted)
        return exps/np.sum(exps, axis=1, keepdims=True)
    else:
        raise ValueError("Input must be 1D or 2D")
