In [3]:
import logging
import together, os, yaml
from langchain.llms.base import LLM
from pydantic import Extra, Field, root_validator
from typing import Any, Dict, List, Mapping, Optional
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.llms.utils import enforce_stop_tokens
from langchain.utils import get_from_dict_or_env

from langchain.chains import RetrievalQA
from langchain.vectorstores import Chroma
from InstructorEmbedding import INSTRUCTOR
from langchain.document_loaders import TextLoader
from langchain.document_loaders import PyPDFLoader
from langchain.document_loaders import DirectoryLoader
from langchain.embeddings import HuggingFaceBgeEmbeddings
from langchain.text_splitter import RecursiveCharacterTextSplitter

In [4]:
with open('cadentials.yaml') as f:
    credentials = yaml.load(f, Loader=yaml.FullLoader)

os.environ['HUGGINGFACEHUB_API_TOKEN'] = credentials['HUGGINGFACEHUB_API_TOKEN']
os.environ['TOGETHER_AI_API'] = credentials['TOGETHER_AI_API']

In [5]:
together.api_key = os.environ["TOGETHER_AI_API"]

# list available models and descriptons
models = together.Models.list()
for m in models:
    print(m['name'])

Austism/chronos-hermes-13b
EleutherAI/llemma_7b
EleutherAI/pythia-12b-v0
EleutherAI/pythia-1b-v0
EleutherAI/pythia-2.8b-v0
EleutherAI/pythia-6.9b
Gryphe/MythoMax-L2-13b
HuggingFaceH4/starchat-alpha
NousResearch/Nous-Hermes-13b
NousResearch/Nous-Hermes-Llama2-13b
NousResearch/Nous-Hermes-Llama2-70b
NousResearch/Nous-Hermes-llama-2-7b
NumbersStation/nsql-llama-2-7B
Open-Orca/Mistral-7B-OpenOrca
OpenAssistant/llama2-70b-oasst-sft-v10
OpenAssistant/oasst-sft-4-pythia-12b-epoch-3.5
OpenAssistant/stablelm-7b-sft-v7-epoch-3
Phind/Phind-CodeLlama-34B-Python-v1
Phind/Phind-CodeLlama-34B-v2
SG161222/Realistic_Vision_V3.0_VAE
WizardLM/WizardCoder-15B-V1.0
WizardLM/WizardCoder-Python-34B-V1.0
WizardLM/WizardLM-70B-V1.0
bigcode/starcoder
databricks/dolly-v2-3b
databricks/dolly-v2-7b
defog/sqlcoder
garage-bAInd/Platypus2-70B-instruct
huggyllama/llama-13b
huggyllama/llama-30b
huggyllama/llama-65b
huggyllama/llama-7b
lmsys/fastchat-t5-3b-v1.0
lmsys/vicuna-13b-v1.5-16k
lmsys/vicuna-13b-v1.5
lmsys/vicun

In [6]:
together.Models.start("togethercomputer/llama-2-13b-chat")

{'success': True,
 'value': '9f80dbe75ee2d9408b637393a9a3081395fa2dcd007b7ae130c61ebf88aee09e-1802c08ee7b2bdd56c3c7c9853a8ea9272433524cc037566a9a16aff837e285e'}

In [7]:
class TogetherLLM(LLM):
    """Together large language models."""

    model: str = "togethercomputer/llama-2-70b-chat" # model endpoint to use
    together_ai_api: str = os.environ["TOGETHER_AI_API"] # Together API key
    temperature: float = 0.7 # What sampling temperature to use.
    max_tokens: int = 512 # The maximum number of tokens to generate in the completion.
    class Config:
        extra = Extra.forbid

    @root_validator()
    def validate_environment(cls, values: Dict) -> Dict:
        """Validate that the API key is set."""
        api_key = get_from_dict_or_env(
            values, "together_ai_api", "TOGETHER_AI_API"
        )
        values["together_ai_api"] = api_key
        return values

    @property
    def _llm_type(self) -> str:
        """Return type of LLM."""
        return "together"

    def _call(
        self,
        prompt: str,
        **kwargs: Any,
    ) -> str:
        """Call to Together endpoint."""
        together.api_key = self.together_ai_api
        output = together.Complete.create(prompt,
                                          model=self.model,
                                          max_tokens=self.max_tokens,
                                          temperature=self.temperature,
                                          )
        text = output['output']['choices'][0]['text']
        return text

