In [9]:
cd gen/ddsm-visual-primitives/deepminer

/Users/gen/donuts/mit_backup/gen/ddsm-visual-primitives/deepminer


In [11]:
import os, sys

import matplotlib.pyplot as plt
import numpy as np
import joblib
import torch
#import torch.backends.cudnn as cudnn
import torch.nn as nn
import torch.utils.data
import torchvision.transforms as transforms
#from munch import Munch
from PIL import Image
from torch.autograd import Variable
sys.path.append('heatmaps')
import models.inception
import models.resnet

ModuleNotFoundError: No module named 'models.inception'

In [2]:
meta_data = joblib.load('/home/gen/ddsm_meta_data.jbl')
unit_labels = joblib.load('/home/gen/cleaned_unit_labels.jbl')

In [3]:
data_root = os.environ.get('DATA_ROOT', '/data/vision/torralba/scratch2/jimmywu/ddsm-visual-primitives/data/raw/')
image_list_dir = os.environ.get('IMAGE_LIST_DIR', '/data/vision/torralba/scratch2/jimmywu/ddsm-visual-primitives/data/raw_image_lists/')
config_path = os.environ.get('CONFIG_PATH', '/data/vision/torralba/deeprobotics/deeprobotics/ddsm-visual-primitives/ddsm-visual-primitives/training/logs/normal_benign_cancer/2018-02-18_02-08-28.397731_resnet152_pretrained_lr-0.0001_decay-4/config.yml')
epoch = int(os.environ.get('EPOCH', '5'))
class_index = int(os.environ.get('CLASS_INDEX', '2'))
split = os.environ.get('SPLIT', 'val')

image_list_path = os.path.join(image_list_dir, '{}.txt'.format(split))
image_list_path

with open(config_path, 'r') as f:
    cfg = Munch.fromYAML(f)

if cfg.arch.num_classes == 3:
    mask_root = '/data/vision/torralba/scratch2/jimmywu/ddsm-visual-primitives/data/masks/threeclass'

In [4]:
def surgery(model, arch, num_classes):
    if arch == 'inception_v3' or arch == 'resnet152':
        model.module.fc.cpu()
        state_dict = model.state_dict()
        state_dict['module.fc.weight'] = state_dict['module.fc.weight'].view(num_classes, 2048, 1, 1)
        model.module.fc = nn.Conv2d(2048, num_classes, kernel_size=(1, 1))
        model.load_state_dict(state_dict)
        model.module.fc.cuda()
    else:
        raise Exception
class DDSM(torch.utils.data.Dataset):
    def __init__(self, root, image_list_path, split, patch_size, transform):
        self.root = root
        with open(image_list_path, 'r') as f:
            self.image_names = map(lambda line: line.strip(), f.readlines())
        self.patch_size = patch_size
        self.transform = transform

    def __len__(self):
        return len(self.image_names)

    def __getitem__(self, idx):
        image_name = self.image_names[idx]
        image = Image.open(os.path.join(self.root, image_name))
        min_dim = min(image.size)
        ratio = float(4 * self.patch_size) / min_dim
        new_size = (int(ratio * image.size[0]), int(ratio * image.size[1]))
        image = image.resize(new_size, resample=Image.BILINEAR)
        image = np.asarray(image)
        image = np.broadcast_to(np.expand_dims(image, 2), image.shape + (3,))
        image = self.transform(image)
        return image_name, image        

In [5]:
model_names = sorted(name for name in models.__dict__
    if name.islower() and not name.startswith("__")
    and callable(models.__dict__[name]))

print("=> creating model '{}'".format(cfg.arch.model))
if cfg.arch.model == 'inception_v3':
    model = models.inception.inception_v3(use_avgpool=False, transform_input=True)
    model.aux_logits = False
    model.fc = nn.Linear(2048, cfg.arch.num_classes)
    features_layer = model.Mixed_7c
elif cfg.arch.model == 'resnet152':
    model = models.resnet.resnet152(use_avgpool=False)
    model.fc = nn.Linear(2048, cfg.arch.num_classes)
    features_layer = model.layer4
else:
    raise Exception

model = torch.nn.DataParallel(model).cuda()
#cudnn.benchmark = True
resume_path = cfg.training.resume.replace(cfg.training.resume[-16:-8], '{:08}'.format(epoch))
if os.path.isfile(resume_path):
    print("=> loading checkpoint '{}'".format(resume_path))
    checkpoint = torch.load(resume_path)
    start_epoch = checkpoint['epoch']
    model.load_state_dict(checkpoint['state_dict'])
    model.eval()
    print("=> loaded checkpoint '{}' (epoch {})".format(resume_path, checkpoint['epoch']))
else:
    print("=> no checkpoint found at '{}'".format(resume_path))

surgery(model, cfg.arch.model, cfg.arch.num_classes)

normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
patch_size = 299 if cfg.arch.model == 'inception_v3' else 227
dataset = DDSM(data_root, image_list_path, split, patch_size, transforms.Compose([
    transforms.ToTensor(),
    normalize,
]))

