### read test data and model to predict answer for task1

In [1]:
# library
import os

import torch
from torch import nn
from transformers import LongformerTokenizer, AutoTokenizer


  from .autonotebook import tqdm as notebook_tqdm


### read test data

In [2]:

test_dataset_doc_parh = "./dataset/validation_dataset/Validation_Release/"

test_path = [test_dataset_doc_parh + file_path for file_path in os.listdir(test_dataset_doc_parh)]

#check number of data-path
#print(len(test_path)) #560 for validation

In [3]:
# Define function to read data

def load_medical_records(paths):
    medical_record_dict = {}
    for data_path in paths:

        if os.path.isfile(data_path):
            file_id = data_path.split("/")[-1].split(".txt")[0]
            with open(data_path, "r", encoding="utf-8") as f:
                file_text = f.read()
                medical_record_dict[file_id] = file_text
    return medical_record_dict

test_record_dict = load_medical_records(test_path)

In [4]:
# double check
print(len(list(test_record_dict.keys())))

560


In [5]:
# fix label_type
labels_type_table={'OTHER': 0, 'PATIENT': 1, 'DOCTOR': 2, 'CITY': 3, 'ROOM': 4, 'STREET': 5, 'MEDICALRECORD': 6, 'DEPARTMENT': 7, 'LOCATION-OTHER': 8, 'COUNTRY': 9, 'IDNUM': 10, 'STATE': 11, 'AGE': 12, 'SET': 13, 'HOSPITAL': 14, 'DATE': 15, 'ZIP': 16, 'URL': 17, 'DURATION': 18, 'ORGANIZATION': 19, 'TIME': 20, 'PHONE': 21}
print(labels_type_table)

{'OTHER': 0, 'PATIENT': 1, 'DOCTOR': 2, 'CITY': 3, 'ROOM': 4, 'STREET': 5, 'MEDICALRECORD': 6, 'DEPARTMENT': 7, 'LOCATION-OTHER': 8, 'COUNTRY': 9, 'IDNUM': 10, 'STATE': 11, 'AGE': 12, 'SET': 13, 'HOSPITAL': 14, 'DATE': 15, 'ZIP': 16, 'URL': 17, 'DURATION': 18, 'ORGANIZATION': 19, 'TIME': 20, 'PHONE': 21}


### predict

In [6]:
model_name = "allenai/longformer-base-4096"
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)

In [7]:
from transformers import LongformerModel
from torchcrf import CRF

class MyLongformerModel(nn.Module):
    def __init__(self, num_labels):
        super(MyLongformerModel, self).__init__()

        self.longformer = LongformerModel.from_pretrained('allenai/longformer-base-4096')
        self.dropout = nn.Dropout(p=0.1)
        self.classifier = nn.Linear(768, num_labels)
        self.crf = CRF(num_labels, batch_first=True)

    def forward(self, input_ids, attention_mask, labels=None):
        outputs = self.longformer(input_ids=input_ids, attention_mask=attention_mask, return_dict=True)
        sequence_output = self.dropout(outputs.last_hidden_state)
        logits = self.classifier(sequence_output)

        if labels is not None:
            loss = -self.crf(logits, labels, mask=attention_mask.byte())
            return loss
        else:
            return self.crf.decode(logits, mask=attention_mask.byte())

model = MyLongformerModel(num_labels=22)


In [8]:

model_path = './model/longformer-crf_14_0.9815750423541777'
model.load_state_dict(torch.load(model_path, map_location=torch.device('cuda')))




<All keys matched successfully>

In [9]:
def decode_model_result(model_predict_list, offsets_mapping, labels_type_table):
    id_to_label = {id: label for label, id in labels_type_table.items()}
    predict_y = []
    pre_label_id = 0

    for position_id, label_id in enumerate(model_predict_list):
        if label_id != 0:
            if pre_label_id != label_id:
                start = int(offsets_mapping[position_id][0])
            end = int(offsets_mapping[position_id][1])

        if pre_label_id != label_id and pre_label_id != 0:
            predict_y.append([id_to_label[pre_label_id], start, end])
        pre_label_id = label_id

    if pre_label_id != 0:
        predict_y.append([id_to_label[pre_label_id], start, end])

    return predict_y


def merge_overlapping_predictions(predictions):
    if not predictions:
        return []

    sorted_predictions = sorted(predictions, key=lambda x: x[1])

    merged_predictions = [sorted_predictions[0]]
    for current in sorted_predictions[1:]:
        last = merged_predictions[-1]
        if current[0] == last[0] and current[1] <= last[2]:
            merged_predictions[-1] = (last[0], last[1], max(last[2], current[2]))
        else:
            merged_predictions.append(current)

    return merged_predictions

def predict_text_segments(models, tokenizer, text, max_length, overlap, device):
    all_model_predictions = []

    for model in models:
        model_predictions = []
        offset = 0
        for i in range(0, len(text), max_length - overlap):
            segment = text[i:i+max_length]
            encodings = tokenizer(segment, padding=True, truncation=True, return_tensors="pt", return_offsets_mapping=True)
            encodings["input_ids"] = encodings["input_ids"].to(device)
            encodings["attention_mask"] = encodings["attention_mask"].to(device)

            with torch.no_grad():
                outputs = model(encodings["input_ids"], encodings["attention_mask"])
                model_predict_list = outputs[0]
                predictions = decode_model_result(model_predict_list, encodings["offset_mapping"][0], labels_type_table)

            adjusted_predictions = [(label, start+offset, end+offset) for label, start, end in predictions]
            model_predictions.extend(adjusted_predictions)
            offset = i + max_length - overlap
        all_model_predictions.append(model_predictions)

    final_predictions = merge_and_vote(all_model_predictions)
    return final_predictions







