In [1]:
import csv
import os, sys
import random

import numpy as np
import joblib
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from munch import Munch
from PIL import Image
from tqdm import tqdm_notebook as tqdm

import datasets
import fcn_resnet

In [2]:
data_root = '../data/ddsm_raw/'
image_list_dir = '../data/ddsm_raw_image_lists/'
mask_root = '../data/ddsm_masks/3class'
config_path = '../training/pretrained/resnet152_3class/config.yml'
epoch = 5
split = 'test'
image_list_path = os.path.join(image_list_dir, '{}.txt'.format(split))

In [3]:
with open(config_path, 'r') as f:
    cfg = Munch.fromYAML(f)

In [4]:
print("=> creating model '{}'".format(cfg.arch.model))
if cfg.arch.model == 'resnet152':
    model = fcn_resnet.resnet152(num_classes=cfg.arch.num_classes)
    features_layer = model.layer4
else:
    raise Exception

model = torch.nn.DataParallel(model).cuda()

=> creating model 'resnet152'


In [5]:
resume_path = cfg.training.resume.replace(cfg.training.resume[-16:-8], '{:08}'.format(epoch))
resume_path = os.path.join('../training', resume_path)
if os.path.isfile(resume_path):
    print("=> loading checkpoint '{}'".format(resume_path))
    checkpoint = torch.load(resume_path)
    start_epoch = checkpoint['epoch']
    model.load_state_dict(checkpoint['state_dict'])
    model.eval()
    print("=> loaded checkpoint '{}' (epoch {})".format(resume_path, checkpoint['epoch']))
else:
    raise Exception("=> no checkpoint found at '{}'".format(resume_path))

=> loading checkpoint '../training/pretrained/resnet152_3class/checkpoint_00000005.pth.tar'
=> loaded checkpoint '../training/pretrained/resnet152_3class/checkpoint_00000005.pth.tar' (epoch 5)


In [6]:
model.module.surgery()

In [7]:
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
patch_size = 227
dataset = datasets.DDSM(data_root, image_list_path, split, patch_size, transforms.Compose([
    transforms.ToTensor(),
    normalize,
]))

In [8]:
# extract features and max activations
features = []
def feature_hook(module, input, output):
    features.extend(output.data.cpu().numpy())
features_layer._forward_hooks.clear()
features_layer.register_forward_hook(feature_hook)
prob_maps = []
max_class_probs = []
count = 0
skipcount = 0
extracted_dataset = []
with torch.no_grad():
    for im_name, image in tqdm(dataset):
        # skipping normal cases for text explanation benchmark
        if im_name.startswith('normal'):
            skipcount += 1
            continue
        try:
            count += 1
            input = image.unsqueeze(0)
            input = input.cuda()
            output = model(input)
            output = output.transpose(1, 3).contiguous()
            size = output.size()[:3]
            output = output.view(-1, output.size(3))
            prob = nn.Softmax(dim=1)(output)
            prob = prob.view(size[0], size[1], size[2], -1)
            prob = prob.transpose(1, 3)
            prob = prob.cpu().numpy()
            prob_map = prob[0]
            prob_maps.append(prob_map)
            max_class_probs.append(prob_map.max(axis=(1, 2)))
            extracted_dataset.append((im_name, image))
        except:
            skipcount += 1
            continue
print(f"Skipped {skipcount} images in dataset corresponding to normal cases")
max_class_probs = np.array(max_class_probs)
image_indices = np.argsort(-max_class_probs, axis=0)

HBox(children=(IntProgress(value=0, max=1044), HTML(value='')))


Skipped 276 images in dataset corresponding to normal cases


In [9]:
max_class_probs = np.array(max_class_probs)
image_indices = np.argsort(-max_class_probs, axis=0)

In [10]:
num_top_units = 20
params = list(model.parameters())
weight_softmax = params[-2].data.cpu().numpy().squeeze(3).squeeze(2)
max_activations = np.array([feature_map.max(axis=(1, 2)) for feature_map in features])
max_activations = np.expand_dims(max_activations, 1)
weighted_max_activations = max_activations * weight_softmax
unit_indices = np.argsort(-weighted_max_activations, axis=2)
unit_indices = unit_indices[:, :, :num_top_units]

In [11]:
meta_data = joblib.load('data/ddsm_meta_data.jbl')
unit_labels = joblib.load('data/cleaned_unit_labels.jbl')

In [12]:
class_index = 2  # cancer

