# 实战项目2：图像分类 Web 应用

## 项目概述

使用深度学习模型构建一个图像分类 Web 应用，从模型训练到 API 部署。

### 学习目标
- 训练图像分类模型
- 使用 FastAPI 构建 REST API
- 实现模型部署
- 创建简单的前端界面

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms, models
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import warnings
warnings.filterwarnings('ignore')

# 设置设备
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

Using device: cpu


## 第一部分：模型训练

### 1. 数据准备

In [2]:
# 数据增强和预处理
train_transform = transforms.Compose([
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(10),
    transforms.ColorJitter(brightness=0.2, contrast=0.2),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

test_transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

# 使用 CIFAR-10 数据集作为示例
train_dataset = datasets.CIFAR10(
    root='./data', train=True, download=True, transform=train_transform
)
test_dataset = datasets.CIFAR10(
    root='./data', train=False, download=True, transform=test_transform
)

# 数据加载器
# num_workers 说明：Windows 或 Jupyter 环境下如果卡住，请将 num_workers 改为 0
batch_size = 64
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=2)

# 类别名称
class_names = ['airplane', 'automobile', 'bird', 'cat', 'deer',
               'dog', 'frog', 'horse', 'ship', 'truck']

print(f"训练集大小: {len(train_dataset)}")
print(f"测试集大小: {len(test_dataset)}")
print(f"类别数: {len(class_names)}")

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [1:43:26<00:00, 27470.44it/s]  


Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified
训练集大小: 50000
测试集大小: 10000
类别数: 10


In [None]:
# 可视化部分训练数据
def show_images(dataloader, class_names, n=8):
    images, labels = next(iter(dataloader))
    
    fig, axes = plt.subplots(2, n//2, figsize=(15, 6))
    axes = axes.flatten()
    
    for i in range(n):
        # 反归一化
        img = images[i].numpy().transpose((1, 2, 0))
        img = img * np.array([0.229, 0.224, 0.225]) + np.array([0.485, 0.456, 0.406])
        img = np.clip(img, 0, 1)
        
        axes[i].imshow(img)
        axes[i].set_title(class_names[labels[i]])
        axes[i].axis('off')
    
    plt.tight_layout()
    plt.show()

show_images(train_loader, class_names)

### 2. 模型定义

In [None]:
# 使用预训练的 ResNet18 进行迁移学习
from torchvision.models import resnet18, ResNet18_Weights

def create_model(num_classes, pretrained=True):
    # 加载预训练模型
    weights = ResNet18_Weights.IMAGENET1K_V1 if pretrained else None
    model = resnet18(weights=weights)
    
    # 冻结早期层
    for param in model.parameters():
        param.requires_grad = False
    
    # 替换最后的全连接层
    num_features = model.fc.in_features
    model.fc = nn.Sequential(
        nn.Dropout(0.5),
        nn.Linear(num_features, 256),
        nn.ReLU(),
        nn.Dropout(0.3),
        nn.Linear(256, num_classes)
    )
    
    return model

model = create_model(num_classes=len(class_names)).to(device)
print("模型结构（最后几层）:")
print(model.fc)

### 3. 模型训练

In [None]:
# 损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.fc.parameters(), lr=0.001)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.5)

In [None]:
def train_epoch(model, dataloader, criterion, optimizer, device):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    
    for images, labels in dataloader:
        images, labels = images.to(device), labels.to(device)
        
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
    
    return running_loss / len(dataloader), correct / total

def evaluate(model, dataloader, criterion, device):
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0
    
    with torch.no_grad():
        for images, labels in dataloader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)
            
            running_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    
    return running_loss / len(dataloader), correct / total

In [None]:
# 训练循环
num_epochs = 10
train_losses, test_losses = [], []
train_accs, test_accs = [], []

best_acc = 0.0

for epoch in range(num_epochs):
    train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, device)
    test_loss, test_acc = evaluate(model, test_loader, criterion, device)
    
    train_losses.append(train_loss)
    test_losses.append(test_loss)
    train_accs.append(train_acc)
    test_accs.append(test_acc)
    
    scheduler.step()
    
    print(f'Epoch [{epoch+1}/{num_epochs}]')
    print(f'  Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}')
    print(f'  Test Loss: {test_loss:.4f}, Test Acc: {test_acc:.4f}')
    
    # 保存最佳模型
    if test_acc > best_acc:
        best_acc = test_acc
        torch.save(model.state_dict(), 'best_model.pth')

print(f'\nBest Test Accuracy: {best_acc:.4f}')

In [None]:
# 绘制训练曲线
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# 损失曲线
axes[0].plot(train_losses, label='Train Loss')
axes[0].plot(test_losses, label='Test Loss')
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Loss')
axes[0].set_title('Training and Test Loss')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

