# Prediction for Text classification model

In [1]:
import pandas as pd
import numpy as np
import torch
from torch.utils.data import DataLoader, Dataset
from datasets import Dataset, load_dataset, DatasetDict
import torch.nn as nn
import transformers
from transformers import AutoTokenizer, AutoModel, TrainingArguments
from torch.optim import AdamW

tokenizer = AutoTokenizer.from_pretrained("ICLbioengNLP/CXR_BioClinicalBERT_chunkedv1")

In [2]:
device=torch.device("cpu")

In [3]:
class PredBERTClass(torch.nn.Module):
    def __init__(self):
        super(PredBERTClass, self).__init__()
        self.bert = AutoModel.from_pretrained("ICLbioengNLP/CXR_BioClinicalBERT_chunkedv1") 
        self.dropout = torch.nn.Dropout(0.3) # dunno if this work, forget about it now
        self.classifier = torch.nn.Linear(768, 13) # 768 and 13 are fixed in our case
    
    def forward(self, input_ids, mask):
        _, pooled_output = self.bert(input_ids = input_ids, attention_mask = mask, return_dict = False) # dun need word_id and token_type_id I think
        output = self.dropout(pooled_output)
        output = self.classifier(output)
        output = torch.sigmoid(output)
        return output

In [6]:
trained_model = PredBERTClass()
trained_optimizer = AdamW(trained_model.parameters(), lr=5e-5)

# change the path if thats not same as yours 
best_model_path = 'CXR_BioClinicalBERT_Class/best_model.pt'

checkpoint = torch.load(best_model_path, map_location="cpu")
trained_model.load_state_dict(checkpoint['state_dict'])
trained_optimizer.load_state_dict(checkpoint['optimizer'])
epoch = checkpoint['epoch']
loss = checkpoint['valid_loss_min']

trained_model.eval()

