In [1]:
import csv, struct, torch, h5py
import numpy as np
from torch.utils import data
from chofer_torchex.nn import SLayer

In [2]:
def readPersistenceDiagram(file, dim):
    
    birth = np.empty(0, dtype = float)
    death = np.empty(0, dtype = float)
    
    file.read(24)
    
    for s in iter(lambda: file.read(24), b''):
        
        d = int(struct.unpack('<q', s[:8])[0])
        
        if d == dim:
            
            birth = np.append(birth, struct.unpack('<d', s[8:16])[0])
            death = np.append(death, struct.unpack('<d', s[-8:])[0])
    
    return [birth, death]

In [3]:
class PHdata(data.Dataset):
    
    def __init__(self, totnum, validate = False, nu = 1.0):
        
        self.nu = nu
        self.dset = h5py.File('./PersistenceDiagrams/jetPDs_Mass60-95_pT250-300_R1.25_Pix25.hdf5','r')
        self.totnum = totnum
        
    def __len__(self):
        return self.totnum
    
    def __getitem__(self, idx):
        
        #if validate: idx += self.totnum
            
        label = int(self.dset['signal'][idx])
        pd_birth = self.dset['pd_birth'][idx]
        pd_death = self.dset['pd_death'][idx]
        
        x = (pd_birth+pd_death)/np.sqrt(2)
        y = (-pd_birth+pd_death)/np.sqrt(2)
        
        #for i, bo in enumerate(y < self.nu):    
        #    if bo: 
        #        y[i] = np.log(y[i]/self.nu) + self.nu
        
        mset = torch.FloatTensor(list(zip(x, y)))
        
        return [mset, label]