# 准确率曲线
axes[1].plot(train_accs, label='Train Acc')
axes[1].plot(test_accs, label='Test Acc')
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('Accuracy')
axes[1].set_title('Training and Test Accuracy')
axes[1].legend()
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

## 第二部分：API 部署

### 4. 创建推理类

In [None]:
class ImageClassifier:
    """图像分类器类"""
    
    def __init__(self, model_path, class_names, device='cpu'):
        self.device = torch.device(device)
        self.class_names = class_names
        
        # 加载模型
        self.model = create_model(num_classes=len(class_names), pretrained=False)
        self.model.load_state_dict(torch.load(model_path, map_location=self.device))
        self.model.to(self.device)
        self.model.eval()
        
        # 预处理
        self.transform = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ])
    
    def predict(self, image):
        """预测单张图像"""
        # 确保是 PIL Image
        if isinstance(image, str):
            image = Image.open(image).convert('RGB')
        
        # 预处理
        image_tensor = self.transform(image).unsqueeze(0).to(self.device)
        
        # 推理
        with torch.no_grad():
            outputs = self.model(image_tensor)
            probabilities = torch.nn.functional.softmax(outputs, dim=1)[0]
        
        # 获取 Top-5 预测
        top5_prob, top5_idx = torch.topk(probabilities, 5)
        
        results = []
        for prob, idx in zip(top5_prob.cpu().numpy(), top5_idx.cpu().numpy()):
            results.append({
                'class': self.class_names[idx],
                'probability': float(prob)
            })
        
        return results

# 测试推理类
# classifier = ImageClassifier('best_model.pth', class_names)
# results = classifier.predict('test_image.jpg')
print("ImageClassifier 类已定义")

### 5. FastAPI 代码

以下是 FastAPI 应用代码，保存为 `app.py`

In [None]:
# FastAPI 应用代码 - 保存到 app.py
fastapi_code = '''
from fastapi import FastAPI, File, UploadFile, HTTPException
from fastapi.responses import JSONResponse
from fastapi.middleware.cors import CORSMiddleware
from PIL import Image
import io
import torch
from torchvision import transforms, models
import torch.nn as nn

app = FastAPI(
    title="Image Classification API",
    description="CIFAR-10 图像分类 API",
    version="1.0.0"
)

# 允许跨域
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_methods=["*"],
    allow_headers=["*"],
)

# 类别名称
CLASS_NAMES = [
    "airplane", "automobile", "bird", "cat", "deer",
    "dog", "frog", "horse", "ship", "truck"
]

# 加载模型
def load_model():
    model = models.resnet18(weights=None)
    num_features = model.fc.in_features
    model.fc = nn.Sequential(
        nn.Dropout(0.5),
        nn.Linear(num_features, 256),
        nn.ReLU(),
        nn.Dropout(0.3),
        nn.Linear(256, 10)
    )
    model.load_state_dict(torch.load("best_model.pth", map_location="cpu"))
    model.eval()
    return model

model = load_model()

# 图像预处理
transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

@app.get("/")
def root():
    return {"message": "Image Classification API", "status": "running"}

@app.post("/predict")
async def predict(file: UploadFile = File(...)):
    # 验证文件类型
    if not file.content_type.startswith("image/"):
        raise HTTPException(status_code=400, detail="File must be an image")
    
    # 读取图像
    contents = await file.read()
    image = Image.open(io.BytesIO(contents)).convert("RGB")
    
    # 预处理
    image_tensor = transform(image).unsqueeze(0)
    
    # 推理
    with torch.no_grad():
        outputs = model(image_tensor)
        probabilities = torch.nn.functional.softmax(outputs, dim=1)[0]
    
    # 获取 Top-5 结果
    top5_prob, top5_idx = torch.topk(probabilities, 5)
    
    predictions = []
    for prob, idx in zip(top5_prob.cpu().numpy(), top5_idx.cpu().numpy()):
        predictions.append({
            "class": CLASS_NAMES[idx],
            "probability": float(prob)
        })
    
    return JSONResponse(content={"predictions": predictions})

@app.get("/classes")
def get_classes():
    return {"classes": CLASS_NAMES}

if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=8000)
'''

# 写入文件
with open('app.py', 'w', encoding='utf-8') as f:
    f.write(fastapi_code)

print("FastAPI 应用代码已保存到 app.py")
print("\n运行方式: python app.py 或 uvicorn app:app --reload")

### 6. HTML 前端界面

保存为 `index.html`

