In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import torchvision
from torch.utils.data import Dataset, DataLoader
import math
import sys
import cmath

import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

from torchvision import datasets
import torchvision.transforms as transforms

import time

In [None]:
class JetTrainData(Dataset):
    
    def __init__(self):

        self.x = torch.from_numpy( np.load('/data/github/data/diHiggs_neutrino_train_data.npy').astype(np.float32) )
        #                                               hh               tt             tw             tth            ttv            llbj           tatabb
        self.y = torch.from_numpy( np.concatenate((np.ones(19000), np.zeros(9900), np.zeros(5500), np.zeros(200), np.zeros(250), np.zeros(1000), np.zeros(60))).astype(np.int) )
        self.n_samples = 35910
                
    def __getitem__(self,index):
        
        return self.x[index], self.y[index]
        
    def __len__(self):
        return self.n_samples

    
class JetTestData(Dataset):
    
    def __init__(self):

        self.x = torch.from_numpy( np.load('/data/github/data/diHiggs_neutrino_test_data.npy').astype(np.float32) )
        #                                               hh              tt              tw            tth            ttv            llbj          tatabb
        self.y = torch.from_numpy( np.concatenate((np.ones(5000), np.zeros(19000), np.zeros(800), np.zeros(100), np.zeros(100), np.zeros(240), np.zeros(10))).astype(np.int) )
        self.n_samples = 25250
                
    def __getitem__(self,index):
        
        return self.x[index], self.y[index]
        
    def __len__(self):
        return self.n_samples

In [None]:
num_workers = 0

batch_size = 20

train_data = JetTrainData()
test_data = JetTestData()

train_loader = DataLoader(dataset=train_data, batch_size = batch_size, shuffle=True, num_workers=num_workers)
test_loader = DataLoader(dataset=test_data, batch_size = batch_size, shuffle=False, num_workers=num_workers)

dataiter = iter(train_loader)
images, labels = dataiter.next()
images = images.numpy()

fig = plt.figure(figsize=(25, 4))
for idx in np.arange(batch_size):
    ax = fig.add_subplot(2, int(batch_size/2), idx+1, xticks=[], yticks=[])
    ax.imshow(images[idx,2], cmap='gray')
    ax.set_title(str(labels[idx].item()))

In [None]:
class ConvLayer(nn.Module):
    def __init__(self, in_channels=1, out_channels=256):
        super(ConvLayer, self).__init__()

        self.conv1 = nn.Conv3d(1, 256, 
                              kernel_size=(5,3,3), stride=1, padding=0)

        self.conv2 = nn.Conv2d(256, 256, 
                              kernel_size=(3,3), stride=1, padding=0)

        self.conv3 = nn.Conv2d(256, 256, 
                              kernel_size=(3,3), stride=1, padding=0)
        
        self.conv4 = nn.Conv2d(256, 256, 
                              kernel_size=(3,3), stride=1, padding=0)
        
    def forward(self, x):
        x=self.conv1(x.unsqueeze(1))
        x=(F.relu(x)).squeeze(2)
        x=self.conv2(x)
        x=F.relu(x)
        x=self.conv3(x)
        x=F.relu(x)  
        x=self.conv4(x)
        x=F.relu(x) 

        return x

In [None]:
class PrimaryCaps(nn.Module):
    def __init__(self, num_capsules=8, in_channels=256, out_channels=32):
        super(PrimaryCaps, self).__init__()
        self.capsules = nn.ModuleList([
            nn.Conv2d(in_channels=in_channels, out_channels=out_channels, 
                      kernel_size=3, stride=1, padding=0)
            for _ in range(num_capsules)])
    
    def forward(self, x):
        batch_size = x.size(0)
        u = [capsule(x).view(batch_size,  32* 40 * 40, 1) for capsule in self.capsules]
        u = torch.cat(u, dim=-1)
        u_squash = self.squash(u)
        
        return u_squash
    
    def squash(self, input_tensor):
        squared_norm = (input_tensor ** 2).sum(dim=-1, keepdim=True)
        scale = squared_norm / (1 + squared_norm)
        output_tensor = scale * input_tensor / torch.sqrt(squared_norm)    
        
        return output_tensor

