## Import library

In [None]:
import os
import numpy as np
import evaluate
import pickle
import random
import tqdm
import matplotlib.pyplot as plt
from decimal import Decimal, getcontext
getcontext().prec = 64
import warnings
warnings.filterwarnings('ignore')
from scipy.special import rel_entr
import torch
from torch.utils.data import DataLoader
from torch.optim import AdamW
from transformers import AutoTokenizer
from transformers import DataCollatorForTokenClassification
from transformers import AutoModelForTokenClassification
from transformers import Trainer
from transformers import TrainingArguments
from transformers import get_scheduler
from huggingface_hub import Repository, get_full_repo_name
from accelerate import Accelerator
from tqdm.auto import tqdm
from datasets import *

## Create NER labels

In [None]:
entity = ['PATIENT'   , 'DOCTOR'       , 'USERNAME'  ,
          'PROFESSION',
          'ROOM'      , 'DEPARTMENT'   , 'HOSPITAL'  , 'ORGANIZATION', 'STREET' , 'CITY'    , 'STATE' , 'COUNTRY', 'ZIP'  , 'LOCATION-OTHER', 
          'AGE'       , 
          'DATE'      , 'TIME'         , 'DURATION'  , 'SET'         , 
          'PHONE'     , 'FAX'          , 'EMAIL'     , 'URL'         , 'IPADDR' , 
          'SSN'       , 'MEDICALRECORD', 'HEALTHPLAN', 'ACCOUNT'     , 'LICENSE', 'VECHICLE', 'DEVICE', 'BIOID'  , 'IDNUM']
label_names = ['OTHER']
entity_names = []
entity_count = [0] * len(entity)

for s in entity:
    label_names.append(f'B-{s}')
    label_names.append(f'I-{s}')
    entity_names.append(s)
    
id2label = {i: label for i, label in enumerate(label_names)}
label2id = {v: k for k, v in id2label.items()}
org_id2label = {i: label for i, label in enumerate(entity_names)}
org_label2id = {v: k for k, v in org_id2label.items()}

## Create dataset

In [None]:
def Spilt2Words(name, f, fa):
    tok = []
    ner = []
    lidx = 0
    ridx = 0
    while True:
        # remove last '\n'
        ans_info = fa.readline()[:-1].split('\t')
        # remove normalized DATE/TIME
        if (ans_info[1] == 'DATE' or ans_info[1] == 'TIME'): ans_info = ans_info[:-1]
            
        if (ans_info[1] != 'OTHER'): entity_count[org_label2id[ans_info[1]]] += 1
            
        ent_lidx, ent_ridx = int(ans_info[2]), int(ans_info[3])

        # find next ans_info
        while True:
            word = ''
            # find next word lidx
            while True:
                nxt_char = f.read(1)
                if (nxt_char == ' ' or nxt_char == '\n' or nxt_char == '\t'): 
                    lidx += 1
                else: 
                    word += nxt_char
                    break
            ridx = lidx
            # find next word ridx
            while True:
                char_pos = f.tell()
                nxt_char = f.read(1)
                if (nxt_char == ' ' or nxt_char == '\n' or nxt_char == '\t' or ridx + 1 == ent_ridx):
                    ridx += 1
                    f.seek(char_pos)
                    break
                else:
                    ridx += 1
                    word += nxt_char
                
            line_end = 0
            # remove '\n' in last word
            if (word[:-1] == '\n'): 
                line_end = 1
                word = word[:-1]
            # truncate beginning of the word if it is an entity word
            while (lidx < ent_lidx and ridx > ent_lidx and ridx <= ent_ridx):
                lidx += 1
                word = word[1:]
                
            tok.append(word)
            
            if (lidx < ent_lidx):
                ner.append(label2id['OTHER'])
            elif (lidx == ent_lidx):
                ner.append(label2id['B-' + ans_info[1]])
            elif (ridx <= ent_ridx):
                ner.append(label2id['I-' + ans_info[1]])
            
            lidx = ridx
            
            if (ridx == ent_ridx): # found the last word of entity, move to next answer info
                break
        
        info_pos = fa.tell()
        nxt_info = fa.readline()[:-1].split('\t')
        fa.seek(info_pos)
        # nxt_info is in next file
        if (nxt_info[0] != name): 
            break
        # nxt_info is in current file but has overlap in current info
        if (int(nxt_info[3]) <= ent_ridx):
            nxt_info = fa.readline()
            
    return tok, ner

In [None]:
def Segmentation(ds_id, ds_tok, ds_ner, id, tok, ner, l):
    while (len(ner) >= l):
        ridx = l
        k = random.randint(0, 1)
        while (ridx > 0 and ridx < len(ner) and id2label[ner[ridx]] != 'OTHER'):
            if (k): 
                ridx += 1
            else:
                ridx -= 1
        if (ridx == 0):
            ridx = len(ner)
        elif (ridx < len(ner)):
            ridx += 1
        ds_id.append(id)
        ds_tok.append(tok[:ridx])
        ds_ner.append(ner[:ridx])
        tok = tok[ridx:]
        ner = ner[ridx:]
    if (len(ner) > 0):
        ds_id.append(id)
        ds_tok.append(tok)
        ds_ner.append(ner)
    return

In [None]:
ds_dict = {'id':[], 'tokens':[], 'ner_tags':[]}

fnames = [f for f in os.listdir('./First_Phase_Release(Correction)/First_Phase_Text_Dataset')]
fnames.sort()

max_word_length = 80
fa = open('./First_Phase_Release(Correction)/answer.txt', 'r')
for fname in tqdm(fnames):
    f = open(f'./First_Phase_Release(Correction)/First_Phase_Text_Dataset/{fname}', 'r')
    tok, ner = Spilt2Words(fname[:-4], f, fa)
    if (max_word_length > 0):
        Segmentation(ds_dict['id'], ds_dict['tokens'], ds_dict['ner_tags'], fname[:-4], tok, ner, max_word_length)
    else:
        ds_dict['id'].append(fname[:-4])
        ds_dict['tokens'].append(tok)
        ds_dict['ner_tags'].append(ner)
    f.close()

In [None]:
fnames = [f for f in os.listdir('./Second_Phase_Dataset/Second_Phase_Text_Dataset')]
fnames.sort()

max_word_length = 80
fa = open('./Second_Phase_Dataset/answer.txt', 'r')
for fname in tqdm(fnames):
    f = open(f'./Second_Phase_Dataset/Second_Phase_Text_Dataset/{fname}', 'r')
    tok, ner = Spilt2Words(fname[:-4], f, fa)
    if (max_word_length > 0):
        Segmentation(ds_dict['id'], ds_dict['tokens'], ds_dict['ner_tags'], fname[:-4], tok, ner, max_word_length)
    else:
        ds_dict['id'].append(fname[:-4])
        ds_dict['tokens'].append(tok)
        ds_dict['ner_tags'].append(ner)
    f.close()

## Spilt train & dev data

In [None]:
def CountSim(train, valid):
    tcnt = [0] * len(entity)
    vcnt = [0] * len(entity)
    for tdata in train:
        for t in tdata:
            if (t != 0 and id2label[t][0] != 'I'): tcnt[org_label2id[id2label[t][2:]]] += 1
    for vdata in valid:
        for v in vdata:
            if (v != 0 and id2label[v][0] != 'I'): vcnt[org_label2id[id2label[v][2:]]] += 1
    tsum = sum(tcnt)
    vsum = sum(vcnt)
    dist = 0
    for i in range(len(entity)):
        tcnt[i] = tcnt[i]/tsum
        vcnt[i] = vcnt[i]/vsum
        dist += abs(tcnt[i] - vcnt[i]) * abs(tcnt[i] - vcnt[i])
    return tcnt, vcnt, dist

In [None]:
best_ds_train_valid = Dataset.from_dict(ds_dict).train_test_split(train_size=0.9)
best_tpor = [0] * len(entity)
best_vpor = [0] * len(entity)
best_dist = 1
upper_bound = 2e-5
try_step = 1000
while (best_dist > upper_bound):
    for i in tqdm(range(try_step)):
        cur_ds_train_valid = Dataset.from_dict(ds_dict).train_test_split(train_size=0.8)
        cur_tpor, cur_vpor, cur_dist = CountSim(cur_ds_train_valid['train']['ner_tags'], cur_ds_train_valid['test']['ner_tags'])
        if (cur_dist < best_dist):
            best_ds_train_valid = cur_ds_train_valid
            best_tpor = cur_tpor
            best_vpor = cur_vpor
            best_dist = cur_dist
            print(f'New smallest dist = {best_dist}')
    
