### install library and download data

In [1]:
# !pip install -qU \
#   lightning \
#   datasets \
#   wandb \
#   gdown \
#   transformers

In [2]:
# import os
# from pprint import pprint as pp
# os.mkdir("./model/")

In [3]:
# library
import os
#from pprint import pprint as pp

import torch
from torch import nn
from transformers import LongformerTokenizer, AutoTokenizer
from torch.utils.data import Dataset, DataLoader

from torch.optim import AdamW
from torch.nn import CrossEntropyLoss

from collections import defaultdict
import random
import json

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
seed = 61

random.seed(seed)

### load data

In [5]:
# data structure:
# First_Phase_Release(Correction)/First_Phase_Text_Dataset/
# First_Phase_Release(Correction)/answer.txt
# Second_Phase_Dataset/Second_Phase_Text_Dataset/
# Second_Phase_Dataset/answer.txt
# validation_dataset/Validation_Release/
# validation_dataset/answer.txt

first_dataset_doc_path = "./dataset/First_Phase_Release(Correction)/First_Phase_Text_Dataset/"
second_dataset_doc_path = "./dataset/Second_Phase_Dataset/Second_Phase_Text_Dataset/"
label_path = ["./dataset/First_Phase_Release(Correction)/answer.txt", "./dataset/Second_Phase_Dataset/answer.txt"]
val_dataset_doc_parh = "./dataset/validation_dataset/Validation_Release/"
val_label_path = "./dataset/validation_dataset/answer.txt"

first_dataset_path = [first_dataset_doc_path + file_path for file_path in os.listdir(first_dataset_doc_path)]
second_dataset_path = [second_dataset_doc_path + file_path for file_path in os.listdir(second_dataset_doc_path)]
train_path = first_dataset_path + second_dataset_path
val_path = [val_dataset_doc_parh + file_path for file_path in os.listdir(val_dataset_doc_parh)]

#check number of data-path
print(len(first_dataset_path)) #1120
print(len(second_dataset_path)) #614
print()
print(len(train_path)) #1734
print(len(val_path)) #560

1121
615

1736
561


In [6]:
# we can use utf-8-sig to solve ufeff problem (need to remove for label)
# Define function to read label

def create_label_dict(label_path):
    label_dict = {}  # y
    with open(label_path, "r", encoding="utf-8-sig") as f:
        file_text = f.read().strip()  

    # (id, label, start, end, query) or (id, label, start, end, query, time_org, timefix)
    for line in file_text.split("\n"):
        sample = line.split("\t")  
        sample[2], sample[3] = int(sample[2]), int(sample[3])

        if sample[0] not in label_dict:
            label_dict[sample[0]] = [sample[1:]]
        else:
            label_dict[sample[0]].append(sample[1:])

    return label_dict

train_label_dict = create_label_dict(label_path[0])
second_dataset_label_dict = create_label_dict(label_path[1])
train_label_dict.update(second_dataset_label_dict)
val_label_dict = create_label_dict(val_label_path)

In [7]:
# Define function to read data

def load_medical_records(paths):
    medical_record_dict = {}
    for data_path in paths:

        if os.path.isfile(data_path):
            file_id = data_path.split("/")[-1].split(".txt")[0]
            with open(data_path, "r", encoding="utf-8") as f:
                file_text = f.read()
                medical_record_dict[file_id] = file_text
    return medical_record_dict

train_medical_record_dict = load_medical_records(train_path)
val_medical_record_dict = load_medical_records(val_path)

In [8]:
#chect the number of data
print(len(list(train_medical_record_dict.keys()))) #1734
print(len(list(train_label_dict.keys()))) #1734
print(len(list(val_medical_record_dict.keys()))) #560
print(len(list(val_label_dict.keys()))) #560

1734
1734
560
560


In [9]:
all_medical_record_dict = {**train_medical_record_dict, **val_medical_record_dict}
all_label_dict = {**train_label_dict, **val_label_dict}

### clean data

