In [1]:
%matplotlib inline
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
from PIL import Image, ImageDraw
import scipy.ndimage.filters as filters
from scipy.ndimage import binary_dilation
import scipy.ndimage as ndimage
import matplotlib.patches as patches
from collections import OrderedDict
from skimage.measure import label
import cv2
import re
import numpy as np
import pandas as pd
import os,gc
import sys
import shutil
import math
import random
import heapq 
import time
import copy
import itertools  
from sklearn.metrics import confusion_matrix,roc_curve,accuracy_score,auc,roc_auc_score 
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import torchvision
import torchvision.transforms as transforms
import torch.utils.model_zoo as model_zoo
torch.cuda.set_device(0)
print (torch.cuda.current_device())
#os.environ['CUDA_VISIBLE_DEVICES'] = "0,1,2,3,4,5,6,7"

0


In [2]:
#!/usr/bin/env python3
# encoding: utf-8
import re
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.model_zoo as model_zoo
from collections import OrderedDict
 
__all__ = ['DenseNet', 'Densenet121_AG']
 
model_urls = {
    'densenet121': 'https://download.pytorch.org/models/densenet121-a639ec97.pth',
    'densenet169': 'https://download.pytorch.org/models/densenet169-b2777c0a.pth',
    'densenet201': 'https://download.pytorch.org/models/densenet201-c1103571.pth',
    'densenet161': 'https://download.pytorch.org/models/densenet161-8d451a50.pth',
}

