## Use zero shot classification models to make predictions on atrial fibrillation status.

In [1]:
import pandas as pd
import numpy as np
from transformers import pipeline

In [2]:
# initialize a zero shot classification classifier
# here, we use the BART MNLI model, which is the most popular on the transformers library
classifier = pipeline('zero-shot-classification', model='facebook/bart-large-mnli')

In [3]:
# import the dataset
data = pd.read_csv('processed_afib_data.csv')
print(data)

       index  patient_id  afib  \
0          0         109     0   
1          2         113     0   
2          3         114     0   
3          4         115     0   
4          6         117     0   
...      ...         ...   ...   
29446  46515       97164     0   
29447  46516       97484     0   
29448  46517       97488     1   
29449  46518       97492     0   
29450  46519       97497     1   

                                                   notes  age  \
0      PATIENT/TEST INFORMATION: Indication: Code. As...   25   
1      Sinus rhythm, rate 93. Non-specific ST-T wave ...   35   
2      Normal sinus rhythm, rate 96 Right bundle bran...   48   
3      PATIENT/TEST INFORMATION: Indication: Left ven...   75   
4      PATIENT/TEST INFORMATION: Indication: Murmur. ...   50   
...                                                  ...  ...   
29446  PATIENT/TEST INFORMATION: Indication: Aortic v...   83   
29447  Sinus bradycardia with non-diagnostic repolari...   79   
29448 

In [13]:
# this array stores the 0 and 1 classifier based on if the probability is less than or greater than 0.5
predictions = []
# this array stores the precise probability that the model outputs
probability_predictions = []

for index, row in data.iterrows():
    text = row['notes']
    candidate_label = ['atrial fibrillation']

    output = classifier(text, candidate_label)
    # extract the exact prediction from the output, which is a dictionary with the key 'scores' storing an array
    prediction = output['scores'][0]

    probability_predictions.append(prediction)

    if prediction < 0.5:
        predictions.append(0)
    else:
        predictions.append(1)

KeyboardInterrupt: 

In [14]:
print(len(predictions))
print(np.mean(predictions))
print(probability_predictions)

176
0.36363636363636365
0.36946779515743733
[0.14798565208911896, 0.004260016139596701, 0.007346693426370621, 0.8733732104301453, 0.47856613993644714, 0.00835508480668068, 0.30374228954315186, 0.7978691458702087, 0.3883456587791443, 0.610526442527771, 0.012758253142237663, 0.9786338210105896, 0.15835674107074738, 0.24406136572360992, 0.004112997557967901, 0.5663946270942688, 0.005485554691404104, 0.0028088341932743788, 0.452274888753891, 0.8663113713264465, 0.054318126291036606, 0.17624005675315857, 0.3799442648887634, 0.02826797217130661, 0.0018008373444899917, 0.02108229137957096, 0.8496683835983276, 0.6324237585067749, 0.005130487959831953, 0.7790483832359314, 0.9560039639472961, 0.004688859451562166, 0.7401663064956665, 0.0032245267648249865, 0.026957301422953606, 0.04533727094531059, 0.0026358540635555983, 0.9740086793899536, 0.8190281987190247, 0.021336989477276802, 0.005790151190012693, 0.00065345608163625, 0.004089518915861845, 0.0018263614038005471, 0.5236969590187073, 0.37730