x = np.arange(len(entity_names))
width = 0.4
plt.figure(figsize=(12.8, 4.8))
plt.bar(x, best_tpor, width, color='green', label='Train')
plt.bar(x + width, best_vpor, width, color='blue', label='Dev')
plt.xticks(x + width / 2, entity_names, rotation='vertical')
plt.ylabel('Porpotion')
plt.title('TrainDev distribution')
plt.legend()
plt.savefig('TrainDev distribution')
plt.show()

In [None]:
raw_ds = DatasetDict({'train': best_ds_train_valid['train'],
                  'validation': best_ds_train_valid['test']})

In [None]:
# raw_ds = load_from_disk("./ner_dataset/")

In [None]:
raw_ds

## Tokenize data

In [None]:
model_name = "hfl/english-pert-large"

model_checkpoint = model_name
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)

In [None]:
def align_labels_with_tokens(labels, word_ids):
    new_labels = []
    current_word = None
    for word_id in word_ids:
        if word_id != current_word:
            # Start of a new word!
            current_word = word_id
            label = -100 if word_id is None else labels[word_id]
            new_labels.append(label)
        elif word_id is None:
            # Special token
            new_labels.append(-100)
        else:
            # Same word as previous token
            label = labels[word_id]
            # If the label is B-XXX we change it to I-XXX
            if label % 2 == 1:
                label += 1
            new_labels.append(label)

    return new_labels

In [None]:
def tokenize_and_align_labels(examples):
    tokenized_inputs = tokenizer(
        examples['tokens'], truncation=True, is_split_into_words=True
    )
    all_labels = examples['ner_tags']
    new_labels = []
    for i, labels in enumerate(all_labels):
        word_ids = tokenized_inputs.word_ids(i)
        new_labels.append(align_labels_with_tokens(labels, word_ids))

    tokenized_inputs['labels'] = new_labels
    return tokenized_inputs

In [None]:
tokenized_datasets = raw_ds.map(
    tokenize_and_align_labels,
    batched=True,
    remove_columns=raw_ds['train'].column_names,
)

In [None]:
tokenized_datasets

In [None]:
data_collator = DataCollatorForTokenClassification(tokenizer=tokenizer)
batch = data_collator([tokenized_datasets['train'][i] for i in range(2)])
batch['labels']

## Training config

In [None]:
output_dir = './models/ner/'
#repo = Repository(output_dir, clone_from=repo_name)

In [None]:
metric = evaluate.load('seqeval')

In [None]:
def compute_metrics(eval_preds):
    logits, labels = eval_preds
    predictions = np.argmax(logits, axis=-1)

    # Remove ignored index (special tokens) and convert to labels
    true_labels = [[label_names[l] for l in label if l != -100] for label in labels]
    true_predictions = [
        [label_names[p] for (p, l) in zip(prediction, label) if l != -100]
        for prediction, label in zip(predictions, labels)
    ]
    all_metrics = metric.compute(predictions=true_predictions, references=true_labels)
    return {
        'precision': all_metrics['overall_precision'],
        'recall': all_metrics['overall_recall'],
        'f1': all_metrics['overall_f1'],
        'accuracy': all_metrics['overall_accuracy'],
    }

In [None]:
from transformers import AutoModelForTokenClassification, TrainingArguments, Trainer

model = AutoModelForTokenClassification.from_pretrained(
    model_checkpoint, num_labels=67, id2label=id2label, label2id=label2id, ignore_mismatched_sizes=True
)

In [None]:
# training_args = TrainingArguments(
#     output_dir=output_dir,
#     learning_rate=2e-5,
#     per_device_train_batch_size=8,
#     per_device_eval_batch_size=8,
#     num_train_epochs=10,
#     weight_decay=0.01,
#     evaluation_strategy="epoch",
#     save_strategy="epoch",
#     load_best_model_at_end=True,
# )

# trainer = Trainer(
#     model=model,
#     args=training_args,
#     train_dataset=tokenized_datasets["train"],
#     eval_dataset=tokenized_datasets["validation"],
#     tokenizer=tokenizer,
#     data_collator=data_collator,
#     compute_metrics=compute_metrics,
# )

# trainer.train()

In [None]:
#labels = raw_ds['train'][0]['ner_tags']
#labels = [label_names[i] for i in labels]
#labels

In [None]:
#predictions = labels.copy()
#predictions[2] = 'OTHER'
#metric.compute(predictions=[predictions], references=[labels])

In [None]:
model = AutoModelForTokenClassification.from_pretrained(
    model_checkpoint,
    id2label=id2label,
    label2id=label2id,
    ignore_mismatched_sizes=True
)

train_dataloader = DataLoader(
    tokenized_datasets['train'],
    shuffle=True,
    collate_fn=data_collator,
    batch_size=8,
)

eval_dataloader = DataLoader(
    tokenized_datasets['validation'], collate_fn=data_collator, batch_size=8
)

optimizer = AdamW(model.parameters(), lr=2e-5)
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model.to(device)

# accelerator = Accelerator()
accelerator = Accelerator(cpu=True)

model, optimizer, train_dataloader, eval_dataloader = accelerator.prepare(
    model, optimizer, train_dataloader, eval_dataloader
)

num_train_epochs = 10
num_update_steps_per_epoch = len(train_dataloader)
num_training_steps = num_train_epochs * num_update_steps_per_epoch

lr_scheduler = get_scheduler(
    'linear',
    optimizer=optimizer,
    num_warmup_steps=0,
    num_training_steps=num_training_steps,
)

In [None]:
#model_name = 'bert-finetuned-ner-accelerate'
#repo_name = get_full_repo_name(model_name)
#repo_name

In [None]:
def postprocess(predictions, labels):
    predictions = predictions.detach().cpu().clone().numpy()
    labels = labels.detach().cpu().clone().numpy()

    # Remove ignored index (special tokens) and convert to labels
    true_labels = [[label_names[l] for l in label if l != -100] for label in labels]
    true_predictions = [
        [label_names[p] for (p, l) in zip(prediction, label) if l != -100]
        for prediction, label in zip(predictions, labels)
    ]
    return true_labels, true_predictions

## Training

In [None]:
progress_bar = tqdm(range(num_training_steps))
f1_score = []

for epoch in range(num_train_epochs):
    # Training
    model.train()
    for batch in train_dataloader:
        outputs = model(**batch)
        loss = outputs.loss
        accelerator.backward(loss)

        optimizer.step()
        lr_scheduler.step()
        optimizer.zero_grad()
        progress_bar.update(1)

    # Evaluation
    model.eval()
    for batch in eval_dataloader:
        with torch.no_grad():
            outputs = model(**batch)

        predictions = outputs.logits.argmax(dim=-1)
        labels = batch['labels']

        # Necessary to pad predictions and labels for being gathered
        predictions = accelerator.pad_across_processes(predictions, dim=1, pad_index=-100)
        labels = accelerator.pad_across_processes(labels, dim=1, pad_index=-100)

        predictions_gathered = accelerator.gather(predictions)
        labels_gathered = accelerator.gather(labels)

        true_predictions, true_labels = postprocess(predictions_gathered, labels_gathered)
        metric.add_batch(predictions=true_predictions, references=true_labels)

    results = metric.compute()
    f1_score.append(results['overall_f1'])
    print(
        f'epoch {epoch}:',
        {
            key: results[f'overall_{key}']
            for key in ['precision', 'recall', 'f1', 'accuracy']
        },
    )

    #Save and upload
    accelerator.wait_for_everyone()
    unwrapped_model = accelerator.unwrap_model(model)
    unwrapped_model.save_pretrained(output_dir, save_function=accelerator.save)
    if accelerator.is_main_process:
        tokenizer.save_pretrained(output_dir)
        #repo.push_to_hub(
        #    commit_message=f'Training in progress epoch {epoch}', blocking=False
        #)

