This notebook provides a basic illustration of how to use different parts of LegalBench. 

In [3]:
import json
import copy

from tqdm.auto import tqdm
from vllm import LLM, SamplingParams
import datasets
import pandas as pd
import torch
import numpy as np

from tasks import TASKS, ISSUE_TASKS
from utils import generate_prompts
from evaluation import EXACT_MATCH_BALANCED_ACC_TASKS, evaluate, evaluate_sara_numeric_acc, evaluate_successor_liability, evaluate_citation_open, evaluate_definition_extraction, evaluate_ssla

### HuggingFace Model Evaluation

In [None]:
hf_org = "Equall"
model_name = "Saul-7B-Base"
seed = 10
batch_size = 4

sampling_params = SamplingParams(
    temperature=0.0, # Greedy
    seed=seed,
    stop=["\n"],
    max_tokens=16,
)
model = LLM(
    model=f"{hf_org}/{model_name}",
    trust_remote_code=True,
    seed=seed,
    dtype="float16",
    gpu_memory_utilization=0.95,
    max_num_seqs=batch_size,
)

In [None]:
available_exact_match_tasks = copy.deepcopy(EXACT_MATCH_BALANCED_ACC_TASKS)

# It needs a separtate metric
index = available_exact_match_tasks.index("successor_liability")
del available_exact_match_tasks[index]

# This task doesn't even exist
index = available_exact_match_tasks.index("intra_rule_distinguishing")
del available_exact_match_tasks[index]

In [None]:
for task in tqdm(available_exact_match_tasks):
    with open(f"tasks/{task}/base_prompt.txt") as in_file:
        prompt_template = in_file.read()

    dataset = datasets.load_dataset("nguha/legalbench", task)
    test_df = dataset["test"].to_pandas()
    
    prompts = generate_prompts(prompt_template=prompt_template, data_df=test_df)
    outputs = model.generate(prompts, sampling_params)
    generations = [o.outputs[0].text.strip() for o in outputs]
    
    accuracy = evaluate(task, generations, test_df["answer"].tolist())
    data = {"metric_name": "exact_match_balanced_accuracy", "score": accuracy, "task": task}
    with open(f'{model_name}_evaluations.jsonl', 'a', encoding='utf-8') as f:
        json_line = json.dumps(data, ensure_ascii=False)
        f.write(json_line + '\n')

In [None]:
available_other_tasks = [
    'sara_numeric',
    'successor_liability',
    'citation_prediction_open',
    'definition_extraction',
    'ssla_company_defendants',
    'ssla_individual_defendants',
    'ssla_plaintiff',
]

In [None]:
for task in tqdm(available_other_tasks):
    with open(f"tasks/{task}/base_prompt.txt") as in_file:
        prompt_template = in_file.read()

    dataset = datasets.load_dataset("nguha/legalbench", task)
    test_df = dataset["test"].to_pandas()
    answers = test_df["answer"].tolist()
    
    prompts = generate_prompts(prompt_template=prompt_template, data_df=test_df)
    outputs = model.generate(prompts, sampling_params)
    generations = [o.outputs[0].text.strip() for o in outputs]

    if task == "sara_numeric":
        accuracy = evaluate_sara_numeric_acc(generations, answers)
    elif task == "successor_liability":
        accuracy = evaluate_successor_liability(generations, answers)
    elif task == "citation_prediction_open":
        accuracy = evaluate_citation_open(generations, answers)
    elif task == "definition_extraction":
        accuracy = evaluate_definition_extraction(generations, answers)
    elif task.startswith("ssla"):
        accuracy = evaluate_ssla(generations, answers)
    
    data = {"metric_name": "task_specific", "score": accuracy, "task": task}
    with open(f'{model_name}_evaluations.jsonl', 'a', encoding='utf-8') as f:
        json_line = json.dumps(data, ensure_ascii=False)
        f.write(json_line + '\n')