In [30]:
import sys
import os
import numpy as np

SCRIPT_DIR = os.path.dirname(os.path.abspath("..."))
sys.path.append(os.path.dirname(SCRIPT_DIR))

from datasets import load_from_disk
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    PreTrainedModel,
    PreTrainedTokenizer,
    T5ForConditionalGeneration
)

# Load goldern dataset that is unseen in train

In [31]:

test_data = load_from_disk("/opt/home/bo_ling/dataset/gds_v4_t5.hf")
test_data

Dataset({
    features: ['input_text', 'output_text'],
    num_rows: 182662
})

In [32]:
test_data[0]

{'input_text': 'Extract dob from the following input: California DRIVER LICENSE 081202 LN SOTO MUNGUIA FN ALEC PIERRE 9939 VAN RUITEN ST BELLFLOWER, CA 90706 SEX M AM DONOR USA DLY7679862 EXP 08/12/2025 DOB 08/12/2002 AGE 2 IN 7023 DD 08/04/2021606A3/E5FD/25 CLASS C END NONE RSTR NONE 08122002 HAIR BRN EYES HZL HGT 5\'-11" WGT 205 lb FEDERAL LIMITS APPLY ISS 10/14/2021',
 'output_text': '\n2002-08-12'}

# Load checkpoint at step: 19600 (12% of total data in first epoch) 

In [33]:
local_output_dir="/opt/home/bo_ling/dolly_training/t5_doc_transcript/checkpoint-31800"
#model, tokenizer = load_model_tokenizer_for_generate(local_output_dir)
import torch
tokenizer = AutoTokenizer.from_pretrained(local_output_dir)
device = "cuda:2" if torch.cuda.is_available() else "cpu"
model = T5ForConditionalGeneration.from_pretrained(local_output_dir, trust_remote_code=True).to(device)

The argument `trust_remote_code` is to be used with Auto classes. It has no effect here and is ignored.


In [34]:
def generate_helpdesk_response(
    input_text: str,
    model: PreTrainedModel,
    tokenizer: PreTrainedTokenizer,
    PAD_TOKEN = "<pad>",
    EOS_TOKEN = "</s>",
    do_sample: bool = True,
    max_new_tokens: int = 20,
    top_p: float = 0.85,
    top_k: int = 0,
    **kwargs,
) -> str:
    input_ids = tokenizer(input_text, return_tensors="pt").input_ids.to(device)

    end_key_token_id = tokenizer.encode(EOS_TOKEN)[0]

    gen_tokens = model.generate(
        input_ids,
        pad_token_id=tokenizer.pad_token_id,
        # Ensure generation stops once it generates "### End"
        eos_token_id=end_key_token_id,
        do_sample=do_sample,
        max_new_tokens=max_new_tokens,
        top_p=top_p,
        top_k=top_k,
        **kwargs,
    )[0].cpu()

    decoded = tokenizer.decode(gen_tokens).strip(PAD_TOKEN).strip(EOS_TOKEN).strip()

    return decoded

# Print some sample output for human to read

In [35]:
count = 0
for d in test_data:
    input_text= d["input_text"]
    generated = generate_helpdesk_response(input_text, model, tokenizer)
    expected = d['output_text']
    print("="*100)
    print("="*100)
    print("\nINPUT:")
    print(input_text)
    print("\nGENERATED:")
    print(generated)
    print("\nEXPECTED:")
    print(expected)
    count += 1
    if count > 50:
        break


INPUT:
Extract dob from the following input: California DRIVER LICENSE 081202 LN SOTO MUNGUIA FN ALEC PIERRE 9939 VAN RUITEN ST BELLFLOWER, CA 90706 SEX M AM DONOR USA DLY7679862 EXP 08/12/2025 DOB 08/12/2002 AGE 2 IN 7023 DD 08/04/2021606A3/E5FD/25 CLASS C END NONE RSTR NONE 08122002 HAIR BRN EYES HZL HGT 5'-11" WGT 205 lb FEDERAL LIMITS APPLY ISS 10/14/2021

GENERATED:
2002-08-12

EXPECTED:

