<a href="https://colab.research.google.com/github/dotsnangles/Retrieval-Based-Chatbot/blob/main/chatbot_prototype.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#  Click here and do Shift + Enter 3 times!

In [None]:
#@title
!pip install -q transformers datasets folium==0.2.1
!git clone https://github.com/dotsnangles/Poly-Encoder.git
%cd /content/Poly-Encoder

!gdown -q --folder 1Ipr-aNF5ELMY0HTXAmeV26LlgktKUfmG
!gdown -q --folder 1RH7laK4WlucCw68ZeExFvyg7vs-kB_x3

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import BertPreTrainedModel, BertConfig, BertModel, BertTokenizer, AutoModel
from encoder import PolyEncoder
from transform import SelectionJoinTransform, SelectionSequentialTransform

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# print(device)

PATH = '/content/Poly-Encoder/chatbot_output/poly_16_pytorch_model.bin'

bert_name = 'klue/bert-base'
bert_config = BertConfig.from_pretrained(bert_name)

tokenizer = BertTokenizer.from_pretrained(bert_name)
tokenizer.add_tokens(['\n'], special_tokens=True)

context_transform = SelectionJoinTransform(tokenizer=tokenizer, max_len=256)
response_transform = SelectionSequentialTransform(tokenizer=tokenizer, max_len=128)

bert = BertModel.from_pretrained(bert_name, config=bert_config)

model = PolyEncoder(bert_config, bert=bert, poly_m=16)
model.resize_token_embeddings(len(tokenizer))
model.load_state_dict(torch.load(PATH))
model.to(device)
model.device

context = ['This framework generates embeddings for each input sentence', 
            'Sentences are passed as a list of string.', 
            'The quick brown fox jumps over the lazy dog.']

candidates = ['This framework generates embeddings for each input sentence', 
            'Sentences are passed as a list of string.', 
            'The quick brown fox jumps over the lazy dog.']

def context_input(context):
    context_input_ids, context_input_masks = context_transform(context)
    contexts_token_ids_list_batch, contexts_input_masks_list_batch = [context_input_ids], [context_input_masks]

    long_tensors = [contexts_token_ids_list_batch, contexts_input_masks_list_batch]

    contexts_token_ids_list_batch, contexts_input_masks_list_batch = (torch.tensor(t, dtype=torch.long, device=device) for t in long_tensors)

    return contexts_token_ids_list_batch, contexts_input_masks_list_batch
contexts_token_ids_list_batch, contexts_input_masks_list_batch = context_input(context)

def response_input(candidates):
    responses_token_ids_list, responses_input_masks_list = response_transform(candidates)
    responses_token_ids_list_batch, responses_input_masks_list_batch = [responses_token_ids_list], [responses_input_masks_list]

    long_tensors = [responses_token_ids_list_batch, responses_input_masks_list_batch]

    responses_token_ids_list_batch, responses_input_masks_list_batch = (torch.tensor(t, dtype=torch.long, device=device) for t in long_tensors)

    return responses_token_ids_list_batch, responses_input_masks_list_batch
responses_token_ids_list_batch, responses_input_masks_list_batch = response_input(candidates)

def embs_gen(contexts_token_ids_list_batch, contexts_input_masks_list_batch):

    with torch.no_grad():
        model.eval()
        
        ctx_out = model.bert(contexts_token_ids_list_batch, contexts_input_masks_list_batch)[0]  # [bs, length, dim]
        poly_code_ids = torch.arange(model.poly_m, dtype=torch.long).to(contexts_token_ids_list_batch.device)
        poly_code_ids = poly_code_ids.unsqueeze(0).expand(1, model.poly_m)
        poly_codes = model.poly_code_embeddings(poly_code_ids) # [bs, poly_m, dim]
        embs = model.dot_attention(poly_codes, ctx_out, ctx_out) # [bs, poly_m, dim]

        return embs
embs = embs_gen(contexts_token_ids_list_batch, contexts_input_masks_list_batch)

def cand_emb_gen(responses_token_ids_list_batch, responses_input_masks_list_batch):

    with torch.no_grad():
        model.eval()
                
        batch_size, res_cnt, seq_length = responses_token_ids_list_batch.shape # res_cnt is 1 during training
        responses_token_ids_list_batch = responses_token_ids_list_batch.view(-1, seq_length)
        responses_input_masks_list_batch = responses_input_masks_list_batch.view(-1, seq_length)
        cand_emb = model.bert(responses_token_ids_list_batch, responses_input_masks_list_batch)[0][:,0,:] # [bs, dim]
        cand_emb = cand_emb.view(batch_size, res_cnt, -1) # [bs, res_cnt, dim]

        return cand_emb
cand_emb = cand_emb_gen(responses_token_ids_list_batch, responses_input_masks_list_batch)

def loss(embs, cand_emb, contexts_token_ids_list_batch, responses_token_ids_list_batch):
    batch_size, res_cnt, seq_length = responses_token_ids_list_batch.shape

    ctx_emb = model.dot_attention(cand_emb, embs, embs) # [bs, bs, dim]
    # print(ctx_emb)
    ctx_emb = ctx_emb.squeeze()
    # print(ctx_emb)
    dot_product = (ctx_emb*cand_emb) # [bs, bs]
    # print(dot_product)
    dot_product = dot_product.sum(-1)
    print(dot_product)
    mask = torch.eye(batch_size).to(contexts_token_ids_list_batch.device) # [bs, bs]
    print(mask)
    loss = F.log_softmax(dot_product, dim=-1)
    print(loss)
    loss = loss * mask
    print(loss)
    loss = (-loss.sum(dim=1))
    print(loss)
    loss = loss.mean()
    print(loss)
    return loss
