In [1]:
"""
多输出模型：社交媒体用户画像预测

本示例构建一个多任务学习模型，从用户发布的文本内容预测多个属性：
- 输出1：年龄（回归任务）
- 输出2：收入等级（多分类任务）
- 输出3：性别（二分类任务）

多输出模型的优势：
1. 共享表征学习：底层特征对多个任务都有用
2. 正则化效果：多任务学习隐式约束模型，防止过拟合
3. 数据效率：同时学习多个任务，提高样本利用率

损失函数权重：
通过loss_weights平衡不同任务的重要性和难度
"""

from keras import layers
from keras.layers import Input
from keras.models import Model
import numpy as np
from keras.utils import to_categorical

# 设置随机种子
RANDOM_SEED = 42
np.random.seed(RANDOM_SEED)

# 超参数配置
vocabulary_size = 50000
num_income_groups = 10

# 输入：用户发布的文本序列（词索引）
posts_input = Input(shape=(None,), dtype="int32", name="posts")

# 嵌入层：将词索引映射为稠密向量
embedded_posts = layers.Embedding(vocabulary_size, 256)(posts_input)

# 特征提取backbone：多层卷积神经网络
# 使用padding='same'保持序列长度，避免过度缩减
x = layers.Conv1D(128, 5, activation='relu', padding='same')(embedded_posts)
x = layers.MaxPooling1D(5)(x)
x = layers.Conv1D(256, 5, activation='relu', padding='same')(x)
x = layers.Conv1D(256, 5, activation='relu', padding='same')(x)
x = layers.MaxPooling1D(5)(x)
x = layers.Conv1D(256, 5, activation='relu', padding='same')(x)
x = layers.Conv1D(256, 5, activation='relu', padding='same')(x)

# 全局池化：将变长序列转为固定长度向量
x = layers.GlobalMaxPooling1D()(x)

# 共享的全连接层
x = layers.Dense(128, activation='relu')(x)

# 输出分支1：年龄预测（回归任务）
age_prediction = layers.Dense(1, name='age')(x)

# 输出分支2：收入等级预测（多分类任务）
income_prediction = layers.Dense(
    num_income_groups, 
    name='income', 
    activation='softmax')(x)

# 输出分支3：性别预测（二分类任务）
gender_prediction = layers.Dense(
    1, 
    name='gender', 
    activation='sigmoid')(x)

# 构建多输出模型
model = Model(posts_input, 
              [age_prediction, income_prediction, gender_prediction])

# 编译模型：为每个输出指定损失函数和权重
model.compile(
    optimizer='adam',
    loss={
        'age': 'mse',                        # 年龄：均方误差
        'income': 'categorical_crossentropy', # 收入：交叉熵
        'gender': 'binary_crossentropy'      # 性别：二元交叉熵
    },
    loss_weights={
        'age': 0.25,    # 年龄任务权重较低
        'income': 1.0,   # 收入任务权重中等
        'gender': 10.0   # 性别任务权重最高（优先优化）
    },
    metrics={
        'age': 'mae',
        'income': 'accuracy',
        'gender': 'accuracy'
    }
)

print("多输出模型架构：")
model.summary()

# 生成模拟数据
num_samples = 1000
max_length = 100

# 输入：文本序列
posts = np.random.randint(1, vocabulary_size, size=(num_samples, max_length))

# 输出1：年龄（18-80岁）
age_targets = np.random.uniform(18, 80, size=(num_samples, 1))

# 输出2：收入等级（10个类别的one-hot编码）
income_classes = np.random.randint(0, num_income_groups, size=(num_samples,))
income_targets = to_categorical(income_classes, num_income_groups)

# 输出3：性别（0或1）
gender_targets = np.random.randint(0, 2, size=(num_samples, 1)).astype('float32')

print("\n开始训练...")
# 使用字典传递多个输出
history = model.fit(
    posts,
    {
        'age': age_targets,
        'income': income_targets,
        'gender': gender_targets
    },
    epochs=3,
    batch_size=64,
    validation_split=0.2,
    verbose=1
)

print("\n模型训练完成")
print(f"最终损失 - 年龄: {history.history['age_loss'][-1]:.4f}, "
      f"收入: {history.history['income_loss'][-1]:.4f}, "
      f"性别: {history.history['gender_loss'][-1]:.4f}")

2025-12-12 22:21:29.447548: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-12-12 22:21:29.453792: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2025-12-12 22:21:29.461574: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:8454] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2025-12-12 22:21:29.463871: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1452] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-12-12 22:21:29.469560: I tensorflow/core/platform/cpu_feature_guar

I0000 00:00:1765549290.521498   37211 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1765549290.543602   37211 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1765549290.544962   37211 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1765549290.547537   37211 cuda_executor.cc:1015] successful NUMA node read from SysFS ha

