In [20]:
from transformers import pipeline
import numpy as np
import pandas as pd
from sklearn import metrics
import torch
from datasets import Dataset
from transformers.pipelines.pt_utils import KeyDataset
from tqdm.auto import tqdm

In [21]:
from transformers import AutoTokenizer
from transformers import AutoModelForSequenceClassification

tokenizer = AutoTokenizer.from_pretrained('facebook/bart-large-mnli')
model = AutoModelForSequenceClassification.from_pretrained("facebook/bart-large-mnli")
model.config.label2id

{'contradiction': 0, 'entailment': 2, 'neutral': 1}

In [25]:
from datasets import load_dataset

dataset = load_dataset("csv", data_files={
    "train": "/root/data/chex_train.csv",
    "val": "/root/data/chex_val.csv",
    "test": "/root/data/chex_test.csv",
}).rename_column('Unnamed: 0', 'id')
dataset

Found cached dataset csv (/root/.cache/huggingface/datasets/csv/default-dcb44e19e0ca611f/0.0.0/6b34fb8fcf56f7c8ba51dc895bfa2bfbe43546f190a60fcf74bb5e8afdcc2317)


  0%|          | 0/3 [00:00<?, ?it/s]

DatasetDict({
    train: Dataset({
        features: ['id', 'Report Impression', 'Enlarged Cardiomediastinum', 'Cardiomegaly', 'Lung Opacity', 'Lung Lesion', 'Edema', 'Consolidation', 'Pneumonia', 'Atelectasis', 'Pneumothorax', 'Pleural Effusion', 'Pleural Other', 'Fracture', 'Support Devices', 'No Finding'],
        num_rows: 102304
    })
    val: Dataset({
        features: ['id', 'Report Impression', 'Enlarged Cardiomediastinum', 'Cardiomegaly', 'Lung Opacity', 'Lung Lesion', 'Edema', 'Consolidation', 'Pneumonia', 'Atelectasis', 'Pneumothorax', 'Pleural Effusion', 'Pleural Other', 'Fracture', 'Support Devices', 'No Finding'],
        num_rows: 29230
    })
    test: Dataset({
        features: ['id', 'Report Impression', 'Enlarged Cardiomediastinum', 'Cardiomegaly', 'Lung Opacity', 'Lung Lesion', 'Edema', 'Consolidation', 'Pneumonia', 'Atelectasis', 'Pneumothorax', 'Pleural Effusion', 'Pleural Other', 'Fracture', 'Support Devices', 'No Finding'],
        num_rows: 14615
    })
})

In [23]:
dataset['train'][3] # contains Atelectasis, Pleural Effusion, and Fracture

