In [1]:
import os
from tqdm import tqdm

from transformers import pipeline

from langchain.document_loaders import TextLoader
from langchain.embeddings.sentence_transformer import SentenceTransformerEmbeddings
from langchain.text_splitter import CharacterTextSplitter, RecursiveCharacterTextSplitter, Language
from langchain.vectorstores import Chroma
from langchain.docstore.document import Document
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain_core.embeddings import Embeddings

  from .autonotebook import tqdm as notebook_tqdm


### Load code and chunk it 

In [2]:
# Load the files 
all_files = []
for root, dirs, files in os.walk("data/processed/", topdown=True):
   files = [os.path.join(root, f) for f in files if f.endswith('.py') or f.endswith('.md')]
   all_files += files

print(f'Has {len(all_files)} files, e.g.: {all_files[0]}')

# Read the files and create documents
texts = []
for file in all_files: 
    with open(file) as f:
        try:  
            file_content = f.read()
            if file_content:
                texts.append(Document(page_content=file_content, metadata={"filename": file}))
        except Exception as e:
            print(f"Error reading file {file}: {e}")

print(f"Has {len(texts)} documents, e.g. {str(texts[0])[:50]}")

Has 19159 files, e.g.: data/processed/coronavirus-tg-api-master/tasks.py
Has 18116 documents, e.g. page_content='"""\ntasks.py\n--------\nProject inv


In [3]:
text_splitter = CharacterTextSplitter(
    # language=Language.PYTHON, 
    chunk_size=1000, 
    chunk_overlap=100, 
    separator="\n\n", 
    is_separator_regex=False,
)
chunks = text_splitter.split_documents(texts)

approx_tokens = sum([len(t.page_content) for t in chunks])

Created a chunk of size 1117, which is longer than the specified 1000
Created a chunk of size 1207, which is longer than the specified 1000
Created a chunk of size 1070, which is longer than the specified 1000
Created a chunk of size 1103, which is longer than the specified 1000
Created a chunk of size 1394, which is longer than the specified 1000
Created a chunk of size 1096, which is longer than the specified 1000
Created a chunk of size 1086, which is longer than the specified 1000
Created a chunk of size 2883, which is longer than the specified 1000
Created a chunk of size 1357, which is longer than the specified 1000
Created a chunk of size 2357, which is longer than the specified 1000
Created a chunk of size 1428, which is longer than the specified 1000
Created a chunk of size 1366, which is longer than the specified 1000
Created a chunk of size 2642, which is longer than the specified 1000
Created a chunk of size 1039, which is longer than the specified 1000
Created a chunk of s

In [4]:
print(f"Embedding cost with OpenAI: {round(approx_tokens / 1000 * 0.0001, 2)}$")
print(f"Number of chunks: {len(chunks)}")

Embedding cost with OpenAI: 3.64$
Number of chunks: 49536


In [5]:
for i in [0, 1, 100, 101, 1000, 1001]: 
    print(chunks[i].page_content)
    print('-' * 80)

"""
tasks.py
--------
Project invoke tasks

Available commands
  invoke --list
  invoke fmt
  invoke sort
  invoke check
"""
import random

import invoke

TARGETS_DESCRIPTION = "Paths/directories to format. [default: . ]"


@invoke.task(help={"targets": TARGETS_DESCRIPTION})
def sort(ctx, targets="."):
    """Sort module imports."""
    print("sorting imports ...")
    args = ["isort", "-rc", "--atomic", targets]
    ctx.run(" ".join(args))


@invoke.task(pre=[sort], help={"targets": TARGETS_DESCRIPTION})
def fmt(ctx, targets="."):
    """Format python source code & sort imports."""
    print("formatting ...")
    args = ["black", targets]
    ctx.run(" ".join(args))


@invoke.task
def check(ctx, fmt=False, sort=False, diff=False):  # pylint: disable=redefined-outer-name
    """Check code format and import order."""
    if not any([fmt, sort]):
        fmt = True
        sort = True

    fmt_args = ["black", "--check", "."]
    sort_args = ["isort", "-rc", "--check", "."]
