In [15]:
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch

# Load DialoGPT-small model and tokenizer
model_name = "microsoft/DialoGPT-small"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)

# Check if GPU is available and move model to GPU if possible
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

# Set chat parameters
max_length = 1000
chat_history_ids = None

def chat_with_bot(user_input):
    global chat_history_ids
    
    # Encode the new user input, add the eos_token and return a tensor in Pytorch
    new_user_input_ids = tokenizer.encode(user_input + tokenizer.eos_token, return_tensors='pt').to(device)
    
    # Append the new user input tokens to the chat history
    if chat_history_ids is not None:
        bot_input_ids = torch.cat([chat_history_ids, new_user_input_ids], dim=-1)
    else:
        bot_input_ids = new_user_input_ids

    # Generate a response while limiting the total chat history to 1000 tokens
    chat_history_ids = model.generate(
        bot_input_ids, 
        max_length=max_length,
        pad_token_id=tokenizer.eos_token_id,
        no_repeat_ngram_size=3,
        do_sample=True,
        top_k=100,
        top_p=0.7,
        temperature=0.8
    )
    
    # Extract the AI's response
    ai_response = tokenizer.decode(chat_history_ids[:, bot_input_ids.shape[-1]:][0], skip_special_tokens=True)
    
    # Optionally, trim conversation history if it gets too long
    if chat_history_ids.shape[1] > 1000:
        chat_history_ids = chat_history_ids[:, -1000:]
    
    return ai_response




In [10]:
from transformers import AutoModelForQuestionAnswering, AutoTokenizer
import torch

# Load DialoGPT-small model and tokenizer
model_name = "my_awesome_qa_model"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForQuestionAnswering.from_pretrained(model_name)

# Check if GPU is available and move model to GPU if possible
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

# Set chat parameters
max_length = 1000
chat_history_ids = None

def answer_question(question, context):
    # Encode the question and context
    inputs = tokenizer.encode_plus(question, context, add_special_tokens=True, return_tensors='pt').to(device)

    # Forward pass
    with torch.no_grad():
        outputs = model(**inputs)

    # Extract the start and end scores
    start_scores = outputs.start_logits
    end_scores = outputs.end_logits

    # Get the most likely start and end of the answer
    start_index = torch.argmax(start_scores)
    end_index = torch.argmax(end_scores)

    # Convert token indices to the answer string
    answer_tokens = inputs['input_ids'][0][start_index:end_index + 1]
    answer = tokenizer.decode(answer_tokens, skip_special_tokens=True)

    return answer

# Example usage
context = "Thanksgiving is a national holiday celebrated on various dates in October and November in the United States, Canada, Saint Lucia, Liberia, and unofficially in countries like Brazil, Germany and the Philippines. It is also observed in the Dutch town of Leiden and the Australian territory of Norfolk Island. It began as a day of giving thanks for the blessings of the harvest and of the preceding year. Various similarly named harvest festival holidays occur throughout the world during autumn. Although Thanksgiving has historical roots in religious and cultural traditions, it has long been celebrated as a secular holiday as well."
user_question = "When is thanksgiving celebrated?"
response = answer_question(user_question, context)
print(response)



october and november


In [16]:
from tkinter import *
from tkinter import ttk

user_chat_history = []

def update_gui_chat(*args):
    global user_chat_history
    
    user_input = dialog.get()  # Get the input from the user
    
    # Check for termination commands
    if user_input.lower() in ["quit", "bye"]:
        chat.set("Thank you for chatting! Goodbye!")
        dialog.set("")  # Clear the entry box
        return  # Exit the function early
    
    # Append the user input to the chat history
    user_chat_history.append("User: " + user_input)
    
    # Get the AI's response from the chat_with_bot function
    ai_response = chat_with_bot(user_input)
    
    # Append the AI's response to the chat history
    user_chat_history.append("AI: " + ai_response)

    # Update the chat_history StringVar to display in the UI
    chat.set("\n".join(user_chat_history))
    
    # Clear the input field after sending the message
    dialog.set("")  # Clear the entry box

# Create the main application window
app = Tk()
app.title("AI Dialog Chat")

mainframe = ttk.Frame(app, padding="50 3 20 12")
mainframe.grid(column=0, row=0, sticky=(N, W, E, S))

# Configure the grid to have 5 columns and 5 rows
for i in range(5):
    mainframe.columnconfigure(i, weight=1)  # Allow columns to expand
    mainframe.rowconfigure(i, weight=1)     # Allow rows to expand

# Row 1: Instructions
label = ttk.Label(mainframe, text="Chat with the AI. Type 'exit', 'quit', or 'bye' to finish the conversation.")
label.grid(column=2, row=1, sticky=(W, E))

# Row 2: Chat History
label = ttk.Label(mainframe, text="Chat History:")
label.grid(column=1, row=2, sticky=(W, E))
chat = StringVar()
ttk.Label(mainframe, textvariable=chat).grid(column=2, row=2, sticky=(W, E))

# Row 3: Input Box
label = ttk.Label(mainframe, text="Input: ")
label.grid(column=1, row=3, sticky=(W, E))

dialog = StringVar()
dialog_entry = ttk.Entry(mainframe, width=150, textvariable=dialog)
dialog_entry.grid(column=2, row=3, sticky=(W, E))

# Chat button to trigger the update_gui_chat function
ttk.Button(mainframe, text="Chat", command=update_gui_chat).grid(column=2, row=4, sticky=(W, E))

# Add padding to all child widgets
for child in mainframe.winfo_children(): 
    child.grid_configure(padx=5, pady=5)

# Set focus to the entry field and bind the Enter key to the chat function
dialog_entry.focus()
app.bind("<Return>", update_gui_chat)

# Run the application
app.mainloop()



The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.


In [9]:
import gradio as gr

def respond(message, chat_history):
    bot_message = chat_with_bot(message)
    chat_history.append((message, bot_message))
    
    # Optionally limit chat history to a certain length
    if len(chat_history) > 10:  # Keep only the last 10 exchanges
        chat_history = chat_history[-10:]

    return "", chat_history

with gr.Blocks() as demo:
    chatbot = gr.Chatbot()
    msg = gr.Textbox()
    clear = gr.ClearButton([msg, chatbot])

    msg.submit(respond, [msg, chatbot], [msg, chatbot])

if __name__ == "__main__":
    demo.launch()

Running on local URL:  http://127.0.0.1:7861

To create a public link, set `share=True` in `launch()`.


In [11]:
import gradio as gr
import random
import time

# Function to respond to questions
def respond(user_question, context, chat_history):
    if context.strip() == "":
        return "Please provide context.", chat_history

    answer = answer_question(user_question, context)
    chat_history.append((user_question, answer))  # Append user question and bot answer to chat history
    return "", chat_history

# Create Gradio Blocks interface
with gr.Blocks() as demo:
    gr.Markdown("### Question Answering Chat")
    context = gr.Textbox(label="Context", placeholder="Enter the context here...", lines=5)
    question = gr.Textbox(label="Question", placeholder="Ask your question...", lines=1)
    chatbot = gr.Chatbot()  # Create a chat-like interface
    clear = gr.Button("Clear Chat")

    # Handle submit button for questions
    question.submit(respond, inputs=[question, context, chatbot], outputs=[question, chatbot])
    
    # Handle clear button to reset chat
    clear.click(lambda: ("", []), outputs=[context, chatbot])  # Clear context and chat history

if __name__ == "__main__":
    demo.launch()

Running on local URL:  http://127.0.0.1:7862

To create a public link, set `share=True` in `launch()`.