In [None]:
import helpers

def dynamic_routing(b_ij, u_hat, squash, routing_iterations=3):
    for iteration in range(routing_iterations):
        c_ij = helpers.softmax(b_ij, dim=2)
        
        s_j = (c_ij * u_hat).sum(dim=2, keepdim=True)
        
        v_j = squash(s_j)

        if iteration < routing_iterations - 1:
            a_ij = (u_hat * v_j).sum(dim=-1, keepdim=True)
            
            b_ij = b_ij + a_ij
    
    return v_j

In [None]:
TRAIN_ON_GPU = torch.cuda.is_available()

if(TRAIN_ON_GPU):
    print('Training on GPU!')
else:
    print('Only CPU available')

In [None]:
class DigitCaps(nn.Module):
    def __init__(self, num_capsules=2, previous_layer_nodes=32*40*40, 
                 in_channels=8, out_channels=16):
        super(DigitCaps, self).__init__()

        self.num_capsules = num_capsules
        self.previous_layer_nodes = previous_layer_nodes 
        self.in_channels = in_channels

        self.W = nn.Parameter(torch.randn(num_capsules, previous_layer_nodes, 
                                          in_channels, out_channels))

    def forward(self, u):
        u = u[None, :, :, None, :]
        W = self.W[:, None, :, :, :]
        
        u_hat = torch.matmul(u, W)
        
        b_ij = torch.zeros(*u_hat.size())

        if TRAIN_ON_GPU:
            b_ij = b_ij.cuda()

        v_j = dynamic_routing(b_ij, u_hat, self.squash, routing_iterations=3)

        return v_j
    
    
    def squash(self, input_tensor):
        squared_norm = (input_tensor ** 2).sum(dim=-1, keepdim=True)
        scale = squared_norm / (1 + squared_norm)
        output_tensor = scale * input_tensor / torch.sqrt(squared_norm)    

        return output_tensor

In [None]:
class Decoder(nn.Module):
    def __init__(self, input_vector_length=16, input_capsules=2, hidden_dim=512):
        super(Decoder, self).__init__()

        input_dim = input_vector_length * input_capsules
        
        self.linear_layers = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(inplace=True),
            nn.Linear(hidden_dim, hidden_dim*2),
            nn.ReLU(inplace=True),
            nn.Linear(hidden_dim*2, 5*50*50),
            nn.Sigmoid()
            )
        
    def forward(self, x):
        classes = (x ** 2).sum(dim=-1) ** 0.5
        classes = F.softmax(classes, dim=-1)
        
        _, max_length_indices = classes.max(dim=1)
        
        sparse_matrix = torch.eye(2)
        if TRAIN_ON_GPU:
            sparse_matrix = sparse_matrix.cuda()

        y = sparse_matrix.index_select(dim=0, index=max_length_indices.data)
        
        x = x * y[:, :, None]
        
        flattened_x = x.contiguous().view(x.size(0), -1)
        
        reconstructions = self.linear_layers(flattened_x)

        return reconstructions, y

In [None]:
class CapsuleNetwork(nn.Module):
    
    def __init__(self):
        super(CapsuleNetwork, self).__init__()
        self.conv_layer = ConvLayer()
        self.primary_capsules = PrimaryCaps()
        self.digit_capsules = DigitCaps()
        self.decoder = Decoder()
                
    def forward(self, images):
        primary_caps_output = self.primary_capsules(self.conv_layer(images))
        caps_output = self.digit_capsules(primary_caps_output).squeeze().transpose(0,1)
        reconstructions, y = self.decoder(caps_output)
        
        return caps_output, reconstructions, y

In [None]:
capsule_net = CapsuleNetwork()

print(capsule_net)

if TRAIN_ON_GPU:
    capsule_net = capsule_net.cuda()

