In [1]:
import os
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, GPTQConfig, pipeline
import time
from peft import (
    LoraConfig,
    PeftConfig,
    PeftModel,
    get_peft_model,
    prepare_model_for_kbit_training
)

In [2]:
three_bit_llama = "./models/llama-2-70b-hf-quantized-3bits/"
four_bit_llama = "./models/llama-2-70b-hf-quantized-4bits/"
leader_board_model = "rwitz2/go-bruins-v2.1.1"

In [3]:
model = AutoModelForCausalLM.from_pretrained(four_bit_llama, device_map="auto", cache_dir="./models")
tokenizer = AutoTokenizer.from_pretrained(four_bit_llama, cache_dir="./models")

Loading checkpoint shards:   0%|          | 0/8 [00:00<?, ?it/s]

***CHAT***

In [4]:
from IPython.display import display, HTML, clear_output, Markdown
import textwrap
import ipywidgets as widgets
import re
import io
import pandas as pd
import json
import nbformat
import PyPDF2

In [5]:
# Configuration
runtimeFlag = "cuda:0" #Run on GPU (you can't run GPTQ on cpu)
scaling_factor = 1.0 # allows for a max sequence length of 16384*6 = 98304! Unfortunately, requires Colab Pro and a V100 or A100 to have sufficient RAM.

In [6]:
# Set the SYSTEM PROMPT
# DEFAULT_SYSTEM_PROMPT = 'You are a helpful pair-coding assistant.'
DEFAULT_SYSTEM_PROMPT = 'You are a helpful assistant.'
SYSTEM_PROMPT = DEFAULT_SYSTEM_PROMPT

print(SYSTEM_PROMPT)       

You are a helpful assistant.


In [7]:
# B_INST, E_INST = "[INST]", "[/INST]"

B_INST, E_INST = "Question: ", "Answer: "

B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"

scaling_factor = 1.0
max_context = int(model.config.max_position_embeddings*scaling_factor)
max_doc_length = int(0.75 * max_context)  # max doc length is 75% of the context length
max_doc_words = int(max_doc_length)

In [8]:
def generate_response(dialogs, temperature=0.01, top_p=0.9, logprobs=False):
    torch.cuda.empty_cache()
    # print(json.dumps(dialogs, indent=4))
    max_prompt_len = int(0.85 * max_context)
    max_gen_len = int(0.10 * max_prompt_len)

    prompt_tokens = []
    for dialog in dialogs:
        if dialog[0]["role"] != "system":
            dialog = [
                {
                    "role": "system",
                    "content": SYSTEM_PROMPT,
                }
            ] + dialog
        dialog_tokens = [tokenizer(
            # f"{B_INST} {B_SYS}{(dialog[0]['content']).strip()}{E_SYS}{(dialog[1]['content']).strip()} {E_INST}",
            f"{B_INST} {(dialog[1]['content']).strip()} {E_INST}", # Omits the system prompt altogether
            return_tensors="pt",
            add_special_tokens=True
        ).input_ids.to(runtimeFlag)]
        for i in range(2, len(dialog), 2):
            user_tokens = tokenizer(
                f"{B_INST} {(dialog[i+1]['content']).strip()} {E_INST}",
                return_tensors="pt",
                add_special_tokens=True
            ).input_ids.to(runtimeFlag)
            assistant_w_eos = dialog[i]['content'].strip() + tokenizer.eos_token
            assistant_tokens = tokenizer(
                            assistant_w_eos,
                            return_tensors="pt",
                            add_special_tokens=False
                        ).input_ids.to(runtimeFlag)
            tokens = torch.cat([assistant_tokens, user_tokens], dim=-1)
            dialog_tokens.append(tokens)
        prompt_tokens.append(torch.cat(dialog_tokens, dim=-1))

    input_ids = prompt_tokens[0]
    if len(input_ids[0]) > max_prompt_len:
        return "\n\n **The language model's input limit has been reached. Clear the chat and start afresh!**"

    # print(tokenizer.decode(input_ids[0], skip_special_tokens=True))

    generation_output = model.generate(
        input_ids=input_ids,
        do_sample=True,
        max_new_tokens=max_gen_len,
        temperature=temperature,
        top_p=top_p,
    );

    new_tokens = generation_output[0][input_ids.shape[-1]:]
    new_assistant_response = tokenizer.decode(new_tokens, skip_special_tokens=True).strip();

    return new_assistant_response

In [9]:
def print_wrapped(text):
    # Regular expression pattern to detect code blocks
    code_pattern = r'```(.+?)```'
    matches = list(re.finditer(code_pattern, text, re.DOTALL))

    if not matches:
        # If there are no code blocks, display the entire text as Markdown
        display(Markdown(text))
        return

    start = 0
    for match in matches:
        # Display the text before the code block as Markdown
        before_code = text[start:match.start()].strip()
        if before_code:
            display(Markdown(before_code))

        # Display the code block
        code = match.group(0).strip()  # Extract code block
        display(Markdown(code))  # Display code block

        start = match.end()

    # Display the text after the last code block as Markdown
    after_code = text[start:].strip()  # Text after the last code block
    if after_code:
        display(Markdown(after_code))