In [10]:
def post_processing(label_name, start, end, text_segment):
    processed_label = label_name.strip()

    if processed_label.endswith('-') or processed_label.endswith('"') or processed_label.endswith("'"):
        processed_label = processed_label[:-1]
        end -= 1

    if processed_label == 'DATE' and text_segment.isdigit() and len(text_segment) > 8:
        end = start + 8 
        text_segment = text_segment[:8]  

    if processed_label == 'STATE':
        if text_segment.endswith('TAS'):
            text_segment = 'TAS'
            start = end - 3  
        elif (len(text_segment) >= 3):
            if text_segment[0].isupper() and text_segment[1].isupper() and text_segment[2].islower():
                if len(text_segment) == 3:
                    text_segment = text_segment[:2]  
                    end -= 1
                else:
                    text_segment = text_segment[1:]  
                    start += 1   

    if processed_label == 'CITY':
        if any(text_segment.endswith(suffix) for suffix in ['ONT', 'LET', 'NET', 'LAT']):
            end -= 1  
        elif any(text_segment.endswith(suffix) for suffix in ['RAS', 'CHS', 'LES']):
            end -= 1  

    return processed_label, start, end, text_segment



In [11]:
def merge_continuous_time_labels(predictions):
    merged_predictions = []
    prev_label = None

    for label_name, start, end, predict_str in predictions:
        if label_name == 'TIME' and prev_label and prev_label['label_name'] == 'TIME':
            if prev_label['end'] + 1 == start:
                prev_label['predict_str'] += ' ' + predict_str
                prev_label['end'] = end
                continue

        if prev_label:
            merged_predictions.append((prev_label['label_name'], prev_label['start'], prev_label['end'], prev_label['predict_str']))

        prev_label = {'label_name': label_name, 'start': start, 'end': end, 'predict_str': predict_str}

    if prev_label:
        merged_predictions.append((prev_label['label_name'], prev_label['start'], prev_label['end'], prev_label['predict_str']))

    return merged_predictions

In [12]:

def predict_for_single_sample(model, tokenizer, sample_id, val_medical_record_dict, device, max_length=4096, overlap=512):
    output_string = ""
    sample_text = val_medical_record_dict[sample_id]
    predictions = predict_text_segments(model, tokenizer, sample_text, max_length, overlap, device)
    final_predictions = merge_overlapping_predictions(predictions)

    extended_predictions = [(label_name, start, end, sample_text[start:end]) for label_name, start, end in final_predictions]

    merged_predictions = merge_continuous_time_labels(extended_predictions)

    for label_name, start, end, predict_str in merged_predictions:
        label_name, start, end, predict_str = post_processing(label_name, start, end, predict_str)
        sample_result_str = f"{sample_id}\t{label_name}\t{start}\t{end}\t{predict_str}\n"
        output_string += sample_result_str

    return output_string

In [13]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
model = model.to(device)

In [14]:
# test for one sample
sample_id = "file5124"  
print(predict_for_single_sample(model, tokenizer, sample_id, test_record_dict, device))

  score = torch.where(mask[i].unsqueeze(1), next_score, score)


file5124	DATE	12	19	2015083
file5124	HOSPITAL	36	58	COBRAM DISTRICT HEALTH
file5124	PATIENT	71	76	Corle
file5124	IDNUM	85	95	22N444639B
file5124	MEDICALRECORD	102	109	2254446
file5124	AGE	247	249	79
file5124	DOCTOR	892	902	F Koudelka
file5124	DOCTOR	1362	1374	F Wiltberger
file5124	DOCTOR	1988	1996	F Musich
file5124	DOCTOR	4180	4187	F Comee
file5124	DOCTOR	4570	4582	F Blachowski
file5124	DATE	4638	4644	2/6/72
file5124	DOCTOR	4676	4683	F Itani
file5124	DATE	7618	7626	8/6/2071
file5124	TIME	7888	7907	2846-12-08 00:00:00
file5124	PATIENT	7919	7926	Endsley



In [15]:
def predict_for_entire_dataset(model, tokenizer, val_medical_record_dict, device, max_length=4096, overlap=512):
    output_string = ""
    for sample_id, sample_text in val_medical_record_dict.items():
        predictions = predict_text_segments(model, tokenizer, sample_text, max_length, overlap, device)
        final_predictions = merge_overlapping_predictions(predictions)

        extended_predictions = [(label_name, start, end, sample_text[start:end]) for label_name, start, end in final_predictions]

        merged_predictions = merge_continuous_time_labels(extended_predictions)

        for label_name, start, end, predict_str in merged_predictions:
            label_name, start, end, predict_str = post_processing(label_name, start, end, predict_str)
            sample_result_str = f"{sample_id}\t{label_name}\t{start}\t{end}\t{predict_str}\n"
            output_string += sample_result_str

    return output_string


In [16]:
output_string = predict_for_entire_dataset(model, tokenizer, test_record_dict, device)

submission_dir = "./submission"
if not os.path.exists(submission_dir):
    os.mkdir(submission_dir)

with open(os.path.join(submission_dir, "answer.txt"), "w", encoding="utf-8") as f:
    f.write(output_string)