导入包，固定随机数种子

In [None]:
import os
import random
from collections import defaultdict
import timm
import torch
import torch.nn as nn
import torch.nn.functional as F  # 关键导入语句
from PIL import Image
from sklearn.model_selection import train_test_split
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms
from sklearn.metrics import confusion_matrix, classification_report, precision_recall_curve, average_precision_score
import matplotlib.pyplot as plt
import seaborn as sns
from timm.models.vision_transformer import VisionTransformer
import numpy as np
from tqdm import tqdm
import time
import preprocess

random.seed(456)

定义数据集类与多任务loss函数

In [None]:
class UTKFaceDataset(Dataset):
    def __init__(self, dataframe, transform=None, max_age=116):
        self.dataframe = dataframe
        self.transform = transform
        self.max_age = max_age  # 使用preprocess中计算的最大年龄

    def __len__(self):
        return len(self.dataframe)

    def __getitem__(self, idx):
        row = self.dataframe.iloc[idx]
        # 从DataFrame获取像素数组并转换为图像
        img_array = row['pixels']
        image = Image.fromarray(img_array.astype('uint8'), 'RGB')
        # 应用数据增强和转换
        if self.transform:
            image = self.transform(image)
        age = row['age'] / self.max_age  # 年龄归一化
        gender = row['gender']
        return image, {
            'age': torch.tensor(age, dtype=torch.float32),
            'gender': torch.tensor(gender, dtype=torch.long)
        }


class MultiTaskViT(nn.Module):
    def __init__(self, img_size=224, patch_size=16, embed_dim=768, depth=12, num_heads=12):
        super().__init__()
        # 初始化官方ViT主干（无预训练分类头）
        self.vit = VisionTransformer(
            img_size=img_size,
            patch_size=patch_size,
            in_chans=3,
            num_classes=0,  # 关键：禁用原分类头
            embed_dim=embed_dim,
            depth=depth,
            num_heads=num_heads,
            mlp_ratio=4.0,
            qkv_bias=True
        )
        # 不确定性参数
        self.log_var_age = nn.Parameter(torch.tensor(0.0, requires_grad=True))
        self.log_var_gender = nn.Parameter(torch.tensor(0.0, requires_grad=True))
        # 多任务头部
        self.age_head = nn.Linear(embed_dim, 1)  # 年龄回归头
        self.gender_head = nn.Linear(embed_dim, 2)  # 性别分类头

    def forward_features(self, x):
        # 直接调用ViT的前向特征提取方法
        x = self.vit.patch_embed(x)
        cls_token = self.vit.cls_token.expand(x.shape[0], -1, -1)
        x = torch.cat((cls_token, x), dim=1)
        x = self.vit.pos_drop(x + self.vit.pos_embed)
        x = self.vit.blocks(x)
        x = self.vit.norm(x)
        return x[:, 0]  # 返回CLS token的特征

    def forward(self, x):
        features = self.forward_features(x)
        return {
            'age': self.age_head(features).squeeze(),
            'gender': self.gender_head(features)
        }


In [None]:
def multi_task_loss(outputs, labels, model):
    # 可学习的不确定性参数
    log_var_age = torch.log(torch.tensor(1.0, requires_grad=True))
    log_var_gender = torch.log(torch.tensor(1.0, requires_grad=True))

    age_loss = 0.5 * torch.exp(-log_var_age) * F.l1_loss(outputs['age'], labels['age']) + 0.5 * model.log_var_age
    gender_loss = 0.5 * torch.exp(-log_var_gender) * F.cross_entropy(outputs['gender'],labels['gender']) + 0.5 * model.log_var_gender

    return age_loss + gender_loss, {'age': age_loss, 'gender': gender_loss}

训练函数

