

# A Revised Interactive KeyClass Tutorial: Text Classification with Label-Descriptions Only


<hr>

***By:*** Laxmi Vijayan, Aganze Mihigo   

***Based On:*** Classifying Unstructured Clinical Notes via Automatic Weak Supervision
**Authors:** Arnab Dey, Chufan Gao, Mononito Goswami, correspondence to &lt;mgoswami@andrew.cmu.edu&gt;




# 1. Introduction

## 1.1 Problem Background & Motivation

The accuracy of International Classification of Diseases (ICD) codes is paramount in the healthcare sector for two main reasons:

1. These codes are integral to standardized billing practices so providers are reimbursed correctly and efficiently for the services they provide <sup><a href="#references"><b>1</b></a></sup>.
2. ICD codes are crucial in epidemiological studies, where they help in tracking and analyzing the prevalence and incidence of diseases across different populations and geographies.<sup><a href="#references"><b>2</b></a>,</sup><sup><a href="#references"><b>3</b></a>,</sup><sup><a href="#references"><b>4</b></a></sup>

The assignment of ICD codes is determined through the analysis of Electronic Health Records (EHRs) which are comprehensive, patient-centered records that are digital versions of a patient's paper chart and track a patient’s trajectory through the healthcare system <sup><a href="#references"><b>1</b></a></sup>. These records include:
1. Detailed medical histories
2. Diagnoses Information
3. Procedures
4. Medications Information.

However, the data captured in EHRs can often be unstructured and comprised of medical jargon, making the assignment of ICD codes a labor-intensive process. Consequently, healthcare providers frequently depend on trained coders or third-party vendors. This process is **expensive**, **labor-intensive**, and **prone to errors**<sup><a href="#references"><b>1</b></a>,</sup><sup><a href="#references"><b>5</b></a>,</sup><sup><a href="#references"><b>6</b></a></sup> due to the subjective interpretation of text and the ever-evolving nature of medical nomenclature and coding systems.

Given the broad impact of ICD codes within the healthcare sector, there is a pressing demand for **accuracy** and **efficiency**, which has spurred interest in leveraging machine learning (ML) technologies to automate the coding process. These systems can potentially reduce the time and cost associated with manual coding and improve accuracy by consistently applying the same rules to the available data, thereby minimizing human error.

However, conventional ML approaches depend heavily on large volumes of manually labeled data, which are still costly and labor-intensive<sup><a href="#references"><b>9</b></a>,</sup><sup><a href="#references"><b>10</b></a></sup>. This problem is exacerbated twofold by the frequent updates to the ICD, which are not necessarily compatible with previous versions, and by the large variability in the existing labeled diagnostic data as a result of institutional processes or the clinical diagnostic practices of the physicians <sup><a href="#references"><b>1</b></a>,</sup><sup><a href="#references"><b>7</b></a>,</sup><sup><a href="#references"><b>8</b></a></sup>.

The need for a reliable, generalizable, cost- and time-effective automated classification system is clear.

## 1.2 Classifying Unstructured Clinical Notes via Automatic Weak Supervision

In the paper "Classifying Unstructured Clinical Notes via Automatic Weak Supervision," <sup><a href="#references"><b>10</b></a></sup> the authors present a novel framework for text classification, specifically focusing on assigning International Classification of Diseases (ICD) codes to unstructured clinical notes without the need for manually labeled data. This framework, named KeyClass, utilizes a general weakly supervised learning approach, leveraging the linguistic domain knowledge embedded in pre-trained language models and employing a data programming methodology.

### 1.2.1 Innovations and Effectiveness

KeyClass introduces several innovative approaches:
1. **Interpretable Weak Supervision Sources**: It automatically extracts weak supervision sources, such as keywords and phrases from class-label descriptions, allowing the model to learn from these inputs without the need for a human-labeled training set.
2. **Utilization of Pre-trained Language Models**: By integrating pre-trained language models, KeyClass harnesses extensive linguistic knowledge, facilitating the accurate classification of medical terms and phrases found in clinical notes.
3. **Data Programming Integration**: KeyClass employs data programming to generate probabilistic labels for training data, significantly reducing the reliance on manually labeled datasets.

### 1.2.2 Performance
KeyClass demonstrated strong performance across multiple datasets, particularly the MIMIC-III database, where it was tasked with assigning ICD-9 codes to medical notes. It performed comparably to more traditional supervised learning methods such as FasTag <sup><a href="#references"><b>9</b></a></sup>, showcasing its capability to effectively handle real-world, complex text classification tasks without extensive manual data labeling.

