In [2]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import models, transforms, datasets
from tqdm import tqdm
from torchsummary import summary
from PIL import Image
import torch.nn.functional as F

# 设备配置
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
from torchvision.models.efficientnet import EfficientNet_V2_S_Weights
from torchvision.models.mobilenetv3 import MobileNet_V3_Small_Weights

In [4]:
print("CUDA 可用:", torch.cuda.is_available())
print("PyTorch 版本:", torch.__version__)

CUDA 可用: True
PyTorch 版本: 2.6.0+cu118


In [12]:
class CustomDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.samples = []

        # 读取类别文件
        classes_path = os.path.join(root_dir, '_classes.txt')
        with open(classes_path, 'r') as f:
            self.classes = [line.strip() for line in f.readlines()]
        self.class_to_idx = {cls: idx for idx, cls in enumerate(self.classes)}

        # 读取标注文件
        annotations_path = os.path.join(root_dir, '_annotations.txt')
        with open(annotations_path, 'r') as f:
            for line in f.readlines():
                line = line.strip()
                if not line:
                    continue
                parts = line.split()
                if len(parts) < 2:
                    continue
                image_name = parts[0]
                data_part = parts[1].split(',')
                if len(data_part) < 5:
                    continue
                try:
                    target = int(data_part[-1])
                    self.samples.append((image_name, target))
                except Exception as e:
                    print(f"Skipped invalid sample: {line} - Error: {e}")

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        image_name, target = self.samples[idx]
        image_path = os.path.join(self.root_dir, image_name)

        try:
            image = Image.open(image_path).convert('RGB')
        except Exception as e:
            print(f"Error loading image {image_path}: {e}")
            return self[random.randint(0, len(self) - 1)] 

        if self.transform:
            image = self.transform(image)

        return image, target

In [13]:
# 数据预处理
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# 加载数据集
train_dataset = CustomDataset(
    root_dir=r'C:\Users\lenovo\Desktop\archive\train',
    transform=transform
)
test_dataset = CustomDataset(
    root_dir=r'C:\Users\lenovo\Desktop\archive\test',
    transform=transform
)


In [14]:
print("Train dataset length:", len(train_dataset))

# 检查前几个样本
for i in range(5):
    img, label = train_dataset[i]
    print(f"Sample {i}: Image shape={img.shape}, Label={label}")


Train dataset length: 11948
Sample 0: Image shape=torch.Size([3, 224, 224]), Label=3
Sample 1: Image shape=torch.Size([3, 224, 224]), Label=3
Sample 2: Image shape=torch.Size([3, 224, 224]), Label=2
Sample 3: Image shape=torch.Size([3, 224, 224]), Label=0
Sample 4: Image shape=torch.Size([3, 224, 224]), Label=3


In [16]:
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True, num_workers=0)
test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False, num_workers=0)

# 加载预训练教师模型（EfficientNetV2）
teacher_model = models.efficientnet_v2_s(weights=EfficientNet_V2_S_Weights.IMAGENET1K_V1)
teacher_model.classifier[1] = nn.Linear(teacher_model.classifier[1].in_features, 6)
teacher_model = teacher_model.to(device)
teacher_model.eval()

# 初始化学生模型（MobileNetV3）
student_model = models.mobilenet_v3_small(weights=MobileNet_V3_Small_Weights.IMAGENET1K_V1)
student_model.classifier[3] = nn.Linear(student_model.classifier[3].in_features, 6)
student_model = student_model.to(device)

# 损失函数和优化器
hard_loss = nn.CrossEntropyLoss()
soft_loss = nn.KLDivLoss(reduction="batchmean")
optimizer = optim.Adam(student_model.parameters(), lr=1e-4)

# 蒸馏参数
temp = 7
alpha = 0.3