2002-08-12

INPUT:
Extract issue date from the following input: California DRIVER LICENSE 081202 LN SOTO MUNGUIA FN ALEC PIERRE 9939 VAN RUITEN ST BELLFLOWER, CA 90706 SEX M AM DONOR USA DLY7679862 EXP 08/12/2025 DOB 08/12/2002 AGE 2 IN 7023 DD 08/04/2021606A3/E5FD/25 CLASS C END NONE RSTR NONE 08122002 HAIR BRN EYES HZL HGT 5'-11" WGT 205 lb FEDERAL LIMITS APPLY ISS 10/14/2021

GENERATED:
2021-10-14

EXPECTED:

2021-10-14

INPUT:
Extract expiration date from the following input: California DRIVER LICENSE 081202 LN SOTO MUNGUIA FN ALEC PIERRE 9939 VAN RUITEN ST BELLFLOWER, CA 90706 SEX M AM DONO


INPUT:
Extract issue date from the following input: IDAHO Плисей ашки The Gem State DRIVER'S LICENSE 4d 4a Iss 4b Exp 9 Class D 9a End NONE 12 Rest B 1 DAWSON 2 REBECCA LYNN 3 DOB 8 1797 S MARSH WOOD PL MERIDIAN ID 83642-7456 15 Sex 16 Hot 17Wgt 18 Eyes 19 Hair F 5'-06" 140 lb BLU BLN 5 DD 010212930034 ZE3329031 10/20/2021 10/19/2025 10/19/1980

GENERATED:
2021-10-20

EXPECTED:

2021-10-20

INPUT:
Extract expiration date from the following input: IDAHO Плисей ашки The Gem State DRIVER'S LICENSE 4d 4a Iss 4b Exp 9 Class D 9a End NONE 12 Rest B 1 DAWSON 2 REBECCA LYNN 3 DOB 8 1797 S MARSH WOOD PL MERIDIAN ID 83642-7456 15 Sex 16 Hot 17Wgt 18 Eyes 19 Hair F 5'-06" 140 lb BLU BLN 5 DD 010212930034 ZE3329031 10/20/2021 10/19/2025 10/19/1980

GENERATED:
2025-10-19

EXPECTED:

2025-10-19

INPUT:
Extract first name from the following input: IDAHO Плисей ашки The Gem State DRIVER'S LICENSE 4d 4a Iss 4b Exp 9 Class D 9a End NONE 12 Rest B 1 DAWSON 2 REBECCA LYNN 3 DOB 8 1797 S MARSH WOOD PL MER


INPUT:
Extract issue date from the following input: CALIFORNIA DRIVER LICENSE DL D5327317 EXP 11/02/2022 LN AUGUST FN DANNY TALBERT JR 2882 WISHING WAY CERES, CA 95307 DOB 11/02/1983 RSTR NONE Det toget CLASS C END NONE SEX M HAIR BLK HGT 5-10" WGT 255 lb DD 07/25/2013557RB/BBFD/22 11021983 EYES BRN ISS 01/12/2018

GENERATED:
2018-01-12

EXPECTED:

2018-01-12

INPUT:
Extract expiration date from the following input: CALIFORNIA DRIVER LICENSE DL D5327317 EXP 11/02/2022 LN AUGUST FN DANNY TALBERT JR 2882 WISHING WAY CERES, CA 95307 DOB 11/02/1983 RSTR NONE Det toget CLASS C END NONE SEX M HAIR BLK HGT 5-10" WGT 255 lb DD 07/25/2013557RB/BBFD/22 11021983 EYES BRN ISS 01/12/2018

GENERATED:
2022-11-02

EXPECTED:

2022-11-02

INPUT:
Extract first name from the following input: CALIFORNIA DRIVER LICENSE DL D5327317 EXP 11/02/2022 LN AUGUST FN DANNY TALBERT JR 2882 WISHING WAY CERES, CA 95307 DOB 11/02/1983 RSTR NONE Det toget CLASS C END NONE SEX M HAIR BLK HGT 5-10" WGT 255 lb DD 07/25/201