Some weights of the model checkpoint at ICLbioengNLP/CXR_BioClinicalBERT_chunkedv1 were not used when initializing BertModel: ['cls.predictions.decoder.bias', 'cls.predictions.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertModel were not initialized from the model checkpoint at ICLbioengNLP/CXR_BioClinicalBERT_chunkedv1 and are newly initialized: ['bert.poole

PredBERTClass(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(28996, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True

### Loading testing sample in

In [70]:
section = "impression" # see which section u want to compare with - impression or findings

In [71]:
testing_df = pd.read_csv('final_samples_updated.csv')
testing_df = testing_df.drop('Unnamed: 0', 1)
print(len(testing_df["study_id"].tolist()))

# exract useful info and turn into dict
sample_dataset = dict.fromkeys(["study_id", "diagnosis", "diagnosis_id", "impression", "findings"])
sample_dataset["study_id"] = testing_df["study_id"].tolist()
sample_dataset["diagnosis"] = testing_df["diagnosis"].tolist()
sample_dataset["diagnosis_id"] = testing_df["diagnosis_id"].tolist()
sample_dataset["impression"] = testing_df["impression"].tolist()
sample_dataset["findings"] = testing_df["findings"].tolist()

164


In [72]:
# dict for saving the scores of every sample!
diagnoses = {'Atelectasis':[], 'Cardiomegaly':[], 'Consolidation':[], 'Edema':[], 'Enlarged Cardiomediastinum':[], 'Fracture':[], 
            'Lung Lesion':[], 'Lung Opacity':[], 'No Finding':[], 'Pleural Effusion':[], 'Pleural Other':[], 'Pneumonia':[], 'Pneumothorax':[]}
# Check if the order is correct!

### Tokenize all samples and make a prediction 

In [73]:
sample_size = len(sample_dataset["study_id"])

for i in range(sample_size):
    report = sample_dataset[section][i]
    
    # tokenize sentence 
    new_tokens = tokenizer.encode_plus(report, max_length=128, truncation=True,
                                       padding='max_length', return_tensors='pt')
    t_ids = new_tokens['input_ids']
    t_mask = new_tokens['attention_mask']
    
    # feed each sample into the model 
    test_prediction = trained_model(t_ids, t_mask)
    test_prediction = test_prediction.detach().numpy().tolist()
    test_prediction = test_prediction[0]
    # the test_prediction is a list of 13 values, each corresponding to the probability score of being in each diagnosis
    
    counter = 0
    for key in diagnoses.keys():
        # add each score accordingly to each diagnosis key 
        diagnoses[key].append(test_prediction[counter])
        counter += 1

In [74]:
print(len(diagnoses["Atelectasis"]))

164


In [80]:
result_df = pd.DataFrame.from_dict(diagnoses)
result_df.insert(0, "Diagnosis (GT)", sample_dataset["diagnosis"])
result_df.insert(0, "Impression", sample_dataset["impression"])
result_df.insert(0, "Findings", sample_dataset["findings"])
result_df.insert(0, "study_id", sample_dataset["study_id"])
result_df.head(n=15)
# result_df.to_csv('whole_results_table.csv')

Unnamed: 0,study_id,Findings,Impression,Diagnosis (GT),Atelectasis,Cardiomegaly,Consolidation,Edema,Enlarged Cardiomediastinum,Fracture,Lung Lesion,Lung Opacity,No Finding,Pleural Effusion,Pleural Other,Pneumonia,Pneumothorax
0,s58402174,AP portable semi upright view of the chest.\n ...,Increasing bibasilar atelectasis. Possible mi...,Atelectasis,0.001259,0.997841,0.00021,8.2e-05,0.000141,3.7e-05,1.2e-05,4.4e-05,9.8e-05,3.1e-05,0.000689,7e-06,0.000336
1,s59983953,An endotracheal tube approximately 7 cm from t...,1. Bibasilar and right upper lobe atelectasis...,Atelectasis,0.000112,0.947061,2.8e-05,7e-06,0.000181,9e-06,4.7e-05,5.7e-05,0.000277,0.000508,0.000138,3e-06,0.074356
2,s55481818,Linear opacities of the lung bases bilaterally...,Emphysema and bibasilar atelectasis. No evide...,Atelectasis,0.000211,0.995724,5.5e-05,5.7e-05,0.000302,4.8e-05,6.6e-05,6.4e-05,8.8e-05,8.8e-05,0.000129,1.8e-05,0.006214
3,s51499550,AP portable upright view of the chest. Midli...,Limited exam with given low lung volumes with ...,Atelectasis,0.000181,0.999466,8.5e-05,7.4e-05,0.000114,1.4e-05,1.9e-05,0.000243,0.000131,5.5e-05,0.000247,4e-06,0.000425
4,s51644170,Patient is status post median sternotomy. Rig...,Persistently low lung volumes with streaky rig...,Atelectasis,0.00029,0.999898,7.9e-05,8.3e-05,0.000278,6.1e-05,6.8e-05,0.00052,0.000306,4.5e-05,0.000211,1.6e-05,0.000296
5,s57361873,PA and lateral chest radiograph demonstrate a ...,Overall stable appearance of the chest with lo...,Atelectasis,0.000208,0.999611,2.7e-05,3.9e-05,0.000222,3.1e-05,3.9e-05,0.000249,0.000228,3.8e-05,0.000106,7e-06,0.000602
6,s59735304,AP portable upright view of the chest. Right...,Bibasilar atelectasis. No convincing evidence...,Atelectasis,0.000831,0.981743,6.4e-05,1.5e-05,0.000288,4.2e-05,5.1e-05,3.9e-05,5.4e-05,0.000138,0.002499,9e-06,0.007939
7,s55617591,There is no change in the total right upper lo...,Unchanged total right upper lobe collapse in t...,Atelectasis,7.7e-05,0.839398,2.1e-05,4e-06,6.3e-05,2.2e-05,1.2e-05,6.5e-05,6.6e-05,0.000131,0.000107,2e-06,0.008477
8,s56545860,Right PICC line ends at mid SVC. Left-sided p...,Right PICC line ends at mid SVC. Small bibasi...,Atelectasis,0.000288,0.998632,2.4e-05,7.5e-05,0.000253,1.9e-05,2.4e-05,0.000145,5.3e-05,7.1e-05,0.000104,5e-06,0.000455
9,s54898695,Frontal and lateral views of the chest demonst...,Collapse of the remaining left lung with furth...,Atelectasis,0.000174,0.99957,4.9e-05,8.7e-05,0.000164,4.8e-05,5.1e-05,0.000291,0.000109,4.2e-05,0.000259,1e-05,0.00077


### Calculate the average score for each diagnosis!

In [87]:
# check the order!!
labels = ['Atelectasis', 'Cardiomegaly', 'Consolidation', 'Edema', 'Enlarged Cardiomediastinum', 'Fracture', 
            'Lung Lesion', 'Lung Opacity', 'No Finding', 'Pleural Effusion', 'Pleural Other', 'Pneumonia', 'Pneumothorax']

# check the order!! - for saving all the average score for each diagnosis
average_weights = {'Atelectasis':[], 'Cardiomegaly':[], 'Consolidation':[], 'Edema':[], 'Enlarged Cardiomediastinum':[], 'Fracture':[], 
            'Lung Lesion':[], 'Lung Opacity':[], 'No Finding':[], 'Pleural Effusion':[], 'Pleural Other':[], 'Pneumonia':[], 'Pneumothorax':[]}

for label in labels: 
    # extract rows of the same diagnosis 
    selected_df = result_df.loc[result_df["Diagnosis (GT)"] == label]
#     print(selected_df[label])

    # for each diagnosis calculate the mean 
    for key in average_weights.keys():
        avg = selected_df[key].mean()
        average_weights[key].append(avg)

print(average_weights) # should have a dict with 13 keys, each key containing 13 scores
# wrong order and wrong scores between "No findings - Pneumothroax"

{'Atelectasis': [0.0003422310784420309, 0.9768397688865662, 0.0010847954370547086, 0.0020742148587790626, 0.0011941368720727041, 0.0006083508686226031, 0.0015046142968155134, 0.0013615388141867393, nan, 0.00036899815507543583, 0.0014836350223049522, 0.001310973448319904, 0.0009271368102277988], 'Cardiomegaly': [0.9487642288208008, 0.00020003838241488363, 0.0007070743900840171, 7.590144741698168e-05, 0.0046086921041326905, 0.0004156980880493806, 0.00041805453687023447, 0.0003910335469602918, nan, 0.059144589285763986, 0.0006436453142669052, 0.0009770321346877608, 0.0767278049833504], 'Consolidation': [5.2760870191074596e-05, 4.062463437245848e-05, 0.8951526621977488, 0.0001688070262995704, 0.25003366806413396, 0.0001405981397097507, 0.0003855338521437564, 0.0002919411167871052, nan, 9.93052981357323e-05, 0.0013279551349114627, 0.00030119566751333576, 0.00012936325401824433], 'Edema': [4.685104953144522e-05, 0.0001489052020284968, 0.00025495432055322453, 0.8830957298477491, 0.00090999755

In [92]:
average_df = pd.DataFrame.from_dict(average_weights)
average_df.insert(0, "Diagnosis (GT)", labels)
average_df.style.highlight_max(color = 'yellow', axis = 1)
# should show a diagnoal of yellow boxes!

  extrema = data == np.nanmax(data.to_numpy())


Unnamed: 0,Diagnosis (GT),Atelectasis,Cardiomegaly,Consolidation,Edema,Enlarged Cardiomediastinum,Fracture,Lung Lesion,Lung Opacity,No Finding,Pleural Effusion,Pleural Other,Pneumonia,Pneumothorax
0,Atelectasis,0.000342,0.948764,5.3e-05,4.7e-05,0.000221,3.9e-05,4.5e-05,0.000147,0.000219,0.000129,0.000342,9e-06,0.046592
1,Cardiomegaly,0.97684,0.0002,4.1e-05,0.000149,8.1e-05,3e-05,8.5e-05,0.000153,0.000246,0.00012,0.000183,1.6e-05,0.011692
2,Consolidation,0.001085,0.000707,0.895153,0.000255,0.000498,0.000239,0.000524,0.000514,0.000813,0.000319,0.002584,0.000115,0.055591
3,Edema,0.002074,7.6e-05,0.000169,0.883096,6.4e-05,5e-05,8.6e-05,0.000153,0.000208,0.000217,0.000481,5.4e-05,0.04685
4,Enlarged Cardiomediastinum,0.001194,0.004609,0.250034,0.00091,0.756487,0.000494,0.000438,0.330533,0.001258,0.001334,0.000708,0.000265,0.273867
5,Fracture,0.000608,0.000416,0.000141,0.000117,0.000172,0.988651,0.000395,0.000386,0.00088,0.002015,0.000382,0.000158,0.005265
6,Lung Lesion,0.001505,0.000418,0.000386,0.000235,0.000399,0.000182,0.988512,0.000818,0.000724,0.001258,0.078069,0.000208,0.009387
7,Lung Opacity,0.001362,0.000391,0.000292,0.000369,7.3e-05,5e-05,6.9e-05,0.969591,0.000218,0.000103,0.068081,3.5e-05,0.000964
8,No Finding,,,,,,,,,,,,,
9,Pleural Effusion,0.000369,0.059145,9.9e-05,6.9e-05,9.3e-05,6.7e-05,1.6e-05,7.8e-05,0.878658,0.000115,0.05836,1.2e-05,0.069235
