In [5]:
import json
import copy

from transformers import AutoTokenizer
from tqdm.auto import tqdm
from vllm import LLM, SamplingParams
import datasets
import torch
import numpy as np

from tasks import TASKS, ISSUE_TASKS
from utils import generate_prompts
from evaluation import EXACT_MATCH_BALANCED_ACC_TASKS, evaluate

  from .autonotebook import tqdm as notebook_tqdm


INFO 02-28 03:42:45 __init__.py:183] Automatically detected platform cuda.


2025-02-28 03:42:45,379	INFO util.py:154 -- Missing packages: ['ipywidgets']. Run `pip install -U ipywidgets`, then restart the notebook server for rich notebook output.


In [7]:
existing = pd.read_json('fp32_saul-7b_evaluation_em-full.jsonl', lines=True)
running = pd.read_json('Saul-7B-Base_evaluations.jsonl', lines=True)

In [26]:
total = pd.concat((
    running[running.metric_name == 'task_specific'][['metric_name', 'score', 'pred', 'label', 'task']],
    existing
))

In [27]:
available_exact_match_tasks = copy.deepcopy(EXACT_MATCH_BALANCED_ACC_TASKS)
# It needs a separtate metric
index = available_exact_match_tasks.index("successor_liability")
del available_exact_match_tasks[index]

# This task doesn't even exist
index = available_exact_match_tasks.index("intra_rule_distinguishing")
del available_exact_match_tasks[index]

tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B")
other_tasks = [
    'sara_numeric',
    'successor_liability',
    'citation_prediction_open',
    'definition_extraction',
    'ssla_company_defendants',
    'ssla_individual_defendants',
    'ssla_plaintiff',
]
all_tasks = sorted(other_tasks + available_exact_match_tasks)

In [49]:
token_counts = []

for task in tqdm(all_tasks):
    with open(f"tasks/{task}/base_prompt.txt") as in_file:
        prompt_template = in_file.read()

    dataset = datasets.load_dataset("nguha/legalbench", task)
    test_df = dataset["test"].to_pandas()
    answers = test_df["answer"].tolist()
    
    prompts = generate_prompts(prompt_template=prompt_template, data_df=test_df)
    prompt_tokens = [len(tokenizer.encode(text)) for text in prompts]
    answer_tokens = [len(tokenizer.encode(answer)) for answer in answers]
    full_tokens = [a + b for a, b in zip(prompt_tokens, answer_tokens)]
    token_counts.append((task, prompt_tokens, answer_tokens, full_tokens))

token_count_df = pd.DataFrame(token_counts, columns=['task', 'num_prompt_tokens', 'num_answer_tokens', 'num_full_tokens'])
total.sort_values('task', ascending=True, inplace=True)

100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 161/161 [03:19<00:00,  1.24s/it]


In [50]:
token_count_df

Unnamed: 0,task,num_prompt_tokens,num_answer_tokens,num_full_tokens
0,sara_numeric,"[7892, 7902, 7891, 7892, 7899, 7861, 7915, 786...","[4, 4, 4, 5, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, ...","[7896, 7906, 7895, 7897, 7903, 7865, 7919, 787..."
1,successor_liability,"[493, 494, 521, 530, 517, 528, 528, 562, 543, ...","[7, 7, 7, 7, 7, 7, 7, 7, 7, 3, 3, 3, 9, 3, 3, ...","[500, 501, 528, 537, 524, 535, 535, 569, 550, ..."
2,citation_prediction_open,"[196, 178, 208, 184, 200, 182, 179, 208, 232, ...","[9, 13, 14, 6, 7, 7, 14, 9, 16, 13, 10, 20, 13...","[205, 191, 222, 190, 207, 189, 193, 217, 248, ..."
3,definition_extraction,"[728, 691, 729, 803, 753, 716, 768, 799, 767, ...","[3, 2, 3, 4, 4, 3, 4, 3, 2, 3, 2, 6, 3, 2, 3, ...","[731, 693, 732, 807, 757, 719, 772, 802, 769, ..."
4,ssla_company_defendants,"[2052, 1847, 2064, 2077, 2088, 2052, 2101, 202...","[4, 7, 7, 9, 8, 7, 5, 6, 8, 6, 14, 8, 5, 5, 16...","[2056, 1854, 2071, 2086, 2096, 2059, 2106, 203..."
...,...,...,...,...
156,telemarketing_sales_rule,"[450, 422, 439, 438, 468, 432, 443, 436, 453, ...","[2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, ...","[452, 424, 441, 440, 470, 434, 445, 438, 455, ..."
157,textualism_tool_dictionaries,"[1356, 1312, 1306, 1416, 1307, 1347, 1319, 156...","[2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, ...","[1358, 1314, 1308, 1418, 1309, 1349, 1321, 156..."
158,textualism_tool_plain,"[1273, 871, 866, 924, 1063, 974, 1151, 1012, 9...","[2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, ...","[1275, 873, 868, 926, 1065, 976, 1153, 1014, 9..."
159,ucc_v_common_law,"[352, 354, 349, 346, 349, 358, 347, 350, 346, ...","[3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, ...","[355, 357, 352, 349, 352, 361, 350, 353, 349, ..."


In [53]:
final = pd.merge(total, token_count_df, on='task')

In [55]:
final.to_parquet('fp32_saul-7b_legalbench.parquet')