In [1]:
import json
import math
import random
import torch

import torch.nn as nn
import pandas as pd
from torch.utils.data import Dataset, DataLoader

In [2]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
DEVICE

device(type='cuda')

In [3]:
def process_data(file_path):
    with open(file_path, "r") as file:
        datas = json.load(file)

    dataset = []
    for data in datas:
        user_info = []
        for key, value in data["user_info"].items():
            if value:
                user_info.append(f"{key}_{value}")
        
        parameter = []
        for key, value in data["parameter"].items():
            if value:
                parameter.append(f"{key}_{value}")
        dataset.append({"user_info": user_info, "parameter": parameter})

    return dataset

In [4]:
data_path = "../../train_data/type_1/data.json"
dataset = process_data(data_path)

# special symbols
special_symbols = ['<unk>', '<pad>', '<bos>', '<eos>']

# make vocab
user_info_vocab = list(special_symbols)
parameter_vocab = list(special_symbols)

for data in dataset:
    user_info_vocab.extend(data["user_info"])
    parameter_vocab.extend(data["parameter"])

user_info_vocab = sorted(list(set(user_info_vocab)))
parameter_vocab = sorted(list(set(parameter_vocab)))

len(user_info_vocab), len(parameter_vocab)

(1217, 4691)

In [5]:
for index, word in enumerate(user_info_vocab):
    if index < 20:
        print(f"Word: {word} —— Index: {index}")
    else:
        break

Word: <bos> —— Index: 0
Word: <eos> —— Index: 1
Word: <pad> —— Index: 2
Word: <unk> —— Index: 3
Word: age_1 —— Index: 4
Word: age_10 —— Index: 5
Word: age_104 —— Index: 6
Word: age_1048 —— Index: 7
Word: age_11 —— Index: 8
Word: age_12 —— Index: 9
Word: age_13 —— Index: 10
Word: age_14 —— Index: 11
Word: age_15 —— Index: 12
Word: age_16 —— Index: 13
Word: age_17 —— Index: 14
Word: age_18 —— Index: 15
Word: age_19 —— Index: 16
Word: age_2 —— Index: 17
Word: age_20 —— Index: 18
Word: age_21 —— Index: 19


In [6]:
train_path = "../../train_data/type_1/train.json"
valid_path = "../../train_data/type_1/valid.json"

train_dataset = process_data(train_path)
valid_dataset = process_data(valid_path)

In [7]:
# Tokenize
token_train_dataset = []
for data in train_dataset:
    user_info = [user_info_vocab.index(word) for word in data["user_info"]]
    parameter = [parameter_vocab.index(word) for word in data["parameter"]]
    token_train_dataset.append({"user_info": user_info, "parameter": parameter})


token_valid_dataset = []
for data in valid_dataset:
    user_info = [user_info_vocab.index(word) for word in data["user_info"]]
    parameter = [parameter_vocab.index(word) for word in data["parameter"]]
    token_valid_dataset.append({"user_info": user_info, "parameter": parameter})

In [8]:
for data in token_train_dataset:
    print(data)
    break

{'user_info': [33, 1099, 1006, 889, 933, 956, 1029, 448, 333, 376, 404, 474], 'parameter': [6, 73, 119, 165, 211, 257, 303, 349, 398, 443, 490, 540, 588, 633, 676, 722, 768, 828, 875, 923, 971, 1019, 1067, 1115, 1160, 1199, 1236, 1284, 1328, 1376, 1422, 1470, 1518, 1538, 1560, 1579, 1598, 1617, 1636, 1655, 1674, 1694, 1715, 1736, 1757, 1778, 1799, 1819, 1839, 1847, 1879, 1911, 1955, 2002, 2049, 2096, 2143, 2190, 2237, 2285, 2325, 2371, 2420, 2467, 2514, 2572, 2616, 2663, 2710, 2757, 2804, 2851, 2892, 2935, 2973, 3012, 3067, 3103, 3154, 3200, 3246, 3258, 3282, 3318, 3362, 3409, 3456, 3503, 3550, 3597, 3644, 3692, 3732, 3778, 3827, 3874, 3921, 3979, 4023, 4070, 4117, 4164, 4211, 4258, 4299, 4342, 4380, 4419, 4474, 4510, 4561, 4607, 4653, 4667, 4673, 4676, 4677, 4683, 4688]}


In [9]:
class HearingAidDataset(Dataset):
    def __init__(self, dataset):
        self.dataset = dataset
    
    def __len__(self):
        return len(self.dataset)
    
    def __getitem__(self, idx):
        user_info = self.dataset[idx]["user_info"]
        user_info.insert(0, 0)
        user_info.append(1)
        
        param = self.dataset[idx]["parameter"]
        param.insert(0, 0)
        param.append(1)
        
        user_info_tensor = torch.tensor(user_info)
        param_tensor = torch.tensor(param)

        return user_info_tensor, param_tensor

