In [1]:
import pandas as pd
import numpy as np
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
from PIL import Image
import matplotlib
matplotlib.use('TkAgg')
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import classification_report, confusion_matrix

In [2]:
print("CUDA 可用:", torch.cuda.is_available())
print("PyTorch 版本:", torch.__version__)

CUDA 可用: True
PyTorch 版本: 2.6.0+cu118


In [3]:
torch.manual_seed(42)
np.random.seed(42)

In [4]:
config = {
    'data_path': r'E:\archive', 
    'batch_size': 64,
    'lr': 3e-4,
    'epochs': 50,
    'num_classes': 6,
    'temperature': 3,
    'alpha': 0.7,  # 蒸馏损失权重
    'device': 'cuda' if torch.cuda.is_available() else 'cpu'
}

In [5]:
class DrivingBehaviorDataset(Dataset):
    def __init__(self, root_dir, annotations_file, classes_file, transform=None):
        self.root_dir = root_dir
        self.transform = transform

        # 读取类别文件
        with open(classes_file, 'r') as f:
            self.classes = [line.strip() for line in f.readlines()]
        self.class_to_idx = {cls: idx for idx, cls in enumerate(self.classes)}

        # 读取标注文件
        self.samples = []
        with open(annotations_file, 'r') as f:
            for line in f.readlines():
                parts = line.strip().split(' ')
                if not parts:
                    continue
                try:
                    class_id = int(parts[-1].split(',')[-1])
                    img_name = parts[0]
                    self.samples.append((img_name, class_id))
                except:
                    print(f"跳过无效标注行: {line}")

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        # 最大重试次数
        max_retries = 5
        for attempt in range(max_retries):
            try:
                img_name, target = self.samples[idx]
                img_path = os.path.join(self.root_dir, img_name)
                image = Image.open(img_path).convert('L')  # 加载灰度图像

                if self.transform:
                    image = self.transform(image)

                return image, target  # 成功加载，返回结果

            except Exception as e:
                print(f"加载图像失败 (尝试 {attempt+1}/{max_retries}): {img_path} - {e}")
                idx = (idx + 1) % len(self)  # 移动到下一个索引

        # 达到最大重试次数仍未找到有效图像
        raise RuntimeError("达到最大重试次数，无法加载有效图像。请检查数据集完整性")

In [5]:
class DrivingBehaviorDataset(Dataset):
    def __init__(self, root_dir, annotations_file, classes_file, transform=None):
        self.root_dir = root_dir
        self.transform = transform

        # 读取类别文件
        with open(classes_file, 'r') as f:
            self.classes = [line.strip() for line in f.readlines()]
        self.class_to_idx = {cls: idx for idx, cls in enumerate(self.classes)}

        # 读取标注文件
        self.samples = []
        with open(annotations_file, 'r') as f:
            for line in f.readlines():
                line = line.strip()
                if not line:
                    continue

                # 按空格分割，取最后一个元素（包含逗号）
                parts = line.split(' ')
                if len(parts) < 2:
                    continue  # 跳过无效行

                # 提取最后一个字段（包含逗号）
                last_part = parts[-1]

                # 再按逗号分割，取最后一个值作为类别ID
                if ',' in last_part:
                    class_id_str = last_part.split(',')[-1]
                else:
                    class_id_str = last_part

                # 转换为整数
                try:
                    class_id = int(class_id_str)
                except ValueError:
                    print(f"跳过无效标注行: {line}")
                    continue

                # 提取图片名
                img_name = parts[0]

                # 添加到样本列表
                self.samples.append((img_name, class_id))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        img_name, target = self.samples[idx]
        img_path = os.path.join(self.root_dir, img_name)

        try:
            image = Image.open(img_path).convert('L')
        except Exception as e:
            print(f"加载图像失败: {img_path} - {e}")
            return self.__getitem__((idx + 1) % len(self))

        if self.transform:
            image = self.transform(image)

        return image, target

In [7]:
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

In [8]:
def load_data():
    train_dataset = DrivingBehaviorDataset(
        root_dir=os.path.join(config['data_path'], 'train'),
        annotations_file=os.path.join(config['data_path'], 'train', '_annotations.txt'),
        classes_file=os.path.join(config['data_path'], 'train', '_classes.txt'),
        transform=transform
    )
    val_dataset = DrivingBehaviorDataset(
        root_dir=os.path.join(config['data_path'], 'valid'),
        annotations_file=os.path.join(config['data_path'], 'valid', '_annotations.txt'),
        classes_file=os.path.join(config['data_path'], 'valid', '_classes.txt'),
        transform=transform
    )
    test_dataset = DrivingBehaviorDataset(
        root_dir=os.path.join(config['data_path'], 'test'),
        annotations_file=os.path.join(config['data_path'], 'test', '_annotations.txt'),
        classes_file=os.path.join(config['data_path'], 'test', '_classes.txt'),
        transform=transform
    )
    return (
        DataLoader(train_dataset, batch_size=config['batch_size'], shuffle=True, num_workers=4),
        DataLoader(val_dataset, batch_size=config['batch_size'], num_workers=4),
        DataLoader(test_dataset, batch_size=config['batch_size'], num_workers=4)
    )

