In [1]:
import argparse
import os
os.environ["CUDA_VISIBLE_DEVICES"]='3'
import sys
import copy
import matplotlib.pyplot as plt
import matplotlib as mpl
sys.path.append('../')
import numpy as np
from return_data import return_data
#from scipy import misc
#import cv2
import torch
import torch.nn as nn
from torch.autograd import Variable
from utils import str2bool, label2binary, cuda, idxtobool, UnknownDatasetError, UnknownModelError, index_transfer, save_batch, save_batch2
from pathlib import Path
from torch.nn import functional as F
from sklearn.preprocessing import label_binarize
from sklearn.metrics import f1_score, precision_score, recall_score, roc_auc_score
import matplotlib
matplotlib.use('agg')
import matplotlib.pyplot as plt
matplotlib.pyplot.switch_backend('agg')
##https://github.com/lightdogs/pytorch-smoothgrad
from lib.gradients import VanillaGrad, SmoothGrad #, GuidedBackpropGrad, GuidedBackpropSmoothGrad
from lib.image_utils import Segmentation
from lib.labels import IMAGENET_LABELS

#%%    
def parse_args():

    parser = argparse.ArgumentParser()
    
    parser.add_argument('--dataset', default='mnist', type = str, help='dataset name: imdb, imdb, mnist')
    parser.add_argument('--default_dir', default='.', type = str, help='default directory path')
    parser.add_argument('--data_dir', default='"../mnist/dataset/Dataset_BUSI_AN/train/images"', type = str, help='data directory path')
    parser.add_argument('--method', default='smoothgrad', type=str, help = 'interpretable ML method: saliency, taylor')
    parser.add_argument('--batch_size', type=int, default=8, metavar='N', help='input batch size for training')
    parser.add_argument('--model_name', default='original_BUSI7.ckpt', type=str, help = 'if train is True, model name to be saved, otherwise model name to be loaded')
    #parser.add_argument('--chunk_size', default = 1, type = int, help='chunk size. for image, chunk x chunk will be the actual chunk size')
    parser.add_argument('--chunk_size', default=8, type=int, help='chunk size. for image, chunk x chunk will be the actual chunk size')
    parser.add_argument('--cuda', default=True, type=str2bool, help = 'enable cuda')
    parser.add_argument('--out_dir', type=str, default='./result/saliency/BUSI_8_50_2', help='Result directory path')
    parser.add_argument('--K', type=int, default=50, help='dimension of encoding Z')
    
    args = parser.parse_args([])
    
    args.cuda = args.cuda and torch.cuda.is_available()
    if args.cuda:
        print("Using GPU for acceleration")
    else:
        print("Using CPU for computation")
    
    print('Input data: {}'.format(args.dataset))

    return args
#%%  
def main():
     
    args = parse_args()

    if not os.path.exists(args.out_dir):

        os.makedirs(args.out_dir)

    ## Data Loader
    args.root = "../mnist/dataset/Dataset_BUSI_AN/train/images"
    args.load_pred = False
    args.device = torch.device("cuda" if args.cuda else "cpu")
    args.model_dir = '../' + args.dataset + '/models'
    device = torch.device("cuda" if args.cuda else "cpu")
    
    data_loader = return_data(args)
    test_loader = data_loader['test']
    
    if 'mnist' in args.dataset:
    
        from Net import Net
        ## load model
        model = Net().to(device) 
        
        args.word_idx = None
        args.original_ncol = 224
        args.original_nrow = 224
        args.chunk_size = args.chunk_size if args.chunk_size > 0 else 1
        assert np.remainder(args.original_nrow, args.chunk_size) == 0
        args.filter_size = (args.chunk_size, args.chunk_size)
        
    else:
    
        raise UnknownDatasetError()
            

    model_name = Path(args.model_dir).joinpath(args.model_name)
    model.load_state_dict(torch.load(model_name, map_location='cpu'))

    if args.cuda:
        model.cuda()

    ## Prediction
    test(args, model, device, test_loader, k=args.K)
    
    
