In [1]:
source_folder = '/content/drive/My Drive/data/bert_insight/'
destination_folder = '/content/drive/My Drive/data/bert_insight/'

In [2]:
# !pip3 install transformers

In [3]:
# !pip3 install torchtext==0.6.0

In [4]:
import pandas as pd
import numpy as np

import torch
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import random_split
from tqdm.notebook import tqdm

from transformers import BertTokenizer, BertForSequenceClassification

## 載入加探訪資料

In [5]:
# df_train = pd.read_csv(source_folder + 'train.csv')
# df_test = pd.read_csv(source_folder + 'test.csv')
df_train = pd.read_csv('archive/train.csv', encoding='utf8')
df_test = pd.read_csv('archive/test.csv', encoding='utf8')

In [6]:
# 資料型式如下：
df_train.head()

Unnamed: 0,id,tid1,tid2,title1_zh,title2_zh,title1_en,title2_en,label
0,0,0,1,2017养老保险又新增两项，农村老人人人可申领，你领到了吗,警方辟谣“鸟巢大会每人领5万” 仍有老人坚持进京,There are two new old-age insurance benefits f...,"Police disprove ""bird's nest congress each per...",unrelated
1,3,2,3,"""你不来深圳，早晚你儿子也要来""，不出10年深圳人均GDP将超香港",深圳GDP首超香港？深圳统计局辟谣：只是差距在缩小,"""If you do not come to Shenzhen, sooner or lat...",Shenzhen's GDP outstrips Hong Kong? Shenzhen S...,unrelated
2,1,2,4,"""你不来深圳，早晚你儿子也要来""，不出10年深圳人均GDP将超香港",GDP首超香港？深圳澄清：还差一点点……,"""If you do not come to Shenzhen, sooner or lat...",The GDP overtopped Hong Kong? Shenzhen clarifi...,unrelated
3,2,2,5,"""你不来深圳，早晚你儿子也要来""，不出10年深圳人均GDP将超香港",去年深圳GDP首超香港？深圳统计局辟谣：还差611亿,"""If you do not come to Shenzhen, sooner or lat...",Shenzhen's GDP topped Hong Kong last year? She...,unrelated
4,9,6,7,"""用大蒜鉴别地沟油的方法,怎么鉴别地沟油",吃了30年食用油才知道，一片大蒜轻松鉴别地沟油,"""How to discriminate oil from gutter oil by me...",It took 30 years of cooking oil to know that o...,agreed


In [7]:
# 只取3個欄位
df_train = df_train[['title1_zh', 'title2_zh', 'label']].dropna(axis=0).reset_index(drop=True)
df_test = df_test[['id', 'title1_zh', 'title2_zh']].dropna(axis=0).reset_index(drop=True)

In [8]:
df_train.head()

Unnamed: 0,title1_zh,title2_zh,label
0,2017养老保险又新增两项，农村老人人人可申领，你领到了吗,警方辟谣“鸟巢大会每人领5万” 仍有老人坚持进京,unrelated
1,"""你不来深圳，早晚你儿子也要来""，不出10年深圳人均GDP将超香港",深圳GDP首超香港？深圳统计局辟谣：只是差距在缩小,unrelated
2,"""你不来深圳，早晚你儿子也要来""，不出10年深圳人均GDP将超香港",GDP首超香港？深圳澄清：还差一点点……,unrelated
3,"""你不来深圳，早晚你儿子也要来""，不出10年深圳人均GDP将超香港",去年深圳GDP首超香港？深圳统计局辟谣：还差611亿,unrelated
4,"""用大蒜鉴别地沟油的方法,怎么鉴别地沟油",吃了30年食用油才知道，一片大蒜轻松鉴别地沟油,agreed


In [9]:
df_train.label.value_counts()

unrelated    219313
agreed        92966
disagreed      8266
Name: label, dtype: int64

In [10]:
# 因數據太大 將data/50

df_train_unrelated = df_train[df_train.label=='unrelated'][:int(len(df_train[df_train.label=='unrelated'].label)/50)]
df_train_agreed = df_train[df_train.label=='agreed'][:int(len(df_train[df_train.label=='agreed'].label)/50)]
df_train_disagreed = df_train[df_train.label=='disagreed'][:int(len(df_train[df_train.label=='disagreed'].label)/50)]

df_train = pd.concat([df_train_unrelated,df_train_agreed,df_train_disagreed], axis=0)
df_train.reset_index(inplace=True)
df_train

