## Importing the libraries

In [3]:
import os
import io
import IPython.display
from PIL import Image
import base64 
import requests 
requests.adapters.DEFAULT_TIMEOUT = 60
import gradio as gr

  from .autonotebook import tqdm as notebook_tqdm


## Setting the path to institute's centralized database containing the models

In [2]:
# Specifying the cache directory - loading the downloaded models
# This directory can have all the Large Language Models. The size of the directory could get into terabytes
# Currently, we are at 2.7 TiB 
# Setting this variable, sets the huggingface hub path - reads and writes defaults to this path
os.environ['HUGGINGFACE_HUB_CACHE'] = os.environ['LLM_CACHE_PATH']

## Loading the LLM from the path specified above

In [None]:
import torch
from transformers import AutoConfig, AutoTokenizer, AutoModelForCausalLM, pipeline, AutoModelForSeq2SeqLM

model_id = "tiiuae/falcon-40b-instruct"

config = AutoConfig.from_pretrained(model_id, trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained(model_id)

model = AutoModelForCausalLM.from_pretrained(model_id,
                                             trust_remote_code=True,
                                             torch_dtype=torch.bfloat16,
                                             load_in_8bit=True,
                                             device_map="auto")

In [4]:
from transformers import StoppingCriteria, StoppingCriteriaList
class CustomStoppingCriteria(StoppingCriteria):
    def __init__(self, stops = []):
        self.stops = stops
        self.ENCOUNTERS = 2
    
    def __call__(self, input_ids: torch.LongTensor, score: torch.FloatTensor, **kwargs) -> bool:
        for stop in self.stops:
            if stop == tokenizer.decode(input_ids[0][-1]).strip():
                return True
        return False

stopping_criteria = StoppingCriteriaList([CustomStoppingCriteria(["\nUser:", "User:", "Falcon:", "User"])])

## Integrating memory for the LLM

In [5]:
from langchain import HuggingFacePipeline
from langchain.chains import ConversationChain
from langchain.chains.conversation.memory import ConversationSummaryBufferMemory

pipe = pipeline(
    "text-generation", model=model, tokenizer=tokenizer, device_map="auto", max_new_tokens=512, 
    pad_token_id=tokenizer.eos_token_id, stopping_criteria=stopping_criteria
)
summary_llm = HuggingFacePipeline(pipeline=pipe)

# 110 (prompt of memory buffer) + 512 (summary words) + 1046 (max length of new generations) < 2046
memory = ConversationSummaryBufferMemory(llm=summary_llm, max_token_limit=200)
memory.save_context({"input": "Hello"}, {"output": "What's up"})

conversation_prompt_template='''Below is an instruction that describes a task, paired with current conversation to provide history of conversation and \
an input that provides further context. \
Write a response that appropriately completes the request.

### Instruction:
You are an AI named Falcon. Answer the questions asked to you in a talkative manner.

### Current conversation:
{history}

### Input:
{input}

### Response:'''


In [6]:
print(memory.prompt.template)

Progressively summarize the lines of conversation provided, adding onto the previous summary returning a new summary.

EXAMPLE
Current summary:
The human asks what the AI thinks of artificial intelligence. The AI thinks artificial intelligence is a force for good.

New lines of conversation:
Human: Why do you think artificial intelligence is a force for good?
AI: Because artificial intelligence will help humans reach their full potential.

New summary:
The human asks what the AI thinks of artificial intelligence. The AI thinks artificial intelligence is a force for good because it will help humans reach their full potential.
END OF EXAMPLE

Current summary:
{summary}

New lines of conversation:
{new_lines}

New summary:


## Setting up the Gradio interface

In [None]:
from transformers import TextIteratorStreamer
from threading import Thread
from transformers import StoppingCriteria, StoppingCriteriaList

torch_device = "cuda" if torch.cuda.is_available() else "cpu"

streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) #timeout=10.,    
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, device_map="auto", max_new_tokens=800, 
                pad_token_id=tokenizer.eos_token_id, stopping_criteria=stopping_criteria, streamer=streamer)
llm = HuggingFacePipeline(pipeline=pipe)

# 729 (prompt of memory buffer) + 512 (summary words) + 800 (max length of new generations)
conversation = ConversationChain(llm=llm, memory=memory, verbose=False)
conversation.prompt.template = conversation_prompt_template

def respond(message, chat_history, instruction, temperature=0.7, skip_words=["<|endoftext|>", "\nUser:", "User:", "Falcon:"]):
    chat_history = chat_history + [[message, ""]]
    
    # Start generation on a separate thread, so that we don't block the UI. The text is pulled from the streamer
    # in the main thread. Adds timeout to the streamer to handle exceptions in the generation thread.
    
    generate_kwargs = dict(
        input= message,
    )
    t = Thread(target=conversation.predict, kwargs=generate_kwargs)
    t.start()
    
    acc_text = ""
    #Streaming the tokens
    for idx, response in enumerate(streamer):
            text_token = response
            # if text_token in skip_words:
            #     continue

            if idx == 0 and text_token.startswith(" "):
                text_token = text_token[1:]

            acc_text += text_token
            last_turn = list(chat_history.pop(-1))
            last_turn[-1] += acc_text
            chat_history = chat_history + [last_turn]
            yield "", chat_history
            acc_text = ""
    
    # Crucial step for summarizing in the memory chain - you have to wait when you go over max allowed tokens for the model to summarize
    _ = t.join()

gr.close_all()

with gr.Blocks() as demo:
    chatbot = gr.Chatbot(height=600) #just to fit the notebook
    msg = gr.Textbox(label="Prompt")
    with gr.Accordion(label="Advanced options",open=False):
        system = gr.Textbox(label="System message", lines=2, value="A conversation between a user and an LLM-based AI assistant. The assistant gives helpful and honest answers.")
        temperature = gr.Slider(label="temperature", minimum=0.1, maximum=1, value=0.7, step=0.1)
        max_new_tokens = gr.Slider(label="Max New Tokens", minimum=1, maximum=1000, value=250, step=1, interactive=True)
    btn = gr.Button("Submit")
    clear = gr.ClearButton(components=[msg, chatbot], value="Clear console")

    btn.click(respond, inputs=[msg, chatbot, system], outputs=[msg, chatbot])
    msg.submit(respond, inputs=[msg, chatbot, system], outputs=[msg, chatbot]) #Press enter to submit

demo.queue().launch(share=True)#, server_port=int(os.environ['PORT4']))