In [None]:
print(results)
# {'AGE': {'precision': 0.9565217391304348, 'recall': 0.88, 'f1': 0.9166666666666666, 'number': 25}, 'CITY': {'precision': 0.9893048128342246, 'recall': 0.9840425531914894, 'f1': 0.9866666666666667, 'number': 188}, 'DATE': {'precision': 0.9914984059511158, 'recall': 0.9978609625668449, 'f1': 0.9946695095948828, 'number': 935}, 'DEPARTMENT': {'precision': 0.9478672985781991, 'recall': 0.9302325581395349, 'f1': 0.9389671361502349, 'number': 215}, 'DOCTOR': {'precision': 0.9845201238390093, 'recall': 0.9845201238390093, 'f1': 0.9845201238390093, 'number': 1292}, 'DURATION': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 3}, 'HOSPITAL': {'precision': 0.9827089337175793, 'recall': 0.9798850574712644, 'f1': 0.981294964028777, 'number': 348}, 'IDNUM': {'precision': 0.9876373626373627, 'recall': 0.9930939226519337, 'f1': 0.9903581267217632, 'number': 724}, 'LOCATION-OTHER': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 0}, 'MEDICALRECORD': {'precision': 0.9971830985915493, 'recall': 0.9943820224719101, 'f1': 0.9957805907172996, 'number': 356}, 'ORGANIZATION': {'precision': 0.9090909090909091, 'recall': 0.8333333333333334, 'f1': 0.8695652173913043, 'number': 24}, 'PATIENT': {'precision': 0.9831460674157303, 'recall': 0.9915014164305949, 'f1': 0.9873060648801127, 'number': 353}, 'PHONE': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 0}, 'SET': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 0}, 'STATE': {'precision': 0.9942528735632183, 'recall': 0.9885714285714285, 'f1': 0.991404011461318, 'number': 175}, 'STREET': {'precision': 0.9888888888888889, 'recall': 0.978021978021978, 'f1': 0.9834254143646408, 'number': 182}, 'TIME': {'precision': 0.9772727272727273, 'recall': 0.9641255605381166, 'f1': 0.9706546275395034, 'number': 223}, 'ZIP': {'precision': 1.0, 'recall': 0.9731182795698925, 'f1': 0.9863760217983651, 'number': 186}, 'overall_precision': 0.9833652007648184, 'overall_recall': 0.9835532606616944, 'overall_f1': 0.9834592217229181, 'overall_accuracy': 0.999238331528254}

## Draw f1 score

In [None]:
f1_dict = {'AGE': {'precision': 0.9565217391304348, 'recall': 0.88, 'f1': 0.9166666666666666, 'number': 25}, 'CITY': {'precision': 0.9893048128342246, 'recall': 0.9840425531914894, 'f1': 0.9866666666666667, 'number': 188}, 'DATE': {'precision': 0.9914984059511158, 'recall': 0.9978609625668449, 'f1': 0.9946695095948828, 'number': 935}, 'DEPARTMENT': {'precision': 0.9478672985781991, 'recall': 0.9302325581395349, 'f1': 0.9389671361502349, 'number': 215}, 'DOCTOR': {'precision': 0.9845201238390093, 'recall': 0.9845201238390093, 'f1': 0.9845201238390093, 'number': 1292}, 'DURATION': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 3}, 'HOSPITAL': {'precision': 0.9827089337175793, 'recall': 0.9798850574712644, 'f1': 0.981294964028777, 'number': 348}, 'IDNUM': {'precision': 0.9876373626373627, 'recall': 0.9930939226519337, 'f1': 0.9903581267217632, 'number': 724}, 'LOCATION-OTHER': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 0}, 'MEDICALRECORD': {'precision': 0.9971830985915493, 'recall': 0.9943820224719101, 'f1': 0.9957805907172996, 'number': 356}, 'ORGANIZATION': {'precision': 0.9090909090909091, 'recall': 0.8333333333333334, 'f1': 0.8695652173913043, 'number': 24}, 'PATIENT': {'precision': 0.9831460674157303, 'recall': 0.9915014164305949, 'f1': 0.9873060648801127, 'number': 353}, 'PHONE': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 0}, 'SET': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 0}, 'STATE': {'precision': 0.9942528735632183, 'recall': 0.9885714285714285, 'f1': 0.991404011461318, 'number': 175}, 'STREET': {'precision': 0.9888888888888889, 'recall': 0.978021978021978, 'f1': 0.9834254143646408, 'number': 182}, 'TIME': {'precision': 0.9772727272727273, 'recall': 0.9641255605381166, 'f1': 0.9706546275395034, 'number': 223}, 'ZIP': {'precision': 1.0, 'recall': 0.9731182795698925, 'f1': 0.9863760217983651, 'number': 186}, 'overall_precision': 0.9833652007648184, 'overall_recall': 0.9835532606616944, 'overall_f1': 0.9834592217229181, 'overall_accuracy': 0.999238331528254}

In [None]:
pert_f1_df = pd.DataFrame({
    'Tag': [],
    'Precision': [],
    'Recall': [],
    'F1': [],
    'Number': []
})

for tag, data in f1_dict.items():
    new_row = []
    overall = []

    if tag.find('overall_') != -1:
        continue

    new_row.append(tag)

    for key, val in data.items():
        new_row.append(val)
    
    pert_f1_df.loc[len(pert_f1_df.index)] = new_row

pert_f1_df.loc[len(pert_f1_df.index)] = ['', 'overall_precision', 'overall_recall', 'overall_f1', 'overall_accuracy']
pert_f1_df.loc[len(pert_f1_df.index)] = ['', 0.9833652007648184, 0.9835532606616944, 0.9834592217229181, 0.999238331528254]


In [None]:
pert_f1_df

In [None]:
pert_f1_df.to_csv('./Validation_Dataset/pert_ans/f1.csv')

In [None]:
model_name = model_name.replace('/', '_')
plt.plot(f1_score, label = "f1")
# naming the x axis
plt.xlabel('epoch')
# naming the y axis
plt.ylabel('f1 score')
# giving a title to my graph
title = f'{model_name}'
plt.title(title)
# show a legend on the plot
plt.legend()
# store fig
# plt.savefig(model_name)
# function to show the plot
plt.show()
# store score
# with open(title, "wb") as fp:   #Pickling
#     pickle.dump(f1_score, fp)

## Inference

In [None]:
val_docs = {'id':[], 'doc':[]}
fnames = [f for f in os.listdir('./Validation_Dataset/Validation_Release/')]
fnames.sort()

# max_word_length = 80
# fa = open('./Second_Phase_Dataset/answer.txt', 'r')
for fname in tqdm(fnames):
    f = open(f'./Validation_Dataset/Validation_Release/{fname}', 'r')
    lines = f.read()
    # tok = lines.split()

    val_docs['id'].append(fname[:-4])
    val_docs['doc'].append(lines)

    # tok, ner = Spilt2Words(fname[:-4], f, fa)
    # if (max_word_length > 0):
    #     Segmentation(ds_dict['id'], ds_dict['tokens'], ds_dict['ner_tags'], fname[:-4], tok, ner, max_word_length)
    # else:
    #     ds_dict['id'].append(fname[:-4])
    #     ds_dict['tokens'].append(tok)
    #     ds_dict['ner_tags'].append(ner)
    f.close()

In [None]:
import nltk
from nltk.tokenize import sent_tokenize

# Download the sentence tokenizer model (run this once)
nltk.download('punkt')

In [None]:
import re
def split_documents(fnames, words_per_segment):
    result_dict = {}

    for fname in tqdm(fnames):
        with open(os.path.join('./Validation_Dataset/Validation_Release', fname), 'r') as file:
            content = file.read()

        current_segment = []
        segments = []
        word_count = 0

        # Use a regular expression to split the content into words
        words = content.split(" ")

        for word in words:
            # Check if splitting is needed based on word count
            word_count += 1
            if word_count > words_per_segment:
                key = f"{fname[:-4]}_{len(segments) + 1}"
                result_dict[key] = ' '.join(current_segment)
                current_segment = []
                segments.append(key)
                word_count = 0

            current_segment.append(word)

        # Handle the remaining words after the loop
        if current_segment:
            key = f"{fname[:-4]}_{len(segments) + 1}"
            result_dict[key] = ' '.join(current_segment)
            segments.append(key)

    return result_dict

In [None]:
fnames = [f for f in os.listdir('./Validation_Dataset/Validation_Release/')]
fnames.sort()

max_lines_per_segment = 10
max_sentences_per_segment = 5
max_characters_per_segment = 100
words_per_segment = 80

# result_segments = split_documents(fnames, max_lines_per_segment, max_sentences_per_segment)

val_result_segments = split_documents(fnames, words_per_segment)