{'Unnamed: 0': 8231,
 'Report Impression': "1.  Extensive cecal wall thickening and inflammatory changes with suspected pneumatosis and evidence of extraluminal mesenteric gas, and trace portal venous gas, in keeping with bowel ischemia. No frank disruption in the bowel contour is seen on noncontrast images. No abscess or drainable fluid collection. 2.  Normal short appendix. 3.  Moderate-sized bilateral pleural effusions with a partially visualized nodular opacity in the right middle lobe, likely representing focal atelectasis. Other less likely etiologies include consolidation or pulmonary nodule, and when the patient's status improves, further assessment with CT chest could be considered. 4.  Compression fracture of L1 with bony retropulsion. This is new from the radiographs of 2/3/2019, but still appears chronic. Correlation with point tenderness recommended. Dr. Li discussed these findings with Dr. Cohen via telephone on 9/19/2020 at 4:10 AM..",
 'Enlarged Cardiomediastinum': None

In [24]:
dataset['train'][3]['Fracture'] == 1

True

In [26]:
dataset['train'].column_names

['id',
 'Report Impression',
 'Enlarged Cardiomediastinum',
 'Cardiomegaly',
 'Lung Opacity',
 'Lung Lesion',
 'Edema',
 'Consolidation',
 'Pneumonia',
 'Atelectasis',
 'Pneumothorax',
 'Pleural Effusion',
 'Pleural Other',
 'Fracture',
 'Support Devices',
 'No Finding']

In [40]:
def create_target_sentences_with_labels(labels):
    # function(batch: Dict[str, List]) -> Dict[str, List]
    def create_target_sentences(batch):
        text_key = 'Report Impression'
        out = {'target': [], text_key: [], 'labels': [], 'original_label': [], 'id': []}
#         out = {'target': [], text_key: [], 'labels': []}

        for i in range(len(batch[text_key])):
            for label in labels:
                out['original_label'].append(label)
                out['id'].append(batch['id'][i])
                out['target'].append(f'This example is {label}.')
                out[text_key].append(batch[text_key][i])
                if batch[label][i] == -1:
                    out['labels'].append(model.config.label2id['contradiction'])
                elif batch[label][i] == None or batch[label][i] == 0:
                    out['labels'].append(model.config.label2id['neutral'])
                elif batch[label][i] == 1:
                    out['labels'].append(model.config.label2id['entailment'])
                else:
                    raise Exception(f"invalid value in labels {batch[label][i]}")
#         for k, v in out.items():
#             print(k, len(v))
        return out
    return create_target_sentences
    
all_labels = ["Fracture", "Edema", "Cardiomegaly", "Pneumonia", "Atelectasis", "Pneumothorax", "Pleural Effusion"]
train_labels = ["Fracture", "Edema", "Pneumonia", "Atelectasis", "Pneumothorax"]
val_labels = ["Pleural Effusion"]
test_labels = ["Cardiomegaly"]

train = dataset.map(
    create_target_sentences_with_labels(train_labels),
    batched=True,
    remove_columns=dataset['train'].column_names
)
val = dataset.map(
    create_target_sentences_with_labels(val_labels),
    batched=True,
    remove_columns=dataset['train'].column_names,
)
test = dataset.map(
    create_target_sentences_with_labels(test_labels),
    batched=True,
    remove_columns=dataset['train'].column_names,
)
test

Map:   0%|          | 0/102304 [00:00<?, ? examples/s]

target 5000
Report Impression 5000
labels 5000
original_label 5000
id 5000
target 5000
Report Impression 5000
labels 5000
original_label 5000
id 5000
target 5000
Report Impression 5000
labels 5000
original_label 5000
id 5000
target 5000
Report Impression 5000
labels 5000
original_label 5000
id 5000
target 5000
Report Impression 5000
labels 5000
original_label 5000
id 5000
target 5000
Report Impression 5000
labels 5000
original_label 5000
id 5000
target 5000
Report Impression 5000
labels 5000
original_label 5000
id 5000
target 5000
Report Impression 5000
labels 5000
original_label 5000
id 5000
target 5000
Report Impression 5000
labels 5000
original_label 5000
id 5000
target 5000
Report Impression 5000
labels 5000
original_label 5000
id 5000
target 5000
Report Impression 5000
labels 5000
original_label 5000
id 5000
target 5000
Report Impression 5000
labels 5000
original_label 5000
id 5000
target 5000
Report Impression 5000
labels 5000
original_label 5000
id 5000
target 5000
Report Impres

Map:   0%|          | 0/29230 [00:00<?, ? examples/s]

target 5000
Report Impression 5000
labels 5000
original_label 5000
id 5000
target 5000
Report Impression 5000
labels 5000
original_label 5000
id 5000
target 5000
Report Impression 5000
labels 5000
original_label 5000
id 5000
target 5000
Report Impression 5000
labels 5000
original_label 5000
id 5000
target 5000
Report Impression 5000
labels 5000
original_label 5000
id 5000
target 5000
Report Impression 5000
labels 5000
original_label 5000
id 5000
target 5000
Report Impression 5000
labels 5000
original_label 5000
id 5000
target 5000
Report Impression 5000
labels 5000
original_label 5000
id 5000
target 5000
Report Impression 5000
labels 5000
original_label 5000
id 5000
target 5000
Report Impression 5000
labels 5000
original_label 5000
id 5000
target 5000
Report Impression 5000
labels 5000
original_label 5000
id 5000
target 5000
Report Impression 5000
labels 5000
original_label 5000
id 5000
target 5000
Report Impression 5000
labels 5000
original_label 5000
id 5000
target 5000
Report Impres

Map:   0%|          | 0/14615 [00:00<?, ? examples/s]

target 5000
Report Impression 5000
labels 5000
original_label 5000
id 5000
target 5000
Report Impression 5000
labels 5000
original_label 5000
id 5000
target 5000
Report Impression 5000
labels 5000
original_label 5000
id 5000
target 5000
Report Impression 5000
labels 5000
original_label 5000
id 5000
target 5000
Report Impression 5000
labels 5000
original_label 5000
id 5000
target 5000
Report Impression 5000
labels 5000
original_label 5000
id 5000
target 5000
Report Impression 5000
labels 5000
original_label 5000
id 5000
target 5000
Report Impression 5000
labels 5000
original_label 5000
id 5000
target 5000
Report Impression 5000
labels 5000
original_label 5000
id 5000
target 5000
Report Impression 5000
labels 5000
original_label 5000
id 5000
target 5000
Report Impression 5000
labels 5000
original_label 5000
id 5000
target 5000
Report Impression 5000
labels 5000
original_label 5000
id 5000
target 5000
Report Impression 5000
labels 5000
original_label 5000
id 5000
target 5000
Report Impres

Map:   0%|          | 0/102304 [00:00<?, ? examples/s]

target 1000
Report Impression 1000
labels 1000
original_label 1000
id 1000
target 1000
Report Impression 1000
labels 1000
original_label 1000
id 1000
target 1000
Report Impression 1000
labels 1000
original_label 1000
id 1000
target 1000
Report Impression 1000
labels 1000
original_label 1000
id 1000
target 1000
Report Impression 1000
labels 1000
original_label 1000
id 1000
target 1000
Report Impression 1000
labels 1000
original_label 1000
id 1000
target 1000
Report Impression 1000
labels 1000
original_label 1000
id 1000
target 1000
Report Impression 1000
labels 1000
original_label 1000
id 1000
target 1000
Report Impression 1000
labels 1000
original_label 1000
id 1000
target 1000
Report Impression 1000
labels 1000
original_label 1000
id 1000
target 1000
Report Impression 1000
labels 1000
original_label 1000
id 1000
target 1000
Report Impression 1000
labels 1000
original_label 1000
id 1000
target 1000
Report Impression 1000
labels 1000
original_label 1000
id 1000
target 1000
Report Impres

Map:   0%|          | 0/29230 [00:00<?, ? examples/s]

target 1000
Report Impression 1000
labels 1000
original_label 1000
id 1000
target 1000
Report Impression 1000
labels 1000
original_label 1000
id 1000
target 1000
Report Impression 1000
labels 1000
original_label 1000
id 1000
target 1000
Report Impression 1000
labels 1000
original_label 1000
id 1000
target 1000
Report Impression 1000
labels 1000
original_label 1000
id 1000
target 1000
Report Impression 1000
labels 1000
original_label 1000
id 1000
target 1000
Report Impression 1000
labels 1000
original_label 1000
id 1000
target 1000
Report Impression 1000
labels 1000
original_label 1000
id 1000
target 1000
Report Impression 1000
labels 1000
original_label 1000
id 1000
target 1000
Report Impression 1000
labels 1000
original_label 1000
id 1000
target 1000
Report Impression 1000
labels 1000
original_label 1000
id 1000
target 1000
Report Impression 1000
labels 1000
original_label 1000
id 1000
target 1000
Report Impression 1000
labels 1000
original_label 1000
id 1000
target 1000
Report Impres

Map:   0%|          | 0/14615 [00:00<?, ? examples/s]

target 1000
Report Impression 1000
labels 1000
original_label 1000
id 1000
target 1000
Report Impression 1000
labels 1000
original_label 1000
id 1000
target 1000
Report Impression 1000
labels 1000
original_label 1000
id 1000
target 1000
Report Impression 1000
labels 1000
original_label 1000
id 1000
target 1000
Report Impression 1000
labels 1000
original_label 1000
id 1000
target 1000
Report Impression 1000
labels 1000
original_label 1000
id 1000
target 1000
Report Impression 1000
labels 1000
original_label 1000
id 1000
target 1000
Report Impression 1000
labels 1000
original_label 1000
id 1000
target 1000
Report Impression 1000
labels 1000
original_label 1000
id 1000
target 1000
Report Impression 1000
labels 1000
original_label 1000
id 1000
target 1000
Report Impression 1000
labels 1000
original_label 1000
id 1000
target 1000
Report Impression 1000
labels 1000
original_label 1000
id 1000
target 1000
Report Impression 1000
labels 1000
original_label 1000
id 1000
target 1000
Report Impres

Map:   0%|          | 0/102304 [00:00<?, ? examples/s]

target 1000
Report Impression 1000
labels 1000
original_label 1000
id 1000
target 1000
Report Impression 1000
labels 1000
original_label 1000
id 1000
target 1000
Report Impression 1000
labels 1000
original_label 1000
id 1000
target 1000
Report Impression 1000
labels 1000
original_label 1000
id 1000
target 1000
Report Impression 1000
labels 1000
original_label 1000
id 1000
target 1000
Report Impression 1000
labels 1000
original_label 1000
id 1000
target 1000
Report Impression 1000
labels 1000
original_label 1000
id 1000
target 1000
Report Impression 1000
labels 1000
original_label 1000
id 1000
target 1000
Report Impression 1000
labels 1000
original_label 1000
id 1000
target 1000
Report Impression 1000
labels 1000
original_label 1000
id 1000
target 1000
Report Impression 1000
labels 1000
original_label 1000
id 1000
target 1000
Report Impression 1000
labels 1000
original_label 1000
id 1000
target 1000
Report Impression 1000
labels 1000
original_label 1000
id 1000
target 1000
Report Impres

Map:   0%|          | 0/29230 [00:00<?, ? examples/s]

target 1000
Report Impression 1000
labels 1000
original_label 1000
id 1000
target 1000
Report Impression 1000
labels 1000
original_label 1000
id 1000
target 1000
Report Impression 1000
labels 1000
original_label 1000
id 1000
target 1000
Report Impression 1000
labels 1000
original_label 1000
id 1000
target 1000
Report Impression 1000
labels 1000
original_label 1000
id 1000
target 1000
Report Impression 1000
labels 1000
original_label 1000
id 1000
target 1000
Report Impression 1000
labels 1000
original_label 1000
id 1000
target 1000
Report Impression 1000
labels 1000
original_label 1000
id 1000
target 1000
Report Impression 1000
labels 1000
original_label 1000
id 1000
target 1000
Report Impression 1000
labels 1000
original_label 1000
id 1000
target 1000
Report Impression 1000
labels 1000
original_label 1000
id 1000
target 1000
Report Impression 1000
labels 1000
original_label 1000
id 1000
target 1000
Report Impression 1000
labels 1000
original_label 1000
id 1000
target 1000
Report Impres

Map:   0%|          | 0/14615 [00:00<?, ? examples/s]

target 1000
Report Impression 1000
labels 1000
original_label 1000
id 1000
target 1000
Report Impression 1000
labels 1000
original_label 1000
id 1000
target 1000
Report Impression 1000
labels 1000
original_label 1000
id 1000
target 1000
Report Impression 1000
labels 1000
original_label 1000
id 1000
target 1000
Report Impression 1000
labels 1000
original_label 1000
id 1000
target 1000
Report Impression 1000
labels 1000
original_label 1000
id 1000
target 1000
Report Impression 1000
labels 1000
original_label 1000
id 1000
target 1000
Report Impression 1000
labels 1000
original_label 1000
id 1000
target 1000
Report Impression 1000
labels 1000
original_label 1000
id 1000
target 1000
Report Impression 1000
labels 1000
original_label 1000
id 1000
target 1000
Report Impression 1000
labels 1000
original_label 1000
id 1000
target 1000
Report Impression 1000
labels 1000
original_label 1000
id 1000
target 1000
Report Impression 1000
labels 1000
original_label 1000
id 1000
target 1000
Report Impres

DatasetDict({
    train: Dataset({
        features: ['id', 'Report Impression', 'target', 'labels', 'original_label'],
        num_rows: 102304
    })
    val: Dataset({
        features: ['id', 'Report Impression', 'target', 'labels', 'original_label'],
        num_rows: 29230
    })
    test: Dataset({
        features: ['id', 'Report Impression', 'target', 'labels', 'original_label'],
        num_rows: 14615
    })
})

In [41]:
train

DatasetDict({
    train: Dataset({
        features: ['id', 'Report Impression', 'target', 'labels', 'original_label'],
        num_rows: 511520
    })
    val: Dataset({
        features: ['id', 'Report Impression', 'target', 'labels', 'original_label'],
        num_rows: 146150
    })
    test: Dataset({
        features: ['id', 'Report Impression', 'target', 'labels', 'original_label'],
        num_rows: 73075
    })
})

In [44]:
train['train'][0:10]

{'id': [101931,
  101931,
  101931,
  101931,
  101931,
  42061,
  42061,
  42061,
  42061,
  42061],
 'Report Impression': ['1.  Large bilateral layering pleural effusions. 2.  Left mid lung zone nodular opacity could reflect a granuloma or overlying structure. Recommend comparison to prior imaging or follow up.',
  '1.  Large bilateral layering pleural effusions. 2.  Left mid lung zone nodular opacity could reflect a granuloma or overlying structure. Recommend comparison to prior imaging or follow up.',
  '1.  Large bilateral layering pleural effusions. 2.  Left mid lung zone nodular opacity could reflect a granuloma or overlying structure. Recommend comparison to prior imaging or follow up.',
  '1.  Large bilateral layering pleural effusions. 2.  Left mid lung zone nodular opacity could reflect a granuloma or overlying structure. Recommend comparison to prior imaging or follow up.',
  '1.  Large bilateral layering pleural effusions. 2.  Left mid lung zone nodular opacity could refle

In [47]:
remove_columns = train['train'].column_names
remove_columns.remove('id') # keep these columns!
remove_columns.remove('labels') # keep these columns!
remove_columns.remove('original_label') # keep these columns!
print(remove_columns)

# TODO: max_length may be slow?
def tokenize_function(examples):
    return tokenizer(text=examples["Report Impression"], text_pair=examples["target"], padding="max_length", truncation='only_first')

train_tokenized = train.map(
    tokenize_function,
    batched=True,
    remove_columns=remove_columns,
)
val_tokenized = val.map(
    tokenize_function,
    batched=True,
    remove_columns=remove_columns,
)
test_tokenized = test.map(
    tokenize_function,
    batched=True,
    remove_columns=remove_columns,
)

['Report Impression', 'target']


Map:   0%|          | 0/511520 [00:00<?, ? examples/s]

Map:   0%|          | 0/146150 [00:00<?, ? examples/s]

Map:   0%|          | 0/73075 [00:00<?, ? examples/s]

Map:   0%|          | 0/102304 [00:00<?, ? examples/s]

Map:   0%|          | 0/29230 [00:00<?, ? examples/s]

Map:   0%|          | 0/14615 [00:00<?, ? examples/s]

Map:   0%|          | 0/102304 [00:00<?, ? examples/s]

Map:   0%|          | 0/29230 [00:00<?, ? examples/s]

Map:   0%|          | 0/14615 [00:00<?, ? examples/s]

In [48]:
train_tokenized.save_to_disk("/root/data/bart_train_labels")

Saving the dataset (0/6 shards):   0%|          | 0/511520 [00:00<?, ? examples/s]

Saving the dataset (0/2 shards):   0%|          | 0/146150 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/73075 [00:00<?, ? examples/s]

In [49]:
val_tokenized.save_to_disk("/root/data/bart_val_labels")

Saving the dataset (0/2 shards):   0%|          | 0/102304 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/29230 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/14615 [00:00<?, ? examples/s]

In [50]:
test_tokenized.save_to_disk("/root/data/bart_test_labels")

Saving the dataset (0/2 shards):   0%|          | 0/102304 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/29230 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/14615 [00:00<?, ? examples/s]

In [51]:
train_tokenized

DatasetDict({
    train: Dataset({
        features: ['id', 'labels', 'original_label', 'input_ids', 'attention_mask'],
        num_rows: 511520
    })
    val: Dataset({
        features: ['id', 'labels', 'original_label', 'input_ids', 'attention_mask'],
        num_rows: 146150
    })
    test: Dataset({
        features: ['id', 'labels', 'original_label', 'input_ids', 'attention_mask'],
        num_rows: 73075
    })
})

In [52]:
from datasets import concatenate_datasets
train_concat = concatenate_datasets([train_tokenized['train'], train_tokenized['val']])
train_concat

Dataset({
    features: ['id', 'labels', 'original_label', 'input_ids', 'attention_mask'],
    num_rows: 657670
})

In [53]:
train_tokenized['train'][100]

{'id': 12815,
 'labels': 1,
 'original_label': 'Fracture',
 'input_ids': [0,
  134,
  4,
  1437,
  440,
  13827,
  19567,
  1043,
  3917,
  2617,
  36536,
  6948,
  4,
  132,
  4,
  1437,
  440,
  22259,
  50,
  17292,
  8196,
  337,
  15645,
  9,
  5,
  29851,
  20625,
  4,
  155,
  4,
  1437,
  440,
  13827,
  36536,
  6948,
  11,
  5,
  5397,
  4,
  2,
  2,
  713,
  1246,
  16,
  45274,
  2407,
  4,
  2,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,