# 训练函数
def train_kd(epochs):
    best_acc = 0.0

    for epoch in range(epochs):
        # 训练阶段
        student_model.train()
        train_loss = 0.0

        for data, targets in tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}"):
            data = data.to(device)
            targets = targets.to(device)

            # 教师模型预测
            with torch.no_grad():
                teacher_preds = teacher_model(data)

            # 学生模型预测
            student_preds = student_model(data)

            # 计算损失
            student_loss = hard_loss(student_preds, targets)
            distill_loss = soft_loss(
                F.log_softmax(student_preds / temp, dim=1),
                F.softmax(teacher_preds / temp, dim=1)
            )
            loss = alpha * student_loss + (1 - alpha) * distill_loss

            # 反向传播
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            train_loss += loss.item() * data.size(0)

        # 验证阶段
        student_model.eval()
        correct = 0
        total = 0

        with torch.no_grad():
            for data, targets in test_loader:
                data = data.to(device)
                targets = targets.to(device)

                outputs = student_model(data)
                _, preds = torch.max(outputs, 1)

                correct += (preds == targets).sum().item()
                total += targets.size(0)

        val_acc = correct / total
        train_loss = train_loss / len(train_dataset)

        print(f"Epoch {epoch+1} - Loss: {train_loss:.4f} | Val Acc: {val_acc*100:.2f}%")

        # 保存最佳模型
        if val_acc > best_acc:
            best_acc = val_acc
            torch.save(student_model.state_dict(), 'best_student.pth')

    print(f"Best Validation Accuracy: {best_acc*100:.2f}%")

# 开始训练
if __name__ == '__main__':
    train_kd(epochs=6)

Epoch 1/6: 100%|█████████████████████████████████████████████████████████████████████| 747/747 [01:48<00:00,  6.91it/s]


Epoch 1 - Loss: 0.1728 | Val Acc: 92.59%


Epoch 2/6: 100%|█████████████████████████████████████████████████████████████████████| 747/747 [02:04<00:00,  6.01it/s]


Epoch 2 - Loss: 0.0943 | Val Acc: 95.13%


Epoch 3/6: 100%|█████████████████████████████████████████████████████████████████████| 747/747 [01:58<00:00,  6.32it/s]


Epoch 3 - Loss: 0.0784 | Val Acc: 94.11%


Epoch 4/6: 100%|█████████████████████████████████████████████████████████████████████| 747/747 [01:56<00:00,  6.43it/s]


Epoch 4 - Loss: 0.0689 | Val Acc: 95.43%


Epoch 5/6: 100%|█████████████████████████████████████████████████████████████████████| 747/747 [01:57<00:00,  6.38it/s]


Epoch 5 - Loss: 0.0625 | Val Acc: 95.94%


Epoch 6/6: 100%|█████████████████████████████████████████████████████████████████████| 747/747 [01:58<00:00,  6.33it/s]


Epoch 6 - Loss: 0.0577 | Val Acc: 95.74%
Best Validation Accuracy: 95.94%


In [18]:
teacher_model = models.efficientnet_v2_s(weights=EfficientNet_V2_S_Weights.IMAGENET1K_V1)
teacher_model.classifier[1] = nn.Linear(teacher_model.classifier[1].in_features, 6)
teacher_model = teacher_model.to(device)

# 损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(teacher_model.parameters(), lr=1e-4)

# 训练函数
def train_teacher(epochs):
    best_acc = 0.0
    for epoch in range(epochs):
        teacher_model.train()
        train_loss = 0.0

        for data, targets in tqdm(train_loader, desc=f"Teacher Epoch {epoch+1}/{epochs}"):
            data = data.to(device)
            targets = targets.to(device)

            preds = teacher_model(data)
            loss = criterion(preds, targets)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            train_loss += loss.item() * data.size(0)

        # 验证阶段
        teacher_model.eval()
        correct = 0
        total = 0
        with torch.no_grad():
            for data, targets in test_loader:
                data = data.to(device)
                targets = targets.to(device)
                preds = teacher_model(data)
                _, predicted = torch.max(preds, 1)
                correct += (predicted == targets).sum().item()
                total += targets.size(0)
        val_acc = correct / total

        print(f"Teacher Epoch {epoch+1} - Loss: {train_loss/len(train_dataset):.4f} | Val Acc: {val_acc*100:.2f}%")

        # 保存最佳模型
        if val_acc > best_acc:
            best_acc = val_acc
            torch.save(teacher_model.state_dict(), 'best_teacher.pth')

# 开始训练
if __name__ == '__main__':
    train_teacher(epochs=6)

Teacher Epoch 1/6: 100%|█████████████████████████████████████████████████████████████| 747/747 [03:33<00:00,  3.49it/s]