train_loader, val_loader, test_loader = load_data()

跳过无效标注行: gA_4_s2_ir_face_mp4-440_jpg.rf.f0ab9d03d718ac287cac6fca394783d0.jpg

跳过无效标注行: gA_4_s2_ir_face_mp4-440_jpg.rf.fa7c39ef44d869bdab6c298c89442f19.jpg

跳过无效标注行: gB_9_s1_2019-03-07T16-36-24-01-00_ir_face_mp4-401_jpg.rf.4a274f58e714facd11a3693a3325c3d5.jpg

跳过无效标注行: gB_9_s1_2019-03-07T16-36-24-01-00_ir_face_mp4-401_jpg.rf.c507475339c83c2de42f2400987923d2.jpg



In [9]:
teacher_model = models.efficientnet_v2_s(weights=models.EfficientNet_V2_S_Weights.DEFAULT)
teacher_model.features[0][0] = nn.Conv2d(1, 32, kernel_size=3, stride=2, bias=False)
teacher_model.classifier = nn.Sequential(
    nn.Dropout(0.4),
    nn.Linear(1280, config['num_classes']),
    nn.Softmax(dim=1)
)
teacher_model = teacher_model.to(config['device'])

In [10]:
student_model = models.mobilenet_v3_small(weights=models.MobileNet_V3_Small_Weights.DEFAULT)
student_model.classifier = nn.Sequential(
    nn.Dropout(0.2),
    nn.Linear(1024, config['num_classes']),
    nn.BatchNorm1d(config['num_classes'])
)
student_model = student_model.to(config['device'])

In [11]:
class FeatureHook:
    def __init__(self, model, layer_name):
        self.hook = model._modules.get(layer_name).register_forward_hook(self.hook_fn)
        self.feature = None

    def hook_fn(self, module, input, output):
        self.feature = output

    def close(self):
        self.hook.remove()

teacher_hook = FeatureHook(teacher_model, 'features')
student_hook = FeatureHook(student_model, 'features')

# 动态温度系数（随训练过程递增）
def dynamic_temperature(epoch, max_temp=20):
    return min(config['temperature'] + epoch * 0.5, max_temp)

# 改进的知识蒸馏损失函数（包含中间层蒸馏）
class DistillationLoss(nn.Module):
    def __init__(self, temperature, alpha):
        super().__init__()
        self.temperature = temperature
        self.alpha = alpha
        self.hard_loss = nn.CrossEntropyLoss()
        self.kl_div = nn.KLDivLoss()
        self.mse_loss = nn.MSELoss()

    def forward(self, y_s, y_t, labels, student_features, teacher_features):
        # 输出层蒸馏
        soft_loss = self.kl_div(
            torch.log_softmax(y_s / self.temperature, dim=1),
            torch.softmax(y_t / self.temperature, dim=1)
        )
        hard_loss = self.hard_loss(y_s, labels)
        
        feature_loss = self.mse_loss(student_features, teacher_features)
        
        return (1 - self.alpha) * soft_loss * (self.temperature**2) + self.alpha * hard_loss + 0.3 * feature_loss

In [12]:
def train_epoch(model, teacher, loader, optimizer, loss_fn, device, epoch):
    model.train()
    total_loss = 0
    correct = 0
    total = 0

    for inputs, labels in loader:
        inputs, labels = inputs.to(device), labels.to(device)
        
        with torch.no_grad():
            teacher_logits = teacher(inputs)
            teacher_features = teacher_hook.feature
        
        student_logits = model(inputs)
        student_features = student_hook.feature

        # 动态温度系数
        temp = dynamic_temperature(epoch)
        loss = loss_fn(student_logits, teacher_logits, labels, student_features, teacher_features)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        _, predicted = student_logits.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()

    return total_loss / len(loader), correct / total

In [13]:
@torch.no_grad()
def evaluate(model, loader, device):
    model.eval()
    correct = 0
    total = 0
    all_preds = []
    all_labels = []

    for inputs, labels in loader:
        inputs, labels = inputs.to(device), labels.to(device)
        outputs = model(inputs)
        _, predicted = outputs.max(1)

        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()
        all_preds.extend(predicted.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())

    return correct / total, all_preds, all_labels