Unnamed: 0,index,title1_zh,title2_zh,label
0,0,2017养老保险又新增两项，农村老人人人可申领，你领到了吗,警方辟谣“鸟巢大会每人领5万” 仍有老人坚持进京,unrelated
1,1,"""你不来深圳，早晚你儿子也要来""，不出10年深圳人均GDP将超香港",深圳GDP首超香港？深圳统计局辟谣：只是差距在缩小,unrelated
2,2,"""你不来深圳，早晚你儿子也要来""，不出10年深圳人均GDP将超香港",GDP首超香港？深圳澄清：还差一点点……,unrelated
3,3,"""你不来深圳，早晚你儿子也要来""，不出10年深圳人均GDP将超香港",去年深圳GDP首超香港？深圳统计局辟谣：还差611亿,unrelated
4,5,"""你不来深圳，早晚你儿子也要来""，不出10年深圳人均GDP将超香港",深圳GDP首超香港？统计局辟谣：未超但差距再度缩小,unrelated
...,...,...,...,...
6405,9106,2018年驾驶证扣分最新规定,辟谣：超速严重由扣6分调整为扣12分，纯属谣言！,disagreed
6406,9122,2018年驾驶证消分新规定，不知道的你一定要准备,辟谣！朋友圈疯传“2018驾照消分新规”是真的？权威解答在这里,disagreed
6407,9128,2018年驾驶证“替人销分”新规，驾驶证代扣分已经行不通了！,辟谣，“销分新规”存误读,disagreed
6408,9129,2018年高考零分作文《中国式平衡》,“高考满分/零分作文”已被辟谣，为什么还有人会相信？,disagreed


## 將 label 改為 編碼模式

In [11]:
# 分label為unique code
possible_labels = df_train.label.unique()

In [12]:
possible_labels

array(['unrelated', 'agreed', 'disagreed'], dtype=object)

In [13]:
label_dict = {}

for idx, label in enumerate(possible_labels):
    label_dict[label] = idx
  
label_dict

{'unrelated': 0, 'agreed': 1, 'disagreed': 2}

In [14]:
df_train['label_num'] = df_train.label.replace(label_dict)
df_train.head(20)

Unnamed: 0,index,title1_zh,title2_zh,label,label_num
0,0,2017养老保险又新增两项，农村老人人人可申领，你领到了吗,警方辟谣“鸟巢大会每人领5万” 仍有老人坚持进京,unrelated,0
1,1,"""你不来深圳，早晚你儿子也要来""，不出10年深圳人均GDP将超香港",深圳GDP首超香港？深圳统计局辟谣：只是差距在缩小,unrelated,0
2,2,"""你不来深圳，早晚你儿子也要来""，不出10年深圳人均GDP将超香港",GDP首超香港？深圳澄清：还差一点点……,unrelated,0
3,3,"""你不来深圳，早晚你儿子也要来""，不出10年深圳人均GDP将超香港",去年深圳GDP首超香港？深圳统计局辟谣：还差611亿,unrelated,0
4,5,"""你不来深圳，早晚你儿子也要来""，不出10年深圳人均GDP将超香港",深圳GDP首超香港？统计局辟谣：未超但差距再度缩小,unrelated,0
5,6,"""吃榴莲的禁忌,吃错会致命!","榴莲不能和什么一起吃 与咖啡同吃诱发心脏病""""",unrelated,0
6,7,"""你不来深圳，早晚你儿子也要来""，不出10年深圳人均GDP将超香港",深圳GDP首超香港？辟谣：未超但差距再度缩小,unrelated,0
7,8,"""旅行青蛙？居然是一款""""生育意愿测试器”！大家还是玩""""珠宝V课""""吧""",咸宁一家店的蛋糕含有“棉花”？崇阳多部门联合辟谣,unrelated,0
8,14,"""飞机就要起飞，一个男人在机舱口跪下！""这是今天最催泪的一幕……",陈乔恩公开宣布与他分手：有时候该放手就不再留恋,unrelated,0
9,16,"""男人在机舱口跪下！""原来一切都只因为爱！",“父亲跪舱门口”谣言实锤！微博首发者：上了假冒飞行员的当！,unrelated,0


## 資料分割

In [15]:
from sklearn.model_selection import train_test_split

# 用stratify處理資料不平衡的問題
X_train, X_val, y_train, y_label = train_test_split(
    df_train.index.values,
    df_train.label_num.values,
    test_size=0.2,
    random_state=17,
    stratify=df_train.label_num.values
)

In [16]:
# 先將data type 設成not_set模式
df_train['data_type'] = ['not_set']*df_train.shape[0]

