In [None]:
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
code_llama_models = {
    "CodeLlama 7B Instruct": "meta-llama/CodeLlama-7b-Instruct-hf",
    "CodeLlama 13B Instruct": "meta-llama/CodeLlama-13b-Instruct-hf",
    "CodeLlama 34B Instruct": "meta-llama/CodeLlama-34b-Instruct-hf",
    "CodeLlama 70B Instruct": "meta-llama/CodeLlama-70b-Instruct-hf",
}

In [None]:
def load_model(model_name):
    """Load the specified CodeLlama model."""
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForCausalLM.from_pretrained(model_name, device_map='auto')
    model.to(device)
    return model, tokenizer

Cache models

In [None]:
model_cache = {}

In [None]:
def generate_code(prompt, model_choice):
    """Generate code using the selected CodeLlama model."""
    if model_choice not in model_cache:
        model_cache[model_choice] = load_model(code_llama_models[model_choice])
    model, tokenizer = model_cache[model_choice]
    full_prompt = f"Write a program based on the following description:\n\n\"{prompt}\"\n\nCode:\n"
    inputs = tokenizer(full_prompt, return_tensors="pt").to(device)
    outputs = model.generate(
        **inputs,
        max_length=512,
        do_sample=True,
        temperature=0.5,
        top_p=0.9,
        num_return_sequences=1,
        eos_token_id=tokenizer.eos_token_id,
    )
    generated_code = tokenizer.decode(outputs[0], skip_special_tokens=True)
    code = generated_code.split("Code:")[-1].strip()
    return code

Gradio interface

In [None]:
with gr.Blocks() as demo:
    gr.Markdown("<h1><center>Code Generation with CodeLlama Models</center></h1>")
    model_choice = gr.Dropdown(list(code_llama_models.keys()), label="Select CodeLlama Model")
    prompt_input = gr.Textbox(label="Describe the code you want", lines=2)
    code_output = gr.Code(language="python", label="Generated Code")
    generate_button = gr.Button("Generate Code")
    generate_button.click(generate_code, [prompt_input, model_choice], code_output)

In [None]:
demo.launch()