# 12/5/19

Seeing if we've actually reproduced CheXNet.

In [3]:
%load_ext autoreload
%autoreload 2

import math
import os
import os.path as osp
import json
from functools import partial
from collections import defaultdict
os.chdir('/lfs/1/gangus/repositories/pytorch-classification/Emmental-ChexNet')

import torch
import torch.nn as nn
import torch.nn.functional as F
import sklearn.metrics as skl
import numpy as np
import pandas as pd
from tqdm import tqdm
import matplotlib.pyplot as plt
from PIL import Image

import emmental
from emmental import Meta
from emmental.data import EmmentalDataLoader
from emmental.learner import EmmentalLearner
from emmental.model import EmmentalModel
from emmental.scorer import Scorer
from emmental.task import EmmentalTask

from dataset import CXR8Dataset
from task import get_task
from task_config import CXR8_TASK_NAMES
from transforms import get_data_transforms

In [None]:
emmental_dir = '/lfs/1/gangus/repositories/pytorch-classification/Emmental-ChexNet/logs/2019_12_04/14_31_34/99ae8f45'

In [None]:
def ce_loss(task_name, immediate_ouput_dict, Y, active):
    module_name = f"{task_name}_pred_head"
    return F.cross_entropy(
        immediate_ouput_dict[module_name][0][active], (Y.view(-1) - 1)[active]
    )

def output(task_name, immediate_ouput_dict):
    module_name = f"{task_name}_pred_head"
    return F.softmax(immediate_ouput_dict[module_name][0], dim=-1)

DATA_NAME = 'CXR8'

image_path = '/lfs/1/jdunnmon/data/nih/images/images'
data_path = '/dfs/scratch1/senwu/mmtl/emmental-tutorials/chexnet/data/nih_labels.csv'

task_names = CXR8_TASK_NAMES
task_to_label_dict = {t: t for t in task_names}
add_binary_triage_label = False
batch_size = 16

emmental.init()
split = 'val'

model_config = {
    'model_path': osp.join(emmental_dir, 'best_model_model_all_val_loss.pth.pth'),
    'device': 0,
    'dataparallel': True
}

cxr8_transform = get_data_transforms(DATA_NAME)
dataset = CXR8Dataset(
    name=DATA_NAME,
    path_to_images=image_path,
    path_to_labels=data_path,
    split=split,
    transform=cxr8_transform[split],
    sample=0,
    seed=1701,
    add_binary_triage_label=add_binary_triage_label,
)

dataloader = EmmentalDataLoader(
    task_to_label_dict=task_to_label_dict,
    dataset=dataset,
    split=split,
    shuffle=True if split == "train" else False,
    batch_size=batch_size,
    num_workers=16,
)

tasks = get_task(task_names)
model = EmmentalModel(name=DATA_NAME, tasks=tasks)

In [None]:
if Meta.config["model_config"]["model_path"]:
    model.load(Meta.config["model_config"]["model_path"])

In [None]:
output = model.predict(dataloader)

In [None]:
for task_name in dataset.Y_dict.keys():
    print(task_name)
    print(sum(dataset.Y_dict[task_name] == 1).type(torch.DoubleTensor) / len(dataset))

In [None]:
idx = 5
x, y = dataset[idx]
image_index = x['image_name']
print(data_df.loc[image_index])
print(y)


# 2.) Determine Drain Prevalence in longitudinal exams

In [4]:
image_dir = '/lfs/1/jdunnmon/data/nih/images/images'
preds_df = pd.read_csv('/lfs/1/gangus/repositories/pytorch-classification/drain_detector/data/chexnet/by-patient-id/split/all_v2.csv', index_col=0).set_index('Image Index')
valid_df = pd.read_csv('/lfs/1/gangus/repositories/pytorch-classification/drain_detector/data/chexnet/by-patient-id/split/valid.csv', index_col=0).set_index('Image Index')

In [5]:
pos_pneumo = preds_df.loc[(preds_df['Pneumothorax'] == 1)]
neg_pneumo = preds_df.loc[(preds_df['Pneumothorax'] == 0)]

subsets = [('pos_pneumo', pos_pneumo), ('neg_pneumo', neg_pneumo)]
subset_info = defaultdict(list)
for subset_name, subset in subsets:
    patient_id_groups = subset.groupby('Patient ID')
    for patient_id, patient_id_group in tqdm(patient_id_groups):
        subset_info[subset_name].append({
            'patient_id': patient_id,
            'num_images': len(patient_id_group),
            'num_drains': sum(patient_id_group['drain'] == 1),
            'num_normal': sum(patient_id_group['drain'] == 0),
            'pairable': sum(patient_id_group['drain'] == 1) > 1 and sum(patient_id_group['drain'] == 0) > 1
        })

100%|██████████| 1484/1484 [00:01<00:00, 801.98it/s]
100%|██████████| 30726/30726 [00:31<00:00, 965.75it/s] 


In [20]:
counts = defaultdict(lambda: defaultdict(int))
for subset, info in subset_info.items():
    for row in info:
        counts[subset]['patients'] += 1
        if row['pairable']:
            counts[subset]['pairable'] += 1
            counts[subset]['num_images'] += row['num_images']
            counts[subset]['num_drains'] += row['num_drains']
            counts[subset]['num_normal'] += row['num_normal']
for subset_name, _ in subsets:
    counts[subset_name]['pairable_frac'] = counts[subset_name]['pairable'] / counts[subset_name]['patients']
counts

defaultdict(<function __main__.<lambda>()>,
            {'pos_pneumo': defaultdict(int,
                         {'patients': 1484,
                          'pairable': 208,
                          'num_images': 2240,
                          'num_drains': 1490,
                          'num_normal': 750,
                          'pairable_frac': 0.14016172506738545}),
             'neg_pneumo': defaultdict(int,
                         {'patients': 30726,
                          'pairable': 3570,
                          'num_images': 53326,
                          'num_drains': 23042,
                          'num_normal': 30284,
                          'pairable_frac': 0.11618824448349932})})

In [24]:
pairable_patients = defaultdict(list)
for subset, info in subset_info.items():
    for row in info:
        if row['pairable']:
            pairable_patients[subset].append(row['patient_id'])

In [49]:
subset_name = 'neg_pneumo'
idx = 6

pairable_patient = pairable_patients[subset_name][idx]
target_rows = preds_df.loc[(preds_df['Patient ID'] == pairable_patient)]
target_rows = target_rows.loc[target_rows['Pneumothorax'] == (1 if subset_name == 'pos_pneumo' else 0)]
print(len(target_rows))

5


In [None]:
target_rows = preds_df.loc[(preds_df['Patient ID'] == pairable_patient)]
target_rows = target_rows.loc[target_rows['Pneumothorax'] == (1 if subset_name == 'pos_pneumo' else 0)]

# y_true_rows = valid_df.loc[valid_df['Patient ID'] == target['patient_id']]
plt.rcParams['figure.figsize'] = [30, 10 * math.ceil(len(target_rows)/3)]

fig, axs = plt.subplots(math.ceil(len(target_rows)/3), 3)
if isinstance(axs[0], np.ndarray):
    axs = [ax for ax_row in axs for ax in ax_row]
    

print(np.array(target_rows['drain']))
# print(list(y_true_rows['drain']))



for i, ax in enumerate(axs):
    if i > len(target_rows) - 1:
        axs[i].set_axis_off()
    else:
        row = target_rows.iloc[i]
        image_path = osp.join(image_dir, row.name)
        img = Image.open(image_path)
        axs[i].imshow(img, cmap=plt.cm.bone)