In [1]:
model_path = "outputs_squad/merged_model"
# model_path = "gmongaras/Wizard_7B_Squad_v2"

In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
    AutoTokenizer,
)

# Load in model
def load_model(model_name):
    # Load in the token
    token = ""

    # Load model
    bnb_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype="bfloat16",
        bnb_4bit_use_double_quant=True,
    )
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        trust_remote_code=True,
        device_map="auto",
        load_in_8bit=True,
        quantization_config=bnb_config,
        use_auth_token=token,
        do_sample=True,
        cache_dir="models",
    ).eval()

    tokenizer = AutoTokenizer.from_pretrained(
        model_name,
        trust_remote_code=True,
        token=token
    )
    class Infer:
        def __init__(self, model, tokenizer):
            self.model = model
            self.tokenizer = tokenizer

        def forward(self, text, limit=128, temp=1.0):
            text = self.tokenizer(text, return_tensors="pt").to("cuda")
            output = self.model.generate(**text, do_sample=True, temperature=temp, max_new_tokens=int(limit), top_p=0.95, top_k=60, pad_token_id=self.tokenizer.pad_token_id)
            return self.tokenizer.decode(output[0], skip_special_tokens=True)

    return Infer(model, tokenizer)
model = load_model(model_path)

In [None]:
import gradio as gr

def predict(temp, limit, text):
    prompt = text
    out = model.forward(prompt, limit, temp)
    return out

pred = gr.Interface(
    predict,
    inputs=[
        gr.Slider(0.001, 10, value=0.1, label="Temperature"),
        gr.Slider(1, 1024, value=128, label="Token Limit"),
        gr.Textbox(
            label="Input",
            lines=1,
            value="#### Human: What's the capital of Australia?#### Assistant: ",
        ),
    ],
    outputs='text',
)

pred.launch(share=True)