## 网络结构设计
1. 提取用户特征和电影特征作为网络的输入，其中：
    - 用户特征包含：性别、年龄和职业
    - 电影特征包含：电影名称、电影类型以及电影海报
2. 提取用户信息，使用Embedding层将用户特征映射为向量表示，然后输入到全连接层并相加
3. 提取电影信息，将电影类型映射为向量表示，电影名称和电影海报使用卷积层得到向量表示，然后输入到全连接层并相加
4. 得到用户和电影的向量表示后，计算二者的余弦相似度。最后，用该相似度和用户真实评分的均方差作为该回归模型的损失函数

In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

## 1. 用户信息
### 1.1 提取性别特征

In [2]:
# 自定义一个用户性别数据
usr_gender_data = np.array((0, 1)).reshape(-1).astype('int64')
usr_gender_dict_size = 2
usr_gender_emb = nn.Embedding(num_embeddings=usr_gender_dict_size, embedding_dim=16)
usr_gender_fc = nn.Linear(in_features=16, out_features=16)

usr_gender_var = torch.from_numpy(usr_gender_data)
usr_gender_feat = usr_gender_fc(usr_gender_emb(usr_gender_var))
usr_gender_feat = F.relu(usr_gender_feat)
print(usr_gender_feat.shape)
print(usr_gender_feat)

torch.Size([2, 16])
tensor([[0.0000, 0.0000, 0.6966, 0.0000, 0.3730, 0.4355, 0.4415, 1.1237, 0.3944,
         0.3533, 0.0000, 0.0000, 1.0818, 0.0798, 0.0000, 0.0000],
        [0.0000, 0.3916, 0.3843, 0.0000, 0.0998, 0.0000, 0.5639, 0.5535, 0.0000,
         0.0000, 0.5187, 0.3305, 0.5410, 0.0000, 0.1427, 0.0000]],
       grad_fn=<ReluBackward0>)


### 1.2 提取用户年龄特征

In [3]:
# 自定义一个用户年龄数据
usr_age_data = np.array((1, 18)).reshape(-1).astype('int64')
usr_age_dict_size = 56 + 1
usr_age_emb = nn.Embedding(num_embeddings=usr_age_dict_size, embedding_dim=16)
usr_age_fc = nn.Linear(in_features=16, out_features=16)

usr_age_var = torch.from_numpy(usr_age_data)
usr_age_feat = usr_age_fc(usr_age_emb(usr_age_var))
usr_age_feat = F.relu(usr_age_feat)
print(usr_age_feat.shape)
print(usr_age_feat)

torch.Size([2, 16])
tensor([[ 0.5290, -0.6472, -0.3879,  0.5995, -0.2659,  0.2476, -1.1362,  0.3244,
          1.6538, -0.4319, -0.7275,  0.7412,  0.0623,  0.1928,  0.0296,  1.9320],
        [ 0.4671, -0.8481,  0.1499,  1.2090,  1.2093,  0.3064,  1.6807, -0.2640,
         -0.5385,  0.6502, -0.2269, -0.3856, -0.0081, -0.2660, -0.5461, -1.1172]],
       grad_fn=<AddmmBackward>)


### 1.3 提取用户职业特征

In [4]:
usr_job_data = np.array((0, 20)).reshape(-1).astype('int64')
usr_job_dict_size = 20 + 1
usr_job_emb = nn.Embedding(num_embeddings=usr_job_dict_size, embedding_dim=16)
usr_job_fc = nn.Linear(in_features=16, out_features=16)

usr_job_var = torch.from_numpy(usr_job_data)
usr_job_feat = usr_job_fc(usr_job_emb(usr_job_var))
usr_job_feat = F.relu(usr_job_feat)
print(usr_job_feat.shape)
print(usr_job_feat)

torch.Size([2, 16])
tensor([[1.9749, 0.0000, 0.0000, 0.8252, 0.0000, 0.6930, 0.0000, 0.0000, 0.0358,
         0.0000, 0.7563, 0.0000, 0.0000, 0.4580, 0.8911, 0.0455],
        [0.7527, 0.2311, 0.0000, 0.1405, 1.1170, 0.5591, 0.1824, 1.3307, 0.0000,
         0.0000, 0.0399, 0.0000, 0.0000, 0.1492, 0.2513, 0.0000]],
       grad_fn=<ReluBackward0>)


### 1.4 融合用户特征

In [5]:
fc_job = nn.Linear(in_features=16, out_features=200)
fc_age = nn.Linear(in_features=16, out_features=200)
fc_gender = nn.Linear(in_features=16, out_features=200)

gender_feat = F.tanh(fc_gender(usr_gender_feat))
age_feat = F.tanh(fc_age(usr_age_feat))
job_feat = F.tanh(fc_job(usr_job_feat))

usr_feat = gender_feat + age_feat + job_feat
print(usr_feat.shape)
print(usr_feat)

torch.Size([2, 200])
tensor([[-0.4754, -0.5006,  0.7800, -0.0916, -0.4430,  0.9941, -0.2878, -0.3890,
          0.9851,  0.0592,  1.3108, -0.5596,  0.5079,  0.4680, -0.0095, -0.1781,
          0.9574,  0.1006, -0.6925,  1.4402, -0.0202, -0.3206, -0.2330,  0.2378,
          0.4972,  0.3473,  0.9743, -0.2952,  0.3214, -0.1116,  0.6535, -0.3434,
         -0.3431,  0.4105,  0.4813,  0.7363, -0.0193,  0.7338,  0.0827,  0.7957,
         -0.8258, -0.6831,  0.9995, -0.1984, -0.9636, -0.5469,  0.2023,  0.8143,
         -1.2965,  0.7914,  0.2987,  0.1895, -0.1863,  0.2468,  0.2159, -0.1489,
          0.7074, -1.0013, -0.5645, -0.4629,  0.3709, -0.6216,  0.3045, -0.2654,
         -1.0390, -0.2435, -0.3589, -0.1659,  0.2704,  0.0126, -1.1397, -0.2288,
         -0.0535,  0.4737, -1.8212,  0.0198,  0.9539,  0.0860, -1.0539, -1.0696,
          0.9028,  0.1553,  0.1447,  0.4883,  0.4088, -0.5573, -0.2253,  0.6572,
         -0.2522,  0.2880, -0.1599,  0.6082, -0.0648,  0.3248, -0.1293,  0.3353,
       

