# CNN for Particle Classification

In this notebook, we train 10 layers deep CNN for particle type classification ($e^-$, $\mu^-$, and $\gamma$) using the workshop dataset.

In [1]:
from __future__ import print_function
from IPython.display import display
import torch, time
import numpy as np

## Defining a network
Let's define our network. The design below consists of 7 convolution layers + 3 fully-connected layers (10 learnable layers). Here is a summary of the graph operations.
* Feature extractor:
    1. Input shape: (N,88,168,2) ... N samples of 88x168 2D images with 2 channels
    2. Convolution layer + ReLU, 16 filters, kernel size 3x3, stride 1 (default)
    3. 2D max-pooling, kernel size 2, stride 2
    4. 2x Convolution layer + ReLU, 32 filters, kernel size 3x3, stride 1 (default)
    5. 2D max-pooling, kernel size 2, stride 2
    6. 2x Convolution layer + ReLU, 64 filters, kernel size 3x3, stride 1 (default)
    7. 2D max-pooling, kernel size 2, stride 2
    8. 2x Convolution layer + ReLU, 128 filters, kernel size 3x3, stride 1 (default)
* Flattening
    9. 2D average-pooling, kernel size = 2D image spatial dimension at this point (results in length 128 1D array)
* Classifier:
    10. Fully-connected layer + ReLU, 128 filters
    11. Fully-connected layer + ReLU, 128 filters
    12. Fully-connected layer, M filters where M = number of classification categories

In [2]:
class CNN(torch.nn.Module):
    
    def __init__(self, num_class):
        
        super(CNN, self).__init__()
        # feature extractor CNN
        self._feature = torch.nn.Sequential(
            torch.nn.Conv2d(2,16,3), torch.nn.ReLU(),
            torch.nn.MaxPool2d(2,2),
            torch.nn.Conv2d(16,32,3), torch.nn.ReLU(),
            torch.nn.Conv2d(32,32,3), torch.nn.ReLU(),
            torch.nn.MaxPool2d(2,2),
            torch.nn.Conv2d(32,64,3), torch.nn.ReLU(),
            torch.nn.Conv2d(64,64,3), torch.nn.ReLU(),
            torch.nn.MaxPool2d(2,2),
            torch.nn.Conv2d(64,128,3), torch.nn.ReLU(),
            torch.nn.Conv2d(128,128,3), torch.nn.ReLU())
        self._classifier = torch.nn.Sequential(
            torch.nn.Linear(128,128), torch.nn.ReLU(),
            torch.nn.Linear(128,128), torch.nn.ReLU(),
            torch.nn.Linear(128,num_class)
        )

    def forward(self, x):
        net = self._feature(x)
        net = torch.nn.AvgPool2d(net.size()[2:])(net)
        return self._classifier(net.view(-1,128))


## Defining a train loop
For convenience, define a _BLOB_ class to keep objects together. To a BLOB instance, we attach LeNet, our loss function (`nn.CrossEntropyLoss`), and Adam optimizer algorithm. For analysis purpose, we also include `nn.Softmax`. Finally, we attach data and label place holders.

In [3]:
class BLOB:
    pass
blob=BLOB()
blob.net       = CNN(3).cuda() # construct Lenet for 3 class classification, use GPU
blob.criterion = torch.nn.CrossEntropyLoss() # use softmax loss to define an error
blob.optimizer = torch.optim.Adam(blob.net.parameters()) # use Adam optimizer algorithm
blob.softmax   = torch.nn.Softmax(dim=1) # not for training, but softmax score for each class
blob.data      = None # data for training/analysis
blob.label     = None # label for training/analysis

We define 2 functions to be called in the training loop: forward and backward. These functions implement the evaluation of the results, error (loss) definition, and propagation of errors (gradients) back to update the network parameters.