In [14]:
def plot_confusion_matrix(y_true, y_pred, class_names):
    cm = confusion_matrix(y_true, y_pred)
    plt.figure(figsize=(10, 8))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
                xticklabels=class_names, yticklabels=class_names)
    plt.xlabel('Predicted')
    plt.ylabel('True')
    plt.title('Confusion Matrix')
    plt.tight_layout()
    plt.savefig('confusion_matrix.png')
    plt.show()
    plt.close()

In [14]:
def show_predictions(model, data_loader, device, class_names, num_images=6):
    model.eval()
    images_so_far = 0
    plt.figure(figsize=(10, 5))

    with torch.no_grad():
        for inputs, labels in data_loader:
            inputs = inputs.to(device)
            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)

            for j in range(inputs.size(0)):
                images_so_far += 1
                ax = plt.subplot(num_images // 2, 2, images_so_far)
                ax.axis('off')
                ax.set_title(f'Pred: {class_names[preds[j]]} | True: {class_names[labels[j]]}')
                plt.imshow(inputs[j].cpu().permute(1, 2, 0).squeeze(), cmap='gray')
                if images_so_far == num_images:
                    plt.tight_layout()
                    plt.savefig('prediction_examples.png')
                    plt.show()
                    plt.close()
                    return

In [8]:
train_dataset = DrivingBehaviorDataset(
    root_dir=os.path.join(config['data_path'], 'train'),
    annotations_file=os.path.join(config['data_path'], 'train', '_annotations.txt'),
    classes_file=os.path.join(config['data_path'], 'train', '_classes.txt'),
    transform=transform
)

In [9]:
train_dataset

<__main__.DrivingBehaviorDataset at 0x19644364430>

In [10]:
test_dataset = DrivingBehaviorDataset(
    root_dir=os.path.join(config['data_path'], 'test'),
    annotations_file=os.path.join(config['data_path'], 'test', '_annotations.txt'),
    classes_file=os.path.join(config['data_path'], 'test', '_classes.txt'),
    transform=transform
)

In [9]:
train_loader = DataLoader(train_dataset, batch_size=config['batch_size'], shuffle=True, num_workers=4)
test_loader = DataLoader(test_dataset, batch_size=config['batch_size'], num_workers=4)

In [10]:
# 教师模型（EfficientNetV2）
teacher_model = models.efficientnet_v2_s(weights=models.EfficientNet_V2_S_Weights.DEFAULT)  # 推荐使用 DEFAULT
teacher_model.features[0][0] = nn.Conv2d(1, 32, kernel_size=3, stride=2, bias=False)
teacher_model.classifier[1] = nn.Linear(1280, config['num_classes'])
teacher_model = teacher_model.to(config['device'])


In [11]:
# 学生模型（MobileNetV3 Small）
student_model = models.mobilenet_v3_small(weights=models.MobileNet_V3_Small_Weights.DEFAULT)
student_model.classifier[3] = nn.Linear(1024, config['num_classes'])
student_model = student_model.to(config['device'])

In [12]:
import torchvision.models as models

In [13]:
print(models.EfficientNet_V2_S_Weights.DEFAULT)

EfficientNet_V2_S_Weights.IMAGENET1K_V1


In [14]:
print(teacher_model.features[0][0].in_channels)

1


In [15]:
class DistillationLoss(nn.Module):
    def __init__(self, temperature, alpha):
        super().__init__()
        self.temperature = temperature
        self.alpha = alpha
        self.hard_loss = nn.CrossEntropyLoss()
        self.kl_div = nn.KLDivLoss()

    def forward(self, y_s, y_t, labels):
        soft_loss = self.kl_div(
            torch.log_softmax(y_s / self.temperature, dim=1),
            torch.softmax(y_t / self.temperature, dim=1)
        )
        hard_loss = self.hard_loss(y_s, labels)
        return (1 - self.alpha) * soft_loss * (self.temperature**2) + self.alpha * hard_loss

In [16]:
def train_epoch(model, teacher, loader, optimizer, loss_fn, device):
    model.train()
    total_loss = 0
    correct = 0
    total = 0

    for inputs, labels in loader:
        inputs, labels = inputs.to(device), labels.to(device)

        with torch.no_grad():
            teacher_logits = teacher(inputs)
        student_logits = model(inputs)

        loss = loss_fn(student_logits, teacher_logits, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        _, predicted = student_logits.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()

    return total_loss / len(loader), correct / total

In [17]:
def evaluate(model, loader, device):
    model.eval()
    correct = 0
    total = 0
    all_preds = []
    all_labels = []

    for inputs, labels in loader:
        inputs, labels = inputs.to(device), labels.to(device)
        outputs = model(inputs)
        _, predicted = outputs.max(1)

        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()
        all_preds.extend(predicted.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())

    return correct / total, all_preds, all_labels

In [24]:
%matplotlib inline

In [18]:
def plot_confusion_matrix(y_true, y_pred, class_names):
    try:
        # 确保输入为numpy数组
        y_true = np.array(y_true)
        y_pred = np.array(y_pred)
        
        cm = confusion_matrix(y_true, y_pred)
        plt.figure(figsize=(10, 8))
        sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
                    xticklabels=class_names, 
                    yticklabels=class_names,
                    cbar=False)  # 避免颜色条遮挡
        plt.xticks(rotation=45)  # 防止标签重叠
        plt.yticks(rotation=0)
        plt.xlabel('Predicted', fontsize=12)
        plt.ylabel('True', fontsize=12)
        plt.title('Confusion Matrix', fontsize=14)
        plt.tight_layout()
        plt.savefig('confusion_matrix.png', dpi=300)  # 增加保存分辨率
        plt.show(block=True)  # 确保阻塞显示
        plt.close()
    except Exception as e:
        print(f"绘图失败: {str(e)}")

In [19]:
def initialize_plots():
    plt.figure(figsize=(12, 5))
    plt.ion()

def update_plots(epoch, train_loss, train_acc, val_loss, val_acc, losses, accuracies):
    epochs = list(range(1, epoch + 2))
    losses['train'].append(train_loss)
    losses['val'].append(val_loss)
    accuracies['train'].append(train_acc)
    accuracies['val'].append(val_acc)

    plt.clf()
    plt.subplot(1, 2, 1)
    plt.plot(epochs, losses['train'], label='Train Loss')
    plt.plot(epochs, losses['val'], label='Val Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Training and Validation Loss')
    plt.legend()

    plt.subplot(1, 2, 2)
    plt.plot(epochs, accuracies['train'], label='Train Acc')
    plt.plot(epochs, accuracies['val'], label='Val Acc')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.title('Training and Validation Accuracy')
    plt.legend()

    plt.tight_layout()
    plt.pause(0.01)

In [None]:
def show_predictions(model, data_loader, device, class_names, num_images=6):
    model.eval()
    images_so_far = 0
    plt.figure(figsize=(10, 5))

    with torch.no_grad():
        for inputs, labels in data_loader:
            inputs = inputs.to(device)
            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)

            for j in range(inputs.size(0)):
                images_so_far += 1
                ax = plt.subplot(num_images // 2, 2, images_so_far)
                ax.axis('off')
                ax.set_title(f'Pred: {class_names[preds[j]]} | True: {class_names[labels[j]]}')
                plt.imshow(inputs[j].cpu().permute(1, 2, 0).squeeze(), cmap='gray')

                if images_so_far == num_images:
                    plt.show()
                    return

# 主训练流程
def main():
    teacher = teacher_model
    student = student_model

    for param in teacher.parameters():
        param.requires_grad = False

    optimizer = optim.AdamW(student.parameters(), lr=config['lr'])
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=3)
    loss_fn = DistillationLoss(config['temperature'], config['alpha'])

    best_acc = 0
    losses = {'train': [], 'val': []}
    accuracies = {'train': [], 'val': []}
    initialize_plots()

    for epoch in range(config['epochs']):
        train_loss, train_acc = train_epoch(student, teacher, train_loader, optimizer, loss_fn, config['device'])
        val_acc, _, _ = evaluate(student, val_loader, config['device'])
        print(f"Epoch {epoch+1:03d} | Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.4f} | Val Acc: {val_acc:.4f}")
        scheduler.step(train_loss)

        update_plots(epoch, train_loss, train_acc, train_loss, val_acc, losses, accuracies)

        if val_acc > best_acc:
            best_acc = val_acc
            torch.save(student.state_dict(), 'best_student_model.pth')

    student.load_state_dict(torch.load('best_student_model.pth'))
    test_acc, preds, labels = evaluate(student, test_loader, config['device'])
    print(f"\nTest Accuracy: {test_acc:.4f}")
    print("Classification Report:")
    print(classification_report(labels, preds))

    class_names = ['DangerousDriving', 'Distracted', 'Drinking', 'SafeDriving', 'SleepyDriving', 'Yawn']
    plot_confusion_matrix(labels, preds, class_names)
    show_predictions(student, test_loader, config['device'], class_names, num_images=6)

if __name__ == '__main__':
    main()