### 1.2.3 Contribution to Research
KeyClass significantly contributes to the field by addressing the high costs and labor-intensive processes involved in manual ICD code assignment. It provides a scalable and efficient solution that could be adopted widely in healthcare settings to enhance the accuracy and efficiency of medical coding practices. The model's ability to perform well without labeled data presents a significant advancement in machine learning applications within the healthcare sector, offering a pathway toward more automated and accessible medical record management systems. This approach not only aligns with the ongoing needs for improved data handling in healthcare but also sets a foundation for future research in automated systems that require minimal human intervention while maintaining high accuracy and reliability.

<a id='methodology'></a>
## 2. Methodology

Instead of relying on previously labeled documents, KeyClass combines the linguistic domain expertise of pre-trained models and easily obtained class descriptions to label data.

In the following notebook, we've built on the tutorial provided by the authors of the paper to create an easy way to familiarize yourself with the KeyClass framework using the same IMDb Dataset.

<a id='classdesc'></a>
### 2.1 Find Class Descriptions

The IMDb dataset is often used for movie review sentiments. A movie review can be classified as either being `positive` or `negative`.

In order to being classification, we provide general descriptions of the two classes. A class description for *positive* can be `good, amazing, exciting, positive, fun`. Similarly, a class description for a *negative* `terrible, bad, boring, negative`.

These descriptions can either be generated by mining sources as such as Wikipedia or through more official categorizations, such as the ICD9's long category descriptions.

Class Descriptions used in this tutorial can be found [here](./config_files/config_imdb.yml)

<a id='keywords'></a>
### 2.2 Find Relevant Keywords / Encoding the Dataset

KeyClass uses these class descriptions to find 1 to 3 word keywords and phrases that are highly suggestive of each class.
Using pre-trained models such as the `paraphrase-mpnet-v2` or more specialized linguistic models, such as BlueBert, it creates labeling functions to map each keyword or phrase to the class description it's most closely related to. A subsampling of top-k labeling functions per class is used in order to account for computation and space constraints.

##  2.3 Setting Up the Environment

Please uncomment the following cells and run the commands to begin setting up your environment.

In [1]:
!pip install pyyaml



In [2]:
!pip install sentence-transformers

