In [None]:
# For tips on running notebooks in Google Colab, see
# https://docs.pytorch.org/tutorials/beginner/colab
%matplotlib inline

[Learn the Basics](intro.html) \|\|
[Quickstart](quickstart_tutorial.html) \|\|
[Tensors](tensorqs_tutorial.html) \|\| [Datasets &
DataLoaders](data_tutorial.html) \|\| **Transforms** \|\| [Build
Model](buildmodel_tutorial.html) \|\|
[Autograd](autogradqs_tutorial.html) \|\|
[Optimization](optimization_tutorial.html) \|\| [Save & Load
Model](saveloadrun_tutorial.html)

Transforms
==========

Data does not always come in its final processed form that is required
for training machine learning algorithms. We use **transforms** to
perform some manipulation of the data and make it suitable for training.

All TorchVision datasets have two parameters -`transform` to modify the
features and `target_transform` to modify the labels - that accept
callables containing the transformation logic. The
[torchvision.transforms](https://pytorch.org/vision/stable/transforms.html)
module offers several commonly-used transforms out of the box.

The FashionMNIST features are in PIL Image format, and the labels are
integers. For training, we need the features as normalized tensors, and
the labels as one-hot encoded tensors. To make these transformations, we
use `ToTensor` and `Lambda`.

# **解释scatter_**:

在 PyTorch 中，`scatter_` 是一个非常有用的“填坑”函数。它的核心逻辑是：**根据你提供的索引（Index），将特定的值（Value）散布（Scatter）到原张量的指定位置中。**

那个结尾的下划线 `_` 代表这是一个 **In-place 操作**，即它会直接修改原张量，而不是返回一个新张量。

---

### 1. 函数原型

```python
tensor.scatter_(dim, index, src/value)

```

* **`dim`**: 沿着哪个维度进行填充。
* **`index`**: 索引张量，告诉程序要把值填在哪些位置。
* **`src/value`**: 要填进去的数据。可以是一个张量（src），也可以是一个标量（value）。

---

### 2. 直观理解：以你的 One-hot 编码为例

假设我们要处理一个类别 ，类别总数是 。

1. **准备“底盘”：** `torch.zeros(10)` -> `[0, 0, 0, 0, 0, 0, 0, 0, 0, 0]`
2. **设定“坐标”：** `torch.tensor(3)`
3. **执行 `scatter_`：** * `dim=0`：在这一行里操作。
* `index=3`：瞄准索引为 3 的那个坑位。
* `value=1`：把 1 填进去。



**结果：** `[0, 0, 0, 1, 0, 0, 0, 0, 0, 0]`

---

### 3. 数学表达

对于二维张量，如果你执行 `self.scatter_(0, index, src)`，其逻辑可以看作：


如果是像你代码里那种填充固定值（`value=1`）：


---

### 4. 为什么要用 `scatter_` 做 One-hot？

在深度学习中，我们经常需要把整数标签（比如 3）变成向量（比如 `[0,0,0,1...]`）。

* **如果用循环：** 速度极慢，尤其是在 Batch 很大时。
* **用 `scatter_`：** 它是高度优化的底层操作，支持 GPU 加速。即使你有一个包含 1024 个样本的 Batch，只需要一行代码就能瞬间完成全量的 One-hot 转换。

---

### 5. 进阶示例：批量处理

如果你有一个 Batch 的标签 `labels = torch.tensor([0, 2])`，想转为 One-hot：

```python
batch_size = 2
num_classes = 5
labels = torch.tensor([0, 2]).view(-1, 1) # 变成列向量 [[0], [2]]

# 创建全 0 矩阵
one_hot = torch.zeros(batch_size, num_classes)

# 沿维度 1（横向）填充
one_hot.scatter_(1, labels, 1)

print(one_hot)
# 输出:
# tensor([[1., 0., 0., 0., 0.],  <- 索引 0 变成了 1
#         [0., 0., 1., 0., 0.]]) <- 索引 2 变成了 1

```

---

### 总结

你可以把 `scatter_` 想象成一个**“精准打击”**的过程：你手里拿着一堆子弹（`value`），对着地图（`index`）上标出的坐标，把子弹打进靶子（`self`）里。



In [None]:
import torch
from torchvision import datasets
from torchvision.transforms import ToTensor, Lambda

ds = datasets.FashionMNIST(
    root="data",
    train=True,
    download=True,
    transform=ToTensor(),
    target_transform=Lambda(lambda y: torch.zeros(10, dtype=torch.float).scatter_(0, torch.tensor(y), value=1))
)

ToTensor()
==========

[ToTensor](https://pytorch.org/vision/stable/transforms.html#torchvision.transforms.ToTensor)
converts a PIL image or NumPy `ndarray` into a `FloatTensor`. and scales
the image\'s pixel intensity values in the range \[0., 1.\]


Lambda Transforms
=================

Lambda transforms apply any user-defined lambda function. Here, we
define a function to turn the integer into a one-hot encoded tensor. It
first creates a zero tensor of size 10 (the number of labels in our
dataset) and calls
[scatter\_](https://pytorch.org/docs/stable/generated/torch.Tensor.scatter_.html)
which assigns a `value=1` on the index as given by the label `y`.


In [None]:
target_transform = Lambda(lambda y: torch.zeros(
    10, dtype=torch.float).scatter_(dim=0, index=torch.tensor(y), value=1))

------------------------------------------------------------------------


Further Reading
===============

-   [torchvision.transforms
    API](https://pytorch.org/vision/stable/transforms.html)