def test(args, model, device, test_loader, k, **kargs):
    '''
    k: the number of raw features selected
    '''

    model.eval()
    # test_loss = 0
    total_num = 0
    total_num_ind = 0
    # correct = 0
    
    correct_zeropadded = 0
    precision_macro_zeropadded = 0  
    precision_micro_zeropadded = 0
    precision_weighted_zeropadded = 0
    recall_macro_zeropadded = 0
    recall_micro_zeropadded = 0
    recall_weighted_zeropadded = 0
    f1_macro_zeropadded = 0
    f1_micro_zeropadded = 0
    f1_weighted_zeropadded = 0

    vmi_zeropadded_sum = 0
    vmi_fidel_sum = 0
    vmi_fidel_fixed_sum = 0
    
    correct_approx = 0
    precision_macro_approx = 0
    precision_micro_approx = 0
    precision_weighted_approx = 0
    recall_macro_approx = 0
    recall_micro_approx = 0
    recall_weighted_approx = 0
    f1_macro_approx = 0
    f1_micro_approx = 0
    f1_weighted_approx = 0
    
    correct_approx_fixed = 0
    precision_macro_approx_fixed = 0
    precision_micro_approx_fixed = 0
    precision_weighted_approx_fixed = 0
    recall_macro_approx_fixed = 0
    recall_micro_approx_fixed = 0
    recall_weighted_approx_fixed = 0
    f1_macro_approx_fixed = 0
    f1_micro_approx_fixed = 0
    f1_weighted_approx_fixed = 0    
        
    is_cuda = args.cuda       
                        
    # outmode = "TEST"
    #
    # if outfile:
    #    assert kargs['outmode'] in ['train', 'test', 'valid']
    #    outmode = kargs['outmode']
    
    # with torch.no_grad():
        
    # predictions = []
    # predictions_idx = []

    for idx, batch in enumerate(test_loader):  # (data, target, _, _)

        if 'mnist' in args.dataset:
            num_labels = 2
            data = batch[0]
            target = batch[1]
            idx_list = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 17,]
            
        else:
            raise UnknownDatasetError()
                   
        data, target = data.to(device), target.to(device)
        output_all, output_all2 = model(data)
        # test_loss += F.cross_entropy(output_all, target, reduction = 'sum').item()
        pred = output_all.max(-1, keepdim=True)[1] # get the index of the max log-probability
        # correct += pred.eq(target.view_as(pred)).sum().item()
        total_num += 1
        total_num_ind += data.size(0)
           
        for i in range(data.size(0)):
            
            if 'mnist' in args.dataset:
            
                ## Calculate Gradient
                input = Variable(data[i:(i+1)], requires_grad=True)
                output, _ = model(input)
                output = torch.max(output)

                if args.method == 'saliency':
                    grad, = torch.autograd.grad(output, input, retain_graph=True)
                    
                elif args.method == 'smoothgrad':
                    grad_ft = SmoothGrad(pretrained_model=model,
                                         is_cuda=args.cuda,
                                         n_samples=2,
                                         magnitude=False)
                    grad = grad_ft(input, index=None)
                    
                else:
                    UnknownModelError()
                    
                ## Select Variables
                grad_size = grad.size()
                if args.chunk_size > 1:
                    grad_chunk = F.avg_pool2d(torch.abs(grad), kernel_size=args.filter_size, stride=args.filter_size, padding=0)
                    _, index_chunk = torch.abs(grad_chunk.view(grad_size[0], grad_size[1], -1)).topk(k, dim=-1) # index_chunk:[1, 1, 30]

#                    if args.method == 'saliency':    
#                        _, index_chunk = torch.abs(grad_chunk.view(grad_size[0], grad_size[1], -1)).topk(k, dim=-1)
#                    
#                    elif args.method == 'taylor':
#                        _, index_chunk = torch.abs(grad_chunk.view(grad_size[0], grad_size[1], -1) * input.view())
                    #                        .topk(k, dim = -1)