In [17]:
# 將train跟val的資料分別做記號
df_train.loc[X_train, 'data_type'] = 'train'
df_train.loc[X_val, 'data_type'] = 'val'

In [18]:
df_train.groupby(['label', 'label_num', 'data_type']).count()

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,index,title1_zh,title2_zh
label,label_num,data_type,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
agreed,1,train,1487,1487,1487
agreed,1,val,372,372,372
disagreed,2,train,132,132,132
disagreed,2,val,33,33,33
unrelated,0,train,3509,3509,3509
unrelated,0,val,877,877,877


## 載入 bert Tokenizer 及 Data

In [19]:
from transformers import BertTokenizer
from torch.utils.data import TensorDataset

In [20]:
# 載入bert tokenizer
MODEL_NAME = 'bert-base-chinese'
tokenizer = BertTokenizer.from_pretrained(MODEL_NAME)

In [21]:
# 建置數據集
class NewsPairDataset(Dataset):
    def __init__(self, tokenizer, data_df, max_len=512):
        self.tokenizer = tokenizer
        self.data_df = data_df
        self.max_len = max_len

    def __getitem__(self, idx):
        text1 = self.data_df.loc[idx, 'title1_zh']
        text2 = self.data_df.loc[idx, 'title2_zh']
        label_num = self.data_df.loc[idx, 'label_num'] if 'label_num' in self.data_df.columns else None

        # add SEP between two sentence
        text1_tokens = self.tokenizer.tokenize(text1)
        text2_tokens = self.tokenizer.tokenize(text2)
        # add [cls] & [sep]
        len_all_tokens = len(text1_tokens) + len(text2_tokens) + 2

        # 若有句子比我們定義的長將最長砍半分給前後句
        if len_all_tokens > self.max_len:
            limit_num = (self.max_len - 2) // 2
            text1_tokens = text1_tokens[:limit_num]
            text2_tokens = text2_tokens[:limit_num]

        # 定義bert的inputs
        input = {}
        sentence_tokens = ['[CLS]'] + text1_tokens + ['[SEP]'] + text2_tokens
        # 定義input_ids 將句子轉成特定的id
        input['input_ids'] = torch.tensor(self.tokenizer.convert_tokens_to_ids(sentence_tokens),dtype=torch.long)

        # 取得[SEP]的位置
        sep_idx = sentence_tokens.index('[SEP]')
        # 第1句的token_type_ids設為0 第2句的token_type_ids設為1 
        input['token_type_ids'] = torch.tensor([0] * (sep_idx + 1) + [1] * (len(sentence_tokens) - sep_idx - 1), dtype=torch.long)

        # 設定 attention_mask 就是讓BERT只注意句子即可 padding不要注意
        input['attention_mask'] = torch.tensor([1] * len(sentence_tokens),dtype=torch.long)

        if label_num:
            label_num = torch.tensor(label_num)

        return input, label_num

    def __len__(self):
        return len(self.data_df)

In [22]:
# 定義mini batch, mode 設成train/test/predict
def create_mini_batch(samples):
    # 處理batch input
    input_ids = []
    token_type_ids = []
    attention_mask = []
    labels = []

    for idx in range(len(samples)):
        input_ids.append(samples[idx][0]['input_ids'].squeeze(0))
        token_type_ids.append(samples[idx][0]['token_type_ids'].squeeze(0))
        attention_mask.append(samples[idx][0]['attention_mask'].squeeze(0))
        if samples[idx][1] != None:
          labels.append(torch.tensor(samples[idx][1]))


    # zero pad 到同一序列長度
    input_ids = torch.nn.utils.rnn.pad_sequence(input_ids, batch_first=True, padding_value=0)
    token_type_ids = torch.nn.utils.rnn.pad_sequence(token_type_ids, batch_first=True, padding_value=0)
    attention_mask = torch.nn.utils.rnn.pad_sequence(attention_mask, batch_first=True, padding_value=0)

    if len(labels) > 0:
        labels = torch.stack(labels)
        return input_ids, token_type_ids, attention_mask, labels
    else:
        return input_ids, token_type_ids, attention_mask

In [23]:
train_batch_size = 32
eval_batch_size = 512

train_df = df_train[df_train.data_type == 'train']
train_df.reset_index(inplace=True)
val_df = df_train[df_train.data_type == 'val']
val_df.reset_index(inplace=True)

dataset_train = NewsPairDataset(tokenizer, train_df)
dataset_val = NewsPairDataset(tokenizer, val_df)