In [10]:
def check_labels(text, labels, record_id, tag=False):
    for i, label in enumerate(labels):  
        extracted_text = text[label[1]:label[2]]
        if extracted_text != label[3]:
            print(f"Error in ID {record_id}, Line {i}: {label[0]}, position: {label[1]}-{label[2]}, "
                  f"label: '{label[3]}', extracted: '{extracted_text}'")
        elif tag:
            print(f"Correct in ID {record_id}, Line {i}: {label[0]}, position: {label[1]}-{label[2]}, extracted: '{extracted_text}'")

def check_all_labels(medical_records, label_dict, tag=False):
    for record_id, text in medical_records.items():
        if record_id in label_dict:
            labels = label_dict[record_id]
            check_labels(text, labels, record_id, tag)
        else:
            print(f"ID: {record_id} has no label")

         

In [11]:
# check training data
check_all_labels(all_medical_record_dict, all_label_dict)   

Error in ID 1139, Line 16: HOSPITAL, position: 2702-2722, label: 'PLANTAGENET HOSPITAL', extracted: 'PLANTAGENE3/9 JENNIE'
Error in ID 1481, Line 21: DEPARTMENT, position: 2390-2403, label: 'SEALS Central', extracted: 'SEAKALBARRI H'
Error in ID file21297, Line 20: ORGANIZATION, position: 6045-6064, label: 'KB Home Los Angeles', extracted: 'KB Home	Los Angeles'


In [12]:
# check 1139, PLANTAGENET 3/9 JENNIE COX CLOSE Pathology ?
print(all_medical_record_dict['1139'][2702:2722])
print(all_label_dict['1139'][16])

# replace it
all_label_dict['1139'][16][3]=all_medical_record_dict['1139'][2702:2722]

PLANTAGENE3/9 JENNIE
['HOSPITAL', 2702, 2722, 'PLANTAGENET HOSPITAL']


In [13]:
# check 1481, there is no DEPARTMENT
print(all_medical_record_dict['1481'][2390:2403])
print(all_label_dict['1481'][21])

# remove it 
all_label_dict['1481'].pop(21)

SEAKALBARRI H
['DEPARTMENT', 2390, 2403, 'SEALS Central']


['DEPARTMENT', 2390, 2403, 'SEALS Central']

In [14]:
# check file21297, index 6047 is '\t'
all_medical_record_dict['file21297'][6045:6064]

# replace it
all_medical_record_dict['file21297'] = val_medical_record_dict['file21297'][:6047] + ' ' + val_medical_record_dict['file21297'][6048:]

In [15]:
all_keys = list(all_medical_record_dict.keys())
random.shuffle(all_keys)


In [16]:
train_size = int(0.8 * len(all_keys))
val_size = len(all_keys) - train_size

train_keys = all_keys[:train_size]
val_keys = all_keys[train_size:]

train_medical_record_dict = {key: all_medical_record_dict[key] for key in train_keys}
train_label_dict = {key: all_label_dict[key] for key in train_keys}

val_medical_record_dict = {key: all_medical_record_dict[key] for key in val_keys}
val_label_dict = {key: all_label_dict[key] for key in val_keys}

print("New Train Set Size:", len(train_medical_record_dict))
print("New Validation Set Size:", len(val_medical_record_dict))

New Train Set Size: 1835
New Validation Set Size: 459


### create label type table

In [17]:
#add special token [other] in label list
labels_type = list(set( [label[0] for labels in train_label_dict.values() for label in labels] ))
labels_type = ["OTHER"] + labels_type 
labels_num = len(labels_type)
# print(labels_type)
# print("The number of labels:", labels_num)
labels_type_table = {label_name:id for id, label_name in enumerate(labels_type)}
print(labels_type_table)

{'OTHER': 0, 'URL': 1, 'TIME': 2, 'PATIENT': 3, 'ZIP': 4, 'LOCATION-OTHER': 5, 'ROOM': 6, 'DOCTOR': 7, 'ORGANIZATION': 8, 'PHONE': 9, 'CITY': 10, 'MEDICALRECORD': 11, 'SET': 12, 'COUNTRY': 13, 'STREET': 14, 'DURATION': 15, 'STATE': 16, 'DEPARTMENT': 17, 'AGE': 18, 'IDNUM': 19, 'DATE': 20, 'HOSPITAL': 21}


