## Installs

In [None]:
!pip install langchain langchain_community langchain-openai langchainhub chromadb tiktoken -q

In [None]:
! pip install nbstripout -q

In [None]:
! pip install langchain-text-splitters -q

In [None]:
import os
import time
import json
from pprint import pprint
import pandas as pd

import langchain
print("langchain.__version__ ", langchain.__version__)

from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import Chroma
from langchain_openai.embeddings import OpenAIEmbeddings
from langchain_core.documents import Document

In [None]:
from dotenv import load_dotenv
load_dotenv()

## Warm up & Config

In [None]:
raw_docs_base_dir = '../data/processed/p_jsons'

In [None]:
## check the len of each doc
all_len = []
all_char_len = []
for item in os.listdir(raw_docs_base_dir):
    with open(os.path.join(raw_docs_base_dir, item), 'r') as f:
        data = json.load(f)
        all_len.append(len(' '.join(data['doc_judgement']).split()))
        all_char_len.append(len(' '.join(data['doc_judgement'])))

pprint(pd.Series(all_len).describe())
pprint(pd.Series(all_char_len).describe())

print('Percent len > 7k: ', (len([item for item in all_len if item > 7000])/len(all_len)) * 100)

## Utils

In [None]:
def custom_chunker(text: str):

    # based on len of doc, we can set different chunk size
    num_chars = len(text)

    if num_chars < 3000:
        return [text]
    
    elif num_chars > 3000 and num_chars < 12000:
        splitter = RecursiveCharacterTextSplitter(chunk_size=3000, 
                                                  chunk_overlap=300, 
                                                  separators=["\n\n", "\n", ".", " "])
        return splitter.split_text(text)
        
    else:
        coarse_splitter = RecursiveCharacterTextSplitter(chunk_size=9000,
                                                         chunk_overlap=900,
                                                         separators=["\n\n", "\n", ".", " "])
        coarse_chunks = coarse_splitter.split_text(text)
        fine_splitter = RecursiveCharacterTextSplitter(chunk_size=3000,
                                                       chunk_overlap=300,
                                                       separators=["\n\n", "\n", ".", " "])
        final_chunks = []
        for coarse_chunk in coarse_chunks:
            fine_chunks = fine_splitter.split_text(coarse_chunk)
            final_chunks.extend(fine_chunks)

        return final_chunks

## Chunking

In [None]:
## exec: All chunks Extraction 

chunks_all = []
for item in os.listdir(raw_docs_base_dir):
    with open(os.path.join(raw_docs_base_dir, item), 'r') as f:
        data = json.load(f)
        doc_text = ' '.join(data['doc_judgement'])
        chunks = custom_chunker(doc_text)
        print(f"Document: {item}, Original Length: {len(doc_text)}, Number of Chunks: {len(chunks)}")
        for idx, chunk in enumerate(chunks):
            chunk_metadata = {
                'source_doc': item,
                'chunk_index': idx,
                'original_length': len(doc_text)
            }
            chunks_all.append((chunk, chunk_metadata))
        print("\n")

In [None]:
chunks_all

In [None]:
lc_documents = [Document(page_content=item[0], metadata=item[1]) for item in chunks_all]
print(len(lc_documents))

import random
print(random.choice(lc_documents))

## Indexing

In [None]:
# initialize the chroma dir
vector_store_chroma = Chroma(collection_name='legal_mini_rag', 
                             embedding_function=OpenAIEmbeddings(),
                             persist_directory='/tmp/chroma_db_test'
                             )

In [None]:
vector_store_chroma.add_documents(lc_documents)
# vector_store_chroma.persist() # to save them to disk  

In [None]:
## test collection 
my_collection = vector_store_chroma._collection
print('Total docs indexed: ', my_collection.count())

random_embedding = my_collection.get(include=["embeddings"], limit=1)
print('embedding len: ', random_embedding['embeddings'].shape)

## Retrieval

## Generation