INPUT:
Extract last name from the following input: MASSACHUSETTS Colleen Jilmie REGISTRAR LICENSES NOT FOR FEDERAL IDEA 4a ISS Stypher Watch DRIVER'S 04/12/2022 04/26/2027 29 CLASS 12 REST DM B 4b EXP WEBB 2 STEPHEN H 4d NUMBER 8 35 WICKLOW ST APT 1 MALDEN, MA 02148-6317 18 EYES BLU 15 SEX M 16 HGT 6'-00" 5 DD 04/15/2022 Rev 02/22/2016 S83805821 04/26/1952 9a END NONE DOB 04/26/52

GENERATED:
Webb

EXPECTED:

WEBB

INPUT:
Extract license class from the following input: MASSACHUSETTS Colleen Jilmie REGISTRAR LICENSES NOT FOR FEDERAL IDEA 4a ISS Stypher Watch DRIVER'S 04/12/2022 04/26/2027 29 CLASS 12 REST DM B 4b EXP WEBB 2 STEPHEN H 4d NUMBER 8 35 WICKLOW ST APT 1 MALDEN, MA 02148-6317 18 EYES BLU 15 SEX M 16 HGT 6'-00" 5 DD 04/15/2022 Rev 02/22/2016 S83805821 04/26/1952 9a END NONE DOB 04/26/52

GENERATED:
DM

EXPECTED:

DM

INPUT:
Extract drivers license number from the following input: MASSACHUSETTS Colleen Jilmie REGISTRAR LICENSES NOT FOR FEDERAL IDEA 4a ISS Stypher Watch DRIVER'

In [36]:
def normalize_str(s):
    return s.replace("\n", "").replace(":", "").replace("<pad>", "").replace("</s>", "").strip().lower()
normalize_str("OSCAR") == normalize_str("\n\nOscar ")

True

#  Compute accuracy based on instruction on goldern dataset

In [None]:
count = 0
statistics = {}
for d in test_data:
    input_text= d["input_text"]
    generated = generate_helpdesk_response(input_text, model, tokenizer)
    expected = d['output_text']
    if "input:" in input_text:
        instruction = input_text[:input_text.find("input:")] + "input:"
    else:
        instruction = "Is the driving license valid for identification?"
    if instruction in statistics:
        stat = statistics[instruction]
    else:
        stat = {"eq": 0, "neq": 0}
        statistics[instruction] = stat
    if normalize_str(generated) == normalize_str(expected):
        stat["eq"] += 1
    else:
        stat["neq"] += 1
    count += 1
    if count > 10000:
        break

In [None]:
print(f"******************** Performance for doc transcript pii *****************\n")
for k ,v in statistics.items():
    v["accuracy"] = v["eq"] / (v["eq"] + v['neq'])
    print(f"=== The accuracy for task `{k}`: {v['accuracy']} on {(v['eq'] + v['neq'])}  samples ===")

In [9]:
print(f"******************** Performance for doc transcript pii *****************\n")
for k ,v in statistics.items():
    v["accuracy"] = v["eq"] / (v["eq"] + v['neq'])
    print(f"=== The accuracy for task `{k}`: {v['accuracy']} on {(v['eq'] + v['neq'])}  samples ===")

******************** Performance for doc transcript pii *****************

=== The accuracy for task `Extract dob from the following input:`: 0.9641255605381166 on 223  samples ===
=== The accuracy for task `Extract issue date from the following input:`: 0.8430493273542601 on 223  samples ===
=== The accuracy for task `Extract expiration date from the following input:`: 0.9865470852017937 on 223  samples ===
=== The accuracy for task `Extract first name from the following input:`: 0.8340807174887892 on 223  samples ===
=== The accuracy for task `Extract middle name from the following input:`: 0.8342245989304813 on 187  samples ===
=== The accuracy for task `Extract last name from the following input:`: 0.6771300448430493 on 223  samples ===
=== The accuracy for task `Extract license class from the following input:`: 0.9417040358744395 on 223  samples ===
=== The accuracy for task `Extract drivers license number from the following input:`: 0.7130044843049327 on 223  samples ===
=== The 