#### Directory Structure
In order to ensure the notebook works as intended, please match the following directory structure:


    -- data/
        -- pdfs/
            -- Folder containing the folders of TCGA pathology reports downloads
      
        -- raw_ocr_txt/
            -- Folder containing OCR converted pdfs as .txt 
    
        -- clean_ocr_txt/
            -- Folder containing .txt files after LLM pass to correct mislabelled characters


#### Gather PDF filepaths

In [None]:
from PIL import Image
from pdf2image import convert_from_path
import pytesseract
import os
import time
from tqdm import trange, tqdm

# pytesseract.pytesseract.tesseract_cmd = '/opt/homebrew/bin/tesseract' # For Matthew
pytesseract.pytesseract.tesseract_cmd = '/opt/homebrew/Cellar/tesseract/5.3.4/bin/tesseract'

file = open('./data_filepath.txt','r')
base_dir = file.read()
print(base_dir)
file.close()
pdfs_base_dir = base_dir + 'pdfs/'

pdf_paths = {}
for root, dirs, files in os.walk(pdfs_base_dir):
    for name in files:
        if name.endswith('.PDF'):
            
            uuid = name.rstrip(".PDF").split(".")[-1]
            pdf_path = os.path.join(root, name)
            
            if uuid in pdf_paths.keys():
                raise Exception("UUID already present in pdfs, are there duplicates?")
                
            pdf_paths[uuid] = pdf_path

#### Convert scanned pdfs to txt using OCR

In [None]:
raw_ocr_base_dir = base_dir + 'raw_ocr_txt/'

already_converted_docs = []
for root, dirs, files in os.walk(raw_ocr_base_dir):
    for name in files:
        if name.endswith('.txt'):
            uuid = name.rstrip(".txt")
            already_converted_docs.append(uuid)
            

docs = {}
for uuid, pdf_path in tqdm(pdf_paths.items()):
    txt_file_path = raw_ocr_base_dir + uuid + '.txt'
    t0 = time.time()
    print("Processing UUID", uuid)
    
    if uuid in already_converted_docs:
        print("UUID already converted - reading in txt file\n")
        with open(txt_file_path, 'r') as file:
            doc = file.read()
        
        docs[uuid] = doc
        
    else:
        images = convert_from_path(pdf_path)
        pages = [pytesseract.image_to_string(image) for image in images]
        doc = ''.join(pages)
        
        with open(txt_file_path, 'w') as file:
            file.write(doc)
            
        docs[uuid] = doc
        print(f"Doc contained {len(pages)} page{'s' if len(pages) > 1 else ''}, completed in {round(time.time() - t0,2)} seconds\n")

print("\n----------------------------------")
print("       PROCESSING COMPLETE")
print("----------------------------------")

#### Pass OCR text to Llama2 LLM to fix spelling/incorrect characters

In [None]:
# Setup API
from openai import OpenAI
import json

api_key_path = f"{base_dir}../assets/secrets.json"
with open(api_key_path, 'r') as file:
    content = file.read()
    secrets = json.loads(content)


openai = OpenAI(
    api_key=secrets['api_key'],
    base_url="https://api.deepinfra.com/v1/openai",
)

model = "meta-llama/Llama-2-13b-chat-hf"

In [None]:
def chunk_split(doc, chunk_size=1024):
    words = doc.split()
    chunks = []
    for i in range(0, len(words), chunk_size):
        chunk = words[i:(i+chunk_size)]
        chunk = ' '.join(chunk)
        chunks.append(chunk)

    return chunks

In [None]:
clean_ocr_base_dir = base_dir + 'clean_ocr_txt/'

already_cleaned_docs = []
for root, dirs, files in os.walk(clean_ocr_base_dir):
    for name in files:
        if name.endswith('.txt'):
            uuid = name.rstrip(".txt")
            already_cleaned_docs.append(uuid)
            

# Define prompt for each LLM pass
prompt = "Please spellcheck and correct the following text (a pathology report converted to text from pdf using OCR):\n"

# Set params for api response stream
prompt_tokens = 2048
max_tokens = 4096 - prompt_tokens
stream = True

# Text correction and recovery
cleaned_docs = {}
chunk_size = 512
total_tokens = 0
total_words = 0
for uuid in tqdm(docs):
        txt_file_path = clean_ocr_base_dir + uuid + '.txt'
        t0 = time.time()
        print("\nProcessing UUID", uuid)

        if uuid in already_cleaned_docs:
            print("UUID already cleaned - reading in txt file\n")
            with open(txt_file_path, 'r') as file:
                doc = file.read()

            cleaned_docs[uuid] = doc

        else:
            raw_doc = '\n\n[Begin Document]\n' + docs[uuid] + '[End Document]'
            total_words += len(raw_doc.split())
            raw_chunks = chunk_split(raw_doc, chunk_size)
            
            clean_chunks = []
            for raw_chunk in raw_chunks:
                cat_prompt = " ".join([prompt, raw_chunk])
                response = openai.chat.completions.create(
                    model=model,
                    messages=[{"role": "user", "content": cat_prompt}],
                    stream = stream, 
                    max_tokens=max_tokens 
                )
            
            

                for event in response:
                    if hasattr(event, 'error_type'):
                        print(event)
                        print("_______________________________")
                        print("This document failed to process")
                        token_count = 'unknown'
                    else:
                        tkn = event.choices[0].delta.content
                        if tkn is not None:
                            clean_chunks.append(tkn)
                        else:
                            token_count = event.usage['total_tokens']
                            total_tokens += token_count
                            
            doc = ''.join(clean_chunks)
            
            with open(txt_file_path, 'w') as file:
                file.write(doc)
    
            cleaned_docs[uuid] = doc
            if not token_count == 'unknown':
                print(f"Cleaning completed in {round(time.time() - t0,2)} seconds, cost {token_count} tokens")

                
print(f"\n\nFinal Cost: {total_tokens} tokens")
print(f"Average token / word: {total_tokens / total_words}")
print("\n----------------------------------")
print("       PROCESSING COMPLETE")
print("----------------------------------")