In [1]:
import os
import pickle
import pandas as pd
from torch.utils.data import DataLoader, TensorDataset
import torch
from transformers import AdamW

In [2]:
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

In [3]:
def use_model(model_name, config_file_path, model_file_path, vocab_file_path, num_labels):
    # 選擇模型並加載設定
    if(model_name == 'bert'):
        from transformers import BertConfig, BertForSequenceClassification, BertTokenizer
        
        model_config, model_class, model_tokenizer = (BertConfig, BertForSequenceClassification, BertTokenizer)
        config = model_config.from_pretrained(config_file_path,num_labels = num_labels)
        model = model_class.from_pretrained(model_file_path, from_tf=bool('.ckpt' in 'bert-base-chinese'), config=config)
        tokenizer = model_tokenizer(vocab_file=vocab_file_path)
        return model, tokenizer
    
    elif(model_name == 'albert'):
        from albert.albert_zh import AlbertConfig, AlbertTokenizer, AlbertForSequenceClassification
        
        model_config, model_class, model_tokenizer = (AlbertConfig, AlbertForSequenceClassification, AlbertTokenizer)
        config = model_config.from_pretrained(config_file_path, num_labels=num_labels)
        model = model_class.from_pretrained(model_file_path, config=config)
        tokenizer = model_tokenizer.from_pretrained(vocab_file_path)
        return model, tokenizer

In [4]:
# 將文字輸入轉換成對應的id編號
def to_bert_ids(tokenizer, q_input):
    return tokenizer.build_inputs_with_special_tokens(tokenizer.convert_tokens_to_ids(tokenizer.tokenize(q_input)))

In [5]:
def split_dataset(full_dataset, split_rate=0.8):  
    train_size = int(split_rate * len(full_dataset))
    test_size = len(full_dataset) - train_size
    train_dataset, test_dataset = torch.utils.data.random_split(full_dataset, [train_size, test_size])
    return train_dataset, test_dataset

In [6]:
class DataDic(object):
    def __init__(self, answers):
        self.answers = answers #全部答案(含重複)
        self.answers_norepeat = sorted(list(set(answers))) # 不重複
        self.answers_types = len(self.answers_norepeat) # 總共多少類
        self.ans_list = [] # 用於查找id或是text的list
        self._make_dic() # 製作字典
    
    def _make_dic(self):
        for index_a, a in enumerate(self.answers_norepeat):
            if a != None:
                self.ans_list.append((index_a, a))

    def to_id(self, text):
        for ans_id, ans_text in self.ans_list:
            if text == ans_text:
                return ans_id

    def to_text(self, id):
        for ans_id,ans_text in self.ans_list:
            if id == ans_id:
                return ans_text

    @property
    def types(self):
        return self.answers_types
    
    @property
    def data(self):
        return self.answers

    def __len__(self):
        return len(self.answers)

In [7]:
def convert_data_to_feature(tokenizer, train_data_path):
    df_data = pd.read_csv(train_data_path, encoding='utf-8')
    
    questions = df_data['sentence'].values.tolist()
    answers = df_data['label'].values.tolist()
    
    print(len(questions), len(answers))
    assert len(answers) == len(questions)
    
    ans_dic = DataDic(answers)
    question_dic = DataDic(questions)

    q_tokens = []
    max_seq_len = 0
    for q in question_dic.data:
        bert_ids = to_bert_ids(tokenizer, q)
        if(len(bert_ids) > max_seq_len):
            max_seq_len = len(bert_ids)
        q_tokens.append(bert_ids)
        # print(tokenizer.convert_ids_to_tokens(tokenizer.build_inputs_with_special_tokens(tokenizer.convert_tokens_to_ids(tokenizer.tokenize(q)))))
    
    print("最長輸入長度:",max_seq_len)
    assert max_seq_len <= 512 # 小於BERT-base長度限制

    # 補齊長度
    for q in q_tokens:
        while len(q) < max_seq_len:
            q.append(0)
    
    a_labels = []
    for a in ans_dic.data:
        a_labels.append(ans_dic.to_id(a))
        # print (ans_dic.to_id(a))
    
    # BERT input embedding
    answer_lables = a_labels
    input_ids = q_tokens
    input_masks = [[1]*max_seq_len for i in range(len(question_dic))]  # 1
    input_segment_ids = [[0]*max_seq_len for i in range(len(question_dic))]  # 0
    assert len(input_ids) == len(question_dic) and len(input_ids) == len(input_masks) and len(input_ids) == len(input_segment_ids)

    data_features = {
        'input_ids': input_ids,
        'input_masks': input_masks,
        'input_segment_ids': input_segment_ids,
        'answer_lables': answer_lables,
        'question_dic': question_dic,
        'answer_dic': ans_dic
    }
    
    output = open('trained_model/data_features.pkl', 'wb')
    pickle.dump(data_features, output)
    return data_features

In [8]:
# model_setting = {
#     "model_name":"bert", 
#     "config_file_path":"bert-base-chinese", 
#     "model_file_path":"bert-base-chinese", 
#     "vocab_file_path":"bert-base-chinese-vocab.txt",
#     "num_labels":2  # 分幾類 
# }

