In [14]:
!pip install transformers sentencepiece gradio

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [15]:
import torch

import re

import gradio as gr

In [16]:
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, T5Tokenizer, T5ForConditionalGeneration

chatbot_tokenizer = AutoTokenizer.from_pretrained("facebook/blenderbot-400M-distill")
chatbot_model = AutoModelForSeq2SeqLM.from_pretrained("facebook/blenderbot-400M-distill")

shakespeare_tokenizer = T5Tokenizer.from_pretrained('t5-base')
shakespeare_model = T5ForConditionalGeneration.from_pretrained('t5-base', pad_token_id=shakespeare_tokenizer.eos_token_id)
shakespeare_tokenizer.pad_token = shakespeare_tokenizer.eos_token

In [17]:
from google.colab import drive

import shutil

drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
# Loading best model
shakespeare_model.load_state_dict(torch.load('/content/drive/MyDrive/best_t5_base_model.pt', map_location=torch.device('cpu')))

In [21]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

chatbot_model = chatbot_model.to(device)
shakespeare_model = shakespeare_model.to(device)

In [22]:
# Get output from pretrained chatbot in modern English

def get_chatbot_output(input_text, temperature):
    
    chatbot_input = chatbot_tokenizer.encode(input_text)
    chatbot_output = chatbot_model.generate(input_ids=torch.tensor([chatbot_input]).to(device), do_sample=True, max_new_tokens=30, temperature=temperature)
    chatbot_output_string = chatbot_tokenizer.decode(chatbot_output[0], skip_special_tokens=True).strip()
    
    return chatbot_output_string

In [23]:
# Split chatbot output to pass each sentence separately

def split_chatbot_output(chatbot_output_string):

    # Get punctuations in order, as re.split() will ignore them
    punctuation_ordered = []
    for i in chatbot_output_string:
        if (i=='.' or i=='?' or i=='!'):
            punctuation_ordered.append(i)

    # Split chatbot output to list without punctuations
    chatbot_output_list = re.split('[.!?]', chatbot_output_string)
    
    # Remove empty strings and add respective punctuations
    chatbot_output_list_cleaned = []
    for i in range(len(chatbot_output_list)):
        if (chatbot_output_list[i] == ''):
            continue
        chatbot_output_list_cleaned.append(chatbot_output_list[i].strip()+punctuation_ordered[i])

    return chatbot_output_list_cleaned

In [24]:
# Get Shakespeare output string

def get_shakespeare_output(chatbot_output_list_cleaned, temperature):
    
    shakespeare_output_string = ''
    shakespeare_output_list = []

    for i in range(len(chatbot_output_list_cleaned)):
        shakespeare_input = shakespeare_tokenizer.encode(chatbot_output_list_cleaned[i])
        shakespeare_output = shakespeare_model.generate(input_ids=torch.tensor([shakespeare_input]).to(device), do_sample=True, max_new_tokens=50, temperature=temperature)
        shakespeare_output_list.append(shakespeare_tokenizer.decode(shakespeare_output[0], skip_special_tokens=True).strip())

    for i in range(len(shakespeare_output_list)):
        shakespeare_output_string = shakespeare_output_string +' '+shakespeare_output_list[i]

    shakespeare_output_string = shakespeare_output_string.strip()

    return shakespeare_output_string

### Chatbot Function

In [25]:
def Shakespeare_Chatbot(Your_Input, Chatbot_Temperature, Shakespeare_Temperature):

    chatbot_output_string = get_chatbot_output(Your_Input, Chatbot_Temperature)

    chatbot_output_list_cleaned = split_chatbot_output(chatbot_output_string)

    shakespeare_output_string = get_shakespeare_output(chatbot_output_list_cleaned, Shakespeare_Temperature)

    return chatbot_output_string, shakespeare_output_string

### Chatbot Interface

In [28]:
output_1 = gr.Textbox(label="Original Chatbot")
output_2 = gr.Textbox(label="Shakespeare Chatbot")

demo = gr.Interface(fn=Shakespeare_Chatbot, inputs=["text",gr.Slider(0, 2, 0.1), gr.Slider(0, 2, 0.1)], outputs=[output_1,output_2])

demo.launch()

Colab notebook detected. To show errors in colab notebook, set `debug=True` in `launch()`
Note: opening Chrome Inspector may crash demo inside Colab notebooks.

To create a public link, set `share=True` in `launch()`.


<IPython.core.display.Javascript object>