Teacher Epoch 1 - Loss: 0.3421 | Val Acc: 95.94%


Teacher Epoch 2/6: 100%|█████████████████████████████████████████████████████████████| 747/747 [04:11<00:00,  2.97it/s]


Teacher Epoch 2 - Loss: 0.1091 | Val Acc: 96.24%


Teacher Epoch 3/6: 100%|█████████████████████████████████████████████████████████████| 747/747 [04:23<00:00,  2.83it/s]


Teacher Epoch 3 - Loss: 0.0744 | Val Acc: 96.95%


Teacher Epoch 4/6: 100%|█████████████████████████████████████████████████████████████| 747/747 [04:29<00:00,  2.77it/s]


Teacher Epoch 4 - Loss: 0.0541 | Val Acc: 96.45%


Teacher Epoch 5/6: 100%|█████████████████████████████████████████████████████████████| 747/747 [04:27<00:00,  2.79it/s]


Teacher Epoch 5 - Loss: 0.0374 | Val Acc: 97.06%


Teacher Epoch 6/6: 100%|█████████████████████████████████████████████████████████████| 747/747 [04:28<00:00,  2.78it/s]


Teacher Epoch 6 - Loss: 0.0348 | Val Acc: 97.77%


In [20]:
# 初始化学生模型
student_model = models.mobilenet_v3_small(weights=MobileNet_V3_Small_Weights.IMAGENET1K_V1)
student_model.classifier[3] = nn.Linear(student_model.classifier[3].in_features, 6)
student_model = student_model.to(device)

# 损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(student_model.parameters(), lr=1e-4)

# 训练函数
def train_student_from_scratch(epochs):
    best_acc = 0.0
    for epoch in range(epochs):
        student_model.train()
        train_loss = 0.0

        for data, targets in tqdm(train_loader, desc=f"Student Epoch {epoch+1}/{epochs}"):
            data = data.to(device)
            targets = targets.to(device)

            preds = student_model(data)
            loss = criterion(preds, targets)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            train_loss += loss.item() * data.size(0)

        # 验证阶段
        student_model.eval()
        correct = 0
        total = 0
        with torch.no_grad():
            for data, targets in test_loader:
                data = data.to(device)
                targets = targets.to(device)
                preds = student_model(data)
                _, predicted = torch.max(preds, 1)
                correct += (predicted == targets).sum().item()
                total += targets.size(0)
        val_acc = correct / total

        print(f"Student Epoch {epoch+1} - Loss: {train_loss/len(train_dataset):.4f} | Val Acc: {val_acc*100:.2f}%")

        # 保存最佳模型
        if val_acc > best_acc:
            best_acc = val_acc
            torch.save(student_model.state_dict(), 'best_student_scratch.pth')

# 开始训练
if __name__ == '__main__':
    train_student_from_scratch(epochs=6)

Student Epoch 1/6: 100%|█████████████████████████████████████████████████████████████| 747/747 [01:38<00:00,  7.60it/s]


Student Epoch 1 - Loss: 0.4927 | Val Acc: 91.78%


Student Epoch 2/6: 100%|█████████████████████████████████████████████████████████████| 747/747 [01:39<00:00,  7.54it/s]


Student Epoch 2 - Loss: 0.1882 | Val Acc: 94.01%


Student Epoch 3/6: 100%|█████████████████████████████████████████████████████████████| 747/747 [01:38<00:00,  7.61it/s]


Student Epoch 3 - Loss: 0.1340 | Val Acc: 93.81%


Student Epoch 4/6: 100%|█████████████████████████████████████████████████████████████| 747/747 [01:38<00:00,  7.57it/s]


Student Epoch 4 - Loss: 0.0956 | Val Acc: 92.39%


Student Epoch 5/6: 100%|█████████████████████████████████████████████████████████████| 747/747 [01:37<00:00,  7.67it/s]


Student Epoch 5 - Loss: 0.0753 | Val Acc: 94.72%


Student Epoch 6/6: 100%|█████████████████████████████████████████████████████████████| 747/747 [01:36<00:00,  7.75it/s]


Student Epoch 6 - Loss: 0.0630 | Val Acc: 94.92%
