In [12]:
#Load packages
import numpy as np
import pandas as pd
import random
import torch
import json
from torch import nn
import torch.nn.functional as F
import gluonnlp as nlp
from kobert.utils import get_tokenizer
from kobert.pytorch_kobert import get_pytorch_kobert_model

%matplotlib inline
import matplotlib.pyplot as plt
import matplotlib.image as mpimg

device = torch.device("cpu")

In [13]:
with open('labtoinfo.json') as f:
    labtoinfo = json.load(f)
with open('codetolab.json') as f:
    codetolab = json.load(f)

In [14]:
with open('conv.txt') as f:
    lines = f.readlines()
conv = []
for i in lines:
    conv.append(i.strip())

In [15]:
convlab = []
for i in conv:
    convlab.append(str(codetolab[i]))

In [16]:
labpool = []
for i in labtoinfo:
    if i not in convlab:
        labpool.append(i)

In [6]:
bertmodel, vocab = get_pytorch_kobert_model()
tokenizer = get_tokenizer()
tok = nlp.data.BERTSPTokenizer(tokenizer, vocab, lower=False)
indexer = nlp.data.BERTSentenceTransform(tok, max_seq_length=32, pad=True, pair=False)

using cached model
using cached model
using cached model


In [7]:
class BERTClassification(nn.Module):
    def __init__(self, bert, hidden_size = 768, dr_rate = .5):
        super(BERTClassification, self).__init__()
        self.bert = bert
        self.d = nn.Dropout(p = dr_rate)
        self.linear = nn.Linear(hidden_size, 531)
        
    def gen_attention_mask(self, token_ids, valid_length):
        attention_mask = torch.zeros_like(token_ids)
        for i, v in enumerate(valid_length):
            attention_mask[i][:v] = 1
        return attention_mask.float()

    def forward(self, token_ids, valid_length, segment_ids):
        attention_mask = self.gen_attention_mask(token_ids, valid_length)
        _, pooler = self.bert(input_ids = token_ids, token_type_ids = segment_ids.long(), attention_mask = attention_mask.float().to(token_ids.device))
        out = self.d(pooler)
        out = self.linear(out)
        return out

In [8]:
model = BERTClassification(bertmodel).to(device)
model.load_state_dict(torch.load('COVID_Class_BERT4.pth',map_location=device))
model.eval()

BERTClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(8002, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=

In [18]:
def covidbertwans(inp):
    token_ids, valid_length, segment_ids = indexer([inp])
    token_ids = torch.tensor(token_ids).view(1,len(token_ids)).long().to(device)
    segment_ids = torch.tensor(segment_ids).long().to(device)
    valid_length= torch.tensor(valid_length).view(-1)
    result = model(token_ids, valid_length, segment_ids).view(-1).detach().numpy()
    pred = np.argmax(result)
    top10pred = np.argsort(result)[::-1][:10]
    score = F.softmax(torch.tensor(result),dim=0).detach().numpy()\
    [np.argsort(F.softmax(torch.tensor(result),dim=0)).detach().numpy()[::-1][:10]]
    
    #### 답변을 찾지 못할경우 표시
    if score[0] < .5:
        print('죄송합니다ㅠ 질문에 대한 답변을 찾지 못했어요.')
        print(' ')
        print('혹시 이런게 궁금하신가요?:')
        count = 0
        j=1
        i=1
        while count <= 4:
            if str(top10pred[i]) in convlab:
                i+=1
            else :
                print(j,'. ',labtoinfo[str(top10pred[i])]['ques']) #, ' ', score[i-1]
                i+=1
                j+=1
                count+=1
    else :
        #### 답변 표시
        top1 = top10pred[0]
        
        if str(top1) in convlab:
            print('질문: ',labtoinfo[str(top1)]['ques'], ' ', score[0])
            print(' ')
            print('답변:')
            print(labtoinfo[str(top1)]['txtans'])
        else:
            print('질문: ',labtoinfo[str(top1)]['ques'], ' ', round(score[0],4)) #, ' ', score[0]
            print(' ')
            print('답변:')
            if str(labtoinfo[str(top1)]['txtans'])!='nan':
                print(labtoinfo[str(top1)]['txtans'])
            if str(labtoinfo[str(top1)]['imgans'])!='nan':
                direc = str(labtoinfo[str(top1)]['imgans'])
                img = mpimg.imread(direc,0)
                plt.figure(figsize=(10,10))
                imgplot = plt.imshow(img)
                plt.show()
            if str(labtoinfo[str(top1)]['ref'])!='nan':
                print('출처: ',labtoinfo[str(top1)]['ref'])

        #### 추천질문 표시  
        if str(top1) in convlab:
            j=1
            labs = random.choices(labpool,k=5)
            print(' ')
            print('혹시 이런게 궁금하신가요?:')
            for i in labs:
                print(j,'. ',labtoinfo[i]['ques'])
                j+=1
        else:
            count = 0
            i = 1
            j = 1
            print(' ')
            print('비슷한 질문들도 있어요:')
            while count <= 3:
                if str(top10pred[i]) in convlab:
                    i+=1
                else :
                    print(j,'. ',labtoinfo[str(top10pred[i])]['ques'], ' ', round(score[i-1],4)) #, ' ', score[i-1]
                    i+=1
                    j+=1
                    count+=1

In [29]:
inp = input()
covidbertwans(inp)

물건에다가 손 소독제 뿌려도 되나요?
질문:  물건의 소독에 손 소독제를 이용해도 되나요?   0.9737
 
답변:
자주 만지는 표면과 물건의 소독에 손 소독제를 사용하지 마세요.
출처:  https://www.cdc.gov/coronavirus/2019-ncov/faq.html
 
비슷한 질문들도 있어요:
1 .  손소독제 올바른 사용법은 무엇인가요?   0.9737
2 .  손 소독제 사용만으로도 코로나19가 충분히 예방 가능한가요?   0.0074
3 .  COVID-19 예방을 위해 어떻게 손을 씻어야하나요?   0.0022
4 .  COVID-19을 방지하기 위해 어떤 세척제를 사용해야 합니까?   0.0016