In [13]:
# DeepMiner Benchmark: Text only explanations, 
# question is "Does this explanation indicate cancer or benign?"

# Respondants to this benchmark are expected to label cancer/benign 
# for each case based only on the unit explanations. 

explanations_benign = []
explanations_cancer = []
diagnosis = []
for count, image_index in enumerate(image_indices[:, class_index]):
    image_name, image = extracted_dataset[image_index]
    #print(f"Case #{count}: {image_name}")
    prefix = image_name.split('-')[0][:-3]
    try:
        gt_report = meta_data['meta'][f'{prefix}s/{image_name}']
    except:
        #print(f'{image_name} does not have a GT report')
        continue
    #print('class {} prob: {}'.format(class_index, max_class_probs[image_index][class_index]))
    diagnosis.append((count, image_name, gt_report, max_class_probs[image_index][class_index], gt_report[0][-3]))

    # Pick top 3 of the top 20 units that have non-empty explanation strings
    # Randomly pick 3 (same as # top units) from the list of explained units 
    # that are not already top units for this mammogram
    num_top_units = 3
    top_units_w_expl = [] # this are the strings that are the keys for unit_labels (1-based index)

    # collect top unit explanations to exclude bottom units with same label
    top_explanations = set()
    
    for unit_index in unit_indices[image_index][class_index]:
        try:
            #print(f"Top Unit: {unit_index+1} {unit_labels['unit_'+str(unit_index + 1)]}")
            top_explanations.add(unit_labels['unit_'+str(unit_index + 1)])
            top_units_w_expl.append('unit_'+str(unit_index + 1))
        except KeyError:
            continue
    top_units_w_expl = top_units_w_expl[:min(num_top_units,len(top_units_w_expl))]
    #print(top_units_w_expl)
    deepminer_explanation = []
    for unit_name in top_units_w_expl:
        unit_index = int(unit_name.split('_')[1])-1
        #print(f"{unit_name} {unit_labels[unit_name]} activation val: {weighted_max_activations[image_index][class_index][unit_index]}")
        deepminer_explanation.append(unit_labels[unit_name])
    if prefix == 'cancer':
        explanations_cancer.append((count, deepminer_explanation))
    else:
        explanations_benign.append((count, deepminer_explanation))

print(f"total questions: {len(explanations_cancer) + len(explanations_benign)}, total answers: {len(diagnosis)}")

total questions: 416, total answers: 416


In [14]:
# Generate csv for entering text only benchmark reponses
import csv

# Expects a list of (case_fname, answer_value) tuples
def write_csv(csv_fname, header, row_list):
    
    # Write question csv
    with open(csv_fname+'.csv', mode='w') as csv_file:
        employee_writer = csv.writer(csv_file, delimiter=',', quotechar='"', quoting=csv.QUOTE_MINIMAL)

        employee_writer.writerow(header)
        for row in row_list:
            employee_writer.writerow(row)

In [15]:
# Randomly sample same number of cases per class
num_cases_per_class = 165
random.shuffle(explanations_cancer)
random.shuffle(explanations_benign)
explanations = explanations_cancer[:num_cases_per_class] + explanations_benign[:num_cases_per_class]
random.shuffle(explanations)

In [16]:
# Save to csv
with open('explanations-benchmark.csv', 'w', newline='') as csvfile:
    fieldnames = ['Case #', 'DeepMiner Explanation', 'Cancer or Benign (C/B)?']
    writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
    writer.writeheader()
    for e in explanations:
        writer.writerow({fieldname: value for fieldname, value in zip(fieldnames, e)})
        
with open('explanations-benchmark-answer-key.csv', 'w', newline='') as csvfile:
    fieldnames = ['Case #', 'Image Name', 'GT Report', 'DeepMiner Cancer Likelihood Prediction', 'GT Diagnosis']
    writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
    writer.writeheader()
    for d in diagnosis:
        writer.writerow({fieldname: value for fieldname, value in zip(fieldnames, d)})

In [2]:
# Score responses
with open('explanations-benchmark-responses.csv', newline='') as csvfile:
    reader = csv.DictReader(csvfile)
    responses = {row['Case #']: row['Cancer or Benign (C/B)?'].upper() for row in reader}
with open('explanations-benchmark-answer-key.csv', newline='') as csvfile:
    reader = csv.DictReader(csvfile)
    labels = {row['Case #']: 'C' if row['Image Name'].startswith('cancer') else 'B' for row in reader}
np.sum([labels[case_num] == response for case_num, response in responses.items()])

182