# Data Loader

In [8]:
loader = DirectoryLoader(
                        'data/new_papers/', 
                        glob="./*.pdf", 
                        loader_cls=PyPDFLoader
                        )
documents = loader.load()
len(documents)

142

In [9]:
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
texts = text_splitter.split_documents(documents)
len(texts)

659

# Embeddings

In [10]:
model_name = "BAAI/bge-base-en"
encode_kwargs = {'normalize_embeddings': True} # set True to compute cosine similarity

model_norm = HuggingFaceBgeEmbeddings(
                                    model_name=model_name,
                                    model_kwargs={'device': 'mps'},
                                    encode_kwargs=encode_kwargs
                                    )

Downloading (…)9a243/.gitattributes:   0%|          | 0.00/1.52k [00:00<?, ?B/s]

Downloading (…)_Pooling/config.json:   0%|          | 0.00/190 [00:00<?, ?B/s]

Downloading (…)1e3c49a243/README.md:   0%|          | 0.00/90.1k [00:00<?, ?B/s]

Downloading (…)3c49a243/config.json:   0%|          | 0.00/719 [00:00<?, ?B/s]

Downloading (…)ce_transformers.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

Downloading model.safetensors:   0%|          | 0.00/438M [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/438M [00:00<?, ?B/s]

# Chroma DB

In [24]:
persist_directory = 'db/03'

## Here is the nmew embeddings being used
embedding = model_norm

vectordb = Chroma.from_documents(
                                documents=texts,
                                embedding=embedding,
                                persist_directory=persist_directory
                                )

In [25]:
retriever = vectordb.as_retriever(search_kwargs={"k": 5})

In [27]:
llm = TogetherLLM(
                model= "togethercomputer/llama-2-70b-chat",
                temperature = 0.1,
                max_tokens = 1024
                )

In [28]:
# create the chain to answer questions
qa_chain = RetrievalQA.from_chain_type(
                                    llm=llm,
                                    chain_type="stuff",
                                    retriever=retriever,
                                    return_source_documents=True
                                    )


In [29]:
import textwrap

def wrap_text_preserve_newlines(text, width=110):
    lines = text.split('\n')
    wrapped_lines = [textwrap.fill(line, width=width) for line in lines]
    wrapped_text = '\n'.join(wrapped_lines)

    return wrapped_text

def process_llm_response(llm_response):
    print(wrap_text_preserve_newlines(llm_response['result']))
    print('\n\nSources:')
    for source in llm_response["source_documents"]:
        print(source.metadata['source'])

In [30]:
# full example
query = "What is Flash attention?"
llm_response = qa_chain(query)
process_llm_response(llm_response)

Flash attention is a new attention algorithm that computes exact attention with far fewer memory accesses. It
is designed to avoid reading and writing the attention matrix to and from HBM, which reduces the memory
accesses and improves the performance. It uses tiling to split the input into blocks and make several passes
over input blocks, thus incrementally performing the softmax reduction. It also stores the softmax
normalization factor from the forward pass to quickly recompute attention on-chip in the backward pass, which
is faster than the standard approach of reading the intermediate attention matrix from HBM.


Sources:
data/new_papers/Flash-attention.pdf
data/new_papers/Flash-attention.pdf
data/new_papers/Flash-attention.pdf
data/new_papers/Flash-attention.pdf
data/new_papers/Flash-attention.pdf