In [24]:
val_df

Unnamed: 0,level_0,index,title1_zh,title2_zh,label,label_num,data_type
0,1,1,"""你不来深圳，早晚你儿子也要来""，不出10年深圳人均GDP将超香港",深圳GDP首超香港？深圳统计局辟谣：只是差距在缩小,unrelated,0,val
1,4,5,"""你不来深圳，早晚你儿子也要来""，不出10年深圳人均GDP将超香港",深圳GDP首超香港？统计局辟谣：未超但差距再度缩小,unrelated,0,val
2,11,18,"""飞机就要起飞，一个男人在机舱口跪下！""这是见过最催泪的一幕……",池州一彩民中30万大奖乐晕倒在地？谣言！但却很暖心...,unrelated,0,val
3,12,19,#健康过大年#还在逗孩子喝酒？儿童喝酒的危害多大你知道吗？,孩子当父母“小尾巴”危害大，父母教育有技巧，孩子健康快乐成长,unrelated,0,val
4,19,26,#吃秀##美好的一天从早餐开始#续集这会灯好多，上个一会黑一会白本来皮肤就黄,不花钱甩肉吃一黑，皮肤水嫩少9斤，润肠通便、排体内3年垃圾！,unrelated,0,val
...,...,...,...,...,...,...,...
1277,6390,7825,2018年新交规：喝酒没开车也算酒驾，同样扣分罚款！别想钻空,辟谣！新交规，酒后躺车内休息也算酒驾？答案是……,disagreed,2,val
1278,6399,8850,2018年起，这5类农民再有钱，也不能建房，很多人都“中招”了！,传言：国家将不允许农民建房！真实情况是这样的！,disagreed,2,val
1279,6402,9050,2018年驾驶证“替人销分”新规，驾驶证代扣分已经行不通了！,辟谣！3月1日驾驶证销分实行新规是误读！烟台司机快别去扎堆了,disagreed,2,val
1280,6403,9052,2018年驾驶证“替人销分”新规，驾驶证代扣分已经行不通了！,紧急辟谣！“2018销分新规”存误读，不必扎堆排队了！,disagreed,2,val


In [25]:
len(dataset_train)

5128

In [26]:
len(dataset_val)

1282

## 定義 fine-tune mode

In [27]:
from transformers import BertForSequenceClassification

model = BertForSequenceClassification.from_pretrained(
    MODEL_NAME,
    num_labels=len(label_dict),
    output_attentions=False,
    output_hidden_states=False
)

Some weights of the model checkpoint at bert-base-chinese were not used when initializing BertForSequenceClassification: ['cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at

In [28]:
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler

train_loader = DataLoader(dataset_train, 
                          sampler=RandomSampler(dataset_train),
                          batch_size=train_batch_size, 
                          collate_fn=create_mini_batch)

valid_loader = DataLoader(dataset_val, 
                          sampler=RandomSampler(dataset_val),
                          batch_size=eval_batch_size, 
                          collate_fn=create_mini_batch)

In [29]:
len(train_loader)

161

## 定義優化器與調整learning rate

In [30]:
from transformers import AdamW, get_linear_schedule_with_warmup

optimizer = AdamW(
    model.parameters(),
    lr=2e-5,
    eps=1e-8
)

In [31]:
epochs = 5

scheduler = get_linear_schedule_with_warmup(
    optimizer,
    num_warmup_steps=0,
    num_training_steps=len(train_loader)*epochs
)

## 設定accuracy, f1 score

In [32]:
from sklearn.metrics import f1_score

In [33]:
def f1_score_func(preds, labels):
    preds_flat = np.argmax(preds, axis=1).flatten()
    labels_flat = labels.flatten()
    
    return f1_score(labels_flat, preds_flat, average='weighted')

In [34]:
def accuracy_per_class(preds, labels):
    label_dict_inverse = {v: k for k, v in label_dict.items()}
    
    preds_flat = np.argmax(preds, axis=1).flatten()
    labels_flat = labels.flatten()
    
    for label in np.unique(labels_flat):
        y_preds = preds_flat[labels_flat==label]
        y_true = labels_flat[labels_flat==label]
        print(f'Class: {label_dict_inverse[label]}')
        print(f'Accuracy: {len(y_preds[y_preds==label])}/{len(y_true)}\n')

In [35]:
import random

seed_val = 17
random.seed(seed_val)
np.random.seed(seed_val)
torch.manual_seed(seed_val)
torch.cuda.manual_seed_all(seed_val)

In [36]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

print(device)

cpu


In [37]:
def evaluate(dataloader_val):

    model.eval()
    
    loss_val_total = 0
    predictions, true_vals = [], []
    
    for batch in tqdm(dataloader_val):
        
        batch = tuple(b.to(device) for b in batch)
        
        inputs = {'input_ids':      batch[0],
                  'token_type_ids': batch[1],
                  'attention_mask': batch[2],
                  'labels':         batch[3],
                 }

        with torch.no_grad():        
            outputs = model(**inputs)
            
        loss = outputs[0]
        logits = outputs[1]
        loss_val_total += loss.item()

        logits = logits.detach().cpu().numpy()
        label_ids = inputs['labels'].cpu().numpy()
        predictions.append(logits)
        true_vals.append(label_ids)
    
    loss_val_avg = loss_val_total/len(dataloader_val) 
    
    predictions = np.concatenate(predictions, axis=0)
    true_vals = np.concatenate(true_vals, axis=0)
            
    return loss_val_avg, predictions, true_vals

In [38]:
for epoch in tqdm(range(1, epochs+1)):
    
    model.train()
    
    loss_train_total = 0

    progress_bar = tqdm(train_loader, 
                        desc='Epoch {:1d}'.format(epoch),
                        leave=False,
                        disable=False)

    for batch in progress_bar:
        model.zero_grad()
        batch = tuple(b.to(device) for b in batch)
        
        inputs = {
            'input_ids'       : batch[0],
            'token_type_ids'  : batch[1],
            'attention_mask'  : batch[2],
            'labels'          : batch[3]
        }
        
        outputs = model(**inputs)
        
        loss = outputs[0]
        loss_train_total += loss.item()
        loss.backward()
        
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        
        optimizer.step()
        scheduler.step()
        
        progress_bar.set_postfix({
            'training_loss': '{:.3f}'.format(loss.item()/len(batch))
        })
        
    # torch.save(model.state_dict(), f'Models/BERT_ft_epoch{epoch}.model')
    
    tqdm.write('\nEpoch {epoch}')
    
    loss_train_avg = loss_train_total/len(train_loader)
    tqdm.write(f'Training loss: {loss_train_avg}')
    
    val_loss, predictions, true_vals = evaluate(valid_loader)
    val_f1 = f1_score_func(predictions, true_vals)
    tqdm.write(f'Val loss: {val_loss}')
    tqdm.write(f'f1 score: {val_f1}')
    

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=5.0), HTML(value='')))

HBox(children=(HTML(value='Epoch 1'), FloatProgress(value=0.0, max=161.0), HTML(value='')))

  



Epoch {epoch}
Training loss: 0.5180314128628428


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=3.0), HTML(value='')))


Val loss: 0.321689635515213
f1 score: 0.8664102301212675


HBox(children=(HTML(value='Epoch 2'), FloatProgress(value=0.0, max=161.0), HTML(value='')))


Epoch {epoch}
Training loss: 0.27697803223540324


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=3.0), HTML(value='')))


Val loss: 0.29343581199645996
f1 score: 0.8892504953110667


HBox(children=(HTML(value='Epoch 3'), FloatProgress(value=0.0, max=161.0), HTML(value='')))


Epoch {epoch}
Training loss: 0.18886023508789745


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=3.0), HTML(value='')))


Val loss: 0.3004259367783864
f1 score: 0.8924422960113068


HBox(children=(HTML(value='Epoch 4'), FloatProgress(value=0.0, max=161.0), HTML(value='')))


Epoch {epoch}
Training loss: 0.12962135054194224


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=3.0), HTML(value='')))


Val loss: 0.2926155875126521
f1 score: 0.9047488962366376


HBox(children=(HTML(value='Epoch 5'), FloatProgress(value=0.0, max=161.0), HTML(value='')))


Epoch {epoch}
Training loss: 0.09334115973263053


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=3.0), HTML(value='')))


Val loss: 0.299183189868927
f1 score: 0.9043072825432508



In [47]:
df_test = df_test[:int(len(df_test.id)/50)]
df_test

Unnamed: 0,id,title1_zh,title2_zh
0,321187,萨拉赫人气爆棚!埃及总统大选未参选获百万选票 现任总统压力山大,辟谣！里昂官方否认费基尔加盟利物浦，难道是价格没谈拢？
1,321190,萨达姆被捕后告诫美国的一句话，发人深思,10大最让美国人相信的荒诞谣言，如蜥蜴人掌控着美国
2,321189,萨达姆此项计划没有此国破坏的话，美国还会对伊拉克发动战争吗,萨达姆被捕后告诫美国的一句话，发人深思
3,321193,萨达姆被捕后告诫美国的一句话，发人深思,被绞刑处死的萨达姆是替身？他的此男人举动击破替身谣言！
4,321191,萨达姆被捕后告诫美国的一句话，发人深思,中国川贝枇杷膏在美国受到热捧？纯属谣言！
...,...,...,...
1597,322791,蓝洁瑛亲口承认：曾遭曾志伟性侵，事后听了某人话不想报警,卓伟又爆新料：“曾志伟性侵蓝洁瑛”
1598,322790,蓝洁瑛亲口承认曾遭曾志伟性侵,网络疯传影片指出他强奸蓝洁瑛，曾志伟严正驳斥谣言
1599,322792,蓝洁瑛亲口承认：曾遭曾志伟性侵，事后听了某人话不想报警,终于开口了，春十三娘“蓝洁瑛”亲指当年“曾志伟”性侵！
1600,322793,蓝洁瑛亲口承认：曾遭曾志伟性侵，事后听了某人话不想报警,蓝洁瑛性侵案告破，比起曾志伟，这位大佬连周润发郑少秋都不敢惹


In [48]:
## for test data

test_dataset = NewsPairDataset(tokenizer, df_test)

test_loader = DataLoader(
    dataset=test_dataset,
    batch_size=eval_batch_size,
    collate_fn=create_mini_batch)

with torch.no_grad():
    pred = []
    for data in tqdm(test_loader):
        input_ids, token_type_ids, attention_mask = [d.to(device) for d in data]

        outputs = model(
            input_ids=input_ids,
            token_type_ids=token_type_ids,
            attention_mask=attention_mask
        )
        indexes = outputs.logits.argmax(dim=-1).cpu().tolist()
        pred += [i for i in indexes]

df_result = df_test[['id']].copy()
df_result['pred'] = pred
df_result.to_csv('result.csv', index=None)

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=4.0), HTML(value='')))




In [62]:
# 處理solution.csv
df_ans = pd.read_csv('archive/solution.csv')
df_ans.rename(columns={'Id': 'id'}, inplace=True)
df_ans = df_ans[['id', 'Expected']].dropna(axis=0).reset_index(drop=True)
df_ans['Expected_num'] = df_ans.Expected.replace(label_dict)
df = df_ans.merge(df_result, on='id')
df = df[~df.pred.isna()]
df

Unnamed: 0,id,Expected,Expected_num,pred
0,322521,unrelated,0,0
1,322678,unrelated,0,0
2,322679,unrelated,0,0
3,322670,unrelated,0,0
4,322671,unrelated,0,0
...,...,...,...,...
1597,322013,unrelated,0,0
1598,322014,unrelated,0,1
1599,322015,agreed,1,0
1600,322016,agreed,1,1


In [63]:
test_acc = np.mean(df['Expected_num'] == df['pred'])
print(f'test accuarcy: {test_acc}')

test accuarcy: 0.7852684144818977


In [68]:
# predict pairs {'unrelated': 0, 'agreed': 1, 'disagreed': 2}

test1 = '今天天氣很好'
test2 = '今天天氣不錯'

data_df = pd.DataFrame({'title1_zh': [test1], 'title2_zh': [test2]})

data_news_pairs = NewsPairDataset(tokenizer, data_df)

with torch.no_grad():
    input_ids = []
    token_type_ids = []
    attention_mask = []

    input_ids.append(data_news_pairs[0][0]['input_ids'].squeeze(0))
    token_type_ids.append(data_news_pairs[0][0]['token_type_ids'].squeeze(0))
    attention_mask.append(data_news_pairs[0][0]['attention_mask'].squeeze(0))

    input_ids = torch.nn.utils.rnn.pad_sequence(input_ids, batch_first=True, padding_value=0)
    token_type_ids = torch.nn.utils.rnn.pad_sequence(token_type_ids, batch_first=True, padding_value=0)
    attention_mask = torch.nn.utils.rnn.pad_sequence(attention_mask, batch_first=True, padding_value=0)

    outputs = model(input_ids=input_ids.to(device), token_type_ids=token_type_ids.to(device), attention_mask=attention_mask.to(device))
    indexes = outputs.logits.argmax(dim=-1).cpu()
    if indexes[0] == 0:
        print('unrelated')
    elif indexes[0] == 1:
        print('agreed')
    else:
        print('disagreed')

agreed