In [None]:
# Print the first segment of the first document for demonstration
key_example = list(val_result_segments.keys())[2]
print(f"Segment {key_example}:")
print(val_result_segments[key_example])

# val_docs['doc'][0][1855:].count('\n')

In [None]:
# list[val_result_segments.keys()]

In [None]:
# print(len(sent_tokenize(result_segments['650_8'])))
# print(val_result_segments['file21703_12'])

### Load model

In [None]:
from transformers import pipeline

# Replace this with your own checkpoint
model_checkpoint = "./models/ner/"
token_classifier = pipeline(
    "token-classification", model=model_checkpoint, aggregation_strategy="simple"
)

In [None]:
# classify
val_result_ans_dict = {}

for fid_sid, seg in val_result_segments.items():
    try:
        val_result_ans_dict[fid_sid] = token_classifier(seg)
    except:
        print(fid_sid)
        break

In [None]:
# val_result_ans_dict['1002_6']

In [None]:
# result_ans_dict_cpy = result_ans_dict.copy()


In [None]:
import re
from word2number import w2n

def Normalize(time_type, org):
    nor = ''
    if (time_type == 'DATE'):
        if (re.match('\d{1,2}(\/|\.| |-|,)\d{1,2}(\/|\.| |-|,)\d{2,4}', org)):
            l = re.split('\/|\.| |-|,', org)
            if (len(l[2]) == 2):
                l[2] = '20' + l[2]
            elif (len(l[2]) == 3):
                l[2] = '2' + l[2]
            if (len(l[1]) == 1):
                l[1] = '0' + l[1]
            if (len(l[0]) == 1):
                l[0] = '0' + l[0]
            nor = l[2] + '-' + l[1] + '-' + l[0]
        elif (re.match('\/\d{1,2}\/(\d{2}|\d{4})', org)):
            l = re.split('\/', org)
            if (len(l[1]) == 1):
                l[1] = '0' + l[1]
            if (len(l[2]) == 2):
                l[2] = '20' + l[2]
            nor = l[2] + '-' + l[1]
        elif (re.match('\d{1,2}\/\d{2,5}', org)):
            l = re.split('\/', org)
            if (len(l[0]) == 1):
                l[0] = '0' + l[0]
            if (len(l[1]) == 2):
                nor = '20' + l[1] + '-' + l[0]
            elif (len(l[1]) == 3):
                nor = '20' + l[1][1:] + '-' + '0' + l[1][0] + '-' + l[0]
            elif (len(l[1]) == 4):
                nor = l[1] + '-' + l[0]
            elif (len(l[1]) == 5):
                nor = l[1][1:] + '-' + '0' + l[1][0] + '-' + l[0]
        elif (re.match('\d{8}', org)):
            nor = org[0:4] + '-' + org[4:6] + '-' + org[6:8]
        elif (re.match('\d{4}', org)):
            nor = org
        elif (re.match('\d{3}', org)):
            nor = '2' + org
        elif (re.match('(\d{2}|)(-|)(Jan|Feb|Mar|Apr|May|Jun|Jul|Aug|Sep|Oct|Nov|Dec)(-| )\d{2,4}', org)):
            org = org.replace('Jan', '01')
            org = org.replace('Feb', '02')
            org = org.replace('Mar', '03')
            org = org.replace('Apr', '04')
            org = org.replace('May', '05')
            org = org.replace('Jun', '06')
            org = org.replace('Jul', '07')
            org = org.replace('Aug', '08')
            org = org.replace('Sep', '09')
            org = org.replace('Oct', '10')
            org = org.replace('Nov', '11')
            org = org.replace('Dec', '12')
            l = re.split('-| ', org)
            if (len(l) == 2):
                if (len(l[1]) == 2):
                    l[1] = '20' + l[1]
                elif (len(l[1]) == 3):
                    l[1] = '2' + l[1]
                nor = l[1] + '-' + l[0]
            else:
                if (len(l[2]) == 2):
                    l[2] = '20' + l[2]
                elif (len(l[2]) == 3):
                    l[2] = '2' + l[2]
                nor = l[2] + '-' + l[1] + '-' + l[0]
        elif (re.match('\d{1,2}((st)|(nd)|(rd)|(th)) of (January|February|March|April|May|June|July|August|September|October|November|December) \d{4}', org)):
            org = org.replace('January', '01')
            org = org.replace('Feburary', '02')
            org = org.replace('March', '03')
            org = org.replace('April', '04')
            org = org.replace('May', '05')
            org = org.replace('June', '06')
            org = org.replace('July', '07')
            org = org.replace('August', '08')
            org = org.replace('September', '09')
            org = org.replace('October', '10')
            org = org.replace('November', '11')
            org = org.replace('December', '12')
            l = re.split(' ', org)
            nor = l[3] + '-' + l[2] + '-' + l[0][:-2]
        elif (re.match('(\d{1,2}|)( |)(January|February|March|April|May|June|July|August|September|October|November|December) \d{4}', org)):
            if (re.match('\d', org[0]) and re.match('\d', org[1]) == None):
                org = '0' + org
            org = org.replace('January', '01')
            org = org.replace('Feburary', '02')
            org = org.replace('March', '03')
            org = org.replace('April', '04')
            org = org.replace('May', '05')
            org = org.replace('June', '06')
            org = org.replace('July', '07')
            org = org.replace('August', '08')
            org = org.replace('September', '09')
            org = org.replace('October', '10')
            org = org.replace('November', '11')
            org = org.replace('December', '12')
            org = org.replace(' ', '')
            if (len(org) == 6):
                nor = org[2:] + '-' + org[0:2]
            else:    
                nor = org[4:] + '-' + org[2:4] + '-' + org[0:2]
    elif (time_type == 'TIME'):
        if (re.match('(\d{1,2}(\/|\.)\d{1,2}(\/|\.)\d{2,4}(  | |)|)(at|)( |)\d{1,2}(:|\.)\d{2}(AM|am|PM|pm|Hr|Hrs|hr|hrs|)( on the \d{1,2}((st)|(nd)|(rd)|(th)) of (January|February|March|April|May|June|July|August|September|October|November|December) \d{4}|)', org)):
            tmp = org
            pm = 0
            am = 0
            if (re.search('PM', org, flags=0) != None):
                pm = 1
            if (re.search('pm', org, flags=0) != None):
                pm = 1
            if (re.search('AM', org, flags=0) != None):
                am = 1
            if (re.search('am', org, flags=0) != None):
                am = 1
            get_date = 0
            date = re.search('\d{1,2}(\/|\.)\d{1,2}(\/|\.)\d{2,4}', org, flags=0)
            if (date != None):
                date = date.group(0)
                org = org.replace(date, '')
                date = re.split('\/|\.', date)
                if (len(date[0]) == 1):
                    date[0] = '0' + date[0]
                if (len(date[1]) == 1):
                    date[1] = '0' + date[1]
                if (len(date[2]) == 2):
                    date[2] = '20' + date[2]
                elif (len(date[2]) == 3):
                    date[2] = '2' + date[2]
                nor = date[2] + '-' + date[1] + '-' + date[0]
                get_date = 1
            yyyy = re.search('\d{4}', org, flags=0)
            if (yyyy != None and get_date == 0):
                yyyy = yyyy.group(0)
                org = org.replace(yyyy, '')
                nor = yyyy + '-'
            mm = re.search('January|February|March|April|May|June|July|August|September|October|November|December', org, flags=0)
            if (mm != None and get_date == 0):
                mm = mm.group(0)
                org = org.replace(mm, '')
                mm = mm.replace('January', '01')
                mm = mm.replace('Feburary', '02')
                mm = mm.replace('March', '03')
                mm = mm.replace('April', '04')
                mm = mm.replace('May', '05')
                mm = mm.replace('June', '06')
                mm = mm.replace('July', '07')
                mm = mm.replace('August', '08')
                mm = mm.replace('September', '09')
                mm = mm.replace('October', '10')
                mm = mm.replace('November', '11')
                mm = mm.replace('December', '12')
                nor = nor + mm + '-'
            dd = re.search('\d{1,2}((st)|(nd)|(rd)|(th))', org, flags=0)
            if (dd != None and get_date == 0):
                dd = dd.group(0)
                org = org.replace(dd, '')
                dd = dd.replace('st', '')
                dd = dd.replace('nd', '')
                dd = dd.replace('rd', '')
                dd = dd.replace('th', '')
                if (len(dd) == 1):
                    dd = '0' + dd
                nor = nor + dd
            get_time = 0
            time = re.search('\d{1,2}(:|\.)\d{1,2}', org, flags=0)
            if (time != None):
                time = time.group(0)
                org = org.replace(time, '')
                time = re.split('\.|:', time)
                if (pm == 1 and int(time[0]) < 12):
                    time[0] = str(int(time[0]) + 12)
                elif (am == 1 and int(time[0]) == 12):
                    time[0] = '00'
                if (len(time[0]) == 1):
                    time[0] = '0' + time[0]
                nor = nor + 'T' + time[0] + ':' + time[1]
                get_time = 1
            pm = 0
            am = 0
            if (re.search('pm', org, flags=0) != None):
                pm = 1
            if (re.search('am', org, flags=0) != None):
                am = 1
            time = re.search('\d{1,4}', org, flags=0)
            if (time != None and get_time == 0):
                time = time.group(0)
                org = org.replace(time, '')
                hh, mm = '00', '00'
                if (len(time) == 4):
                    hh = time[0:2]
                    mm = time[2:]
                elif (len(time) == 3):
                    hh = time[0]
                    mm = time[1:]
                elif (len(time) == 2):
                    hh = time
                elif (len(time) == 1):
                    hh = time
                if (pm == 1 and int(hh) < 12):
                    hh = str(int(hh) + 12)
                elif (am == 1 and int(hh) == 12):
                    hh = '00'
                nor = nor + 'T' + hh + ':' + mm    
            #if (nor != ans):    
                #print(f'1:nor={nor}, ans={ans}, org={tmp}')
        elif (re.match('\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2}', org)):
            tmp = org
            nor = org.replace(' ', 'T')
            #if (nor != ans):    
                #print(f'2:nor={nor}, ans={ans}, org={tmp}')
        elif (re.match('(at |)(\d{1,2}|)(:|\.|)\d{2}( |)(am|pm|Hr|Hrs|hr|hrs|)( on | )(the |)\d{1,2}(\/|\.)\d{2,4}(\/|\.)\d{1,2}', org)):
            tmp = org
            pm = 0
            am = 0
            if (re.search('pm', org, flags=0) != None):
                pm = 1
            if (re.search('am', org, flags=0) != None):
                am = 1
            date = re.search('\d{1,2}(\/|\.)\d{1,2}(\/|\.)\d{2,4}', org, flags=0)
            if (date != None):
                date = date.group(0)
                org = org.replace(date, '')
                date = re.split('\/|\.', date)
                if (len(date[0]) == 1):
                    date[0] = '0' + date[0]
                if (len(date[1]) == 1):
                    date[1] = '0' + date[1]
                if (len(date[2]) == 2):
                    date[2] = '20' + date[2]
                elif (len(date[2]) == 3):
                    date[2] = '2' + date[2]
                nor = date[2] + '-' + date[1] + '-' + date[0] + 'T'
            org = org.replace(':', '')
            time = re.search('\d{1,4}', org, flags=0)
            if (time != None):
                time = time.group(0)
                org = org.replace(time, '')
                hh, mm = '00', '00'
                if (len(time) == 4):
                    hh = time[0:2]
                    mm = time[2:]
                elif (len(time) == 3):
                    hh = time[0]
                    mm = time[1:]
                elif (len(time) == 2):
                    hh = time
                elif (len(time) == 1):
                    hh = time
                if (pm == 1 and int(hh) < 12):
                    hh = str(int(hh) + 12)
                elif (am == 1 and int(hh) == 12):
                    hh = '00'
                nor = nor + hh + ':' + mm
            #if (nor != ans):    
                #print(f'3:nor={nor}, ans={ans}, org={tmp}')
        elif (re.match('((\d{1,2}((pm)|(am)))|(\d{4}(Hr|Hrs|hr|hrs|)))(( on )| )\d{1,2}(\/|\.)\d{1,2}(\/|\.)\d{2,4}', org)):
            tmp = org
            pm = 0
            am = 0
            if (re.search('pm', org, flags=0) != None):
                pm = 1
            if (re.search('am', org, flags=0) != None):
                am = 1
            date = re.search('\d{1,2}(\/|\.)\d{1,2}(\/|\.)\d{2,4}', org, flags=0)
            if (date != None):
                date = date.group(0)
                org = org.replace(date, '')
                date = re.split('\/|\.', date)
                if (len(date[0]) == 1):
                    date[0] = '0' + date[0]
                if (len(date[1]) == 1):
                    date[1] = '0' + date[1]
                if (len(date[2]) == 2):
                    date[2] = '20' + date[2]
                elif (len(date[2]) == 3):
                    date[2] = '2' + date[2]
                nor = date[2] + '-' + date[1] + '-' + date[0] + 'T'
            hrtime = re.search('\d{4}', org, flags=0)
            if (hrtime != None):
                hrtime = hrtime.group(0)
                org = org.replace(hrtime, '')
                nor = nor + hrtime[0:2] + ':' + hrtime[2:]
            time = re.search('\d{1,2}', org, flags=0)
            if (time != None):
                time = time.group(0)
                org = org.replace(time, '')
                hh = time
                if (pm == 1 and int(hh) < 12):
                    hh = str(int(hh) + 12)
                elif (am == 1 and int(hh) == 12):
                    hh = '00'
                if (len(hh) == 1):
                    hh = '0' + hh
                nor = nor + hh + ':' + '00'
            #if (nor != ans):    
                #print(f'4:nor={nor}, ans={ans}, org={tmp}')
    elif (time_type == 'DURATION'):   
        tmp = org
        org = org.replace('one', '1')
        org = org.replace('two', '2')
        org = org.replace('three', '3')
        org = org.replace('four', '4')
        org = org.replace('five', '5')
        num = ''
        alp = ''
        space_idx = org.find(' ')
        for i in range(len(org)):
            if (org[i] == 'D' or org[i] == 'd' or\
                org[i] == 'W' or org[i] == 'w' or\
                org[i] == 'M' or org[i] == 'm' or\
                org[i] == 'Y' or org[i] == 'y') and i > space_idx:
                alp = org[i]
                org = org[:i]
                break
        # print(org, alp)
        org = re.split('-| ', org)
        try:
            if org[0].isalpha():
                org[0] = w2n.word_to_num(org[0])
            # print(org)
            if (len(org) == 1 or org[1] == ''):
                nor = 'P' + str(org[0]) + alp.upper()
            else:
                nor = 'P' + str((int(org[0]) + int(org[1])) / 2) + alp.upper()
        except:
            nor = tmp
        # if (nor != ans):    
        #     print(f'dur:nor={nor}, ans={ans}, org={tmp}')
    elif (time_type == 'SET'):
        if (re.match('twice', org)):
            nor = 'R2'
    return nor