=> creating model 'resnet152'
=> loading checkpoint '/data/vision/torralba/scratch2/jimmywu/ddsm-visual-primitives/checkpoints/normal_benign_cancer/2018-02-18_02-08-28.397731_resnet152_pretrained_lr-0.0001_decay-4/checkpoint_00000005.pth.tar'
=> loaded checkpoint '/data/vision/torralba/scratch2/jimmywu/ddsm-visual-primitives/checkpoints/normal_benign_cancer/2018-02-18_02-08-28.397731_resnet152_pretrained_lr-0.0001_decay-4/checkpoint_00000005.pth.tar' (epoch 5)


In [6]:
# extract features and max activations
features = []
def feature_hook(module, input, output):
    features.extend(output.data.cpu().numpy())
features_layer._forward_hooks.clear()
features_layer.register_forward_hook(feature_hook)
prob_maps = []
max_class_probs = []
for _, image in dataset:
    input_var = Variable(image.unsqueeze(0), volatile=True)
    output = model(input_var)
    output = output.transpose(1, 3).contiguous()
    size = output.size()[:3]
    output = output.view(-1, output.size(3))
    prob = nn.Softmax()(output)
    prob = prob.view(size[0], size[1], size[2], -1)
    prob = prob.transpose(1, 3)
    prob = prob.data.cpu().numpy()
    prob_map = prob[0][class_index]
    prob_maps.append(prob_map)
    max_class_probs.append(prob_map.max())
max_class_probs = np.array(max_class_probs)
image_indices = np.argsort(-max_class_probs)

  from ipykernel import kernelapp as app


In [10]:
num_units = [1, 2, 4, 8, 20]
predicted_reports = {}
for num_top_units in num_units:
    max_activations = np.array([feature_map.max(axis=(1, 2)) for feature_map in features])
    #max_activations = np.expand_dims(max_activations, 1)
    params = list(model.parameters())
    weight_softmax = params[-2].data.cpu().numpy().squeeze(3).squeeze(2)
    weighted_max_activations = max_activations * weight_softmax[class_index, :]
    #unit_indices = np.argsort(-weighted_max_activations, axis=2)
    #unit_indices = unit_indices[:, :, :num_top_units]
    unit_indices = np.argsort(-weighted_max_activations, axis=1)
    unit_indices = unit_indices[:, :num_top_units]

    # reports
    for image_index in image_indices: 
        image_name, image = dataset[image_index]
        try:
            gt_report = [g[1] for g in meta_data['meta'][image_name]]
        except:
            # this image didn't have a GT report
            continue
        indices = unit_indices[image_index]
        caption = ' '.join(['unit_{:04}'.format(unit_index + 1) for unit_index in indices])
        unit_report = []
        for uidx in ['unit_{:04}'.format(unit_index + 1) for unit_index in indices]:
            if uidx in unit_labels.keys():
                unit_report.append(unit_labels[uidx])
        try:
            tmp = predicted_reports[image_name]
        except:
            predicted_reports[image_name] = []
        
        predicted_reports[image_name].append([gt_report, unit_report])
#         print('image {} gt report: {}'.format(image_name, gt_report))
#         print('class {} top units: {}'.format(class_index, caption))
#         print('unit report: {}'.format(unit_report))    
#         print('')    

In [12]:
joblib.dump(predicted_reports, 'predicted_reports_val.jbl')

['predicted_reports_val.jbl']

In [None]:
# Pretty Print Reports
for image_index in image_indices[0]: # TEST!!
    image_name, image = dataset[image_index]
    mask_path = os.path.join(mask_root, image_name.replace('jpg', 'png'))
    mask = None
    if os.path.exists(mask_path):
        mask = Image.open(mask_path).resize((image.size(2), image.size(1)))
        mask = np.asarray(mask)
    for t, m, s in zip(image, [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]):
        t.mul_(s).add_(m)
    image.clamp_(0, 1)
    image = image.numpy().transpose(1, 2, 0)
    print('image name: {}'.format(image_name))
    print('class {} prob: {}'.format(class_index, max_class_probs[image_index]))
    prob_map = prob_maps[image_index]
    fig, axes = plt.subplots(1, 3, figsize=(3 * 6, 6))
    axes[0].imshow(image)
    image_size = image.shape[1::-1]
    axes[1].imshow(image)
    if mask is not None:
        axes[1].imshow(mask == class_index, alpha=0.5, cmap='jet', vmin=0, vmax=1)
    heatmap = np.asarray(Image.fromarray(prob_map).resize(image_size, resample=Image.BILINEAR))
    axes[2].imshow(image)
    axes[2].imshow(heatmap, alpha=0.5, cmap='jet', vmin=0, vmax=1)
    plt.show()
    indices = unit_indices[image_index]
    caption = ' '.join(['unit_{:04}'.format(unit_index + 1) for unit_index in indices])
    print('class {} top units: {}'.format(class_index, caption))
    fig, axes = plt.subplots(1, 8, figsize=(8 * 4, 4))
    feature_maps = []
    top_feature_maps = features[image_index][indices]
    top_feature_maps = top_feature_maps - top_feature_maps.min()
    top_feature_maps = top_feature_maps / top_feature_maps.max()
    for j, unit_index in enumerate(indices):
        feature_map = top_feature_maps[j]
        image_size = image.shape[1::-1]
        feature_map = np.asarray(Image.fromarray(feature_map).resize(image_size, resample=Image.BILINEAR))
        feature_maps.append(feature_map)
        axes[j].imshow(image)
        axes[j].imshow(feature_map, alpha=0.5, cmap='jet', vmin=0, vmax=1)
    plt.show()
    print('')