model_setting = {
    "model_name":"albert", 
    "config_file_path":"albert/albert_tiny/config.json", 
    "model_file_path":"albert/albert_tiny/pytorch_model.bin", 
    "vocab_file_path":"albert/albert_tiny/vocab.txt",
    "num_labels":2 # 分幾類
}

model, tokenizer = use_model(**model_setting)

# setting device    
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
print("using device", device)
model.to(device)

using device cuda


AlbertForSequenceClassification(
  (bert): AlbertModel(
    (embeddings): AlbertEmbeddings(
      (word_embeddings): Embedding(21128, 128, padding_idx=0)
      (word_embeddings_2): Linear(in_features=128, out_features=312, bias=False)
      (position_embeddings): Embedding(512, 312)
      (token_type_embeddings): Embedding(2, 312)
      (LayerNorm): LayerNorm((312,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (encoder): AlbertEncoder(
      (layer_shared): BertLayer(
        (attention): BertAttention(
          (self): BertSelfAttention(
            (query): Linear(in_features=312, out_features=312, bias=True)
            (key): Linear(in_features=312, out_features=312, bias=True)
            (value): Linear(in_features=312, out_features=312, bias=True)
            (dropout): Dropout(p=0.0, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=312, out_features=312, bias=True)
          

In [9]:
data_feature = convert_data_to_feature(tokenizer, './dataset1.csv')
input_ids = data_feature['input_ids']
input_masks = data_feature['input_masks']  # 1
input_segment_ids = data_feature['input_segment_ids']  # 0
answer_lables = data_feature['answer_lables']

91392 91392
最長輸入長度: 62


In [10]:
input_ids[0]

[101,
 5959,
 6752,
 2786,
 5183,
 7350,
 1987,
 9076,
 8129,
 1039,
 102,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0]

In [20]:
answer_lables[0]

1

In [5]:
def make_dataset(input_ids, input_masks, input_segment_ids, answer_lables):
    all_input_ids = torch.tensor([input_id for input_id in input_ids], dtype=torch.long)
    all_input_masks = torch.tensor([input_mask for input_mask in input_masks], dtype=torch.long)
    all_input_segment_ids = torch.tensor([input_segment_id for input_segment_id in input_segment_ids], dtype=torch.long)
    all_answer_lables = torch.tensor([answer_lable for answer_lable in answer_lables], dtype=torch.long)    
    
    return TensorDataset(all_input_ids, all_input_masks, all_input_segment_ids, all_answer_lables)

In [11]:
full_dataset = make_dataset(input_ids=input_ids, input_masks=input_masks, input_segment_ids=input_segment_ids, answer_lables=answer_lables)
train_dataset, test_dataset = split_dataset(full_dataset, 0.8)
train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=32, shuffle=True) 

In [12]:
def compute_accuracy(y_pred, y_target):
    # 計算正確率
    _, y_pred_indices = y_pred.max(dim=1)
    n_correct = torch.eq(y_pred_indices, y_target).sum().item()
    return n_correct / len(y_pred_indices) * 100

In [13]:
# Prepare optimizer and schedule (linear warmup and decay)
no_decay = ['bias', 'LayerNorm.weight']
optimizer_grouped_parameters = [
    {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': 0.0},
    {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
]
optimizer = AdamW(optimizer_grouped_parameters, lr=5e-6, eps=1e-8)

In [14]:
model.zero_grad()
for epoch in range(1):
    running_loss_val = 0.0
    running_acc = 0.0
    
    for batch_index, batch_dict in enumerate(train_dataloader):
        model.train()
        batch_dict = tuple(t.to(device) for t in batch_dict)
        outputs = model(
            batch_dict[0],
            # attention_mask=batch_dict[1],
            labels = batch_dict[3]
        )
            
        loss, logits = outputs[:2]
        loss.sum().backward()
        optimizer.step()
        # scheduler.step()  # Update learning rate schedule
        model.zero_grad()
            
        # compute the loss
        loss_t = loss.item()
        running_loss_val += (loss_t - running_loss_val) / (batch_index + 1)

        # compute the accuracy
        acc_t = compute_accuracy(logits, batch_dict[3])
        running_acc += (acc_t - running_acc) / (batch_index + 1)

        # log
        print("epoch:%2d batch:%4d train_loss:%2.4f train_acc:%3.4f"%(epoch+1, batch_index+1, running_loss_val, running_acc))
        
    running_loss_val = 0.0
    running_acc = 0.0
    for batch_index, batch_dict in enumerate(test_dataloader):
        model.eval()
        batch_dict = tuple(t.to(device) for t in batch_dict)
        outputs = model(
            batch_dict[0],
            # attention_mask=batch_dict[1],
            labels = batch_dict[3]
        )
        loss,logits = outputs[:2]
            
        # compute the loss
        loss_t = loss.item()
        running_loss_val += (loss_t - running_loss_val) / (batch_index + 1)

        # compute the accuracy
        acc_t = compute_accuracy(logits, batch_dict[3])
        running_acc += (acc_t - running_acc) / (batch_index + 1)

        # log
        print("epoch:%2d batch:%4d test_loss:%2.4f test_acc:%3.4f"%(epoch+1, batch_index+1, running_loss_val, running_acc))

epoch: 1 batch:   1 train_loss:0.7166 train_acc:46.8750
epoch: 1 batch:   2 train_loss:0.7050 train_acc:50.0000
epoch: 1 batch:   3 train_loss:0.6934 train_acc:48.9583
epoch: 1 batch:   4 train_loss:0.6824 train_acc:52.3438
epoch: 1 batch:   5 train_loss:0.6802 train_acc:54.3750
epoch: 1 batch:   6 train_loss:0.6773 train_acc:55.7292
epoch: 1 batch:   7 train_loss:0.6718 train_acc:58.4821
epoch: 1 batch:   8 train_loss:0.6697 train_acc:60.1562
epoch: 1 batch:   9 train_loss:0.6638 train_acc:62.8472
epoch: 1 batch:  10 train_loss:0.6595 train_acc:66.2500
epoch: 1 batch:  11 train_loss:0.6562 train_acc:68.1818
epoch: 1 batch:  12 train_loss:0.6513 train_acc:70.3125
epoch: 1 batch:  13 train_loss:0.6471 train_acc:72.1154
epoch: 1 batch:  14 train_loss:0.6431 train_acc:72.9911
epoch: 1 batch:  15 train_loss:0.6401 train_acc:74.1667
epoch: 1 batch:  16 train_loss:0.6389 train_acc:74.2188
epoch: 1 batch:  17 train_loss:0.6361 train_acc:74.4485
epoch: 1 batch:  18 train_loss:0.6326 train_acc:

epoch: 1 batch: 155 train_loss:0.3366 train_acc:95.0202
epoch: 1 batch: 156 train_loss:0.3353 train_acc:95.0521
epoch: 1 batch: 157 train_loss:0.3343 train_acc:95.0836
epoch: 1 batch: 158 train_loss:0.3332 train_acc:95.1147
epoch: 1 batch: 159 train_loss:0.3320 train_acc:95.1454
epoch: 1 batch: 160 train_loss:0.3310 train_acc:95.1562
epoch: 1 batch: 161 train_loss:0.3300 train_acc:95.1863
epoch: 1 batch: 162 train_loss:0.3289 train_acc:95.2160
epoch: 1 batch: 163 train_loss:0.3277 train_acc:95.2454
epoch: 1 batch: 164 train_loss:0.3267 train_acc:95.2744
epoch: 1 batch: 165 train_loss:0.3256 train_acc:95.3030
epoch: 1 batch: 166 train_loss:0.3246 train_acc:95.3313
epoch: 1 batch: 167 train_loss:0.3235 train_acc:95.3593
epoch: 1 batch: 168 train_loss:0.3224 train_acc:95.3869
epoch: 1 batch: 169 train_loss:0.3213 train_acc:95.4142
epoch: 1 batch: 170 train_loss:0.3202 train_acc:95.4412
epoch: 1 batch: 171 train_loss:0.3191 train_acc:95.4678
epoch: 1 batch: 172 train_loss:0.3181 train_acc:

epoch: 1 batch: 310 train_loss:0.2176 train_acc:97.4294
epoch: 1 batch: 311 train_loss:0.2172 train_acc:97.4377
epoch: 1 batch: 312 train_loss:0.2167 train_acc:97.4459
epoch: 1 batch: 313 train_loss:0.2162 train_acc:97.4541
epoch: 1 batch: 314 train_loss:0.2156 train_acc:97.4622
epoch: 1 batch: 315 train_loss:0.2151 train_acc:97.4702
epoch: 1 batch: 316 train_loss:0.2146 train_acc:97.4782
epoch: 1 batch: 317 train_loss:0.2141 train_acc:97.4862
epoch: 1 batch: 318 train_loss:0.2136 train_acc:97.4941
epoch: 1 batch: 319 train_loss:0.2132 train_acc:97.5020
epoch: 1 batch: 320 train_loss:0.2127 train_acc:97.5098
epoch: 1 batch: 321 train_loss:0.2123 train_acc:97.5175
epoch: 1 batch: 322 train_loss:0.2118 train_acc:97.5252
epoch: 1 batch: 323 train_loss:0.2113 train_acc:97.5329
epoch: 1 batch: 324 train_loss:0.2108 train_acc:97.5405
epoch: 1 batch: 325 train_loss:0.2104 train_acc:97.5481
epoch: 1 batch: 326 train_loss:0.2099 train_acc:97.5556
epoch: 1 batch: 327 train_loss:0.2094 train_acc:

epoch: 1 batch: 462 train_loss:0.1618 train_acc:98.2481
epoch: 1 batch: 463 train_loss:0.1616 train_acc:98.2451
epoch: 1 batch: 464 train_loss:0.1614 train_acc:98.2489
epoch: 1 batch: 465 train_loss:0.1611 train_acc:98.2527
epoch: 1 batch: 466 train_loss:0.1608 train_acc:98.2564
epoch: 1 batch: 467 train_loss:0.1606 train_acc:98.2602
epoch: 1 batch: 468 train_loss:0.1603 train_acc:98.2639
epoch: 1 batch: 469 train_loss:0.1600 train_acc:98.2676
epoch: 1 batch: 470 train_loss:0.1598 train_acc:98.2713
epoch: 1 batch: 471 train_loss:0.1595 train_acc:98.2749
epoch: 1 batch: 472 train_loss:0.1593 train_acc:98.2786
epoch: 1 batch: 473 train_loss:0.1590 train_acc:98.2822
epoch: 1 batch: 474 train_loss:0.1588 train_acc:98.2859
epoch: 1 batch: 475 train_loss:0.1585 train_acc:98.2895
epoch: 1 batch: 476 train_loss:0.1582 train_acc:98.2931
epoch: 1 batch: 477 train_loss:0.1580 train_acc:98.2966
epoch: 1 batch: 478 train_loss:0.1577 train_acc:98.3002
epoch: 1 batch: 479 train_loss:0.1575 train_acc:

epoch: 1 batch: 612 train_loss:0.1301 train_acc:98.6622
epoch: 1 batch: 613 train_loss:0.1299 train_acc:98.6644
epoch: 1 batch: 614 train_loss:0.1297 train_acc:98.6665
epoch: 1 batch: 615 train_loss:0.1296 train_acc:98.6636
epoch: 1 batch: 616 train_loss:0.1294 train_acc:98.6658
epoch: 1 batch: 617 train_loss:0.1293 train_acc:98.6679
epoch: 1 batch: 618 train_loss:0.1291 train_acc:98.6701
epoch: 1 batch: 619 train_loss:0.1289 train_acc:98.6723
epoch: 1 batch: 620 train_loss:0.1288 train_acc:98.6744
epoch: 1 batch: 621 train_loss:0.1286 train_acc:98.6765
epoch: 1 batch: 622 train_loss:0.1284 train_acc:98.6787
epoch: 1 batch: 623 train_loss:0.1283 train_acc:98.6808
epoch: 1 batch: 624 train_loss:0.1281 train_acc:98.6829
epoch: 1 batch: 625 train_loss:0.1279 train_acc:98.6850
epoch: 1 batch: 626 train_loss:0.1278 train_acc:98.6871
epoch: 1 batch: 627 train_loss:0.1277 train_acc:98.6842
epoch: 1 batch: 628 train_loss:0.1275 train_acc:98.6863
epoch: 1 batch: 629 train_loss:0.1273 train_acc:

epoch: 1 batch: 764 train_loss:0.1089 train_acc:98.9038
epoch: 1 batch: 765 train_loss:0.1088 train_acc:98.9052
epoch: 1 batch: 766 train_loss:0.1087 train_acc:98.9067
epoch: 1 batch: 767 train_loss:0.1085 train_acc:98.9081
epoch: 1 batch: 768 train_loss:0.1086 train_acc:98.9054
epoch: 1 batch: 769 train_loss:0.1085 train_acc:98.9069
epoch: 1 batch: 770 train_loss:0.1084 train_acc:98.9083
epoch: 1 batch: 771 train_loss:0.1082 train_acc:98.9097
epoch: 1 batch: 772 train_loss:0.1081 train_acc:98.9111
epoch: 1 batch: 773 train_loss:0.1080 train_acc:98.9125
epoch: 1 batch: 774 train_loss:0.1079 train_acc:98.9139
epoch: 1 batch: 775 train_loss:0.1078 train_acc:98.9153
epoch: 1 batch: 776 train_loss:0.1077 train_acc:98.9167
epoch: 1 batch: 777 train_loss:0.1075 train_acc:98.9181
epoch: 1 batch: 778 train_loss:0.1074 train_acc:98.9195
epoch: 1 batch: 779 train_loss:0.1073 train_acc:98.9209
epoch: 1 batch: 780 train_loss:0.1072 train_acc:98.9223
epoch: 1 batch: 781 train_loss:0.1071 train_acc:

epoch: 1 batch: 911 train_loss:0.0941 train_acc:99.0773
epoch: 1 batch: 912 train_loss:0.0940 train_acc:99.0783
epoch: 1 batch: 913 train_loss:0.0939 train_acc:99.0793
epoch: 1 batch: 914 train_loss:0.0939 train_acc:99.0803
epoch: 1 batch: 915 train_loss:0.0938 train_acc:99.0813
epoch: 1 batch: 916 train_loss:0.0937 train_acc:99.0823
epoch: 1 batch: 917 train_loss:0.0936 train_acc:99.0833
epoch: 1 batch: 918 train_loss:0.0935 train_acc:99.0843
epoch: 1 batch: 919 train_loss:0.0934 train_acc:99.0853
epoch: 1 batch: 920 train_loss:0.0933 train_acc:99.0863
epoch: 1 batch: 921 train_loss:0.0932 train_acc:99.0873
epoch: 1 batch: 922 train_loss:0.0931 train_acc:99.0883
epoch: 1 batch: 923 train_loss:0.0931 train_acc:99.0892
epoch: 1 batch: 924 train_loss:0.0930 train_acc:99.0902
epoch: 1 batch: 925 train_loss:0.0929 train_acc:99.0912
epoch: 1 batch: 926 train_loss:0.0928 train_acc:99.0922
epoch: 1 batch: 927 train_loss:0.0927 train_acc:99.0932
epoch: 1 batch: 928 train_loss:0.0926 train_acc:

epoch: 1 batch:1062 train_loss:0.0827 train_acc:99.2026
epoch: 1 batch:1063 train_loss:0.0826 train_acc:99.2033
epoch: 1 batch:1064 train_loss:0.0825 train_acc:99.2041
epoch: 1 batch:1065 train_loss:0.0825 train_acc:99.2048
epoch: 1 batch:1066 train_loss:0.0824 train_acc:99.2056
epoch: 1 batch:1067 train_loss:0.0823 train_acc:99.2063
epoch: 1 batch:1068 train_loss:0.0823 train_acc:99.2070
epoch: 1 batch:1069 train_loss:0.0822 train_acc:99.2078
epoch: 1 batch:1070 train_loss:0.0821 train_acc:99.2085
epoch: 1 batch:1071 train_loss:0.0821 train_acc:99.2093
epoch: 1 batch:1072 train_loss:0.0820 train_acc:99.2100
epoch: 1 batch:1073 train_loss:0.0819 train_acc:99.2107
epoch: 1 batch:1074 train_loss:0.0819 train_acc:99.2115
epoch: 1 batch:1075 train_loss:0.0818 train_acc:99.2122
epoch: 1 batch:1076 train_loss:0.0817 train_acc:99.2129
epoch: 1 batch:1077 train_loss:0.0817 train_acc:99.2137
epoch: 1 batch:1078 train_loss:0.0816 train_acc:99.2144
epoch: 1 batch:1079 train_loss:0.0815 train_acc:

epoch: 1 batch:1216 train_loss:0.0735 train_acc:99.2984
epoch: 1 batch:1217 train_loss:0.0734 train_acc:99.2990
epoch: 1 batch:1218 train_loss:0.0734 train_acc:99.2996
epoch: 1 batch:1219 train_loss:0.0733 train_acc:99.3001
epoch: 1 batch:1220 train_loss:0.0733 train_acc:99.3007
epoch: 1 batch:1221 train_loss:0.0732 train_acc:99.3013
epoch: 1 batch:1222 train_loss:0.0732 train_acc:99.3019
epoch: 1 batch:1223 train_loss:0.0731 train_acc:99.3024
epoch: 1 batch:1224 train_loss:0.0731 train_acc:99.3030
epoch: 1 batch:1225 train_loss:0.0730 train_acc:99.3036
epoch: 1 batch:1226 train_loss:0.0730 train_acc:99.3041
epoch: 1 batch:1227 train_loss:0.0729 train_acc:99.3047
epoch: 1 batch:1228 train_loss:0.0729 train_acc:99.3053
epoch: 1 batch:1229 train_loss:0.0728 train_acc:99.3058
epoch: 1 batch:1230 train_loss:0.0728 train_acc:99.3064
epoch: 1 batch:1231 train_loss:0.0727 train_acc:99.3070
epoch: 1 batch:1232 train_loss:0.0727 train_acc:99.3075
epoch: 1 batch:1233 train_loss:0.0726 train_acc:

epoch: 1 batch:1367 train_loss:0.0663 train_acc:99.3759
epoch: 1 batch:1368 train_loss:0.0662 train_acc:99.3764
epoch: 1 batch:1369 train_loss:0.0662 train_acc:99.3768
epoch: 1 batch:1370 train_loss:0.0662 train_acc:99.3773
epoch: 1 batch:1371 train_loss:0.0661 train_acc:99.3777
epoch: 1 batch:1372 train_loss:0.0661 train_acc:99.3782
epoch: 1 batch:1373 train_loss:0.0660 train_acc:99.3786
epoch: 1 batch:1374 train_loss:0.0660 train_acc:99.3791
epoch: 1 batch:1375 train_loss:0.0659 train_acc:99.3795
epoch: 1 batch:1376 train_loss:0.0659 train_acc:99.3800
epoch: 1 batch:1377 train_loss:0.0658 train_acc:99.3804
epoch: 1 batch:1378 train_loss:0.0658 train_acc:99.3809
epoch: 1 batch:1379 train_loss:0.0658 train_acc:99.3813
epoch: 1 batch:1380 train_loss:0.0657 train_acc:99.3818
epoch: 1 batch:1381 train_loss:0.0657 train_acc:99.3822
epoch: 1 batch:1382 train_loss:0.0656 train_acc:99.3827
epoch: 1 batch:1383 train_loss:0.0656 train_acc:99.3831
epoch: 1 batch:1384 train_loss:0.0656 train_acc:

epoch: 1 batch:1514 train_loss:0.0605 train_acc:99.4365
epoch: 1 batch:1515 train_loss:0.0604 train_acc:99.4369
epoch: 1 batch:1516 train_loss:0.0604 train_acc:99.4373
epoch: 1 batch:1517 train_loss:0.0604 train_acc:99.4376
epoch: 1 batch:1518 train_loss:0.0603 train_acc:99.4380
epoch: 1 batch:1519 train_loss:0.0603 train_acc:99.4384
epoch: 1 batch:1520 train_loss:0.0603 train_acc:99.4387
epoch: 1 batch:1521 train_loss:0.0602 train_acc:99.4391
epoch: 1 batch:1522 train_loss:0.0602 train_acc:99.4395
epoch: 1 batch:1523 train_loss:0.0602 train_acc:99.4398
epoch: 1 batch:1524 train_loss:0.0601 train_acc:99.4402
epoch: 1 batch:1525 train_loss:0.0601 train_acc:99.4406
epoch: 1 batch:1526 train_loss:0.0601 train_acc:99.4409
epoch: 1 batch:1527 train_loss:0.0600 train_acc:99.4413
epoch: 1 batch:1528 train_loss:0.0600 train_acc:99.4417
epoch: 1 batch:1529 train_loss:0.0599 train_acc:99.4420
epoch: 1 batch:1530 train_loss:0.0599 train_acc:99.4424
epoch: 1 batch:1531 train_loss:0.0599 train_acc:

epoch: 1 batch:1663 train_loss:0.0555 train_acc:99.4870
epoch: 1 batch:1664 train_loss:0.0555 train_acc:99.4873
epoch: 1 batch:1665 train_loss:0.0555 train_acc:99.4876
epoch: 1 batch:1666 train_loss:0.0555 train_acc:99.4879
epoch: 1 batch:1667 train_loss:0.0554 train_acc:99.4882
epoch: 1 batch:1668 train_loss:0.0554 train_acc:99.4885
epoch: 1 batch:1669 train_loss:0.0554 train_acc:99.4888
epoch: 1 batch:1670 train_loss:0.0553 train_acc:99.4891
epoch: 1 batch:1671 train_loss:0.0553 train_acc:99.4895
epoch: 1 batch:1672 train_loss:0.0553 train_acc:99.4898
epoch: 1 batch:1673 train_loss:0.0552 train_acc:99.4901
epoch: 1 batch:1674 train_loss:0.0552 train_acc:99.4904
epoch: 1 batch:1675 train_loss:0.0552 train_acc:99.4907
epoch: 1 batch:1676 train_loss:0.0552 train_acc:99.4910
epoch: 1 batch:1677 train_loss:0.0551 train_acc:99.4913
epoch: 1 batch:1678 train_loss:0.0551 train_acc:99.4916
epoch: 1 batch:1679 train_loss:0.0551 train_acc:99.4919
epoch: 1 batch:1680 train_loss:0.0550 train_acc:

epoch: 1 batch:1817 train_loss:0.0512 train_acc:99.5305
epoch: 1 batch:1818 train_loss:0.0512 train_acc:99.5307
epoch: 1 batch:1819 train_loss:0.0512 train_acc:99.5310
epoch: 1 batch:1820 train_loss:0.0512 train_acc:99.5312
epoch: 1 batch:1821 train_loss:0.0511 train_acc:99.5315
epoch: 1 batch:1822 train_loss:0.0511 train_acc:99.5318
epoch: 1 batch:1823 train_loss:0.0511 train_acc:99.5320
epoch: 1 batch:1824 train_loss:0.0510 train_acc:99.5323
epoch: 1 batch:1825 train_loss:0.0510 train_acc:99.5325
epoch: 1 batch:1826 train_loss:0.0510 train_acc:99.5328
epoch: 1 batch:1827 train_loss:0.0510 train_acc:99.5330
epoch: 1 batch:1828 train_loss:0.0509 train_acc:99.5333
epoch: 1 batch:1829 train_loss:0.0509 train_acc:99.5336
epoch: 1 batch:1830 train_loss:0.0509 train_acc:99.5338
epoch: 1 batch:1831 train_loss:0.0509 train_acc:99.5341
epoch: 1 batch:1832 train_loss:0.0508 train_acc:99.5343
epoch: 1 batch:1833 train_loss:0.0508 train_acc:99.5346
epoch: 1 batch:1834 train_loss:0.0508 train_acc:

epoch: 1 batch:1965 train_loss:0.0477 train_acc:99.5658
epoch: 1 batch:1966 train_loss:0.0476 train_acc:99.5661
epoch: 1 batch:1967 train_loss:0.0476 train_acc:99.5663
epoch: 1 batch:1968 train_loss:0.0476 train_acc:99.5665
epoch: 1 batch:1969 train_loss:0.0476 train_acc:99.5667
epoch: 1 batch:1970 train_loss:0.0475 train_acc:99.5669
epoch: 1 batch:1971 train_loss:0.0475 train_acc:99.5672
epoch: 1 batch:1972 train_loss:0.0475 train_acc:99.5674
epoch: 1 batch:1973 train_loss:0.0475 train_acc:99.5676
epoch: 1 batch:1974 train_loss:0.0475 train_acc:99.5678
epoch: 1 batch:1975 train_loss:0.0474 train_acc:99.5680
epoch: 1 batch:1976 train_loss:0.0474 train_acc:99.5683
epoch: 1 batch:1977 train_loss:0.0474 train_acc:99.5685
epoch: 1 batch:1978 train_loss:0.0474 train_acc:99.5687
epoch: 1 batch:1979 train_loss:0.0473 train_acc:99.5689
epoch: 1 batch:1980 train_loss:0.0473 train_acc:99.5691
epoch: 1 batch:1981 train_loss:0.0473 train_acc:99.5693
epoch: 1 batch:1982 train_loss:0.0473 train_acc:

epoch: 1 batch:2116 train_loss:0.0446 train_acc:99.5924
epoch: 1 batch:2117 train_loss:0.0446 train_acc:99.5926
epoch: 1 batch:2118 train_loss:0.0446 train_acc:99.5928
epoch: 1 batch:2119 train_loss:0.0446 train_acc:99.5930
epoch: 1 batch:2120 train_loss:0.0445 train_acc:99.5932
epoch: 1 batch:2121 train_loss:0.0445 train_acc:99.5934
epoch: 1 batch:2122 train_loss:0.0445 train_acc:99.5935
epoch: 1 batch:2123 train_loss:0.0445 train_acc:99.5937
epoch: 1 batch:2124 train_loss:0.0445 train_acc:99.5939
epoch: 1 batch:2125 train_loss:0.0444 train_acc:99.5941
epoch: 1 batch:2126 train_loss:0.0444 train_acc:99.5943
epoch: 1 batch:2127 train_loss:0.0444 train_acc:99.5945
epoch: 1 batch:2128 train_loss:0.0444 train_acc:99.5947
epoch: 1 batch:2129 train_loss:0.0444 train_acc:99.5949
epoch: 1 batch:2130 train_loss:0.0443 train_acc:99.5951
epoch: 1 batch:2131 train_loss:0.0443 train_acc:99.5953
epoch: 1 batch:2132 train_loss:0.0443 train_acc:99.5955
epoch: 1 batch:2133 train_loss:0.0443 train_acc:

epoch: 1 batch:2267 train_loss:0.0419 train_acc:99.6182
epoch: 1 batch:2268 train_loss:0.0419 train_acc:99.6183
epoch: 1 batch:2269 train_loss:0.0419 train_acc:99.6185
epoch: 1 batch:2270 train_loss:0.0418 train_acc:99.6187
epoch: 1 batch:2271 train_loss:0.0418 train_acc:99.6188
epoch: 1 batch:2272 train_loss:0.0418 train_acc:99.6190
epoch: 1 batch:2273 train_loss:0.0418 train_acc:99.6192
epoch: 1 batch:2274 train_loss:0.0418 train_acc:99.6193
epoch: 1 batch:2275 train_loss:0.0418 train_acc:99.6195
epoch: 1 batch:2276 train_loss:0.0417 train_acc:99.6197
epoch: 1 batch:2277 train_loss:0.0417 train_acc:99.6198
epoch: 1 batch:2278 train_loss:0.0417 train_acc:99.6200
epoch: 1 batch:2279 train_loss:0.0417 train_acc:99.6202
epoch: 1 batch:2280 train_loss:0.0417 train_acc:99.6203
epoch: 1 batch:2281 train_loss:0.0416 train_acc:99.6205
epoch: 1 batch:2282 train_loss:0.0416 train_acc:99.6207
epoch: 1 batch:2283 train_loss:0.0416 train_acc:99.6208
epoch: 1 batch:2284 train_loss:0.0416 train_acc:

epoch: 1 batch: 148 test_loss:0.0029 test_acc:99.9789
epoch: 1 batch: 149 test_loss:0.0029 test_acc:99.9790
epoch: 1 batch: 150 test_loss:0.0029 test_acc:99.9792
epoch: 1 batch: 151 test_loss:0.0029 test_acc:99.9793
epoch: 1 batch: 152 test_loss:0.0029 test_acc:99.9794
epoch: 1 batch: 153 test_loss:0.0029 test_acc:99.9796
epoch: 1 batch: 154 test_loss:0.0029 test_acc:99.9797
epoch: 1 batch: 155 test_loss:0.0029 test_acc:99.9798
epoch: 1 batch: 156 test_loss:0.0029 test_acc:99.9800
epoch: 1 batch: 157 test_loss:0.0029 test_acc:99.9801
epoch: 1 batch: 158 test_loss:0.0029 test_acc:99.9802
epoch: 1 batch: 159 test_loss:0.0029 test_acc:99.9803
epoch: 1 batch: 160 test_loss:0.0029 test_acc:99.9805
epoch: 1 batch: 161 test_loss:0.0029 test_acc:99.9806
epoch: 1 batch: 162 test_loss:0.0029 test_acc:99.9807
epoch: 1 batch: 163 test_loss:0.0029 test_acc:99.9808
epoch: 1 batch: 164 test_loss:0.0029 test_acc:99.9809
epoch: 1 batch: 165 test_loss:0.0029 test_acc:99.9811
epoch: 1 batch: 166 test_los

epoch: 1 batch: 321 test_loss:0.0028 test_acc:99.9903
epoch: 1 batch: 322 test_loss:0.0028 test_acc:99.9903
epoch: 1 batch: 323 test_loss:0.0028 test_acc:99.9903
epoch: 1 batch: 324 test_loss:0.0028 test_acc:99.9904
epoch: 1 batch: 325 test_loss:0.0028 test_acc:99.9904
epoch: 1 batch: 326 test_loss:0.0028 test_acc:99.9904
epoch: 1 batch: 327 test_loss:0.0028 test_acc:99.9904
epoch: 1 batch: 328 test_loss:0.0028 test_acc:99.9905
epoch: 1 batch: 329 test_loss:0.0028 test_acc:99.9905
epoch: 1 batch: 330 test_loss:0.0028 test_acc:99.9905
epoch: 1 batch: 331 test_loss:0.0028 test_acc:99.9906
epoch: 1 batch: 332 test_loss:0.0028 test_acc:99.9906
epoch: 1 batch: 333 test_loss:0.0028 test_acc:99.9906
epoch: 1 batch: 334 test_loss:0.0028 test_acc:99.9906
epoch: 1 batch: 335 test_loss:0.0028 test_acc:99.9907
epoch: 1 batch: 336 test_loss:0.0028 test_acc:99.9907
epoch: 1 batch: 337 test_loss:0.0028 test_acc:99.9907
epoch: 1 batch: 338 test_loss:0.0028 test_acc:99.9908
epoch: 1 batch: 339 test_los

epoch: 1 batch: 495 test_loss:0.0028 test_acc:99.9937
epoch: 1 batch: 496 test_loss:0.0028 test_acc:99.9937
epoch: 1 batch: 497 test_loss:0.0028 test_acc:99.9937
epoch: 1 batch: 498 test_loss:0.0028 test_acc:99.9937
epoch: 1 batch: 499 test_loss:0.0028 test_acc:99.9937
epoch: 1 batch: 500 test_loss:0.0028 test_acc:99.9937
epoch: 1 batch: 501 test_loss:0.0028 test_acc:99.9938
epoch: 1 batch: 502 test_loss:0.0028 test_acc:99.9938
epoch: 1 batch: 503 test_loss:0.0028 test_acc:99.9938
epoch: 1 batch: 504 test_loss:0.0028 test_acc:99.9938
epoch: 1 batch: 505 test_loss:0.0028 test_acc:99.9938
epoch: 1 batch: 506 test_loss:0.0028 test_acc:99.9938
epoch: 1 batch: 507 test_loss:0.0028 test_acc:99.9938
epoch: 1 batch: 508 test_loss:0.0028 test_acc:99.9938
epoch: 1 batch: 509 test_loss:0.0028 test_acc:99.9939
epoch: 1 batch: 510 test_loss:0.0028 test_acc:99.9939
epoch: 1 batch: 511 test_loss:0.0028 test_acc:99.9939
epoch: 1 batch: 512 test_loss:0.0028 test_acc:99.9939
epoch: 1 batch: 513 test_los

In [15]:
model_to_save = model.module if hasattr(model, 'module') else model  # Take care of distributed/parallel training
model_to_save.save_pretrained('trained_model')

In [11]:
pkl_file = open('trained_model/data_features.pkl', 'rb')
data_features = pickle.load(pkl_file)
answer_dic = data_features['answer_dic']

In [12]:
# ALBERT
model_setting = {
    "model_name":"albert", 
    "config_file_path":"trained_model/config.json", 
    "model_file_path":"trained_model/pytorch_model.bin", 
    "vocab_file_path":"albert/albert_tiny/vocab.txt",
    "num_labels":2 # 分幾類
}

In [13]:
model, tokenizer = use_model(**model_setting)

AlbertForSequenceClassification(
  (bert): AlbertModel(
    (embeddings): AlbertEmbeddings(
      (word_embeddings): Embedding(21128, 128, padding_idx=0)
      (word_embeddings_2): Linear(in_features=128, out_features=312, bias=False)
      (position_embeddings): Embedding(512, 312)
      (token_type_embeddings): Embedding(2, 312)
      (LayerNorm): LayerNorm((312,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (encoder): AlbertEncoder(
      (layer_shared): BertLayer(
        (attention): BertAttention(
          (self): BertSelfAttention(
            (query): Linear(in_features=312, out_features=312, bias=True)
            (key): Linear(in_features=312, out_features=312, bias=True)
            (value): Linear(in_features=312, out_features=312, bias=True)
            (dropout): Dropout(p=0.0, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=312, out_features=312, bias=True)
          

In [14]:
q_inputs = [
    '為何路邊停車格有編號的要收費，無編號的不用收費',
    '債權人可否向稅捐稽徵處申請查調債務人之財產、所得資料',
    '想做大腸癌檢測，不知道轉到哪一個辦事處',
    'Bruce要轉錢給Jack',
    '我要轉學費給老師'
]

for q_input in q_inputs:
    bert_ids = to_bert_ids(tokenizer,q_input)
    assert len(bert_ids) <= 512
    input_ids = torch.LongTensor(bert_ids).unsqueeze(0)

    # predict
    outputs = model(input_ids)
    predicts = outputs[:2]
    predicts = predicts[0]
    max_val = torch.max(predicts)
    label = (predicts == max_val).nonzero().numpy()[0][1]
    ans_label = answer_dic.to_text(label)
        
    print(q_input)
    print("Action: " + ans_label)
    print()

為何路邊停車格有編號的要收費，無編號的不用收費
Action: OTHER

債權人可否向稅捐稽徵處申請查調債務人之財產、所得資料
Action: OTHER

想做大腸癌檢測，不知道轉到哪一個辦事處
Action: OTHER

Bruce要轉錢給Jack
Action: TRANSFER

我要轉學費給老師
Action: TRANSFER