#                        approx = torch.addcmul(torch.zeros(1), value = 1, 
#                                               tensor1 = grad_chunk.view(grad_size[0], grad_size[1], -1), 
#                                               tensor2 = input.unsqueeze(-1).view(grad_size[0], grad_size[1], -1), out=None)
#                
#                    else:
#                        UnknownModelError()
                    
                    index = index_transfer(dataset=args.dataset,
                                           idx=index_chunk,
                                           filter_size=args.filter_size,
                                           original_nrow=args.original_nrow,
                                           original_ncol=args.original_ncol,
                                           is_cuda=args.cuda).output.unsqueeze(1)
                else:
                    grad_chunk = grad
                    _, index = torch.abs(grad_chunk.view(grad_size[0], grad_size[1], grad_size[2] * grad_size[3])).topk(k, dim=-1)

                ## Approximation
                grad_selected = grad.view(grad_size[0], grad_size[1], grad_size[2] * grad_size[3])[:, :, index[0][0].type(torch.long)]
                data_selected = input.view(grad_size[0], grad_size[1], grad_size[2] * grad_size[3])[:, :, index[0][0].type(torch.long)]
            
            else:
            
                raise UnknownDatasetError()    

            if i == 0:
                grad_all = grad
                index_all = index
                grad_selected_all = grad_selected
                data_selected_all = data_selected
            else:
                grad_all = torch.cat((grad_all, grad), dim = 0) 
                index_all = torch.cat((index_all, index), dim = 0)
                grad_selected_all = torch.cat((grad_selected_all, grad_selected), dim = 0)
                data_selected_all = torch.cat((data_selected_all, data_selected), dim = 0)
                
        if 'mnist' in args.dataset:
            data_size = data.size()
            binary_selected_all = idxtobool(index_all, [data_size[0], data_size[1], data_size[2] * data_size[3]], is_cuda)            
            data_zeropadded = torch.addcmul(torch.zeros(1), value=1, tensor1=binary_selected_all.view(data_size).type(torch.FloatTensor), tensor2=data.type(torch.FloatTensor), out=None)
        
        else:
            raise UnknownDatasetError()

        # Post-hoc Accuracy (zero-padded accuracy)
        output_zeropadded, output_zeropadded2 = model(cuda(data_zeropadded, is_cuda))             
        pred_zeropadded = output_zeropadded.max(1, keepdim=True)[1]  # get the index of the max log-probability
        correct_zeropadded += pred_zeropadded.eq(pred).sum().item()
       
        pred, pred_zeropadded = pred.cpu(), pred_zeropadded.cpu()
        precision_macro_zeropadded += precision_score(pred, pred_zeropadded, average='macro')
        precision_micro_zeropadded += precision_score(pred, pred_zeropadded, average='micro')
        precision_weighted_zeropadded += precision_score(pred, pred_zeropadded, average='weighted')
        recall_macro_zeropadded += recall_score(pred, pred_zeropadded, average='macro')
        recall_micro_zeropadded += recall_score(pred, pred_zeropadded, average='micro')
        recall_weighted_zeropadded += recall_score(pred, pred_zeropadded, average='weighted')
        f1_macro_zeropadded += f1_score(pred, pred_zeropadded, average='macro')
        f1_micro_zeropadded += f1_score(pred, pred_zeropadded, average='micro')
        f1_weighted_zeropadded += f1_score(pred, pred_zeropadded, average='weighted')

        ## Variational Mutual Information            
        vmi = torch.sum(torch.addcmul(torch.zeros(1), value=1,
                                      tensor1=torch.exp(output_all).type(torch.FloatTensor),
                                      tensor2 = output_zeropadded.type(torch.FloatTensor) - torch.logsumexp(output_all, dim = 0).unsqueeze(0).expand(output_zeropadded.size()).type(torch.FloatTensor) + torch.log(torch.tensor(output_all.size(0)).type(torch.FloatTensor)),
                                      #tensor2 = output_zeropadded.type(torch.FloatTensor) - torch.sum(output_all, dim = -1).unsqueeze(-1).expand(output_zeropadded.size()).type(torch.FloatTensor),
                                      out=None), dim = -1)
        vmi_zeropadded_sum += vmi.sum().item()

        ## Approximation Fidelity (prediction performance)
        for outidx in range(num_labels):

            for i in range(data.size(0)):

                if 'mnist' in args.dataset:
                
                    ## Calculate Gradient
                    input = Variable(data[i:(i+1), :, :, :], requires_grad = True) 
                    output, output2 = model(input)
                    output = output[0][outidx]
                    if args.method == 'saliency':
                        grad, = torch.autograd.grad(output, input, retain_graph = True)
