# Notebook for Development

This notebook replicates what was alrady done in [the previous example](https://github.com/WatChMaL/ExampleNotebooks/blob/master/HKML%20CNN%20Image%20Classification.ipynb) but using functions in the `classification` module. There is no new ML example here. The purpose of this notebook is to make the whole thing short and concise, so that you can use this as a testbed to develop different networks more easily.

In [1]:
from __future__ import print_function
from IPython.display import display
import torch, time
import numpy as np
%matplotlib inline

## Defining a network
Let us define the same network similar to how we did in the [the previous example](https://github.com/WatChMaL/ExampleNotebooks/blob/master/HKML%20CNN%20Image%20Classification.ipynb).

In [2]:
class CNN(torch.nn.Module):
    
    def __init__(self, num_class):
        
        super(CNN, self).__init__()
        # feature extractor CNN
        self._feature = torch.nn.Sequential(
            torch.nn.Conv2d(2,16,3), torch.nn.ReLU(),
            torch.nn.MaxPool2d(2,2),
            torch.nn.Conv2d(16,32,3), torch.nn.ReLU(),
            torch.nn.Conv2d(32,32,3), torch.nn.ReLU(),
            torch.nn.MaxPool2d(2,2),
            torch.nn.Conv2d(32,64,3), torch.nn.ReLU(),
            torch.nn.Conv2d(64,64,3), torch.nn.ReLU(),
            torch.nn.MaxPool2d(2,2),
            torch.nn.Conv2d(64,128,3), torch.nn.ReLU(),
            torch.nn.Conv2d(128,128,3), torch.nn.ReLU()
        )
        self._classifier = torch.nn.Sequential(
            torch.nn.Linear(128,128), torch.nn.ReLU(),
            torch.nn.Linear(128,128), torch.nn.ReLU(),
            torch.nn.Linear(128,num_class)
        )

    def forward(self, x):
        net = self._feature(x)
        net = torch.nn.AvgPool2d(net.size()[2:])(net)
        return self._classifier(net.view(-1,128))

## Preparing a _blob_

In [3]:
class BLOB:
    pass
blob=BLOB()
blob.net       = CNN(4).cuda() # construct Lenet for 3 class classification, use GPU
blob.criterionPID = torch.nn.CrossEntropyLoss() # use softmax loss to define an error
blob.alpha = 4*1e-3
#blob.criterionE = torch.nn.MSELoss() # use softmax loss to define an error
blob.criterionE = torch.nn.SmoothL1Loss()
blob.optimizer = torch.optim.Adam(blob.net.parameters(),weight_decay=0.001) # use Adam optimizer algorithm
blob.softmax   = torch.nn.Softmax(dim=1) # not for training, but softmax score for each class
blob.data      = None # data for training/analysis
blob.labelPID     = None # label for training/analysis
blob.labelE = None

# Create data loader
from iotools import loader_factory
DATA_DIRS=['/data/hkml_data/IWCDgrid/varyE/e-','/data/hkml_data/IWCDgrid/varyE/mu-','/data/hkml_data/IWCDgrid/varyE/gamma']
# for train
blob.train_loader=loader_factory('H5Dataset', batch_size=64, shuffle=True, num_workers=4, data_dirs=DATA_DIRS, flavour='100k.h5', start_fraction=0.0, use_fraction=0.9, read_keys=["energies"])
# for validation
blob.test_loader=loader_factory('H5Dataset', batch_size=200, shuffle=True, num_workers=2, data_dirs=DATA_DIRS, flavour='100k.h5', start_fraction=0.9, use_fraction=0.1, read_keys=["energies"])

# Create & attach data recording utility (into csv file)
from utils import CSVData
blob.train_log, blob.test_log = CSVData('log_train.csv'), CSVData('log_test.csv')

## Running a train loop 

In [None]:
from classification import train_loop
train_loop(blob,50.)

Epoch 0 Starting @ 2019-04-18 02:10:27


Epoch 1 Starting @ 2019-04-18 02:16:21


Epoch 2 Starting @ 2019-04-18 02:22:22


Epoch 3 Starting @ 2019-04-18 02:28:31


Epoch 4 Starting @ 2019-04-18 02:34:56


Epoch 5 Starting @ 2019-04-18 02:40:57


Epoch 6 Starting @ 2019-04-18 02:46:58


Epoch 7 Starting @ 2019-04-18 02:53:04


Epoch 8 Starting @ 2019-04-18 02:59:25


Epoch 9 Starting @ 2019-04-18 03:05:48


Epoch 10 Starting @ 2019-04-18 03:12:09


Epoch 11 Starting @ 2019-04-18 03:18:33


Epoch 12 Starting @ 2019-04-18 03:24:45


Epoch 13 Starting @ 2019-04-18 03:30:59


Epoch 14 Starting @ 2019-04-18 03:36:57


Epoch 15 Starting @ 2019-04-18 03:43:08


Epoch 16 Starting @ 2019-04-18 03:49:22


Epoch 17 Starting @ 2019-04-18 03:55:25


Epoch 18 Starting @ 2019-04-18 04:01:22


Epoch 19 Starting @ 2019-04-18 04:07:32


Epoch 20 Starting @ 2019-04-18 04:13:33


Epoch 21 Starting @ 2019-04-18 04:19:50


## Inspecting the training process

In [None]:
from classification import plot_log
plot_log(blob.train_log.name,blob.test_log.name)

In [None]:
torch.save(blob.net.state_dict(), "cvn_10epochs_huberloss_EkinAboveThr.cnn")

## Performance Analysis

In [None]:
from classification import inference
accuracy, labelPID, predictionPID,labelE, predictionE = inference(blob,blob.test_loader)
print('Accuracy mean',accuracy.mean(),'std',accuracy.std())

Plot the confusion matrix

In [None]:
print(labelPID)
print(predictionPID)

from utils import plot_confusion_matrix
plot_confusion_matrix(labelPID,predictionPID,['gamma','electron','muon'])

In [None]:
import matplotlib.pyplot as plt
plt.hist2d(x = labelE, y = predictionE, bins = 100)
plt.xlabel("True")
plt.ylabel("Predicted")
#plt.savefig("cnn_50epochs_varyE_EtrueErec.pdf")
#plt.savefig("cnn_50epochs_varyE_EtrueErec.png")
plt.show()

In [None]:
plt.hist2d(x = labelE, y = predictionE/labelE-1., bins = 100, range = ((0, 2500), (-0.4, 0.4)))
plt.xlabel("True")
plt.ylabel(r"$\frac{Predicted}{True}-1$")
#plt.savefig("cnn_50epochs_varyE_EtrueErecFrac.pdf")
#plt.savefig("cnn_50epochs_varyE_EtrueErecFrac.png")
plt.show()

In [None]:
maskGamma = labelPID == 0
maskElectron = labelPID == 1
maskMuon = labelPID == 2


plt.hist(predictionE[maskGamma]/labelE[maskGamma]-1., bins = 100, range = (-0.4, 0.4), histtype="step", label=r"$\gamma$")
plt.hist(predictionE[maskElectron]/labelE[maskElectron]-1., bins = 100, range = (-0.4, 0.4), histtype="step", label=r"$e$")
plt.hist(predictionE[maskMuon]/labelE[maskMuon]-1., bins = 100, range = (-0.4, 0.4), histtype="step", label=r"$\mu$")

plt.xlabel(r"$\frac{Predicted}{True}-1$")

plt.legend()

#plt.savefig("cnn_50epochs_varyE_EtrueErecProj.pdf")
#plt.savefig("cnn_50epochs_varyE_EtrueErecProj.png")

plt.show()