In [2]:
TRAIN_SAMPLE_PATH = '/kaggle/input/scnuai-dataset/SCNUAI-dataset-v2/v2/train.txt'
DEV_SAMPLE_PATH = '/kaggle/input/scnuai-dataset/SCNUAI-dataset-v2/v2/dev.txt'
TEST_SAMPLE_PATH = '/kaggle/input/scnuai-dataset/SCNUAI-dataset-v2/v2/test.txt'

LABEL_PATH = '/kaggle/input/scnuai-dataset/SCNUAI-dataset-v2/v2/class.txt'

BERT_MODEL = '/kaggle/working/bert-base-chinese'
MODEL_DIR = '/kaggle/working/'
MODEL_TRAIN_DIR = '/kaggle/input/scnuai-models/model-v3-31.pth'
DEV_PRED_DIR = '/kaggle/input/scnuai-dataset/SCNUAI-dataset-v2/v2/dev_predict.csv'
NUM_CLASSES = 34

BERT_PAD_ID = 0
TEXT_LEN = 512
EMBEDDING_DIM = 768
NUM_FILTERS = 512
FILTER_SIZES = [2, 3, 4]
BATCH_SIZE = 16

EPOCH = 100
LR = 2e-6
CLASS_LABELS = []
with open(LABEL_PATH, 'r') as f:
    CLASS_LABELS += [line.strip() for line in f.readlines()]

import torch

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

if __name__ == '__main__':
    print(torch.tensor([1,2,3]).to(DEVICE))

tensor([1, 2, 3], device='cuda:0')


In [3]:
!git lfs install
!git clone https://huggingface.co/bert-base-chinese

