In [None]:
import os
os.environ["CUDA_VISIBLE_DEVICES"]='0'
import numpy as np
import torch
import argparse
from utils import str2bool
from solver import Solver

import warnings
warnings.filterwarnings("ignore")


parser = argparse.ArgumentParser(description='VIBI for interpretation')
parser.add_argument('--epoch', default=100, type=int,
                        help='epoch number')
parser.add_argument('--lr', default=1e-6, type=float,
                        help='learning rate')
parser.add_argument('--beta', default=0.01, type=float,
                        help='beta for balance between information loss and prediction loss')
parser.add_argument('--K', default=100, type=int,
                        help='dimension of encoding Z')
parser.add_argument('--chunk_size', default=2, type=int,
                        help='chunk size. for image, chunk x chunk will be the actual chunk size')
parser.add_argument('--num_avg', default=4, type=int,
                        help='the number of samples when perform multi-shot prediction')
parser.add_argument('--batch_size', default=8, type=int,
                        help='batch size')
parser.add_argument('--env_name', default='main', type=str,
                        help='visdom env name')
parser.add_argument('--dataset', default='mnist', type=str,
                        help='dataset name: imdb, mnist')
parser.add_argument('--model_name', default='original_BUSI6.ckpt', type=str,
                        help='model names to be interpreted')
parser.add_argument('--explainer_type', default='cnn4', type=str,
                        help='explainer types: nn, cnn for mnist')
parser.add_argument('--approximater_type', default='None', type=str,
                        help='explainer types: nn, cnn')
parser.add_argument('--load_checkpoint', default='', type=str,
                        help='checkpoint name')
parser.add_argument('--checkpoint_name', default='best_acc_k10_le-5_tau0.6_beta0.001.tar', type=str,
                        help='checkpoint name')
parser.add_argument('--default_dir', default='.', type=str,
                        help='default directory path')
parser.add_argument('--data_dir', default='dataset/Dataset_BUSI_AN/train/images', type=str,
                        help='data directory path')
parser.add_argument('--summary_dir', default='summary', type=str,
                        help='summary directory path')
parser.add_argument('--checkpoint_dir', default='checkpoints', type=str,
                        help='checkpoint directory path')
parser.add_argument('--cuda', default=True, type=str2bool,
                        help='enable cuda')
parser.add_argument('--mode', default='test', type=str,
                        help='train or test')
parser.add_argument('--tensorboard', default=True, type=str2bool,
                        help='enable tensorboard')
parser.add_argument('--save_image', default=True, type=str2bool,
                        help='if True, then save images')
parser.add_argument('--save_checkpoint', default = True, type= str2bool,
                        help='if True, then save checkpoint')
parser.add_argument('--tau', default=0.7, type=float,
                        help='tau')
args = parser.parse_args([])

torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = True

# print-option
np.set_printoptions(precision=4)  # up to 4th digits for floating point output
torch.set_printoptions(precision=4)
print('\n[ARGUMENTS]\n', args)

# cuda
if torch.cuda.is_available() and not args.cuda:
    print("WARNING: You have a CUDA device, so you should probably run with --cuda True")
args.cuda = (args.cuda and torch.cuda.is_available())

net = Solver(args)

if args.mode == 'train':
    net.train(test=False)
elif args.mode == 'test':
    net.train(test=True)
    # net.val(test = True)
else:
    print('\n Error: "--mode train" or "--mode test" expected')




Using TensorFlow backend.
  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])



[ARGUMENTS]
 Namespace(K=100, approximater_type='None', batch_size=8, beta=0.01, checkpoint_dir='checkpoints', checkpoint_name='best_acc_k10_le-5_tau0.6_beta0.001.tar', chunk_size=2, cuda=True, data_dir='dataset/Dataset_BUSI_AN/train/images', dataset='mnist', default_dir='.', env_name='main', epoch=100, explainer_type='cnn4', load_checkpoint='', lr=1e-06, mode='test', model_name='original_BUSI6.ckpt', num_avg=4, save_checkpoint=True, save_image=True, summary_dir='summary', tau=0.7, tensorboard=True)
/workspace/MTL-IBA/MTL-VIBI
torch.Size([8, 1, 224, 224])
torch.Size([8, 1, 224, 224])
torch.Size([8, 1, 224, 224])
test True


[VAL RESULT]

epoch 1
global iter 125
IZY:2.53 IZX:0.64
acc:0.6719 avg_acc:0.6719
acc_fixed:0.6719 avg_acc_fixed:0.6719
vmi:-0.1008 avg_vmi:-0.1007
vmi_fixed:-0.1008 avg_vmi_fixed:-0.1008

epoch:1
Time spent is 40.505823612213135
test True


[VAL RESULT]

epoch 2
global iter 250
IZY:2.53 IZX:0.54
acc:0.6719 avg_acc:0.6719
acc_fixed:0.6719 avg_acc_fixed:0.6719
vmi:-

