In [1]:
import pandas as pd
import copy
import torch
from collections import Counter
from typing import List, Dict, Tuple
from datasets import load_dataset, Dataset
from dotenv import load_dotenv
# Use a pipeline as a high-level helper
from transformers import pipeline
from tqdm.notebook import tqdm

load_dotenv()
pipe = pipeline("token-classification", model="obi/deid_bert_i2b2")

Device set to use mps:0


In [2]:
dataset = load_dataset("mks-logic/SPY", trust_remote_code=True, faker_random_seed=42)

In [3]:
for x in dataset['medical_consultations']:
    # print(x)
    break
print(dataset.column_names)
lst_of_ent_tags = dataset['medical_consultations']['ent_tags']
lst_of_tokens = dataset['medical_consultations']['tokens']
lst_of_trailing_whitespaces = dataset['medical_consultations']['trailing_whitespaces']
lst_of_labels = dataset['medical_consultations']['labels']
print(type(lst_of_ent_tags[0][2]))
print(type(lst_of_tokens[0][2]))
print(type(lst_of_trailing_whitespaces[0][2]))
print(type(lst_of_labels[0][2]))

{'legal_questions': ['tokens', 'trailing_whitespaces', 'labels', 'ent_tags'], 'medical_consultations': ['tokens', 'trailing_whitespaces', 'labels', 'ent_tags']}
<class 'str'>
<class 'str'>
<class 'bool'>
<class 'int'>


In [4]:
print(lst_of_tokens[0])
print(lst_of_tokens[0].index("Christopher"))
print(lst_of_trailing_whitespaces[0])
print(lst_of_labels[0])

