In [1]:
import os
import pickle
import pandas as pd
from torch.utils.data import DataLoader, TensorDataset
import torch
from transformers import AdamW

In [2]:
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

In [3]:
def use_model(model_name, config_file_path, model_file_path, vocab_file_path, num_labels):
    # 選擇模型並加載設定
    if(model_name == 'bert'):
        from transformers import BertConfig, BertForSequenceClassification, BertTokenizer
        
        model_config, model_class, model_tokenizer = (BertConfig, BertForSequenceClassification, BertTokenizer)
        config = model_config.from_pretrained(config_file_path,num_labels = num_labels)
        model = model_class.from_pretrained(model_file_path, from_tf=bool('.ckpt' in 'bert-base-chinese'), config=config)
        tokenizer = model_tokenizer(vocab_file=vocab_file_path)
        return model, tokenizer
    
    elif(model_name == 'albert'):
        from albert.albert_zh import AlbertConfig, AlbertTokenizer, AlbertForSequenceClassification
        
        model_config, model_class, model_tokenizer = (AlbertConfig, AlbertForSequenceClassification, AlbertTokenizer)
        config = model_config.from_pretrained(config_file_path, num_labels=num_labels)
        model = model_class.from_pretrained(model_file_path, config=config)
        tokenizer = model_tokenizer.from_pretrained(vocab_file_path)
        return model, tokenizer

In [4]:
# 將文字輸入轉換成對應的id編號
def to_bert_ids(tokenizer, q_input):
    return tokenizer.build_inputs_with_special_tokens(tokenizer.convert_tokens_to_ids(tokenizer.tokenize(q_input)))

In [5]:
def split_dataset(full_dataset, split_rate=0.8):  
    train_size = int(split_rate * len(full_dataset))
    test_size = len(full_dataset) - train_size
    train_dataset, test_dataset = torch.utils.data.random_split(full_dataset, [train_size, test_size])
    return train_dataset, test_dataset

In [6]:
class DataDic(object):
    def __init__(self, answers):
        self.answers = answers #全部答案(含重複)
        self.answers_norepeat = sorted(list(set(answers))) # 不重複
        self.answers_types = len(self.answers_norepeat) # 總共多少類
        self.ans_list = [] # 用於查找id或是text的list
        self._make_dic() # 製作字典
    
    def _make_dic(self):
        for index_a, a in enumerate(self.answers_norepeat):
            if a != None:
                self.ans_list.append((index_a, a))

    def to_id(self, text):
        for ans_id, ans_text in self.ans_list:
            if text == ans_text:
                return ans_id

    def to_text(self, id):
        for ans_id,ans_text in self.ans_list:
            if id == ans_id:
                return ans_text

    @property
    def types(self):
        return self.answers_types
    
    @property
    def data(self):
        return self.answers

    def __len__(self):
        return len(self.answers)

In [7]:
def convert_data_to_feature(tokenizer, train_data_path):
    df_data = pd.read_csv(train_data_path, encoding='utf-8')
    
    questions = df_data['sentence'].values.tolist()
    answers = df_data['label'].values.tolist()
    
    print(len(questions), len(answers))
    assert len(answers) == len(questions)
    
    ans_dic = DataDic(answers)
    question_dic = DataDic(questions)

    q_tokens = []
    max_seq_len = 0
    for q in question_dic.data:
        bert_ids = to_bert_ids(tokenizer, q)
        if(len(bert_ids) > max_seq_len):
            max_seq_len = len(bert_ids)
        q_tokens.append(bert_ids)
        # print(tokenizer.convert_ids_to_tokens(tokenizer.build_inputs_with_special_tokens(tokenizer.convert_tokens_to_ids(tokenizer.tokenize(q)))))
    
    print("最長輸入長度:",max_seq_len)
    assert max_seq_len <= 512 # 小於BERT-base長度限制

    # 補齊長度
    for q in q_tokens:
        while len(q) < max_seq_len:
            q.append(0)
    
    a_labels = []
    for a in ans_dic.data:
        a_labels.append(ans_dic.to_id(a))
        # print (ans_dic.to_id(a))
    
    # BERT input embedding
    answer_lables = a_labels
    input_ids = q_tokens
    input_masks = [[1]*max_seq_len for i in range(len(question_dic))]  # 1
    input_segment_ids = [[0]*max_seq_len for i in range(len(question_dic))]  # 0
    assert len(input_ids) == len(question_dic) and len(input_ids) == len(input_masks) and len(input_ids) == len(input_segment_ids)

    data_features = {
        'input_ids': input_ids,
        'input_masks': input_masks,
        'input_segment_ids': input_segment_ids,
        'answer_lables': answer_lables,
        'question_dic': question_dic,
        'answer_dic': ans_dic
    }
    
    output = open('trained_model/data_features.pkl', 'wb')
    pickle.dump(data_features, output)
    return data_features