Error: Failed to call git rev-parse --git-dir: exit status 128 
Git LFS initialized.
Cloning into 'bert-base-chinese'...
remote: Enumerating objects: 52, done.[K
remote: Counting objects: 100% (3/3), done.[K
remote: Compressing objects: 100% (3/3), done.[K
remote: Total 52 (delta 0), reused 0 (delta 0), pack-reused 49[K
Unpacking objects: 100% (52/52), 158.38 KiB | 4.40 MiB/s, done.
Filtering content: 100% (4/4), 1.59 GiB | 99.57 MiB/s, done.


In [4]:
from torch.utils import data
import torch
from transformers import BertTokenizer
from sklearn.metrics import classification_report

from transformers import logging
logging.set_verbosity_error()

class Dataset(data.Dataset):
    def __init__(self, type='train'):
        super().__init__()
        if type == 'train':
            sample_path = TRAIN_SAMPLE_PATH
        elif type == 'dev':
            sample_path = DEV_SAMPLE_PATH
        elif type == 'test':
            sample_path = TEST_SAMPLE_PATH

        self.lines = open(sample_path, encoding='utf-8').readlines()
        self.tokenizer = BertTokenizer.from_pretrained(BERT_MODEL)

    def __len__(self):
        return len(self.lines)

    def __getitem__(self, index):
        text, label = self.lines[index].split('\t')
        tokened = self.tokenizer(text)
        input_ids = tokened['input_ids']
        mask = tokened['attention_mask']
        if len(input_ids) < TEXT_LEN:
            pad_len = (TEXT_LEN - len(input_ids))
            input_ids += [BERT_PAD_ID] * pad_len
            mask += [0] * pad_len
        target = int(label)
        return torch.tensor(input_ids[:TEXT_LEN]), torch.tensor(mask[:TEXT_LEN]), torch.tensor(target)


def get_label():
    text = open(LABEL_PATH, encoding='utf-8').read()
    id2label = text.split()
    return id2label, {v: k for k, v in enumerate(id2label)}


def evaluate(pred, true, target_names=None, output_dict=False):
    return classification_report(
        true,
        pred,
        target_names=target_names,
        output_dict=output_dict,
        zero_division=0,
    )

if __name__ == '__main__':
    dataset = Dataset()
    loader = data.DataLoader(dataset, batch_size=2)
    print(iter(loader).__next__())



[tensor([[ 101,  122,  119,  ...,    0,    0,    0],
        [ 101, 3724, 8038,  ...,  711, 1398, 3333]]), tensor([[1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 1, 1, 1]]), tensor([18, 22])]


In [5]:
import torch.nn as nn
import torch.nn.functional as F
import torch
from transformers import BertModel

class TextCNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.bert = BertModel.from_pretrained(BERT_MODEL)
        for name ,param in self.bert.named_parameters():
            param.requires_grad = False
        self.convs = nn.ModuleList([nn.Conv2d(1, NUM_FILTERS, (i, EMBEDDING_DIM)) for i in FILTER_SIZES])
        self.linear = nn.Linear(NUM_FILTERS * 3, NUM_CLASSES)

    def conv_and_pool(self, conv, input):
        out = conv(input)
        out = F.relu(out)
        return F.max_pool2d(out, (out.shape[2], out.shape[3])).squeeze()

    def forward(self, input, mask):
        out = self.bert(input, mask)[0].unsqueeze(1)
        out = torch.cat([self.conv_and_pool(conv, out) for conv in self.convs], dim=1)
        return self.linear(out)


if __name__ == '__main__':
    model = TextCNN()
    input = torch.randint(0, 3000, (2, TEXT_LEN))
    mask = torch.ones_like(input)
    print(model(input, mask).shape)


torch.Size([2, 34])


In [6]:
# test.py
if __name__ == '__main__':

    id2label, _ = get_label()

    test_dataset = Dataset('test')
    test_loader = data.DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

    model = torch.load(MODEL_TRAIN_DIR, map_location=DEVICE)
    loss_fn = nn.CrossEntropyLoss()

    y_pred = []
    y_true = []

    with torch.no_grad():
        for b, (input, mask, target) in enumerate(test_loader):
            input = input.to(DEVICE)
            mask = mask.to(DEVICE)
            target = target.to(DEVICE)
            
            test_pred = model(input, mask)
            loss = loss_fn(test_pred, target)
            
            if b % 50 != 0:
                continue
            print('>> batch:', b, 'loss:', round(loss.item(), 5))

            test_pred_ = torch.argmax(test_pred, dim=1)

            y_pred += test_pred_.data.tolist()
            y_true += target.data.tolist()

    print(evaluate(y_pred, y_true, id2label))

>> batch: 0 loss: 0.10251
>> batch: 50 loss: 0.10044
>> batch: 100 loss: 0.59514
>> batch: 150 loss: 0.16634
>> batch: 200 loss: 0.48439
>> batch: 250 loss: 0.1812
>> batch: 300 loss: 0.54196
>> batch: 350 loss: 0.48736
>> batch: 400 loss: 0.28296
>> batch: 450 loss: 0.56988
>> batch: 500 loss: 0.58866
>> batch: 550 loss: 0.19513
>> batch: 600 loss: 0.43504
>> batch: 650 loss: 0.3032
>> batch: 700 loss: 0.86012
>> batch: 750 loss: 0.57758
>> batch: 800 loss: 0.1243
>> batch: 850 loss: 0.9034
>> batch: 900 loss: 0.24787
>> batch: 950 loss: 0.43722
>> batch: 1000 loss: 0.51042
>> batch: 1050 loss: 0.33841
>> batch: 1100 loss: 0.35089
>> batch: 1150 loss: 1.18072
>> batch: 1200 loss: 0.57395
>> batch: 1250 loss: 0.02866
>> batch: 1300 loss: 0.48989
>> batch: 1350 loss: 0.67886
>> batch: 1400 loss: 0.10679
>> batch: 1450 loss: 0.72558
>> batch: 1500 loss: 0.43295
>> batch: 1550 loss: 0.6534
>> batch: 1600 loss: 0.24394
>> batch: 1650 loss: 0.70025
>> batch: 1700 loss: 0.3475
>> batch: 1750

In [7]:
import pandas as pd
import numpy as np
df2 = pd.read_csv(DEV_PRED_DIR)
df2

Unnamed: 0,q_id,query
0,10681,原告李娟向本院提出诉讼请求：1、判令被告赔偿原告1000元；2、诉讼费由被告承担。事实及理由...
1,10682,在庭审过程中，原告诉称，2015年9月1日，原告与被告韩世中签订租赁合同，被告韩世中租用原告...
2,10683,经审查，原告通辽市国有资本投资运营有限公司在起诉状中载明的被告郑宝金的身份证号为×××，农业...
3,10684,原告白宜瑛向本院提出诉讼请求：1.由被告支付原告劳动报酬1000元，并按年利率6%计算，支付...
4,10685,半岛公司向本院提出诉讼请求：1.判令半岛公司无须向蔡燕军支付拖欠的剩余工资15000元；2、...
...,...,...
4495,15176,原告张某1向本院提出诉讼请求：要求判令三被告立即将原告送至广州市天河地区养老机构办理入住手续...
4496,15177,原告赵阳贸易（上海）有限公司与被告刘时利、山东巨野双运汽车运输有限公司、石洪振、崔玉魁、台前...
4497,15178,彭世灵向本院提出诉讼请求：判令顺昌县林东和食杂店偿还货款6226元及利息（从2018年9月3...
4498,15179,原告代万华诉称：2013年8月被告万洪华租用原告一台“现代-225”型挖掘机用于石棉县田湾乡...


In [8]:
# predict.py
from tqdm import trange
if __name__ == '__main__':
    id2label, _ = get_label()

    model = torch.load(MODEL_TRAIN_DIR, map_location=DEVICE)
    tokenizer = BertTokenizer.from_pretrained(BERT_MODEL)

    texts_total = df2['query'].tolist()
    res1 = []
    res2 = []
    for k in trange(0, df2.shape[0], BATCH_SIZE):
        texts = texts_total[k:k+BATCH_SIZE] if k+BATCH_SIZE<=df2.shape[0] else texts_total[k:]
        batch_input_ids = []
        batch_mask = []
        for text in texts:
            tokened = tokenizer(text)
            input_ids = tokened['input_ids']
            mask = tokened['attention_mask']
            if len(input_ids) < TEXT_LEN:
                pad_len = (TEXT_LEN - len(input_ids))
                input_ids += [BERT_PAD_ID] * pad_len
                mask += [0] * pad_len
            batch_input_ids.append(input_ids[:TEXT_LEN])
            batch_mask.append(mask[:TEXT_LEN])

        batch_input_ids = torch.tensor(batch_input_ids)
        batch_mask = torch.tensor(batch_mask)
        batch_input_ids = batch_input_ids.to(DEVICE)
        batch_mask = batch_mask.to(DEVICE)
        pred = model(batch_input_ids, batch_mask)
        pred_ = torch.argmax(pred, dim=1)

        res1.append(pred.tolist())
        res2.append([id2label[l] for l in pred_])

100%|██████████| 282/282 [01:39<00:00,  2.84it/s]


In [9]:
res_arr1 = np.array(np.array(res1[:-1]).reshape((-1, NUM_CLASSES)).tolist() + 
                    np.array(res1[-1]).reshape((-1, NUM_CLASSES)).tolist())
res_arr2 = np.array(np.array(res2[:-1]).reshape((-1, 1)).tolist() + 
                    np.array(res2[-1]).reshape((-1, 1)).tolist())
res_arr1

array([[ -7.94729662,  -8.15072727,  -9.00383854, ...,  -4.58894825,
         -7.60817862,  -4.11993408],
       [ -7.05258989,  -4.8899622 ,  -9.03686428, ...,  -5.78777075,
        -17.22094536,  -7.80389071],
       [ -2.3674953 ,  -3.44783807,  -2.1272037 , ...,  -2.02328563,
         -4.69382524,  -3.38607168],
       ...,
       [ -7.04855394,  -6.6351409 ,  -5.64733315, ...,  -4.64816618,
        -11.91704655,  -6.10369396],
       [ -7.01562548,  -4.93205833,  -9.24240208, ...,  -1.90189576,
        -15.45891666,  -6.38709068],
       [ -3.84646153,  -6.66759682,  -3.00454998, ...,  -3.0977633 ,
         -9.29241371,  -3.43854809]])

In [11]:
df_res = pd.concat([df2['q_id'], pd.DataFrame(res_arr2), pd.DataFrame(res_arr1)], axis=1)
df_res.columns = ['q_id', 'cat_1_pred'] + CLASS_LABELS
df_res

Unnamed: 0,q_id,cat_1_pred,交通事故-交通责任,公司事务-企业经营,公司事务-公司治理,公司事务-合伙企业,劳动人事-劳动争议,劳动人事-劳动合同,劳动人事-工资福利,合同事务-买卖合同,...,房地产纠纷-房产纠纷,房地产纠纷-房屋买卖,房地产纠纷-物业纠纷,民事纠纷-产品责任,民事纠纷-人格隐私,民事纠纷-人身侵权,民事纠纷-侵权纠纷,民事纠纷-物权纠纷,知识产权-版权软著,金融证券保险-保险纠纷
0,10681,民事纠纷-产品责任,-7.947297,-8.150727,-9.003839,-8.571435,-8.665410,-7.430333,-7.323967,0.239329,...,-9.914097,-5.990611,-5.505841,9.048461,-5.034771,-8.584026,-1.919343,-4.588948,-7.608179,-4.119934
1,10682,合同事务-租赁合同,-7.052590,-4.889962,-9.036864,-5.268051,-8.001605,-6.810842,-6.191727,-5.370970,...,-8.125882,-8.153060,-6.985040,-11.074688,-8.491351,-9.025531,-7.059803,-5.787771,-17.220945,-7.803891
2,10683,合同事务-借款合同,-2.367495,-3.447838,-2.127204,-2.775053,-1.738711,-0.840977,-2.201243,-4.277031,...,-6.262496,-7.048238,-5.485820,-4.381687,-2.775426,-1.655304,-2.124159,-2.023286,-4.693825,-3.386072
3,10684,劳动人事-工资福利,-9.222224,-5.491253,-7.930977,-4.388585,-5.293586,-0.948451,4.612601,-6.626989,...,-8.583509,-8.496793,-7.508375,-10.875101,-7.180912,-4.976589,-3.991175,-4.373614,-14.789743,-9.181959
4,10685,劳动人事-劳动争议,-8.710951,-5.144566,-6.195058,-5.298028,5.277725,5.236399,2.535692,-8.125552,...,-8.208947,-7.628975,-10.236546,-10.317654,-5.920093,-4.352610,-2.808634,-4.946956,-10.591956,-6.981973
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
4495,15176,婚姻家庭-老人赡养,-6.990736,-11.639692,-10.618101,-8.893085,-4.872028,-6.438754,-8.683522,-9.330384,...,-6.837704,-10.481118,-8.825538,-9.678265,-6.907429,-5.790136,-5.020528,-3.111784,-13.886804,-6.241533
4496,15177,合同事务-运输合同,0.565251,-3.857473,-2.943740,-4.025819,-5.845245,-5.804104,-7.335069,-3.849895,...,-7.662855,-8.937210,-7.756793,-2.132068,-2.837081,-1.317712,-1.179435,-0.067881,-3.866153,-4.365946
4497,15178,合同事务-买卖合同,-7.048554,-6.635141,-5.647333,-5.084998,-6.936284,-8.205406,-7.163452,0.909135,...,-8.656397,-5.601064,-7.310953,-5.972714,-9.587210,-10.112967,-4.750534,-4.648166,-11.917047,-6.103694
4498,15179,合同事务-租赁合同,-7.015625,-4.932058,-9.242402,-3.776425,-8.350371,-6.560231,-2.682569,-1.749906,...,-6.289149,-7.169786,-6.642865,-8.997999,-8.176384,-8.894166,-4.049396,-1.901896,-15.458917,-6.387091


In [13]:
df_res.to_excel('./dev_pred_res.xlsx', index=None)
from IPython.display import FileLink
FileLink('./dev_pred_res.xlsx')