In [None]:
import pandas as pd

In [None]:
val_ans_df = pd.DataFrame({
    'file_id': [],
    'PHI_type': [],
    'PHI_start': [],
    'PHI_end': [],
    'PHI_content': [],
    'ISO': []
})

last_fid = ""
last_idx_of_last_seg = 0
for fid_sid, entities in val_result_ans_dict.items():
    curr_fid = fid_sid.split('_')[0]
    curr_sid = fid_sid.split('_')[1]
    # print(fid_sid)

    if curr_fid != last_fid:
        with open(os.path.join('./Validation_Dataset/Validation_Release', curr_fid+'.txt'), 'r') as file:
            content = file.read()
        last_fid = curr_fid
        last_idx_of_last_seg = 0

    # last_idx_of_last_seg = 0

    for i, entity in enumerate(entities):
        new_row = []
        # print(i, entity)

        if i == len(entities) - 1 and entity['entity_group'] == 'OTHER':
            last_idx_of_last_seg += len(val_result_segments[fid_sid])
            continue
        elif entity['entity_group'] != 'OTHER':
            start_idx = entity['start'] + last_idx_of_last_seg + int(curr_sid) - 1
            end_idx = entity['end'] + last_idx_of_last_seg + int(curr_sid) - 1
            
            word = content[start_idx:end_idx]

            if i == len(entities) - 1:
                last_idx_of_last_seg += len(val_result_segments[fid_sid])

            if len(word) > 1:
                # print(word, start_idx, end_idx)
                while word[0].isalnum() == False or word[-1].isalnum() == False:
                    if word[0].isalnum() == False:
                        word = word[1:]
                        start_idx += 1
                    # print(word, start_idx, end_idx)
                    if word[-1].isalnum() == False:
                        word = word[:-1]
                        end_idx -= 1
            
            if '\n' in word:
                word = word.replace('\n', ' ')

            label = entity['entity_group']

            have_num_or_aplha_desc = False
            if word[0].isdigit() and word[-1].isalpha and content[start_idx-1] == ' ': 
                have_num_or_aplha_desc = True
            elif word.find(' ') != -1:
                have_num_or_aplha_desc = True

            if entity['entity_group'] == 'DURATION' and word != 'twice' and have_num_or_aplha_desc == False:
                word_cpy = word.lower()
                if word_cpy.find('yr') != -1 or word_cpy.find('ye') != -1 or word_cpy.find('m') != -1 or word_cpy.find('w') != -1:
                    last_index = start_idx - 1
                    # case 1: no spcae between number and month, week or year
                    if content[last_index].isdigit():
                        while content[last_index].isdigit():
                            last_index -= 1
                        start_idx = last_index + 1
                        word = content[start_idx:end_idx]
                    else:
                        last_space1 = content.rfind(' ', 0, start_idx)
                        last_space2 = content.rfind(' ', 0, last_space1)
                        start_idx = last_space2 + 1
                        word = content[start_idx:end_idx]
                elif word.isdigit():
                    next_index = end_idx + 1
                    if content[next_index].isalpha():
                        while content[next_index].isalpha():
                            next_index += 1
                        end_idx = next_index
                        word = content[start_idx:end_idx]
                    else:
                        first_space_index = content.find(' ', end_idx)
                        # Find the index of the second space after the target word
                        second_space_index = content.find(' ', first_space_index + 1)
                        end_idx = second_space_index
                        word = content[start_idx:end_idx]   
            elif word == 'twice':
                label = 'SET'
            # if word == 'twice':
            #     label = 'SET'
            
            
            new_row.extend([curr_fid, label, start_idx, end_idx, word])
            

            need_iso = ['DATE', 'TIME', 'DURATION', 'SET']

            if entity['entity_group'] in need_iso:
                new_row.append(Normalize(label, word))
            else:
                new_row.append('')
            
            val_ans_df.loc[len(val_ans_df)] = new_row
        else:
            continue