#                         autograd.grad(outputs=b,inputs=a,grad_outputs=torch.ones_like(a))
                        
                    elif args.method == 'smoothgrad':
                        grad_ft = SmoothGrad(pretrained_model = model, 
                                             is_cuda = args.cuda,
                                             n_samples = 2,
                                             magnitude = False)
                        grad = grad_ft(input, index = None)
                        
                    else:
                        UnknownModelError()

                    ## Select Variables
                    grad_size = grad.size()
                    if args.chunk_size > 1:
                        grad_chunk = F.avg_pool2d(torch.abs(grad), kernel_size = args.filter_size, stride = args.filter_size, padding = 0)
                        _, index_chunk = torch.abs(grad_chunk.view(grad_size[0], grad_size[1], -1)).topk(k, dim = -1)
                        index = index_transfer(dataset = args.dataset,
                                                     idx = index_chunk, 
                                                     filter_size = args.filter_size,
                                                     original_nrow = args.original_nrow,
                                                     original_ncol = args.original_ncol, 
                                                     is_cuda = args.cuda).output.unsqueeze(1)
                    else:
                        grad_chunk = grad
                        _, index = torch.abs(grad_chunk.view(grad_size[0], grad_size[1], grad_size[2] * grad_size[3])).topk(k, dim=-1)
                    
                    ## Approximation
                    grad_selected = grad.view(grad_size[0], grad_size[1], grad_size[2] * grad_size[3])[:, :, index[0][0].type(torch.long)]
                    data_selected = input.view(grad_size[0], grad_size[1], grad_size[2] * grad_size[3])[:, :, index[0][0].type(torch.long)]
                    
                else:
                
                    raise UnknownDatasetError()    
                # print(i)
                if i == 0:
                    grad_all = grad
                    index_all = index
                    grad_selected_all = grad_selected
                    data_selected_all = data_selected
                else:
                    grad_all = torch.cat((grad_all, grad), dim = 0) 
                    index_all = torch.cat((index_all, index), dim = 0)
                    grad_selected_all = torch.cat((grad_selected_all, grad_selected), dim = 0)
                    data_selected_all = torch.cat((data_selected_all, data_selected), dim = 0)
            
            if 'mnist' in args.dataset:
            
                approx = torch.addcmul(torch.zeros(1), value = 1, tensor1 = grad_all.view(data_size[0], data_size[1], data_size[2] * data_size[3]).type(torch.FloatTensor), tensor2 = data.view(data_size[0], data_size[1], data_size[2] * data_size[3]).type(torch.FloatTensor), out=None)
                approx = torch.exp(torch.sum(approx, dim = -1))##squeeze(-1)
                approx_fixed = torch.addcmul(torch.zeros(1), value=1, tensor1 = grad_selected_all.type(torch.FloatTensor) , tensor2 = data_selected_all.type(torch.FloatTensor), out=None)
                approx_fixed = torch.exp(torch.sum(approx_fixed, dim = -1)) #.squeeze(-1)
                
            else:
                
                raise UnknownDatasetError()   
            
            if outidx == 0:
                approx_all = approx
                approx_fixed_all = approx_fixed
            else:
                approx_all = torch.cat((approx_all, approx), dim = 1)
                approx_fixed_all = torch.cat((approx_fixed_all, approx_fixed), dim = 1)

        pred = pred.type(torch.LongTensor)
        pred_approx = approx_all.topk(1, dim = -1)[1]
        pred_approx = pred_approx.type(torch.LongTensor)
        pred_approx_fixed = approx_fixed_all.topk(1, dim = -1)[1]
        pred_approx_fixed = pred_approx_fixed.type(torch.LongTensor)
        pred_approx_logit = F.softmax(torch.log(approx_all), dim=1)
        pred_approx_fixed_logit = F.softmax(torch.log(approx_fixed_all), dim = -1)
  
        correct_approx += pred_approx.eq(pred).sum().item()
        precision_macro_approx += precision_score(pred, pred_approx, average = 'macro')  
        precision_micro_approx += precision_score(pred, pred_approx, average = 'micro')  
        precision_weighted_approx += precision_score(pred, pred_approx, average = 'weighted')
        recall_macro_approx += recall_score(pred, pred_approx, average = 'macro')
        recall_micro_approx += recall_score(pred, pred_approx, average = 'micro')
        recall_weighted_approx += recall_score(pred, pred_approx, average = 'weighted')
        f1_macro_approx += f1_score(pred, pred_approx, average = 'macro')
        f1_micro_approx += f1_score(pred, pred_approx, average = 'micro')
        f1_weighted_approx += f1_score(pred, pred_approx, average = 'weighted')
        
        correct_approx_fixed += pred_approx_fixed.eq(pred).sum().item()
        precision_macro_approx_fixed += precision_score(pred, pred_approx_fixed, average = 'macro')  
        precision_micro_approx_fixed += precision_score(pred, pred_approx_fixed, average = 'micro')  
        precision_weighted_approx_fixed += precision_score(pred, pred_approx_fixed, average = 'weighted')
        recall_macro_approx_fixed += recall_score(pred, pred_approx_fixed, average = 'macro')
        recall_micro_approx_fixed += recall_score(pred, pred_approx_fixed, average = 'micro')
        recall_weighted_approx_fixed += recall_score(pred, pred_approx_fixed, average = 'weighted')
        f1_macro_approx_fixed += f1_score(pred, pred_approx_fixed, average = 'macro')
        f1_micro_approx_fixed += f1_score(pred, pred_approx_fixed, average = 'micro')
        f1_weighted_approx_fixed += f1_score(pred, pred_approx_fixed, average = 'weighted')    
        
        ## Variational Mutual Information    
        vmi = torch.sum(torch.addcmul(torch.zeros(1), value=1,
                                      tensor1=torch.exp(output_all).type(torch.FloatTensor),
                                      tensor2 = pred_approx_logit.type(torch.FloatTensor) - torch.logsumexp(output_all, dim = 0).unsqueeze(0).expand(pred_approx_logit.size()).type(torch.FloatTensor) + torch.log(torch.tensor(output_all.size(0)).type(torch.FloatTensor)),
                                      out=None), dim = -1)
        vmi_fidel_sum += vmi.sum().item()

        vmi = torch.sum(torch.addcmul(torch.zeros(1), value = 1, 
                                      tensor1 = torch.exp(output_all).type(torch.FloatTensor),
                                      tensor2 = pred_approx_fixed_logit.type(torch.FloatTensor) - torch.logsumexp(output_all, dim = 0).unsqueeze(0).expand(pred_approx_fixed_logit.size()).type(torch.FloatTensor) + torch.log(torch.tensor(output_all.size(0)).type(torch.FloatTensor)),
                                      out=None), dim = -1)
        vmi_fidel_fixed_sum += vmi.sum().item()            

#        #if (idx == 0 or idx == 200): ## figure
#        if idx in idx_list:
#
#            filename = 'figure_saliency_' + args.dataset + '_' + str(k) + '_idx' + str(idx) + '.png'
#            filename = Path(args.out_dir).joinpath(filename)
#    
#            #img = copy.deepcopy(grad_all)    
#            img = copy.deepcopy(data)
#            n_img = img.size(0)
#            n_col = 8
#            n_row = n_img // n_col + 1
#    
#            fig = plt.figure(figsize=(n_col * 1.5, n_row * 1.5)) 
#    
#            for i in range(n_img):
#    
#                plt.subplot(n_row, n_col, 1 + i)
#                plt.axis('off')
#                # original image
#                img0 = img[i].squeeze(0)#.numpy()
#                plt.imshow(img0, cmap = 'autumn_r')
#                # chunk selected
#                img2 = img[i].view(-1)#.numpy()
#                img2[index_all[i]] = cuda(torch.tensor(float('nan')), is_cuda)
#                img2 = img2.view(img0.size())#.numpy()
#                plt.title('BB {}, Apx {}'.format(pred[i].item(), pred[i].item()))
#                plt.imshow(img2, cmap = 'gray')
#    
#            fig.subplots_adjust(wspace = 0.05, hspace = 0.35)       
#            fig.savefig(filename)
#            
#            ## Save predictions
#            #predictions.extend(pred.data.squeeze(-1).cpu().tolist())
#            #predictions_idx.extend(idx.cpu().tolist())

        #print("SAVED!!!!")
