In [1]:
import sys
sys.path.append('/source/main')

In [16]:
from itertools import chain

import numpy as np
import pandas as pd
pd.set_option('display.max_colwidth', -1)
import torch
from naruto_skills.new_voc import Voc
from tqdm import tqdm

from model_def.lstm_attention import LSTMAttention
from data_for_train import get_it_and_train
from data_for_train.positive_dataset import PositiveDataset
from data_for_train.topic_dataset import Topic
from preprocess import preprocessor

# Setup, loading model

In [3]:
model = LSTMAttention(get_it_and_train.voc.get_embedding_weights())
model = model.to(device)
device = torch.device('cuda')
checkpoint = torch.load('/source/main/train/output/saved_models/train_with_share/None.pt', map_location=device)
print(model.load_state_dict(checkpoint['model_state_dict'], strict=False))
model.eval()

IncompatibleKeys(missing_keys=[], unexpected_keys=['xent.weight'])


In [4]:

def docs2input_tensors(docs, device):
    preprocessed_docs = [preprocessor.infer_preprocess(doc) for doc in docs]
    max_len = max([len(item.split()) for item in preprocessed_docs])
    word_input = get_it_and_train.voc.docs2idx(preprocessed_docs, equal_length=max_len)
    inputs = np.array(word_input)
    input_tensors = torch.from_numpy(inputs)
    input_tensors = input_tensors.to(device)
    return input_tensors

def predict_batch(docs):
    with torch.no_grad():
        input_tensors = docs2input_tensors(docs, device)
        predict_tensor = model(input_tensors)
        predict_np = predict_tensor.cpu().numpy()
        return predict_np[:, 1]

def predict_docs(docs, batch_size):
    return list(chain(*[predict_batch(docs[i: i+batch_size]) for i in tqdm(range(0, len(docs), batch_size))]))

model.eval()

# Recall

In [29]:
df_milk_price = pd.read_csv('/source/main/data_download/output/question_milk.csv')
df_milk_price['pred'] = predict_docs(df_milk_price['Mention'], batch_size=128)
print(sum(df_milk_price['pred'] > 0.5)/df_milk_price.shape[0])

100%|██████████| 25/25 [00:01<00:00, 24.19it/s]

0.896887159533074





In [28]:
voc = Voc.load('/source/main/vocab/output/voc.pkl')
positive_data = PositiveDataset('/source/main/data_for_train/output/positive_class_1.csv')
df_pos = pd.DataFrame({'mention': list(positive_data)})

In [None]:
df_milk_price.head()

In [None]:
df_pos.head()

# Precision

In [20]:
df = pd.read_csv('/source/main/data_for_train/output/huge_pool/topics/6877.csv')
df.dropna(subset=['mention'], inplace=True)
print(df.shape[0])

107805


In [21]:
df['pred'] = predict_docs(df['mention'], batch_size=128)

100%|██████████| 843/843 [00:52<00:00, 18.07it/s]


In [22]:
df[df['pred'] >= 0.5].shape

(476, 16)

In [25]:
df[df['pred']>=0.5][['mention', 'topic_id', 'id']].sample(100)

Unnamed: 0,mention,topic_id,id
78445,giá này đắt hơn giá m mua là sao ?,6877,27015e12-7858-5dd6-b80a-d71a0567eba1
49533,giá sao nàng u,6877,d656c761-a0f5-5574-a881-39109d3cf093
78494,danh sách quay thuởg ngay __d__ chua có vay ad,6877,05561929-e183-5ab8-adda-60c49a43daed
41369,gia ban bao nhieu __d__ hop,6877,c5d95db9-c4bb-570b-8b47-450aa13c71c7
74396,"< div class= '' bbcodeblock bbcodequote '' > < div class= '' attribution type '' > < i class= '' fa fa-quote-left '' > < /i > thành trần __d__ nói : < a class= '' attributionlink '' > ↑ < /a > < /div > < blockquote class= '' quotecontainer '' > < div class= '' quote '' > fpt đem so với ibm , oracle ! < img src= '' http : //otofun.netstyles/yahoo/__d__.gif '' class= '' mcesmilie '' alt= '' = ) ) '' title= '' rolling on the floor = ) ) '' > < /div >",6877,46491604-9dd1-5ebf-97c2-5ef13c03bd6c
83035,"nhà em cũng thích uống nescafe nhất , một hộp này giá bao nhiêu đấy c",6877,25c8ae85-c013-572f-9636-75af34c0d1f7
55346,c ơi áo hoa trong video là c mặc size bao nhiêu vậy ạ ?,6877,51e5f421-df59-5d5b-9132-a2ce99658258
84593,cho mình hỏi __d__ sđt & __d__ số cmnd dc nt nhiều mã thẻ cào dc ko ad ? nhưng nếu trúng giải kk thì dc nhiều giải . còn giải đb thì chỉ __d__ . có đúng ko ad ?,6877,74b9c24f-12f9-5719-88af-69ff9971f3fa
37939,loai den bao nhieu,6877,2c150232-fe39-5e03-86c0-ae28a1f62049
78727,bao nhieu __d__ hộp vay j,6877,863cd19a-656f-5c51-b827-c3885adfebfb


In [26]:
df[df['pred']>=0.5][['mention', 'topic_id', 'id']].to_csv('pred_topic_6877.csv', index=None)