In [None]:
val_ans_df

In [None]:
val_ans_df.loc[val_ans_df['PHI_type'] == 'DURATION']

In [None]:
# with open('')

In [None]:
df_crf_val = pd.read_csv('./Validation_Dataset/prediction.csv')
df_crf_duration = df_crf_val.loc[df_crf_val['PHI_type'] == 'DURATION']
df_crf_lo = df_crf_val.loc[df_crf_val['PHI_type'] == 'LOCATION-OTHER']

In [None]:
df_crf_duration

In [None]:
val_ans_df.to_csv('./Validation_Dataset/pert_ans/pert_answer.csv')

In [None]:
val_ans_df.to_csv('./Validation_Dataset/pert_ans/pert_answer.txt', sep='\t', header=False, index=False)

## Inference: Test set

In [None]:
test_docs = {'id':[], 'doc':[]}
fnames = [f for f in os.listdir('./opendid_test/opendid_test/')]
fnames.sort()

# max_word_length = 80
# fa = open('./Second_Phase_Dataset/answer.txt', 'r')
for fname in tqdm(fnames):
    f = open(f'./opendid_test/opendid_test/{fname}', 'r')
    lines = f.read()
    # tok = lines.split()

    test_docs['id'].append(fname[:-4])
    test_docs['doc'].append(lines)

    f.close()

In [None]:
import nltk
from nltk.tokenize import sent_tokenize

# Download the sentence tokenizer model (run this once)
nltk.download('punkt')

In [None]:
import re
def split_documents(fnames, words_per_segment):
    result_dict = {}

    for fname in tqdm(fnames):
        with open(os.path.join('./opendid_test/opendid_test', fname), 'r') as file:
            content = file.read()

        current_segment = []
        segments = []
        word_count = 0

        # Use a regular expression to split the content into words
        words = content.split(" ")

        for word in words:
            # Check if splitting is needed based on word count
            word_count += 1
            if word_count > words_per_segment:
                key = f"{fname[:-4]}_{len(segments) + 1}"
                result_dict[key] = ' '.join(current_segment)
                current_segment = []
                segments.append(key)
                word_count = 0

            current_segment.append(word)

        # Handle the remaining words after the loop
        if current_segment:
            key = f"{fname[:-4]}_{len(segments) + 1}"
            result_dict[key] = ' '.join(current_segment)
            segments.append(key)

    return result_dict

In [None]:
fnames = [f for f in os.listdir('./opendid_test/opendid_test/')]
fnames.sort()

max_lines_per_segment = 10
max_sentences_per_segment = 5
max_characters_per_segment = 100
words_per_segment = 80

# result_segments = split_documents(fnames, max_lines_per_segment, max_sentences_per_segment)

result_segments = split_documents(fnames, words_per_segment)


In [None]:
# Print the first segment of the first document for demonstration
key_example = list(result_segments.keys())[0]
print(f"Segment {key_example}:")
print(result_segments[key_example])

# val_docs['doc'][0][1855:].count('\n')

In [None]:
# list[result_segments.keys()]

In [None]:
# print(len(sent_tokenize(result_segments['650_8'])))
# print(result_segments['file21703_12'])

In [None]:
from transformers import pipeline

# Replace this with your own checkpoint
model_checkpoint = "./models/ner/"
token_classifier = pipeline(
    "token-classification", model=model_checkpoint, aggregation_strategy="simple"
)

In [None]:
result_ans_dict = {}

for fid_sid, seg in result_segments.items():
    try:
        result_ans_dict[fid_sid] = token_classifier(seg)
    except:
        print(fid_sid)
        break

In [None]:
# result_ans_dict['1002_6']

In [None]:
import re
from word2number import w2n