#         if idx in idx_list:
        if idx in range(len(test_loader)):
            # filename
            filename = 'figure_'+ args.method + '_' + args.dataset + '_chunk' + str(args.chunk_size) + '_' + str(k) + '_idx' + str(idx) + '.png'
            filename = Path(args.out_dir).joinpath(filename)
            index_chunk = index_all
            
#            if args.chunk_size is not 1:
#                
#                index_chunk = index_transfer(dataset = args.dataset,
#                                             idx = index_chunk, 
#                                             filter_size = args.filter_size,
#                                             original_nrow = args.original_nrow,
#                                             original_ncol = args.original_ncol, 
#                                             is_cuda = args.cuda).output
            
            save_batch(dataset = args.dataset, 
                       batch = data, label = target, label_pred = pred.squeeze(-1), label_approx = pred_approx_fixed.squeeze(-1),
                       index = index_chunk, 
                       filename = filename, 
                       is_cuda = args.cuda,
                       word_idx = args.word_idx).output
#%%                                 
    ## Post-hoc Accuracy (zero-padded accuracy)
    accuracy_zeropadded = correct_zeropadded/total_num_ind
    precision_macro_zeropadded = precision_macro_zeropadded/total_num
    precision_micro_zeropadded = precision_micro_zeropadded/total_num
    precision_weighted_zeropadded = precision_weighted_zeropadded/total_num
    recall_macro_zeropadded = recall_macro_zeropadded/total_num
    recall_micro_zeropadded = recall_micro_zeropadded/total_num
    recall_weighted_zeropadded = recall_weighted_zeropadded/total_num
    f1_macro_zeropadded = f1_macro_zeropadded/total_num
    f1_micro_zeropadded = f1_micro_zeropadded/total_num
    f1_weighted_zeropadded = f1_weighted_zeropadded/total_num
    
    ## VMI
    vmi_zeropadded = vmi_zeropadded_sum/total_num_ind
    vmi_fidel = vmi_fidel_sum / total_num_ind
    vmi_fidel_fixed = vmi_fidel_fixed_sum / total_num_ind
    
    ## Approximation Fidelity (prediction performance)
    accuracy_approx = correct_approx/total_num_ind
    precision_macro_approx = precision_macro_approx/total_num
    precision_micro_approx = precision_micro_approx/total_num
    precision_weighted_approx = precision_weighted_approx/total_num
    recall_macro_approx = recall_macro_approx/total_num
    recall_micro_approx = recall_micro_approx/total_num
    recall_weighted_approx = recall_weighted_approx/total_num
    f1_macro_approx = f1_macro_approx/total_num
    f1_micro_approx = f1_micro_approx/total_num
    f1_weighted_approx = f1_weighted_approx/total_num
    
    accuracy_approx_fixed = correct_approx_fixed/total_num_ind
    precision_macro_approx_fixed = precision_macro_approx_fixed/total_num
    precision_micro_approx_fixed = precision_micro_approx_fixed/total_num
    precision_weighted_approx_fixed = precision_weighted_approx_fixed/total_num
    recall_macro_approx_fixed = recall_macro_approx_fixed/total_num
    recall_micro_approx_fixed = recall_micro_approx_fixed/total_num
    recall_weighted_approx_fixed = recall_weighted_approx_fixed/total_num
    f1_macro_approx_fixed = f1_macro_approx_fixed/total_num
    f1_micro_approx_fixed = f1_micro_approx_fixed/total_num
    f1_weighted_approx_fixed = f1_weighted_approx/total_num

    print('\n\n[VAL RESULT]\n')
    # Post-hoc Accuracy (zero-padded accuracy); 

    print('acc_zeropadded:{:.4f} avg_acc:{:.4f} avg_acc_fixed:{:.4f}'
            .format(accuracy_zeropadded, accuracy_approx, accuracy_approx_fixed), end = '\n')
    print('precision_macro_zeropadded:{:.4f} precision_macro_approx:{:.4f} precision_macro_approx_fixed:{:.4f}'
            .format(precision_macro_zeropadded, precision_macro_approx, precision_macro_approx_fixed), end = '\n')   
    print('precision_micro_zeropadded:{:.4f} precision_micro_approx:{:.4f} precision_micro_approx_fixed:{:.4f}'
            .format(precision_micro_zeropadded, precision_micro_approx, precision_micro_approx_fixed), end = '\n')   
    print('recall_macro_zeropadded:{:.4f} recall_macro_approx:{:.4f} recall_macro_approx_fixed:{:.4f}'
            .format(recall_macro_zeropadded, recall_macro_approx, recall_macro_approx_fixed), end = '\n')   
    print('recall_micro_zeropadded:{:.4f} recall_micro_approx:{:.4f} recall_micro_approx_fixed:{:.4f}'
            .format(recall_micro_zeropadded, recall_micro_approx, recall_micro_approx_fixed), end = '\n') 
    print('f1_macro_zeropadded:{:.4f} f1_macro_approx:{:.4f} f1_macro_approx_fixed:{:.4f}'
            .format(f1_macro_zeropadded, f1_macro_approx, f1_macro_approx_fixed), end = '\n')   
    print('f1_micro_zeropadded:{:.4f} f1_micro_approx:{:.4f} f1_micro_approx_fixed:{:.4f}'
            .format(f1_micro_zeropadded, f1_micro_approx_fixed, f1_micro_approx_fixed), end = '\n') 
    print('vmi:{:.4f} vmi_fixed:{:.4f} vmi_zeropadded:{:.4f}'.format(vmi_fidel, vmi_fidel_fixed, vmi_zeropadded), end = '\n')
    print()
    