In [18]:
# fix it
labels_type_table={'OTHER': 0, 'PATIENT': 1, 'DOCTOR': 2, 'CITY': 3, 'ROOM': 4, 'STREET': 5, 'MEDICALRECORD': 6, 'DEPARTMENT': 7, 'LOCATION-OTHER': 8, 'COUNTRY': 9, 'IDNUM': 10, 'STATE': 11, 'AGE': 12, 'SET': 13, 'HOSPITAL': 14, 'DATE': 15, 'ZIP': 16, 'URL': 17, 'DURATION': 18, 'ORGANIZATION': 19, 'TIME': 20, 'PHONE': 21}
print(labels_type_table)

{'OTHER': 0, 'PATIENT': 1, 'DOCTOR': 2, 'CITY': 3, 'ROOM': 4, 'STREET': 5, 'MEDICALRECORD': 6, 'DEPARTMENT': 7, 'LOCATION-OTHER': 8, 'COUNTRY': 9, 'IDNUM': 10, 'STATE': 11, 'AGE': 12, 'SET': 13, 'HOSPITAL': 14, 'DATE': 15, 'ZIP': 16, 'URL': 17, 'DURATION': 18, 'ORGANIZATION': 19, 'TIME': 20, 'PHONE': 21}


In [19]:
#check the label_type is enough for validation
val_labels_type = list(set( [label[0] for labels in val_label_dict.values() for label in labels] ))
for val_label_type in val_labels_type:
    if val_label_type not in labels_type:
        print("Special label in validation:", val_label_type)

In [20]:
# Function to count label distribution
def count_label_distribution(label_dict, labels_type_table):
    label_counts = {label: 0 for label in labels_type_table.keys()}
    for labels in label_dict.values():
        for label_info in labels:
            label = label_info[0]  # Extract label name
            if label in label_counts:
                label_counts[label] += 1
    return label_counts

In [21]:
# Calculate label distribution
train_label_distribution = count_label_distribution(train_label_dict, labels_type_table)
val_label_distribution = count_label_distribution(val_label_dict, labels_type_table)

# Print results
print("Train Label Distribution:")
for label, count in train_label_distribution.items():
    print(f"  {label}: {count}")

print("\nValidation Label Distribution:")
for label, count in val_label_distribution.items():
    print(f"  {label}: {count}")

Train Label Distribution:
  OTHER: 0
  PATIENT: 1885
  DOCTOR: 6987
  CITY: 1020
  ROOM: 1
  STREET: 980
  MEDICALRECORD: 1912
  DEPARTMENT: 1135
  LOCATION-OTHER: 7
  COUNTRY: 3
  IDNUM: 3918
  STATE: 952
  AGE: 146
  SET: 11
  HOSPITAL: 1915
  DATE: 5108
  ZIP: 994
  URL: 3
  DURATION: 28
  ORGANIZATION: 113
  TIME: 1256
  PHONE: 9

Validation Label Distribution:
  OTHER: 0
  PATIENT: 479
  DOCTOR: 1745
  CITY: 256
  ROOM: 0
  STREET: 240
  MEDICALRECORD: 475
  DEPARTMENT: 268
  LOCATION-OTHER: 3
  COUNTRY: 2
  IDNUM: 936
  STATE: 233
  AGE: 38
  SET: 3
  HOSPITAL: 478
  DATE: 1285
  ZIP: 244
  URL: 0
  DURATION: 6
  ORGANIZATION: 47
  TIME: 279
  PHONE: 2


### pretrain model

In [22]:
model_name = "yikuan8/Clinical-Longformer"
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)



In [23]:
from transformers import LongformerModel

class MyLongformerModel(nn.Module):

    def __init__(self, num_labels):
        super(MyLongformerModel, self).__init__()

        self.longformer = LongformerModel.from_pretrained('yikuan8/Clinical-Longformer')
        self.dropout = nn.Dropout(p=0.1)
        self.classifier = nn.Linear(768, num_labels)

    def forward(self, input_ids, attention_mask):
        output = self.longformer(input_ids=input_ids, attention_mask=attention_mask, return_dict=True)
        output = self.dropout(output.last_hidden_state)
        logits = self.classifier(output)

        return logits


model = MyLongformerModel(num_labels=22)  

  return self.fget.__get__(instance, owner)()