def Normalize(time_type, org):
    nor = ''
    if (time_type == 'DATE'):
        if (re.match('\d{1,2}(\/|\.| |-|,)\d{1,2}(\/|\.| |-|,)\d{2,4}', org)):
            l = re.split('\/|\.| |-|,', org)
            if (len(l[2]) == 2):
                l[2] = '20' + l[2]
            elif (len(l[2]) == 3):
                l[2] = '2' + l[2]
            if (len(l[1]) == 1):
                l[1] = '0' + l[1]
            if (len(l[0]) == 1):
                l[0] = '0' + l[0]
            nor = l[2] + '-' + l[1] + '-' + l[0]
        elif (re.match('\/\d{1,2}\/(\d{2}|\d{4})', org)):
            l = re.split('\/', org)
            if (len(l[1]) == 1):
                l[1] = '0' + l[1]
            if (len(l[2]) == 2):
                l[2] = '20' + l[2]
            nor = l[2] + '-' + l[1]
        elif (re.match('\d{1,2}\/\d{2,5}', org)):
            l = re.split('\/', org)
            if (len(l[0]) == 1):
                l[0] = '0' + l[0]
            if (len(l[1]) == 2):
                nor = '20' + l[1] + '-' + l[0]
            elif (len(l[1]) == 3):
                nor = '20' + l[1][1:] + '-' + '0' + l[1][0] + '-' + l[0]
            elif (len(l[1]) == 4):
                nor = l[1] + '-' + l[0]
            elif (len(l[1]) == 5):
                nor = l[1][1:] + '-' + '0' + l[1][0] + '-' + l[0]
        elif (re.match('\d{8}', org)):
            nor = org[0:4] + '-' + org[4:6] + '-' + org[6:8]
        elif (re.match('\d{4}', org)):
            nor = org
        elif (re.match('\d{3}', org)):
            nor = '2' + org
        elif (re.match('(\d{2}|)(-|)(Jan|Feb|Mar|Apr|May|Jun|Jul|Aug|Sep|Oct|Nov|Dec)(-| )\d{2,4}', org)):
            org = org.replace('Jan', '01')
            org = org.replace('Feb', '02')
            org = org.replace('Mar', '03')
            org = org.replace('Apr', '04')
            org = org.replace('May', '05')
            org = org.replace('Jun', '06')
            org = org.replace('Jul', '07')
            org = org.replace('Aug', '08')
            org = org.replace('Sep', '09')
            org = org.replace('Oct', '10')
            org = org.replace('Nov', '11')
            org = org.replace('Dec', '12')
            l = re.split('-| ', org)
            if (len(l) == 2):
                if (len(l[1]) == 2):
                    l[1] = '20' + l[1]
                elif (len(l[1]) == 3):
                    l[1] = '2' + l[1]
                nor = l[1] + '-' + l[0]
            else:
                if (len(l[2]) == 2):
                    l[2] = '20' + l[2]
                elif (len(l[2]) == 3):
                    l[2] = '2' + l[2]
                nor = l[2] + '-' + l[1] + '-' + l[0]
        elif (re.match('\d{1,2}((st)|(nd)|(rd)|(th)) of (January|February|March|April|May|June|July|August|September|October|November|December) \d{4}', org)):
            org = org.replace('January', '01')
            org = org.replace('Feburary', '02')
            org = org.replace('March', '03')
            org = org.replace('April', '04')
            org = org.replace('May', '05')
            org = org.replace('June', '06')
            org = org.replace('July', '07')
            org = org.replace('August', '08')
            org = org.replace('September', '09')
            org = org.replace('October', '10')
            org = org.replace('November', '11')
            org = org.replace('December', '12')
            l = re.split(' ', org)
            nor = l[3] + '-' + l[2] + '-' + l[0][:-2]
        elif (re.match('(\d{1,2}|)( |)(January|February|March|April|May|June|July|August|September|October|November|December) \d{4}', org)):
            if (re.match('\d', org[0]) and re.match('\d', org[1]) == None):
                org = '0' + org
            org = org.replace('January', '01')
            org = org.replace('Feburary', '02')
            org = org.replace('March', '03')
            org = org.replace('April', '04')
            org = org.replace('May', '05')
            org = org.replace('June', '06')
            org = org.replace('July', '07')
            org = org.replace('August', '08')
            org = org.replace('September', '09')
            org = org.replace('October', '10')
            org = org.replace('November', '11')
            org = org.replace('December', '12')
            org = org.replace(' ', '')
            if (len(org) == 6):
                nor = org[2:] + '-' + org[0:2]
            else:    
                nor = org[4:] + '-' + org[2:4] + '-' + org[0:2]
    elif (time_type == 'TIME'):
        if (re.match('(\d{1,2}(\/|\.)\d{1,2}(\/|\.)\d{2,4}(  | |)|)(at|)( |)\d{1,2}(:|\.)\d{2}(AM|am|PM|pm|Hr|Hrs|hr|hrs|)( on the \d{1,2}((st)|(nd)|(rd)|(th)) of (January|February|March|April|May|June|July|August|September|October|November|December) \d{4}|)', org)):
            tmp = org
            pm = 0
            am = 0
            if (re.search('PM', org, flags=0) != None):
                pm = 1
            if (re.search('pm', org, flags=0) != None):
                pm = 1
            if (re.search('AM', org, flags=0) != None):
                am = 1
            if (re.search('am', org, flags=0) != None):
                am = 1
            get_date = 0
            date = re.search('\d{1,2}(\/|\.)\d{1,2}(\/|\.)\d{2,4}', org, flags=0)
            if (date != None):
                date = date.group(0)
                org = org.replace(date, '')
                date = re.split('\/|\.', date)
                if (len(date[0]) == 1):
                    date[0] = '0' + date[0]
                if (len(date[1]) == 1):
                    date[1] = '0' + date[1]
                if (len(date[2]) == 2):
                    date[2] = '20' + date[2]
                elif (len(date[2]) == 3):
                    date[2] = '2' + date[2]
                nor = date[2] + '-' + date[1] + '-' + date[0]
                get_date = 1
            yyyy = re.search('\d{4}', org, flags=0)
            if (yyyy != None and get_date == 0):
                yyyy = yyyy.group(0)
                org = org.replace(yyyy, '')
                nor = yyyy + '-'
            mm = re.search('January|February|March|April|May|June|July|August|September|October|November|December', org, flags=0)
            if (mm != None and get_date == 0):
                mm = mm.group(0)
                org = org.replace(mm, '')
                mm = mm.replace('January', '01')
                mm = mm.replace('Feburary', '02')
                mm = mm.replace('March', '03')
                mm = mm.replace('April', '04')
                mm = mm.replace('May', '05')
                mm = mm.replace('June', '06')
                mm = mm.replace('July', '07')
                mm = mm.replace('August', '08')
                mm = mm.replace('September', '09')
                mm = mm.replace('October', '10')
                mm = mm.replace('November', '11')
                mm = mm.replace('December', '12')
                nor = nor + mm + '-'
            dd = re.search('\d{1,2}((st)|(nd)|(rd)|(th))', org, flags=0)
            if (dd != None and get_date == 0):
                dd = dd.group(0)
                org = org.replace(dd, '')
                dd = dd.replace('st', '')
                dd = dd.replace('nd', '')
                dd = dd.replace('rd', '')
                dd = dd.replace('th', '')
                if (len(dd) == 1):
                    dd = '0' + dd
                nor = nor + dd
            get_time = 0
            time = re.search('\d{1,2}(:|\.)\d{1,2}', org, flags=0)
            if (time != None):
                time = time.group(0)
                org = org.replace(time, '')
                time = re.split('\.|:', time)
                if (pm == 1 and int(time[0]) < 12):
                    time[0] = str(int(time[0]) + 12)
                elif (am == 1 and int(time[0]) == 12):
                    time[0] = '00'
                if (len(time[0]) == 1):
                    time[0] = '0' + time[0]
                nor = nor + 'T' + time[0] + ':' + time[1]
                get_time = 1
            pm = 0
            am = 0
            if (re.search('pm', org, flags=0) != None):
                pm = 1
            if (re.search('am', org, flags=0) != None):
                am = 1
            time = re.search('\d{1,4}', org, flags=0)
            if (time != None and get_time == 0):
                time = time.group(0)
                org = org.replace(time, '')
                hh, mm = '00', '00'
                if (len(time) == 4):
                    hh = time[0:2]
                    mm = time[2:]
                elif (len(time) == 3):
                    hh = time[0]
                    mm = time[1:]
                elif (len(time) == 2):
                    hh = time
                elif (len(time) == 1):
                    hh = time
                if (pm == 1 and int(hh) < 12):
                    hh = str(int(hh) + 12)
                elif (am == 1 and int(hh) == 12):
                    hh = '00'
                nor = nor + 'T' + hh + ':' + mm    
            #if (nor != ans):    
                #print(f'1:nor={nor}, ans={ans}, org={tmp}')
        elif (re.match('\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2}', org)):
            tmp = org
            nor = org.replace(' ', 'T')
            #if (nor != ans):    
                #print(f'2:nor={nor}, ans={ans}, org={tmp}')
        elif (re.match('(at |)(\d{1,2}|)(:|\.|)\d{2}( |)(am|pm|Hr|Hrs|hr|hrs|)( on | )(the |)\d{1,2}(\/|\.)\d{2,4}(\/|\.)\d{1,2}', org)):
            tmp = org
            pm = 0
            am = 0
            if (re.search('pm', org, flags=0) != None):
                pm = 1
            if (re.search('am', org, flags=0) != None):
                am = 1
            date = re.search('\d{1,2}(\/|\.)\d{1,2}(\/|\.)\d{2,4}', org, flags=0)
            if (date != None):
                date = date.group(0)
                org = org.replace(date, '')
                date = re.split('\/|\.', date)
                if (len(date[0]) == 1):
                    date[0] = '0' + date[0]
                if (len(date[1]) == 1):
                    date[1] = '0' + date[1]
                if (len(date[2]) == 2):
                    date[2] = '20' + date[2]
                elif (len(date[2]) == 3):
                    date[2] = '2' + date[2]
                nor = date[2] + '-' + date[1] + '-' + date[0] + 'T'
            org = org.replace(':', '')
            time = re.search('\d{1,4}', org, flags=0)
            if (time != None):
                time = time.group(0)
                org = org.replace(time, '')
                hh, mm = '00', '00'
                if (len(time) == 4):
                    hh = time[0:2]
                    mm = time[2:]
                elif (len(time) == 3):
                    hh = time[0]
                    mm = time[1:]
                elif (len(time) == 2):
                    hh = time
                elif (len(time) == 1):
                    hh = time
                if (pm == 1 and int(hh) < 12):
                    hh = str(int(hh) + 12)
                elif (am == 1 and int(hh) == 12):
                    hh = '00'
                nor = nor + hh + ':' + mm
            #if (nor != ans):    
                #print(f'3:nor={nor}, ans={ans}, org={tmp}')
        elif (re.match('((\d{1,2}((pm)|(am)))|(\d{4}(Hr|Hrs|hr|hrs|)))(( on )| )\d{1,2}(\/|\.)\d{1,2}(\/|\.)\d{2,4}', org)):
            tmp = org
            pm = 0
            am = 0
            if (re.search('pm', org, flags=0) != None):
                pm = 1
            if (re.search('am', org, flags=0) != None):
                am = 1
            date = re.search('\d{1,2}(\/|\.)\d{1,2}(\/|\.)\d{2,4}', org, flags=0)
            if (date != None):
                date = date.group(0)
                org = org.replace(date, '')
                date = re.split('\/|\.', date)
                if (len(date[0]) == 1):
                    date[0] = '0' + date[0]
                if (len(date[1]) == 1):
                    date[1] = '0' + date[1]
                if (len(date[2]) == 2):
                    date[2] = '20' + date[2]
                elif (len(date[2]) == 3):
                    date[2] = '2' + date[2]
                nor = date[2] + '-' + date[1] + '-' + date[0] + 'T'
            hrtime = re.search('\d{4}', org, flags=0)
            if (hrtime != None):
                hrtime = hrtime.group(0)
                org = org.replace(hrtime, '')
                nor = nor + hrtime[0:2] + ':' + hrtime[2:]
            time = re.search('\d{1,2}', org, flags=0)
            if (time != None):
                time = time.group(0)
                org = org.replace(time, '')
                hh = time
                if (pm == 1 and int(hh) < 12):
                    hh = str(int(hh) + 12)
                elif (am == 1 and int(hh) == 12):
                    hh = '00'
                if (len(hh) == 1):
                    hh = '0' + hh
                nor = nor + hh + ':' + '00'
            #if (nor != ans):    
                #print(f'4:nor={nor}, ans={ans}, org={tmp}')
    elif (time_type == 'DURATION'):   
        tmp = org
        org = org.replace('one', '1')
        org = org.replace('two', '2')
        org = org.replace('three', '3')
        org = org.replace('four', '4')
        org = org.replace('five', '5')
        num = ''
        alp = ''
        space_idx = org.find(' ')
        for i in range(len(org)):
            if (org[i] == 'D' or org[i] == 'd' or\
                org[i] == 'W' or org[i] == 'w' or\
                org[i] == 'M' or org[i] == 'm' or\
                org[i] == 'Y' or org[i] == 'y') and i > space_idx:
                alp = org[i]
                org = org[:i]
                break
        # print(org, alp)
        org = re.split('-| ', org)
        try:
            if org[0].isalpha():
                org[0] = w2n.word_to_num(org[0])
            # print(org)
            if (len(org) == 1 or org[1] == ''):
                nor = 'P' + str(org[0]) + alp.upper()
            else:
                nor = 'P' + str((int(org[0]) + int(org[1])) / 2) + alp.upper()
        except:
            nor = tmp
        # if (nor != ans):    
        #     print(f'dur:nor={nor}, ans={ans}, org={tmp}')
    elif (time_type == 'SET'):
        if (re.match('twice', org)):
            nor = 'R2'
    return nor


