In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F 
import time
import random
import datetime
import os
import numpy as np
import torch.optim as optim
import matplotlib.pyplot as plt
from torch import Tensor
import functools
from torchvision import datasets, transforms
from torch.nn.modules.loss import _Loss
from torch.utils.data import DataLoader
import torchvision
from model.drn import drn_d_54, drn_d_base
from model.CBAM import drn_d_CBAM
from model.resnet import ResNet50 as resnet
from load_data import *
from optimizer import Ranger
from collections import Counter
from sklearn.metrics import roc_curve, auc
from sklearn.metrics import recall_score,accuracy_score, precision_score,f1_score


Matplotlib created a temporary config/cache directory at /tmp/matplotlib-7tllr00w because the default path (/home/ubuntu/.config/matplotlib) is not a writable directory; it is highly recommended to set the MPLCONFIGDIR environment variable to a writable directory, in particular to speed up the import of Matplotlib and to better support multiprocessing.
In /home/ubuntu/tensorflow/anaconda3/lib/python3.7/site-packages/matplotlib/mpl-data/stylelib/_classic_test.mplstyle: 
The text.latex.preview rcparam was deprecated in Matplotlib 3.3 and will be removed two minor releases later.
In /home/ubuntu/tensorflow/anaconda3/lib/python3.7/site-packages/matplotlib/mpl-data/stylelib/_classic_test.mplstyle: 
The mathtext.fallback_to_cm rcparam was deprecated in Matplotlib 3.3 and will be removed two minor releases later.
In /home/ubuntu/tensorflow/anaconda3/lib/python3.7/site-packages/matplotlib/mpl-data/stylelib/_classic_test.mplstyle: Support for setting the 'mathtext.fallback_to_cm' rcParam is dep

PyTorch Version:  1.5.0+cu101
Torchvision Version:  0.6.0+cu101


In [15]:
class A_net(nn.Module):

    def __init__(self, classes):
        super(A_net, self).__init__()
        self.classes = classes
        self.conv  = nn.Conv2d(1, 3, 1, 1, 0)
        self.conv1 = conv_set(3,12,3,1,0)
        self.conv2 = conv_set(12,16,3,1,0)
        self.conv3 = conv_set(16,32,3,1,0)
        self.conv4 = conv_set(32,128,3,1,0)
        self.conv5 = conv_set(128,classes,3,1,0)
    def forward(self, x):
        x = self.conv(x)
        out = self.conv1(x)
        out = F.max_pool2d(out, 2,2)
        out = self.conv2(out)
        out = F.max_pool2d(out, 2,2)
        out = self.conv3(out)
        out = F.max_pool2d(out, 2,2)
        out = self.conv4(out)
        out = self.conv5(out)
        out = F.max_pool2d(out, 2,2)
        return out 
        
class CNNX(nn.Module):
    
    def __init__(self, backbone = 'drn', out_stride = 16, num_class = 2):
        super(CNNX, self).__init__()
        
        if backbone == 'drn':
            output_stride = 8
        self.drn = drn_d_54(nn.BatchNorm2d)
        self.attention = A_net(512)
        self.avgpool = nn.AvgPool2d(36, stride = 1)
        self.fc = nn.Linear(430592, 2)
        


    def forward(self, input):
        x = self.drn(input)
        x = self.avgpool(x)
        
        atten = self.attention(input)
        atten = torch.sigmoid(atten)
        x = torch.mul(x, atten)
        x = x.view(x.size(0), -1)
        x = self.fc(x)

        return x

In [16]:
# 中间特征提取
class FeatureExtractor(nn.Module):
    def __init__(self, submodule, extracted_layers):
        super(FeatureExtractor, self).__init__()
        self.submodule = submodule
        self.extracted_layers = extracted_layers
 
    def forward(self, x):
        outputs = []
        print('---------',self.submodule._modules.items())
        for name, module in self.submodule._modules.items():
            if "fc" in name:
                x = x.view(x.size(0), -1)
            print(module)
            x = module(x)
            print('name', name)
            if name in self.extracted_layers:
                outputs.append(x)
        return outputs


def get_feature():  # 特征可视化
    # 输入数据
    img = get_picture(pic_dir, transform) # 输入的图像是【3,256,256】
    # 插入维度
    img = img.unsqueeze(0)  # 【1,3,256,256】
    img = img.to(device)

    # 特征输出
    net = LeNet().to(device)
    net.load_state_dict(torch.load('./model/net_050.pth'))
    exact_list = ['conv1']
    myexactor = FeatureExtractor(net, exact_list)  # 输出是一个网络
    x = myexactor(img)

    # 特征输出可视化
    for i in range(32):  # 可视化了32通道
        ax = plt.subplot(6, 6, i + 1)
        ax.set_title('Feature {}'.format(i))
        ax.axis('off')
        ax.set_title('new—conv1-image')

        plt.imshow(x[0].data.cpu().numpy()[0,i,:,:],cmap='jet')

    plt.show() 
