In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
import shutil
import zipfile
import pprint
from pathlib import Path
import pickle
import sys

from ipywidgets import interact, interactive, fixed, interact_manual,Button, HBox, VBox
import ipywidgets as widgets
from ipywidgets import TwoByTwoLayout
from ipywidgets import Button, Layout
from termcolor import colored
from transformers import BlenderbotTokenizer, BlenderbotForConditionalGeneration, BlenderbotConfig

add_paths = ['nlp_suite/chatbot/style_transfer_paraphrase', 'nlp_suite/chatbot/style_transfer_paraphrase/style_paraphrase']
for add_path in add_paths: 
    if add_path not in sys.path: sys.path.append(add_path)

from nlp_suite import data_preprocessing
from nlp_suite.chatbot.style_transfer_chatbot import StyleTransferChatbot

[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


In [3]:
############################ PRE-SETUP ##########################
##################################################################################
# Setup (ie setting options, loading data, models, etc) that can be done beforehand,
# agnostic to the user input data, goes here.

# loading chatbot model
blenderbot_chatbot_model = BlenderbotForConditionalGeneration.from_pretrained('facebook/blenderbot-400M-distill').to("cuda")

# silence tqdm
from tqdm import tqdm
from functools import partialmethod
tqdm.__init__ = partialmethod(tqdm.__init__, disable=True)

In [4]:
def process_inputs(button_instance):
    print("Processing...")
    temp_files_dir = "./dashboard_temp_files"
    global user_name
    user_name = user_name_widget.value
    
    if user_name != "" and uploader_widget.value != {}:
        # prepping an empty temp folder to hold uploaded data
        if os.path.exists(temp_files_dir):
            shutil.rmtree(temp_files_dir)
        os.makedirs(temp_files_dir)
        
        # saving uploaded zip
        uploaded_file_path = os.path.join(temp_files_dir, "saved-output.zip")
        with open(uploaded_file_path, "wb") as f:
            content = (uploader_widget.value[list(uploader_widget.value.keys())[0]])["content"]
            f.write(content)
            
        # unzipping
        with zipfile.ZipFile(uploaded_file_path, 'r') as zip_ref:
            zip_ref.extractall(temp_files_dir)
        os.remove(uploaded_file_path)
        
        # extracting user's messages from text files
        discord_log_paths = [str(path) for path in list(Path(temp_files_dir).rglob('*.txt'))]
        channel_messages, _ = data_preprocessing.process_discord_data(discord_log_paths, 3)
        
        # saving user messages in cache 
        if user_name in channel_messages:
            cached_user_data_dir = os.path.join("cached_user_data", user_name)
            os.makedirs(cached_user_data_dir, exist_ok=True)
            user_messages_path = os.path.join(cached_user_data_dir, "user_messages.p")
            pickle.dump(channel_messages[user_name], open(user_messages_path, "wb"))
            global user_messages
            user_messages = pickle.load(open(user_messages_path, "rb"))
            print("Data processing finished.")
        else:
            print("Data processing failed -- couldn't find user {} in chat logs.".format(user_name))
    else:
        print("Data processing failed -- missing username or chat log.")
        

In [5]:
############################ USER DATA UPLOAD AND PROCESSING ##########################
##################################################################################
user_name = ""
user_messages = None

print("Please upload a zipped folder containing chat log file(s), and specify a user. \nThen, click the green button to start the data analysis.")
uploader_widget = widgets.FileUpload(description="Upload Chat", accept='.zip', multiple=False)
user_name_widget = widgets.Text(description='User Name:', disabled=False)
process_data_button_widget = widgets.Button(description='Process Chat Data', disabled=False, button_style='success')
process_data_button_widget.on_click(process_inputs)

display(HBox([uploader_widget, user_name_widget]))
display(process_data_button_widget)

Please upload a zipped folder containing chat log file(s), and specify a user. 
Then, click the green button to start the data analysis.


HBox(children=(FileUpload(value={}, accept='.zip', description='Upload Chat'), Text(value='', description='Use…

Button(button_style='success', description='Process Chat Data', style=ButtonStyle())

Processing...
Data processing finished.


In [7]:
print(user_name)
print(len(user_messages))
pprint.pprint(user_messages[:25])

muffins
12391
['Say... Why not have the bot just say the Welcome? Instead of you having to '
 "do it every time, y'know? w",
 'Do you... have a macro for that?',
 'Wait so... for the giveaway... Is it just chosen randomly?',
 'Awh rip tryna rig it, eh?',
 'Hheh... can I add one more to chur friends list?',
 'Is that... squints... is that a tilde',
 'Mhm... Chur doing a good job though',
 'Exynos... I think he got it.',
 'Woah... hello:D Is that... Mod?',
 'I think you can do over 1k!',
 'Plates are not good for smashing. Shards hurt:c',
 'From this moment onward',
 'I would like to be your informally obsessive psychiatrist helper person.',
 'If you will accept me.',
 'Alright... Hello @Victoria. How are you?',
 'Aight. Anything bothering ya?',
 "What's wrong with piano?",
 "That's a very edgy, doubtful time in a persons lifu.",
 "It's kinda edgy... y'know?",
 'For most people... I think.',
 'What command for that one:3',
 'Ah... Daddy we needchu now get outta erp',
 "Is it cuz' chu wan

In [9]:
############################ CHATBOT ##########################
###############################################################

# NOTE: this takes about 1.5 mins to load. Not sure how to make faster
style_model_dir = "cached_user_data/{}/style_transfer_paraphrase_checkpoint".format(user_name)
chatbot = StyleTransferChatbot(style_model_dir, blenderbot_chatbot_model)
chatbot_tokenizer = chatbot.tokenizer

# messages seperator based on https://github.com/huggingface/transformers/issues/9365
seperator = "    "
all_conversation_text = ""
max_position_embeddings = 128

out = widgets.Output(layout={'border': '2px solid black', "height":"400px", "width":"900px","overflow":'scroll'})

@out.capture()
def submit_to_chatbot(next_user_input=""):
    global all_conversation_text
    
    # adding user input to conversation text
    print(colored('> User', 'blue', attrs=["bold"]) + ": {}".format(next_user_input))
    if all_conversation_text != "": 
        next_user_input = seperator + next_user_input
    all_conversation_text += next_user_input
    
    # truncate token length if needed
    curr_token_length = chatbot_tokenizer([all_conversation_text], return_tensors='pt')['input_ids'].shape[1]
    while curr_token_length >= max_position_embeddings:
        all_conversation_text = seperator.join(all_conversation_text.split(seperator)[1:])
        curr_token_length = chatbot_tokenizer([all_conversation_text], return_tensors='pt')['input_ids'].shape[1]
    
    # get response
    bot_stylized_response, bot_response = chatbot.get_response(all_conversation_text)
    print(colored('> Bot (original)', 'green', attrs=["bold"]) + ": {}".format(bot_response))
    print(colored('> Bot (stylized)', 'red', attrs=["bold"]) + ": {}".format(bot_stylized_response))
    print("")
    all_conversation_text += seperator + bot_stylized_response  

def restart(button_instance):
    global all_conversation_text
    all_conversation_text = ""
    out.clear_output()

button = widgets.Button(description='Restart Chat',disabled=False,button_style='')
button.on_click(restart)

text_box = widgets.Text(description="User Chat:")
text_input = interactive(submit_to_chatbot,{'manual': True, "manual_name":"Submit Message"}, next_user_input=text_box)
left = VBox([out,button])
display(HBox([text_input,left]))

[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


HBox(children=(interactive(children=(Text(value='', description='User Chat:'), Button(description='Submit Mess…