In [4]:
def forward(blob,train=True):
    """
       Args: blob should have attributes, net, criterion, softmax, data, label
       Returns: a dictionary of predicted labels, softmax, loss, and accuracy
    """
    with torch.set_grad_enabled(train):
        # Prediction
        data = torch.as_tensor(blob.data).cuda()#[torch.as_tensor(d).cuda() for d in blob.data]
        data = data.permute(0,3,1,2)
        prediction = blob.net(data)
        # Training
        loss,acc=-1,-1
        if blob.label is not None:
            label = torch.as_tensor(blob.label).type(torch.LongTensor).cuda()#[torch.as_tensor(l).cuda() for l in blob.label]
            label.requires_grad = False
            print(prediction)
            print(label)
            loss = blob.criterion(prediction,label)
        blob.loss = loss
        
        softmax    = blob.softmax(prediction).cpu().detach().numpy()
        prediction = torch.argmax(prediction,dim=-1)
        accuracy   = (prediction == label).sum().item() / float(prediction.nelement())        
        prediction = prediction.cpu().detach().numpy()
        
        return {'prediction' : prediction,
                'softmax'    : softmax,
                'loss'       : loss.cpu().detach().item(),
                'accuracy'   : accuracy}

def backward(blob):
    blob.optimizer.zero_grad()  # Reset gradients accumulation
    blob.loss.backward()
    blob.optimizer.step()


## Running a train loop 
Let's prepare the data loaders for both train and test datasets. We use the latter to check if the network suffers from overtraining.

In [5]:
# Create data loader
from iotools import loader_factory
DATA_DIRS=['/data/hkml_data/IWCDgrid/varyE/e-','/data/hkml_data/IWCDgrid/varyE/mu-','/data/hkml_data/IWCDgrid/varyE/gamma']
# for train
train_loader=loader_factory('H5Dataset', batch_size=200, shuffle=True, num_workers=4, data_dirs=DATA_DIRS, flavour='100k.h5', start_fraction=0.0, use_fraction=0.2)
# for validation
test_loader=loader_factory('H5Dataset', batch_size=200, shuffle=True, num_workers=2, data_dirs=DATA_DIRS, flavour='100k.h5', start_fraction=0.2, use_fraction=0.1)

Also import `CSVData` from our utility module, which lets us write train log (accuracy, loss, etc.) in a csv file.

In [6]:
# Import 0) progress bar and 1) data recording utility (into csv file)
from utils import progress_bar, CSVData
blob.train_log, blob.test_log = CSVData('log_train.csv'), CSVData('log_test.csv')

Finally, we're ready to run the training! Let's create a dataloader, write a loop to  call forward and backward.

In [7]:
# Define train period. "epoch" = N image consumption where N is the total number of train samples.
TRAIN_EPOCH=3.0
# Set the network to training mode
blob.net.train()
epoch=0.
iteration=0
# Start training
while int(epoch+0.5) < TRAIN_EPOCH:
    print('Epoch',int(epoch+0.5),'Starting @',time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()))
    # Create a progress bar for this epoch
    from utils import progress_bar
    progress = display(progress_bar(0,len(train_loader)),display_id=True)
    # Loop over data samples and into the network forward function
    for i,data in enumerate(train_loader):
        # Data and label
        blob.data,blob.label = data[0:2]
        # Call forward: make a prediction & measure the average error
        res = forward(blob,True)
        # Call backward: backpropagate error and update weights
        backward(blob)
        # Epoch update
        epoch += 1./len(train_loader)
        iteration += 1
        
        #
        # Log/Report
        #
        # Record the current performance on train set
        blob.train_log.record(['iteration','epoch','accuracy','loss'],[iteration,epoch,res['accuracy'],res['loss']])
        blob.train_log.write()
        # once in a while, report
        if i==0 or (i+1)%10 == 0:
            message = '... Iteration %d ... Epoch %1.2f ... Loss %1.3f ... Accuracy %1.3f' % (iteration,epoch,res['loss'],res['accuracy'])
            progress.update(progress_bar((i+1),len(train_loader),message))
        # more rarely, run validation
        if (i+1)%100 == 0:
            with torch.no_grad():
                blob.net.eval()
                test_data = next(iter(test_loader))
                blob.data,blob.label = test_data[0:2]
                res = forward(blob,False)
                blob.test_log.record(['iteration','epoch','accuracy','loss'],[iteration,epoch,res['accuracy'],res['loss']])
                blob.test_log.write()
            blob.net.train()
        if epoch >= TRAIN_EPOCH:
            break
    message = '... Iteration %d ... Epoch %1.2f ... Loss %1.3f ... Accuracy %1.3f' % (iteration,epoch,res['loss'],res['accuracy'])
    progress.update(progress_bar((i+1),len(train_loader),message))

