In [1]:
import sys
sys.path.append('..')

import torch
import pandas as pd
import numpy as np
from torch import nn 
from torch.utils.data import DataLoader
from transformers import AutoTokenizer, AutoModelForTokenClassification
from datasets import Dataset
from torchcrf import CRF
from tqdm.notebook import tqdm_notebook
import logging

from utils import *
from dataset import *
from preprocess import *
from wrapper import *
from models import *
from pipeline import POSTaggingPipeline, map_to_df

torch.cuda.is_available()
# device = torch.device('cpu')
device = torch.device('cuda:0')

In [2]:
df = pd.read_csv('../data/data-org/train.csv', sep='\t').set_index('id')
corpus = df[df.label == 0].drop(columns=['label'])

In [3]:
model_name = "KoichiYasuoka/chinese-bert-wwm-ext-upos"

tagger = POSTaggingPipeline(model_name=model_name)
ds = tagger(texts=corpus, device=device, return_tags=False)

  indexed_value = torch.tensor(value[index]).squeeze()
100%|██████████| 718/718 [01:17<00:00,  9.23it/s]


In [4]:
ds.test = False
ds.train_val_split = 0.8
ds.construct_dataset()

In [5]:
id2label = tagger.model.config.id2label
num_tags = len(id2label)

crf = CRF(num_tags=num_tags, batch_first=True)
if 'cuda' in device.type:
    crf.cuda()

In [49]:
seq_len = ds.maxlength
batch_size = 128

def get_tagging_datasets(ds):
    try:
        return (
            ds.dataset['train'].with_format('pytorch', columns=['emissions', 'attention_mask']).rename_columns({'attention_mask':'mask'}),
            ds.dataset['val'].with_format('pytorch', columns=['emissions', 'attention_mask']).rename_columns({'attention_mask':'mask'}),
        )
    except:
        print('Warning: No dev set.')
        return ds.dataset['train'].with_format('pytorch', columns=['emissions', 'attention_mask']).rename_columns({'attention_mask':'mask'})

def process_batch(batch, device):
    inputs = {}
    if 'input_ids' in batch.keys():
        inputs['input_ids'] = torch.from_numpy(np.array(batch['input_ids'])).to(dtype=torch.int, device=device)
    inputs['mask'] = batch['mask'].to(device=device, dtype=torch.bool)
    inputs['emissions'] = torch.concat([torch.concat([x]).unsqueeze(1) for x in batch['emissions']], dim=1).to(device=device).long()
    inputs['tags'] = inputs['emissions'].argmax(-1).to(device=device, dtype=torch.int).long()
    return inputs

In [None]:
    
train_set, dev_set = get_tagging_datasets(ds)
train_dataloader = DataLoader(
    train_set, 
    batch_size=batch_size, 
    drop_last=True, 
)
dev_dataloader = DataLoader(
    dev_set, 
    batch_size=batch_size, 
    drop_last=True, 
)

In [15]:
from collections import OrderedDict

state_dict = OrderedDict([
    ('start_transitions', torch.load('crf_parameters_4723434/start_transitions.pt')), 
    ('end_transitions', torch.load('crf_parameters_4723434/end_transitions.pt')), 
    ('transitions', torch.load('crf_parameters_4723434/transitions.pt')), 
])

crf.load_state_dict(state_dict)

<All keys matched successfully>

In [87]:
sample_df = df.copy(deep=True)

In [88]:
model_name = "KoichiYasuoka/chinese-bert-wwm-ext-upos"

tagger = POSTaggingPipeline(model_name=model_name)
ds = tagger(texts=sample_df.text, device=device, return_tags=False)

id2label = tagger.model.config.id2label
num_tags = len(id2label)

crf = CRF(num_tags=num_tags, batch_first=True)
if 'cuda' in device.type:
    crf.cuda()

seq_len = ds.maxlength
batch_size_inference = 1

dataset = get_tagging_datasets(ds)
dataloader = DataLoader(
    dataset, 
    batch_size=batch_size_inference, 
    drop_last=True, 
)

logliks = []
for i, batch in enumerate(tqdm_notebook(dataloader)):
    inputs = process_batch(batch, device=device)
    loglik = crf(**inputs)
    logliks.append(loglik.detach().cpu().numpy())

  indexed_value = torch.tensor(value[index]).squeeze()
100%|██████████| 2828/2828 [05:06<00:00,  9.22it/s]




  0%|          | 0/45248 [00:00<?, ?it/s]

In [83]:
sample_df['llh'] = np.exp(np.array(logliks))
sample_df = sample_df[['label', 'llh', 'text']]
with pd.option_context('display.max_rows', None, 'display.max_columns', None, ):
    pd.options.display.max_colwidth = 100
    display(sample_df.sort_values(by='llh'))

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  sample_df['llh'] = np.exp(np.array(logliks))


Unnamed: 0_level_0,label,llh,text
id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
35786,0,2.1e-05,在近些年中，我们对类似过程都已经不陌生，如小燕子热、金庸热、大辫子热、警察热、大话西游热、哈里*波特热，直至最近的《流星花园》热......包括美女作家、网络文学、雪村的走红乃至唐装热等；而且...
36281,0,0.000127,《千年之约·梦幻龟兹》由浙江援疆资金扶持打造，整台演出以塔里木歌舞团为班底，融合龟兹乐舞、胡旋舞等多种绚丽多姿的舞蹈，重现了我国汉唐时期阿克苏地区所在的龟兹、姑墨、温宿等地的辉煌盛景。
34564,0,0.000654,日前，瑞士瑞信银行发布了一份报告，指出中国成年人人均财富值，在过去十年中从2000年的6000美元增至2010年的18000美元。这一财富数据是否真正反映我国群众的实际财富水平，是一个需要探讨...
30882,1,0.000672,观赏自然风景也是如此，苏东坡有诗云“不识庐山真面目，只缘身在此山中”，就是讲的距离太近而又无法欣赏庐山的自然美这样一种情形。
3855,1,0.00128,瑞士某激进组织提交的一份旨在以实现居民收入平等为目的的修改宪法提案引起了社会的强烈反响，很多政府官员对此提出质疑。
34910,1,0.00144,看涨的共享单车市场和摩拜ofo的双雄争霸并不意味着毫无风险。无论在哪个城市，共享单车均面临着停车难、停车乱，如果没有合理的疏导，很可能“解决最后一公里出行”变成了“阻挡最后五十米交通”。
38102,0,0.001678,我国古代城池的北门常常被称为玄武门，我想这可能是凶为秦始皇统一天下以后，历朝的威胁主要来自北方，所以统治者有意用张牙舞爪的龟蛇状的玄武来威慑外敌。
27372,1,0.001741,《神犬小七第二季》将励志正能量与都市偶像情感融合，加之动物演员的增多，既表现了现代都市剧的“新鲜感”，也将呈现一个“人犬大同”的正能量社会。
11042,1,0.00175,党风廉政建设责任能不能担当起来，关键在于主体责任这个“牛鼻子”抓没抓住，然而各地不同程度地存在管党、治党失之于宽，监督责任落实不到位。
5324,0,0.002043,2008年8月1日，我国《反垄断法》正式实施，“利用市场支配地位”或者“滥用行政权力”来限制竞争的行为将从此被视为违法。


In [85]:
sample_df[sample_df.label == 0].llh.mean(), sample_df[sample_df.label == 1].llh.mean()

(0.03242170438170433, 0.04105721414089203)