#%%
#        if outfile:
#            
#            predictions = np.array(predictions)
#            predictions_idx = np.array(predictions_idx)
#            inds = predictions_idx.argsort()
#            sorted_predictions = predictions[inds]
#
#            output_name = model_name + '_pred_' + outmode + '.pt'
#            torch.save(sorted_predictions, Path(outfile_path).joinpath(output_name))                

if __name__ == '__main__':
    main()







Using TensorFlow backend.
  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])


Using GPU for acceleration
Input data: mnist


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


tensor([[-0.1922, -0.1843, -0.1608,  ..., -0.2392, -0.2392, -0.2314],
        [-0.2627, -0.2627, -0.2471,  ..., -0.2157, -0.2235, -0.2235],
        [-0.3020, -0.3098, -0.3098,  ..., -0.1686, -0.1922, -0.2000],
        ...,
        [-0.9294, -0.9294, -0.9451,  ..., -0.8039, -0.8039, -0.7961],
        [-0.8431, -0.8667, -0.8902,  ..., -0.8039, -0.7961, -0.7882],
        [-0.8039, -0.8039, -0.8039,  ..., -0.8275, -0.8353, -0.8353]],
       device='cuda:0')
tensor([[-0.9922, -0.9922, -0.9922,  ..., -0.9529, -0.1765, -0.9922],
        [-0.9922, -0.9922, -0.9922,  ..., -0.9529, -0.1765, -0.9843],
        [-0.9922, -0.9922, -0.9922,  ..., -0.9608, -0.1765, -0.9922],
        ...,
        [-0.5922, -0.6000, -0.5608,  ..., -0.7882, -0.8196, -0.8039],
        [-0.6000, -0.5608, -0.5529,  ..., -0.8039, -0.7647, -0.7725],
        [-0.5294, -0.5922, -0.5922,  ..., -0.7647, -0.7961, -0.8431]],
       device='cuda:0')
