In [2]:
import sys
import os
import numpy as np

SCRIPT_DIR = os.path.dirname(os.path.abspath("..."))
sys.path.append(os.path.dirname(SCRIPT_DIR))

from training.transcript_trainer import PROMPT_FORMAT, RESPONSE_KEY_NL, END_KEY
from datasets import load_from_disk
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    PreTrainedModel,
    PreTrainedTokenizer,
)

# Load goldern dataset that is unseen in train

In [5]:

test_data = load_from_disk("/opt/home/bo_ling/dataset/gds_v4_simplify.hf")
test_data

Dataset({
    features: ['instruction', 'input', 'output', 'text'],
    num_rows: 182662
})

In [6]:
test_data[0]

{'instruction': 'Extract dob from the following input:',
 'input': 'California DRIVER LICENSE 081202 LN SOTO MUNGUIA FN ALEC PIERRE 9939 VAN RUITEN ST BELLFLOWER, CA 90706 SEX M AM DONOR USA DLY7679862 EXP 08/12/2025 DOB 08/12/2002 AGE 2 IN 7023 DD 08/04/2021606A3/E5FD/25 CLASS C END NONE RSTR NONE 08122002 HAIR BRN EYES HZL HGT 5\'-11" WGT 205 lb FEDERAL LIMITS APPLY ISS 10/14/2021',
 'output': '\n2002-08-12',
 'text': 'Extract dob from the following input:\nCalifornia DRIVER LICENSE 081202 LN SOTO MUNGUIA FN ALEC PIERRE 9939 VAN RUITEN ST BELLFLOWER, CA 90706 SEX M AM DONOR USA DLY7679862 EXP 08/12/2025 DOB 08/12/2002 AGE 2 IN 7023 DD 08/04/2021606A3/E5FD/25 CLASS C END NONE RSTR NONE 08122002 HAIR BRN EYES HZL HGT 5\'-11" WGT 205 lb FEDERAL LIMITS APPLY ISS 10/14/2021\n\n### Response:\n\n2002-08-12\n'}

# Load checkpoint at step: 19600 (12% of total data in first epoch) 

In [8]:
local_output_dir="/opt/home/bo_ling/dolly_training/doc_transcript_pii_data_simplify_b1/checkpoint-64800"
#model, tokenizer = load_model_tokenizer_for_generate(local_output_dir)
import torch
tokenizer = AutoTokenizer.from_pretrained(local_output_dir, padding_side="left")
device = 'cpu' #"cuda:0" if torch.cuda.is_available() else "cpu"
model = AutoModelForCausalLM.from_pretrained(local_output_dir, trust_remote_code=True).to(device)

In [9]:
def generate_transcript_response(
    instruction: str,
    input_text: str,
    model: PreTrainedModel,
    tokenizer: PreTrainedTokenizer,
    do_sample: bool = True,
    max_new_tokens: int = 256,
    top_p: float = 0.92,
    top_k: int = 0,
    **kwargs,
) -> str:
    texts = PROMPT_FORMAT.format(instruction=instruction, input_text=input_text, output_text="")
    input_ids = tokenizer(texts, return_tensors="pt").input_ids.to(device)

    response_key_token_id = tokenizer.encode(RESPONSE_KEY_NL)[0]
    end_key_token_id = tokenizer.encode(END_KEY)[0]

    gen_tokens = model.generate(
        input_ids,
        pad_token_id=tokenizer.pad_token_id,
        # Ensure generation stops once it generates "### End"
        eos_token_id=end_key_token_id,
        do_sample=do_sample,
        max_new_tokens=max_new_tokens,
        top_p=top_p,
        top_k=top_k,
        **kwargs,
    )[0].cpu()

    # The response will be set to this variable if we can identify it.
    decoded = None

    # Find where "### Response:" is first found in the generated tokens.  Considering this is part of the prompt,
    # we should definitely find it.  We will return the tokens found after this token.
    response_pos = None
    response_positions = np.where(gen_tokens == response_key_token_id)[0]
    if len(response_positions) == 0:
        logger.warn(f"Could not find response key {response_key_token_id} in: {gen_tokens}")
    else:
        response_pos = response_positions[0]

    if response_pos:
        # Next find where "### End" is located.  The model has been trained to end its responses with this sequence
        # (or actually, the token ID it maps to, since it is a special token).  We may not find this token, as the
        # response could be truncated.  If we don't find it then just return everything to the end.  Note that
        # even though we set eos_token_id, we still see the this token at the end.
        end_pos = None
        end_positions = np.where(gen_tokens == end_key_token_id)[0]
        if len(end_positions) > 0:
            end_pos = end_positions[0]

        decoded = tokenizer.decode(gen_tokens[response_pos + 1 : end_pos]).strip()

    return decoded

# Print some sample output for human to read

