# ICD Tokenize

The data set is a list of ICD labels and their descriptions.  
The goal is to tokenize the descriptions and train a model to predict the ICD labels.

In [1]:
! pip install -U transformers datasets



In [2]:
from datasets import load_dataset
dataset = (load_dataset("eddielin0926/chinese-icd", split="train")
            .train_test_split(train_size=800, test_size=200))
print(dataset)

DatasetDict({
    train: Dataset({
        features: ['year', 'month', 'no', 'death', 'input_code', 'result_code', 'check', 'serial_no', 'catalog', 'inputs', 'results', 'icds', 'encodes'],
        num_rows: 800
    })
    test: Dataset({
        features: ['year', 'month', 'no', 'death', 'input_code', 'result_code', 'check', 'serial_no', 'catalog', 'inputs', 'results', 'icds', 'encodes'],
        num_rows: 200
    })
})


### Configuration
Defining some key variables that will be used later on in the training

In [3]:
MODEL_CHECKPOINT = "bert-base-chinese"
BATCH_SIZE = 128
EPOCHS = 1
MAX_LEN = 80
LEARNING_RATE = 1e-05

## Pre-processing

In [4]:
class_label = dataset['train'].features["encodes"].feature
num_labels = class_label.num_classes
print("number of labels: ", num_labels)
print(class_label)

number of labels:  3568
ClassLabel(names=['L519', 'A523', 'I898', 'A047', 'E144', 'C797', 'C755', 'K831', 'B379', 'S621', 'C672', 'K409', 'D073', 'A179', 'I255', 'K353', 'C029', 'W11', 'D139', 'R944', 'V785', 'T502', 'C921', 'K228', 'S069(TR)', 'K226', 'N501(nTR)', 'D136', 'Q878', 'S610', 'L032', 'T835', 'O699', 'K820', 'V827', 'K256', 'M769', 'C677', 'K920', 'C689', 'T183', 'T327', 'B948', 'T213', 'C160', 'R060', 'T812', 'F104', 'I311', 'I670', 'C112', 'H931', 'K868', 'S158', 'R35', 'L109', 'T115', 'G2009', 'H348', 'P012', 'Q019', 'V878', 'G969', 'H441', 'K099', 'M431', 'X97', 'C773', 'J989', 'F191', 'Q445', 'O691', 'I110', 'K109', 'S121', 'Q069', 'D302', 'K650', 'D447', 'N508', 'V875', 'E702', 'J840', 'T174', 'S360(TR)', 'R798', 'N428', 'C629', 'O690', 'O441', 'M624', 'I519', 'R093', 'I471', 'A78', 'T818', 'X78', 'D133', 'P252', 'N920', 'K627', 'V455', 'B86', 'C509', 'P229', 'V892', 'I350(nRH)', 'H309', 'C944', 'T360', 'S618', 'J9840', 'L409', 'L038', 'B457', 'H431(nTR)', 'R509', 'F9

In [5]:
import numpy as np
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained(MODEL_CHECKPOINT, use_fast=True)

def preprocess(ds):
    texts = [" ".join(txt) for txt in ds["inputs"]]
    encoded_data = tokenizer(texts, padding="max_length", truncation=True)

    labels_matrix = []
    # fill numpy array
    for idx, labels in enumerate(ds["encodes"]):
        labels_matrix.append([1.0 if i in labels else 0.0 for i in range(num_labels)])

    encoded_data["labels"] = labels_matrix

    return encoded_data

tokenized_datasets = dataset.map(preprocess, batched=True, remove_columns=dataset['train'].column_names)
print(tokenized_datasets['train'])

Map:   0%|          | 0/800 [00:00<?, ? examples/s]

Map:   0%|          | 0/200 [00:00<?, ? examples/s]

Dataset({
    features: ['input_ids', 'token_type_ids', 'attention_mask', 'labels'],
    num_rows: 800
})


In [6]:
from transformers import AutoModelForSequenceClassification

model = AutoModelForSequenceClassification.from_pretrained(
    MODEL_CHECKPOINT,
    num_labels=num_labels,
    problem_type="multi_label_classification"
)

print(model)

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-chinese and are newly initialized: ['classifier.weight', 'classifier.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


BertForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(21128, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12,

In [13]:
import torch

def preprocess_logits_for_metrics(logits: torch.Tensor, labels: torch.Tensor):
    probs = torch.sigmoid(logits)
    preds = torch.where(probs > 0.5, 1, 0)
    return preds

In [10]:
import numpy as np
from transformers import EvalPrediction
from sklearn.metrics import precision_recall_fscore_support

def compute_metrics(eval_pred: EvalPrediction):
    preds, refs = eval_pred

    precision, recall, fbeta_score, support = precision_recall_fscore_support(refs, preds, average="macro", zero_division=0)

    result = {
        'precision': precision,
        'recall': recall,
        'fbeta_score': fbeta_score
    }

    return result

In [7]:
from transformers import TrainingArguments

model_name = MODEL_CHECKPOINT.split("/")[-1]

args = TrainingArguments(
    f"{model_name}-finetuned-icd-{num_labels}",
    evaluation_strategy = "epoch",
    save_strategy = "epoch",
    save_total_limit=5,
    learning_rate=LEARNING_RATE,
    num_train_epochs=EPOCHS,
    load_best_model_at_end=True
)

In [14]:
from transformers import Trainer

trainer = Trainer(
    model,
    args,
    train_dataset=tokenized_datasets['train'],
    eval_dataset=tokenized_datasets['test'],
    tokenizer=tokenizer,
    preprocess_logits_for_metrics=preprocess_logits_for_metrics,
    compute_metrics=compute_metrics
)

In [15]:
trainer.train()

Epoch,Training Loss,Validation Loss,Precision,Recall,Fbeta Score
1,No log,0.313073,7e-06,0.000841,1.4e-05


TrainOutput(global_step=100, training_loss=0.35359668731689453, metrics={'train_runtime': 92.7788, 'train_samples_per_second': 8.623, 'train_steps_per_second': 1.078, 'total_flos': 217228207718400.0, 'train_loss': 0.35359668731689453, 'epoch': 1.0})