blob.test_log.close()
blob.train_log.close()

Epoch 0 Starting @ 2019-04-16 23:02:43


tensor([[-3.8538e-03,  4.6449e-04,  6.5975e-02],
        [-2.9594e-03,  1.3140e-03,  6.5038e-02],
        [-2.6890e-04, -1.3092e-03,  6.4395e-02],
        [-4.4539e-03,  1.6126e-04,  6.5014e-02],
        [-6.3318e-03,  2.8335e-03,  6.6924e-02],
        [-2.9524e-03,  8.8961e-04,  6.6125e-02],
        [-5.2161e-03,  1.2276e-03,  6.6526e-02],
        [-5.3938e-03,  1.5141e-03,  6.6795e-02],
        [-8.2422e-03,  2.6878e-03,  6.6578e-02],
        [-4.8637e-03,  1.4575e-03,  6.6299e-02],
        [-9.8554e-03,  1.6798e-03,  6.6635e-02],
        [-9.3870e-03,  3.1639e-03,  6.6610e-02],
        [-9.8167e-03,  2.9656e-03,  6.6559e-02],
        [ 2.5108e-03, -4.0211e-03,  5.9467e-02],
        [-6.5660e-03,  2.0603e-03,  6.6084e-02],
        [-3.0574e-03,  1.1157e-03,  6.5240e-02],
        [-4.6331e-03,  1.6378e-03,  6.6700e-02],
        [-6.3658e-03,  2.3493e-03,  6.6673e-02],
        [-4.0878e-03,  3.5778e-04,  6.5416e-02],
        [-3.0017e-03,  1.1444e-04,  6.5855e-02],
        [-1.6920e-04

tensor([[-0.0242,  0.0179,  0.0315],
        [-0.0241,  0.0181,  0.0327],
        [-0.0273,  0.0159,  0.0284],
        [-0.0241,  0.0180,  0.0325],
        [-0.0390,  0.0113,  0.0235],
        [-0.0879, -0.0304, -0.0209],
        [-0.0243,  0.0179,  0.0312],
        [-0.0246,  0.0181,  0.0324],
        [-0.0248,  0.0177,  0.0315],
        [-0.0301,  0.0151,  0.0262],
        [-0.0298,  0.0145,  0.0278],
        [-0.0292,  0.0153,  0.0269],
        [-0.0272,  0.0173,  0.0313],
        [-0.0236,  0.0185,  0.0339],
        [-0.0248,  0.0192,  0.0393],
        [-0.0250,  0.0178,  0.0318],
        [-0.0233,  0.0197,  0.0392],
        [-0.0243,  0.0183,  0.0324],
        [-0.0282,  0.0157,  0.0283],
        [-0.0290,  0.0163,  0.0306],
        [-0.1254, -0.0509, -0.0162],
        [-0.0233,  0.0194,  0.0409],
        [-0.0365,  0.0120,  0.0238],
        [-0.0247,  0.0181,  0.0319],
        [-0.0310,  0.0141,  0.0256],
        [-0.0323,  0.0149,  0.0265],
        [-0.0253,  0.0178,  0.0312],
 

tensor([[-7.4248e-02,  1.1230e-02,  6.0666e-02],
        [-1.0604e-01, -3.7288e-04,  9.2191e-02],
        [-5.5657e-02,  2.2113e-02,  4.9561e-02],
        [-7.7282e-02,  1.0605e-02,  6.2528e-02],
        [-1.0727e-01, -8.9834e-04,  9.0456e-02],
        [-8.5642e-02,  4.5552e-03,  7.1396e-02],
        [-8.4214e-02,  5.4685e-03,  6.8467e-02],
        [-8.1446e-02,  8.4294e-03,  6.5649e-02],
        [-8.5944e-02,  5.6685e-03,  7.1647e-02],
        [-9.8126e-02,  7.8742e-04,  8.2238e-02],
        [-9.1860e-02,  3.7208e-03,  7.8951e-02],
        [-8.4438e-02,  6.8911e-03,  6.8656e-02],
        [-7.6765e-02,  1.0863e-02,  6.2509e-02],
        [-7.3977e-02,  1.2515e-02,  6.0548e-02],
        [-1.1200e-01, -4.2639e-03,  9.8043e-02],
        [-5.2200e-01, -2.4845e-01,  5.7205e-01],
        [-7.5890e-02,  1.1424e-02,  6.1463e-02],
        [-7.5418e-02,  1.1874e-02,  6.1056e-02],
        [-9.3077e-02,  4.4020e-03,  7.7240e-02],
        [-7.4056e-02,  1.1549e-02,  6.1456e-02],
        [-6.3386e-02

tensor([[-1.3313e-01, -2.0546e-02,  1.4605e-01],
        [-1.4710e-01, -2.7783e-02,  1.6435e-01],
        [-1.4458e-01, -2.6338e-02,  1.6152e-01],
        [-1.2725e-01, -1.6672e-02,  1.3699e-01],
        [-1.1289e-01, -8.8673e-03,  1.1443e-01],
        [-1.0636e-01, -4.0470e-03,  1.0410e-01],
        [-1.2134e-01, -1.3772e-02,  1.2829e-01],
        [-1.4140e-01, -2.4353e-02,  1.5945e-01],
        [-1.4276e-01, -2.4822e-02,  1.6052e-01],
        [-1.1970e-01, -1.2538e-02,  1.2524e-01],
        [-1.1564e-01, -1.0748e-02,  1.1969e-01],
        [-1.3791e-01, -2.2695e-02,  1.5567e-01],
        [-1.2024e-01, -1.2503e-02,  1.2699e-01],
        [-1.8228e-01, -5.1912e-02,  2.1988e-01],
        [-1.6693e-01, -4.0958e-02,  1.9512e-01],
        [-1.0185e-01, -1.8688e-03,  9.7491e-02],
        [-1.0988e-01, -6.2211e-03,  1.1069e-01],
        [-1.2288e-01, -1.4335e-02,  1.3090e-01],
        [-1.0483e-01, -3.8174e-03,  1.0172e-01],
        [-1.5235e-01, -3.1003e-02,  1.7344e-01],
        [-1.1921e-01

tensor([[-6.5879e-02,  2.4486e-02,  4.8234e-02],
        [-1.0551e-01,  8.8268e-03,  9.2902e-02],
        [-1.0782e-01,  7.5331e-03,  9.3940e-02],
        [-1.1462e-01,  5.2075e-03,  1.0308e-01],
        [-9.3984e-02,  1.3487e-02,  7.7699e-02],
        [-1.1870e-01,  3.9386e-03,  1.0906e-01],
        [-1.1203e-01,  6.5383e-03,  1.0132e-01],
        [-1.0590e-01,  8.6775e-03,  9.2139e-02],
        [-5.9094e-01, -2.4323e-01,  6.7789e-01],
        [-7.3007e-02,  2.0651e-02,  5.4905e-02],
        [-1.3390e-01, -3.6674e-03,  1.3053e-01],
        [-9.8828e-02,  1.0472e-02,  8.4791e-02],
        [-1.5692e-01, -1.3376e-02,  1.6136e-01],
        [-1.2552e-01,  1.2103e-03,  1.1834e-01],
        [-9.8573e-02,  1.1391e-02,  8.2681e-02],
        [-9.8316e-02,  1.1497e-02,  8.2120e-02],
        [-1.0102e-01,  1.0083e-02,  8.6836e-02],
        [-1.0121e-01,  1.0648e-02,  8.6919e-02],
        [-9.5148e-02,  1.2604e-02,  7.9367e-02],
        [-1.1419e-01,  4.9905e-03,  1.0460e-01],
        [-1.4334e-01

tensor([[-6.3266e-02,  3.1798e-02,  3.5970e-02],
        [-9.0486e-01, -3.9419e-01,  1.0682e+00],
        [-1.5783e-01, -2.5979e-03,  1.4259e-01],
        [-7.5226e-02,  2.6263e-02,  4.7033e-02],
        [-8.9262e-02,  2.4756e-02,  6.0802e-02],
        [-7.0328e-02,  2.8119e-02,  4.2068e-02],
        [-1.6807e-01, -7.1802e-03,  1.5609e-01],
        [-7.3020e-02,  2.7271e-02,  4.5126e-02],
        [-1.2301e-01,  1.1735e-02,  9.9275e-02],
        [-6.0764e-02,  3.3626e-02,  3.3497e-02],
        [-1.6213e-01, -5.4692e-03,  1.6059e-01],
        [-7.2895e-02,  2.7444e-02,  4.4177e-02],
        [-1.3901e-01,  5.8917e-03,  1.2584e-01],
        [-7.8090e-02,  2.7475e-02,  4.8773e-02],
        [-7.3426e-02,  2.7351e-02,  4.4682e-02],
        [-7.9933e-02,  2.5590e-02,  5.2312e-02],
        [-7.2169e-02,  2.6930e-02,  4.4490e-02],
        [-9.7492e-02,  1.9930e-02,  7.0021e-02],
        [-6.6000e-02,  2.9479e-02,  3.9182e-02],
        [-7.9744e-02,  2.4656e-02,  5.1956e-02],
        [-1.9981e-01

tensor([[-8.8249e-02,  1.7495e-02,  6.1878e-02],
        [-8.8449e-02,  1.7361e-02,  6.1793e-02],
        [-1.2218e-01,  4.2376e-03,  1.0392e-01],
        [-9.8830e-02,  1.3276e-02,  7.4914e-02],
        [-5.2433e-02,  3.5276e-02,  2.5747e-02],
        [-2.0173e-01, -3.8539e-02,  2.2204e-01],
        [-1.4547e-01, -7.8348e-03,  1.3102e-01],
        [-8.9975e-02,  1.7122e-02,  6.3856e-02],
        [-9.9596e-02,  1.1492e-02,  7.7090e-02],
        [-1.9510e-01, -3.1139e-02,  1.9541e-01],
        [-1.4974e-01, -9.4569e-03,  1.4025e-01],
        [-8.7945e-02,  1.7673e-02,  6.2459e-02],
        [-8.1753e-02,  2.0192e-02,  5.3674e-02],
        [-1.0877e-01,  9.8460e-03,  8.7069e-02],
        [-9.7722e-02,  1.3200e-02,  7.3759e-02],
        [-8.0817e-02,  2.0117e-02,  5.3611e-02],
        [-8.1207e-02,  2.0569e-02,  5.2990e-02],
        [-9.4261e-02,  1.3976e-02,  6.9203e-02],
        [-1.8226e-01, -2.4717e-02,  1.8222e-01],
        [-8.9296e-02,  1.7177e-02,  6.3041e-02],
        [-7.7808e-02

tensor([[-1.1149e-01, -8.7545e-03,  1.0664e-01],
        [-8.2822e-01, -5.2322e-01,  1.1219e+00],
        [-6.9806e-02,  1.6961e-02,  4.6871e-02],
        [-1.1540e-01, -1.1530e-02,  1.1907e-01],
        [-2.0191e-01, -6.4935e-02,  2.4716e-01],
        [-2.8976e-01, -1.3076e-01,  3.8364e-01],
        [-1.0975e-01, -7.2748e-03,  1.0426e-01],
        [-2.5392e+00, -1.7984e+00,  3.5476e+00],
        [-1.6618e-01, -4.0177e-02,  1.8001e-01],
        [-2.1060e-01, -7.0682e-02,  2.5263e-01],
        [-1.1668e-01, -1.1215e-02,  1.1543e-01],
        [-1.1015e-01, -6.4340e-03,  1.1077e-01],
        [-9.3722e-02,  3.2213e-03,  8.0702e-02],
        [-8.7900e-02,  5.8144e-03,  7.2791e-02],
        [-1.0601e-01, -4.3943e-03,  9.9744e-02],
        [-9.3014e-02,  3.6388e-03,  7.9312e-02],
        [-1.0879e-01, -6.5902e-03,  1.0252e-01],
        [-1.5015e-01, -3.0964e-02,  1.6926e-01],
        [-9.2268e-02,  2.7875e-03,  7.7709e-02],
        [-1.4591e-01, -2.7887e-02,  1.6217e-01],
        [-1.2030e-01

tensor([[-7.1923e-02,  2.2144e-02,  4.3682e-02],
        [-1.1422e-01,  3.5504e-03,  1.0091e-01],
        [-6.1263e-02,  2.5787e-02,  2.9771e-02],
        [-3.4032e-02,  4.4709e-02,  4.1022e-03],
        [-5.1846e-02,  3.2115e-02,  1.8722e-02],
        [-5.7672e-02,  2.8454e-02,  2.5164e-02],
        [-7.2890e-02,  2.1018e-02,  4.4752e-02],
        [-4.6755e-02,  3.5077e-02,  1.3701e-02],
        [-3.8271e-02,  4.1909e-02,  5.2430e-03],
        [-3.9302e-02,  4.0602e-02,  7.0163e-03],
        [-1.0514e-01,  6.6073e-03,  8.3065e-02],
        [-5.8020e-02,  2.7606e-02,  2.8851e-02],
        [-2.0431e-02,  5.3923e-02,  1.5883e-03],
        [-2.2091e-02,  5.2991e-02,  1.5700e-03],
        [-3.9947e-02,  3.9735e-02,  7.4758e-03],
        [-2.4246e-02,  5.2338e-02,  5.8883e-04],
        [-5.4830e-02,  3.0360e-02,  2.2297e-02],
        [-7.5909e-02,  1.9828e-02,  5.0870e-02],
        [-4.6914e-02,  3.5026e-02,  1.3956e-02],
        [-2.0239e-02,  5.3693e-02,  2.3017e-03],
        [-4.5361e-02

tensor([[-6.2728e-02,  2.8613e-02,  2.8240e-02],
        [-5.0579e-02,  3.3558e-02,  1.6200e-02],
        [-1.3599e-02,  5.5661e-02, -1.0583e-03],
        [-1.3700e-02,  5.6472e-02, -2.3014e-03],
        [-3.6434e-01, -1.1072e-01,  4.0274e-01],
        [-8.5285e-02,  2.1165e-02,  5.8081e-02],
        [-3.6253e-02,  4.2632e-02,  2.7416e-03],
        [-4.1144e-02,  3.8876e-02,  5.5124e-03],
        [-3.2852e-02,  4.5093e-02, -6.0083e-04],
        [-6.6552e-02,  2.8777e-02,  3.2301e-02],
        [-2.2463e-02,  5.3304e-02, -5.1929e-03],
        [-6.1525e-02,  2.9160e-02,  2.6932e-02],
        [-5.3683e-02,  3.2358e-02,  1.7565e-02],
        [-1.3750e+00, -6.6061e-01,  1.6110e+00],
        [-5.5262e-02,  3.1804e-02,  1.9636e-02],
        [-3.8062e-02,  4.0928e-02,  3.2493e-03],
        [-4.9327e-02,  3.4102e-02,  1.3334e-02],
        [-5.7127e-02,  3.2193e-02,  2.0058e-02],
        [-1.8617e-02,  5.4581e-02, -4.1017e-03],
        [-4.5584e-01, -1.6262e-01,  5.1930e-01],
        [-1.0754e-01

tensor([[-4.4248e-02,  3.1477e-02,  1.1793e-02],
        [-1.9750e-01, -4.3483e-02,  2.0612e-01],
        [-2.2822e-01, -5.0178e-02,  2.5461e-01],
        [-9.4796e-02,  9.2929e-03,  7.6629e-02],
        [-1.6843e-01, -2.3092e-02,  1.7361e-01],
        [-2.0543e-02,  4.8753e-02, -4.0902e-03],
        [-2.1414e-02,  4.8447e-02, -4.3123e-03],
        [-3.5852e-01, -1.2569e-01,  4.0832e-01],
        [-2.2799e-02,  4.7617e-02, -3.9459e-03],
        [-7.5752e-02,  1.7473e-02,  5.3366e-02],
        [-7.3882e-02,  1.8771e-02,  5.0061e-02],
        [-2.9188e-01, -8.9508e-02,  3.1878e-01],
        [-2.8976e-02,  4.2888e-02, -1.0738e-03],
        [-4.0179e-02,  3.4272e-02,  7.3994e-03],
        [-7.6620e-02,  1.6757e-02,  5.2906e-02],
        [-6.5942e-02,  2.1205e-02,  3.7779e-02],
        [-1.2294e-01, -2.4417e-03,  1.1440e-01],
        [-2.4615e-02,  4.5705e-02, -3.6273e-03],
        [-1.1185e-01,  2.1952e-03,  9.8989e-02],
        [-5.7087e-02,  2.5671e-02,  2.6990e-02],
        [-2.7659e-01

tensor([[-8.0553e-02,  2.0235e-02,  5.0017e-02],
        [-3.2182e-02,  4.1436e-02, -3.4650e-03],
        [-1.0021e-01,  1.3696e-02,  7.3736e-02],
        [-2.7775e-02,  4.4464e-02, -6.7893e-03],
        [-3.2552e-01, -7.8250e-02,  3.5060e-01],
        [-7.1536e-02,  2.3350e-02,  3.7109e-02],
        [-1.3834e-02,  5.4064e-02, -1.1889e-02],
        [-4.0889e-01, -1.2041e-01,  4.4198e-01],
        [-7.0072e-02,  2.3331e-02,  3.6292e-02],
        [-4.6057e-02,  3.4003e-02,  7.4514e-03],
        [-1.8464e-02,  5.0730e-02, -1.0488e-02],
        [-4.5591e-02,  3.3655e-02,  8.3163e-03],
        [-3.1725e-01, -7.5674e-02,  3.3579e-01],
        [-2.2245e-01, -3.2997e-02,  2.2937e-01],
        [-3.3769e-01, -8.5263e-02,  3.6272e-01],
        [-8.3734e-02,  1.9591e-02,  5.4178e-02],
        [-1.2497e-01,  4.9273e-03,  1.0476e-01],
        [-8.4795e-02,  1.8813e-02,  5.4912e-02],
        [-1.0533e-01,  1.1867e-02,  8.0821e-02],
        [-2.7214e-01, -6.0332e-02,  2.7274e-01],
        [-1.2697e-02

tensor([[-2.7735e-01, -4.7987e-02,  2.6024e-01],
        [-1.0302e-01,  2.0916e-02,  7.1134e-02],
        [-4.6427e-02,  4.0370e-02, -1.8004e-04],
        [-1.1955e-01,  1.5539e-02,  9.0797e-02],
        [-1.1272e-01,  1.7004e-02,  8.1367e-02],
        [-6.9215e-02,  3.3460e-02,  3.3267e-02],
        [-9.3836e-02,  2.4093e-02,  5.9663e-02],
        [-9.0933e-03,  5.7603e-02, -1.8671e-02],
        [-1.0298e-01,  2.4069e-02,  7.4167e-02],
        [-1.9525e-02,  5.0531e-02, -1.5339e-02],
        [-1.0747e-01,  1.9470e-02,  7.4556e-02],
        [-1.7357e-01, -4.7563e-05,  1.5558e-01],
        [-1.9343e-01, -9.1918e-03,  1.7553e-01],
        [-8.1458e-02,  2.7526e-02,  4.2954e-02],
        [-2.2884e-01, -2.6607e-02,  2.0830e-01],
        [-1.0266e-01,  2.1157e-02,  6.9388e-02],
        [-1.7181e-01, -8.7759e-03,  1.5349e-01],
        [-1.3914e-01,  1.1435e-02,  1.1599e-01],
        [-8.8241e-02,  2.5272e-02,  5.3505e-02],
        [-1.5947e-01,  1.0215e-03,  1.3274e-01],
        [-1.4263e-01

KeyboardInterrupt: 

## Inspecting the training process
Let's plot the train log for both train and test set.

In [None]:
import pandas as pd
import matplotlib.pyplot as plt
%matplotlib inline

train_log = pd.read_csv(blob.train_log.name)
test_log  = pd.read_csv(blob.test_log.name)

fig, ax1 = plt.subplots(figsize=(12,8),facecolor='w')
line11 = ax1.plot(train_log.epoch, train_log.loss, linewidth=2, label='Train loss', color='b', alpha=0.3)
line12 = ax1.plot(test_log.epoch, test_log.loss, marker='o', markersize=12, linestyle='', label='Test loss', color='blue')
ax1.set_xlabel('Epoch',fontweight='bold',fontsize=24,color='black')
ax1.tick_params('x',colors='black',labelsize=18)
ax1.set_ylabel('Loss', fontsize=24, fontweight='bold',color='b')
ax1.tick_params('y',colors='b',labelsize=18)

ax2 = ax1.twinx()
line21 = ax2.plot(train_log.epoch, train_log.accuracy, linewidth=2, label='Train accuracy', color='r', alpha=0.3)
line22 = ax2.plot(test_log.epoch, test_log.accuracy, marker='o', markersize=12, linestyle='', label='Test accuracy', color='red')

ax2.set_ylabel('Accuracy', fontsize=24, fontweight='bold',color='r')
ax2.tick_params('y',colors='r',labelsize=18)
ax2.set_ylim(0.,1.0)

# added these four lines
lines  = line11 + line12 + line21 + line22
labels = [l.get_label() for l in lines]
leg    = ax1.legend(lines, labels, fontsize=16, loc=5)
leg_frame = leg.get_frame()
leg_frame.set_facecolor('white')

plt.grid()
plt.show()

We see the loss is coming down while the accuracy is increasing. These two should be anti-correlated, so this is expected. We also see the network performance on the test dataset (circles) follow those of train dataset (lines). This means there is no apparent overtraining.

**Question: is the network still learning?**
Both the loss and accuracy curve have large fluctuations and it is somewhat hard to see if the values are still changing. Let's plot the moving average of the loss and accuracy values.

In [None]:
def moving_average(a, n=3) :
    ret = np.cumsum(a, dtype=float)
    ret[n:] = ret[n:] - ret[:-n]
    return ret[n - 1:] / n

epoch    = moving_average(np.array(train_log.epoch),40)
accuracy = moving_average(np.array(train_log.accuracy),40)
loss     = moving_average(np.array(train_log.loss),40)

fig, ax1 = plt.subplots(figsize=(12,8),facecolor='w')
line11 = ax1.plot(train_log.epoch, train_log.loss, linewidth=2, label='Loss', color='b', alpha=0.3)
line12 = ax1.plot(epoch, loss, label='Loss (averaged)', color='blue')
ax1.set_xlabel('Epoch',fontweight='bold',fontsize=24,color='black')
ax1.tick_params('x',colors='black',labelsize=18)
ax1.set_ylabel('Loss', fontsize=24, fontweight='bold',color='b')
ax1.tick_params('y',colors='b',labelsize=18)

ax2 = ax1.twinx()
line21 = ax2.plot(train_log.epoch, train_log.accuracy, linewidth=2, label='Accuracy', color='r', alpha=0.3)
line22 = ax2.plot(epoch, accuracy, label='Accuracy (averaged)', color='red')

ax2.set_ylabel('Accuracy', fontsize=24, fontweight='bold',color='r')
ax2.tick_params('y',colors='r',labelsize=18)
ax2.set_ylim(0.,1.0)

# added these four lines
lines  = line11 + line12 + line21 + line22
labels = [l.get_label() for l in lines]
leg    = ax1.legend(lines, labels, fontsize=16, loc=5)
leg_frame = leg.get_frame()
leg_frame.set_facecolor('white')

plt.grid()
plt.show()


The tick lines now represent the moving average (all data points are from the train log). It appears the network is still learning. We can train for a longer period to achieve a better accuracy.

## Performance Analysis
Beyond looking at the performance of the network, we can analyze how the network is performing for each classification target. Let's first obtain a high-statistics analysis output by running the network on all test samples.

In [None]:
def inference(blob,data_loader):
    label,prediction,accuracy=[],[],[]
    # set the network to test (non-train) mode
    blob.net.eval()
    # create the result holder
    index,label,prediction = [],[],[]
    for i,data in enumerate(data_loader):
        blob.data, blob.label = data[0:2]
        res = forward(blob,True)
        accuracy.append(res['accuracy'])
        prediction.append(res['prediction'])
        label.append(blob.label)
        #if i==2: break
    # report accuracy
    accuracy   = np.array(accuracy,dtype=np.float32)
    label      = np.hstack(label)
    prediction = np.hstack(prediction)
    
    return accuracy, label, prediction

Let's run the inference using this function on the test sample, and look at the error matrix.

In [None]:
from utils import plot_confusion_matrix
accuracy,label,prediction = inference(blob,test_loader)
print('Accuracy mean',accuracy.mean(),'std',accuracy.std())
plot_confusion_matrix(label,prediction,['gamma','electron','muon'])

As one may expect, muon is distinguished fairly well while there is some confusion between electron and gamma ray.