In [None]:
# Install Streamlit and bitsandbytes
!pip install streamlit
!pip install bitsandbytes

In [None]:
# Login to Hugging Face account using notebook interface
from huggingface_hub import notebook_login

notebook_login()


In [None]:
# Load tokenizer and 4-bit quantized model from Hugging Face
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig

model_name = 'meta-llama/Meta-Llama-3-8B-Instruct'
tokenizer = AutoTokenizer.from_pretrained(model_name)

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.float16,
    bnb_4bit_quant_type='nf4',
    bnb_4bit_use_double_quant=False,
)

model = AutoModelForCausalLM.from_pretrained(
    model_name,
    device_map='auto',
    quantization_config=bnb_config,
    torch_dtype=torch.bfloat16,
)


In [None]:
# Generate a response from the model based on a prompt
prompt = 'You are a physician. Briefly describe pneumonia.'

model_input = tokenizer(prompt, return_tensors='pt').to(model.device)
input_ids = model_input['input_ids']

with torch.no_grad():
    result = model.generate(
        input_ids,
        max_new_tokens=300,
        do_sample=True,
        temperature=0.6,
        top_p=0.9
    )

    result = result[0][input_ids.shape[-1]:]
    output = tokenizer.decode(result, skip_special_tokens=True)
    print('\n output \n', output)

    del input_ids
    del model_input
    torch.cuda.empty_cache()


In [None]:
# Save the Streamlit application to app.py
%%writefile /content/app.py

# Import necessary libraries
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
import streamlit as st

# Load model and tokenizer with quantization configuration
@st.cache_resource(show_spinner=False)
def load_model_and_tokenizer(model_name):
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    bnb_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_compute_dtype=torch.float16,
        bnb_4bit_quant_type='nf4',
        bnb_4bit_use_double_quant=False,
    )
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        device_map='auto',
        quantization_config=bnb_config,
        torch_dtype=torch.bfloat16,
    )
    return tokenizer, model

# Main function for the Streamlit chat application
def st_chatllm():
    st.title('ChatLLM')

    model_name = 'meta-llama/Meta-Llama-3-8B-Instruct'
    tokenizer, model = load_model_and_tokenizer(model_name)

    # Initialize session state for input and output
    if 'input' not in st.session_state:
        st.session_state.input = []
    if 'output' not in st.session_state:
        st.session_state.output = []

    response_container = st.container(height=500)
    input_container = st.container()

    # Display previous chat messages
    with response_container:
        for i in range(len(st.session_state['output'])):
            st.chat_message('user').write(st.session_state['input'][i])
            st.chat_message('assistant').write(st.session_state['output'][i])

    # Input prompt and generate response
    with input_container:
        prompt = st.chat_input('Enter your message')
        if prompt:
            st.session_state.input.append(prompt)
            model_input = tokenizer(prompt, return_tensors='pt').to(model.device)
            input_ids = model_input['input_ids']

            with torch.no_grad():
                response = model.generate(
                    input_ids,
                    max_new_tokens=400,
                    do_sample=False,
                )
                output = tokenizer.decode(response[:, input_ids.shape[-1]:][0], skip_special_tokens=True)

            st.session_state.output.append(output)
            st.rerun()

# Run the Streamlit application
if __name__ == '__main__':
    st_chatllm()


In [None]:
# Get the public IP address of the current machine
!curl ifconfig.me


In [None]:
# Run the Streamlit application and create a tunnel for external access
!streamlit run app.py & sleep 5 && npx localtunnel --port 8501
# Tunnel password should be the IP address