In [None]:
class CapsuleLoss(nn.Module):
    
    def __init__(self):
        super(CapsuleLoss, self).__init__()
        self.reconstruction_loss = nn.MSELoss(reduction='sum')

    def forward(self, x, labels, images, reconstructions):
        batch_size = x.size(0)

        v_c = torch.sqrt((x**2).sum(dim=2, keepdim=True))

        left = F.relu(0.9 - v_c).view(batch_size, -1)
        right = F.relu(v_c - 0.1).view(batch_size, -1)

        margin_loss = labels * left + 0.5 * (1. - labels) * right
        margin_loss = margin_loss.sum()

        images = images.view(reconstructions.size()[0], -1)
        reconstruction_loss = self.reconstruction_loss(reconstructions, images)

        return (margin_loss + 0.0005 * reconstruction_loss) / images.size(0)

In [None]:
criterion = CapsuleLoss()
optimizer = optim.Adam(capsule_net.parameters(), lr=1e-4, weight_decay=0)

In [None]:
def train(capsule_net, criterion, optimizer, epoch, print_every=300):
    
    class_correct = list(0. for i in range(2))
    class_total = list(0. for i in range(2))

    train_loss = 0.0
        
    capsule_net.train()

    for batch_i, (images, target) in enumerate(train_loader):

        target = torch.eye(2).index_select(dim=0, index=target)

        if TRAIN_ON_GPU:
            images, target = images.cuda(), target.cuda()

        optimizer.zero_grad()
        caps_output, reconstructions, y = capsule_net(images)
        loss = criterion(caps_output, target, images, reconstructions)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        
        _, pred = torch.max(y.data.cpu(), 1)
        _, target_shape = torch.max(target.data.cpu(), 1)

        correct = np.squeeze(pred.eq(target_shape.data.view_as(pred)))
        for i in range(len(target)):
            label = target_shape.data[i]
            class_correct[label] += correct[i].item()
            class_total[label] += 1
            
    loss = train_loss/len(train_loader)
    accuracy = np.sum(class_correct) / np.sum(class_total)
    
    print(f'\n Train Epoch: {epoch} \tLoss: {loss:.6f}')
    
    return loss, accuracy

In [None]:
def test(capsule_net, test_loader):

    DNN_score = open('/data/github/result/CapsNet_DNN_score.TXT', 'a')    
    
    class_correct = list(0. for i in range(2))
    class_total = list(0. for i in range(2))
    
    test_loss = 0

    capsule_net.eval()
    
    for batch_i, (images, target) in enumerate(test_loader):
        target = torch.eye(2).index_select(dim=0, index=target)

        batch_size = images.size(0)

        if TRAIN_ON_GPU:
            images, target = images.cuda(), target.cuda()

        caps_output, reconstructions, y = capsule_net(images)
        loss = criterion(caps_output, target, images, reconstructions)
        test_loss += loss.item() 
        _, pred = torch.max(y.data.cpu(), 1)
        _, target_shape = torch.max(target.data.cpu(), 1)

        if batch_i < int(len(test_data)/batch_size):
            for n in range(0,batch_size):
                hh=math.sqrt((caps_output[n][1]**2).sum())
                DNN_score.write(str(hh))
                DNN_score.write('\n')                   
                       
        if int(len(target)/batch_size) == batch_i:
            for n in range (0,int((len(test_data)/batch_size-int(len(test_data)/batch_size))*batch_size)):
                hh=math.sqrt((caps_output[n][1]**2).sum())
                DNN_score.write(str(hh))
                DNN_score.write('\n')         

        correct = np.squeeze(pred.eq(target_shape.data.view_as(pred)))
        for i in range(len(target)):
            label = target_shape.data[i]
            class_correct[label] += correct[i].item()
            class_total[label] += 1

    loss = test_loss/len(test_loader)
    accuracy = np.sum(class_correct) / np.sum(class_total)
    
    print(f'\nTest set: Loss: {loss:.6f}, Accuracy: {round(np.sum(class_correct))}/{round(np.sum(class_total))}({(100. * accuracy):.2f}%)\n')
    print('----------------------------------------------------')
    
    DNN_score.close()    

    return loss, accuracy

