# 12/5/19

Seeing if we've actually reproduced CheXNet.

UPDATE: we are able to reproduce CheXNet. Adding weighted cross entropy loss function didn't actually change anything.

In [6]:
%load_ext autoreload
%autoreload 2

import math
import os
import os.path as osp
import json
from functools import partial
from collections import defaultdict
os.chdir('/lfs/1/gangus/repositories/pytorch-classification/Emmental-ChexNet')

import torch
import torch.nn as nn
import torch.nn.functional as F
import sklearn.metrics as skl
import numpy as np
import pandas as pd
from tqdm import tqdm
import matplotlib.pyplot as plt
from PIL import Image

import emmental
from emmental import Meta
from emmental.data import EmmentalDataLoader
from emmental.learner import EmmentalLearner
from emmental.model import EmmentalModel
from emmental.scorer import Scorer
from emmental.task import EmmentalTask
from emmental.utils.utils import str2bool, move_to_device

from dataset import CXR8Dataset
from task import get_task
from task_config import CXR8_TASK_NAMES
from transforms import get_data_transforms

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [16]:
# old
emmental_dir = '/lfs/1/gangus/repositories/pytorch-classification/Emmental-ChexNet/logs/2019_12_04/14_31_34/99ae8f45'
# new
# emmental_dir = '/lfs/1/gangus/repositories/pytorch-classification/Emmental-ChexNet/logs/2019_12_05/11_22_23/07d2550b/'

In [35]:
def ce_loss(task_name, immediate_ouput_dict, Y, active):
    module_name = f"{task_name}_pred_head"
    return F.cross_entropy(
        immediate_ouput_dict[module_name][0][active], (Y.view(-1) - 1)[active]
    )

def output(task_name, immediate_ouput_dict):
    module_name = f"{task_name}_pred_head"
    return F.softmax(immediate_ouput_dict[module_name][0], dim=-1)

DATA_NAME = 'CXR8'

image_path = '/lfs/1/jdunnmon/data/nih/images/images'
data_path = '/dfs/scratch1/senwu/mmtl/emmental-tutorials/chexnet/data/nih_labels.csv'

task_names = CXR8_TASK_NAMES
task_to_label_dict = {t: t for t in task_names}
add_binary_triage_label = False
batch_size = 16

split = 'val'

emmental.init()

model_config = {
    'model_path': osp.join(emmental_dir, 'best_model_model_all_val_loss.pth'),
    'device': 0,
    'dataparallel': True
}
Meta.update_config(
    config={
        "meta_config": {"seed": 1701, "device": 0},
        "model_config": model_config
    }
)

cxr8_transform = get_data_transforms(DATA_NAME)
dataset = CXR8Dataset(
    name=DATA_NAME,
    path_to_images=image_path,
    path_to_labels=data_path,
    split=split,
    transform=cxr8_transform[split],
    sample=0,
    seed=1701,
    add_binary_triage_label=add_binary_triage_label,
)

task_to_class_weights = {}
for task_name in task_names:
    task_labels = dataset.Y_dict[task_to_label_dict[task_name]]
    # weighting scheme from paper: w_pos = |N| / (|P| + |N|), w_neg = |P| / (|P| + |N|)
    w_pos = sum(task_labels == 2).type(torch.FloatTensor) / len(task_labels) # categorical: [0: abstain, 1: positive, 2: negative]
    w_neg = sum(task_labels == 1).type(torch.FloatTensor) / len(task_labels)
    task_to_class_weights[task_name] = move_to_device(torch.tensor([w_pos, w_neg]), Meta.config["model_config"]["device"])

dataloader = EmmentalDataLoader(
    task_to_label_dict=task_to_label_dict,
    dataset=dataset,
    split=split,
    shuffle=True if split == "train" else False,
    batch_size=batch_size,
    num_workers=16,
)

tasks = get_task(task_names, task_to_class_weights)
model = EmmentalModel(name=DATA_NAME, tasks=tasks)

[2019-12-08 13:30:15,171][INFO] emmental.meta:110 - Logging was already initialized to use /tmp/2019_12_08/13_16_45/0bcdcba2.  To configure logging manually, call emmental.init_logging before initialiting Meta.
[2019-12-08 13:30:15,230][INFO] emmental.meta:60 - Loading Emmental default config from /lfs/1/gangus/repositories/pytorch-classification/emmental/src/emmental/emmental-default-config.yaml.
[2019-12-08 13:30:15,232][INFO] emmental.meta:160 - Updating Emmental config from user provided config.
[2019-12-08 13:30:20,769][INFO] emmental.task:48 - Created task: Atelectasis
[2019-12-08 13:30:20,771][INFO] emmental.task:48 - Created task: Cardiomegaly
[2019-12-08 13:30:20,772][INFO] emmental.task:48 - Created task: Effusion
[2019-12-08 13:30:20,773][INFO] emmental.task:48 - Created task: Infiltration
[2019-12-08 13:30:20,774][INFO] emmental.task:48 - Created task: Mass
[2019-12-08 13:30:20,775][INFO] emmental.task:48 - Created task: Nodule
[2019-12-08 13:30:20,776][INFO] emmental.task:

