In [1]:
__author__ = "Jon Ball"
__version__ = "October 2023"

In [2]:
from langchain.embeddings import SentenceTransformerEmbeddings
from langchain.vectorstores import Chroma
import chromadb
from transformers import AutoTokenizer
from tqdm import tqdm
import numpy as np
import torch
import random
import jinja2
import json

In [3]:
# Set random seed 
random.seed(42)
_ = torch.manual_seed(42)

In [4]:
journals = ["se", "aje", "aerj", "asr", "ajs", "bjse", "isse", "random"]

In [5]:
# load local tokenizer
tokenizer = AutoTokenizer.from_pretrained("codellama/CodeLlama-7b-Instruct-hf")

In [6]:
model_name = "Muennighoff/SGPT-125M-weightedmean-nli-bitfit"
model_kwargs = {"device": "cpu"}
encode_kwargs = {"normalize_embeddings": True}
embedding = SentenceTransformerEmbeddings(
    model_name=model_name,
    model_kwargs=model_kwargs,
    encode_kwargs=encode_kwargs
)

In [7]:
# Access local chroma client
persistent_client = chromadb.PersistentClient(path="../chroma")
# Vector db of 10 hand-labeled examples
manualDB = Chroma(
    client=persistent_client,
    collection_name="manual",
    embedding_function=embedding,
    )
# Vector db of 90 examples labeled by GPT-4
gpt4DB = Chroma(
    client=persistent_client,
    collection_name="gpt4",
    embedding_function=embedding,
    )

LLaMA2-Code has a functional context length of 16k tokens.

## Prompt Lengths

In [8]:
class jinjaLoader():
    # jinja2 template renderer
    def __init__(self, template_dir, template_file):
        self.templateLoader = jinja2.FileSystemLoader(searchpath=template_dir)
        self.templateEnv = jinja2.Environment( loader=self.templateLoader )
        self.template = self.templateEnv.get_template( template_file )

    def render(self, templateVars):
        return self.template.render( templateVars )

In [9]:
jinjitsu = jinjaLoader("../prompts", "fewshot.prompt")

In [10]:
def prompt_len(n_examples):
    for journal in journals:
        prompt_lens = []
        path = f"../data/eric/{journal}.json"
        with open(path, "r") as infile:
            d = json.load(infile)
        records = d["response"]["docs"]
        for rec in tqdm(records):
            input = str(rec)
            # Pull most similar manually labeled doc
            examples = [ex.metadata for ex in manualDB.similarity_search(input, 1)]
            if n_examples > 1:
                # Pull most similar GPT-4 labeled docs
                examples += [ex.metadata for ex in gpt4DB.similarity_search(input, n_examples-1)]
            templateVars = {
                "examples": examples,
                "input": input, 
                "output": ""
                }
            PROMPT = jinjitsu.render(templateVars)
            prompt_lens.append(len(tokenizer.tokenize(PROMPT)))
        print(f"Journal: {journal.upper()}, {len(records)} records")
        print(f"Mean: {np.mean(prompt_lens)} | Std: {np.std(prompt_lens)} | Max: {np.max(prompt_lens)} | Min: {np.min(prompt_lens)}", "\n")

### Average one-shot prompt length in tokens for each journal (/w RAG):

In [11]:
prompt_len(1)

  0%|          | 0/1035 [00:00<?, ?it/s]

100%|██████████| 1035/1035 [01:09<00:00, 14.85it/s]


Journal: SE, 1035 records
Mean: 1396.9942028985506 | Std: 168.62513083464847 | Max: 1931 | Min: 1129 



100%|██████████| 789/789 [00:47<00:00, 16.58it/s]


Journal: AJE, 789 records
Mean: 1445.8022813688212 | Std: 166.52439487530037 | Max: 1926 | Min: 1197 



100%|██████████| 1912/1912 [01:42<00:00, 18.58it/s]


Journal: AERJ, 1912 records
Mean: 1407.5261506276152 | Std: 160.29819997264798 | Max: 1913 | Min: 1099 



100%|██████████| 350/350 [00:18<00:00, 18.70it/s]


Journal: ASR, 350 records
Mean: 1341.4714285714285 | Std: 145.04788582972682 | Max: 1774 | Min: 1113 



100%|██████████| 333/333 [00:16<00:00, 19.59it/s]