tensor([[-0.7961, -0.8745,  0.4824,  ...,  0.6314,  0.7098,  0.7569],
        [-0.93

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


tensor([[-0.4353, -0.1137, -0.6157,  ...,  0.4118,  0.3569, -1.0000],
        [ 0.3725, -0.4275,  0.5529,  ...,  0.2157,  0.3098,  0.0431],
        [-0.7725, -0.2157, -0.3333,  ...,  0.2863,  0.2235,  0.2627],
        ...,
        [-0.3804, -0.3333, -0.2706,  ..., -0.6078, -0.6314, -0.6078],
        [-0.3255, -0.3412, -0.3098,  ..., -0.6314, -0.6549, -0.6627],
        [-0.3412, -0.3098, -0.2706,  ..., -0.6392, -0.6706, -0.6627]],
       device='cuda:0')
tensor([[-0.2078, -0.7725,  0.2941,  ...,  0.7412,  0.5765,  0.5059],
        [-0.9922, -0.8745,  0.5765,  ...,  0.5294,  0.4039,  0.3804],
        [-0.3804, -0.5608,  0.4275,  ...,  0.1922,  0.1373,  0.1608],
        ...,
        [-0.7725, -0.7882, -0.8353,  ..., -0.8353, -0.8196, -0.8039],
        [-0.7882, -0.7725, -0.8275,  ..., -0.8118, -0.7882, -0.7961],
        [-0.7725, -0.7882, -0.8275,  ..., -0.8196, -0.8353, -0.8353]],
       device='cuda:0')
tensor([[ 0.4902, -1.0000, -0.7412,  ...,  0.5216,  0.4510,  0.3882],
        [ 0.58

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


tensor([[-0.9216,  0.4431, -0.8118,  ...,  0.0745,  0.1373, -0.0510],
        [-0.8824, -0.1216, -0.5137,  ...,  0.2157,  0.2235,  0.1294],
        [-0.9843, -0.1216, -0.5608,  ...,  0.0824,  0.1059,  0.1216],
        ...,
        [-0.8039, -0.8196, -0.8118,  ..., -0.8118, -0.8353, -0.8353],
        [-0.8275, -0.8196, -0.7804,  ..., -0.8196, -0.8275, -0.8196],
        [-0.7882, -0.8039, -0.8039,  ..., -0.8275, -0.8275, -0.8039]],
       device='cuda:0')
tensor([[-0.1686, -0.0353,  0.0667,  ...,  0.1922,  0.1686,  0.0510],
        [-0.1059, -0.0667, -0.0510,  ...,  0.0275, -0.0824, -0.1922],
        [-0.1922, -0.0196, -0.0275,  ..., -0.1059, -0.1529, -0.1216],
        ...,
        [-0.4588, -0.4431, -0.3647,  ..., -0.1451, -0.0824, -0.0118],
        [-0.5216, -0.4980, -0.4667,  ..., -0.0980, -0.0039, -0.0039],
        [-0.6078, -0.5765, -0.5294,  ..., -0.0902, -0.0902, -0.1294]],
       device='cuda:0')
tensor([[-0.3961, -0.4118, -0.4118,  ..., -0.9922, -0.9922, -0.9922],
        [-0.28

  fig = plt.figure(frameon=False)


tensor([[-0.9922, -0.9765,  0.2784,  ...,  0.3569,  0.3333,  0.3725],
        [-0.9843, -0.4431,  0.4510,  ...,  0.3961,  0.3647,  0.3255],
        [ 0.1922,  0.5451,  0.4745,  ...,  0.3255,  0.3098,  0.2392],
        ...,
        [-0.8353, -0.8588, -0.8667,  ..., -0.8275, -0.7882, -0.8118],
        [-0.8353, -0.7961, -0.7961,  ..., -0.7961, -0.8275, -0.8431],
        [-0.7961, -0.8196, -0.8275,  ..., -0.7961, -0.8353, -0.8431]],
       device='cuda:0')
tensor([[-0.9922, -0.9922, -0.0510,  ..., -0.9922, -0.9922, -0.9922],
        [ 0.4902, -0.3882, -0.0980,  ..., -0.9922, -0.9922, -0.9922],
        [-0.6235,  0.4667, -0.9765,  ..., -0.9922, -0.9922, -0.9922],
        ...,
        [-0.8039, -0.8039, -0.8431,  ..., -0.7333, -0.7098, -0.6706],
        [-0.8196, -0.7882, -0.7725,  ..., -0.6235, -0.6000, -0.6392],
        [-0.8275, -0.8118, -0.7961,  ..., -0.6549, -0.7020, -0.7569]],
       device='cuda:0')


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


tensor([[-0.9843, -0.9843, -0.9922,  ..., -0.0275, -0.0745, -0.0980],
        [ 0.0980, -0.8902, -0.9922,  ..., -0.0196,  0.0431, -0.0039],
        [ 0.5216, -0.5451, -0.8980,  ...,  0.0745,  0.0667,  0.0196],
        ...,
        [-0.7569, -0.7569, -0.7569,  ..., -0.7647, -0.7412, -0.7333],
        [-0.7569, -0.7725, -0.7647,  ..., -0.6941, -0.6706, -0.6627],
        [-0.7882, -0.7961, -0.7882,  ..., -0.6863, -0.7176, -0.7412]],
       device='cuda:0')
tensor([[-0.8588,  0.0118,  0.4431,  ...,  0.4510,  0.5373,  0.5686],
        [-0.2941,  0.0745, -0.7725,  ...,  0.2941,  0.3804,  0.2471],
        [ 0.0588, -0.5765, -0.7490,  ...,  0.1529,  0.0667,  0.0353],
        ...,
        [-0.8588, -0.8980, -0.9216,  ..., -0.8824, -0.8667, -0.8824],
        [-0.8510, -0.8902, -0.9137,  ..., -0.8745, -0.8588, -0.8667],
        [-0.8431, -0.8824, -0.9137,  ..., -0.8902, -0.8745, -0.8824]],
       device='cuda:0')
tensor([[ 0.2471, -0.7804, -0.7098,  ...,  0.6941,  0.5843,  0.5059],
        [ 0.12

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


tensor([[ 0.5765, -0.2706,  0.8353,  ...,  0.7725,  0.7725,  0.7961],
        [ 0.3098,  0.4353,  0.7412,  ...,  0.7098,  0.6941,  0.7176],
        [-0.9216, -0.3647,  0.6235,  ...,  0.6235,  0.5686,  0.5843],
        ...,
        [-0.8431, -0.8510, -0.8824,  ..., -0.8431, -0.7333, -0.7020],
        [-0.8745, -0.9137, -0.9137,  ..., -0.8118, -0.7412, -0.7176],
        [-0.8196, -0.8118, -0.8275,  ..., -0.7961, -0.7333, -0.7176]],
       device='cuda:0')
tensor([[-0.9529,  0.7020,  0.6784,  ...,  0.5294,  0.4510,  0.3961],
        [-0.6627,  0.6863,  0.6784,  ...,  0.4353,  0.4510,  0.4902],
        [-0.4824,  0.5216,  0.6549,  ...,  0.4196,  0.3882,  0.3333],
        ...,
        [-0.8196, -0.7961, -0.7961,  ..., -0.8275, -0.8353, -0.8431],
        [-0.8667, -0.8353, -0.8275,  ..., -0.8431, -0.8510, -0.8431],
        [-0.8510, -0.8588, -0.8510,  ..., -0.7961, -0.8196, -0.8275]],
       device='cuda:0')
tensor([[-0.9922, -0.9765, -0.0275,  ...,  0.7020,  0.7412,  0.8275],
        [ 0.52

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


tensor([[-0.6863,  0.4353,  0.5922,  ...,  0.4510,  0.4824,  0.4745],
        [ 0.4039, -0.6941, -0.9529,  ...,  0.4039,  0.3412,  0.3176],
        [ 0.5686, -0.9843, -0.9294,  ...,  0.2392,  0.1686,  0.1922],
        ...,
        [-0.8824, -0.8824, -0.8431,  ..., -0.4275, -0.3490, -0.2706],
        [-0.8902, -0.8980, -0.8667,  ..., -0.5294, -0.4745, -0.4431],
        [-0.8745, -0.8745, -0.8667,  ..., -0.6471, -0.6314, -0.6314]],
       device='cuda:0')
tensor([[-0.0588, -0.0745,  0.3176,  ...,  0.5451,  0.5608,  0.4667],
        [-0.9922, -0.9843,  0.3882,  ...,  0.3804,  0.3098,  0.2392],
        [-0.9843, -0.9922,  0.3098,  ...,  0.3725,  0.3725,  0.3412],
        ...,
        [-0.8118, -0.8431, -0.8431,  ..., -0.8353, -0.8431, -0.8353],
        [-0.7882, -0.7490, -0.7725,  ..., -0.8353, -0.8353, -0.8588],
        [-0.7961, -0.8118, -0.8431,  ..., -0.8196, -0.8353, -0.8196]],
       device='cuda:0')
tensor([[ 0.5843,  0.3882, -1.0000,  ..., -0.9922, -0.9922, -0.9922],
        [-0.96

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


tensor([[-0.6941, -0.9843, -0.5765,  ...,  0.8118,  0.7961,  0.6627],
        [ 0.5451, -0.0039,  0.0196,  ...,  0.7804,  0.7569,  0.6471],
        [-0.3882, -0.3176,  0.5922,  ...,  0.6314,  0.6157,  0.6392],
        ...,
        [-0.8902, -0.8902, -0.9059,  ..., -0.8196, -0.8275, -0.8118],
        [-0.8353, -0.8118, -0.8510,  ..., -0.7333, -0.7882, -0.8039],
        [-0.8196, -0.8667, -0.8745,  ..., -0.7804, -0.8039, -0.8039]],
       device='cuda:0')
tensor([[-0.9922, -0.9922, -0.9765,  ...,  0.7255,  0.7725,  0.7255],
        [-0.9843,  0.5059,  0.1686,  ...,  0.8118,  0.8510,  0.7255],
        [-0.6471, -0.5686, -0.9686,  ...,  0.7804,  0.6706,  0.5529],
        ...,
        [-0.8588, -0.8275, -0.7961,  ..., -0.6941, -0.6627, -0.7020],
        [-0.8824, -0.8431, -0.7647,  ..., -0.7255, -0.7412, -0.7882],
        [-0.8118, -0.8039, -0.7333,  ..., -0.7725, -0.7569, -0.7569]],
       device='cuda:0')
tensor([[ 0.2549,  0.1922,  0.1608,  ...,  0.2314,  0.2157,  0.2549],
        [ 0.12

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


tensor([[-0.9843, -0.6000, -0.0745,  ...,  0.7176,  0.6941,  0.5608],
        [-0.7647,  0.4980, -0.4824,  ...,  0.5843,  0.5843,  0.4275],
        [ 0.0275, -0.3804, -0.0039,  ...,  0.4275,  0.3804,  0.4431],
        ...,
        [-0.7961, -0.7725, -0.7569,  ..., -0.8353, -0.8196, -0.8196],
        [-0.8353, -0.8275, -0.7961,  ..., -0.7961, -0.7961, -0.8510],
        [-0.8745, -0.8431, -0.8745,  ..., -0.8745, -0.8431, -0.8353]],
       device='cuda:0')
tensor([[-0.0039,  0.0667,  0.1529,  ...,  0.0745,  0.0667,  0.0196],
        [-0.2235, -0.0667,  0.0431,  ...,  0.1294,  0.1059,  0.0353],
        [-0.2784, -0.1608, -0.0510,  ...,  0.0353,  0.0196, -0.0431],
        ...,
        [-0.9059, -0.9137, -0.9137,  ..., -0.8902, -0.8824, -0.9137],
        [-0.9137, -0.9294, -0.9373,  ..., -0.8667, -0.8588, -0.8588],
        [-0.9216, -0.9294, -0.9373,  ..., -0.9294, -0.9294, -0.9137]],
       device='cuda:0')
tensor([[-0.9451,  0.8118,  0.7412,  ...,  0.8980,  0.7961,  0.7569],
        [-0.65