<a href="https://colab.research.google.com/github/nrimsky/qa/blob/main/openai_only_codebase_qa.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [10]:
%%capture
!pip install -U faiss-cpu langchain openai tiktoken

In [15]:
import os
import requests
import shutil
import tarfile
import re
from langchain.vectorstores import FAISS
from langchain.chains import RetrievalQA, RetrievalQAWithSourcesChain
from langchain.chat_models import ChatOpenAI
from langchain.prompts.chat import (
    ChatPromptTemplate,
    SystemMessagePromptTemplate,
    HumanMessagePromptTemplate,
)
from langchain.embeddings import OpenAIEmbeddings
import torch
import os
import requests
import os
import base64
import shutil
from langchain.text_splitter import RecursiveCharacterTextSplitter
import textwrap
from google.colab import drive
from langchain.llms import OpenAI
from langchain.embeddings import OpenAIEmbeddings
from langchain.chains import LLMChain, HypotheticalDocumentEmbedder
from langchain.prompts import PromptTemplate


In [3]:
drive.mount('/content/drive')

Mounted at /content/drive


In [4]:
def print_wrapped(text):
  wrapper = textwrap.TextWrapper(width=100)
  word_list = wrapper.wrap(text=text)

  for element in word_list:
      print(element)

In [None]:
os.environ['OPENAI_API_KEY'] = input("Paste OpenAI API Key: ")

In [6]:
def get_source_from_github(repo_path):
    api_url = f'https://api.github.com/repos/{repo_path}/contents'
    drive_path = '/content/drive/My Drive/'
    repo_dir = os.path.join(drive_path, repo_path.replace('/', '_')) # replacing '/' with '_' as '/' is not allowed in directory names
    file_contents = {}
    total_size = 0
    if not os.path.exists(repo_dir) or not os.listdir(repo_dir): # if the directory does not exist or is empty
        file_contents, total_size = download_from_github(api_url, repo_dir)
    else:
        for dirpath, dirnames, filenames in os.walk(repo_dir):
            for filename in filenames:
                filepath = os.path.join(dirpath, filename)
                try:
                    with open(filepath, 'r') as f:
                        file_content = f.read()
                        file_contents[filepath] = file_content
                        total_size += len(file_content)
                except UnicodeDecodeError:
                    pass
    if total_size <= 1000:
        print(f"Failed to extract enough source data - file content size = {total_size} chars")
    return file_contents

def download_from_github(api_url, local_dir):
    file_contents = {}
    total_size = 0
    response = requests.get(api_url)
    if response.status_code == 200:
        files = response.json()
        for file in files:
            if file['type'] == 'file':
                file_response = requests.get(file['download_url'])
                extension = file['download_url'].split(".")[-1]
                if file_response.status_code == 200:
                    local_file_path = os.path.join(local_dir, file['path'])
                    os.makedirs(os.path.dirname(local_file_path), exist_ok=True)
                    with open(local_file_path, 'wb') as f:
                        f.write(file_response.content)
                    try:
                        with open(local_file_path, 'r') as f:
                            file_content = f.read()
                            file_contents[local_file_path] = file_content
                            total_size += len(file_content)
                    except UnicodeDecodeError:
                        continue
            elif file['type'] == 'dir':
                sub_dir_contents, sub_dir_size = download_from_github(file['url'], local_dir)
                file_contents.update(sub_dir_contents)
                total_size += sub_dir_size
    else:
        print(f'Error: received status code {response.status_code} from GitHub.')
    return file_contents, total_size

In [7]:
def extract_all_text_chunks(repo_path, chunk_size = 500):
    text_splitter = RecursiveCharacterTextSplitter(
      chunk_size = chunk_size,
      chunk_overlap  = 0,
      length_function = len,
    )
    all_files = get_source_from_github(repo_path)
    documents = text_splitter.create_documents(texts = list(all_files.values()), metadatas = [{"source": p} for p in list(all_files.keys())])
    return documents


In [13]:
def cli_ask_questions(repo_path):
    chunks = extract_all_text_chunks(repo_path)
    base_embeddings = OpenAIEmbeddings()
    hyde_prompt_template = """Generate a hypothetical code snippet to answer this question
    Question: {question}
    Snippet:"""
    prompt = PromptTemplate(input_variables=["question"], template=hyde_prompt_template)
    llm_chain = LLMChain(llm=OpenAI(), prompt=prompt)
    embeddings = HypotheticalDocumentEmbedder(llm_chain=llm_chain, base_embeddings=base_embeddings)
    vectorstore = FAISS.from_documents(chunks, embeddings)
    chain_type_kwargs = {
        "prompt": ChatPromptTemplate.from_messages([
            SystemMessagePromptTemplate.from_template(
                "You are a helpful assistant that answers questions about a codebase given some snippets from the codebase. Whenever useful and possible, you directly quote the code and cite your sources."
            ),
            HumanMessagePromptTemplate.from_template("""
                Some relevant code snippets:

                {context}

                Question: {question}
                Answer:
            """)
        ]),
        "document_variable_name": "context"
    }

    qa = RetrievalQAWithSourcesChain.from_chain_type(
        llm=ChatOpenAI(model_name='gpt-3.5-turbo'),
        chain_type="stuff",
        retriever=vectorstore.as_retriever(),
        chain_type_kwargs=chain_type_kwargs
    )

    while True:
        question = input("Enter your question about the codebase (or 'quit' to stop): ")
        if question.lower() == 'quit':
            break
        else:
            try:
                response = qa(question)
                print_wrapped(question)
                print_wrapped(response['answer'])
            except Exception as e:
                print("An error occurred while processing your question.")
                print(str(e))


In [None]:
repo_path = input("Enter the relative path of the GitHub repo you want to ask questions about: ")
cli_ask_questions(repo_path)