dialog_history = [{"role": "system", "content": SYSTEM_PROMPT}]

button = widgets.Button(description="Send")
upload_button = widgets.Button(description="Upload .txt or .pdf")
text = widgets.Textarea(layout=widgets.Layout(width='800px'))

output_log = widgets.Output()

def on_button_clicked(b):
    user_input = text.value
    dialog_history.append({"role": "user", "content": user_input})

    text.value = ''

    # Change button description and color, and disable it
    button.description = 'Processing...'
    button.style.button_color = '#ff6e00'  # Use hex color codes for better color choices
    button.disabled = True  # Disable the button when processing

    with output_log:
        clear_output()
        for message in dialog_history:
            print_wrapped(f'**{message["role"].capitalize()}**: {message["content"]}\n')

    assistant_response = generate_response([dialog_history]);

    # Re-enable the button, reset description and color after processing
    button.description = 'Send'
    button.style.button_color = 'lightgray'
    button.disabled = False

    dialog_history.append({"role": "assistant", "content": assistant_response})

    with output_log:
        clear_output()
        for message in dialog_history:
            print_wrapped(f'**{message["role"].capitalize()}**: {message["content"]}\n')

button.on_click(on_button_clicked)

# Create an output widget for alerts
alert_out = widgets.Output()

clear_button = widgets.Button(description="Clear Chat")
text = widgets.Textarea(layout=widgets.Layout(width='800px'))

def on_clear_button_clicked(b):
    # Clear the dialog history
    dialog_history.clear()
    # Add back the initial system prompt
    dialog_history.append({"role": "system", "content": SYSTEM_PROMPT})
    # Clear the output log
    with output_log:
        clear_output()

clear_button.on_click(on_clear_button_clicked)

In [10]:
# File path input for .txt and .pdf files
file_path_input = widgets.Text(
    placeholder="Enter the path of your file here",
    description="File Path:",
    disabled=False,
)

# Process File Button
process_file_button = widgets.Button(
    description="Process File",
    disabled=True  # Initially disabled
)

# Function to check if the file path is valid
def is_valid_file_path(file_path):
    return os.path.isfile(file_path) and file_path.split('.')[-1] in ['txt', 'pdf']

# Function to enable the process file button based on the file path validity
def on_file_path_change(change):
    process_file_button.disabled = not is_valid_file_path(change.new)

file_path_input.observe(on_file_path_change, names='value')

# Function to handle the file processing button click
def on_process_file_button_clicked(b):
    file_path = file_path_input.value
    if is_valid_file_path(file_path):
        file_type = file_path.split('.')[-1]
        try:
            if file_type == 'txt':
                with open(file_path, 'r') as file:
                    file_content = file.read()
            elif file_type == 'pdf':
                with open(file_path, 'rb') as file:
                    pdf_reader = PyPDF2.PdfFileReader(file)
                    file_content = ''.join([pdf_reader.getPage(page).extractText() for page in range(pdf_reader.numPages)])
            
            # Add file content to dialog history as user input
            dialog_history.append({"role": "user", "content": file_content})

            with output_log:
                clear_output()
                for message in dialog_history:
                    print_wrapped(f'**{message["role"].capitalize()}**: {message["content"]}\n')

            # Generate response from the model
            assistant_response = generate_response([dialog_history])

            # Append the model's response to the dialog history
            dialog_history.append({"role": "assistant", "content": assistant_response})

            # Display the updated dialog history
            with output_log:
                clear_output()
                for message in dialog_history:
                    print_wrapped(f'**{message["role"].capitalize()}**: {message["content"]}\n')

        except Exception as e:
            with output_log:
                clear_output()
                print(f"Error processing the file: {e}")

        # Clear the file path input for next use
        file_path_input.value = ''
        process_file_button.disabled = True
    else:
        with output_log:
            clear_output()
            print("Invalid file path or unsupported file type.")

process_file_button.on_click(on_process_file_button_clicked)

In [11]:
from IPython.display import display, HTML
from ipywidgets import HBox, VBox

# Create the title with HTML
title = f"<h1 style='color: #ff6e00;'>{four_bit_llama} 🤖</h1> <p>(Max context of: {max_context}. Uploaded files will be shortened to {max_doc_words} tokens)</p>"

# Assuming that output_log, alert_out, and text are other widgets or display elements...
first_row = HBox([button, clear_button])  # Arrange these buttons horizontally
# Adding the file upload and process file button to the layout
file_upload_row = HBox([file_path_input, process_file_button])
layout = VBox([output_log, alert_out, text, first_row, file_upload_row])

In [12]:
display(HTML(title))  # Use HTML function to display the title
display(layout)

VBox(children=(Output(), Output(), Textarea(value='', layout=Layout(width='800px')), HBox(children=(Button(des…