In [133]:
import torch

import matplotlib.pyplot as plt
import numpy as np

import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import Model as kdd

In [134]:
batch_size=400
datamodel=kdd.kdd99Data(batch_size)
#dataloaders
trainloader=datamodel.train_dataloader
testloader=datamodel.test_dataloader

#constant for classes
classes=tuple([i for i in range(23)])

In [135]:
#CNN
class Net(nn.Module):
    def __init__(self):
        super(Net,self).__init__()
        self.convmodel=nn.Sequential(nn.Conv2d(1,6,3,padding=2),
                                    nn.BatchNorm2d(6),
                                    nn.Sigmoid(),
                                    nn.Conv2d(6,16,3),
                                    nn.BatchNorm2d(16),
                                    nn.Sigmoid(),
                                    )
        self.fcmodel=nn.Sequential(nn.Linear(16*8*8,200),
                                    nn.BatchNorm1d(200),
                                    nn.Sigmoid(),
                                    nn.Dropout(p=0.3),
                                    nn.Linear(200,128),
                                    nn.Sigmoid(),
                                    nn.Dropout(p=0.3),
                                    nn.Linear(128,23)
                                    )

    def forward(self,out):
        out=self.convmodel(out)
        out=out.view(-1,16*8*8)
        out=self.fcmodel(out)
        return out


In [136]:
def accuracy(net,testloader):
    correct = 0.
    total = 0.
    with torch.no_grad():
        for data in testloader:
            cvdata, labels = data
            outputs = net(cvdata)
            _, predicted = torch.max(outputs.data, 1)
            # predicted[predicted>0]=1.
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    #print("correct:%d  total:%d" %(correct,total))
    print('Accuracy: %6f %%' % (100 * correct / total))

In [137]:
net = Net()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters(), lr=1e-2)
#scheduler = optim.lr_scheduler.StepLR(optimizer, 6,gamma=0.1)

In [138]:
running_loss = 0.0
mini_batches_count=200
for epoch in range(2):  # loop over the dataset multiple times
    for i, data in enumerate(trainloader, 0):

        # get the inputs; data is a list of [inputs, labels]
        inputs, labels = data

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        # every * mini-batches...
        if i % mini_batches_count == (mini_batches_count-1):    
            #scheduler.step()
            print("[%d,%6d] ava_loss: %.5f"%(epoch+1,i+1,running_loss/mini_batches_count))    
            accuracy(net,testloader)
            running_loss = 0.0

print('Finished Training')

[1,   200] ava_loss: 0.17079
Accuracy: 98.329364 %
[1,   400] ava_loss: 0.05484
Accuracy: 98.781434 %
[1,   600] ava_loss: 0.05643
Accuracy: 98.128294 %
[1,   800] ava_loss: 0.04382
Accuracy: 99.055375 %
[2,   200] ava_loss: 0.04129
Accuracy: 99.197069 %
[2,   400] ava_loss: 0.02739
Accuracy: 99.430526 %
[2,   600] ava_loss: 0.02885
Accuracy: 99.280736 %
[2,   800] ava_loss: 0.02074
Accuracy: 99.538483 %
Finished Training


In [139]:
PATH="../../pth/nsl_kdd_test.pth"
torch.save(net.state_dict(),PATH)

In [140]:
net = Net()
net.load_state_dict(torch.load(PATH))

<All keys matched successfully>

In [141]:
class_correct = list(0 for i in range(23))
class_total = list(0 for i in range(23))
with torch.no_grad():
    for data in testloader:
        images, labels = data
        outputs = net(images)
        _, predicted = torch.max(outputs, 1)
        # predicted[predicted>0]=1
        c = (predicted == labels).squeeze()
        for i in range(len(labels)):
            label = labels[i]
            class_correct[label] += c[i].item()
            class_total[label] += 1


for i in range(23):
    print('total:%d,correct:%d' %(class_total[i],class_correct[i]))
    if class_total[i]!=0:
        print('Accuracy: %.4f %%' % (100 * class_correct[i] / class_total[i]))
    else:
        print('Accuracy:')

total:692,correct:691
Accuracy: 99.8555 %
total:7,correct:0
Accuracy: 0.0000 %
total:1,correct:0
Accuracy: 0.0000 %
total:18,correct:0
Accuracy: 0.0000 %
total:2,correct:0
Accuracy: 0.0000 %
total:369,correct:279
Accuracy: 75.6098 %
total:5,correct:0
Accuracy: 0.0000 %
total:4,correct:0
Accuracy: 0.0000 %
total:0,correct:0
Accuracy:
total:32141,correct:32104
Accuracy: 99.8849 %
total:75,correct:0
Accuracy: 0.0000 %
total:29294,correct:29177
Accuracy: 99.6006 %
total:1,correct:0
Accuracy: 0.0000 %
total:2,correct:0
Accuracy: 0.0000 %
total:90,correct:43
Accuracy: 47.7778 %
total:326,correct:274
Accuracy: 84.0491 %
total:3,correct:0
Accuracy: 0.0000 %
total:473,correct:382
Accuracy: 80.7611 %
total:84083,correct:84075
Accuracy: 99.9905 %
total:0,correct:0
Accuracy:
total:281,correct:267
Accuracy: 95.0178 %
total:335,correct:270
Accuracy: 80.5970 %
total:5,correct:0
Accuracy: 0.0000 %