In [None]:
import pandas as pd

In [None]:
test_ans_df = pd.DataFrame({
    'file_id': [],
    'PHI_type': [],
    'PHI_start': [],
    'PHI_end': [],
    'PHI_content': [],
    'ISO': []
})

last_fid = ""
last_idx_of_last_seg = 0
for fid_sid, entities in result_ans_dict.items():
    curr_fid = fid_sid.split('_')[0]
    curr_sid = fid_sid.split('_')[1]
    # print(fid_sid)

    if curr_fid != last_fid:
        with open(os.path.join('./opendid_test/opendid_test', curr_fid+'.txt'), 'r') as file:
            content = file.read()
        last_fid = curr_fid
        last_idx_of_last_seg = 0

    # last_idx_of_last_seg = 0

    for i, entity in enumerate(entities):
        new_row = []
        # print(i, entity)

        if i == len(entities) - 1 and entity['entity_group'] == 'OTHER':
            last_idx_of_last_seg += len(result_segments[fid_sid])
            continue
        elif entity['entity_group'] != 'OTHER':
            # print(fid_sid)
            # print(last_idx_of_last_seg)
            start_idx = entity['start'] + last_idx_of_last_seg + int(curr_sid) - 1
            end_idx = entity['end'] + last_idx_of_last_seg + int(curr_sid) - 1
            # print('start', start_idx)
            # print(entity['word'])
            # find_idx = content.lower()[start_idx:].find(entity['word']) + start_idx
            # print('find', find_idx)
            # end_idx = find_idx+len(entity['word'])
            word = content[start_idx:end_idx]

            if i == len(entities) - 1:
                last_idx_of_last_seg += len(result_segments[fid_sid])

            if len(word) == 1 and word.isalnum() == False: 
                continue
            if len(word) > 1:
                # print(word, start_idx, end_idx)
                while word[0].isalnum() == False or word[-1].isalnum() == False:
                    if word[0].isalnum() == False:
                        word = word[1:]
                        start_idx += 1
                    # print(word, start_idx, end_idx)
                    if word[-1].isalnum() == False and word[-1] == ')': 
                        break
                    elif word[-1].isalnum() == False: # in case it is a newline character
                        word = word[:-1]
                        end_idx -= 1
            
            if '\n' in word:
                word = word.replace('\n', ' ') # in case contains newline character

            label = entity['entity_group']

            if entity['entity_group'] == 'DURATION' and word != 'twice':
                last_index = start_idx - 1
                # case 1: no spcae between number and month, week or year
                if content[last_index].isdigit():
                    while content[last_index].isdigit():
                        last_index -= 1
                    start_idx = last_index + 1
                    word = content[start_idx:end_idx]
                else:
                    last_space1 = content.rfind(' ', 0, start_idx)
                    last_space2 = content.rfind(' ', 0, last_space1)
                    start_idx = last_space2 + 1
                    word = content[start_idx:end_idx]
            elif word == 'twice':
                label = 'SET'
            
            new_row.extend([curr_fid, label, start_idx, end_idx, word])
            
            # print(curr_fid, label, start_idx, end_idx, word)

            need_iso = ['DATE', 'TIME', 'DURATION', 'SET']

            if label in need_iso:
                new_row.append(Normalize(label, word))
            else:
                new_row.append('')
            
            test_ans_df.loc[len(test_ans_df)] = new_row
        else:
            continue


In [None]:
test_ans_df

### Add in CRF predictions that PERT misses

In [None]:
df_crf = pd.read_csv('./opendid_test/crf_answer.csv')

In [None]:
df_crf_dur = df_crf.loc[df_crf['PHI_type'] == 'DURATION']
df_crf_set = df_crf.loc[df_crf['PHI_type'] == 'SET']
df_crf_lo = df_crf.loc[df_crf['PHI_type'] == 'LOCATION-OTHER'] 

In [None]:
df_crf_lo

In [None]:
test_ans_df.loc[test_ans_df['file_id'] == '2465']

In [None]:
# New row data
crf_set_row_idx = 6905
new_row = df_crf.iloc[crf_set_row_idx]

# Determine the position to insert the new row
insert_position = 8619

# Create a new Series with the new row data
new_series = pd.Series(new_row)

# Shift down the rows below the insertion point
test_ans_df = pd.concat([test_ans_df.iloc[:insert_position], new_series.to_frame().transpose(), test_ans_df.iloc[insert_position:]]).reset_index(drop=True)


In [None]:
# New row data
crf_lo_row_idx = 3401
new_row = df_crf.iloc[crf_lo_row_idx]

# Determine the position to insert the new row
insert_position = 4221

# Create a new Series with the new row data
new_series = pd.Series(new_row)

# Shift down the rows below the insertion point
test_ans_df = pd.concat([test_ans_df.iloc[:insert_position], new_series.to_frame().transpose(), test_ans_df.iloc[insert_position:]]).reset_index(drop=True)

test_ans_df

In [None]:
rows_to_delete = [4222, 4223, 4224]

# Use the drop method to delete the specified rows
test_ans_df = test_ans_df.drop(rows_to_delete)

test_ans_df.loc[test_ans_df['file_id'] == '2465']

In [None]:
test_ans_df.to_csv('./opendid_test/test_answer.csv')

In [None]:
test_ans_df.to_csv('./opendid_test/test_answer.txt', sep='\t', header=False, index=False)