In [8]:
model_setting = {
    "model_name":"bert", 
    "config_file_path":"bert-base-chinese", 
    "model_file_path":"bert-base-chinese", 
    "vocab_file_path":"bert-base-chinese-vocab.txt",
    "num_labels":2  # 分幾類 
}

# model_setting = {
#     "model_name":"albert", 
#     "config_file_path":"albert/albert_tiny/config.json", 
#     "model_file_path":"albert/albert_tiny/pytorch_model.bin", 
#     "vocab_file_path":"albert/albert_tiny/vocab.txt",
#     "num_labels":2 # 分幾類
# }

model, tokenizer = use_model(**model_setting)

# setting device    
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
print("using device", device)
model.to(device)

Some weights of the model checkpoint at bert-base-chinese were not used when initializing BertForSequenceClassification: ['cls.predictions.decoder.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.weight', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at

using device cuda


BertForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(21128, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, element

In [9]:
data_feature = convert_data_to_feature(tokenizer, './dataset_classification.csv')
input_ids = data_feature['input_ids']
input_masks = data_feature['input_masks']  # 1
input_segment_ids = data_feature['input_segment_ids']  # 0
answer_lables = data_feature['answer_lables']

108745 108745
最長輸入長度: 72


In [10]:
def make_dataset(input_ids, input_masks, input_segment_ids, answer_lables):
    all_input_ids = torch.tensor([input_id for input_id in input_ids], dtype=torch.long)
    all_input_masks = torch.tensor([input_mask for input_mask in input_masks], dtype=torch.long)
    all_input_segment_ids = torch.tensor([input_segment_id for input_segment_id in input_segment_ids], dtype=torch.long)
    all_answer_lables = torch.tensor([answer_lable for answer_lable in answer_lables], dtype=torch.long)    
    
    return TensorDataset(all_input_ids, all_input_masks, all_input_segment_ids, all_answer_lables)

In [11]:
batch_size = 64

In [12]:
full_dataset = make_dataset(input_ids=input_ids, input_masks=input_masks, input_segment_ids=input_segment_ids, answer_lables=answer_lables)
train_dataset, test_dataset = split_dataset(full_dataset, 0.8)
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True) 

In [13]:
def compute_accuracy(y_pred, y_target):
    # 計算正確率
    _, y_pred_indices = y_pred.max(dim=1)
    n_correct = torch.eq(y_pred_indices, y_target).sum().item()
    return n_correct / len(y_pred_indices) * 100

In [14]:
# Prepare optimizer and schedule (linear warmup and decay)
no_decay = ['bias', 'LayerNorm.weight']
optimizer_grouped_parameters = [
    {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': 0.0},
    {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
]
optimizer = AdamW(optimizer_grouped_parameters, lr=5e-6, eps=1e-8)

In [15]:
epochs = 1

In [16]:
model.zero_grad()
for epoch in range(epochs):
    running_loss_val = 0.0
    running_acc = 0.0
    
    for batch_index, batch_dict in enumerate(train_dataloader):
        model.train()
        batch_dict = tuple(t.to(device) for t in batch_dict)
        outputs = model(
            batch_dict[0],
            # attention_mask=batch_dict[1],
            labels = batch_dict[3]
        )
            
        loss, logits = outputs[:2]
        loss.sum().backward()
        optimizer.step()
        # scheduler.step()  # Update learning rate schedule
        model.zero_grad()
            
        # compute the loss
        loss_t = loss.item()
        running_loss_val += (loss_t - running_loss_val) / (batch_index + 1)

        # compute the accuracy
        acc_t = compute_accuracy(logits, batch_dict[3])
        running_acc += (acc_t - running_acc) / (batch_index + 1)

        # log
        print("epoch:%2d batch:%4d train_loss:%2.4f train_acc:%3.4f"%(epoch+1, batch_index+1, running_loss_val, running_acc))
        
    running_loss_val = 0.0
    running_acc = 0.0
    for batch_index, batch_dict in enumerate(test_dataloader):
        model.eval()
        batch_dict = tuple(t.to(device) for t in batch_dict)
        outputs = model(
            batch_dict[0],
            # attention_mask=batch_dict[1],
            labels = batch_dict[3]
        )
        loss,logits = outputs[:2]
            
        # compute the loss
        loss_t = loss.item()
        running_loss_val += (loss_t - running_loss_val) / (batch_index + 1)

        # compute the accuracy
        acc_t = compute_accuracy(logits, batch_dict[3])
        running_acc += (acc_t - running_acc) / (batch_index + 1)

        # log
        print("epoch:%2d batch:%4d test_loss:%2.4f test_acc:%3.4f"%(epoch+1, batch_index+1, running_loss_val, running_acc))

epoch: 1 batch:   1 train_loss:0.6196 train_acc:70.3125
epoch: 1 batch:   2 train_loss:0.6318 train_acc:67.1875
epoch: 1 batch:   3 train_loss:0.6292 train_acc:66.1458
epoch: 1 batch:   4 train_loss:0.6184 train_acc:66.7969
epoch: 1 batch:   5 train_loss:0.6009 train_acc:69.6875
epoch: 1 batch:   6 train_loss:0.5826 train_acc:72.9167
epoch: 1 batch:   7 train_loss:0.5652 train_acc:75.6696
epoch: 1 batch:   8 train_loss:0.5507 train_acc:77.9297
epoch: 1 batch:   9 train_loss:0.5361 train_acc:80.0347
epoch: 1 batch:  10 train_loss:0.5206 train_acc:81.7188
epoch: 1 batch:  11 train_loss:0.5072 train_acc:82.8125
epoch: 1 batch:  12 train_loss:0.4941 train_acc:83.8542
epoch: 1 batch:  13 train_loss:0.4801 train_acc:85.0962
epoch: 1 batch:  14 train_loss:0.4680 train_acc:85.7143
epoch: 1 batch:  15 train_loss:0.4574 train_acc:86.3542
epoch: 1 batch:  16 train_loss:0.4455 train_acc:87.1094
epoch: 1 batch:  17 train_loss:0.4331 train_acc:87.8676
epoch: 1 batch:  18 train_loss:0.4205 train_acc:

epoch: 1 batch: 148 train_loss:0.0871 train_acc:98.4481
epoch: 1 batch: 149 train_loss:0.0868 train_acc:98.4480
epoch: 1 batch: 150 train_loss:0.0863 train_acc:98.4583
epoch: 1 batch: 151 train_loss:0.0857 train_acc:98.4685
epoch: 1 batch: 152 train_loss:0.0852 train_acc:98.4786
epoch: 1 batch: 153 train_loss:0.0847 train_acc:98.4886
epoch: 1 batch: 154 train_loss:0.0842 train_acc:98.4984
epoch: 1 batch: 155 train_loss:0.0837 train_acc:98.5081
epoch: 1 batch: 156 train_loss:0.0832 train_acc:98.5176
epoch: 1 batch: 157 train_loss:0.0827 train_acc:98.5271
epoch: 1 batch: 158 train_loss:0.0822 train_acc:98.5364
epoch: 1 batch: 159 train_loss:0.0818 train_acc:98.5456
epoch: 1 batch: 160 train_loss:0.0813 train_acc:98.5547
epoch: 1 batch: 161 train_loss:0.0808 train_acc:98.5637
epoch: 1 batch: 162 train_loss:0.0804 train_acc:98.5725
epoch: 1 batch: 163 train_loss:0.0799 train_acc:98.5813
epoch: 1 batch: 164 train_loss:0.0795 train_acc:98.5899
epoch: 1 batch: 165 train_loss:0.0790 train_acc:

epoch: 1 batch: 295 train_loss:0.0456 train_acc:99.2108
epoch: 1 batch: 296 train_loss:0.0455 train_acc:99.2135
epoch: 1 batch: 297 train_loss:0.0453 train_acc:99.2161
epoch: 1 batch: 298 train_loss:0.0452 train_acc:99.2188
epoch: 1 batch: 299 train_loss:0.0450 train_acc:99.2214
epoch: 1 batch: 300 train_loss:0.0449 train_acc:99.2240
epoch: 1 batch: 301 train_loss:0.0447 train_acc:99.2265
epoch: 1 batch: 302 train_loss:0.0446 train_acc:99.2291
epoch: 1 batch: 303 train_loss:0.0445 train_acc:99.2316
epoch: 1 batch: 304 train_loss:0.0443 train_acc:99.2342
epoch: 1 batch: 305 train_loss:0.0442 train_acc:99.2367
epoch: 1 batch: 306 train_loss:0.0440 train_acc:99.2392
epoch: 1 batch: 307 train_loss:0.0439 train_acc:99.2417
epoch: 1 batch: 308 train_loss:0.0438 train_acc:99.2441
epoch: 1 batch: 309 train_loss:0.0436 train_acc:99.2466
epoch: 1 batch: 310 train_loss:0.0435 train_acc:99.2490
epoch: 1 batch: 311 train_loss:0.0433 train_acc:99.2514
epoch: 1 batch: 312 train_loss:0.0432 train_acc:

epoch: 1 batch: 442 train_loss:0.0315 train_acc:99.4627
epoch: 1 batch: 443 train_loss:0.0314 train_acc:99.4639
epoch: 1 batch: 444 train_loss:0.0313 train_acc:99.4651
epoch: 1 batch: 445 train_loss:0.0313 train_acc:99.4663
epoch: 1 batch: 446 train_loss:0.0312 train_acc:99.4675
epoch: 1 batch: 447 train_loss:0.0311 train_acc:99.4687
epoch: 1 batch: 448 train_loss:0.0311 train_acc:99.4699
epoch: 1 batch: 449 train_loss:0.0310 train_acc:99.4710
epoch: 1 batch: 450 train_loss:0.0309 train_acc:99.4722
epoch: 1 batch: 451 train_loss:0.0309 train_acc:99.4734
epoch: 1 batch: 452 train_loss:0.0308 train_acc:99.4746
epoch: 1 batch: 453 train_loss:0.0307 train_acc:99.4757
epoch: 1 batch: 454 train_loss:0.0307 train_acc:99.4769
epoch: 1 batch: 455 train_loss:0.0306 train_acc:99.4780
epoch: 1 batch: 456 train_loss:0.0305 train_acc:99.4792
epoch: 1 batch: 457 train_loss:0.0305 train_acc:99.4803
epoch: 1 batch: 458 train_loss:0.0304 train_acc:99.4814
epoch: 1 batch: 459 train_loss:0.0303 train_acc:

epoch: 1 batch: 589 train_loss:0.0238 train_acc:99.5968
epoch: 1 batch: 590 train_loss:0.0238 train_acc:99.5975
epoch: 1 batch: 591 train_loss:0.0237 train_acc:99.5981
epoch: 1 batch: 592 train_loss:0.0237 train_acc:99.5988
epoch: 1 batch: 593 train_loss:0.0236 train_acc:99.5995
epoch: 1 batch: 594 train_loss:0.0236 train_acc:99.6002
epoch: 1 batch: 595 train_loss:0.0236 train_acc:99.6008
epoch: 1 batch: 596 train_loss:0.0235 train_acc:99.6015
epoch: 1 batch: 597 train_loss:0.0235 train_acc:99.6022
epoch: 1 batch: 598 train_loss:0.0234 train_acc:99.6028
epoch: 1 batch: 599 train_loss:0.0234 train_acc:99.6035
epoch: 1 batch: 600 train_loss:0.0234 train_acc:99.6042
epoch: 1 batch: 601 train_loss:0.0233 train_acc:99.6048
epoch: 1 batch: 602 train_loss:0.0233 train_acc:99.6055
epoch: 1 batch: 603 train_loss:0.0232 train_acc:99.6061
epoch: 1 batch: 604 train_loss:0.0232 train_acc:99.6068
epoch: 1 batch: 605 train_loss:0.0232 train_acc:99.6074
epoch: 1 batch: 606 train_loss:0.0231 train_acc:

epoch: 1 batch: 736 train_loss:0.0193 train_acc:99.6709
epoch: 1 batch: 737 train_loss:0.0193 train_acc:99.6714
epoch: 1 batch: 738 train_loss:0.0193 train_acc:99.6718
epoch: 1 batch: 739 train_loss:0.0192 train_acc:99.6723
epoch: 1 batch: 740 train_loss:0.0192 train_acc:99.6727
epoch: 1 batch: 741 train_loss:0.0192 train_acc:99.6732
epoch: 1 batch: 742 train_loss:0.0192 train_acc:99.6736
epoch: 1 batch: 743 train_loss:0.0191 train_acc:99.6740
epoch: 1 batch: 744 train_loss:0.0191 train_acc:99.6745
epoch: 1 batch: 745 train_loss:0.0191 train_acc:99.6749
epoch: 1 batch: 746 train_loss:0.0191 train_acc:99.6754
epoch: 1 batch: 747 train_loss:0.0190 train_acc:99.6758
epoch: 1 batch: 748 train_loss:0.0190 train_acc:99.6762
epoch: 1 batch: 749 train_loss:0.0190 train_acc:99.6767
epoch: 1 batch: 750 train_loss:0.0190 train_acc:99.6771
epoch: 1 batch: 751 train_loss:0.0189 train_acc:99.6775
epoch: 1 batch: 752 train_loss:0.0189 train_acc:99.6779
epoch: 1 batch: 753 train_loss:0.0189 train_acc:

epoch: 1 batch: 883 train_loss:0.0161 train_acc:99.7257
epoch: 1 batch: 884 train_loss:0.0161 train_acc:99.7260
epoch: 1 batch: 885 train_loss:0.0161 train_acc:99.7263
epoch: 1 batch: 886 train_loss:0.0161 train_acc:99.7267
epoch: 1 batch: 887 train_loss:0.0161 train_acc:99.7270
epoch: 1 batch: 888 train_loss:0.0161 train_acc:99.7273
epoch: 1 batch: 889 train_loss:0.0160 train_acc:99.7276
epoch: 1 batch: 890 train_loss:0.0160 train_acc:99.7279
epoch: 1 batch: 891 train_loss:0.0160 train_acc:99.7282
epoch: 1 batch: 892 train_loss:0.0160 train_acc:99.7285
epoch: 1 batch: 893 train_loss:0.0160 train_acc:99.7288
epoch: 1 batch: 894 train_loss:0.0159 train_acc:99.7291
epoch: 1 batch: 895 train_loss:0.0159 train_acc:99.7294
epoch: 1 batch: 896 train_loss:0.0159 train_acc:99.7297
epoch: 1 batch: 897 train_loss:0.0159 train_acc:99.7300
epoch: 1 batch: 898 train_loss:0.0159 train_acc:99.7303
epoch: 1 batch: 899 train_loss:0.0159 train_acc:99.7306
epoch: 1 batch: 900 train_loss:0.0158 train_acc:

epoch: 1 batch:1030 train_loss:0.0139 train_acc:99.7649
epoch: 1 batch:1031 train_loss:0.0139 train_acc:99.7651
epoch: 1 batch:1032 train_loss:0.0138 train_acc:99.7653
epoch: 1 batch:1033 train_loss:0.0138 train_acc:99.7655
epoch: 1 batch:1034 train_loss:0.0138 train_acc:99.7658
epoch: 1 batch:1035 train_loss:0.0138 train_acc:99.7660
epoch: 1 batch:1036 train_loss:0.0138 train_acc:99.7662
epoch: 1 batch:1037 train_loss:0.0138 train_acc:99.7665
epoch: 1 batch:1038 train_loss:0.0138 train_acc:99.7667
epoch: 1 batch:1039 train_loss:0.0138 train_acc:99.7669
epoch: 1 batch:1040 train_loss:0.0137 train_acc:99.7671
epoch: 1 batch:1041 train_loss:0.0137 train_acc:99.7674
epoch: 1 batch:1042 train_loss:0.0137 train_acc:99.7676
epoch: 1 batch:1043 train_loss:0.0137 train_acc:99.7678
epoch: 1 batch:1044 train_loss:0.0137 train_acc:99.7680
epoch: 1 batch:1045 train_loss:0.0137 train_acc:99.7682
epoch: 1 batch:1046 train_loss:0.0137 train_acc:99.7685
epoch: 1 batch:1047 train_loss:0.0136 train_acc:

epoch: 1 batch:1177 train_loss:0.0122 train_acc:99.7942
epoch: 1 batch:1178 train_loss:0.0122 train_acc:99.7944
epoch: 1 batch:1179 train_loss:0.0121 train_acc:99.7946
epoch: 1 batch:1180 train_loss:0.0121 train_acc:99.7948
epoch: 1 batch:1181 train_loss:0.0121 train_acc:99.7949
epoch: 1 batch:1182 train_loss:0.0121 train_acc:99.7951
epoch: 1 batch:1183 train_loss:0.0121 train_acc:99.7953
epoch: 1 batch:1184 train_loss:0.0121 train_acc:99.7954
epoch: 1 batch:1185 train_loss:0.0121 train_acc:99.7956
epoch: 1 batch:1186 train_loss:0.0121 train_acc:99.7958
epoch: 1 batch:1187 train_loss:0.0121 train_acc:99.7960
epoch: 1 batch:1188 train_loss:0.0120 train_acc:99.7961
epoch: 1 batch:1189 train_loss:0.0120 train_acc:99.7963
epoch: 1 batch:1190 train_loss:0.0120 train_acc:99.7965
epoch: 1 batch:1191 train_loss:0.0120 train_acc:99.7967
epoch: 1 batch:1192 train_loss:0.0120 train_acc:99.7968
epoch: 1 batch:1193 train_loss:0.0120 train_acc:99.7970
epoch: 1 batch:1194 train_loss:0.0120 train_acc:

epoch: 1 batch:1324 train_loss:0.0108 train_acc:99.8171
epoch: 1 batch:1325 train_loss:0.0108 train_acc:99.8172
epoch: 1 batch:1326 train_loss:0.0108 train_acc:99.8174
epoch: 1 batch:1327 train_loss:0.0108 train_acc:99.8175
epoch: 1 batch:1328 train_loss:0.0108 train_acc:99.8176
epoch: 1 batch:1329 train_loss:0.0108 train_acc:99.8178
epoch: 1 batch:1330 train_loss:0.0108 train_acc:99.8179
epoch: 1 batch:1331 train_loss:0.0108 train_acc:99.8180
epoch: 1 batch:1332 train_loss:0.0108 train_acc:99.8182
epoch: 1 batch:1333 train_loss:0.0108 train_acc:99.8183
epoch: 1 batch:1334 train_loss:0.0107 train_acc:99.8185
epoch: 1 batch:1335 train_loss:0.0107 train_acc:99.8186
epoch: 1 batch:1336 train_loss:0.0107 train_acc:99.8187
epoch: 1 batch:1337 train_loss:0.0107 train_acc:99.8189
epoch: 1 batch:1338 train_loss:0.0107 train_acc:99.8190
epoch: 1 batch:1339 train_loss:0.0107 train_acc:99.8191
epoch: 1 batch:1340 train_loss:0.0107 train_acc:99.8193
epoch: 1 batch:1341 train_loss:0.0107 train_acc:

epoch: 1 batch: 114 test_loss:0.0001 test_acc:100.0000
epoch: 1 batch: 115 test_loss:0.0001 test_acc:100.0000
epoch: 1 batch: 116 test_loss:0.0001 test_acc:100.0000
epoch: 1 batch: 117 test_loss:0.0001 test_acc:100.0000
epoch: 1 batch: 118 test_loss:0.0001 test_acc:100.0000
epoch: 1 batch: 119 test_loss:0.0001 test_acc:100.0000
epoch: 1 batch: 120 test_loss:0.0001 test_acc:100.0000
epoch: 1 batch: 121 test_loss:0.0001 test_acc:100.0000
epoch: 1 batch: 122 test_loss:0.0001 test_acc:100.0000
epoch: 1 batch: 123 test_loss:0.0001 test_acc:100.0000
epoch: 1 batch: 124 test_loss:0.0001 test_acc:100.0000
epoch: 1 batch: 125 test_loss:0.0001 test_acc:100.0000
epoch: 1 batch: 126 test_loss:0.0001 test_acc:100.0000
epoch: 1 batch: 127 test_loss:0.0001 test_acc:100.0000
epoch: 1 batch: 128 test_loss:0.0001 test_acc:100.0000
epoch: 1 batch: 129 test_loss:0.0001 test_acc:100.0000
epoch: 1 batch: 130 test_loss:0.0001 test_acc:100.0000
epoch: 1 batch: 131 test_loss:0.0001 test_acc:100.0000
epoch: 1 b

epoch: 1 batch: 265 test_loss:0.0005 test_acc:99.9941
epoch: 1 batch: 266 test_loss:0.0005 test_acc:99.9941
epoch: 1 batch: 267 test_loss:0.0005 test_acc:99.9941
epoch: 1 batch: 268 test_loss:0.0005 test_acc:99.9942
epoch: 1 batch: 269 test_loss:0.0005 test_acc:99.9942
epoch: 1 batch: 270 test_loss:0.0005 test_acc:99.9942
epoch: 1 batch: 271 test_loss:0.0005 test_acc:99.9942
epoch: 1 batch: 272 test_loss:0.0005 test_acc:99.9943
epoch: 1 batch: 273 test_loss:0.0005 test_acc:99.9943
epoch: 1 batch: 274 test_loss:0.0005 test_acc:99.9943
epoch: 1 batch: 275 test_loss:0.0005 test_acc:99.9943
epoch: 1 batch: 276 test_loss:0.0005 test_acc:99.9943
epoch: 1 batch: 277 test_loss:0.0005 test_acc:99.9944
epoch: 1 batch: 278 test_loss:0.0005 test_acc:99.9944
epoch: 1 batch: 279 test_loss:0.0005 test_acc:99.9944
epoch: 1 batch: 280 test_loss:0.0005 test_acc:99.9944
epoch: 1 batch: 281 test_loss:0.0005 test_acc:99.9944
epoch: 1 batch: 282 test_loss:0.0005 test_acc:99.9945
epoch: 1 batch: 283 test_los

In [17]:
model_to_save = model.module if hasattr(model, 'module') else model  # Take care of distributed/parallel training
model_to_save.save_pretrained('trained_model')

In [18]:
torch.save(model, "model_classification_gpu_epoch_"+str(epochs)+"_batch_"+str(batch_size))

In [28]:
pkl_file = open('trained_model/data_features.pkl', 'rb')
data_features = pickle.load(pkl_file)
answer_dic = data_features['answer_dic']

In [29]:
# ALBERT
# model_setting = {
#     "model_name":"albert", 
#     "config_file_path":"trained_model/config.json", 
#     "model_file_path":"trained_model/pytorch_model.bin", 
#     "vocab_file_path":"albert/albert_tiny/vocab.txt",
#     "num_labels":2 # 分幾類
# }
model_setting = {
    "model_name":"bert", 
    "config_file_path":"bert-base-chinese", 
    "model_file_path":"bert-base-chinese", 
    "vocab_file_path":"bert-base-chinese-vocab.txt",
    "num_labels":2  # 分幾類 
}

In [30]:
model, tokenizer = use_model(**model_setting)

Some weights of the model checkpoint at bert-base-chinese were not used when initializing BertForSequenceClassification: ['cls.predictions.decoder.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.weight', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at

In [33]:
q_inputs = [
    '為何路邊停車格有編號的要收費，無編號的不用收費',
    '債權人可否向稅捐稽徵處申請查調債務人之財產、所得資料',
    '想做大腸癌檢測，不知道轉到哪一個辦事處',
    'Bruce要轉錢給Jack',
    '我想轉帳1000元給老師',
    '轉帳給父親的戶頭從台幣戶3137元'
]

In [34]:
model.eval()
for q_input in q_inputs:
    bert_ids = to_bert_ids(tokenizer, q_input)
    assert len(bert_ids) <= 512
    input_ids = torch.LongTensor(bert_ids).unsqueeze(0)

    # predict
    outputs = model(input_ids)
    predicts = outputs[:2]
    predicts = predicts[0]
    print(predicts)
    max_val = torch.max(predicts)
    label = (predicts == max_val).nonzero().numpy()[0][1]
    ans_label = answer_dic.to_text(label)
        
    print(q_input)
    print("Action: " + ans_label)
    print()

tensor([[-0.2999, -0.5514]], grad_fn=<AddmmBackward>)
為何路邊停車格有編號的要收費，無編號的不用收費
Action: OTHER

tensor([[ 0.1944, -0.5152]], grad_fn=<AddmmBackward>)
債權人可否向稅捐稽徵處申請查調債務人之財產、所得資料
Action: OTHER

tensor([[-0.0056, -0.6463]], grad_fn=<AddmmBackward>)
想做大腸癌檢測，不知道轉到哪一個辦事處
Action: OTHER

tensor([[-0.1361, -0.3262]], grad_fn=<AddmmBackward>)
Bruce要轉錢給Jack
Action: OTHER

tensor([[-0.1290, -0.3830]], grad_fn=<AddmmBackward>)
我想轉帳1000元給老師
Action: OTHER

tensor([[-0.3360, -0.6764]], grad_fn=<AddmmBackward>)
轉帳給父親的戶頭從台幣戶3137元
Action: OTHER