In [None]:
def train(model, num_epochs, device, train_loader, val_loader, criterion, optimizer, scheduler):
    print("开始训练")
    writer = SummaryWriter()  # 记录训练过程
    scaler = torch.cuda.amp.GradScaler()  # gpu加速用的
    start_time = time.time()  # 开始计时
    for epoch in tqdm(range(num_epochs)):
        # 训练阶段
        model.train()
        train_loss = 0.0
        for images, labels in train_loader:
            images = images.to(device)
            age_labels = labels['age'].to(device)
            gender_labels = labels['gender'].to(device)

            optimizer.zero_grad()
            # 常规版本
            # outputs = model(images).squeeze()
            # loss = criterion(outputs, ages)
            # loss.backward()
            # optimizer.step()
            with torch.cuda.amp.autocast():
                outputs = model(images)
                total_loss, loss_dict = criterion(outputs, {
                    'age': age_labels,
                    'gender': gender_labels
                }, model)
            scaler.scale(total_loss).backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=4)  # 梯度裁剪
            scaler.step(optimizer)
            scaler.update()
            train_loss += total_loss.item() * images.size(0)
        train_loss /= len(train_loader.dataset)
        writer.add_scalar('Loss/Train/Age', loss_dict['age'], epoch)
        writer.add_scalar('Loss/Train/Gender', loss_dict['gender'], epoch)

        # 验证阶段
        # 用到的一些变量初始化
        val_age_mae = 0
        val_gender_correct = 0.0
        val_gender_total = 0.0
        all_val_gender_probs = []
        all_val_gender_labels = []
        # 早停counter
        age_counter = 0
        gender_counter = 0

        model.eval()
        with torch.no_grad():
            for images, labels in val_loader:
                images = images.to(device)
                age_labels = labels['age'].to(device) * max_age
                gender_labels = labels['gender'].to(device)

                outputs = model(images)
                pred_ages = outputs['age'] * max_age
                pred_genders = outputs['gender'].argmax(dim=1)
                gender_probs = F.softmax(outputs['gender'], dim=1)

                val_age_mae += F.l1_loss(pred_ages, age_labels, reduction='sum').item()
                val_gender_correct += (pred_genders == gender_labels).sum().item()
                val_gender_total += gender_labels.size(0)

                all_val_gender_probs.extend(gender_probs[:, 1].cpu().numpy())
                all_val_gender_labels.extend(gender_labels.cpu().numpy())

        val_age_mae /= len(val_loader.dataset)
        val_gender_acc = val_gender_correct / val_gender_total

        print(f'Epoch {epoch + 1}: '
              f'Train Loss {train_loss:.4f}, '
              f'Val Age MAE {val_age_mae:.2f}, '
              f'Val Gender Acc {val_gender_acc:.2%}')

        # 计算验证集性别分类指标
        val_gender_preds = (np.array(all_val_gender_probs) > 0.5).astype(int)
        print("\nValidation Gender Classification Report:")
        print(classification_report(all_val_gender_labels, val_gender_preds, target_names=['Male', 'Female']))

        # 绘制验证集混淆矩阵
        plt.figure(figsize=(8, 6))
        cm = confusion_matrix(all_val_gender_labels, val_gender_preds)
        sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=['Male', 'Female'],
                    yticklabels=['Male', 'Female'])
        plt.xlabel('Predicted')
        plt.ylabel('True')
        plt.title('Validation Confusion Matrix')
        plt.savefig('val_confusion_matrix.png')
        plt.close()
        # 学习率调整与早停
        # 策略1：基于年龄准确率的早停
        if val_age_mae < best_val_age_mae:
            best_val_age_mae = val_age_mae
            # torch.save(model.state_dict(), 'best_age_model.pth')
            age_counter = 0
        else:
            age_counter += 1

        # 策略2：基于性别准确率的早停
        if val_gender_acc > best_val_gender_acc:
            best_val_gender_acc = val_gender_acc
            # torch.save(model.state_dict(), 'best_gender_model.pth')
            gender_counter = 0
        else:
            gender_counter += 1

        # 综合早停（任务连续5次未改善则停止）
        if age_counter >= 5 and gender_counter >= 5:
            torch.save(model.state_dict(), 'best_integrate_model.pth')
            print("Early stopping triggered")
            break

        # 学习率调整（可针对特定任务，如果是gender表现太差可以）
        scheduler.step()

    cost_time = time.time() - start_time
    print(f'训练总耗时{cost_time:.4f}')

In [None]:
'''性别的评估函数'''

In [None]:
# 混淆矩阵
def plot_confusion_matrix(true_labels, pred_labels, classes=['Male', 'Female']):
    cm = confusion_matrix(true_labels, pred_labels)
    plt.figure(figsize=(8, 6))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=classes, yticklabels=classes)
    plt.xlabel('Predicted')
    plt.ylabel('True')
    plt.title('Confusion Matrix')
    plt.savefig('confusion_matrix.png')
    plt.close()