Some weights of LongformerModel were not initialized from the model checkpoint at yikuan8/Clinical-Longformer and are newly initialized: ['longformer.pooler.dense.bias', 'longformer.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [25]:
BACH_SIZE = 4
#TRAIN_RATIO = 0.9
LEARNING_RATE = 1e-5
EPOCH = 20

In [26]:
num_labels=22

### dataloader

In [27]:
import torch
from torch.utils.data import Dataset, DataLoader
class Privacy_protection_dataset(Dataset):
    def __init__(self, medical_record_dict:dict, medical_record_labels:dict, tokenizer, labels_type_table:dict, mode:str):
        self.max_length = 4096
        self.labels_type_table = labels_type_table
        self.tokenizer = tokenizer
        self.data = []

        for id, text in medical_record_dict.items():
            labels = medical_record_labels.get(id, [])
            self.split_and_add_data(text, labels, id)
    
    def split_and_add_data(self, text, labels, id):
        # Split text into chunks of max_length
        for i in range(0, len(text), self.max_length):
            text_chunk = text[i:i+self.max_length]
            # Adjust labels for this chunk
            chunk_labels = [label for label in labels if label[1] >= i and label[2] <= i+self.max_length]
            chunk_labels = [[label[0], label[1] - i, label[2] - i] for label in chunk_labels]
            self.data.append((text_chunk, chunk_labels, id))

    def __getitem__(self, index):
        text_chunk, chunk_labels, id = self.data[index]
        return text_chunk, chunk_labels, id

    def __len__(self):
        return len(self.data)

    #find the correct labels ids after tokenizer
    def find_token_ids(self, label_start, label_end, offset_mapping):
        encodeing_start = float("inf") #max
        encodeing_end = 0
        for token_id, token_range in enumerate(offset_mapping):
            token_start, token_end = token_range
          
            #if token range one side out of label range, still take the token
            if token_start == 0 and token_end == 0: #special tocken
                continue
                
            if label_start<token_end and label_end>token_start:
                if token_id<encodeing_start:
                    encodeing_start = token_id
                encodeing_end = token_id+1
                
        return encodeing_start, encodeing_end

    def encode_labels_position(self, batch_lables:list, offset_mapping:list):
        #encode the batch_lables's position
        batch_encodeing_labels = []
        for sample_labels, sample_offsets in zip(batch_lables, offset_mapping):
            encodeing_labels = []
            for label in sample_labels:
                encodeing_start, encodeing_end = self.find_token_ids(label[1], label[2], sample_offsets)
                encodeing_labels.append([label[0], encodeing_start, encodeing_end])
            batch_encodeing_labels.append(encodeing_labels)
        return batch_encodeing_labels

    def create_labels_tensor(self, batch_shape:list, batch_labels_position_encoded:list):
        if batch_shape[-1]> self.max_length:
            batch_shape[-1] = self.max_length
        labels_tensor = torch.zeros(batch_shape)

        for sample_id in range(batch_shape[0]):
            for label in batch_labels_position_encoded[sample_id]:
                label_id = self.labels_type_table[label[0]]
                start = label[1]
                end = label[2]
                
                if start >= self.max_length: continue
                elif end >= self.max_length: end = self.max_length
                
                labels_tensor[sample_id][start:end] = label_id
                
        return labels_tensor

    def collate_fn(self, batch_items:list):
        #the calculation process in dataloader iteration
        batch_medical_record = [sample[0] for sample in batch_items]
        batch_labels = [sample[1] for sample in batch_items]
        batch_id_list = [sample[2] for sample in batch_items]
        
        encodings = self.tokenizer(batch_medical_record, padding=True, max_length=self.max_length, truncation=True, return_tensors="pt", return_offsets_mapping="True") # truncation=True

        batch_labels_position_encoded = self.encode_labels_position(batch_labels, encodings["offset_mapping"])
        batch_labels_tensor = self.create_labels_tensor(encodings["input_ids"].shape, batch_labels_position_encoded)
  
        return encodings, batch_labels_tensor, batch_labels

In [28]:
train_id_list = list(train_medical_record_dict.keys())
train_medical_record = {sample_id: train_medical_record_dict[sample_id] for sample_id in train_id_list}
train_labels = {sample_id: train_label_dict[sample_id] for sample_id in train_id_list}

val_id_list = list(val_medical_record_dict.keys())
val_medical_record = {sample_id: val_medical_record_dict[sample_id] for sample_id in val_id_list}
val_labels = {sample_id: val_label_dict[sample_id] for sample_id in val_id_list}

train_dataset = Privacy_protection_dataset(train_medical_record, train_labels, tokenizer, labels_type_table, "train")
val_dataset = Privacy_protection_dataset(val_medical_record, val_labels, tokenizer, labels_type_table, "validation")


train_dataloader = DataLoader( train_dataset, batch_size = BACH_SIZE, shuffle = True, collate_fn = train_dataset.collate_fn)
val_dataloader = DataLoader( val_dataset, batch_size = BACH_SIZE, shuffle = False, collate_fn = val_dataset.collate_fn)

In [29]:
# #check dataloader
# print(len(train_dataset))
# for sample in train_dataset:
#     train_x, train_y,_ = sample
#     #print(train_x)
#     print(train_y)
#     break
    
# print(len(train_dataloader))
# for sample in train_dataloader:
#     train_x, train_y, _ = sample
#     #print(train_x)
#     print(train_y)
#     break

In [30]:
# #show the first batch labels embeddings
# print(labels_type_table)
# for i in range(BACH_SIZE):
#     print(train_y[i].tolist())

### train

In [31]:

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
model = model.to(device) # Put model on device
optim = AdamW(model.parameters(), lr = LEARNING_RATE)
#if use CRF
loss_fct = CrossEntropyLoss()

In [32]:
def decode_model_result(model_predict_table, offsets_mapping, labels_type_table):
    model_predict_list = model_predict_table.tolist()
    id_to_label = {id: label for label, id in labels_type_table.items()}
    predict_y = []
    pre_label_id = 0
    start = 0

    for position_id, label_id in enumerate(model_predict_list):
        if label_id != 0:
            if pre_label_id != label_id:
                start = int(offsets_mapping[position_id][0])
            end = int(offsets_mapping[position_id][1])

        if pre_label_id != label_id and pre_label_id != 0:
            predict_y.append((id_to_label[pre_label_id], start, end))  # 改為元組
        pre_label_id = label_id

    if pre_label_id != 0:
        predict_y.append((id_to_label[pre_label_id], start, end))  # 同樣改為元組

    return predict_y


def calculate_batch_score(batch_labels, model_predict_sequences, offset_mappings, labels_type_table):
    score_table = defaultdict(lambda: {"TP": 0, "FP": 0, "FN": 0})
    id_to_label = {id: label for label, id in labels_type_table.items()}
    batch_size = len(model_predict_sequences)

    for batch_id in range(batch_size):
        sample_prediction = decode_model_result(model_predict_sequences[batch_id], offset_mappings[batch_id], labels_type_table)
        sample_ground_truth = batch_labels[batch_id]

        # convert ground truth and predictions to sets for comparison
        sample_ground_truth = set([tuple(token) for token in sample_ground_truth])
        sample_prediction = set([tuple(token) for token in sample_prediction])

        # calculate TP, FP, FN for each label
        for label_id in labels_type_table.values():
            label = id_to_label[label_id]
            gt_entities = {x for x in sample_ground_truth if x[0] == label}
            pred_entities = {x for x in sample_prediction if x[0] == label}

            score_table[label]["TP"] += len(gt_entities & pred_entities)
            score_table[label]["FP"] += len(pred_entities - gt_entities)
            score_table[label]["FN"] += len(gt_entities - pred_entities)

    return score_table


In [33]:
def save_metrics_to_file(metrics, filename):
    with open(filename, 'w') as file:
        json.dump(metrics, file, indent=4)

training_stats = []


In [34]:
for epoch in range(EPOCH):
    model.train()
    total_train_loss = 0
    train_score_table = defaultdict(lambda: {"TP": 0, "FP": 0, "FN": 0})

    for batch_x, batch_y, batch_labels in train_dataloader:
        optim.zero_grad()
        input_ids = batch_x["input_ids"].to(device)
        attention_mask = batch_x["attention_mask"].to(device)
        labels = batch_y.long().to(device)
        
        outputs = model(input_ids, attention_mask)

        train_loss = loss_fct(outputs.transpose(-1, -2), labels)
        total_train_loss += train_loss.item()

        train_loss.backward()
        optim.step()

        # Optional: Calculate confusion matrix for training data
        with torch.no_grad():
            outputs = model(input_ids, attention_mask)
            model_predict_tables = torch.argmax(outputs, dim=-1, keepdim=True).squeeze(-1) 
            batch_score_table = calculate_batch_score(batch_labels, model_predict_tables, batch_x["offset_mapping"], labels_type_table)
            for label, scores in batch_score_table.items():
                for key in train_score_table[label]:
                    train_score_table[label][key] += scores[key]

    avg_train_loss = total_train_loss / len(train_dataloader)

    model.eval()
    total_val_loss = 0
    total_val_score_table = defaultdict(lambda: {"TP": 0, "FP": 0, "FN": 0})

    for batch_x, batch_y, batch_labels in val_dataloader:
        input_ids = batch_x["input_ids"].to(device)
        attention_mask = batch_x["attention_mask"].to(device)
        labels = batch_y.long().to(device)
        
        with torch.no_grad():
            # Forward pass
            outputs = model(input_ids, attention_mask)
            # Calculate loss only for the active part of the loss
            val_loss = loss_fct(outputs.transpose(-1, -2), labels)
                        
            total_val_loss += val_loss.item()

            model_predict_tables = torch.argmax(outputs, dim=-1, keepdim=True).squeeze(-1)                
            
            batch_score_table = calculate_batch_score(batch_labels, model_predict_tables, batch_x["offset_mapping"], labels_type_table)
            for label, scores in batch_score_table.items():
                for key in total_val_score_table[label]:
                    total_val_score_table[label][key] += scores[key]

    avg_val_loss = total_val_loss / len(val_dataloader)

    # Storing metrics for each epoch
    epoch_stats = {
        'epoch': epoch,
        'train_loss': avg_train_loss,
        'val_loss': avg_val_loss,
        'train_confusion_matrix': train_score_table,  # Confusion matrix for training data
        'val_confusion_matrix': total_val_score_table  # Confusion matrix for validation data
    }
    training_stats.append(epoch_stats)
    
    # Print training statistics for current epoch
    print(f"Epoch {epoch}")
    print(f"Train Loss: {avg_train_loss}")
    print(f"Validation Loss: {avg_val_loss}")
#     print(f"Train Confusion Matrix: {train_score_table}")
#     print(f"Validation Confusion Matrix: {total_val_score_table}")
    
    # save_model
    model_save_path = f"./model_proceed/Clinical_longformer_epoch_{epoch}.pt"
    torch.save(model.state_dict(), model_save_path)
    
# Save training statistics to a file
save_metrics_to_file(training_stats, 'Clinical_training_stat_longformer.json')  

Input ids are automatically padded to be a multiple of `config.attention_window`: 512


Epoch 0
Train Loss: 0.12367465942023201
Validation Loss: 0.007885276908643714
Epoch 1
Train Loss: 0.007174165593528884
Validation Loss: 0.004311302577609721
Epoch 2
Train Loss: 0.004293867629366531
Validation Loss: 0.003640659365049649
Epoch 3
Train Loss: 0.002847567018353968
Validation Loss: 0.0029496674775647738
Epoch 4
Train Loss: 0.0020949217204277878
Validation Loss: 0.002239813918126386
Epoch 5
Train Loss: 0.0015898151152783037
Validation Loss: 0.0025780160078451362
Epoch 6
Train Loss: 0.001233280033088127
Validation Loss: 0.002394143355308313
Epoch 7
Train Loss: 0.0010409497439452299
Validation Loss: 0.0027142742664568948
Epoch 8
Train Loss: 0.0008732758623004494
Validation Loss: 0.0022930943365992436
Epoch 9
Train Loss: 0.0006063183966925272
Validation Loss: 0.00207638403730664
Epoch 10
Train Loss: 0.000517673560824891
Validation Loss: 0.002175390780148414
Epoch 11
Train Loss: 0.0005439671973218272
Validation Loss: 0.002480827491931502
Epoch 12
Train Loss: 0.0013088208287490085