# 1. Build a text classification model using a pre-trained BERT model that can be found in Huggingface.

In [29]:
from datasets import load_dataset
import torch
from torch.utils.data import DataLoader, random_split
from transformers import BertForSequenceClassification, BertTokenizer, Trainer, TrainingArguments
from datasets import load_dataset, load_metric
import pandas as pd
import torchmetrics
import numpy as np
import sys

In [30]:
dataset = load_dataset("rungalileo/medical_transcription_40")

tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")



  0%|          | 0/2 [00:00<?, ?it/s]

In [31]:
max_length = 256  # truncate text

def tokenize(batch):
    return tokenizer(batch["text"], padding=True, truncation=True, max_length=max_length)

tokenized_dataset = dataset.map(tokenize, batched=True, batch_size=len(dataset["train"]))



In [32]:
train_df = pd.DataFrame(dataset["train"])
train_df

Unnamed: 0,id,text,label
0,3614,"EXAM: , CT scan of the abdomen and pelvis with...",23
1,488,"PREOPERATIVE DIAGNOSIS: , Fracture dislocation...",25
2,2482,"EARS, NOSE, MOUTH AND THROAT,EARS/NOSE: , The ...",32
3,3552,"PREOPERATIVE DIAGNOSIS: , Refractory dyspepsia...",23
4,3437,"CHIEF COMPLAINT:, Lump in the chest wall.,HIS...",36
...,...,...,...
4494,1492,"GENERAL EVALUATION:,Fetal Cardiac Activity: No...",15
4495,2623,"DELIVERY NOTE: , This is an 18-year-old, G2, P...",38
4496,3601,"EXAM:, CT examination of the abdomen and pelv...",23
4497,2818,"CC:, Progressive lower extremity weakness.,HX:...",6


In [33]:
train, val = random_split(tokenized_dataset["train"], [4000, 499])
test = dataset["test"]

In [34]:
model = BertForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=40)

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForSequenceClassification: ['cls.predictions.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.weight', 'cls.predictions.decoder.weight']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at

In [35]:
training_args = TrainingArguments(
    output_dir="output",
    num_train_epochs=4,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    evaluation_strategy="epoch",
    logging_dir="logs",
    learning_rate=2e-5,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train,
    eval_dataset=val,
    tokenizer=tokenizer,
)

trainer.train()



Epoch,Training Loss,Validation Loss
1,2.6475,2.202925
2,2.0436,1.871579
3,1.7433,1.790285
4,1.5757,1.783176


TrainOutput(global_step=2000, training_loss=2.002531463623047, metrics={'train_runtime': 870.877, 'train_samples_per_second': 18.372, 'train_steps_per_second': 2.297, 'total_flos': 2105606602752000.0, 'train_loss': 2.002531463623047, 'epoch': 4.0})

# 2. Record the metrics of interest to verify whether the models are high-performing.

In [36]:
model.eval()

def collate_fn(batch):
    text = [i["text"] for i in batch]
    labels = [i["label"] for i in batch]

    input_tensors = tokenizer(text, padding=True, truncation=True, return_tensors='pt')

    return input_tensors, torch.tensor(labels)

test_dataloader = DataLoader(test, batch_size=8, shuffle=False, collate_fn=collate_fn)

accuracy = torchmetrics.Accuracy(task='multiclass', num_classes=40).to('cuda:0')
f1 = torchmetrics.F1Score(task='multiclass', num_classes=40, average='weighted').to('cuda:0')
confusion_matrix = torchmetrics.ConfusionMatrix(task='multiclass', num_classes=40).to('cuda:0')
matthews_corrcoef = torchmetrics.MatthewsCorrCoef(task='multiclass', num_classes=40).to('cuda:0')

with torch.no_grad(): # disable gradient calculations to save memory
    for text, labels in test_dataloader:
        text, labels = text.to('cuda:0'), labels.to('cuda:0')

        outputs = model(**text)
        _, predicted = torch.max(outputs.logits, 1)

        accuracy.update(predicted, labels)
        f1.update(predicted, labels)
        confusion_matrix.update(predicted, labels)
        matthews_corrcoef.update(predicted, labels)

accuracy_val = accuracy.compute()
f1_val = f1.compute()
confusion_matrix_val = confusion_matrix.compute()
matthews_corrcoef_val = matthews_corrcoef.compute()

print("Accuracy:", accuracy_val)
print("F1 Score:", f1_val)
print("Confusion Matrix:", confusion_matrix_val)
print("Matthews Correlation Coefficient:", matthews_corrcoef_val)

Accuracy: tensor(0.4960, device='cuda:0')
F1 Score: tensor(0.4465, device='cuda:0')
Confusion Matrix: tensor([[ 7,  0,  0,  ...,  0,  0,  0],
        [ 0,  0,  0,  ...,  0,  0,  0],
        [ 0,  0,  0,  ...,  0,  0,  0],
        ...,
        [ 0,  0,  0,  ...,  0,  0,  0],
        [ 0,  0,  0,  ...,  0, 10,  0],
        [ 0,  0,  0,  ...,  0,  0,  0]], device='cuda:0')
Matthews Correlation Coefficient: tensor(0.4366, device='cuda:0')


In [37]:
confusion_matrix_np = confusion_matrix_val.cpu().numpy()
np.set_printoptions(threshold=sys.maxsize, linewidth=sys.maxsize)
print(confusion_matrix_np)

[[ 7  0  0  0  0  0  0  0  0  1  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0]
 [ 0  0  0  0  0  0  0  0  0  2  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0]
 [ 0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  2  0  0  0  0  0  0  0  0  0  3  0  0  0  0  0  0  0  0  0  0  0  0  0  0]
 [ 0  0  0  0  1  0  0  0  0  0  0  0  0  2  0  0  0  0  0  0  0  0  0  1  0  1  0  0  0  0  0  0  0  0  0  0  0  0  1  0]
 [ 0  0  0  0  5  0  0  0  0  3  0  0  0  0  0  0  0  0  0  0  0  0  0  1  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0]
 [ 0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  3  0  0  0  0  0  0  0  0  0  0  0  0  0  0]
 [ 0  0  0  0  0  0 10  0  0  2  0  0  0  5  0  5  0  0  0  0  0  0  0  0  0  1  0  0  0  0  0  0  0  0  0  0  0  0  0  0]
 [ 0  0  0  0  0  0  0  0  0  0  0  0  0  1  0  0  0  0  0  0  0  0  0  0  0  1  0  0  0  0  0  0  0  0  1  0  0  0  0  0]
 [ 0  0  0  0  0

# 3. Why did the dataset perform the way it did? What would help it to improve?

    1. The main reason the model performed poorly is because the size of the training dataset was very small. Increasing the size of the training dataset would help improve performance.
    2. There is heavy class imbalance for many of the 40 classes. Undersampling and oversampling techniques, for example SMOTE, can help with this problem. 
    3. Grid search or other hyperparameter optimizations
    4. Increase max sequence length or use full text
    5. Try a different model, for example Longformer or BERT large

#4. What if you could not truncate the data? What would be some of the approaches to classifying this data? What are some of the limitations? 

    1. Longformer model: Longformer has an “attention mechanism that scales linearly with sequence length, making it easy to process documents of thousands of tokens or longer.” - Betalgy et al.
        1. Limitation: computationally expensive, takes longer to train
    2. Text segmentation: Instead of truncating, segment the text into shorter sections with the same class
        1. Limitation: Will lose some information by shortening texts