#Fine-tunning PLM-ICD with synthesis data

Notebook này thực hiện:
- Fine-tunning mô hình PLM-ICD từ checkpoint của Joakim Edin et. al (https://github.com/JoakimEdin/explainable-medical-coding) trên dữ liệu ghi chú y tế tổng hợp

In [None]:
!git clone https://github.com/chancholat/explain-icd.git
!pip install -q python-dotenv==1.0.0
!pip install -q datasets==3.4.1
!pip install -q omegaconf==2.3.0
!pip install -q captum==0.7.0
!pip install -q --force-reinstall transformers==4.38.1
!pip install -q --force-reinstall numpy==2.2.0
!pip install -q hydra-core

## Training on Synthesis notes

In [None]:
import os

# Change to a specific directory
os.chdir("/content/explain-icd")

In [None]:
!mkdir /content/explain-icd/data/processed/augmented_icd9_inpatient_code
!cp /content/explain-icd/data/processed/mdace_icd9_inpatient_code/test.parquet /content/explain-icd/data/processed/augmented_icd9_inpatient_code/
!cp /content/explain-icd/data/processed/mdace_icd9_inpatient_code/val.parquet /content/explain-icd/data/processed/augmented_icd9_inpatient_code/

Remember to up load the synthesis train.parquet, test.parquet, validation.parquet under augmented_inpatent_icd9_code folder

#### Testing synthesis notes

In [None]:
from datasets import load_dataset
from pathlib import Path

dataset_path = Path("/content/explain-icd/explainable_medical_coding/datasets/mdace_inpatient_icd9_code.py")
mimic = load_dataset(str(dataset_path), trust_remote_code=True)

Generating train split: 0 examples [00:00, ? examples/s]

Generating validation split: 0 examples [00:00, ? examples/s]

Generating test split: 0 examples [00:00, ? examples/s]

In [None]:
import pandas as pd
from collections import Counter
import numpy as np

train_dx_codes = mimic['train']['diagnosis_codes']
val_dx_codes = mimic['validation']['diagnosis_codes']

# Flatten all codes
# It is important that we only synthesis the dianosis code,
# Because the procedure code is assigned based on how each specific desease is diagnosis

# Step 3: Assign to DataFrame
train_df = pd.DataFrame()
val_df = pd.DataFrame()
train_df["train_dx_codes"] = [codes for codes in train_dx_codes if isinstance(codes, list) and len(codes) > 0]
val_df["val_dx_codes"] = [codes for codes in val_dx_codes if isinstance(codes, list) and len(codes) > 0]
all_train_codes = [code for row in train_df["train_dx_codes"] for code in row]
all_val_codes = [code for row in val_df["val_dx_codes"] for code in row]
train_code_freq = Counter(all_train_codes)
val_code_freq = Counter(all_val_codes)

def get_percentiles(freq_values):
  # Compute quartiles
  q1 = np.percentile(freq_values, 25)   # First quartile (Q1)
  q2 = np.percentile(freq_values, 50)   # Second quartile (median, Q2)
  q3 = np.percentile(freq_values, 75)   # Third quartile (Q3)
  q87 = np.percentile(freq_values, 87.5)
  q90 = np.percentile(freq_values, 90)
  q93 = np.percentile(freq_values, 93)
  q96 = np.percentile(freq_values, 96)
  q99 = np.percentile(freq_values, 99)
  q4 = np.max(freq_values)              # Fourth quartile (max)

  print(f"Q1 (25th percentile): {q1}")
  print(f"Q2 (Median): {q2}")
  print(f"Q3 (75th percentile): {q3}")
  print(f"Q87.5 (87.5th percentile): {q87}")
  print(f"Q90 (90th percentile): {q90}")
  print(f"Q93 (93th percentile): {q93}")
  print(f"Q96 (96th percentile): {q96}")
  print(f"Q99 (99th percentile): {q99}")
  print(f"Q4 (Max): {q4}")
  print()
  return q87, q90, q93, q96, q99, q4, q3, q2, q1

q87, q90, q93, q96, q99, q4, q3, q2, q1 = get_percentiles(np.array(list(train_code_freq.values())))
_, val_q90, _, _, _, val_q4, val_q3, val_q2, val_q1 = get_percentiles(np.array(list(val_code_freq.values())))

threshold = q90
val_threshold = val_q3
rare_codes = {code for code, freq in train_code_freq.items() if freq < threshold and val_code_freq[code] >= val_threshold}
print(len(rare_codes))

Q1 (25th percentile): 2.0
Q2 (Median): 5.0
Q3 (75th percentile): 26.0
Q87.5 (87.5th percentile): 90.0
Q90 (90th percentile): 126.0
Q93 (93th percentile): 199.40000000000146
Q96 (96th percentile): 390.7999999999993
Q99 (99th percentile): 1290.6000000000022
Q4 (Max): 17963

Q1 (25th percentile): 1.0
Q2 (Median): 2.0
Q3 (75th percentile): 6.0
Q87.5 (87.5th percentile): 15.0
Q90 (90th percentile): 19.0
Q93 (93th percentile): 29.0
Q96 (96th percentile): 45.279999999999745
Q99 (99th percentile): 134.07000000000016
Q4 (Max): 735

124


In [None]:
import pandas as pd

# Read dataframe from path
synth_path = '/content/explain-icd/data/processed/augmented_icd9_inpatient_code/train.parquet'
synthesize_df = pd.read_parquet(synth_path)
# synthesize_df.head()

real_path = '/content/explain-icd/data/processed/mdace_icd9_inpatient_code/train.parquet'
real_df = pd.read_parquet(real_path)
print(len(synthesize_df))

865


In [None]:
synthesize_df.head()

Unnamed: 0,note_id,subject_id,_id,note_type,note_subtype,text,diagnosis_codes,diagnosis_code_spans,diagnosis_code_type,procedure_codes,procedure_code_spans,procedure_code_type
0,3,3,3,,,Admission Date: [* * 2164-12-14 * *] Discharge...,"[482.0, 530.81, V12.71, 284.1, 571.5, 268.9, 5...",[],icd9cm,[],[],icd9pcs
1,10,10,10,,,Admission Date: [* * 2142-11-25 * *] Discharge...,"[427.89, 282.2, 790.7, 238.72, 715.95, V12.51,...",[],icd9cm,[],[],icd9pcs
2,11,11,11,,,Admission Date: [* * 2139-1-1 * *] Discharge D...,"[512.1, 428.0, 362.01, 733.00, 410.71, 250.60,...",[],icd9cm,[],[],icd9pcs
3,16,16,16,,,Admission Date: [* * 2159-4-20 * *] Discharge ...,"[576.2, 577.8, 070.70, 401.9, 311, 338.29, 493...",[],icd9cm,[],[],icd9pcs
4,23,23,23,,,Admission Date: [* * 2197-11-11 * *] Discharge...,"[401.9, 272.4, V45.71, 569.85, 562.12, 455.8, ...",[],icd9cm,[],[],icd9pcs


In [None]:
real_df.head()

Unnamed: 0,note_id,subject_id,_id,note_type,note_subtype,text,diagnosis_codes,diagnosis_code_spans,diagnosis_code_type,procedure_codes,procedure_code_spans,procedure_code_type
0,46698,99231,151778,Discharge summary,Report,Admission Date: [**2150-1-10**] ...,"[585.6, 038.9, 427.31, 078.5, 242.90, 710.0, 4...","[[[362, 365]], [[436, 441], [10133, 10138]], [...",icd9cm,[],[[[]]],icd9pcs
1,451263,96960,137513,Physician,Intensivist Note,TSICU\n HPI:\n Pt 2 days s/p gastric bypas...,[327.23],"[[[3959, 3981]]]",icd9cm,[],[[[]]],icd9pcs
2,8174,93578,149623,Discharge summary,Report,Admission Date: [**2106-10-29**] ...,"[348.9, 401.9, 272.0, 600.00, 477.8, V15.82, 2...","[[[281, 295], [616, 634]], [[991, 993]], [[995...",icd9cm,[],[[[]]],icd9pcs
3,18156,92287,106961,Discharge summary,Report,Admission Date: [**2173-8-3**] D...,"[V12.55, V12.51, 707.19, 425.4, 402.91, 571.5,...","[[[451, 458]], [[464, 478]], [[675, 699]], [[1...",icd9cm,[],[[[]]],icd9pcs
4,39335,96381,101173,Discharge summary,Report,Admission Date: [**2120-1-11**] ...,"[V10.11, 198.3, 348.5, 250.00, 401.9, 414.01, ...","[[[850, 886]], [[1403, 1427]], [[1418, 1450]],...",icd9cm,[],[[[]]],icd9pcs


In [None]:
from collections import Counter

def build_code_frequency(sets):
    freq = Counter()
    for s in sets:
        freq.update(s)
    return freq

augmented_code_sets = synthesize_df['diagnosis_codes'].tolist()
augmented_counter = build_code_frequency(augmented_code_sets)
for code in list(rare_codes)[:10]:
    print(f"Code {code} has {augmented_counter[code]} augmented sets")
#mapping rare code to its frquency
#checking if code in augemnted_couner keys
augmented_rare_code_freq = {code: augmented_counter[code] for code in rare_codes if code in augmented_counter}

#sort by rare code frequqency
augmented_rare_code_freq = dict(sorted(augmented_rare_code_freq.items(), key=lambda item: item[1], reverse=True))

#print top 10 most frequent augmented rare code
for code, freq in list(augmented_rare_code_freq.items())[:10]:
    print(f"Rare code {code} has {freq} augmented sets, original {train_code_freq[code]} sets, validating {val_code_freq[code]} sets")

print()
#print top 10 least frequent augmented rare code
for code, freq in list(augmented_rare_code_freq.items())[-10:]:
    print(f"Rare code {code} has {freq} augmented sets, original {train_code_freq[code]} sets, validating {val_code_freq[code]} sets")

#count code that in rare_code but not in augmented_counter
print(f"Code that in rare_code but not in augmented_counter: {len([code for code in rare_codes if code not in augmented_counter])}")

# count the set that in augmented_code_sets but not contain rare code
print(f"Code sets in augmented_code_sets where all codes are not in rare_codes: {len([code_set for code_set in augmented_code_sets if all(code not in rare_codes for code in code_set)])}")
print(f"Total augmented sets {len(augmented_code_sets)} ")

Rare code V85.1 has 31 augmented sets, original 59 sets, validating 9 sets
Rare code V14.0 has 31 augmented sets, original 51 sets, validating 6 sets
Rare code 796.3 has 30 augmented sets, original 58 sets, validating 6 sets
Rare code 789.51 has 30 augmented sets, original 39 sets, validating 7 sets
Rare code 710.2 has 27 augmented sets, original 50 sets, validating 9 sets
Rare code 780.65 has 27 augmented sets, original 49 sets, validating 6 sets
Rare code V46.3 has 25 augmented sets, original 68 sets, validating 7 sets
Rare code 455.2 has 25 augmented sets, original 80 sets, validating 6 sets
Rare code 729.81 has 25 augmented sets, original 53 sets, validating 10 sets
Rare code 429.5 has 23 augmented sets, original 69 sets, validating 6 sets

Rare code 296.50 has 2 augmented sets, original 64 sets, validating 6 sets
Rare code V15.52 has 2 augmented sets, original 40 sets, validating 7 sets
Rare code 304.00 has 2 augmented sets, original 90 sets, validating 6 sets
Rare code V64.41 has

### Checking training data

In [None]:
from datasets import load_dataset
from pathlib import Path

dataset_path = Path("/content/explain-icd/explainable_medical_coding/datasets/augmented_inpatient_icd9_code.py")
augmented = load_dataset(str(dataset_path), trust_remote_code=True)

Generating train split: 0 examples [00:00, ? examples/s]

Generating validation split: 0 examples [00:00, ? examples/s]

Generating test split: 0 examples [00:00, ? examples/s]

In [None]:
print(len(augmented["train"]))
print(len(mimic["train"]))

865
48074


In [None]:
	# !wget https://dl.fbaipublicfiles.com/biolm/RoBERTa-base-PM-M3-Voc-hf.tar.gz -P models
	# !tar -xvzf models/RoBERTa-base-PM-M3-Voc-hf.tar.gz -C models
	# !rm models/RoBERTa-base-PM-M3-Voc-hf.tar.gz
	# !mv models/RoBERTa-base-PM-M3-Voc/RoBERTa-base-PM-M3-Voc-hf models/roberta-base-pm-m3-voc-hf
	# !rm -r models/RoBERTa-base-PM-M3-Voc

--2025-07-24 08:11:01--  https://dl.fbaipublicfiles.com/biolm/RoBERTa-base-PM-M3-Voc-hf.tar.gz
Resolving dl.fbaipublicfiles.com (dl.fbaipublicfiles.com)... 3.163.189.96, 3.163.189.14, 3.163.189.51, ...
Connecting to dl.fbaipublicfiles.com (dl.fbaipublicfiles.com)|3.163.189.96|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 296574522 (283M) [application/gzip]
Saving to: ‘models/RoBERTa-base-PM-M3-Voc-hf.tar.gz’


2025-07-24 08:11:03 (205 MB/s) - ‘models/RoBERTa-base-PM-M3-Voc-hf.tar.gz’ saved [296574522/296574522]

RoBERTa-base-PM-M3-Voc/RoBERTa-base-PM-M3-Voc-hf/
RoBERTa-base-PM-M3-Voc/RoBERTa-base-PM-M3-Voc-hf/vocab.json
RoBERTa-base-PM-M3-Voc/RoBERTa-base-PM-M3-Voc-hf/config.json
RoBERTa-base-PM-M3-Voc/RoBERTa-base-PM-M3-Voc-hf/pytorch_model.bin
RoBERTa-base-PM-M3-Voc/RoBERTa-base-PM-M3-Voc-hf/merges.txt


### Train

In [None]:
# !python train_plm.py experiment=augmented_icd9_code/plm_icd_supervised load_model=./models/supervised/ym0o7co8  dataloader.max_batch_size=8 dataloader.batch_size=8 callbacks=f1_macro gpu=0
!python train_plm.py experiment=augmented_icd9_code/plm_icd load_model=./models/unsupervised/vxrn54op  dataloader.max_batch_size=8 dataloader.batch_size=8 callbacks=f1_macro gpu=0

  diet_gradient_scaler = torch.cuda.amp.GradScaler()
  advesarial_noise_gradient_scaler = torch.cuda.amp.GradScaler()
[2025-07-25 05:15:49,856][infotropy.utils.random][INFO] - Set 'numpy', 'random' and 'torch' random seed to 1337
[32m'Device: cuda'[0m
[32m'CUDA_VISIBLE_DEVICES: 0'[0m
[2025-07-25 05:15:50,511][/content/explain-icd/train_plm.py][INFO] - Loading Tokenizer from model_path
  split2code_indices[split_name] = torch.tensor(target_ids)
[2025-07-25 05:15:50,782][/content/explain-icd/train_plm.py][INFO] - {'num_examples': 5815, 'num_train_examples': 865, 'num_val_examples': 1753, 'num_test_examples': 3197, 'average_words_per_example': np.float64(3714.2958259905226), 'average_targets_per_example': np.float64(17.21477018121796), 'num_classes': 8943, 'num_train_classes': 1951, 'num_val_classes': 3044, 'num_test_classes': 3980, 'vocab_size': 50001, 'pad_token_id': 1, 'pad_target_id': -1, 'sos_target_id': None, 'eos_target_id': None}
[2025-07-25 05:15:50,782][/content/explain-icd/

In [None]:
from huggingface_hub import create_repo, upload_folder
from google.colab import userdata
import os

# Get Hugging Face token from Colab secrets
hf_token = ""

# Define repository details
repo_id = "ChanBeDu/Synthesis-PLM-ICD"
repo_type = "model"
folder_to_upload = "/content/explain-icd/models/pwkz99e4"
path_in_repo = "supervised-seed10-filter-notes-lr1e5"
commit_message = "supervised-seed10-filter-notes-lr1e5"

# Create the repository if it doesn't exist
try:
    create_repo(repo_id, repo_type=repo_type, token=hf_token, exist_ok=True)
    print(f"Repository '{repo_id}' created or already exists.")
except Exception as e:
    print(f"Error creating repository: {e}")
    # If creating the repo fails, you might need to check your token and permissions

# Upload the folder
try:
    upload_folder(
        repo_id=repo_id,
        folder_path=folder_to_upload,
        path_in_repo=path_in_repo,
        commit_message=commit_message,
        repo_type=repo_type,
        token=hf_token,
    )
    print(f"Folder '{folder_to_upload}' successfully uploaded to '{repo_id}/{path_in_repo}'.")
except Exception as e:
    print(f"Error uploading folder: {e}")
    # If uploading fails, check the folder path and your token/permissions again

Repository 'ChanBeDu/Synthesis-PLM-ICD' created or already exists.
Folder '/content/explain-icd/models/pwkz99e4' successfully uploaded to 'ChanBeDu/Synthesis-PLM-ICD/supervised-seed10-filter-notes-lr1e5'.