In [None]:
html_code = '''
<!DOCTYPE html>
<html lang="zh">
<head>
    <meta charset="UTF-8">
    <meta name="viewport" content="width=device-width, initial-scale=1.0">
    <title>图像分类器</title>
    <style>
        body {
            font-family: Arial, sans-serif;
            max-width: 800px;
            margin: 0 auto;
            padding: 20px;
            background: #f5f5f5;
        }
        h1 { color: #333; text-align: center; }
        .container {
            background: white;
            padding: 30px;
            border-radius: 10px;
            box-shadow: 0 2px 10px rgba(0,0,0,0.1);
        }
        .upload-area {
            border: 2px dashed #ccc;
            padding: 40px;
            text-align: center;
            border-radius: 10px;
            cursor: pointer;
            transition: border-color 0.3s;
        }
        .upload-area:hover { border-color: #007bff; }
        #preview {
            max-width: 300px;
            margin: 20px auto;
            display: none;
        }
        #preview img { width: 100%; border-radius: 10px; }
        button {
            background: #007bff;
            color: white;
            border: none;
            padding: 12px 30px;
            border-radius: 5px;
            cursor: pointer;
            font-size: 16px;
            display: block;
            margin: 20px auto;
        }
        button:hover { background: #0056b3; }
        button:disabled { background: #ccc; cursor: not-allowed; }
        #results {
            margin-top: 20px;
            display: none;
        }
        .prediction {
            display: flex;
            align-items: center;
            margin: 10px 0;
        }
        .class-name { width: 100px; font-weight: bold; }
        .prob-bar {
            flex: 1;
            height: 25px;
            background: #e9ecef;
            border-radius: 5px;
            overflow: hidden;
        }
        .prob-fill {
            height: 100%;
            background: linear-gradient(90deg, #007bff, #00d4ff);
            display: flex;
            align-items: center;
            justify-content: flex-end;
            padding-right: 10px;
            color: white;
            font-size: 12px;
        }
        .loading { display: none; text-align: center; }
    </style>
</head>
<body>
    <h1>CIFAR-10 图像分类器</h1>
    <div class="container">
        <div class="upload-area" onclick="document.getElementById('fileInput').click()">
            <p>点击或拖拽图片到此处上传</p>
            <input type="file" id="fileInput" accept="image/*" style="display:none">
        </div>
        <div id="preview"><img id="previewImage"></div>
        <button id="predictBtn" disabled>开始分类</button>
        <div class="loading" id="loading">分类中...</div>
        <div id="results"></div>
    </div>

    <script>
        const fileInput = document.getElementById("fileInput");
        const preview = document.getElementById("preview");
        const previewImage = document.getElementById("previewImage");
        const predictBtn = document.getElementById("predictBtn");
        const results = document.getElementById("results");
        const loading = document.getElementById("loading");

        fileInput.onchange = function(e) {
            const file = e.target.files[0];
            if (file) {
                const reader = new FileReader();
                reader.onload = function(e) {
                    previewImage.src = e.target.result;
                    preview.style.display = "block";
                    predictBtn.disabled = false;
                    results.style.display = "none";
                };
                reader.readAsDataURL(file);
            }
        };

        predictBtn.onclick = async function() {
            const file = fileInput.files[0];
            if (!file) return;

            loading.style.display = "block";
            predictBtn.disabled = true;

            const formData = new FormData();
            formData.append("file", file);

            try {
                const response = await fetch("http://localhost:8000/predict", {
                    method: "POST",
                    body: formData
                });
                const data = await response.json();
                displayResults(data.predictions);
            } catch (error) {
                alert("分类失败: " + error.message);
            }

            loading.style.display = "none";
            predictBtn.disabled = false;
        };

        function displayResults(predictions) {
            results.innerHTML = "<h3>分类结果:</h3>";
            predictions.forEach(pred => {
                const percent = (pred.probability * 100).toFixed(1);
                results.innerHTML += `
                    <div class="prediction">
                        <span class="class-name">${pred.class}</span>
                        <div class="prob-bar">
                            <div class="prob-fill" style="width: ${percent}%">${percent}%</div>
                        </div>
                    </div>
                `;
            });
            results.style.display = "block";
        }
    </script>
</body>
</html>
'''

# 写入文件
with open('index.html', 'w', encoding='utf-8') as f:
    f.write(html_code)

print("HTML 前端代码已保存到 index.html")
print("启动 API 后用浏览器打开 index.html 即可使用")

## 7. 项目总结

### 项目结构

```
image-classifier/
├── app.py              # FastAPI 应用
├── best_model.pth      # 训练好的模型
├── index.html          # 前端界面
└── requirements.txt    # 依赖
```

### requirements.txt

```
torch
torchvision
fastapi
uvicorn
python-multipart
pillow
```

### 运行步骤

1. 安装依赖: `pip install -r requirements.txt`
2. 启动 API: `python app.py`
3. 打开 `index.html` 进行测试

### 扩展方向

1. 添加更多类别
2. 使用更大的模型
3. 部署到云服务器
4. 添加用户认证
5. 优化推理速度