test True


[VAL RESULT]

epoch 30
global iter 3750
IZY:2.54 IZX:0.01
acc:0.6250 avg_acc:0.5938
acc_fixed:0.4688 avg_acc_fixed:0.4688
vmi:-0.0911 avg_vmi:-0.0947
vmi_fixed:-0.1119 avg_vmi_fixed:-0.1119

epoch:30
Time spent is 1123.9176931381226
test True


[VAL RESULT]

epoch 31
global iter 3875
IZY:2.54 IZX:0.01
acc:0.6406 avg_acc:0.6562
acc_fixed:0.5156 avg_acc_fixed:0.5156
vmi:-0.0917 avg_vmi:-0.0903
vmi_fixed:-0.1122 avg_vmi_fixed:-0.1122

epoch:31
Time spent is 1156.0101566314697


[TRAINING RESULT]

epoch 32 Time since 20m 3s
global iter 4000
i:125 IZY:2.56 IZX:0.00
acc:0.6250 avg_acc:0.7500
acc_fixed:0.5000 avg_acc_fixed:0.5000
vmi:0.0348 avg_vmi:0.0476
vmi_fixed:-0.0037 avg_vmi_fixed:-0.0143
test True


[VAL RESULT]

epoch 32
global iter 4000
IZY:2.54 IZX:0.01
acc:0.6406 avg_acc:0.5938
acc_fixed:0.5156 avg_acc_fixed:0.5156
vmi:-0.0877 avg_vmi:-0.0896
vmi_fixed:-0.1127 avg_vmi_fixed:-0.1127

epoch:32
Time spent is 1194.2665691375732
test True


[VAL RESULT]

epoch 33
global iter

test True


[VAL RESULT]

epoch 60
global iter 7500
IZY:2.64 IZX:0.00
acc:0.6875 avg_acc:0.6719
acc_fixed:0.3438 avg_acc_fixed:0.3438
vmi:0.0009 avg_vmi:-0.0238
vmi_fixed:-0.2205 avg_vmi_fixed:-0.2205

epoch:60
Time spent is 2121.4342999458313
test True


[VAL RESULT]

epoch 61
global iter 7625
IZY:2.62 IZX:0.00
acc:0.6875 avg_acc:0.6562
acc_fixed:0.3438 avg_acc_fixed:0.3438
vmi:-0.0132 avg_vmi:-0.0172
vmi_fixed:-0.2258 avg_vmi_fixed:-0.2258

epoch:61
Time spent is 2155.629629611969
test True


[VAL RESULT]

epoch 62
global iter 7750
IZY:2.63 IZX:0.00
acc:0.7188 avg_acc:0.6875
acc_fixed:0.3281 avg_acc_fixed:0.3281
vmi:-0.0049 avg_vmi:-0.0200
vmi_fixed:-0.2318 avg_vmi_fixed:-0.2318

epoch:62
Time spent is 2189.463306903839
test True


[VAL RESULT]

epoch 63
global iter 7875
IZY:2.60 IZX:0.00
acc:0.6719 avg_acc:0.6406
acc_fixed:0.3281 avg_acc_fixed:0.3281
vmi:-0.0314 avg_vmi:-0.0378
vmi_fixed:-0.2305 avg_vmi_fixed:-0.2305

epoch:63
Time spent is 2222.720778942108


[TRAINING RESULT]

epo

tensor([[-0.4353, -0.1137, -0.6157,  ...,  0.4118,  0.3569, -1.0000],
        [ 0.3725, -0.4275,  0.5529,  ...,  0.2157,  0.3098,  0.0431],
        [-0.7725, -0.2157, -0.3333,  ...,  0.2863,  0.2235,  0.2627],
        ...,
        [-0.3804, -0.3333, -0.2706,  ..., -0.6078, -0.6314, -0.6078],
        [-0.3255, -0.3412, -0.3098,  ..., -0.6314, -0.6549, -0.6627],
        [-0.3412, -0.3098, -0.2706,  ..., -0.6392, -0.6706, -0.6627]],
       device='cuda:0')
tensor([[-0.2078, -0.7725,  0.2941,  ...,  0.7412,  0.5765,  0.5059],
        [-0.9922, -0.8745,  0.5765,  ...,  0.5294,  0.4039,  0.3804],
        [-0.3804, -0.5608,  0.4275,  ...,  0.1922,  0.1373,  0.1608],
        ...,
        [-0.7725, -0.7882, -0.8353,  ..., -0.8353, -0.8196, -0.8039],
        [-0.7882, -0.7725, -0.8275,  ..., -0.8118, -0.7882, -0.7961],
        [-0.7725, -0.7882, -0.8275,  ..., -0.8196, -0.8353, -0.8353]],
       device='cuda:0')