Journal: AJS, 333 records
Mean: 1264.6606606606606 | Std: 53.20685043203422 | Max: 1417 | Min: 1125 



100%|██████████| 1223/1223 [01:06<00:00, 18.41it/s]


Journal: BJSE, 1223 records
Mean: 1525.9950940310712 | Std: 119.5487152388556 | Max: 2357 | Min: 1177 



100%|██████████| 386/386 [00:21<00:00, 17.59it/s]


Journal: ISSE, 386 records
Mean: 1552.658031088083 | Std: 103.77405283460428 | Max: 1872 | Min: 1251 



100%|██████████| 1022/1022 [00:56<00:00, 18.19it/s]

Journal: RANDOM, 1022 records
Mean: 1363.2093933463796 | Std: 170.18018880948733 | Max: 2069 | Min: 1099 






### Average two-shot prompt length in tokens for each journal (/w RAG):

In [12]:
prompt_len(2)

100%|██████████| 1035/1035 [01:54<00:00,  9.02it/s]


Journal: SE, 1035 records
Mean: 1874.3826086956522 | Std: 207.34145211504472 | Max: 2736 | Min: 1549 



100%|██████████| 789/789 [01:30<00:00,  8.76it/s]


Journal: AJE, 789 records
Mean: 1924.4613434727503 | Std: 197.77167408642066 | Max: 2620 | Min: 1583 



100%|██████████| 1912/1912 [03:37<00:00,  8.80it/s]


Journal: AERJ, 1912 records
Mean: 1881.0475941422594 | Std: 187.48371849663323 | Max: 2503 | Min: 1501 



100%|██████████| 350/350 [00:40<00:00,  8.66it/s]


Journal: ASR, 350 records
Mean: 1816.8771428571429 | Std: 177.61684055567366 | Max: 2510 | Min: 1557 



100%|██████████| 333/333 [00:42<00:00,  7.92it/s]


Journal: AJS, 333 records
Mean: 1721.3783783783783 | Std: 106.62133502285099 | Max: 2118 | Min: 1500 



100%|██████████| 1223/1223 [02:21<00:00,  8.67it/s]


Journal: BJSE, 1223 records
Mean: 2009.5486508585445 | Std: 155.69145445531694 | Max: 2718 | Min: 1624 



100%|██████████| 386/386 [00:44<00:00,  8.69it/s]


Journal: ISSE, 386 records
Mean: 2026.6528497409327 | Std: 127.34146378385451 | Max: 2411 | Min: 1675 



100%|██████████| 1022/1022 [01:55<00:00,  8.86it/s]

Journal: RANDOM, 1022 records
Mean: 1844.1409001956947 | Std: 202.57372292462293 | Max: 2595 | Min: 1477 






### Average three-shot prompt length in tokens for each journal (/w RAG):

In [13]:
prompt_len(3)

100%|██████████| 1035/1035 [01:59<00:00,  8.68it/s]


Journal: SE, 1035 records
Mean: 2355.710144927536 | Std: 241.23078340722608 | Max: 3329 | Min: 1920 



100%|██████████| 789/789 [01:33<00:00,  8.47it/s]


Journal: AJE, 789 records
Mean: 2415.793409378961 | Std: 231.03918904964112 | Max: 3143 | Min: 1995 



100%|██████████| 1912/1912 [03:41<00:00,  8.65it/s]


Journal: AERJ, 1912 records
Mean: 2350.9680962343095 | Std: 220.7690616564291 | Max: 3141 | Min: 1893 



100%|██████████| 350/350 [00:42<00:00,  8.20it/s]


Journal: ASR, 350 records
Mean: 2296.14 | Std: 209.4781962059877 | Max: 3090 | Min: 1979 



100%|██████████| 333/333 [00:37<00:00,  8.96it/s]


Journal: AJS, 333 records
Mean: 2179.6576576576576 | Std: 142.42303619906113 | Max: 2681 | Min: 1907 



100%|██████████| 1223/1223 [02:22<00:00,  8.61it/s]


Journal: BJSE, 1223 records
Mean: 2487.085854456255 | Std: 181.9066503011347 | Max: 3179 | Min: 2050 



100%|██████████| 386/386 [00:45<00:00,  8.55it/s]


Journal: ISSE, 386 records
Mean: 2513.5259067357515 | Std: 168.30473103707692 | Max: 3107 | Min: 2131 



100%|██████████| 1022/1022 [01:56<00:00,  8.76it/s]

