In [None]:
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline

Check if CUDA is available

In [None]:
device = 0 if torch.cuda.is_available() else -1

In [None]:
llama_models = {
    "Meta-Llama 3 70B Instruct": "meta-llama/Meta-Llama-3-70B-Instruct",
    "Meta-Llama 3 8B Instruct": "meta-llama/Meta-Llama-3-8B-Instruct",
    "Llama 3.1 70B Instruct": "meta-llama/Llama-3.1-70B-Instruct",
    "Llama 3.1 8B Instruct": "meta-llama/Llama-3.1-8B-Instruct",
}

In [None]:
def load_model(model_name):
    """Load the specified Llama model."""
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForCausalLM.from_pretrained(model_name)
    generator = pipeline('text-generation', model=model, tokenizer=tokenizer, device=device)
    return generator

Cache models to avoid reloading

In [None]:
model_cache = {}

In [None]:
def chatbot_interface(user_input, history, model_choice):
    """Generate chatbot responses using the selected Llama model."""
    if model_choice not in model_cache:
        model_cache[model_choice] = load_model(llama_models[model_choice])
    generator = model_cache[model_choice]
    
    if history is None:
        history = []
    history.append(("User", user_input))
    prompt = "\n".join([f"{speaker}: {text}" for speaker, text in history]) + "\nAssistant:"
    
    # Generate response
    response = generator(prompt, max_length=512, pad_token_id=generator.tokenizer.eos_token_id,
                         do_sample=True, temperature=0.7, top_p=0.9)[0]['generated_text']
    assistant_reply = response.split("Assistant:")[-1].strip()
    history.append(("Assistant", assistant_reply))
    
    return history, history

Gradio interface

In [None]:
with gr.Blocks() as demo:
    gr.Markdown("<h1><center>Chat with Llama Models</center></h1>")
    model_choice = gr.Dropdown(list(llama_models.keys()), label="Select Llama Model")
    chatbot = gr.Chatbot()
    state = gr.State([])
    txt_input = gr.Textbox(show_label=False, placeholder="Type your message here...")
    def respond(user_input, history, model_choice):
        return chatbot_interface(user_input, history, model_choice)
    txt_input.submit(respond, [txt_input, state, model_choice], [chatbot, state])
    gr.Button("Submit").click(respond, [txt_input, state, model_choice], [chatbot, state])

In [None]:
demo.launch()