In [4]:
class MyModel(torch.nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        sharpness = torch.ones(100, 2)*0.01
        self.input = SLayer(100, 2, sharpness_init = sharpness)
        self.hidden1 = torch.nn.Linear(100, 70)
        self.hidden2 = torch.nn.Linear(70, 40)
        self.hidden3 = torch.nn.Linear(40, 10)
        self.hidden4 = torch.nn.Linear(10, 5)
        self.output = torch.nn.Linear(5, 2)
        self.relu = torch.nn.ReLU()
        
    def forward(self, mset):
        x = self.input(mset)
        
#        for i, xx in enumerate(x):
#            x[i] = xx / len(mset[i])
        
        x = self.relu(x)
        x = self.relu(self.hidden1(x))
        x = self.relu(self.hidden2(x))
        x = self.relu(self.hidden3(x))
        x = self.relu(self.hidden4(x))
        x = self.output(x)
        return x

In [5]:
def mycollate_fn(batch): 
    
    batch_input = []
    batch_label = []
    
    for x in batch:
        
        batch_input.append(x[0])
        batch_label.append(x[1])
    
    return (batch_input, torch.LongTensor(batch_label))

In [6]:
def validate(model, totnum):
    
    model.load_state_dict(torch.load('./TrainedModels/trained_params.dat'))
    
    dataset = PHdata(totnum, validate = True)
    loader = data.DataLoader(dataset, 100, shuffle = False, collate_fn = mycollate_fn)
    
    csv_file = open('./Statistics/t-statistic.csv', mode = 'w')
    csv_writer = csv.writer(csv_file, delimiter=',', quotechar='"', quoting=csv.QUOTE_MINIMAL)
    
    for i, (batch_input, labels) in enumerate(loader):
        outputs = model(batch_input)   
        outputs = outputs.cpu()
        outputs = outputs.detach().numpy()
    
        for j in range(len(outputs)):
            val = np.exp(outputs[j])
            csv_writer.writerow([val[1]/(val[0]+val[1])])
    
    csv_file.close()
    
    return

In [7]:
def train(model, batchsize, num_epoch, totnum = 2000, lr = 0.005, momentum = 0.9, fn = mycollate_fn):
    
    dataset = PHdata(totnum)
    loader = data.DataLoader(dataset, batchsize, shuffle = True, collate_fn = fn)
    optimizer = torch.optim.SGD(model.parameters(), lr, momentum)
    criterion = torch.nn.CrossEntropyLoss()
    
    for epoch in range(num_epoch):
    
        for i, (batch_input, labels) in enumerate(loader):
        
            optimizer.zero_grad()

            outputs = model(batch_input)
        
            loss = criterion(outputs, labels)
            loss.backward()
        
            optimizer.step()
    
            print ('Epoch [%d/%d], Loss: %.4f' %(epoch+1, num_epoch, loss.item() ))
            
    torch.save(model.state_dict(), './TrainedModels/trained_params.dat')
    
    return

In [8]:
model = MyModel()

In [9]:
train(model, batchsize = 500, num_epoch = 200)

Epoch [1/200], Loss: 0.7329
Epoch [1/200], Loss: 0.7327
Epoch [1/200], Loss: 0.7322
Epoch [1/200], Loss: 0.7208
Epoch [2/200], Loss: 0.7288
Epoch [2/200], Loss: 0.6991
Epoch [2/200], Loss: 0.7164
Epoch [2/200], Loss: 0.7180
Epoch [3/200], Loss: 0.6992
Epoch [3/200], Loss: 0.7041
Epoch [3/200], Loss: 0.7192
Epoch [3/200], Loss: 0.7108
Epoch [4/200], Loss: 0.6969
Epoch [4/200], Loss: 0.7061
Epoch [4/200], Loss: 0.7151
Epoch [4/200], Loss: 0.7026
Epoch [5/200], Loss: 0.6991
Epoch [5/200], Loss: 0.6954
Epoch [5/200], Loss: 0.7069
Epoch [5/200], Loss: 0.7107
Epoch [6/200], Loss: 0.6949
Epoch [6/200], Loss: 0.7005
Epoch [6/200], Loss: 0.7031
Epoch [6/200], Loss: 0.7060
Epoch [7/200], Loss: 0.6914
Epoch [7/200], Loss: 0.7007
Epoch [7/200], Loss: 0.7004
Epoch [7/200], Loss: 0.7053
Epoch [8/200], Loss: 0.7064
Epoch [8/200], Loss: 0.6954
Epoch [8/200], Loss: 0.6971
Epoch [8/200], Loss: 0.6941
Epoch [9/200], Loss: 0.6943
Epoch [9/200], Loss: 0.7032
Epoch [9/200], Loss: 0.6996
Epoch [9/200], Loss:

Epoch [72/200], Loss: 0.6471
Epoch [72/200], Loss: 0.6540
Epoch [72/200], Loss: 0.6498
Epoch [72/200], Loss: 0.6525
Epoch [73/200], Loss: 0.6574
Epoch [73/200], Loss: 0.6470
Epoch [73/200], Loss: 0.6445
Epoch [73/200], Loss: 0.6483
Epoch [74/200], Loss: 0.6517
Epoch [74/200], Loss: 0.6508
Epoch [74/200], Loss: 0.6462
Epoch [74/200], Loss: 0.6411
Epoch [75/200], Loss: 0.6480
Epoch [75/200], Loss: 0.6471
Epoch [75/200], Loss: 0.6439
Epoch [75/200], Loss: 0.6451
Epoch [76/200], Loss: 0.6359
Epoch [76/200], Loss: 0.6516
Epoch [76/200], Loss: 0.6468
Epoch [76/200], Loss: 0.6428
Epoch [77/200], Loss: 0.6443
Epoch [77/200], Loss: 0.6375
Epoch [77/200], Loss: 0.6389
Epoch [77/200], Loss: 0.6517
Epoch [78/200], Loss: 0.6453
Epoch [78/200], Loss: 0.6389
Epoch [78/200], Loss: 0.6369
Epoch [78/200], Loss: 0.6439
Epoch [79/200], Loss: 0.6383
Epoch [79/200], Loss: 0.6481
Epoch [79/200], Loss: 0.6319
Epoch [79/200], Loss: 0.6408
Epoch [80/200], Loss: 0.6382
Epoch [80/200], Loss: 0.6451
Epoch [80/200]

Epoch [141/200], Loss: 0.5638
Epoch [141/200], Loss: 0.5862
Epoch [141/200], Loss: 0.5581
Epoch [142/200], Loss: 0.5595
Epoch [142/200], Loss: 0.5584
Epoch [142/200], Loss: 0.5893
Epoch [142/200], Loss: 0.5672
Epoch [143/200], Loss: 0.5881
Epoch [143/200], Loss: 0.5822
Epoch [143/200], Loss: 0.5231
Epoch [143/200], Loss: 0.5503
Epoch [144/200], Loss: 0.5503
Epoch [144/200], Loss: 0.5591
Epoch [144/200], Loss: 0.5631
Epoch [144/200], Loss: 0.5673
Epoch [145/200], Loss: 0.5662
Epoch [145/200], Loss: 0.5787
Epoch [145/200], Loss: 0.5270
Epoch [145/200], Loss: 0.5617
Epoch [146/200], Loss: 0.5490
Epoch [146/200], Loss: 0.6080
Epoch [146/200], Loss: 0.5534
Epoch [146/200], Loss: 0.5833
Epoch [147/200], Loss: 0.5510
Epoch [147/200], Loss: 0.5711
Epoch [147/200], Loss: 0.5859
Epoch [147/200], Loss: 0.5574
Epoch [148/200], Loss: 0.5728
Epoch [148/200], Loss: 0.5681
Epoch [148/200], Loss: 0.5558
Epoch [148/200], Loss: 0.5461
Epoch [149/200], Loss: 0.5624
Epoch [149/200], Loss: 0.5577
Epoch [149

In [10]:
validate(model, totnum = 4000)