多输出模型架构：



开始训练...
Epoch 1/3


I0000 00:00:1765549292.168136   37310 service.cc:146] XLA service 0x7a3cf8006820 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1765549292.168157   37310 service.cc:154]   StreamExecutor device (0): NVIDIA GeForce RTX 4080 Laptop GPU, Compute Capability 8.9
2025-12-12 22:21:32.199168: I tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc:268] disabling MLIR crash reproducer, set env var `MLIR_CRASH_REPRODUCER_DIRECTORY` to enable.
2025-12-12 22:21:32.345250: I external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:531] Loaded cuDNN version 91301


[1m 1/13[0m [32m━[0m[37m━━━━━━━━━━━━━━━━━━━[0m [1m36s[0m 3s/step - age_loss: 2684.6450 - age_mae: 48.6432 - gender_accuracy: 0.5000 - gender_loss: 0.6930 - income_accuracy: 0.1250 - income_loss: 2.3031 - loss: 680.3944

[1m10/13[0m [32m━━━━━━━━━━━━━━━[0m[37m━━━━━[0m [1m0s[0m 6ms/step - age_loss: 2254.1875 - age_mae: 43.1768 - gender_accuracy: 0.4954 - gender_loss: 0.7752 - income_accuracy: 0.1096 - income_loss: 3.1378 - loss: 574.4369

I0000 00:00:1765549294.155125   37310 device_compiler.h:188] Compiled cluster using XLA!  This line is logged at most once for the lifetime of the process.


[1m13/13[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 143ms/step - age_loss: 2061.5142 - age_mae: 40.5563 - gender_accuracy: 0.4938 - gender_loss: 0.7997 - income_accuracy: 0.1082 - income_loss: 3.4178 - loss: 527.2562

[1m13/13[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m6s[0m 244ms/step - age_loss: 1363.6525 - age_mae: 31.2929 - gender_accuracy: 0.4900 - gender_loss: 0.8660 - income_accuracy: 0.1063 - income_loss: 4.2290 - loss: 359.8237 - val_age_loss: 628.7371 - val_age_mae: 21.4770 - val_gender_accuracy: 0.4800 - val_gender_loss: 0.8517 - val_income_accuracy: 0.0950 - val_income_loss: 2.9930 - val_loss: 182.1337


Epoch 2/3


[1m 1/13[0m [32m━[0m[37m━━━━━━━━━━━━━━━━━━━[0m [1m0s[0m 19ms/step - age_loss: 504.5919 - age_mae: 18.1105 - gender_accuracy: 0.4375 - gender_loss: 0.9183 - income_accuracy: 0.1406 - income_loss: 3.0699 - loss: 138.4011

[1m12/13[0m [32m━━━━━━━━━━━━━━━━━━[0m[37m━━[0m [1m0s[0m 5ms/step - age_loss: 397.1236 - age_mae: 16.2325 - gender_accuracy: 0.4597 - gender_loss: 1.0490 - income_accuracy: 0.1232 - income_loss: 3.2056 - loss: 112.9763 

[1m13/13[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 9ms/step - age_loss: 327.7531 - age_mae: 14.9050 - gender_accuracy: 0.4750 - gender_loss: 1.0294 - income_accuracy: 0.1163 - income_loss: 3.0468 - loss: 96.5324 - val_age_loss: 351.2443 - val_age_mae: 16.4968 - val_gender_accuracy: 0.4800 - val_gender_loss: 1.1968 - val_income_accuracy: 0.1300 - val_income_loss: 2.7816 - val_loss: 101.5606


Epoch 3/3


[1m 1/13[0m [32m━[0m[37m━━━━━━━━━━━━━━━━━━━[0m [1m0s[0m 19ms/step - age_loss: 113.7116 - age_mae: 9.1908 - gender_accuracy: 0.4688 - gender_loss: 1.2918 - income_accuracy: 0.0938 - income_loss: 2.7010 - loss: 44.0467

[1m13/13[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 5ms/step - age_loss: 116.6757 - age_mae: 8.8726 - gender_accuracy: 0.4932 - gender_loss: 1.0323 - income_accuracy: 0.1029 - income_loss: 2.7593 - loss: 42.2816 

[1m13/13[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 9ms/step - age_loss: 108.4334 - age_mae: 8.5929 - gender_accuracy: 0.5025 - gender_loss: 0.9318 - income_accuracy: 0.0975 - income_loss: 2.6537 - loss: 39.4710 - val_age_loss: 353.5485 - val_age_mae: 16.2038 - val_gender_accuracy: 0.4800 - val_gender_loss: 0.7532 - val_income_accuracy: 0.1200 - val_income_loss: 2.3972 - val_loss: 93.4376



模型训练完成
最终损失 - 年龄: 108.4334, 收入: 2.6537, 性别: 0.9318