def Densenet121_AG(pretrained=False, **kwargs):
    r"""Densenet-121 model from
    `"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`_
    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 24, 16),
                     **kwargs)
    if pretrained:
        # '.'s are no longer allowed in module names, but pervious _DenseLayer
        # has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'.
        # They are also in the checkpoints in model_urls. This pattern is used
        # to find such keys.
        pattern = re.compile(
            r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$')
        state_dict = model_zoo.load_url(model_urls['densenet121'])
        for key in list(state_dict.keys()):
            res = pattern.match(key)
            if res:
                new_key = res.group(1) + res.group(2)
                state_dict[new_key] = state_dict[key]
                del state_dict[key]
        model.load_state_dict(state_dict)
    return model


class _DenseLayer(nn.Sequential):
    def __init__(self, num_input_features, growth_rate, bn_size, drop_rate):
        super(_DenseLayer, self).__init__()
        self.add_module('norm1', nn.BatchNorm2d(num_input_features)),
        self.add_module('relu1', nn.ReLU(inplace=True)),
        self.add_module('conv1', nn.Conv2d(num_input_features, bn_size *
                        growth_rate, kernel_size=1, stride=1, bias=False)),
        self.add_module('norm2', nn.BatchNorm2d(bn_size * growth_rate)),
        self.add_module('relu2', nn.ReLU(inplace=True)),
        self.add_module('conv2', nn.Conv2d(bn_size * growth_rate, growth_rate,
                        kernel_size=3, stride=1, padding=1, bias=False)),
        self.drop_rate = drop_rate
 
    def forward(self, x):
        new_features = super(_DenseLayer, self).forward(x)
        if self.drop_rate > 0:
            new_features = F.dropout(new_features, p=self.drop_rate, training=self.training)
        return torch.cat([x, new_features], 1)
 
 
class _DenseBlock(nn.Sequential):
    def __init__(self, num_layers, num_input_features, bn_size, growth_rate, drop_rate):
        super(_DenseBlock, self).__init__()
        for i in range(num_layers):
            layer = _DenseLayer(num_input_features + i * growth_rate, growth_rate, bn_size, drop_rate)
            self.add_module('denselayer%d' % (i + 1), layer)
 
 
class _Transition(nn.Sequential):
    def __init__(self, num_input_features, num_output_features):
        super(_Transition, self).__init__()
        self.add_module('norm', nn.BatchNorm2d(num_input_features))
        self.add_module('relu', nn.ReLU(inplace=True))
        self.add_module('conv', nn.Conv2d(num_input_features, num_output_features,
                                          kernel_size=1, stride=1, bias=False))
        self.add_module('pool', nn.AvgPool2d(kernel_size=2, stride=2))
 
 
class DenseNet(nn.Module):
    r"""Densenet-BC model class, based on
    `"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`_
    Args:
        growth_rate (int) - how many filters to add each layer (`k` in paper)
        block_config (list of 4 ints) - how many layers in each pooling block
        num_init_features (int) - the number of filters to learn in the first convolution layer
        bn_size (int) - multiplicative factor for number of bottle neck layers
          (i.e. bn_size * k features in the bottleneck layer)
        drop_rate (float) - dropout rate after each dense layer
        num_classes (int) - number of classification classes
    """
 
    def __init__(self, growth_rate=32, block_config=(6, 12, 24, 16),
                 num_init_features=64, bn_size=4, drop_rate=0, num_classes=1000):
 
        super(DenseNet, self).__init__()
 
        # First convolution
        self.features = nn.Sequential(OrderedDict([
            ('conv0', nn.Conv2d(3, num_init_features, kernel_size=7, stride=2, padding=3, bias=False)),
            ('norm0', nn.BatchNorm2d(num_init_features)),
            ('relu0', nn.ReLU(inplace=True)),
            ('pool0', nn.MaxPool2d(kernel_size=3, stride=2, padding=1)),
        ]))
 
        # Each denseblock
        num_features = num_init_features
        for i, num_layers in enumerate(block_config):
            block = _DenseBlock(num_layers=num_layers, num_input_features=num_features,
                                bn_size=bn_size, growth_rate=growth_rate, drop_rate=drop_rate)
            self.features.add_module('denseblock%d' % (i + 1), block)
            num_features = num_features + num_layers * growth_rate
            if i != len(block_config) - 1:
                trans = _Transition(num_input_features=num_features, num_output_features=num_features // 2)
                self.features.add_module('transition%d' % (i + 1), trans)
                num_features = num_features // 2
 
        # Final batch norm
        self.features.add_module('norm5', nn.BatchNorm2d(num_features))
 
        # Linear layer
        self.classifier = nn.Linear(num_features, num_classes)

        self.Sigmoid = nn.Sigmoid()
 
        # Official init from torch repo.
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.constant_(m.bias, 0)

    def forward(self, x):
        features = self.features(x)
        out = F.relu(features, inplace=True)
        out_after_pooling = F.avg_pool2d(out, kernel_size=7, stride=1).view(features.size(0), -1)
        out = self.classifier(out_after_pooling)
        out = self.Sigmoid(out)
        return out, features, out_after_pooling


class Fusion_Branch(nn.Module):
    def __init__(self, input_size, output_size):
        super(Fusion_Branch, self).__init__()
        self.fc = nn.Linear(input_size, output_size)
        self.Sigmoid = nn.Sigmoid()

    def forward(self, global_pool, local_pool):
        #fusion = torch.cat((global_pool.unsqueeze(2), local_pool.unsqueeze(2)), 2).cuda()
        #fusion = fusion.max(2)[0]#.squeeze(2).cuda()
        #print(fusion.shape)
        fusion = torch.cat((global_pool,local_pool), 1).cuda()
        fusion_var = torch.autograd.Variable(fusion)
        x = self.fc(fusion_var)
        x = self.Sigmoid(x)

        return x

In [3]:
import torch
from torch.utils.data import Dataset
from PIL import Image
import os


class ChestXrayDataSet(Dataset):
    def __init__(self, data_dir, image_list_file, transform=None):
        """
        Args:
            data_dir: path to image directory.
            image_list_file: path to the file containing images
                with corresponding labels.
            transform: optional transform to be applied on a sample.
        """
        image_names = []
        labels = []
        with open(image_list_file, "r") as f:
            for line in f:
                items = line.split()
                image_name= items[0].split('/')[1]
                label = items[1:]
                label = [int(i) for i in label]
                image_name = os.path.join(data_dir, image_name)
                image_names.append(image_name)
                labels.append(label)

        self.image_names = image_names
        self.labels = labels
        self.transform = transform

    def __getitem__(self, index):
        """
        Args:
            index: the index of item
        Returns:
            image and its labels
        """
        image_name = self.image_names[index]
        image = Image.open(image_name).convert('RGB')
        label = self.labels[index]
        if self.transform is not None:
            image = self.transform(image)
        return image, torch.FloatTensor(label)

    def __len__(self):
        return len(self.image_names)

In [None]:
# encoding: utf-8
"""
Training implementation
Author: Ian Ren
Update time: 08/11/2020
"""
import re
import sys
import os
import cv2
import time
import numpy as np
import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
from torch.optim import lr_scheduler
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from sklearn.metrics import roc_auc_score
from skimage.measure import label
from PIL import Image

#np.set_printoptions(threshold = np.nan)


CKPT_PATH = ''

CKPT_PATH_G = ''#'/data/tmpexec/AGCNN/paper/AG_CNN_Global_epoch_1.pkl' 
CKPT_PATH_L = ''#'/data/tmpexec/AGCNN/paper/AG_CNN_Local_epoch_2.pkl' 
CKPT_PATH_F = ''#'/data/tmpexec/AGCNN/paper/AG_CNN_Fusion_epoch_23.pkl'

N_CLASSES = 14
CLASS_NAMES = [ 'Atelectasis', 'Cardiomegaly', 'Effusion', 'Infiltration', 'Mass', 'Nodule', 'Pneumonia',
                'Pneumothorax', 'Consolidation', 'Edema', 'Emphysema', 'Fibrosis', 'Pleural_Thickening', 'Hernia']

# load with your own dataset path
DATA_DIR = '/data/fjsdata/NIH-CXR/images/images/'
TRAIN_IMAGE_LIST = '/data/fjsdata/NIH-CXR/chexnet_dataset/train.txt'
VAL_IMAGE_LIST = '/data/fjsdata/NIH-CXR/chexnet_dataset/val.txt'
save_model_path = '/data/tmpexec/AGCNN/'
save_model_name = 'AG_CNN'

# learning rate
LR_G = 1e-8
LR_L = 1e-8
LR_F = 1e-3
num_epochs = 50
BATCH_SIZE = 32

normalize = transforms.Normalize(
   mean=[0.485, 0.456, 0.406],
   std=[0.229, 0.224, 0.225]
)
preprocess = transforms.Compose([
   transforms.Resize((256,256)),
   transforms.CenterCrop(224),
   transforms.ToTensor(),
   normalize,
])


def Attention_gen_patchs(ori_image, fm_cuda):
    # feature map -> feature mask (using feature map to crop on the original image) -> crop -> patchs
    feature_conv = fm_cuda.data.cpu().numpy()
    size_upsample = (224, 224) 
    bz, nc, h, w = feature_conv.shape

    patchs_cuda = torch.FloatTensor().cuda()

    for i in range(0, bz):
        feature = feature_conv[i]
        cam = feature.reshape((nc, h*w))
        cam = cam.sum(axis=0)
        cam = cam.reshape(h,w)
        cam = cam - np.min(cam)
        cam_img = cam / np.max(cam)
        cam_img = np.uint8(255 * cam_img)

        heatmap_bin = binImage(cv2.resize(cam_img, size_upsample))
        heatmap_maxconn = selectMaxConnect(heatmap_bin)
        heatmap_mask = heatmap_bin * heatmap_maxconn

        ind = np.argwhere(heatmap_mask != 0)
        minh = min(ind[:,0])
        minw = min(ind[:,1])
        maxh = max(ind[:,0])
        maxw = max(ind[:,1])
        
        # to ori image 
        image = ori_image[i].numpy().reshape(224,224,3)
        image = image[int(224*0.334):int(224*0.667),int(224*0.334):int(224*0.667),:]

        image = cv2.resize(image, size_upsample)
        image_crop = image[minh:maxh,minw:maxw,:] * 256 # because image was normalized before
        image_crop = preprocess(Image.fromarray(image_crop.astype('uint8')).convert('RGB')) 

        img_variable = torch.autograd.Variable(image_crop.reshape(3,224,224).unsqueeze(0).cuda())

        patchs_cuda = torch.cat((patchs_cuda,img_variable),0)

    return patchs_cuda


def binImage(heatmap):
    _, heatmap_bin = cv2.threshold(heatmap , 0 , 255 , cv2.THRESH_BINARY+cv2.THRESH_OTSU)
    # t in the paper
    #_, heatmap_bin = cv2.threshold(heatmap , 178 , 255 , cv2.THRESH_BINARY)
    return heatmap_bin


def selectMaxConnect(heatmap):
    labeled_img, num = label(heatmap, connectivity=2, background=0, return_num=True)    
    max_label = 0
    max_num = 0
    for i in range(1, num+1):
        if np.sum(labeled_img == i) > max_num:
            max_num = np.sum(labeled_img == i)
            max_label = i
    lcc = (labeled_img == max_label)
    if max_num == 0:
        lcc = (labeled_img == -1)
    lcc = lcc + 0
    return lcc 


def main():
    print('********************load data********************')
    normalize = transforms.Normalize([0.485, 0.456, 0.406],
                                     [0.229, 0.224, 0.225])

    train_dataset = ChestXrayDataSet(data_dir=DATA_DIR,
                                    image_list_file=TRAIN_IMAGE_LIST,
                                    transform=transforms.Compose([
                                        transforms.Resize(224),
                                        transforms.CenterCrop(224),
                                        transforms.ToTensor(),
                                        normalize,
                                    ]))
    train_loader = DataLoader(dataset=train_dataset, batch_size=BATCH_SIZE,
                             shuffle=True, num_workers=0, pin_memory=True)
    
    test_dataset = ChestXrayDataSet(data_dir=DATA_DIR,
                                    image_list_file=VAL_IMAGE_LIST,
                                    transform=transforms.Compose([
                                        transforms.Resize(256),
                                        transforms.CenterCrop(224),
                                        transforms.ToTensor(),
                                        normalize,
                                    ]))
    test_loader = DataLoader(dataset=test_dataset, batch_size=128,
                             shuffle=False, num_workers=0, pin_memory=True)
    print('********************load data succeed!********************')


    print('********************load model********************')
    # initialize and load the model
    Global_Branch_model = Densenet121_AG(pretrained = False, num_classes = N_CLASSES).cuda()
    Local_Branch_model = Densenet121_AG(pretrained = False, num_classes = N_CLASSES).cuda()
    Fusion_Branch_model = Fusion_Branch(input_size = 2048, output_size = N_CLASSES).cuda()

    if os.path.isfile(CKPT_PATH):
        print("=> loading checkpoint")
        checkpoint = torch.load(CKPT_PATH)
        # to load state
        # Code modified from torchvision densenet source for loading from pre .4 densenet weights.
        state_dict = checkpoint['state_dict']
        remove_data_parallel = True # Change if you don't want to use nn.DataParallel(model)

        pattern = re.compile(
            r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$')
        for key in list(state_dict.keys()):
            ori_key =  key
            key = key.replace('densenet121.','')
            #print('key',key)
            match = pattern.match(key)
            new_key = match.group(1) + match.group(2) if match else key
            new_key = new_key[7:] if remove_data_parallel else new_key
            #print('new_key',new_key)
            if '.0.' in new_key:
                new_key = new_key.replace('0.','')
            state_dict[new_key] = state_dict[ori_key]
            # Delete old key only if modified.
            if match or remove_data_parallel: 
                del state_dict[ori_key]
        
        Global_Branch_model.load_state_dict(state_dict)
        Local_Branch_model.load_state_dict(state_dict)
        print("=> loaded baseline checkpoint")
        
    else:
        print("=> no checkpoint found")

    if os.path.isfile(CKPT_PATH_G):
        checkpoint = torch.load(CKPT_PATH_G)
        Global_Branch_model.load_state_dict(checkpoint)
        print("=> loaded Global_Branch_model checkpoint")

    if os.path.isfile(CKPT_PATH_L):
        checkpoint = torch.load(CKPT_PATH_L)
        Local_Branch_model.load_state_dict(checkpoint)
        print("=> loaded Local_Branch_model checkpoint")

    if os.path.isfile(CKPT_PATH_F):
        checkpoint = torch.load(CKPT_PATH_F)
        Fusion_Branch_model.load_state_dict(checkpoint)
        print("=> loaded Fusion_Branch_model checkpoint")

    cudnn.benchmark = True
    criterion = nn.BCELoss()
    optimizer_global = optim.Adam(Global_Branch_model.parameters(), lr=LR_G, betas=(0.9, 0.999), eps=1e-08, weight_decay=1e-5)
    lr_scheduler_global = lr_scheduler.StepLR(optimizer_global , step_size = 10, gamma = 1)
    
    optimizer_local = optim.Adam(Local_Branch_model.parameters(), lr=LR_L, betas=(0.9, 0.999), eps=1e-08, weight_decay=1e-5)
    lr_scheduler_local = lr_scheduler.StepLR(optimizer_local , step_size = 10, gamma = 1)
    
    optimizer_fusion = optim.Adam(Fusion_Branch_model.parameters(), lr=LR_F, betas=(0.9, 0.999), eps=1e-08, weight_decay=1e-5)
    lr_scheduler_fusion = lr_scheduler.StepLR(optimizer_fusion , step_size = 15, gamma = 0.1)
    print('********************load model succeed!********************')

    print('********************begin training!********************')
    AUROCs_best = float('inf')
    for epoch in range(num_epochs):
        since = time.time()
        print('Epoch {}/{}'.format(epoch+1 , num_epochs))
        print('-' * 10)
        #set the mode of model
        lr_scheduler_global.step()  #about lr and gamma
        lr_scheduler_local.step() 
        lr_scheduler_fusion.step() 
        Global_Branch_model.train()  #set model to training mode
        Local_Branch_model.train()
        Fusion_Branch_model.train()

        running_loss = 0.0
        #Iterate over data
        for i, (input, target) in enumerate(train_loader):
            input_var = torch.autograd.Variable(input.cuda())
            target_var = torch.autograd.Variable(target.cuda())
            optimizer_global.zero_grad()
            optimizer_local.zero_grad()
            optimizer_fusion.zero_grad()

            # compute output
            output_global, fm_global, pool_global = Global_Branch_model(input_var)
            patchs_var = Attention_gen_patchs(input,fm_global)
            output_local, _, pool_local = Local_Branch_model(patchs_var)
            #print(fusion_var.shape)
            output_fusion = Fusion_Branch_model(pool_global, pool_local)
            #
            # loss
            loss1 = criterion(output_global, target_var)
            loss2 = criterion(output_local, target_var)
            loss3 = criterion(output_fusion, target_var)
            #
            loss = loss1*0.8 + loss2*0.1 + loss3*0.1 

            if (i%500) == 0: 
                print('step: {} totalloss: {loss:.3f} loss1: {loss1:.3f} loss2: {loss2:.3f} loss3: {loss3:.3f}'\
                      .format(i, loss = loss, loss1 = loss1, loss2 = loss2, loss3 = loss3))

            loss.backward() 
            optimizer_global.step()  
            optimizer_local.step()
            optimizer_fusion.step()

            #print(loss.data.item())
            running_loss += loss.data.item()

        epoch_loss = float(running_loss) / float(i)
        print(' Epoch over  Loss: {:.5f}'.format(epoch_loss))

        print('*******testing!*********')
        AUROCs_f = test(Global_Branch_model, Local_Branch_model, Fusion_Branch_model,test_loader)
        #break

        #save
        if epoch % 1 == 0 and AUROCs_best<AUROCs_f:
            AUROCs_best = AUROCs_f
            save_path = save_model_path
            torch.save(Global_Branch_model.state_dict(), save_path+save_model_name+'_Global'+'_epoch_'+str(epoch)+'.pkl')
            print('Global_Branch_model already save!')
            torch.save(Local_Branch_model.state_dict(), save_path+save_model_name+'_Local'+'_epoch_'+str(epoch)+'.pkl')
            print('Local_Branch_model already save!')
            torch.save(Fusion_Branch_model.state_dict(), save_path+save_model_name+'_Fusion'+'_epoch_'+str(epoch)+'.pkl')            
            print('Fusion_Branch_model already save!')

        time_elapsed = time.time() - since
        print('Training one epoch complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60 , time_elapsed % 60))
    

def test(model_global, model_local, model_fusion, test_loader):

    # initialize the ground truth and output tensor
    gt = torch.FloatTensor().cuda()
    pred_global = torch.FloatTensor().cuda()
    pred_local = torch.FloatTensor().cuda()
    pred_fusion = torch.FloatTensor().cuda()

    # switch to evaluate mode
    model_global.eval()
    model_local.eval()
    model_fusion.eval()
    cudnn.benchmark = True

    for i, (inp, target) in enumerate(test_loader):
        with torch.no_grad():
            target = target.cuda()
            gt = torch.cat((gt, target), 0)
            input_var = torch.autograd.Variable(inp.cuda())
            #output = model_global(input_var)

            output_global, fm_global, pool_global = model_global(input_var)
            
            patchs_var = Attention_gen_patchs(inp,fm_global)

            output_local, _, pool_local = model_local(patchs_var)

            output_fusion = model_fusion(pool_global,pool_local)

            pred_global = torch.cat((pred_global, output_global.data), 0)
            pred_local = torch.cat((pred_local, output_local.data), 0)
            pred_fusion = torch.cat((pred_fusion, output_fusion.data), 0)
            
            sys.stdout.write('\r testing process: = {}'.format(i+1))
            sys.stdout.flush()
            
    AUROCs_g = compute_AUCs(gt, pred_global)
    AUROC_avg = np.array(AUROCs_g).mean()
    print('Global branch: The average AUROC is {AUROC_avg:.3f}'.format(AUROC_avg=AUROC_avg))
    for i in range(N_CLASSES):
        print('The AUROC of {} is {:.4f}'.format(CLASS_NAMES[i], AUROCs_g[i]))

    AUROCs_l = compute_AUCs(gt, pred_local)
    AUROC_avg = np.array(AUROCs_l).mean()
    print('\n')
    print('Local branch: The average AUROC is {AUROC_avg:.3f}'.format(AUROC_avg=AUROC_avg))
    for i in range(N_CLASSES):
        print('The AUROC of {} is {:.4f}'.format(CLASS_NAMES[i], AUROCs_l[i]))

    AUROCs_f = compute_AUCs(gt, pred_fusion)
    AUROC_avg = np.array(AUROCs_f).mean()
    print('\n')
    print('Fusion branch: The average AUROC is {AUROC_avg:.3f}'.format(AUROC_avg=AUROC_avg))
    for i in range(N_CLASSES):
        print('The AUROC of {} is {:.4f}'.format(CLASS_NAMES[i], AUROCs_f[i]))
    
    return np.array(AUROCs_f).mean()


def compute_AUCs(gt, pred):
    """Computes Area Under the Curve (AUC) from prediction scores.
    Args:
        gt: Pytorch tensor on GPU, shape = [n_samples, n_classes]
          true binary labels.
        pred: Pytorch tensor on GPU, shape = [n_samples, n_classes]
          can either be probability estimates of the positive class,
          confidence values, or binary decisions.
    Returns:
        List of AUROCs of all classes.
    """
    AUROCs = []
    gt_np = gt.cpu().numpy()
    pred_np = pred.cpu().numpy()
    for i in range(N_CLASSES):
        AUROCs.append(roc_auc_score(gt_np[:, i], pred_np[:, i]))
    return AUROCs

if __name__ == '__main__':
    main()

********************load data********************
********************load data succeed!********************
********************load model********************
=> no checkpoint found
********************load model succeed!********************
********************begin training!********************
Epoch 1/50
----------
step: 0 totalloss: 0.705 loss1: 0.709 loss2: 0.662 loss3: 0.719
step: 500 totalloss: 0.656 loss1: 0.709 loss2: 0.661 loss3: 0.229
step: 1000 totalloss: 0.647 loss1: 0.706 loss2: 0.662 loss3: 0.153
step: 1500 totalloss: 0.639 loss1: 0.701 loss2: 0.663 loss3: 0.117
step: 2000 totalloss: 0.646 loss1: 0.702 loss2: 0.663 loss3: 0.177
 Epoch over  Loss: 0.64864
*******testing!*********
 testing process: = 88Global branch: The average AUROC is 0.475
The AUROC of Atelectasis is 0.5033
The AUROC of Cardiomegaly is 0.4456
The AUROC of Effusion is 0.4767
The AUROC of Infiltration is 0.4672
The AUROC of Mass is 0.4967
The AUROC of Nodule is 0.4978
The AUROC of Pneumonia is 0.5419
Th

step: 0 totalloss: 0.634 loss1: 0.685 loss2: 0.650 loss3: 0.204
step: 500 totalloss: 0.627 loss1: 0.686 loss2: 0.648 loss3: 0.136
step: 1000 totalloss: 0.627 loss1: 0.682 loss2: 0.646 loss3: 0.165
step: 1500 totalloss: 0.636 loss1: 0.685 loss2: 0.643 loss3: 0.233
step: 2000 totalloss: 0.631 loss1: 0.686 loss2: 0.642 loss3: 0.180
 Epoch over  Loss: 0.63019
*******testing!*********
 testing process: = 88Global branch: The average AUROC is 0.479
The AUROC of Atelectasis is 0.5026
The AUROC of Cardiomegaly is 0.4546
The AUROC of Effusion is 0.4770
The AUROC of Infiltration is 0.4799
The AUROC of Mass is 0.4986
The AUROC of Nodule is 0.5024
The AUROC of Pneumonia is 0.5410
The AUROC of Pneumothorax is 0.5231
The AUROC of Consolidation is 0.4689
The AUROC of Edema is 0.4057
The AUROC of Emphysema is 0.4308
The AUROC of Fibrosis is 0.4638
The AUROC of Pleural_Thickening is 0.5144
The AUROC of Hernia is 0.4491


Local branch: The average AUROC is 0.504
The AUROC of Atelectasis is 0.5851
The AU

step: 1500 totalloss: 0.613 loss1: 0.664 loss2: 0.626 loss3: 0.187
step: 2000 totalloss: 0.607 loss1: 0.661 loss2: 0.625 loss3: 0.159
 Epoch over  Loss: 0.61272
*******testing!*********
 testing process: = 88Global branch: The average AUROC is 0.480
The AUROC of Atelectasis is 0.5043
The AUROC of Cardiomegaly is 0.4546
The AUROC of Effusion is 0.4831
The AUROC of Infiltration is 0.5045
The AUROC of Mass is 0.4959
The AUROC of Nodule is 0.5047
The AUROC of Pneumonia is 0.5397
The AUROC of Pneumothorax is 0.5286
The AUROC of Consolidation is 0.4707
The AUROC of Edema is 0.4034
The AUROC of Emphysema is 0.4336
The AUROC of Fibrosis is 0.4650
The AUROC of Pleural_Thickening is 0.5225
The AUROC of Hernia is 0.4159


Local branch: The average AUROC is 0.507
The AUROC of Atelectasis is 0.5761
The AUROC of Cardiomegaly is 0.5062
The AUROC of Effusion is 0.4803
The AUROC of Infiltration is 0.5023
The AUROC of Mass is 0.4876
The AUROC of Nodule is 0.4823
The AUROC of Pneumonia is 0.5307
The AURO

 testing process: = 88Global branch: The average AUROC is 0.487
The AUROC of Atelectasis is 0.5098
The AUROC of Cardiomegaly is 0.4552
The AUROC of Effusion is 0.4867
The AUROC of Infiltration is 0.5281
The AUROC of Mass is 0.5059
The AUROC of Nodule is 0.5058
The AUROC of Pneumonia is 0.5266
The AUROC of Pneumothorax is 0.5335
The AUROC of Consolidation is 0.5015
The AUROC of Edema is 0.4294
The AUROC of Emphysema is 0.4307
The AUROC of Fibrosis is 0.4907
The AUROC of Pleural_Thickening is 0.5316
The AUROC of Hernia is 0.3793


Local branch: The average AUROC is 0.506
The AUROC of Atelectasis is 0.5749
The AUROC of Cardiomegaly is 0.4944
The AUROC of Effusion is 0.4815
The AUROC of Infiltration is 0.5030
The AUROC of Mass is 0.4956
The AUROC of Nodule is 0.4729
The AUROC of Pneumonia is 0.5188
The AUROC of Pneumothorax is 0.4701
The AUROC of Consolidation is 0.4734
The AUROC of Edema is 0.5675
The AUROC of Emphysema is 0.5314
The AUROC of Fibrosis is 0.5049
The AUROC of Pleural_Thicke

In [19]:
# encoding: utf-8
import re
import sys
import os
import cv2
import time
import numpy as np
import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
from torch.optim import lr_scheduler
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from sklearn.metrics import roc_auc_score
from skimage.measure import label
from PIL import Image

#np.set_printoptions(threshold = np.nan)


CKPT_PATH = ''

CKPT_PATH_G = '/data/tmpexec/AGCNN/AG_CNN_Global_epoch_1.pkl' 
CKPT_PATH_L = '/data/tmpexec/AGCNN/AG_CNN_Local_epoch_2.pkl' 
CKPT_PATH_F = '/data/tmpexec/AGCNN/AG_CNN_Fusion_epoch_23.pkl'

N_CLASSES = 14
CLASS_NAMES = [ 'Atelectasis', 'Cardiomegaly', 'Effusion', 'Infiltration', 'Mass', 'Nodule', 'Pneumonia',
                'Pneumothorax', 'Consolidation', 'Edema', 'Emphysema', 'Fibrosis', 'Pleural_Thickening', 'Hernia']

DATA_DIR = '/data/fjsdata/NIH-CXR/images/images/'
TRAIN_IMAGE_LIST = '/data/fjsdata/NIH-CXR/chexnet_dataset/train.txt'
TEST_IMAGE_LIST = '/data/fjsdata/NIH-CXR/chexnet_dataset/test.txt'

num_epochs = 50
BATCH_SIZE = 32

normalize = transforms.Normalize(
   mean=[0.485, 0.456, 0.406],
   std=[0.229, 0.224, 0.225]
)
preprocess = transforms.Compose([
   transforms.Resize((256,256)),
   transforms.CenterCrop(224),
   transforms.ToTensor(),
   normalize,
])


def Attention_gen_patchs(ori_image, fm_cuda):
    # fm => mask =>(+ ori-img) => crop = patchs
    feature_conv = fm_cuda.data.cpu().numpy()
    size_upsample = (224, 224) 
    bz, nc, h, w = feature_conv.shape

    patchs_cuda = torch.FloatTensor().cuda()

    for i in range(0, bz):
        feature = feature_conv[i]
        cam = feature.reshape((nc, h*w))
        cam = cam.sum(axis=0)
        cam = cam.reshape(h,w)
        cam = cam - np.min(cam)
        cam_img = cam / np.max(cam)
        cam_img = np.uint8(255 * cam_img)

        heatmap_bin = binImage(cv2.resize(cam_img, size_upsample))
        heatmap_maxconn = selectMaxConnect(heatmap_bin)
        heatmap_mask = heatmap_bin * heatmap_maxconn

        ind = np.argwhere(heatmap_mask != 0)
        minh = min(ind[:,0])
        minw = min(ind[:,1])
        maxh = max(ind[:,0])
        maxw = max(ind[:,1])
        
        # to ori image 
        image = ori_image[i].numpy().reshape(224,224,3)
        image = image[int(224*0.334):int(224*0.667),int(224*0.334):int(224*0.667),:]

        image = cv2.resize(image, size_upsample)
        image_crop = image[minh:maxh,minw:maxw,:] * 256 # because image was normalized before
        image_crop = preprocess(Image.fromarray(image_crop.astype('uint8')).convert('RGB')) 

        img_variable = torch.autograd.Variable(image_crop.reshape(3,224,224).unsqueeze(0).cuda())

        patchs_cuda = torch.cat((patchs_cuda,img_variable),0)

    return patchs_cuda


def binImage(heatmap):
    _, heatmap_bin = cv2.threshold(heatmap , 0 , 255 , cv2.THRESH_BINARY+cv2.THRESH_OTSU)
    # t in the paper
    #_, heatmap_bin = cv2.threshold(heatmap , 178 , 255 , cv2.THRESH_BINARY)
    return heatmap_bin


def selectMaxConnect(heatmap):
    labeled_img, num = label(heatmap, connectivity=2, background=0, return_num=True)    
    max_label = 0
    max_num = 0
    for i in range(1, num+1):
        if np.sum(labeled_img == i) > max_num:
            max_num = np.sum(labeled_img == i)
            max_label = i
    lcc = (labeled_img == max_label)
    if max_num == 0:
        lcc = (labeled_img == -1)
    lcc = lcc + 0
    return lcc 


def main():
    print('********************load data********************')
    normalize = transforms.Normalize([0.485, 0.456, 0.406],
                                     [0.229, 0.224, 0.225])

    test_dataset = ChestXrayDataSet(data_dir=DATA_DIR,
                                    image_list_file=TEST_IMAGE_LIST,
                                    transform=transforms.Compose([
                                        transforms.Resize(256),
                                        transforms.CenterCrop(224),
                                        transforms.ToTensor(),
                                        normalize,
                                    ]))
    test_loader = DataLoader(dataset=test_dataset, batch_size=128,
                             shuffle=False, num_workers=0, pin_memory=True)
    print('********************load data succeed!********************')


    print('********************load model********************')
    # initialize and load the model
    Global_Branch_model = Densenet121_AG(pretrained = False, num_classes = N_CLASSES).cuda()
    Local_Branch_model = Densenet121_AG(pretrained = False, num_classes = N_CLASSES).cuda()
    Fusion_Branch_model = Fusion_Branch(input_size = 2048, output_size = N_CLASSES).cuda()

    if os.path.isfile(CKPT_PATH_G):
        checkpoint = torch.load(CKPT_PATH_G)
        Global_Branch_model.load_state_dict(checkpoint)
        print("=> loaded Global_Branch_model checkpoint")

    if os.path.isfile(CKPT_PATH_L):
        checkpoint = torch.load(CKPT_PATH_L)
        Local_Branch_model.load_state_dict(checkpoint)
        print("=> loaded Local_Branch_model checkpoint")

    if os.path.isfile(CKPT_PATH_F):
        checkpoint = torch.load(CKPT_PATH_F)
        Fusion_Branch_model.load_state_dict(checkpoint)
        print("=> loaded Fusion_Branch_model checkpoint")

    cudnn.benchmark = True
    print('******************** load model succeed!********************')

    print('******* begin testing!*********')
    test(Global_Branch_model, Local_Branch_model, Fusion_Branch_model,test_loader)

def test(model_global, model_local, model_fusion, test_loader):

    # initialize the ground truth and output tensor
    gt = torch.FloatTensor().cuda()
    pred_global = torch.FloatTensor().cuda()
    pred_local = torch.FloatTensor().cuda()
    pred_fusion = torch.FloatTensor().cuda()

    # switch to evaluate mode
    model_global.eval()
    model_local.eval()
    model_fusion.eval()
    cudnn.benchmark = True

    for i, (inp, target) in enumerate(test_loader):
        with torch.no_grad():     
            target = target.cuda()
            gt = torch.cat((gt, target), 0)
            input_var = torch.autograd.Variable(inp.cuda())

            output_global, fm_global, pool_global = model_global(input_var)
            
            patchs_var = Attention_gen_patchs(inp,fm_global)

            output_local, _, pool_local = model_local(patchs_var)

            output_fusion = model_fusion(pool_global,pool_local)

            pred_global = torch.cat((pred_global, output_global.data), 0)
            pred_local = torch.cat((pred_local, output_local.data), 0)
            pred_fusion = torch.cat((pred_fusion, output_fusion.data), 0)
            
            sys.stdout.write('\r testing process: = {}'.format(i+1))
            sys.stdout.flush()
            
    AUROCs_g = compute_AUCs(gt, pred_global)
    AUROC_avg = np.array(AUROCs_g).mean()
    print('Global branch: The average AUROC is {AUROC_avg:.3f}'.format(AUROC_avg=AUROC_avg))
    for i in range(N_CLASSES):
        print('The AUROC of {} is {}'.format(CLASS_NAMES[i], AUROCs_g[i]))

    AUROCs_l = compute_AUCs(gt, pred_local)
    AUROC_avg = np.array(AUROCs_l).mean()
    print('\n')
    print('Local branch: The average AUROC is {AUROC_avg:.3f}'.format(AUROC_avg=AUROC_avg))
    for i in range(N_CLASSES):
        print('The AUROC of {} is {}'.format(CLASS_NAMES[i], AUROCs_l[i]))

    AUROCs_f = compute_AUCs(gt, pred_fusion)
    AUROC_avg = np.array(AUROCs_f).mean()
    print('\n')
    print('Fusion branch: The average AUROC is {AUROC_avg:.3f}'.format(AUROC_avg=AUROC_avg))
    for i in range(N_CLASSES):
        print('The AUROC of {} is {}'.format(CLASS_NAMES[i], AUROCs_f[i]))


def compute_AUCs(gt, pred):
    """Computes Area Under the Curve (AUC) from prediction scores.
    Args:
        gt: Pytorch tensor on GPU, shape = [n_samples, n_classes]
          true binary labels.
        pred: Pytorch tensor on GPU, shape = [n_samples, n_classes]
          can either be probability estimates of the positive class,
          confidence values, or binary decisions.
    Returns:
        List of AUROCs of all classes.
    """
    AUROCs = []
    gt_np = gt.cpu().numpy()
    pred_np = pred.cpu().numpy()
    for i in range(N_CLASSES):
        AUROCs.append(roc_auc_score(gt_np[:, i], pred_np[:, i]))
    return AUROCs

if __name__ == '__main__':
    main()

********************load data********************
********************load data succeed!********************
********************load model********************
=> loaded Global_Branch_model checkpoint
=> loaded Local_Branch_model checkpoint
=> loaded Fusion_Branch_model checkpoint
******************** load model succeed!********************
******* begin testing!*********
 testing process: = 176Global branch: The average AUROC is 0.839
The AUROC of Atelectasis is 0.8235533370251485
The AUROC of Cardiomegaly is 0.9149242739132466
The AUROC of Effusion is 0.8855035262218596
The AUROC of Infiltration is 0.7073808055682215
The AUROC of Mass is 0.8428857286111491
The AUROC of Nodule is 0.7757692210739041
The AUROC of Pneumonia is 0.77675317705674
The AUROC of Pneumothorax is 0.8680464132603121
The AUROC of Consolidation is 0.8127160925962797
The AUROC of Edema is 0.8931314367524132
The AUROC of Emphysema is 0.9152569924536593
The AUROC of Fibrosis is 0.8239410556238518
The AUROC of Pleural_

In [28]:
CKPT_PATH_G = '/data/tmpexec/AGCNN/paper/AG_CNN_Global_epoch_1.pkl' 
CKPT_PATH_L = '/data/tmpexec/AGCNN/paper/AG_CNN_Local_epoch_2.pkl' 
CKPT_PATH_F = '/data/tmpexec/AGCNN/paper/AG_CNN_Fusion_epoch_23.pkl'
# initialize and load the model
Global_Branch_model = Densenet121_AG(pretrained = False, num_classes = N_CLASSES).cuda()
Local_Branch_model = Densenet121_AG(pretrained = False, num_classes = N_CLASSES).cuda()
Fusion_Branch_model = Fusion_Branch(input_size = 2048, output_size = N_CLASSES).cuda()

if os.path.isfile(CKPT_PATH_G):
    checkpoint = torch.load(CKPT_PATH_G)
    Global_Branch_model.load_state_dict(checkpoint)
    print("=> loaded Global_Branch_model checkpoint")

if os.path.isfile(CKPT_PATH_L):
    checkpoint = torch.load(CKPT_PATH_L)
    Local_Branch_model.load_state_dict(checkpoint)
    print("=> loaded Local_Branch_model checkpoint")

if os.path.isfile(CKPT_PATH_F):
    checkpoint = torch.load(CKPT_PATH_F)
    Fusion_Branch_model.load_state_dict(checkpoint)
    print("=> loaded Fusion_Branch_model checkpoint")
        
#performance of testset
# initialize the ground truth and output tensor
Global_Branch_model.eval()
Local_Branch_model.eval()
Fusion_Branch_model.eval()
cudnn.benchmark = True
# initialize the ground truth and output tensor
gt = torch.FloatTensor().cuda()
pred_global = torch.FloatTensor().cuda()
pred_local = torch.FloatTensor().cuda()
pred_fusion = torch.FloatTensor().cuda()
num_batches = len(teY) // batchSize  +1
with torch.no_grad():     
    for i in range(num_batches):
        min_idx = i * batchSize
        max_idx = np.min([len(teY), (i+1)*batchSize])
        I_batch = torch.from_numpy(teI[min_idx:max_idx]).type(torch.FloatTensor).cuda()
        y_batch = torch.from_numpy(teY[min_idx:max_idx]).type(torch.FloatTensor).cuda()
        gt = torch.cat((gt, y_batch), 0)
        
        output_global, fm_global, pool_global = Global_Branch_model(I_batch)
        patchs_var = Attention_gen_patchs(I_batch.data.cpu(),fm_global)
        output_local, _, pool_local = Local_Branch_model(patchs_var)
        output_fusion = Fusion_Branch_model(pool_global,pool_local)

        pred_global = torch.cat((pred_global, output_global.data), 0)
        pred_local = torch.cat((pred_local, output_local.data), 0)
        pred_fusion = torch.cat((pred_fusion, output_fusion.data), 0)
            
        sys.stdout.write('\r {} / {} '.format(i, num_batches))
        sys.stdout.flush()
        
            
AUROCs_g = compute_AUCs(gt, pred_global)
AUROC_avg = np.array(AUROCs_g).mean()
print('Global branch: The average AUROC is {AUROC_avg:.4f}'.format(AUROC_avg=AUROC_avg))
for i in range(N_CLASSES):
    print('The AUROC of {} is {:.4f}'.format(CLASS_NAMES[i], AUROCs_g[i]))

AUROCs_l = compute_AUCs(gt, pred_local)
AUROC_avg = np.array(AUROCs_l).mean()
print('\n')
print('Local branch: The average AUROC is {AUROC_avg:.4f}'.format(AUROC_avg=AUROC_avg))
for i in range(N_CLASSES):
    print('The AUROC of {} is {:.4f}'.format(CLASS_NAMES[i], AUROCs_l[i]))

AUROCs_f = compute_AUCs(gt, pred_fusion)
AUROC_avg = np.array(AUROCs_f).mean()
print('\n')
print('Fusion branch: The average AUROC is {AUROC_avg:.4f}'.format(AUROC_avg=AUROC_avg))
for i in range(N_CLASSES):
    print('The AUROC of {} is {:.4f}'.format(CLASS_NAMES[i], AUROCs_f[i]))

=> loaded Global_Branch_model checkpoint
=> loaded Local_Branch_model checkpoint
=> loaded Fusion_Branch_model checkpoint
 2243 / 2244 Global branch: The average AUROC is 0.838
The AUROC of Atelectasis is 0.8262332376517247
The AUROC of Cardiomegaly is 0.9141369201374948
The AUROC of Effusion is 0.8859380050537341
The AUROC of Infiltration is 0.7063778922034437
The AUROC of Mass is 0.8432511426310141
The AUROC of Nodule is 0.7738812774201932
The AUROC of Pneumonia is 0.7756925132704011
The AUROC of Pneumothorax is 0.8673377455555967
The AUROC of Consolidation is 0.8131091098653926
The AUROC of Edema is 0.8914614273178907
The AUROC of Emphysema is 0.9123397885676864
The AUROC of Fibrosis is 0.8216879678365977
The AUROC of Pleural_Thickening is 0.7816199794739345
The AUROC of Hernia is 0.922022241079005


Local branch: The average AUROC is 0.791
The AUROC of Atelectasis is 0.7862530677373757
The AUROC of Cardiomegaly is 0.8832458854022425
The AUROC of Effusion is 0.8591878314338008
The A

In [2]:
def Image_Processing(img_path, crop_size=224):
    img = Image.open(img_path).convert('RGB').resize((256, 256),Image.ANTIALIAS) #open and resize
    #crop and normalize
    transform_sequence = transforms.Compose([
                                             #transforms.ToPILImage(), #if not PILImage
                                             transforms.CenterCrop(crop_size),
                                             #transforms.RandomCrop(crop_size),
                                             #transforms.RandomHorizontalFlip(),
                                             transforms.ToTensor(), # range [0, 255] -> [0.0,1.0]
                                             transforms.Normalize([0.485, 0.456, 0.406],[0.229, 0.224, 0.225])
                                            ])
    img = transform_sequence(img).numpy() #tensor to numpy
    return img

CLASS_NAMES = ['Atelectasis', 'Cardiomegaly', 'Effusion','Infiltration', 'Mass', 'Nodule', 'Pneumonia','Pneumothorax', \
               'Consolidation', 'Edema', 'Emphysema', 'Fibrosis', 'Pleural_Thickening', 'Hernia'] 
N_CLASSES = len(CLASS_NAMES) #class numbers
def compute_AUCs(gt, pred):
    AUROCs = []
    gt_np = gt.cpu().numpy()
    pred_np = pred.cpu().numpy()
    for i in range(N_CLASSES):
        AUROCs.append(roc_auc_score(gt_np[:, i], pred_np[:, i]))
    return AUROCs 

img_path = '/data/fjsdata/NIH-CXR/images/images/' 
#preparing the trainset and  testset
trN, trI, trY = [],[],[]
with open('/data/fjsdata/NIH-CXR/chexnet_dataset/train.txt', "r") as file_descriptor: #tarinset
    lines = file_descriptor.readlines()
    for line in lines:
        #if(len(trN)>1001): break
        try:
            line_items = line.split()
            image_name = line_items[0].split('/')[1]
            trN.append(image_name)
            image_label = line_items[1:]  # 14 labels from index 2
            image_label = [int(i) for i in image_label]  
            trY.append(np.array(image_label))
            img = Image_Processing(os.path.join(img_path, image_name))
            trI.append(img)
        except:
            print(image_name+":"+str(os.path.join(img_path, image_name)))
        sys.stdout.write('\r{} / {} '.format(len(trN),78468))
        sys.stdout.flush()
trI = np.array(trI)
trY = np.array(trY)   
print('The length of trainset is %d'%len(trN))
        
valN, valI, valY = [],[],[]
with open('/data/fjsdata/NIH-CXR/chexnet_dataset/val.txt', "r") as file_descriptor: #valset
    lines = file_descriptor.readlines()
    for line in lines:
        #if(len(valN)>1001): break
        try:
            line_items = line.split()
            image_name = line_items[0].split('/')[1]
            valN.append(image_name)
            image_label = line_items[1:]  # 14 labels from index 2
            image_label = [int(i) for i in image_label]  
            valY.append(np.array(image_label))
            img = Image_Processing(os.path.join(img_path, image_name))
            valI.append(img)
        except:
            print(image_name+":"+str(os.path.join(img_path, image_name)))
        sys.stdout.write('\r{} / {} '.format(len(valN),11219))
        sys.stdout.flush()
valI = np.array(valI)
valY = np.array(valY) 
print('The length of validset is %d'%len(valN))

teN, teI, teY = [],[],[]
with open('/data/fjsdata/NIH-CXR/chexnet_dataset/test.txt', "r") as file_descriptor: #testset
    lines = file_descriptor.readlines()
    for line in lines:
        #if(len(teN)>1001): break
        try:
            line_items = line.split()
            image_name = line_items[0].split('/')[1]
            teN.append(image_name)
            image_label = line_items[1:]  # 14 labels from index 2
            image_label = [int(i) for i in image_label]  
            teY.append(np.array(image_label))
            img = Image_Processing(os.path.join(img_path, image_name))                    
            teI.append(img)
        except:
            print(image_name+":"+str(os.path.join(img_path, image_name)))
        sys.stdout.write('\r{} / {} '.format(len(teN),22433))
        sys.stdout.flush()
teI = np.array(teI)
teY = np.array(teY)    
print('The length of testset is %d'%len(teN))

#preparing bounding box dataset
boxdata = pd.read_csv("/data/fjsdata/NIH-CXR/chexnet_dataset/fjs_BBox.csv" , sep=',')
boxdata = boxdata[['Image Index','Finding Label','Bbox [x', 'y', 'w', 'h]']]
#print('Dataset statistic, records: %d, fields: %d' % (boxdata.shape[0], boxdata.shape[1]))
#print(boxdata.columns.values.tolist())
bbN, bbI, bbY, bBox = [],[],[],[]
for _, row in boxdata.iterrows():
    bbN.append(row['Image Index'])
    
    img = Image_Processing(os.path.join(img_path, row['Image Index']))
    bbI.append(img)
    
    labels = np.zeros(len(CLASS_NAMES))
    labels[CLASS_NAMES.index(row['Finding Label'])] = 1
    bbY.append(labels)
    
    bBox.append(np.array([row['Bbox [x'], row['y'], row['w'], row['h]']])) #xywh  
print('The length of boxset is %d'%len(bbN))
bbI = np.array(bbI)
bbY = np.array(bbY)
bBox = np.array(bBox)

78468 / 78468 The length of trainset is 78468
11219 / 11219 The length of validset is 11219
22433 / 22433 The length of testset is 22433
The length of boxset is 984


In [3]:
#construct model
class DenseNet121_AG(nn.Module):
    def __init__(self, num_classes, is_pre_trained, fusion_size=2048):
        super(DenseNet121_AG, self).__init__()
        self.dense_net_121 = torchvision.models.densenet121(pretrained=is_pre_trained)
        #num_fc_kernels = self.dense_net_121.classifier.in_features
        #self.dense_net_121.classifier = nn.Sequential(nn.Linear(num_fc_kernels, num_classes), nn.Sigmoid())
        #for fusion
        self.fc = nn.Linear(fusion_size, num_classes)
        self.Sigmoid = nn.Sigmoid()
        

    def forward(self, x):
        #global
        features_g = self.dense_net_121.features(x)
        out_g = F.relu(features_g, inplace=True)
        pooling_g = F.avg_pool2d(out_g, kernel_size=7, stride=1).view(features_g.size(0), -1)
        #out_g = self.dense_net_121.classifier(features_g)
        #local
        patchs_var = Attention_gen_patchs(x, features_g)
        features_l = self.dense_net_121.features(patchs_var)
        out_l = F.relu(features_l, inplace=True)
        pooling_l = F.avg_pool2d(out_l, kernel_size=7, stride=1).view(features_l.size(0), -1) 
        #fusion
        fusion = torch.cat((pooling_g,pooling_l), 1)
        out = self.fc(fusion)
        out = self.Sigmoid(out)
        
        return out

def binImage(heatmap):
    _, heatmap_bin = cv2.threshold(heatmap , 0 , 255 , cv2.THRESH_BINARY+cv2.THRESH_OTSU)
    # t in the paper
    #_, heatmap_bin = cv2.threshold(heatmap , 178 , 255 , cv2.THRESH_BINARY)
    return heatmap_bin


def selectMaxConnect(heatmap):
    labeled_img, num = label(heatmap, connectivity=2, background=0, return_num=True)    
    max_label = 0
    max_num = 0
    for i in range(1, num+1):
        if np.sum(labeled_img == i) > max_num:
            max_num = np.sum(labeled_img == i)
            max_label = i
    lcc = (labeled_img == max_label)
    if max_num == 0:
        lcc = (labeled_img == -1)
    lcc = lcc + 0
    return lcc 

normalize = transforms.Normalize(
   mean=[0.485, 0.456, 0.406],
   std=[0.229, 0.224, 0.225]
)
preprocess = transforms.Compose([
   transforms.Resize((256,256)),
   transforms.CenterCrop(224),
   transforms.ToTensor(),
   normalize,
])

def Attention_gen_patchs(ori_image, fm_cuda):
    # feature map -> feature mask (using feature map to crop on the original image) -> crop -> patchs
    feature_conv = fm_cuda.data.cpu().numpy()
    ori_image = ori_image.data.cpu().numpy()
    size_upsample = (224, 224) 
    bz, nc, h, w = feature_conv.shape

    patchs_cuda = torch.FloatTensor().cuda()

    for i in range(0, bz):
        feature = feature_conv[i]
        cam = feature.reshape((nc, h*w))
        cam = cam.sum(axis=0)
        cam = cam.reshape(h,w)
        cam = cam - np.min(cam)
        cam_img = cam / np.max(cam)
        cam_img = np.uint8(255 * cam_img)

        heatmap_bin = binImage(cv2.resize(cam_img, size_upsample))
        heatmap_maxconn = selectMaxConnect(heatmap_bin)
        heatmap_mask = heatmap_bin * heatmap_maxconn

        ind = np.argwhere(heatmap_mask != 0)
        minh = min(ind[:,0])
        minw = min(ind[:,1])
        maxh = max(ind[:,0])
        maxw = max(ind[:,1])
        
        # to ori image 
        #image = ori_image[i].numpy().reshape(224,224,3)
        image = ori_image[i].reshape(224,224,3)
        image = image[int(224*0.334):int(224*0.667),int(224*0.334):int(224*0.667),:]

        image = cv2.resize(image, size_upsample)
        image_crop = image[minh:maxh,minw:maxw,:] * 256 # because image was normalized before
        image_crop = preprocess(Image.fromarray(image_crop.astype('uint8')).convert('RGB')) 

        img_variable = torch.autograd.Variable(image_crop.reshape(3,224,224).unsqueeze(0).cuda())

        patchs_cuda = torch.cat((patchs_cuda,img_variable),0)

    return patchs_cuda

In [5]:
model = DenseNet121_AG(num_classes=N_CLASSES, is_pre_trained=True).cuda()#initialize model
#model = torch.nn.DataParallel(model, device_ids=[0, 1, 2, 3, 4, 5, 6, 7]).cuda()# make model available multi GPU cores training
torch.backends.cudnn.benchmark = True  # improve train speed slightly
optimizer = torch.optim.Adam(model.parameters(), lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=1e-5)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.1, patience=5, mode='min')
criterion = torch.nn.BCELoss()
#train model
best_net, best_loss = None, float('inf')
AUROC_best = 0.
batchSize = 10 #'Batch Size': 32
for epoch in range(10):#'Max Epoch': 50
    model.train()  # set network as train mode
    shuffled_idx = np.random.permutation(np.arange(len(trY)))
    num_batches = len(shuffled_idx) // batchSize + 1
    with torch.autograd.enable_grad():
        for i in range(num_batches):
            optimizer.zero_grad()#grad vanish
            min_idx = i * batchSize
            max_idx = np.min([len(shuffled_idx), (i+1)*batchSize])
            selected_idx = shuffled_idx[min_idx:max_idx]
            I_batch = torch.from_numpy(trI[selected_idx]).type(torch.FloatTensor).cuda()
            y_batch = torch.from_numpy(trY[selected_idx]).type(torch.FloatTensor).cuda()
            #forword
            y_outputs = model(I_batch)#permute the dims of matrix， .permute(0, 3, 1, 2)
            #loss
            loss = criterion(y_outputs, y_batch)
            loss.backward()
            #update parameters
            optimizer.step()
            sys.stdout.write('\r {} / {} : train loss = {}'.format(i+1, num_batches, float('%0.6f'%loss.item())))
            sys.stdout.flush()     
    #validation process
    gt = torch.FloatTensor().cuda()
    pred = torch.FloatTensor().cuda()
    loss_val = []
    mean_loss_tensor = 0.
    num_batches = len(valY) // batchSize  +1
    model.eval()  # set network as eval mode without BN & Dropout
    with torch.autograd.no_grad():
        for j in range(num_batches):
            min_idx = j * batchSize
            max_idx = np.min([len(valY), (j+1)*batchSize])
            I_batch = torch.from_numpy(valI[min_idx:max_idx]).type(torch.FloatTensor).cuda()
            y_batch = torch.from_numpy(valY[min_idx:max_idx]).type(torch.FloatTensor).cuda()
            y_outputs = model(I_batch)#forword， .permute(0, 3, 1, 2)
            curr_loss = criterion(y_outputs, y_batch)
            gt = torch.cat((gt, y_batch), 0)
            pred = torch.cat((pred, y_outputs.data), 0)
            sys.stdout.write('\r {} / {} : validation loss = {}'.format(j + 1, num_batches, float('%0.6f'%curr_loss.item()) ) )
            sys.stdout.flush()  
            mean_loss_tensor += curr_loss  # tensor op.
            loss_val.append(curr_loss.item())
    mean_loss_tensor = mean_loss_tensor / len(valY)  # tensor
    scheduler.step(mean_loss_tensor.item())
    AUROCs = compute_AUCs(gt, pred)
    AUROC_avg =  np.array(AUROCs).mean()
    print("\r Eopch: %5d val_loss = %.6f avg_auroc= %.6f" % (epoch + 1, np.mean(loss_val), AUROC_avg)) 
    #if np.mean(loss_val) < best_loss:
    if AUROC_avg > AUROC_best:
        best_loss = np.mean(loss_val)
        AUROC_best = AUROC_avg
        best_net = copy.deepcopy(model)        
print("\r best_loss = %.6f best_auroc = %0.6f" % (best_loss, AUROC_best))
model = model.cpu()#release gpu memory
torch.cuda.empty_cache()

 Eopch:     1 val_loss = 0.174227 avg_auroc= 0.636918
 Eopch:     2 val_loss = 0.173026 avg_auroc= 0.670172
 Eopch:     3 val_loss = 0.175175 avg_auroc= 0.687696
 5899 / 7847 : train loss = 0.088487

KeyboardInterrupt: 

In [6]:
#performance of testset
# initialize the ground truth and output tensor
gt = torch.FloatTensor().cuda()
pred = torch.FloatTensor().cuda()
num_batches = len(teY) // batchSize  +1
best_net.eval()  # set network as eval mode without BN & Dropout
with torch.autograd.no_grad():
    for i in range(num_batches):
        min_idx = i * batchSize
        max_idx = np.min([len(teY), (i+1)*batchSize])
        I_batch = torch.from_numpy(teI[min_idx:max_idx]).type(torch.FloatTensor).cuda()
        y_batch = torch.from_numpy(teY[min_idx:max_idx]).type(torch.FloatTensor).cuda()
        gt = torch.cat((gt, y_batch), 0)
        y_outputs = best_net(I_batch)#forword，.permute(0, 3, 1, 2)
        pred = torch.cat((pred, y_outputs.data), 0)
        sys.stdout.write('\r {} / {} '.format(i, num_batches))
        sys.stdout.flush()

AUROCs = compute_AUCs(gt, pred)
AUROC_avg = np.array(AUROCs).mean()
print('The average AUROC is {AUROC_avg:.4f}'.format(AUROC_avg=AUROC_avg))
for i in range(N_CLASSES):
    print('The AUROC of {} is {:.4f}'.format(CLASS_NAMES[i], AUROCs[i]))

 2243 / 2244 The average AUROC is 0.6770
The AUROC of Atelectasis is 0.6664
The AUROC of Cardiomegaly is 0.6957
The AUROC of Effusion is 0.7541
The AUROC of Infiltration is 0.6342
The AUROC of Mass is 0.5953
The AUROC of Nodule is 0.5806
The AUROC of Pneumonia is 0.6538
The AUROC of Pneumothorax is 0.6684
The AUROC of Consolidation is 0.7341
The AUROC of Edema is 0.8026
The AUROC of Emphysema is 0.6738
The AUROC of Fibrosis is 0.6843
The AUROC of Pleural_Thickening is 0.6295
The AUROC of Hernia is 0.7047