In [36]:
if Meta.config["model_config"]["model_path"]:
    model.load(Meta.config["model_config"]["model_path"])

[2019-12-08 13:30:21,302][INFO] emmental.model:518 - [CXR8] Model loaded from /lfs/1/gangus/repositories/pytorch-classification/Emmental-ChexNet/logs/2019_12_04/14_31_34/99ae8f45/best_model_model_all_val_loss.pth
[2019-12-08 13:30:21,303][INFO] emmental.model:71 - Moving model to GPU (cuda:0).


In [37]:
model.score(dataloader)

100%|██████████| 702/702 [00:40<00:00, 17.54it/s]


{'Atelectasis/CXR8/val/accuracy': 0.07852749799447366,
 'Atelectasis/CXR8/val/f1': 0.15025155623774195,
 'Atelectasis/CXR8/val/roc_auc': 0.7896770012121855,
 'Atelectasis/CXR8/val/loss': 0.6920055356583125,
 'Atelectasis/CXR8/val/average': 0.3394853518148004,
 'Cardiomegaly/CXR8/val/accuracy': 0.016044210713967378,
 'Cardiomegaly/CXR8/val/f1': 0.031875332034708694,
 'Cardiomegaly/CXR8/val/roc_auc': 0.8699177596623857,
 'Cardiomegaly/CXR8/val/loss': 0.44838390986896914,
 'Cardiomegaly/CXR8/val/average': 0.3059457674703539,
 'Effusion/CXR8/val/accuracy': 0.06043319368927712,
 'Effusion/CXR8/val/f1': 0.12008501594048883,
 'Effusion/CXR8/val/roc_auc': 0.8670489035906388,
 'Effusion/CXR8/val/loss': 0.5765145105979367,
 'Effusion/CXR8/val/average': 0.34918903774013493,
 'Infiltration/CXR8/val/accuracy': 0.15393528835012033,
 'Infiltration/CXR8/val/f1': 0.2737576285963383,
 'Infiltration/CXR8/val/roc_auc': 0.6830709248757703,
 'Infiltration/CXR8/val/loss': 0.8972146208419213,
 'Infiltration/C

In [38]:
d = {'Atelectasis/CXR8/val/accuracy': 0.07852749799447366,
 'Atelectasis/CXR8/val/f1': 0.15025155623774195,
 'Atelectasis/CXR8/val/roc_auc': 0.7896770012121855,
 'Atelectasis/CXR8/val/loss': 0.6920055356583125,
 'Atelectasis/CXR8/val/average': 0.3394853518148004,
 'Cardiomegaly/CXR8/val/accuracy': 0.016044210713967378,
 'Cardiomegaly/CXR8/val/f1': 0.031875332034708694,
 'Cardiomegaly/CXR8/val/roc_auc': 0.8699177596623857,
 'Cardiomegaly/CXR8/val/loss': 0.44838390986896914,
 'Cardiomegaly/CXR8/val/average': 0.3059457674703539,
 'Effusion/CXR8/val/accuracy': 0.06043319368927712,
 'Effusion/CXR8/val/f1': 0.12008501594048883,
 'Effusion/CXR8/val/roc_auc': 0.8670489035906388,
 'Effusion/CXR8/val/loss': 0.5765145105979367,
 'Effusion/CXR8/val/average': 0.34918903774013493,
 'Infiltration/CXR8/val/accuracy': 0.15393528835012033,
 'Infiltration/CXR8/val/f1': 0.2737576285963383,
 'Infiltration/CXR8/val/roc_auc': 0.6830709248757703,
 'Infiltration/CXR8/val/loss': 0.8972146208419213,
 'Infiltration/CXR8/val/average': 0.37025461394074294,
 'Mass/CXR8/val/accuracy': 0.04697388359033782,
 'Mass/CXR8/val/f1': 0.09049540654245729,
 'Mass/CXR8/val/roc_auc': 0.8174934868793657,
 'Mass/CXR8/val/loss': 0.6544761944973775,
 'Mass/CXR8/val/average': 0.3183209256707203,
 'Nodule/CXR8/val/accuracy': 0.050360994741064265,
 'Nodule/CXR8/val/f1': 0.09646576745774288,
 'Nodule/CXR8/val/roc_auc': 0.7140461368937955,
 'Nodule/CXR8/val/loss': 0.9041521511825018,
 'Nodule/CXR8/val/average': 0.2869576330308676,
 'Pneumonia/CXR8/val/accuracy': 0.011854889027542562,
 'Pneumonia/CXR8/val/f1': 0.023431994362226923,
 'Pneumonia/CXR8/val/roc_auc': 0.7300347657887276,
 'Pneumonia/CXR8/val/loss': 0.662257546569,
 'Pneumonia/CXR8/val/average': 0.25510721639283235,
 'Pneumothorax/CXR8/val/accuracy': 0.03788216418575631,
 'Pneumothorax/CXR8/val/f1': 0.07363770250368189,
 'Pneumothorax/CXR8/val/roc_auc': 0.8721789843640054,
 'Pneumothorax/CXR8/val/loss': 0.5403512984842698,
 'Pneumothorax/CXR8/val/average': 0.32789961701781456,
 'Consolidation/CXR8/val/accuracy': 0.039843123273018984,
 'Consolidation/CXR8/val/f1': 0.07663295045431168,
 'Consolidation/CXR8/val/roc_auc': 0.8071199796306774,
 'Consolidation/CXR8/val/loss': 0.8007840376650052,
 'Consolidation/CXR8/val/average': 0.30786535111933605,
 'Edema/CXR8/val/accuracy': 0.0177377662893306,
 'Edema/CXR8/val/f1': 0.03486334968465312,
 'Edema/CXR8/val/roc_auc': 0.8988796624013068,
 'Edema/CXR8/val/loss': 0.46390932216399805,
 'Edema/CXR8/val/average': 0.3171602594584302,
 'Emphysema/CXR8/val/accuracy': 0.015776807202067922,
 'Emphysema/CXR8/val/f1': 0.031200423056583815,
 'Emphysema/CXR8/val/roc_auc': 0.8435683634547272,
 'Emphysema/CXR8/val/loss': 0.5165548232352979,
 'Emphysema/CXR8/val/average': 0.296848531237793,
 'Fibrosis/CXR8/val/accuracy': 0.014796327658436581,
 'Fibrosis/CXR8/val/f1': 0.02916117698726394,
 'Fibrosis/CXR8/val/roc_auc': 0.7506300421081776,
 'Fibrosis/CXR8/val/loss': 0.6405456553269788,
 'Fibrosis/CXR8/val/average': 0.26486251558462603,
 'Pleural_Thickening/CXR8/val/accuracy': 0.03315803547553258,
 'Pleural_Thickening/CXR8/val/f1': 0.06418773186092659,
 'Pleural_Thickening/CXR8/val/roc_auc': 0.7868463209192176,
 'Pleural_Thickening/CXR8/val/loss': 0.9525099653646975,
 'Pleural_Thickening/CXR8/val/average': 0.29473069608522556,
 'Hernia/CXR8/val/accuracy': 0.0036545146626259027,
 'Hernia/CXR8/val/f1': 0.0072824156305506225,
 'Hernia/CXR8/val/roc_auc': 0.8536519906261864,
 'Hernia/CXR8/val/loss': 0.15308212358179019,
 'Hernia/CXR8/val/average': 0.288196306973121,
 'model/all/val/micro_average': 0.3087731302526285,
 'model/all/val/macro_average': 0.3087731302526285,
 'model/all/val/loss': 0.6359101210741468,
 'model/all/all/micro_average': 0.3087731302526285,
 'model/all/all/macro_average': 0.3087731302526285,
 'model/all/all/loss': 0.6359101210741468}
d = {k: v for k, v in d.items() if 'val/roc_auc' in k}

In [39]:
d

{'Atelectasis/CXR8/val/roc_auc': 0.7896770012121855,
 'Cardiomegaly/CXR8/val/roc_auc': 0.8699177596623857,
 'Effusion/CXR8/val/roc_auc': 0.8670489035906388,
 'Infiltration/CXR8/val/roc_auc': 0.6830709248757703,
 'Mass/CXR8/val/roc_auc': 0.8174934868793657,
 'Nodule/CXR8/val/roc_auc': 0.7140461368937955,
 'Pneumonia/CXR8/val/roc_auc': 0.7300347657887276,
 'Pneumothorax/CXR8/val/roc_auc': 0.8721789843640054,
 'Consolidation/CXR8/val/roc_auc': 0.8071199796306774,
 'Edema/CXR8/val/roc_auc': 0.8988796624013068,
 'Emphysema/CXR8/val/roc_auc': 0.8435683634547272,
 'Fibrosis/CXR8/val/roc_auc': 0.7506300421081776,
 'Pleural_Thickening/CXR8/val/roc_auc': 0.7868463209192176,
 'Hernia/CXR8/val/roc_auc': 0.8536519906261864}

In [None]:
for task_name in dataset.Y_dict.keys():
    print(task_name)
    print(sum(dataset.Y_dict[task_name] == 1).type(torch.DoubleTensor) / len(dataset))

In [None]:
idx = 5
x, y = dataset[idx]
image_index = x['image_name']
print(data_df.loc[image_index])
print(y)


# 2.) Determine Drain Prevalence in longitudinal exams

In [4]:
image_dir = '/lfs/1/jdunnmon/data/nih/images/images'
preds_df = pd.read_csv('/lfs/1/gangus/repositories/pytorch-classification/drain_detector/data/chexnet/by-patient-id/split/all_v2.csv', index_col=0).set_index('Image Index')
valid_df = pd.read_csv('/lfs/1/gangus/repositories/pytorch-classification/drain_detector/data/chexnet/by-patient-id/split/valid.csv', index_col=0).set_index('Image Index')

In [5]:
pos_pneumo = preds_df.loc[(preds_df['Pneumothorax'] == 1)]
neg_pneumo = preds_df.loc[(preds_df['Pneumothorax'] == 0)]

subsets = [('pos_pneumo', pos_pneumo), ('neg_pneumo', neg_pneumo)]
subset_info = defaultdict(list)
for subset_name, subset in subsets:
    patient_id_groups = subset.groupby('Patient ID')
    for patient_id, patient_id_group in tqdm(patient_id_groups):
        subset_info[subset_name].append({
            'patient_id': patient_id,
            'num_images': len(patient_id_group),
            'num_drains': sum(patient_id_group['drain'] == 1),
            'num_normal': sum(patient_id_group['drain'] == 0),
            'pairable': sum(patient_id_group['drain'] == 1) > 1 and sum(patient_id_group['drain'] == 0) > 1
        })

100%|██████████| 1484/1484 [00:01<00:00, 801.98it/s]
100%|██████████| 30726/30726 [00:31<00:00, 965.75it/s] 


In [20]:
counts = defaultdict(lambda: defaultdict(int))
for subset, info in subset_info.items():
    for row in info:
        counts[subset]['patients'] += 1
        if row['pairable']:
            counts[subset]['pairable'] += 1
            counts[subset]['num_images'] += row['num_images']
            counts[subset]['num_drains'] += row['num_drains']
            counts[subset]['num_normal'] += row['num_normal']
for subset_name, _ in subsets:
    counts[subset_name]['pairable_frac'] = counts[subset_name]['pairable'] / counts[subset_name]['patients']
counts

defaultdict(<function __main__.<lambda>()>,
            {'pos_pneumo': defaultdict(int,
                         {'patients': 1484,
                          'pairable': 208,
                          'num_images': 2240,
                          'num_drains': 1490,
                          'num_normal': 750,
                          'pairable_frac': 0.14016172506738545}),
             'neg_pneumo': defaultdict(int,
                         {'patients': 30726,
                          'pairable': 3570,
                          'num_images': 53326,
                          'num_drains': 23042,
                          'num_normal': 30284,
                          'pairable_frac': 0.11618824448349932})})

In [24]:
pairable_patients = defaultdict(list)
for subset, info in subset_info.items():
    for row in info:
        if row['pairable']:
            pairable_patients[subset].append(row['patient_id'])

In [49]:
subset_name = 'neg_pneumo'
idx = 6

pairable_patient = pairable_patients[subset_name][idx]
target_rows = preds_df.loc[(preds_df['Patient ID'] == pairable_patient)]
target_rows = target_rows.loc[target_rows['Pneumothorax'] == (1 if subset_name == 'pos_pneumo' else 0)]
print(len(target_rows))

5


In [None]:
target_rows = preds_df.loc[(preds_df['Patient ID'] == pairable_patient)]
target_rows = target_rows.loc[target_rows['Pneumothorax'] == (1 if subset_name == 'pos_pneumo' else 0)]

# y_true_rows = valid_df.loc[valid_df['Patient ID'] == target['patient_id']]
plt.rcParams['figure.figsize'] = [30, 10 * math.ceil(len(target_rows)/3)]

fig, axs = plt.subplots(math.ceil(len(target_rows)/3), 3)
if isinstance(axs[0], np.ndarray):
    axs = [ax for ax_row in axs for ax in ax_row]
    

print(np.array(target_rows['drain']))
# print(list(y_true_rows['drain']))



for i, ax in enumerate(axs):
    if i > len(target_rows) - 1:
        axs[i].set_axis_off()
    else:
        row = target_rows.iloc[i]
        image_path = osp.join(image_dir, row.name)
        img = Image.open(image_path)
        axs[i].imshow(img, cmap=plt.cm.bone)