-------------

### Embedd it 

In [6]:
import torch
from unixcoder import UniXcoder

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = UniXcoder("microsoft/unixcoder-base")
model.to(device)

UniXcoder(
  (model): RobertaModel(
    (embeddings): RobertaEmbeddings(
      (word_embeddings): Embedding(51416, 768, padding_idx=1)
      (position_embeddings): Embedding(1026, 768, padding_idx=1)
      (token_type_embeddings): Embedding(10, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): RobertaEncoder(
      (layer): ModuleList(
        (0-11): 12 x RobertaLayer(
          (attention): RobertaAttention(
            (self): RobertaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): RobertaSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm

In [7]:
model.config.max_position_embeddings

1026

In [8]:
class UnixcoderEmbeddings(Embeddings):
    """Interface for embedding models."""

    def __init__(self): 
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model = UniXcoder("microsoft/unixcoder-base")
        self.model = self.model.to(self.device)

    def embed_documents(self, texts):
        """Embed search docs."""
        print(f"Embedding documents: {len(texts)}")
        embeddings = []
        for text in tqdm(texts):
            embedding = self.embed_query(text) 
            embeddings.append(embedding)

        return embeddings
        
    def embed_query(self, text: str):
        """Embed query text."""
        text = chunks[0].page_content
        input_ids = self.model.tokenizer.encode(text, return_tensors="pt").to(self.device)
        if input_ids.shape[1] > self.model.config.max_position_embeddings:
            input_ids = input_ids[:, :self.model.config.max_position_embeddings]
        _, embedding = self.model(input_ids)
        embedding = embedding.squeeze(0).tolist()
        return embedding


In [12]:
def embedd(embedding_function, chunks, persist_directory):
    db = Chroma.from_documents(chunks, embedding_function, persist_directory=persist_directory)
    return db

db = embedd(embedding_function=UnixcoderEmbeddings(), persist_directory="./unixcoder_embeddings", chunks=chunks)

Embedding documents: 41666


100%|██████████| 41666/41666 [06:58<00:00, 99.48it/s] 


Embedding documents: 7870


100%|██████████| 7870/7870 [01:18<00:00, 100.38it/s]


In [13]:
db

<langchain.vectorstores.chroma.Chroma at 0x7ff5ccc77fd0>

In [20]:
retriever = db.as_retriever(
    search_type="mmr",  # Also test "similarity"
    search_kwargs={"k": 8},
)

for filename in os.listdir("test_data"):
    with open(os.path.join("test_data", filename)) as f:
        text = f.read()

    print(f"Query: {filename}")
    print("Results:")
    for result in db.similarity_search(text):
        print(result)
    print("=" * 80)

Query: serve_model_llm
Results:
page_content='from sqlalchemy.orm import declarative_base\n\nBase = declarative_base()\n\n\nclass BaseDBModel(Base):\n    __abstract__ = True' metadata={'filename': 'data/processed/fastapi-todo-rest-master/application/src/models/base.py'}
page_content='from datetime import datetime\nfrom typing import TYPE_CHECKING\n\nfrom sqlalchemy import Column, DateTime, Integer, String\nfrom sqlalchemy.orm import relationship\nfrom sqlalchemy.sql import func\n\nfrom .base import BaseDBModel\n\nif TYPE_CHECKING:  # pragma: no cover\n    from .todo_item import TodoItem  # noqa: F401\n\n\nclass User(BaseDBModel):\n    __tablename__: str = "users"\n\n    id: int | None = Column(Integer, primary_key=True)\n\n    username: str = Column(String, unique=True, nullable=False)\n    email: str = Column(String, unique=True, nullable=False)\n    full_name: str | None = Column(String, nullable=True)\n\n    hashed_password: str = Column(String, nullable=False)\n\n    # timestamps a