# Understanding One-Hot Encoding in PyTorch

This notebook focuses **only** on understanding and practicing *one-hot encoding*.

By the end, you will understand:
- What one-hot encoding means
- Why we use it in classification and recognition
- How to implement it in **PyTorch** step-by-step
- When you do *not* need to use it manually (e.g., with `CrossEntropyLoss`)

## 1Ô∏è‚É£ What is One-Hot Encoding?
In classification tasks, we often have multiple classes, e.g., Cat, Dog, Rabbit.
We need a way to tell the network which class is the correct one.

üëâ One-hot encoding represents the label as a **vector of 0s and 1s**, where only the correct class is 1.

### Example (3 classes)
| Class | One-hot Vector |
|:------|:---------------|
| Cat   | [1, 0, 0] |
| Dog   | [0, 1, 0] |
| Rabbit| [0, 0, 1] |

This is how the model knows *which* class is correct among many options.

## 2Ô∏è‚É£ Why Do We Use One-Hot Encoding?
**1. Clarity:** Only one class is marked as correct (1), others are 0.

**2. Compatibility:** Softmax outputs a probability vector of the same size ‚Äî easy to compare.

**3. Independence:** Class numbers (0,1,2,3...) have no real distance meaning. One-hot avoids confusion like 'class 2 is closer to class 1.'

üìå *Summary:* One-hot encoding clearly indicates which class is correct and matches the output format of classification networks.

## 3Ô∏è‚É£ Creating One-Hot Encodings in PyTorch
PyTorch provides a built-in function: `torch.nn.functional.one_hot()`

### Example: single label

In [None]:
import torch
import torch.nn.functional as F

# Suppose we have 4 possible classes (0,1,2,3)
label = torch.tensor([2])  # class index 2
one_hot = F.one_hot(label, num_classes=4)
print('Label index:', label)
print('One-hot vector:', one_hot)

### Example: batch of labels
Let's encode multiple samples at once.

In [None]:
labels = torch.tensor([0, 1, 2, 3])  # batch of labels
one_hot_batch = F.one_hot(labels, num_classes=4)
print('Labels:', labels)
print('One-hot batch:\n', one_hot_batch)

## 4Ô∏è‚É£ Visualizing One-Hot Vectors
Let‚Äôs visualize how a one-hot vector looks like for a batch of samples.

In [None]:
import matplotlib.pyplot as plt

plt.figure(figsize=(6, 3))
plt.imshow(one_hot_batch.numpy(), cmap='Greens')
plt.title('Visualization of One-Hot Encodings (Batch)')
plt.xlabel('Class index')
plt.ylabel('Sample index')
plt.show()

## 5Ô∏è‚É£ When You Don't Need One-Hot Encoding
In **PyTorch**, some functions (like `nn.CrossEntropyLoss`) do **not** require one-hot labels.

- `CrossEntropyLoss` expects **class indices**, not one-hot vectors.
- Internally, it converts them to one-hot style automatically to compute the loss.

### Example:

In [None]:
import torch.nn as nn

criterion = nn.CrossEntropyLoss()
outputs = torch.tensor([[2.0, 0.5, 0.3]])  # logits for 3 classes
labels = torch.tensor([0])  # correct class index
loss = criterion(outputs, labels)
print('Loss (no one-hot needed):', loss.item())

## üß© Mini Exercise
Try changing the label values below and observe the new one-hot vectors!

**Task:** Modify the labels tensor and rerun the cell.

In [None]:
# Try different label combinations
labels = torch.tensor([1, 3, 0, 2])
one_hot_batch = F.one_hot(labels, num_classes=4)
print('Labels:', labels)
print('One-hot vectors:\n', one_hot_batch)

plt.imshow(one_hot_batch.numpy(), cmap='Greens')
plt.title('Updated One-Hot Encodings')
plt.show()