In [10]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=300, dropout=0.1):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0) # Shape (1, max_len, d_model)
        self.register_buffer('pe', pe)


    def forward(self, x):
        """
            x: (batch_size, x_len, d_model)
            requires_grad_(False) is used to prevent the model from updating the positional encoding
        """
        x = x + self.pe[:, : x.size(1)].requires_grad_(False)
        return self.dropout(x)

In [11]:
class TransformerMultiOutputRegressor(nn.Module):
    def __init__(self, src_vocab_size, tgt_vocab_size, d_model, n_heads, n_layers, dropout=0.1, batch_first=True):
        super().__init__()
        self.src_tok_emb = nn.Embedding(src_vocab_size, d_model)
        self.tgt_tok_emb = nn.Embedding(tgt_vocab_size, d_model)
        self.trasformer = nn.Transformer(
            d_model=d_model, 
            nhead=n_heads, 
            num_encoder_layers=n_layers, 
            num_decoder_layers=n_layers,
            dropout=dropout, 
            batch_first=batch_first
        )
        self.positional_encoding = PositionalEncoding(d_model=d_model, dropout=dropout)
        self.generator = nn.Linear(d_model, tgt_vocab_size)


    def forward(self, src, tgt):
        src_key_padding_mask = (src == 2)
        tgt_key_padding_mask = (tgt == 2)
        src_mask = self.trasformer.generate_square_subsequent_mask(src.size(1)).bool().to(DEVICE)
        tgt_mask = self.trasformer.generate_square_subsequent_mask(tgt.size(1)).bool().to(DEVICE)

        src = self.src_tok_emb(src)
        src = self.positional_encoding(src)
        tgt = self.tgt_tok_emb(tgt)
        tgt = self.positional_encoding(tgt)

        outs = self.trasformer(
            src, tgt,
            src_mask=src_mask, tgt_mask=tgt_mask,
            src_key_padding_mask=src_key_padding_mask,
            tgt_key_padding_mask=tgt_key_padding_mask,
            tgt_is_causal=True, memory_is_causal=False
        )
        return self.generator(outs)

In [12]:
# 初始化模型
src_vocab_size = len(user_info_vocab)
tgt_vocab_size = len(parameter_vocab)
d_model = 128
n_heads = 8
n_layers = 6
dropout = 0.1
batch_first = True

model = TransformerMultiOutputRegressor(src_vocab_size, tgt_vocab_size, d_model, n_heads, n_layers, dropout, batch_first).to(DEVICE)

In [13]:
# 超参数
batch_size = 64
learning_rate = 0.0001
num_epochs = 60

In [14]:
def collate_fn(batch):
    user_info, param = zip(*batch)
    user_info = nn.utils.rnn.pad_sequence(user_info, batch_first=True, padding_value=2)
    param = nn.utils.rnn.pad_sequence(param, batch_first=True, padding_value=2)

    return user_info, param

In [15]:
training_dataset = HearingAidDataset(token_train_dataset)
valid_dataset = HearingAidDataset(token_valid_dataset)

training_loader = DataLoader(training_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn, num_workers=4, pin_memory=True)
valid_loader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn, num_workers=4, pin_memory=True)

In [16]:
for user_info, param in training_loader:
    user_info = user_info.to(DEVICE)
    param = param.to(DEVICE)
    print(f"user_info shape: {user_info.shape}")
    print(f"param shape: {param.shape}")
    param_input = param[:, :-1]
    param_target = param[:, 1:]

    logits = model(user_info, param_input)
    print(f"logits size: {logits.size(-1)}")
    print(logits.shape)
    print(logits.reshape(-1, logits.size(-1)).shape)
    print(param_target.shape)
    print(param_target.reshape(-1).shape)
    break

user_info shape: torch.Size([64, 35])
param shape: torch.Size([64, 121])
logits size: 4691
torch.Size([64, 120, 4691])
torch.Size([7680, 4691])
torch.Size([64, 120])
torch.Size([7680])


In [17]:
# Train
def train_epoch(model, loss_fn, optimizer):
    model.train()
    total_loss = 0
    for user_info, param in training_loader:
        user_info = user_info.to(DEVICE)
        param = param.to(DEVICE)
        
        param_input = param[:, :-1]
        param_target = param[:, 1:]

        logits = model(user_info, param_input)

        optimizer.zero_grad()
        loss = loss_fn(logits.reshape(-1, logits.size(-1)), param_target.reshape(-1))
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    return total_loss / len(training_loader)