# 算mPA
def compute_mPA(true_labels, pred_probs, threshold=0.5):
    pred_labels = (pred_probs > threshold).astype(int)
    report = classification_report(true_labels, pred_labels, target_names=['Male', 'Female'], output_dict=True)
    mPA = (report['Male']['precision'] + report['Female']['precision']) / 2
    print(f'mPA (Mean Precision per Class): {mPA:.4f}')


# F1曲线
def plot_f1_confidence_curve(true_labels, pred_probs):
    thresholds = np.linspace(0, 1, 100)
    f1_scores = []

    for thresh in thresholds:
        pred_labels = (pred_probs > thresh).astype(int)
        report = classification_report(true_labels, pred_labels, target_names=['Male', 'Female'], output_dict=True)
        f1 = report['macro avg']['f1-score']  # 使用宏平均F1
        f1_scores.append(f1)

    plt.figure(figsize=(10, 6))
    plt.plot(thresholds, f1_scores, label='F1 Score')
    plt.xlabel('Confidence Threshold')
    plt.ylabel('Macro F1 Score')
    plt.title('F1-Confidence Curve')
    plt.legend()
    plt.grid(True)
    plt.savefig('f1_confidence_curve.png')
    plt.close()


In [None]:
'''年龄的会使用到的评估函数'''

In [None]:
def plot_age_confusion_matrix(true_groups, pred_groups, age_bins):
    cm = confusion_matrix(true_groups, pred_groups)
    plt.figure(figsize=(12, 10))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
                xticklabels=[f"{age_bins[i]}-{age_bins[i + 1]}" for i in range(len(age_bins) - 1)],
                yticklabels=[f"{age_bins[i]}-{age_bins[i + 1]}" for i in range(len(age_bins) - 1)])
    plt.xlabel('Predicted Age Group')
    plt.ylabel('True Age Group')
    plt.title('Age Group Confusion Matrix')
    plt.savefig('age_confusion_matrix.png')
    plt.close()


def compute_age_mPA(true_groups, pred_groups):
    report = classification_report(true_groups, pred_groups, output_dict=True)
    class_precisions = [report[str(i)]['precision'] for i in range(len(np.unique(true_groups))) if str(i) in report]
    mPA = np.mean(class_precisions)
    print(f"Age mPA (Mean Precision per Age Group): {mPA:.4f}")


def plot_age_f1_confidence_curve(true_ages, pred_ages, max_error_threshold=20):
    error_thresholds = np.arange(0, max_error_threshold + 1, 1)
    f1_scores = []

    for thresh in error_thresholds:
        correct = np.abs(np.array(true_ages) - np.array(pred_ages)) <= thresh
        precision = np.sum(correct) / len(correct)  # 简化的“精确率”
        recall = np.sum(correct) / len(correct)  # 简化的“召回率”
        f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
        f1_scores.append(f1)

    plt.figure(figsize=(10, 6))
    plt.plot(error_thresholds, f1_scores, label='F1 Score')
    plt.xlabel('Allowed Age Prediction Error (years)')
    plt.ylabel('F1 Score')
    plt.title('Age Prediction F1-Confidence Curve')
    plt.legend()
    plt.grid(True)
    plt.savefig('age_f1_confidence_curve.png')
    plt.close()


In [None]:
集成后的测试函数