['Text', ': ', 'Hi', ', ', 'I ', 'am ', 'Christopher', 'Murillo', ', ', '\n', 'I ', 'am ', 'experiencing ', 'a ', 'sharp ', 'pain ', 'in ', 'the ', 'lower ', 'right ', 'abdomen', ', ', 'occasionally ', 'radiating ', 'to ', 'my ', 'back', '. ', 'This ', 'symptom ', 'has ', 'been ', 'persistent ', 'for ', 'the ', 'past ', '4 ', 'days ', 'now', '. ', 'The ', 'pain ', 'is ', 'worsening ', 'over ', 'time', ', ', 'especially ', 'after ', 'eating ', 'or ', 'engaging ', 'in ', 'physical ', 'activities', '. ', 'I ', 'have ', 'noticed ', 'minor ', 'nausea', ', ', 'but ', 'no ', 'vomiting ', 'or ', 'bleeding', '. ', 'The ', 'pain ', 'is ', 'non', '-', 'specific', ', ', 'meaning ', 'it ', 'does', "n't ", 'seem ', 'to ', 'be ', 'triggered ', 'by ', 'a ', 'specific ', 'food ', 'intake ', 'or ', 'any ', 'other ', 'activity', '. ', '\n\n', 'You ', 'can ', 'reach ', 'me ', 'at ', 'alvarezkenneth@gmail.com', 'or ', '493-290-9635', 'for ', 'any ', 'clarification ', 'or ', 'follow', '-', 'up ', 'questions

In [5]:
from typing import List

def reconstruct_text_from_tokens(tokens: List[str], has_trailing_space: List[bool], labels: List[int]) -> str:
    """
    Reconstructs a text string from a list of tokens and a list of trailing whitespace indicators.

    Args:
        tokens (List[str]): List of token strings.
        has_trailing_space (List[bool]): List of booleans where True means a space should follow the token.

    Returns:
        str: The reconstructed string with appropriate single-space separations.
    
    Raises:
        ValueError: If the lengths of tokens and has_trailing_space do not match.
    """
    if len(tokens) != len(has_trailing_space):
        raise ValueError("Length of tokens and has_trailing_space must be equal.")
    
    pieces = []
    for token, has_space, label in zip(tokens, has_trailing_space, labels):
        pieces.append(token)
        if (label != 14):
            pieces.append(" ")

    return "".join(pieces)

In [6]:
t = reconstruct_text_from_tokens(lst_of_tokens[0],lst_of_trailing_whitespaces[0],lst_of_labels[0])
print(t)

Text: Hi, I am Christopher Murillo , 
I am experiencing a sharp pain in the lower right abdomen, occasionally radiating to my back. This symptom has been persistent for the past 4 days now. The pain is worsening over time, especially after eating or engaging in physical activities. I have noticed minor nausea, but no vomiting or bleeding. The pain is non-specific, meaning it doesn't seem to be triggered by a specific food intake or any other activity. 

You can reach me at alvarezkenneth@gmail.com or 493-290-9635 for any clarification or follow-up questions.

I have a history of appendicitis which was surgically removed when I was 17. I have also had several episodes of stomach ulcers in the past, for which I have taken antacids and antibiotics as prescribed by my previous doctor, Dr. Emily Patel, whose contact information is available at https://facebook.com/healthcareprofessionals. 

Currently, I am not taking any medications, but I have an upcoming doctor's appointment with Dr. Soph

In [7]:
# Load model directly
from transformers import AutoTokenizer, AutoModelForTokenClassification

tokenizer = AutoTokenizer.from_pretrained("obi/deid_roberta_i2b2")
model = AutoModelForTokenClassification.from_pretrained("obi/deid_roberta_i2b2")

In [8]:
lst_of_docs = []
for tok_lst, space_lst, label_lst in zip(lst_of_tokens, lst_of_trailing_whitespaces, lst_of_labels):
    doc = reconstruct_text_from_tokens(tok_lst, space_lst, label_lst)
    lst_of_docs.append(doc)
    
lst_of_tokenized_docs = []
for doc in lst_of_docs:
    tokenized_doc = tokenizer.tokenize(doc)

    if isinstance(tokenized_doc, list):
        lst_of_tokenized_docs.append(tokenized_doc[:512])
    else:
        lst_of_tokenized_docs.append([tokenized_doc])

Token indices sequence length is longer than the specified maximum sequence length for this model (528 > 512). Running this sequence through the model will result in indexing errors


In [9]:
# lst_of_docs = [" ".join(lst)
#     for lst in lst_of_tokens
# ]
# lst_of_docs = [" ".join(doc.split())
#               for doc in lst_of_docs]
# print(lst_of_docs[1])

In [10]:
print(dataset['medical_consultations']['tokens'][0])

['Text', ': ', 'Hi', ', ', 'I ', 'am ', 'Christopher', 'Murillo', ', ', '\n', 'I ', 'am ', 'experiencing ', 'a ', 'sharp ', 'pain ', 'in ', 'the ', 'lower ', 'right ', 'abdomen', ', ', 'occasionally ', 'radiating ', 'to ', 'my ', 'back', '. ', 'This ', 'symptom ', 'has ', 'been ', 'persistent ', 'for ', 'the ', 'past ', '4 ', 'days ', 'now', '. ', 'The ', 'pain ', 'is ', 'worsening ', 'over ', 'time', ', ', 'especially ', 'after ', 'eating ', 'or ', 'engaging ', 'in ', 'physical ', 'activities', '. ', 'I ', 'have ', 'noticed ', 'minor ', 'nausea', ', ', 'but ', 'no ', 'vomiting ', 'or ', 'bleeding', '. ', 'The ', 'pain ', 'is ', 'non', '-', 'specific', ', ', 'meaning ', 'it ', 'does', "n't ", 'seem ', 'to ', 'be ', 'triggered ', 'by ', 'a ', 'specific ', 'food ', 'intake ', 'or ', 'any ', 'other ', 'activity', '. ', '\n\n', 'You ', 'can ', 'reach ', 'me ', 'at ', 'alvarezkenneth@gmail.com', 'or ', '493-290-9635', 'for ', 'any ', 'clarification ', 'or ', 'follow', '-', 'up ', 'questions

In [11]:
df = dataset['medical_consultations'].to_pandas()

In [12]:
df['text'] = df['tokens'].apply(lambda x: "".join(x))
dataset_txt=Dataset.from_pandas(df)

In [13]:
from transformers import AutoTokenizer

tokenizer = pipe.tokenizer  # get tokenizer from your existing pipeline

results = []
for inp in tqdm(lst_of_docs):
    assert isinstance(inp, str), "not a str"
    
    # Truncate the input at token level
    inputs = tokenizer(
        inp,
        truncation=True,
        max_length=512,
        return_tensors='pt',
        return_attention_mask=True
    )

    # Decode back to string after truncation
    truncated_text = tokenizer.decode(inputs["input_ids"][0], skip_special_tokens=True)

    # Pass the truncated string to the pipeline
    out = pipe(truncated_text)
    results.append(out)

  0%|          | 0/4491 [00:00<?, ?it/s]

Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


In [14]:
print(results[1])

[{'entity': 'B-PATIENT', 'score': np.float32(0.99694294), 'index': 6, 'word': 'Jacqueline', 'start': 13, 'end': 23}, {'entity': 'L-PATIENT', 'score': np.float32(0.9817825), 'index': 7, 'word': 'Adams', 'start': 24, 'end': 29}, {'entity': 'U-STAFF', 'score': np.float32(0.8695738), 'index': 107, 'word': 'ch', 'start': 465, 'end': 467}, {'entity': 'U-PATIENT', 'score': np.float32(0.33063716), 'index': 185, 'word': 'mill', 'start': 788, 'end': 792}, {'entity': 'B-LOC', 'score': np.float32(0.9791631), 'index': 208, 'word': '67', 'start': 895, 'end': 897}, {'entity': 'B-LOC', 'score': np.float32(0.58527845), 'index': 209, 'word': '##7', 'start': 897, 'end': 898}, {'entity': 'I-LOC', 'score': np.float32(0.48647797), 'index': 210, 'word': '##1', 'start': 898, 'end': 899}, {'entity': 'I-LOC', 'score': np.float32(0.9896563), 'index': 211, 'word': 'Johns', 'start': 900, 'end': 905}, {'entity': 'L-LOC', 'score': np.float32(0.8167179), 'index': 212, 'word': 'Shore', 'start': 906, 'end': 911}, {'ent

In [30]:
# lst_of_lst_of_dict = copy.deepcopy(lst_of_ent_tags)
# tn,tp = 0,0
# fn,fp = 0,0
# for i, lst in tqdm(enumerate(results)):
#     for dictionary in lst:
#         start_idx = dictionary['start']
#         end_idx = dictionary ['end']
#         doc_i = lst_of_docs[i]
#         tok = doc_i[start_idx:end_idx+1]
#         ent_tags = lst_of_ent_tags[i]
#         entity = dictionary['entity']
#         doc_tok = lst_of_tokens[i]
#         try:
            
#             tok_idx = doc_tok.index(tok.strip())
#         except ValueError as ve:
#             fp += 1
#             continue
# # print(start_idx)
# # print(end_idx)
# # print(doc_i)
# # print(tok)
# # print(ent_tags)
# # print(entity)
# # print(doc_tok)
# # raise ValueError(f'{start_idx} oh no') from ve

#         try:
            
#             tag = ent_tags[tok_idx]
#         except ValueError as ve:
#             print(doc_tok)
#             print(ent_tags)
#             assert len(doc_tok) == len(ent_tags), f'{len(doc_tok)} != {len(ent_tags)}'
#         is_pii = tag != 'O'
#         if is_pii: 
#             tp += 1 
#         else:
            
#             fp += 1
#         # print(tag)
#         # lst_of_lst_of_dict = 
#         lst_of_lst_of_dict[i].remove(tag)
#     for tag in lst_of_lst_of_dict[i]: 
#         if tag == 'O':
            
#             tn += 1
#         else:
             
#              fn += 1

In [15]:
lst_of_pii = []
for tok_lst, lab_lst in zip(lst_of_tokens, lst_of_labels):
    pii = []
    for tok, lab in zip(tok_lst, lab_lst):
        if lab != 14:
            pii.append(tok)
    lst_of_pii.append(pii)

In [16]:
lst_of_tokenized_pii = []
for lst in lst_of_pii:
    tok_pii = []
    for pii in lst: 
        tok = tokenizer.tokenize(pii)
        if isinstance(tok, list):
            for t in tok: 
                tok_pii.append(t)
        else:
            tok_pii.append(tok)
    lst_of_tokenized_pii.append(tok_pii)
#print(len(lst_of_tokenized_pii))
lst_of_tokenized_pii[0]

['Christopher',
 'Mu',
 '##rill',
 '##o',
 'al',
 '##var',
 '##ez',
 '##ken',
 '##net',
 '##h',
 '@',
 'g',
 '##mail',
 '.',
 'com',
 '49',
 '##3',
 '-',
 '290',
 '-',
 '96',
 '##35',
 '49',
 '##3',
 '-',
 '290',
 '-',
 '96',
 '##35',
 'heat',
 '##her',
 '##mart',
 '##ine',
 '##z',
 '82',
 '##23',
 'Victoria',
 'Row']

In [25]:
# print(lst_of_tokenized_pii[0])
# print()
# print(lst_of_tokenized_docs[0])

In [17]:
def normalize(token):
    return token.lstrip("##").lstrip("Ġ").strip()

In [46]:
tokenized_docs_copy = [Counter([normalize(tok) for tok in doc]) for doc in lst_of_tokenized_docs]
tokenized_pii_copy = [Counter([normalize(tok) for tok in pii]) for pii in lst_of_tokenized_pii]

fn = sum(sum(counter.values()) for counter in tokenized_pii_copy)
tn = sum(sum(counter.values()) for counter in tokenized_docs_copy)

In [47]:
tp, fp = 0, 0
for i, (model_output, true_pii_counter) in enumerate(zip(results, tokenized_pii_copy)):
    for entity in model_output:
        word = normalize(entity['word'])

        if true_pii_counter[word] > 0:
            true_pii_counter[word] -= 1
            tp += 1
        else:
            fp += 1

        if tokenized_docs_copy[i][word] > 0:
            tokenized_docs_copy[i][word] -= 1

In [48]:
fn = 0
for lst in lst_of_tokenized_pii:
    fn += len(lst)  
tn = 0
for lst in lst_of_tokenized_docs:
    tn += len(lst)
tn = tn - fn


In [49]:
len(dataset['medical_consultations']) * 512

2299392

In [52]:
tokenized_docs_copy = [Counter([normalize(tok) for tok in doc]) for doc in lst_of_tokenized_docs]
tokenized_pii_copy = [Counter([normalize(tok) for tok in pii]) for pii in lst_of_tokenized_pii]

tp, fp = 0, 0
for i, (model_output, true_pii_counter) in enumerate(zip(results, tokenized_pii_copy)):
    for entity in model_output:
        word = normalize(entity['word'])

        if true_pii_counter[word] > 0:
            true_pii_counter[word] -= 1
            tp += 1
        else:
            fp += 1

        if tokenized_docs_copy[i][word] > 0:
            tokenized_docs_copy[i][word] -= 1

fn = sum(sum(counter.values()) for counter in tokenized_pii_copy)
tn = sum(sum(counter.values()) for counter in tokenized_docs_copy) - fn

In [53]:
print(f"TP: {tp}\nTN: {tn}\nFP: {fp}\nFN: {fn}")



TP: 98133
TN: 1965164
FP: 79572
FN: 180711


In [51]:
tok

['MD']