Journal: RANDOM, 1022 records
Mean: 2315.8796477495107 | Std: 226.3151349553457 | Max: 3155 | Min: 1915 






### Average four-shot prompt length in tokens for each journal (/w RAG):

In [14]:
prompt_len(4)

100%|██████████| 1035/1035 [01:58<00:00,  8.70it/s]


Journal: SE, 1035 records
Mean: 2837.737198067633 | Std: 263.31329316127244 | Max: 3753 | Min: 2321 



100%|██████████| 789/789 [01:32<00:00,  8.56it/s]


Journal: AJE, 789 records
Mean: 2901.71989860583 | Std: 254.6127880921651 | Max: 3607 | Min: 2397 



100%|██████████| 1912/1912 [03:44<00:00,  8.53it/s]


Journal: AERJ, 1912 records
Mean: 2825.0700836820083 | Std: 245.35672796679208 | Max: 3663 | Min: 2283 



100%|██████████| 350/350 [00:40<00:00,  8.57it/s]


Journal: ASR, 350 records
Mean: 2771.614285714286 | Std: 229.97875012749347 | Max: 3554 | Min: 2402 



100%|██████████| 333/333 [00:37<00:00,  8.83it/s]


Journal: AJS, 333 records
Mean: 2647.612612612613 | Std: 172.41814027583425 | Max: 3210 | Min: 2331 



100%|██████████| 1223/1223 [02:23<00:00,  8.51it/s]


Journal: BJSE, 1223 records
Mean: 2970.773507767784 | Std: 211.08485941779585 | Max: 3795 | Min: 2495 



100%|██████████| 386/386 [00:47<00:00,  8.05it/s]


Journal: ISSE, 386 records
Mean: 2997.5932642487046 | Std: 196.21030764232538 | Max: 3720 | Min: 2563 



100%|██████████| 1022/1022 [02:07<00:00,  8.03it/s]

Journal: RANDOM, 1022 records
Mean: 2791.8972602739727 | Std: 248.20198479601603 | Max: 3688 | Min: 2323 






### Average five-shot prompt length in tokens for each journal (/w RAG):

In [15]:
prompt_len(5)

100%|██████████| 1035/1035 [02:01<00:00,  8.51it/s]


Journal: SE, 1035 records
Mean: 3324.623188405797 | Std: 287.98050967912997 | Max: 4500 | Min: 2786 



100%|██████████| 789/789 [01:30<00:00,  8.71it/s]


Journal: AJE, 789 records
Mean: 3383.493029150824 | Std: 282.39425651224406 | Max: 4255 | Min: 2802 



100%|██████████| 1912/1912 [03:36<00:00,  8.81it/s]


Journal: AERJ, 1912 records
Mean: 3309.470711297071 | Std: 272.30757126476175 | Max: 4237 | Min: 2680 



100%|██████████| 350/350 [00:39<00:00,  8.82it/s]


Journal: ASR, 350 records
Mean: 3263.8142857142857 | Std: 256.36490136729174 | Max: 4343 | Min: 2793 



100%|██████████| 333/333 [00:38<00:00,  8.67it/s]


Journal: AJS, 333 records
Mean: 3119.5495495495497 | Std: 203.76953973207063 | Max: 3834 | Min: 2709 



100%|██████████| 1223/1223 [02:23<00:00,  8.53it/s]


Journal: BJSE, 1223 records
Mean: 3457.340147179068 | Std: 232.8480048785883 | Max: 4365 | Min: 2925 



100%|██████████| 386/386 [00:44<00:00,  8.63it/s]


Journal: ISSE, 386 records
Mean: 3473.9093264248704 | Std: 205.10261444427738 | Max: 4187 | Min: 3049 



100%|██████████| 1022/1022 [01:56<00:00,  8.79it/s]

Journal: RANDOM, 1022 records
Mean: 3268.6497064579257 | Std: 269.3551671335779 | Max: 4145 | Min: 2740 






### Correct JSON response length in tokens:

In [16]:
all_false = str({"quantitative": False, "qualitative": False, "primary/secondary": False, "tertiary": False, "inequality": False, "nonstructural": False, "culture": False, "school": False, "state": False, "labor": False, "comparative": False, "methods": False})
print(f"Length of max expected JSON output in tokens: {len(tokenizer.tokenize(all_false))}")

Length of max expected JSON output in tokens: 74