In [None]:
def model_test(model, device, test_loader, max_age):
    # 加载最佳模型并测试
    model.eval()
    # age_mae, age_mse = 0.0, 0.0
    age_errors = []
    all_true_ages = []
    all_pred_ages = []
    all_gender_probs = []  # 保存预测概率
    all_gender_labels = []  # 保存真实标签

    with torch.no_grad():
        for images, labels in test_loader:  # 使用验证集作为测试
            images = images.to(device)
            age_labels = labels['age'].to(device) * max_age
            gender_labels = labels['gender'].to(device)

            outputs = model(images) # 反归一化
            pred_ages = outputs['age'] * max_age
            gender_logits = outputs['gender']

            # 计算年龄指标
            # age_mae += F.l1_loss(pred_ages, age_labels, reduction='sum').item()
            # age_mse += F.mse_loss(pred_ages, age_labels, reduction='sum').item()
            all_true_ages.extend(age_labels.cpu().numpy())
            all_pred_ages.extend(pred_ages.cpu().numpy())

            # 计算性别准确率
            gender_probs = F.softmax(gender_logits, dim=1).cpu().numpy()
            all_gender_probs.extend(gender_probs[:, 1])  # 取正类概率
            all_gender_labels.extend(gender_labels.cpu().numpy())

        # 转换为NumPy数组
        all_true_ages = np.array(all_true_ages)
        all_pred_ages = np.array(all_pred_ages)
        all_gender_probs = np.array(all_gender_probs)
        all_gender_labels = np.array(all_gender_labels)

        # ================== 年龄评估 ==================
        print("\n==== Age Evaluation ====")
        # 1. 将年龄分箱（每10岁一组）
        age_bins = np.arange(0, max_age + 1, 10)
        age_bins[-1] = max_age
        true_age_groups = np.digitize(all_true_ages, bins=age_bins) - 1  # 0-based索引
        pred_age_groups = np.digitize(all_pred_ages, bins=age_bins) - 1
        # 2. 生成年龄组的混淆矩阵
        plot_age_confusion_matrix(true_age_groups, pred_age_groups, age_bins)
        # 3. 计算年龄组的mPA（平均精确率）
        compute_age_mPA(true_age_groups, pred_age_groups)
        # 4. 生成年龄预测的F1-置信度曲线（基于误差范围）
        plot_age_f1_confidence_curve(all_true_ages, all_pred_ages, max_error_threshold=20)

        #  ================== 性别评估 ==================
        plot_confusion_matrix(all_gender_labels, (all_gender_probs > 0.5).astype(int))
        compute_mPA(all_gender_labels, all_gender_probs)
        plot_f1_confidence_curve(all_gender_labels, all_gender_probs)

In [None]:
初始化数据

In [None]:

data_dir = os.path.join('./UTKFace')
max_age = 116
df = preprocess.load_data_to_dataframe(data_dir)
# 分层划分
df['age_group'] = df['age'] // 10
stratify_col = df['age_group']
# 训练：验证：测试 = 8：1：1，可以改
train_df, temp_df = train_test_split(df, test_size=0.2, random_state=456, stratify=stratify_col)
# 样本权重计算
age_groups = train_df['age_group'].values
class_counts = np.bincount(age_groups)
class_weights = 1. / class_counts
sample_weights = class_weights[age_groups]
# 动态权重取样
sampler = WeightedRandomSampler(weights=sample_weights, num_samples=len(train_df), replacement=True)

# 验证与测试样本划分
stratify_temp = temp_df['age_group']
val_df, test_df = train_test_split(temp_df, test_size=0.5, random_state=456, stratify=stratify_temp)
# 创建DataLoader，够memory的话可以把batch_size提大点
batch_size = 8
# 数据预处理流程
train_transform = transforms.Compose([
  transforms.ToTensor(),
  transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
val_transform = transforms.Compose([
  transforms.ToTensor(),
  transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
test_transform = transforms.Compose([
  transforms.ToTensor(),
  transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# 创建数据集实例
train_dataset = UTKFaceDataset(train_df, transform=train_transform)
val_dataset = UTKFaceDataset(val_df, transform=val_transform)
test_dataset = UTKFaceDataset(test_df, transform=test_transform)
print('已完成数据处理')
train_loader = DataLoader(train_dataset, batch_size=batch_size, sampler=sampler)  # 动态加权采样
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)


初始化模型并训练

In [None]:
# 初始化ViT模型
model = MultiTaskViT(
  img_size=224,
  patch_size=16,
  embed_dim=768,  # DeiT-Tiny的默认嵌入维度
  depth=12,  # Transformer层数
  num_heads=3  # 注意力头数（DeiT-Tiny为3）
)
# model.load_state_dict(torch.load("./pytorch_model.bin"), strict=False)  # 预训练权重
print('模型初始化完成')

# 训练循环
num_epochs = 25

# 训练配置
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')  # 设备
model.to(device)  # model装载
criterion = multi_task_loss  # 自定义loss函数
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=0.01)  # 优化器
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs, eta_min=1e-6)  # 学习率自适应

train(model, num_epochs, device, train_loader, val_loader, criterion, optimizer, scheduler)

测试

In [None]:
model.load_state_dict(torch.load('best_integrate_model.pth'))
model_test(model, device, test_loader, max_age)