In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"]="2"

import torch
import re
from peft import PeftModel
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig

In [2]:
# create Maroon Chat class
class MaroonChat:
    def __init__(self):
        self.base_model_id = "mistralai/Mistral-7B-v0.1"
        self.bnb_config = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_use_double_quant=True,
            bnb_4bit_quant_type="nf4",
            bnb_4bit_compute_dtype=torch.bfloat16
        )
        self.base_model = AutoModelForCausalLM.from_pretrained(
            self.base_model_id,  # Mistral, same as before
            quantization_config=self.bnb_config,  # Same quantization config as before
            device_map="auto",
            trust_remote_code=True,
            use_auth_token=False
        )
        self.eval_tokenizer = AutoTokenizer.from_pretrained(self.base_model_id, 
                                                            add_bos_token=True, 
                                                            trust_remote_code=True)
        self.model = PeftModel.from_pretrained(self.base_model, 
                                               "mistral-nlp-best/checkpoint-500-r32-alpha64")

    def generate(self, prompt):
        model_input = self.eval_tokenizer(prompt, return_tensors="pt").to("cuda")
        self.model.eval()

        with torch.no_grad():
            result = self.eval_tokenizer.decode(self.model.generate(**model_input, max_new_tokens=100, repetition_penalty=1.15)[0], skip_special_tokens=True)
            
            # filter out the prompt text
            result = re.split(r'(?<!\s\w)(?<!Hon)(?<!Dr)(?<!Mr)(?<!Mrs)(?<!Ms)(?<!No)(?<!R.A)(?<!Prof)(?<!Atty)\.', result)[0]
            
            # remove prompt text from the result
            result = result.replace(prompt, "")

            # remove ## as well
            result = result.replace("#", "")

            # remove preceding spaces
            result = result.strip()

            # remove texts after the first period
            result = re.split(r'\n', result)[0] + "."

            return result

In [3]:
# check number of cuda devices
print(torch.cuda.device_count())

1


In [4]:
# initialize Maroon Chat
mc = MaroonChat()



Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [5]:
# generate response
print(mc.generate("What is Rowel Atienza background?"))

Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


Rowel Atienza is a Filipino professional basketball player who currently plays for the San Miguel Beermen of the Philippine Basketball Association (PBA).


In [6]:
import gradio as gr  

def maroon_chat(prompt):
    return mc.generate(prompt)

iface = gr.Interface(
    fn=maroon_chat, 
    inputs=gr.Textbox(lines=2, placeholder="Enter your prompt here..."),
    outputs=gr.Textbox(lines=2),
    title="Maroon Chat",
    description="A chatbot trained on University of the Philippiens Data."
)

iface.launch(share=True)


Running on local URL:  http://127.0.0.1:7860
Running on public URL: https://8450083abda63da927.gradio.live

This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from Terminal to deploy to Spaces (https://huggingface.co/spaces)




Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:2 for o