In [10]:
count = 0
for d in test_data:
    instruction = d["instruction"]
    input_text= d["input"]
    generated = generate_transcript_response(instruction, input_text, model, tokenizer)
    expected = d['output']
    print("="*100)
    print("="*100)
    print("INSTRUCTION:")
    print(instruction)
    print("\nINPUT:")
    print(input_text)
    print("\nGENERATED:")
    print(generated)
    print("\nEXPECTED:")
    print(expected)
    count += 1
    if count > 20:
        break

INSTRUCTION:
Extract dob from the following input:

INPUT:
California DRIVER LICENSE 081202 LN SOTO MUNGUIA FN ALEC PIERRE 9939 VAN RUITEN ST BELLFLOWER, CA 90706 SEX M AM DONOR USA DLY7679862 EXP 08/12/2025 DOB 08/12/2002 AGE 2 IN 7023 DD 08/04/2021606A3/E5FD/25 CLASS C END NONE RSTR NONE 08122002 HAIR BRN EYES HZL HGT 5'-11" WGT 205 lb FEDERAL LIMITS APPLY ISS 10/14/2021

GENERATED:
2002-08-12-12

EXPECTED:

2002-08-12
INSTRUCTION:
Extract issue date from the following input:

INPUT:
California DRIVER LICENSE 081202 LN SOTO MUNGUIA FN ALEC PIERRE 9939 VAN RUITEN ST BELLFLOWER, CA 90706 SEX M AM DONOR USA DLY7679862 EXP 08/12/2025 DOB 08/12/2002 AGE 2 IN 7023 DD 08/04/2021606A3/E5FD/25 CLASS C END NONE RSTR NONE 08122002 HAIR BRN EYES HZL HGT 5'-11" WGT 205 lb FEDERAL LIMITS APPLY ISS 10/14/2021

GENERATED:
2021-10-14

EXPECTED:

2021-10-14
INSTRUCTION:
Extract expiration date from the following input:

INPUT:
California DRIVER LICENSE 081202 LN SOTO MUNGUIA FN ALEC PIERRE 9939 VAN RU

INSTRUCTION:
Extract issue date from the following input:

INPUT:
IDAHO Плисей ашки The Gem State DRIVER'S LICENSE 4d 4a Iss 4b Exp 9 Class D 9a End NONE 12 Rest B 1 DAWSON 2 REBECCA LYNN 3 DOB 8 1797 S MARSH WOOD PL MERIDIAN ID 83642-7456 15 Sex 16 Hot 17Wgt 18 Eyes 19 Hair F 5'-06" 140 lb BLU BLN 5 DD 010212930034 ZE3329031 10/20/2021 10/19/2025 10/19/1980

GENERATED:
2021-10-20

EXPECTED:

2021-10-20
INSTRUCTION:
Extract expiration date from the following input:

INPUT:
IDAHO Плисей ашки The Gem State DRIVER'S LICENSE 4d 4a Iss 4b Exp 9 Class D 9a End NONE 12 Rest B 1 DAWSON 2 REBECCA LYNN 3 DOB 8 1797 S MARSH WOOD PL MERIDIAN ID 83642-7456 15 Sex 16 Hot 17Wgt 18 Eyes 19 Hair F 5'-06" 140 lb BLU BLN 5 DD 010212930034 ZE3329031 10/20/2021 10/19/2025 10/19/1980

GENERATED:
2025-10-19

EXPECTED:

2025-10-19
INSTRUCTION:
Extract first name from the following input:

INPUT:
IDAHO Плисей ашки The Gem State DRIVER'S LICENSE 4d 4a Iss 4b Exp 9 Class D 9a End NONE 12 Rest B 1 DAWSON 2 REBECC

In [11]:
def normalize_str(s):
    return s.replace("\n", "").strip().lower()
normalize_str("OSCAR") == normalize_str("\n\nOscar ")

True

#  Compute accuracy based on instruction on goldern dataset

In [None]:
count = 0
statistics = {}
for d in test_data:
    instruction = d["instruction"]
    input_text= d["input"]
    generated = generate_transcript_response(instruction, input_text, model, tokenizer)
    expected = d['output']
    if instruction in statistics:
        stat = statistics[instruction]
    else:
        stat = {"eq": 0, "neq": 0}
        statistics[instruction] = stat
    if normalize_str(generated) == normalize_str(expected):
        stat["eq"] += 1
    else:
        stat["neq"] += 1
    count += 1
    if count > 3000:
        break

In [None]:
print(f"******************** Performance for doc transcript pii *****************\n")
for k ,v in statistics.items():
    v["accuracy"] = v["eq"] / (v["eq"] + v['neq'])
    print(f"=== The accuracy for task `{k}`: {v['accuracy']} on {(v['eq'] + v['neq'])}  samples ===")