tensor([[ 0.4902, -1.0000, -0.7412,  ...,  0.5216,  0.4510,  0.3882],
        [ 0.58

In [2]:
# 再训练100个epoch




net.train(test=True)




test True


[VAL RESULT]

epoch 41
global iter 5125
IZY:2.71 IZX:0.00
acc:0.7656 avg_acc:0.7656
acc_fixed:0.5781 avg_acc_fixed:0.5781
vmi:0.0600 avg_vmi:0.0505
vmi_fixed:-0.2149 avg_vmi_fixed:-0.2149

epoch:1
Time spent is 24.678877592086792
test True


[VAL RESULT]

epoch 42
global iter 5250
IZY:2.75 IZX:0.00
acc:0.8281 avg_acc:0.7656
acc_fixed:0.5938 avg_acc_fixed:0.5938
vmi:0.0751 avg_vmi:0.0523
vmi_fixed:-0.2271 avg_vmi_fixed:-0.2271

epoch:2
Time spent is 48.13144373893738
test True


[VAL RESULT]

epoch 43
global iter 5375
IZY:2.67 IZX:0.00
acc:0.7188 avg_acc:0.7500
acc_fixed:0.5469 avg_acc_fixed:0.5469
vmi:0.0285 avg_vmi:0.0518
vmi_fixed:-0.2217 avg_vmi_fixed:-0.2217

epoch:3
Time spent is 73.05545258522034
test True


[VAL RESULT]

epoch 44
global iter 5500
IZY:2.67 IZX:0.00
acc:0.7031 avg_acc:0.7969
acc_fixed:0.5625 avg_acc_fixed:0.5625
vmi:0.0269 avg_vmi:0.0610
vmi_fixed:-0.2166 avg_vmi_fixed:-0.2166

epoch:4
Time spent is 97.28594875335693
test True


[VAL RESULT]

epoch 45




[VAL RESULT]

epoch 72
global iter 9000
IZY:2.74 IZX:0.00
acc:0.7656 avg_acc:0.7969
acc_fixed:0.6250 avg_acc_fixed:0.6250
vmi:0.0778 avg_vmi:0.0560
vmi_fixed:-0.1854 avg_vmi_fixed:-0.1854

epoch:32
Time spent is 807.4938464164734
test True


[VAL RESULT]

epoch 73
global iter 9125
IZY:2.68 IZX:0.00
acc:0.7500 avg_acc:0.7812
acc_fixed:0.6250 avg_acc_fixed:0.6250
vmi:0.0298 avg_vmi:0.0747
vmi_fixed:-0.1899 avg_vmi_fixed:-0.1899

epoch:33
Time spent is 833.1035802364349
test True


[VAL RESULT]

epoch 74
global iter 9250
IZY:2.68 IZX:0.00
acc:0.7500 avg_acc:0.8281
acc_fixed:0.6250 avg_acc_fixed:0.6250
vmi:0.0317 avg_vmi:0.0985
vmi_fixed:-0.1979 avg_vmi_fixed:-0.1979

epoch:34
Time spent is 858.3779797554016
test True


[VAL RESULT]

epoch 75
global iter 9375
IZY:2.74 IZX:0.00
acc:0.7812 avg_acc:0.7969
acc_fixed:0.6562 avg_acc_fixed:0.6562
vmi:0.0882 avg_vmi:0.0655
vmi_fixed:-0.1937 avg_vmi_fixed:-0.1937

epoch:35
Time spent is 883.3417990207672
test True


[VAL RESULT]

epoch 76
global 

tensor([[-0.9686, -0.3098, -0.0039,  ...,  0.2471,  0.1922,  0.1843],
        [-0.7098, -0.8824,  0.5451,  ...,  0.2471,  0.2784,  0.2471],
        [-0.9843, -0.8510,  0.5373,  ...,  0.1451,  0.1765,  0.1529],
        ...,
        [-0.7412, -0.8275, -0.8745,  ..., -0.9765, -0.9529, -0.9765],
        [-0.7490, -0.8353, -0.8667,  ..., -0.9686, -0.9843, -1.0000],
        [-0.7961, -0.8275, -0.8275,  ..., -0.9608, -0.9608, -0.9608]],
       device='cuda:0')
tensor([[ 0.1922, -0.3490,  0.2549,  ..., -0.9922, -0.9922, -0.9922],
        [-0.8980, -0.1765,  0.2941,  ..., -0.9922, -0.9922, -0.9922],
        [ 0.4902, -0.8902,  0.2706,  ..., -0.9922, -0.9922, -0.9922],
        ...,
        [-0.8588, -0.8667, -0.8196,  ..., -0.5843, -0.5373, -0.5608],
        [-0.8118, -0.8196, -0.8745,  ..., -0.5608, -0.5059, -0.4745],
        [-0.8118, -0.8039, -0.8275,  ..., -0.6157, -0.6235, -0.6078]],
       device='cuda:0')
tensor([[-0.9216,  0.4431, -0.8118,  ...,  0.0745,  0.1373, -0.0510],
        [-0.88

In [None]:
net.train(test=False)