Collecting sentence-transformers
  Downloading sentence_transformers-2.7.0-py3-none-any.whl (171 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m171.5/171.5 kB[0m [31m2.9 MB/s[0m eta [36m0:00:00[0m
Collecting nvidia-cuda-nvrtc-cu12==12.1.105 (from torch>=1.11.0->sentence-transformers)
  Using cached nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (23.7 MB)
Collecting nvidia-cuda-runtime-cu12==12.1.105 (from torch>=1.11.0->sentence-transformers)
  Using cached nvidia_cuda_runtime_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (823 kB)
Collecting nvidia-cuda-cupti-cu12==12.1.105 (from torch>=1.11.0->sentence-transformers)
  Using cached nvidia_cuda_cupti_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (14.1 MB)
Collecting nvidia-cudnn-cu12==8.9.2.26 (from torch>=1.11.0->sentence-transformers)
  Using cached nvidia_cudnn_cu12-8.9.2.26-py3-none-manylinux1_x86_64.whl (731.7 MB)
Collecting nvidia-cublas-cu12==12.1.3.1 (from torch>=1.11.0->sentence-transform

In [3]:
!pip install snorkel transformers sentence-transformers cleantext pyhealth gdown

Collecting snorkel
  Downloading snorkel-0.9.9-py3-none-any.whl (103 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m103.3/103.3 kB[0m [31m178.0 kB/s[0m eta [36m0:00:00[0m
Collecting cleantext
  Downloading cleantext-1.1.4-py3-none-any.whl (4.9 kB)
Collecting pyhealth
  Downloading pyhealth-1.1.6-py2.py3-none-any.whl (311 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m311.6/311.6 kB[0m [31m8.1 MB/s[0m eta [36m0:00:00[0m
Collecting munkres>=1.0.6 (from snorkel)
  Downloading munkres-1.1.4-py2.py3-none-any.whl (7.0 kB)
Collecting rdkit>=2022.03.4 (from pyhealth)
  Downloading rdkit-2023.9.6-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (34.9 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m34.9/34.9 MB[0m [31m36.5 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting pandas>=1.0.0 (from snorkel)
  Downloading pandas-1.5.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (12.1 MB)
[2K     [90m━━━━━━━

## 2.4 Mounting Google Drive

In [12]:
from google.colab import drive
drive.mount('/content/drive/', force_remount=True)

Mounted at /content/drive/


In [13]:
import sys
base_path = '/content/drive/MyDrive/KeyClass/'
sys.path.append(base_path + 'keyclass/')
sys.path.append(base_path + 'scripts/')

In [14]:
import argparse
import pandas, plotly, matplotlib, seaborn
import label_data, encode_datasets, train_downstream_model
import torch
import pickle
import numpy as np
import os
from os.path import join, exists
from datetime import datetime
import utils
import models
import create_lfs
import train_classifier

In [15]:
# Input arguments
config_file_path = base_path+'/config_files/config_imdb.yml' # Specify path to the configuration file
random_seed = 0 # Random seed for experiments

In [5]:
args = utils.Parser(config_file_path=config_file_path).parse()

if args['use_custom_encoder']:
    model = models.CustomEncoder(pretrained_model_name_or_path=args['base_encoder'],
        device='cuda' if torch.cuda.is_available() else 'cpu')
else:
    model = models.Encoder(model_name=args['base_encoder'],
        device='cuda' if torch.cuda.is_available() else 'cpu')

for split in ['train', 'test']:
    sentences = utils.fetch_data(dataset=args['dataset'], split=split, path=args['data_path'])
    embeddings = model.encode(sentences=sentences, batch_size=args['end_model_batch_size'],
                                show_progress_bar=args['show_progress_bar'],
                                normalize_embeddings=args['normalize_embeddings'])
    with open(join(args['data_path'], args['dataset'], f'{split}_embeddings.pkl'), 'wb') as f:
        pickle.dump(embeddings, f)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


<a id='label'></a>
### 2.5 Probabilistically Labeling the Data

*KeyClass* creates a matrix that represents how different labeling functions agree or disagree on labeling the training documents. Then, it uses the open-source label model implementation from the Snorkel Python library to turn these agreements and disagreements into probabilistic labels-- it assigns labels with a measure of uncertainty or confidence rather than just labeling them directly.

This approach helps in handling complex or noisy data where simple labeling might be difficult or unreliable.



In [6]:
# Load training data
train_text = utils.fetch_data(dataset=args['dataset'], path=args['data_path'], split='train')

training_labels_present = False
if exists(join(args['data_path'], args['dataset'], 'train_labels.txt')):
    with open(join(args['data_path'], args['dataset'], 'train_labels.txt'), 'r') as f:
        y_train = f.readlines()
    y_train = np.array([int(i.replace('\n','')) for i in y_train])
    training_labels_present = True
else:
    y_train = None
    training_labels_present = False
    print('No training labels found!')

with open(join(args['data_path'], args['dataset'], 'train_embeddings.pkl'), 'rb') as f:
    X_train = pickle.load(f)

# Print dataset statistics
print(f"Getting labels for the {args['dataset']} data...")
print(f'Size of the data: {len(train_text)}')
if training_labels_present:
    print('Class distribution', np.unique(y_train, return_counts=True))

# Load label names/descriptions
label_names = []
for a in args:
    if 'target' in a: label_names.append(args[a])

# Creating labeling functions
labeler = create_lfs.CreateLabellingFunctions(base_encoder=args['base_encoder'],
                                            device=torch.device(args['device']),
                                            label_model=args['label_model'])
proba_preds = labeler.get_labels(text_corpus=train_text, label_names=label_names, min_df=args['min_df'],
                                ngram_range=args['ngram_range'], topk=args['topk'], y_train=y_train,
                                label_model_lr=args['label_model_lr'], label_model_n_epochs=args['label_model_n_epochs'],
                                verbose=True, n_classes=args['n_classes'])

y_train_pred = np.argmax(proba_preds, axis=1)

# Save the predictions
if not os.path.exists(args['preds_path']): os.makedirs(args['preds_path'])
with open(join(args['preds_path'], f"{args['label_model']}_proba_preds.pkl"), 'wb') as f:
    pickle.dump(proba_preds, f)

# Print statistics
print('Label Model Predictions: Unique value and counts', np.unique(y_train_pred, return_counts=True))
if training_labels_present:
    print('Label Model Training Accuracy', np.mean(y_train_pred==y_train))

    # Log the metrics
    training_metrics_with_gt = utils.compute_metrics(y_preds=y_train_pred, y_true=y_train, average=args['average'])
    utils.log(metrics=training_metrics_with_gt, filename='label_model_with_ground_truth',
        results_dir=args['results_path'], split='train')

Getting labels for the imdb data...
Size of the data: 25000
Class distribution (array([0, 1]), array([12500, 12500]))
Found assigned category counts [6789 9578]
labeler.vocabulary:
 16367
labeler.word_indicator_matrix.shape (25000, 600)
Len keywords 600
assigned_category: Unique and Counts (array([0, 1]), array([300, 300]))
negative, hate, expensive, bad, poor, broke, waste, horrible, would not recommend ['abominable' 'abomination' 'absolute worst' 'absolutely awful'
 'absolutely terrible' 'abuse' 'abused' 'abusive' 'abysmal'
 'acting horrible' 'acting poor' 'acting terrible' 'actors bad'
 'actually bad' 'also bad' 'among worst' 'annoyance' 'annoying' 'appalled'
 'appalling' 'atrocious' 'awful' 'awfully' 'awfulness' 'bad' 'bad actor'
 'bad actors' 'bad actually' 'bad almost' 'bad bad' 'bad could'
 'bad either' 'bad enough' 'bad even' 'bad film' 'bad films' 'bad get'
 'bad horror' 'bad idea' 'bad like' 'bad made' 'bad makes' 'bad many'
 'bad movie' 'bad movies' 'bad music' 'bad one' 'ba

100%|██████████| 100/100 [00:03<00:00, 33.23epoch/s]


Label Model Predictions: Unique value and counts (array([0, 1]), array([ 8914, 16086]))
Label Model Training Accuracy 0.70016
Saving results in /content/drive/MyDrive/KeyClass/results/imdb/train_label_model_with_ground_truth_07-May-2024-18_13_45.txt...


<a id='exp_training'></a>
## 3. Experimentation: Training

<a id='downstream'></a>
### 3.1 Training the Downstream Model

Now, we have a proabilitistically labeled training dataset that can be used to train our downstream classfier. KeyClass uses the top-*k* documents with the most confident label estimates to train the classifier. This model will be saved under './models/{dataset_name}' as end_model and the date.

In [7]:
args = utils.Parser(config_file_path=config_file_path).parse()

# Set random seeds
random_seed = random_seed
torch.manual_seed(random_seed)
np.random.seed(random_seed)

X_train_embed_masked, y_train_lm_masked, y_train_masked, \
	X_test_embed, y_test, training_labels_present, \
	sample_weights_masked, proba_preds_masked = train_downstream_model.load_data(args)

# Train a downstream classifier

if args['use_custom_encoder']:
	encoder = models.CustomEncoder(pretrained_model_name_or_path=args['base_encoder'], device=args['device'])
else:
	encoder = models.Encoder(model_name=args['base_encoder'], device=args['device'])

classifier = models.FeedForwardFlexible(encoder_model=encoder,
										h_sizes=args['h_sizes'],
										activation=eval(args['activation']),
										device=torch.device(args['device']))
print('\n===== Training the downstream classifier =====\n')
model = train_classifier.train(model=classifier,
							device=torch.device(args['device']),
							X_train=X_train_embed_masked,
							y_train=y_train_lm_masked,
							sample_weights=sample_weights_masked if args['use_noise_aware_loss'] else None,
							epochs=args['end_model_epochs'],
							batch_size=args['end_model_batch_size'],
							criterion=eval(args['criterion']),
							raw_text=False,
							lr=eval(args['end_model_lr']),
							weight_decay=eval(args['end_model_weight_decay']),
							patience=args['end_model_patience'])

# # Saving the model
# if not os.path.exists(args['preds_path']): os.makedirs(args['preds_path'])
# with open(join(args['preds_path'], f"{args['label_model']}_proba_preds.pkl"), 'wb') as f:
#     pickle.dump(proba_preds, f)


# end_model_preds_train = model.predict_proba(torch.from_numpy(X_train_embed_masked), batch_size=512, raw_text=False)
# end_model_preds_test = model.predict_proba(torch.from_numpy(X_test_embed), batch_size=512, raw_text=False)


if not os.path.exists(args['model_path']): os.makedirs(args['model_path'])
current_time = datetime.now()
model_name = f'end_model_{current_time.strftime("%d-%b-%Y")}.pth'
print(f'Saving model {model_name}...')
with open(join(args['model_path'], model_name), 'wb') as f:
		torch.save(model, f)

end_model_preds_train = model.predict_proba(
		 																			 torch.from_numpy(X_train_embed_masked),
																					 batch_size=512, raw_text=False)
end_model_preds_test = model.predict_proba(torch.from_numpy(X_test_embed),
																						batch_size=512,
																						raw_text=False)

# Save the predictions
with open(join(args['preds_path'], 'end_model_preds_train.pkl'),
					'wb') as f:
		pickle.dump(end_model_preds_train, f)
with open(join(args['preds_path'], 'end_model_preds_test.pkl'), 'wb') as f:
		pickle.dump(end_model_preds_test, f)

# Print statistics
if training_labels_present:
		training_metrics_with_gt = utils.compute_metrics(
				y_preds=np.argmax(end_model_preds_train, axis=1),
				y_true=y_train_masked,
				average=args['average'])
		utils.log(metrics=training_metrics_with_gt,
							filename='end_model_with_ground_truth',
							results_dir=args['results_path'],
							split='train')

training_metrics_with_lm = utils.compute_metrics(y_preds=np.argmax(
		end_model_preds_train, axis=1),
																									y_true=y_train_lm_masked,
																									average=args['average'])
utils.log(metrics=training_metrics_with_lm,
					filename='end_model_with_label_model',
					results_dir=args['results_path'],
					split='train')

testing_metrics = utils.compute_metrics_bootstrap(
		y_preds=np.argmax(end_model_preds_test, axis=1),
		y_true=y_test,
		average=args['average'],
		n_bootstrap=args['n_bootstrap'],
		n_jobs=args['n_jobs'])
utils.log(metrics=testing_metrics,
					filename='end_model_with_ground_truth',
					results_dir=args['results_path'],
					split='test')

Confidence of least confident data point of class 0: 0.9118952135029044
Confidence of least confident data point of class 1: 0.9999157389438634

==== Data statistics ====
Size of training data: (25000, 768), testing data: (25000, 768)
Size of testing labels: (25000,)
Size of training labels: (25000,)
Training class distribution (ground truth): [0.5 0.5]
Training class distribution (label model predictions): [0.35656 0.64344]

KeyClass only trains on the most confidently labeled data points! Applying mask...

==== Data statistics (after applying mask) ====
Size of training data: (7000, 768)
Size of training labels: (7000,)
Training class distribution (ground truth): [0.55057143 0.44942857]
Training class distribution (label model predictions): [0.5 0.5]

===== Training the downstream classifier =====



Epoch 16:  80%|████████  | 16/20 [00:03<00:00,  4.54batch/s, best_loss=0.543, running_loss=0.549, tolerance_count=2]


Stopping early...
Saving model end_model_07-May-2024.pth...
Saving results in /content/drive/MyDrive/KeyClass/results/imdb/train_end_model_with_ground_truth_07-May-2024-18_14_16.txt...
Saving results in /content/drive/MyDrive/KeyClass/results/imdb/train_end_model_with_label_model_07-May-2024-18_14_16.txt...


  pid = os.fork()
[Parallel(n_jobs=10)]: Using backend LokyBackend with 10 concurrent workers.
[Parallel(n_jobs=10)]: Done  30 tasks      | elapsed:    2.5s


Saving results in /content/drive/MyDrive/KeyClass/results/imdb/test_end_model_with_ground_truth_07-May-2024-18_14_19.txt...


[Parallel(n_jobs=10)]: Done 100 out of 100 | elapsed:    3.0s finished


<a id='self'></a>
### 3.2 Self-Training the Model
Lastly, KeyClass self-trains on the entire training dataset to refine the end model classifier further. It saves this model to the same location as end_model_with_self_training and the date.

In [8]:
# Fetching the raw text data for self-training
X_train_text = utils.fetch_data(dataset=args['dataset'], path=args['data_path'], split='train')
X_test_text = utils.fetch_data(dataset=args['dataset'], path=args['data_path'], split='test')

model = train_classifier.self_train(model=model,
									X_train=X_train_text,
									X_val=X_test_text,
									y_val=y_test,
									device=torch.device(args['device']),
									lr=eval(args['self_train_lr']),
									weight_decay=eval(args['self_train_weight_decay']),
									patience=args['self_train_patience'],
									batch_size=args['self_train_batch_size'],
									q_update_interval=args['q_update_interval'],
									self_train_thresh=eval(args['self_train_thresh']),
									print_eval=True)


end_model_preds_test = model.predict_proba(X_test_text, batch_size=args['self_train_batch_size'], raw_text=True)


# Print statistics
testing_metrics = utils.compute_metrics_bootstrap(y_preds=np.argmax(end_model_preds_test, axis=1),
													y_true=y_test,
													average=args['average'],
													n_bootstrap=args['n_bootstrap'],
													n_jobs=args['n_jobs'])


current_time = datetime.now()
model_name = f'end_model_self_trained_{current_time.strftime("%d %b %Y")}.pth'
print(f'Saving model {model_name}...')
with open(join(args['model_path'], model_name), 'wb') as f:
		torch.save(model, f)

end_model_preds_test = model.predict_proba(
		X_test_text, batch_size=args['self_train_batch_size'], raw_text=True)

# Save the predictions
with open(
				join(args['preds_path'], 'end_model_self_trained_preds_test.pkl'),
				'wb') as f:
		pickle.dump(end_model_preds_test, f)

# Print statistics
testing_metrics = utils.compute_metrics_bootstrap(
		y_preds=np.argmax(end_model_preds_test, axis=1),
		y_true=y_test,
		average=args['average'],
		n_bootstrap=args['n_bootstrap'],
		n_jobs=args['n_jobs'])
utils.log(metrics=testing_metrics,
					filename='end_model_with_ground_truth_self_trained',
					results_dir=args['results_path'],
					split='test')

Epoch 8:  13%|█▎        | 8/62 [42:02<4:43:49, 315.35s/batch, self_train_agreement=1, tolerance_count=2, validation_accuracy=0.915]
[Parallel(n_jobs=10)]: Using backend LokyBackend with 10 concurrent workers.
[Parallel(n_jobs=10)]: Done  30 tasks      | elapsed:    2.4s
[Parallel(n_jobs=10)]: Done  81 out of 100 | elapsed:    2.5s remaining:    0.6s
[Parallel(n_jobs=10)]: Done 100 out of 100 | elapsed:    2.8s finished


Saving model end_model_self_trained_07 May 2024.pth...


[Parallel(n_jobs=10)]: Using backend LokyBackend with 10 concurrent workers.
[Parallel(n_jobs=10)]: Done  30 tasks      | elapsed:    2.4s
[Parallel(n_jobs=10)]: Done  81 out of 100 | elapsed:    2.5s remaining:    0.6s


Saving results in /content/drive/MyDrive/KeyClass/results/imdb/test_end_model_with_ground_truth_self_trained_07-May-2024-19_11_23.txt...


[Parallel(n_jobs=10)]: Done 100 out of 100 | elapsed:    2.9s finished


<a id='exp_testing'></a>
## 4. Experimentation: Testing

In [9]:
end_model_path='/content/drive/MyDrive/KeyClass/models/imdb/end_model_07-May-2024.pth'
end_model_self_trained_path='/content/drive/MyDrive/KeyClass/models/imdb/end_model_self_trained_07 May 2024.pth'

args = utils.Parser(config_file_path=config_file_path).parse()

# Set random seeds
random_seed = random_seed
torch.manual_seed(random_seed)
np.random.seed(random_seed)

X_train_embed_masked, y_train_lm_masked, y_train_masked, \
	X_test_embed, y_test, training_labels_present, \
	sample_weights_masked, proba_preds_masked = train_downstream_model.load_data(args)

model = torch.load(end_model_path)

end_model_preds_train = model.predict_proba(torch.from_numpy(X_train_embed_masked), batch_size=512, raw_text=False)
end_model_preds_test = model.predict_proba(torch.from_numpy(X_test_embed), batch_size=512, raw_text=False)

# Print statistics
if training_labels_present:
	training_metrics_with_gt = utils.compute_metrics(y_preds=np.argmax(end_model_preds_train, axis=1),
														y_true=y_train_masked,
														average=args['average'])
	print('training_metrics_with_gt', training_metrics_with_gt)

training_metrics_with_lm = utils.compute_metrics(y_preds=np.argmax(end_model_preds_train, axis=1),
													y_true=y_train_lm_masked,
													average=args['average'])
print('training_metrics_with_lm', training_metrics_with_lm)

testing_metrics = utils.compute_metrics_bootstrap(y_preds=np.argmax(end_model_preds_test, axis=1),
													y_true=y_test,
													average=args['average'],
													n_bootstrap=args['n_bootstrap'],
													n_jobs=args['n_jobs'])
print('testing_metrics', testing_metrics)


print('\n===== Self-training the downstream classifier =====\n')

# Fetching the raw text data for self-training
X_train_text = utils.fetch_data(dataset=args['dataset'], path=args['data_path'], split='train')
X_test_text = utils.fetch_data(dataset=args['dataset'], path=args['data_path'], split='test')

model = torch.load(end_model_self_trained_path)

end_model_preds_test = model.predict_proba(X_test_text, batch_size=args['self_train_batch_size'], raw_text=True)


# Print statistics
testing_metrics = utils.compute_metrics_bootstrap(y_preds=np.argmax(end_model_preds_test, axis=1),
													y_true=y_test,
													average=args['average'],
													n_bootstrap=args['n_bootstrap'],
													n_jobs=args['n_jobs'])
print('testing_metrics after self train', testing_metrics)

utils.log(metrics=testing_metrics,
					filename='end_model_with_ground_truth_self_trained',
					results_dir=args['results_path'],
					split='test')


Confidence of least confident data point of class 0: 0.9118952135029044
Confidence of least confident data point of class 1: 0.9999157389438634

==== Data statistics ====
Size of training data: (25000, 768), testing data: (25000, 768)
Size of testing labels: (25000,)
Size of training labels: (25000,)
Training class distribution (ground truth): [0.5 0.5]
Training class distribution (label model predictions): [0.35656 0.64344]

KeyClass only trains on the most confidently labeled data points! Applying mask...

==== Data statistics (after applying mask) ====
Size of training data: (7000, 768)
Size of training labels: (7000,)
Training class distribution (ground truth): [0.55057143 0.44942857]
Training class distribution (label model predictions): [0.5 0.5]
training_metrics_with_gt [0.9181428571428571, 0.9226160465912984, 0.9181428571428571]
training_metrics_with_lm [0.9218571428571428, 0.9218786672789429, 0.9218571428571428]


[Parallel(n_jobs=10)]: Using backend LokyBackend with 10 concurrent workers.
[Parallel(n_jobs=10)]: Done  40 tasks      | elapsed:    0.1s
[Parallel(n_jobs=10)]: Done 100 out of 100 | elapsed:    0.2s finished


testing_metrics [[0.8474412  0.0019673 ]
 [0.86153452 0.001771  ]
 [0.8474412  0.0019673 ]]

===== Self-training the downstream classifier =====



[Parallel(n_jobs=10)]: Using backend LokyBackend with 10 concurrent workers.
  pid = os.fork()
[Parallel(n_jobs=10)]: Done  30 tasks      | elapsed:    2.5s
[Parallel(n_jobs=10)]: Done  81 out of 100 | elapsed:    2.7s remaining:    0.6s


testing_metrics after self train [[0.9151308  0.00171572]
 [0.91525774 0.00170526]
 [0.9151308  0.00171572]]
Saving results in /content/drive/MyDrive/KeyClass/results/imdb/test_end_model_with_ground_truth_self_trained_07-May-2024-19_20_29.txt...


[Parallel(n_jobs=10)]: Done 100 out of 100 | elapsed:    2.8s finished


## 5. Plotting Results

Examine the Accuracy, Prescision, and Recall values of the different models.

In [21]:
import json
import plotly.graph_objects as go
import os
args = utils.Parser(config_file_path=config_file_path).parse()

test_em_gt = os.path.join(base_path, args['results_path'], 'test_end_model_with_ground_truth_07-May-2024-18_14_19.txt')
test_em_gt_st = os.path.join(base_path, args['results_path'], 'test_end_model_with_ground_truth_self_trained_07-May-2024-19_20_29.txt')
train_em_gt= os.path.join(base_path, args['results_path'], 'train_end_model_with_ground_truth_07-May-2024-18_14_16.txt')
train_em_lm_gt = os.path.join(base_path, args['results_path'], 'train_end_model_with_label_model_07-May-2024-18_14_16.txt')
train_lm_gt = os.path.join(base_path, args['results_path'],'train_label_model_with_ground_truth_07-May-2024-18_13_45.txt')





files = {
    'Train Label Model with Ground Truth': os.path.join(base_path, args['results_path'], 'train_label_model_with_ground_truth_07-May-2024-18_13_45.txt'),
    'Train End Model with Label Model': os.path.join(base_path, args['results_path'], 'train_end_model_with_label_model_07-May-2024-18_14_16.txt'),
    'Train End Model with Ground Truth': os.path.join(base_path, args['results_path'], 'train_end_model_with_ground_truth_07-May-2024-18_14_16.txt'),
    'Test End Model with Ground Truth': os.path.join(base_path, args['results_path'], 'test_end_model_with_ground_truth_07-May-2024-18_14_19.txt'),
    'Test Self-Trained End Model with Ground Truth': os.path.join(base_path, args['results_path'], 'test_end_model_with_ground_truth_self_trained_07-May-2024-19_20_29.txt')

}

# Initialize lists to store the data
labels = []
accuracy = []
precision = []
recall = []
accuracy_err = []
precision_err = []
recall_err = []

# Read each file and extract the metrics
for label, file_path in files.items():
    with open(file_path, 'r') as f:
        data = json.load(f)
        labels.append(label)

        if 'mean' in str(data):
            # Test models with mean and std
            accuracy.append(data['Accuracy (mean, std)'][0])
            precision.append(data['Precision (mean, std)'][0])
            recall.append(data['Recall (mean, std)'][0])
            accuracy_err.append(data['Accuracy (mean, std)'][1])
            precision_err.append(data['Precision (mean, std)'][1])
            recall_err.append(data['Recall (mean, std)'][1])
        else:
            # Train models without std
            accuracy.append(data['Accuracy'])
            precision.append(data['Precision'])
            recall.append(data['Recall'])
            accuracy_err.append(0)
            precision_err.append(0)
            recall_err.append(0)

# Creating the plot with Plotly
fig = go.Figure()

# Adding Accuracy, Precision, and Recall traces
fig.add_trace(go.Bar(name='Accuracy', x=labels, y=accuracy, error_y=dict(type='data', array=accuracy_err)))
fig.add_trace(go.Bar(name='Precision', x=labels, y=precision, error_y=dict(type='data', array=precision_err)))
fig.add_trace(go.Bar(name='Recall', x=labels, y=recall, error_y=dict(type='data', array=recall_err)))

# Update the layout
fig.update_layout(
    barmode='group',
    title='Performance Metrics Across Different Models',
    xaxis_title='IMDb KeyClass Models',
    yaxis_title='Metric Value',
    legend_title='Metric'
)

# Show the plot
fig.show()

# 5. References

[[1](https://pubmed.ncbi.nlm.nih.gov/16178999/)] O’Malley KJ, Cook KF, Price MD, Wildes KR, Hurdle JF, Ashton CM. Measuring diagnoses: ICD code accuracy. Health Serv Res. 2005 Oct;40(5 Pt 2):1620-39. doi: 10.1111/j.1475-6773.2005.00444.x. PMID: 16178999; PMCID: PMC1361216.

[[2](https://pubmed.ncbi.nlm.nih.gov/12711737/)] Calle EE, Rodriguez C, Walker-Thurmond K, Thun MJ. "Overweight, Obesity, and Mortality from Cancer in a Prospectively Studied Cohort of U.S. Adults." New England Journal of Medicine. 2003;348(17):1625–38.

[[3](https://onlinelibrary.wiley.com/doi/abs/10.1111/j.1475-6773.2005.00444.x)] Charbonneau A, Rosen AK, Ash AS, Owen RR, Kader B, Spiro A, Hankin C, Herz LR, Pugh MJV, Kazis L, Miller DR, Berlowitz DR. "Measuring the Quality of Depression in a Large Integrated Health System." Medical Care. 2003;41:669–80.

[[4](https://jamanetwork.com/journals/jama/fullarticle/195992)] Studdert DM, Gresenz CR. "Enrollee Appeals of Preservice Coverage Denials at 2 Health Maintenance Organizations." Journal of the American Medical Association. 2003;289(7):864–70.

[[5](https://n.neurology.org/content/49/3/660.short)] Curtis Benesch, DM Witter, AL Wilder, PW Duncan, GP Samsa, and DB Matchar. Inaccuracy of the international classification of diseases (icd-9-cm) in identifying the diagnosis of ischemic cerebrovascular disease. Neurology, 49(3):660–664, 1997.

[[6](https://doi.org/10.1186/s12911-021-01531-9)] Wabe N, Li L, Lindeman R, et al. Evaluation of the accuracy of diagnostic coding for influenza compared to laboratory results: the availability of test results before hospital discharge facilitates improved coding accuracy. BMC Med Inform Decis Mak 21, 168 (2021).

[[7](https://bmcmedinformdecismak.biomedcentral.com/articles/10.1186/s12911-024-02449-8)] Guo LL, Morse KE, Aftandilian C, Steinberg E, Fries J, Posada J, Fleming SL, Lemmon J, Jessa K, Shah N, Sung L. Characterizing the limitations of using diagnosis codes in the context of machine learning for healthcare. BMC Med Inform Decis Mak. 2024 Feb 14;24(1):51. doi: 10.1186/s12911-024-02449-8. PMID: 38355486; PMCID: PMC10868117.

[[8](https://pubmed.ncbi.nlm.nih.gov/28595574/)] Burles K, Innes G, Senior K, Lang E, McRae A. Limitations of pulmonary embolism ICD-10 codes in emergency department administrative data: let the buyer beware. BMC Med Res Methodol. 2017;17(1):89. doi: 10.1186/s12874-017-0361-1.

[[9](https://journals.plos.org/plosone/article?id=10.1371/journal.pone.0234647)] Venkataraman GR, Pineda AL, Bear Don't Walk Iv OJ, Zehnder AM, Ayyar S, Page RL, Bustamante CD, Rivas MA. FasTag: Automatic text classification of unstructured medical narratives. PLoS One. 2020 Jun 22;15(6):e0234647. doi: 10.1371/journal.pone.0234647. PMID: 32569327; PMCID: PMC7307763.

[[10](https://ui.adsabs.harvard.edu/abs/2022arXiv220612088G)]Gao C, Goswami M, Chen J, Dubrawski A. Classifying unstructured clinical notes via automatic weak supervision. arXiv. 2022 Jun [cited 2024 Mar 15]. In: arXiv:2206.12088 [cs.CL]. doi: 10.48550/arXiv.2206.12088.

[[11](https://doi.org/10.13026/C2XW26.)]Johnson, A., Pollard, T., & Mark, R. (2016). MIMIC-III Clinical Database (version 1.4). PhysioNet. https://doi.org/10.13026/C2XW26.

[[12](https://www.nature.com/articles/sdata201635)]Johnson, A. E. W., Pollard, T. J., Shen, L., Lehman, L. H., Feng, M., Ghassemi, M., Moody, B., Szolovits, P., Celi, L. A., & Mark, R. G. (2016). MIMIC-III, a freely accessible critical care database. Scientific Data, 3, 160035.

[[13]() Goldberger, A., Amaral, L., Glass, L., Hausdorff, J., Ivanov, P. C., Mark, R., ... & Stanley, H. E. (2000). PhysioBank, PhysioToolkit, and PhysioNet: Components of a new research resource for complex physiologic signals. Circulation [Online]. 101 (23), pp. e215–e220.]

[[14](https://github.com/drobbins/ICD9/tree/master)]Robbins D. ICD9 [Internet]. GitHub; 2013. [updated 2013 Nov 11; cited 2024 Apr 14]. Available from: https://github.com/drobbins/ICD9/tree/master