### 1. 问题分析

Fashion Mnist 对 28*28 服装类图片标注(贴分类标签)，训练这个数据集可进行服装类型预测，这是一个典型的有监督学习案例。

1. 确定训练类型 ： e.g. 识别时装类型
2. 收集训练数据集： e.g.  [_Fashion mnist_](https://github.com/zalandoresearch/fashion-mnist)
3. 对数据进行标注： e.g. 为每张图片标注 服装类型标签
4. 训练：
5. 评估：把测试集丢进去，检查识别是否准确


In [5]:
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

### 2. 数据加载与分析


In [3]:
# ┌───────────────┐
# │ 1.数据转换器   │
# └───────────────┘
#
# 1. 定义对图像应用的操作。非常重要！
#    - transforms.ToTensor(): 将 PIL.Image 或 NumPy ndarray 转换为 PyTorch Tensor (形状: [C, H, W], 值域: [0.0, 1.0])
#    - transforms.Normalize(mean, std): 标准化图像张量。
#        - mean (tuple): 各个通道的均值，这里是 (0.5,) 因为是灰度图扩展成3通道。
#        - std (tuple): 各个通道的标准差，这里是 (0.5,)。
#        PyTorch 会将每个通道的像素值减去对应均值后除以对应标准差。
# 注意：如果你一开始不扩展通道，只用 ToTensor()，那么输入张量形状就是 [H, W] 或 [1, H, W]，需要调整模型输入层


# 图片尺寸已经是28x28，我们将其转换为[3, 28, 28]的RGB图像
# resize可以保证输入图像尺寸一致
transform = transforms.Compose([
    transforms.Resize((28, 28)),
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.5,), std=(0.5,)) # 对RGB三个通道进行标准化
    # 如果你的模型直接接受[H, W]或[1, H, W]张量，可以去掉 ToTensor() 并相应调整模型输入
])


In [None]:

# ┌───────────────┐
# │ 下载/加载数据 │
# └───────────────┘
#
# 2. 使用 datasets.FashionMNIST 加载数据
#    - root: 数据存储路径
#    - train: True 表示下载训练集，False 表示下载测试集
#    - download: True 如果需要从互联网下载
#    - transform: 应用到每个图像的数据转换器
#

# 加载训练集 (大约 60,000 个样本)
train_dataset = datasets.FashionMNIST(root='./data', train=True, download=True)
# 加载测试集 (大约 10,000 个样本)
test_dataset = datasets.FashionMNIST(root='./data', train=False, download=True, transform=transform)

In [6]:

# ┌───────────────┐
# │ 创建 DataLoader │
# └───────────────┘
#
# 3. 创建 DataLoader 以方便在训练和推理过程中迭代数据。
#    - DataLoader 将Dataset对象包装成可迭代的对象。
#    - batch_size: 每个批次包含多少样本。
#    - shuffle: 是否在每个epoch开始时打乱训练数据 (通常设置为 True)。
#    - num_workers: 使用多少个子进程加载数据 (提高加载速度，根据系统资源设置，0 表示不使用额外进程)。
batch_size = 64 # 常见的批量大小
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=2) # 测试时不需要打乱

In [7]:
# ┌───────────────┐
# │ 查看数据结构 │
# └───────────────┘
#
# 检查训练集大小和形状
print(f"Train dataset size: {len(train_dataset)}")
example_image, example_label = train_dataset[0]
print(f"Shape of one image tensor: {example_image.shape}")        # [C, H, W] -> [3, 28, 28]
print(f"Data type of one image tensor: {example_image.dtype}")   # tensor.float32
print(f"Value range of one image tensor: [{example_image.min().item()}, {example_image.max().item()}]") # [0.0, 1.0]
print(f"Training labels are initially integers (first 5): {train_dataset.targets[:5]}")      # [9, 0, 0, 3, 0]
print(f"FashionMNIST class names:\n{train_dataset.classes}") # ['T-shirt/top', 'Trouser', ...]

Train dataset size: 60000
Shape of one image tensor: torch.Size([1, 28, 28])
Data type of one image tensor: torch.float32
Value range of one image tensor: [-1.0, 1.0]
Training labels are initially integers (first 5): tensor([9, 0, 0, 3, 0])
FashionMNIST class names:
['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat', 'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']


In [8]:
# 检查 DataLoader 在一个 epoch 内会产生多少批次
print(f"\nNumber of batches in train_loader (one epoch): {len(train_loader)}")
print(f"Example batch shapes from train_loader:")
for batch_idx, (images, labels) in enumerate(train_loader):
    if batch_idx == 0: # 只打印第一个批次查看
        print(f"Batch {batch_idx}: images shape: {images.shape}, labels shape: {labels.shape}") # [B, C, H, W], [B]
        print(f"First image in batch min/max: {images[0].min().item():.2f}, {images[0].max().item():.2f}")
        break


Number of batches in train_loader (one epoch): 938
Example batch shapes from train_loader:
Batch 0: images shape: torch.Size([64, 1, 28, 28]), labels shape: torch.Size([64])
First image in batch min/max: -1.00, 1.00


### 3. 数据处理

In [9]:
import torch.nn as nn
class SimpleNN(nn.Module):
    def __init__(self):
        super(SimpleNN, self).__init__()
        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(28*28, 128)
        self.fc2 = nn.Linear(128, 10)  # 输出10个类别

    def forward(self, x):
        x = self.flatten(x)
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x

model = SimpleNN()

In [13]:
import numpy as np

# 提取特征函数
def extract_features(model, dataloader):
    model.eval()
    features = []
    labels = []
    
    with torch.no_grad():
        for inputs, target in dataloader:
            output = model(inputs)
            features.append(output.numpy())  # 提取的特征
            labels.append(target.numpy())
    
    features = np.concatenate(features, axis=0)
    labels = np.concatenate(labels, axis=0)
    return features, labels

# 提取训练集特征
train_features, train_labels = extract_features(model, train_loader)
test_features, test_labels = extract_features(model, test_loader)
print(train_features.shape, train_labels.shape, test_features.shape, test_labels.shape)
train_features

(60000, 10) (60000,) (10000, 10) (10000,)


array([[ 0.05564203, -0.14549804,  0.27228755, ..., -0.06906623,
         0.20939602, -0.10270488],
       [ 0.22178867, -0.27685505,  0.07496391, ...,  0.11466201,
         0.18641166, -0.07255948],
       [ 0.0788973 , -0.3209244 ,  0.25007313, ...,  0.08427867,
         0.27098778, -0.16816346],
       ...,
       [ 0.24253392, -0.39807236, -0.01522865, ...,  0.09629457,
         0.25996435,  0.06932414],
       [ 0.10248686, -0.5121202 ,  0.01230931, ...,  0.09224001,
        -0.0522152 ,  0.00888842],
       [-0.18718103, -0.2970221 ,  0.04109677, ...,  0.033646  ,
        -0.15043348, -0.07145831]], shape=(60000, 10), dtype=float32)

### 4. 模型训练

In [14]:
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score

# 训练随机森林模型
rf_model = RandomForestClassifier(n_estimators=100)
rf_model.fit(train_features, train_labels)

Random Forest Accuracy: 66.70%


### 5. 模型评估

In [None]:
# 在测试集上进行预测
test_predictions = rf_model.predict(test_features)

# 计算准确率
accuracy = accuracy_score(test_labels, test_predictions)
print(f"Random Forest Accuracy: {accuracy * 100:.2f}%")