In [None]:
import torch as t
import torch.nn as nn
from torch.nn import functional as F
import numpy as np
from collections import OrderedDict
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from matplotlib import pyplot as plt

from sklearn.metrics import accuracy_score,precision_score,recall_score,roc_curve,auc,average_precision_score,precision_recall_curve

In [None]:
class ResLayer(nn.Sequential):
    def __init__(self,in_channel,out_channel,ks = 1,p = 0,downsample = None,s = 1):
        super(ResLayer,self).__init__()
        self.add_module("conv1",nn.Conv2d(in_channel,out_channel,kernel_size=ks,padding=p,bias=False,stride=s))
        self.add_module("norm1",nn.BatchNorm2d(out_channel))
        self.add_module("relu1",nn.LeakyReLU())
        self.add_module("conv2",nn.Conv2d(out_channel,out_channel,kernel_size=3,padding=1,bias=False,stride=1))
        self.add_module("norm2",nn.BatchNorm2d(out_channel))
        # self.drop_rate = drop_rate
        self.downsample = downsample
        self.relu = nn.LeakyReLU()

    def forward(self,x):
        residual = x
        new_features = super(ResLayer,self).forward(x)
        if self.downsample:
            residual = self.downsample(residual)
        new_features += residual

        # relu
        return self.relu(new_features)

class ResBlock(nn.Sequential):
    def __init__(self,num_layers,in_channel,out_channel,ks = 1,p = 0,downsample = None,s = 1):
        super(ResBlock,self).__init__()
        for i in range(num_layers):
            layer = ResLayer(in_channel,out_channel,ks,p,downsample,s)
            self.add_module("reslayer%d" % (i+1),layer)

class ResClass(nn.Module):
    def __init__(self,resb = (3,6,9,6)): # 6,9,6
        super(ResClass,self).__init__()
        self.in_channel = 2
        
        self.features = nn.Sequential(OrderedDict([
            ("conv0",nn.Conv2d(2,16,3,padding=1)),
            ("norm0",nn.BatchNorm2d(16)),
            ("relu0",nn.LeakyReLU())
        ]))
        

        in_channel = 16
        for i,layer_num in enumerate(resb):
                block = ResBlock(layer_num,in_channel,in_channel)
                self.features.add_module("resblock%d"%(i+1),block)
                if(i in (0,2,3)):
                    up = self.Upchannel(in_channel,in_channel*2,s = 5)
                    self.features.add_module("up%d"%(i+1),up)
                    in_channel = in_channel * 2

        self.features.add_module("convv",nn.Conv2d(128,128,kernel_size=(1,2),bias=False))
        self.features.add_module("normvv",nn.BatchNorm2d(128))
        self.features.add_module("reluvv",nn.LeakyReLU())

        self.avg = nn.AvgPool2d((5,1))

        self.fc = nn.Sequential(OrderedDict([
            ("fc1",nn.Linear(1024,512)),
            ('fcr1',nn.LeakyReLU()),
            ("fc2",nn.Linear(512,64)),
            ('fcr2',nn.LeakyReLU()),
            ("fc3",nn.Linear(64,2))
        ]))

        self.sigmoid = nn.Sigmoid()

    def Upchannel(self,inc,outc,s = 2):
        return nn.Sequential(nn.Conv2d(inc,outc,3,padding = 1,stride=(s,1),bias=False),
                            nn.BatchNorm2d(outc),
                            nn.LeakyReLU())

    def forward(self,x):
        out = self.features(x)
        out = self.avg(out)
        out = out.view(-1,8*128)
        out = self.fc(out)

        return self.sigmoid(out)


In [None]:
def roc(l,y):
    fposrate,recall,thresholds = roc_curve(l,y)
    # pre[:,1]

    roc_auc = auc(fposrate,recall)
    fig = plt.figure(figsize=(8,8))
    plt.title('ROC')
    plt.plot(fposrate,recall,'b',label='AUC=%0.2f' % roc_auc)
    plt.legend(loc = 'lower right')
    plt.plot([0,1],[0,1],'r--')
    plt.xlim([0,1])
    plt.ylim([0,1])
    plt.ylabel('recall')
    plt.xlabel('Fall-out')
    plt.close()
    
    return fig

def Aupr(l,y):
    aupr = average_precision_score(l,y)
    precision,recall,thresholds = precision_recall_curve(l,y)
    fig = plt.figure(figsize=(8,8))
    plt.title("PR curve")
    plt.xlabel('Recall')
    plt.ylabel("Precision")
    plt.plot(recall,precision,'b',label='AUPR=%0.2f' % aupr)
    plt.xlim([0,1])
    plt.ylim([0,1])
    plt.legend(loc = 'lower left')
    plt.close()
    return fig

class MyDataSet(Dataset):
    def __init__(self,fp):
        file = np.load(fp)
        # self.yAE = file['AE']
        self.X = file["Sample"]
        self.yl = file['label']
        print(self.yl.shape)

    def __getitem__(self,index):
        
        return self.X[index],self.yl[index]

    def __len__(self):
        
        return self.yl.shape[0]

validd = MyDataSet("test.npz")
valid_loader = DataLoader(validd,batch_size = 1,shuffle=True)

In [8]:
# device = t.device('cpu')
model = t.load("smodel/node4-30.pt",map_location = t.device('cpu'))
model.eval()
valid_loader = DataLoader(validd,batch_size = 1,shuffle=True)
with t.no_grad():
    total_l = None
    total_yl = None
    total_yp = None
    for num2,(X,l) in enumerate(valid_loader):
        X = X.float()
        # print(l.shape)
        l = l
        y = model(X)
        _,pre = t.max(y.data,1)
        if(num2 == 0):
            total_l = l
            total_yl = pre
            total_yp = y[:,1]
        else:
            total_l = t.hstack((total_l,l))
            total_yl = t.hstack((total_yl,pre))
            total_yp = t.hstack((total_yp,y[:,1]))
    total_l = total_l.cpu()
    total_yl = total_yl.cpu()
    total_yp = total_yp.cpu()
    acc = accuracy_score(total_l,total_yl)
    pre = precision_score(total_l,total_yl)
    rec = recall_score(total_l,total_yl)
    fig = roc(total_l,total_yp)
    fig.savefig("roc.jpg",format = 'jpg')
    fig.clear()
    fig = Aupr(total_l,total_yp)
    fig.savefig("aupr.jpg",format = 'jpg')
    fig.clear()