In [1]:
import sys
sys.path.append('../')

import retrieval_bot
from config import corpus_path
from config import cohesion_path
from config import vectorizer_path
from config import graph_path
from config import send_x_path
from config import reply_x_path
from config import evaluator_model_path

## pair corpus

In [2]:
from retrieval_bot import PairCorpus

corpus = PairCorpus(corpus_path)
print('num pairs = {}'.format(len(corpus)))

num pairs = 5903206


## cohesion

In [3]:
from retrieval_bot.tokenizer import CohesionScore

corpus.iter_pair = False
cohesion_trainer = CohesionScore(debug=False)
#scores = cohesion_trainer.train_and_scores(corpus)
#cohesion_trainer.save(cohesion_path)
cohesion_trainer.load(cohesion_path)
scores = cohesion_trainer.scores()

## tokenizer

In [4]:
from retrieval_bot.tokenizer import MaxScoreTokenizer

tokenizer = MaxScoreTokenizer(scores=scores)
tokenizer.tokenize('아니지금어디냐고?왜아직도안와?')

['아니', '지금', '어디', '냐', '고?', '왜', '아직도', '안와?']

## vectorizer

In [5]:
from retrieval_bot.vectorizer import Vectorizer

corpus.limit_pairs = -1
corpus.iter_pair = False
vectorizer = Vectorizer(tokenizer=tokenizer, min_tf=2)
# vectorizer = vectorizer.fit(corpus)
# vectorizer.save(vectorizer_path)
vectorizer.load(vectorizer_path)
len(vectorizer.vocabulary_)

17761

## send2reply

In [6]:
from retrieval_bot.db import Send2Reply

corpus.iter_pair = True
send2reply = Send2Reply()
# send2reply.train(corpus)
# send2reply.save(graph_path)
send2reply.load(graph_path)
len(send2reply)

42025

In [7]:
from scipy.io import mmwrite
from scipy.io import mmread

# send_x = vectorizer.transform(send2reply.sends)
# reply_x = vectorizer.transform(send2reply.replies)
# mmwrite(send_x_path, send_x)
# mmwrite(reply_x_path, reply_x)

send_x = mmread(send_x_path).tocsr()
reply_x = mmread(reply_x_path).tocsr()

## Full search Index

In [8]:
from retrieval_bot.db import FullSearchIndexer

send_indexer = FullSearchIndexer(send_x)

In [9]:
dist, idx = send_indexer.kneighbors(query=send_x[0], n_neighbors=10, max_distance=0.5)
print(idx)
print(dist)

[    0 17039   963]
[0.         0.29289322 0.29289322]


## Default Message

In [10]:
from retrieval_bot.engine import DefaultMessage

default_message = DefaultMessage()
for _ in range(5):
    print(default_message.get_random_message())

음...
음...
어...
무슨 말이에요?
응??


## Evaluator

In [11]:
from retrieval_bot.evaluator import TermPairEvaluator

evaluator = TermPairEvaluator(evaluator_model_path)
for send_term in '뭐먹 어디'.split():
    send_term_idx = vectorizer.encode_a_doc_to_bow(send_term)
    for reply_term in '피자 치킨 지하철 사당역'.split():
        reply_term_idx = vectorizer.encode_a_doc_to_bow(reply_term)
        score = evaluator.evaluate(send_term_idx, reply_term_idx)
        print('{} - {} : {}'.format(send_term, reply_term, score))

뭐먹 - 피자 : 3.08618742336944
뭐먹 - 치킨 : 3.852132624607777
뭐먹 - 지하철 : 0
뭐먹 - 사당역 : 0
어디 - 피자 : 0
어디 - 치킨 : 0.4396905736037917
어디 - 지하철 : 2.259871019456051
어디 - 사당역 : 0.2413882193124051


## Engine

In [14]:
from retrieval_bot.engine import InstanceBasedRetrievalEngine

vectorizer.verbose = False
send2reply.verbose = False

engine = InstanceBasedRetrievalEngine(
    vectorizer = vectorizer,
    send2reply = send2reply,
    send_indexer = send_indexer,
    reply_x = reply_x,
    evaluator = evaluator
)

In [15]:
engine.default_message.get_random_message()

'음...'

In [16]:
engine.process('지금 어디냥? ')

['냐옹', '가구있엉', '안알랴줌', '집이댜', '카페']