def evaluate(model, loss_fn):
    model.eval()
    total_loss = 0
    with torch.no_grad():
        for user_info, param in valid_loader:
            user_info = user_info.to(DEVICE)
            param = param.to(DEVICE)
            param_input = param[:, :-1]
            param_target = param[:, 1:]

            logits = model(user_info, param_input)
            loss = loss_fn(logits.reshape(-1, logits.size(-1)), param_target.reshape(-1))
            total_loss += loss.item()
    return total_loss / len(valid_loader)

In [18]:
# 交叉熵损失
loss_fn = torch.nn.CrossEntropyLoss(ignore_index=2)
# 优化器：Adam
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, betas=(0.9, 0.98), eps=1e-9)

In [19]:
# from timeit import default_timer as timer

# training_losses = []
# valid_losses = []
# for epoch in range(num_epochs):
#     start_time = timer()
#     train_loss = train_epoch(model, loss_fn, optimizer)
#     training_losses.append(train_loss)
#     end_time = timer()
#     valid_loss = evaluate(model, loss_fn)
#     valid_losses.append(valid_loss)
#     print(f"Epoch: {epoch}, Train loss: {train_loss:.3f}, Valid loss: {valid_loss:.3f}, Time: {end_time - start_time:.2f}s")

In [20]:
# Draw the loss curve
# plt.plot(training_losses, label='Training Loss')
# plt.plot(valid_losses, label='Validation Loss')
# plt.title('Training and Validation Loss')
# plt.xlabel('Epochs')
# plt.ylabel('Loss')
# plt.legend()
# plt.show()

In [21]:
# 保存模型
# torch.save(model.state_dict(), '../models/transformer_type_1_80.pth')

In [26]:
# 加载模型
model.load_state_dict(torch.load('../models/transformer_type_1_80.pth'))
model.eval()

RuntimeError: Error(s) in loading state_dict for TransformerMultiOutputRegressor:
	size mismatch for src_tok_emb.weight: copying a param with shape torch.Size([1216, 128]) from checkpoint, the shape in current model is torch.Size([1217, 128]).

In [None]:
def predict(test_data):
    # Randomly select a sample from the test set
    user_info = test_data["user_info"]
    param = test_data["parameter"]

    user_info_token = [user_info_vocab.index(word) for word in user_info]
    user_info_token.insert(0, 0)
    user_info_token.append(1)

    param_token = [parameter_vocab.index(word) for word in param]
    param_token.insert(0, 0)
    param_token.append(1)

    user_info_tensor = torch.tensor(user_info_token).unsqueeze(0).to(DEVICE)
    param_tensor = torch.tensor(param_token).unsqueeze(0).to(DEVICE)
    param_input = param_tensor[:, :-1]
    param_target = param_tensor[:, 1:]

    logits = model(user_info_tensor, param_input)
    # print(f"logits shape: {logits.shape}")

    # Softmax
    logits_probs = nn.functional.softmax(logits, dim=-1)
    # print(f"logits_probs shape: {logits_probs.shape}")

    # Argmax
    logits_argmax = torch.argmax(logits, dim=-1)
    # print(f"logits_argmax shape: {logits_argmax.shape}")

    # Argmax to Index
    logits_argmax = logits_argmax.squeeze(0).cpu().numpy().tolist()

    # Logits_argmax to Word
    logits_words = [parameter_vocab[i] for i in logits_argmax]

    real = [item.split('_')[-1] for item in param]
    pred = [item.split('_')[-1] for item in logits_words[:-1]]

    return real, pred

In [None]:
# set display columns
pd.set_option('display.max_columns', None)

# randomly select a sample from the test set
test_path = "../../train_data/type_1/test.json"
test_dataset = process_data(test_path)

sample = random.choice(test_dataset)
real, pred = predict(sample)

# calculate the accuracy
correct = 0
for a, b in zip(real, pred):
    if a == b:
        correct += 1
print(f"Accuracy: {correct / len(real) * 100:.2f}%")


results = pd.DataFrame({
    '真实值': real,
    '预测值': pred
})

results

Accuracy: 84.75%


Unnamed: 0,真实值,预测值
0,8,6
1,28,38
2,39,32
3,39,39
4,39,39
5,39,39
6,39,39
7,39,39
8,39,32
9,34,42


In [None]:
accuracy = []
for item in test_dataset:
    reals, preds = predict(item)
    correct = 0
    for real, pred in zip(reals, preds):
        if real == pred:
            correct += 1
    accuracy.append(correct / len(reals))

print(f"Average Accuracy: {sum(accuracy) / len(accuracy) * 100:.2f}%")

Average Accuracy: 81.24%