# loss_ = loss(embs, cand_emb, contexts_token_ids_list_batch, responses_token_ids_list_batch)

def score(embs, cand_emb):
    with torch.no_grad():
        model.eval()

        ctx_emb = model.dot_attention(cand_emb, embs, embs) # [bs, res_cnt, dim]
        dot_product = (ctx_emb*cand_emb).sum(-1)
        
        return dot_product
score_ = score(embs, cand_emb)

# forward
with torch.no_grad():
    model.eval()
    
    model_foward = model(contexts_token_ids_list_batch, contexts_input_masks_list_batch, responses_token_ids_list_batch, responses_input_masks_list_batch)

### 데이터 검증
import pickle

with open('/content/Poly-Encoder/감성대화챗봇데이터/train_data_source.pickle', 'rb') as f:
    train = pickle.load(f)
with open('/content/Poly-Encoder/감성대화챗봇데이터/val_data_source.pickle', 'rb') as f:
    dev = pickle.load(f)
# index = 500
# train[index]['context']
# train[index]['responses']
# dev[index]['context']
# dev[index]['responses']

### 챗봇 테이블 생성
data = {
    'context' : [],
    'response': []
}

for sample in train:
    data['context'].append(sample['context'])
    data['response'].append([sample['responses'][0]])
# len(data['context']), len(data['response'])
# idx = 400
# print(data['context'][idx])
# print(data['response'][idx])

import pandas as pd
df = pd.DataFrame(data)

### generate cand_embs & create tensor table
# response_input_srs = df['response'].apply(response_input)
# response_input_lst = response_input_srs.to_list()

# cand_embs_lst = []
# for sample in response_input_lst:
#     cand_embs_lst.append(cand_emb_gen(*sample).to('cpu'))
# df['response embedding'] = cand_embs_lst
# df[['response', 'response embedding']]
# cand_embs = cand_embs_lst[0]
# for idx in range(1, len(cand_embs_lst)):
#     y = cand_embs_lst[idx]
#     cand_embs = torch.cat((cand_embs, y), 1)
# cand_embs = cand_embs.to(device)

import pickle
with open('/content/Poly-Encoder/감성대화챗봇데이터/cand_embs.pickle', 'rb') as f:
    cand_embs = pickle.load(f)
cand_embs.to(device)

### generate context_embs
query = ['너무 성급한 결정을 한 것 같아.']
embs = embs_gen(*context_input(query))

### Score & Retrieve
import time
start = time.time()
s = score(embs, cand_embs)
end = time.time()
idx = s.argmax(1)
idx = int(idx[0])
# df['response'][idx]
# df.iloc[idx]['context']

In [None]:
#@title
### Chatbot UI
consult_context = {
    'num':[],
    'name': [],
    'customer':[],
    'chatbot':[]
}

print('안녕하세요. 공감 만땅이~~⭐️ 공감이🍀 입니다.')
print('세상에 완벽한 사람 없고, 완벽하지 않은 게 잘못이 아닌 것처럼❌')
print('공감이도 부족한 면이 있지만 당신의 얘기에 집중할꺼에요~!😎')
print('공감이가 당신을 이해할 수 있도록 당신에 대해 길게 말해주세요. (비밀인데 TMI 좋아해요💕)')
print('공감이는 언제나 당신 편입니다. 🥰')

print()
print('-'*50)
print()

while True:

  print('공감이 : ')
  print('종료를 원한다면 "종료"를 입력해주세요.')
  name = str(input('성함이 어떻게 되시나요?: '))
  
  if name == '종료':
    break

  while True:
    print()
    print('공감이🍀 : ')
    confirm = input(f'{name}님 맞으신가요? (네/아니요): ')
    print()

    if confirm == '네':
      break

    else:
      print('공감이🍀 : ')
      name = str(input('성함을 다시 입력해주세요!: '))

  print('-'*50)
  print()

  print(f'<<< {name}님 만약 상담을 그만두고 싶으시다면 "끝"를 입력해주세요. >>>')

  print()
  print('<<< 공감이는 아직 대화의 맥락을 추적하지는 못 합니다. >>>')
  print('<<< 그저 두서 없이 마음을 털어놓아 보세요. >>>')
  print('<<< 공감이는 따뜻한 말들만을 배웠답니다. >>>')
  print()
  print('<<< 다섯 단어 이상 입력 시 보다 정확한 답변이 가능합니다. >>>')
  print()
  print(f'{name}님의 고민은 무엇인가요?')
  print()

  count = 1
  
  list_answer = []
  best_num = -1
  
  while True:

    print(f'{name} : ')
    query = [str(input())]
    print()

    best_num = -1
    embs = embs_gen(*context_input(query))
    s = score(embs, cand_embs)
    idx = int(s[0].sort()[-1][best_num])

    best_answer = df['response'][idx][0]

    while True:

      if best_answer not in list_answer:
        break
      
      else:
        best_num += -1
        idx = int(s[0].sort()[-1][best_num])
        best_answer = df['response'][idx][0]

    if query == ['끝']:

      print('-'*50)
      print('-'*50)
      print()

      print(f'공감이🍀는 {name}님이 언제나 행복하시길 바랍니다. 감사합니다')
      print()
      print('-'*50)
      print()
      break

    consult_context['customer'].append(*query)

    print('공감이🍀 : ')
    print(best_answer)
    print()

    consult_context['chatbot'].append(best_answer)
    list_answer.append(best_answer)

    if len(list_answer) == 3:
      del list_answer[0:1]
    
    consult_context['num'].append(count)

    consult_context['name'].append(name)

    count += 1

context = pd.DataFrame(consult_context)