In [None]:
def Like(s, b, u, n):
    return math.e**((n*s+b)*math.log(u*s+b)-math.lgamma(n*s+b+1)-(u*s+b))

def IndiLikeRatioDis(NS1, NB1):
    return math.sqrt(-2*math.log((Like(NS1, NB1, 0.0, 1.0))/(Like(NS1, NB1, 1.0, 1.0))))

In [None]:
DNN_score = open('/data/github/result/CapsNet_DNN_score.TXT', 'w')    

Train_Loss = []
Train_Accuracy = []
Test_Loss = []
Test_Accuracy = []

epochs = 1
best_sig = 0

for epoch in range(1, epochs + 1):
    start=time.time()
    print('########### Training epoch {} start ###########'.format(epoch))
    
    train_loss, train_accuracy = train(capsule_net, criterion, optimizer, epoch)
    test_loss, test_accuracy = test(capsule_net, test_loader)

    Train_Loss.append(train_loss)
    Train_Accuracy.append(train_accuracy)
    Test_Loss.append(test_loss)
    Test_Accuracy.append(test_accuracy)
    
    Results = np.loadtxt('/data/github/result/CapsNet_DNN_score.TXT')

    epoch_ = epoch

    Results_hh_1=Results[len(test_data)*(epoch_-1):len(test_data)*(epoch_-1)+5000]
    Results_tt_1=Results[len(test_data)*(epoch_-1)+5000:len(test_data)*(epoch_-1)+5000+19000]
    Results_tw_1=Results[len(test_data)*(epoch_-1)+5000+19000:len(test_data)*(epoch_-1)+5000+19000+800]
    Results_tth_1=Results[len(test_data)*(epoch_-1)+5000+19000+800:len(test_data)*(epoch_-1)+5000+19000+800+100]
    Results_ttv_1=Results[len(test_data)*(epoch_-1)+5000+19000+800+100:len(test_data)*(epoch_-1)+5000+19000+800+100+100]
    Results_llbj_1=Results[len(test_data)*(epoch_-1)+5000+19000+800+100+100:len(test_data)*(epoch_-1)+5000+19000+800+100+100+240]
    Results_tatabb_1=Results[len(test_data)*(epoch_-1)+5000+19000+800+100+100+240:len(test_data)*epoch_]

    plt.rc('text', usetex=True)
    plt.rc('font', family='Time New Roman')

    logs = False

    axislabels = [ r'$DNN $']

    Yaxislabels = [ r'$DNN $']

    Bmax = 1
    Bmin = 0
    plt.xlim(0, 1)
    bins = np.linspace(Bmin, Bmax,  25)
    plt.hist(Results_hh_1, bins = bins, alpha=1, density=True, histtype='step', align = 'mid', linewidth = 1.5, log=logs, color='black', label= r'$h \; h$')
    plt.hist(Results_tt_1, bins = bins, alpha=1, density=True, histtype='step', align = 'mid', linewidth = 1.5, log=logs, color='blue', label= r'$t \; \overline{t}$')
    plt.hist(Results_tw_1, bins = bins, alpha=1, density=True, histtype='step', align = 'mid', linewidth = 1.5, log=logs, color='red', label= r'$t \; w$')
    plt.hist(Results_tth_1, bins = bins, alpha=1, density=True, histtype='step', align = 'mid', linewidth = 1.5, log=logs, label= r'$t \; \overline{t} \; h$')
    plt.hist(Results_ttv_1, bins = bins, alpha=1, density=True, histtype='step', align = 'mid', linewidth = 1.5, log=logs, label= r'$t \; \overline{t} \; v$')
    plt.hist(Results_llbj_1, bins = bins, alpha=1, density=True, histtype='step', align = 'mid', linewidth = 1.5, log=logs, label= r'$l \; l \; b \; j$')
    plt.hist(Results_tatabb_1, bins = bins, alpha=1, density=True, histtype='step', align = 'mid', linewidth = 1.5, log=logs, label= r'$\tau \; \tau \; b \; b$')
    plt.legend(loc=9,fontsize = 10)
    plt.xlabel(axislabels[0], fontsize = 20)
    plt.ylabel(r'$\rm{(1/\sigma) \; d \sigma / d }$' + Yaxislabels[0]    , fontsize = 20)
    plt.show()
    
    ROC_Results = open('/data/github/result/CapsNet_ROC.TXT'+str(epoch), 'w')
    
    XSig_box = []

    Xbkg_box = []

    Xbkg_tt_box = []

    Xbkg_tw_box = []

    Xbkg_tth_box = []

    Xbkg_ttv_box = []

    Xbkg_llbj_box = []

    Xbkg_tatabb_box = []

    nn = 10000

    Ival = 0.9

    Xreco_Sig = 0.0214964

    Xreco_tt = 120.907 * 1.596

    Xreco_tw = 4.38354

    Xreco_tth = 0.15258 * 1.27

    Xreco_ttv = 0.157968 * 1.54

    Xreco_llbj = 1.22936

    Xreco_tatabb = 0.011392

    for j in range(0, nn):

        roc_sig = 0
        roc_bkg_tt = 0
        roc_bkg_tw = 0
        roc_bkg_tth = 0
        roc_bkg_ttv = 0
        roc_bkg_llbj = 0
        roc_bkg_tatabb = 0

        for i in range(0, len(Results_hh_1)):
            if Results_hh_1[i] > float( Ival + float(0.1*j)/float(nn) ) :
                roc_sig = roc_sig + 1
            
        for i in range(0, len(Results_tt_1 )):
            if Results_tt_1[i] > float( Ival + float(0.1*j)/float(nn) ) :
                roc_bkg_tt = roc_bkg_tt + 1
            
        for i in range(0, len(Results_tw_1) ):
            if Results_tw_1[i] > float( Ival + float(0.1*j)/float(nn) ) :
                roc_bkg_tw = roc_bkg_tw + 1

        for i in range(0, len(Results_tth_1 )):
            if Results_tth_1[i] > float( Ival + float(0.1*j)/float(nn) ) :
                roc_bkg_tth = roc_bkg_tth + 1
            
        for i in range(0, len(Results_ttv_1) ):
            if Results_ttv_1[i] > float( Ival + float(0.1*j)/float(nn) ) :
                roc_bkg_ttv = roc_bkg_ttv + 1

        for i in range(0, len(Results_llbj_1 )):
            if Results_llbj_1[i] > float( Ival + float(0.1*j)/float(nn) ) :
                roc_bkg_llbj = roc_bkg_llbj + 1
            
        for i in range(0, len(Results_tatabb_1) ):
            if Results_tatabb_1[i] > float( Ival + float(0.1*j)/float(nn) ) :
                roc_bkg_tatabb = roc_bkg_tatabb + 1
            
        XSig_box.append( float( float(Xreco_Sig)*float( roc_sig ) / float( len(Results_hh_1) ) )   )

        Xbkg_box.append( float( float(Xreco_tt)*float( roc_bkg_tt ) / float( len(Results_tt_1) ) ) + float( float(Xreco_tw)*float( roc_bkg_tw ) / float( len(Results_tw_1) ) ) + float( float(Xreco_tth)*float( roc_bkg_tth ) / float( len(Results_tth_1) ) ) + float( float(Xreco_ttv)*float( roc_bkg_ttv ) / float( len(Results_ttv_1) ) ) + float( float(Xreco_llbj)*float( roc_bkg_llbj ) / float( len(Results_llbj_1) ) ) + float( float(Xreco_tatabb)*float( roc_bkg_tatabb ) / float( len(Results_tatabb_1) ) )   )

        Xbkg_tt_box.append( float( float(Xreco_tt)*float( roc_bkg_tt ) / float( len(Results_tt_1) ) )  )

        Xbkg_tw_box.append( float( float(Xreco_tw)*float( roc_bkg_tw ) / float( len(Results_tw_1) ) )   )

        Xbkg_tth_box.append( float( float(Xreco_tth)*float( roc_bkg_tth ) / float( len(Results_tth_1) ) )  )

        Xbkg_ttv_box.append( float( float(Xreco_ttv)*float( roc_bkg_ttv ) / float( len(Results_ttv_1) ) )   )

        Xbkg_llbj_box.append( float( float(Xreco_llbj)*float( roc_bkg_llbj ) / float( len(Results_llbj_1) ) )  )

        Xbkg_tatabb_box.append( float( float(Xreco_tatabb)*float( roc_bkg_tatabb ) / float( len(Results_tatabb_1) ) )   )    

    for j in range(0, len(XSig_box) ):

        if float( Xbkg_box[j] ) == 0 :
            break

        Nsig = round( float( 3000*XSig_box[j]*(0.8**2/0.7**2) ), 3)

        Nbkg = round( float( 3000*Xbkg_box[j]*(0.8**2/0.7**2) ), 3)
    
        Nbkg_tt = round( float( 3000*Xbkg_tt_box[j]*(0.8**2/0.7**2) ), 3)
    
        Nbkg_tw = round( float( 3000*Xbkg_tw_box[j]*(0.8**2/0.7**2) ), 3)
    
        Nbkg_tth = round( float( 3000*Xbkg_tth_box[j]*(0.8**2/0.7**2) ), 3)
    
        Nbkg_ttv = round( float( 3000*Xbkg_ttv_box[j]*(0.8**2/0.7**2) ), 3)
    
        Nbkg_llbj = round( float( 3000*Xbkg_llbj_box[j]*(0.8**2/0.7**2) ), 3)
    
        Nbkg_tatabb = round( float( 3000*Xbkg_tatabb_box[j]*(0.8**2/0.7**2) ), 3)

        SobSqrtB = round( float( IndiLikeRatioDis(float( Nsig ),float( Nbkg ) )  ) ,   3 )
        ROC_Results.write(str(Nsig) + ' ' + str(Nbkg) + ' ' + str(SobSqrtB) + ' ' + str(Nbkg_tt) + ' ' + str(Nbkg_tw) + ' ' + str(Nbkg_tth) + ' ' + str(Nbkg_ttv) + ' ' + str(Nbkg_llbj) + ' ' + str(Nbkg_tatabb)   ) 
        ROC_Results.write('\n')

    ROC_Results.close()
    
    ROC_Results= np.loadtxt('/data/github/result/CapsNet_ROC.TXT'+str(epoch))
    
    SB=[]
    hh=[]
    tt=[]
    tw=[]
    tth=[]
    ttv=[]
    llbj=[]
    tatabb=[]

    for n in range(len(ROC_Results)):  
        SB.append(ROC_Results[n][2])
        hh.append(ROC_Results[n][0])
        tt.append(ROC_Results[n][3])
        tw.append(ROC_Results[n][4])
        tth.append(ROC_Results[n][5])
        ttv.append(ROC_Results[n][6])
        llbj.append(ROC_Results[n][7])
        tatabb.append(ROC_Results[n][8])
        
    plt.plot(hh,SB, color='r', label='Significance')
    plt.xlabel(r'$ N_s $', fontsize=20)
    plt.ylabel(r'Significance', fontsize=20)
    plt.legend(loc='best', fontsize=15)
    plt.show()
    
    j=SB.index(max(SB))
    print('\nsignificance: {:.3f} hh: {:.3f} tt: {:.3f} tw: {:.3f} tth: {:.3f} ttv: {:.3f} llbj: {:.3f} tatabb: {:.3f} \n'.format(SB[j], hh[j], tt[j], tw[j], tth[j], ttv[j], llbj[j], tatabb[j]))         
 
    if epoch % 1 == 0:
        sig = max(SB)
        best_sig = max(best_sig,sig)           
        
    end=time.time()
    
    print('* Best Significance : {:.3f} *'.format(best_sig))

    print('Epoch time: {:.2f} mins'.format((end-start)/60))
    print('='*69)         