In [1]:
from ROOT import TMVA, TFile, TTree, TCut
from subprocess import call
from os.path import isfile

import torch
from torch import nn


Welcome to JupyROOT 6.24/04


In [2]:
# Setup TMVA
TMVA.Tools.Instance()
TMVA.PyMethodBase.PyInitialize()

output = TFile.Open('TMVA.root', 'RECREATE')
factory = TMVA.Factory('TMVAClassification', output,
                       '!V:!Silent:Color:DrawProgressBar:Transformations=D,G:AnalysisType=Classification')


In [3]:
# Load data
if not isfile('tmva_class_example.root'):
    call(['curl', '-O', 'http://root.cern.ch/files/tmva_class_example.root'])

  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100   332  100   332    0     0    555      0 --:--:-- --:--:-- --:--:--   557


In [7]:
# Load data
if not isfile('tmva_class_example.root'):
    call(['curl', '-O', 'http://root.cern.ch/files/tmva_class_example.root'])


In [9]:
data = TFile.Open('tmva_class_example.root')
signal = data.Get('TreeS')
background = data.Get('TreeB')

In [10]:
dataloader = TMVA.DataLoader('dataset')
for branch in signal.GetListOfBranches():
    dataloader.AddVariable(branch.GetName())

In [11]:
dataloader.AddSignalTree(signal, 1.0)
dataloader.AddBackgroundTree(background, 1.0)
dataloader.PrepareTrainingAndTestTree(TCut(''),
                                      'nTrain_Signal=4000:nTrain_Background=4000:SplitMode=Random:NormMode=NumEvents:!V')

DataSetInfo              : [dataset] : Added class "Signal"
                         : Add Tree TreeS of type Signal with 6000 events
DataSetInfo              : [dataset] : Added class "Background"
                         : Add Tree TreeB of type Background with 6000 events
                         : Dataset[dataset] : Class index : 0  name : Signal
                         : Dataset[dataset] : Class index : 1  name : Background


In [12]:
# Define model
model = nn.Sequential()
model.add_module('linear_1', nn.Linear(in_features=4, out_features=64))
model.add_module('relu', nn.ReLU())
model.add_module('linear_2', nn.Linear(in_features=64, out_features=2))
model.add_module('softmax', nn.Softmax(dim=1))

In [13]:
# Construct loss function and Optimizer.
loss = torch.nn.MSELoss()
optimizer = torch.optim.SGD

In [14]:
# Define train function
def train(model, train_loader, val_loader, num_epochs, batch_size, optimizer, criterion, save_best, scheduler):
    trainer = optimizer(model.parameters(), lr=0.01)
    schedule, schedulerSteps = scheduler
    best_val = None

    for epoch in range(num_epochs):
        # Training Loop
        # Set to train mode
        model.train()
        running_train_loss = 0.0
        running_val_loss = 0.0
        for i, (X, y) in enumerate(train_loader):
            trainer.zero_grad()
            output = model(X)
            train_loss = criterion(output, y)
            train_loss.backward()
            trainer.step()

            # print train statistics
            running_train_loss += train_loss.item()
            if i % 32 == 31:    # print every 32 mini-batches
                print("[{}, {}] train loss: {:.3f}".format(epoch+1, i+1, running_train_loss / 32))
                running_train_loss = 0.0

        if schedule:
            schedule(optimizer, epoch, schedulerSteps)

        # Validation Loop
        # Set to eval mode
        model.eval()
        with torch.no_grad():
            for i, (X, y) in enumerate(val_loader):
                output = model(X)
                val_loss = criterion(output, y)
                running_val_loss += val_loss.item()

            curr_val = running_val_loss / len(val_loader)
            if save_best:
               if best_val==None:
                   best_val = curr_val
               best_val = save_best(model, curr_val, best_val)

            # print val statistics per epoch
            print("[{}] val loss: {:.3f}".format(epoch+1, curr_val))
            running_val_loss = 0.0

    print("Finished Training on {} Epochs!".format(epoch+1))

    return model

In [15]:
# Define predict function
def predict(model, test_X, batch_size=32):
    # Set to eval mode
    model.eval()

    test_dataset = torch.utils.data.TensorDataset(torch.Tensor(test_X))
    test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

    predictions = []
    with torch.no_grad():
        for i, data in enumerate(test_loader):
            X = data[0]
            outputs = model(X)
            predictions.append(outputs)
        preds = torch.cat(predictions)

    return preds.numpy()


In [16]:
load_model_custom_objects = {"optimizer": optimizer, "criterion": loss, "train_func": train, "predict_func": predict}


In [17]:
# Store model to file
# Convert the model to torchscript before saving
m = torch.jit.script(model)
torch.jit.save(m, "model.pt")
print(m)

RecursiveScriptModule(
  original_name=Sequential
  (linear_1): RecursiveScriptModule(original_name=Linear)
  (relu): RecursiveScriptModule(original_name=ReLU)
  (linear_2): RecursiveScriptModule(original_name=Linear)
  (softmax): RecursiveScriptModule(original_name=Softmax)
)


In [18]:

# Book methods
factory.BookMethod(dataloader, TMVA.Types.kFisher, 'Fisher',
                   '!H:!V:Fisher:VarTransform=D,G')
factory.BookMethod(dataloader, TMVA.Types.kPyTorch, 'PyTorch',
                   'H:!V:VarTransform=D,G:FilenameModel=model.pt:NumEpochs=20:BatchSize=32')



custom objects for loading model :  {'optimizer': <class 'torch.optim.sgd.SGD'>, 'criterion': MSELoss(), 'train_func': <function train at 0x7f7cb042bd30>, 'predict_func': <function predict at 0x7f7bab815dc0>}


<cppyy.gbl.TMVA.MethodPyTorch object at 0x5641f24f7620>

Factory                  : Booking method: [1mFisher[0m
                         : 
Fisher                   : [dataset] : Create Transformation "D" with events from all classes.
                         : 
                         : Transformation, Variable selection : 
                         : Input : variable 'var1' <---> Output : variable 'var1'
                         : Input : variable 'var2' <---> Output : variable 'var2'
                         : Input : variable 'var3' <---> Output : variable 'var3'
                         : Input : variable 'var4' <---> Output : variable 'var4'
Fisher                   : [dataset] : Create Transformation "G" with events from all classes.
                         : 
                         : Transformation, Variable selection : 
                         : Input : variable 'var1' <---> Output : variable 'var1'
                         : Input : variable 'var2' <---> Output : variable 'var2'
                         : Input : variable 'v

In [19]:
# Run training, test and evaluation
factory.TrainAllMethods()
factory.TestAllMethods()
factory.EvaluateAllMethods()

RecursiveScriptModule(
  original_name=Sequential
  (linear_1): RecursiveScriptModule(original_name=Linear)
  (relu): RecursiveScriptModule(original_name=ReLU)
  (linear_2): RecursiveScriptModule(original_name=Linear)
  (softmax): RecursiveScriptModule(original_name=Softmax)
)
[1, 32] train loss: 0.219
[1, 64] train loss: 0.200
[1, 96] train loss: 0.190
[1, 128] train loss: 0.178
[1, 160] train loss: 0.173
[1, 192] train loss: 0.166
[1] val loss: 0.165
[2, 32] train loss: 0.166
[2, 64] train loss: 0.149
[2, 96] train loss: 0.153
[2, 128] train loss: 0.146
[2, 160] train loss: 0.148
[2, 192] train loss: 0.144
[2] val loss: 0.148
[3, 32] train loss: 0.151
[3, 64] train loss: 0.133
[3, 96] train loss: 0.140
[3, 128] train loss: 0.134
[3, 160] train loss: 0.138
[3, 192] train loss: 0.135
[3] val loss: 0.139
[4, 32] train loss: 0.144
[4, 64] train loss: 0.124
[4, 96] train loss: 0.133
[4, 128] train loss: 0.128
[4, 160] train loss: 0.132
[4, 192] train loss: 0.130
[4] val loss: 0.135
[5, 32

0%, time left: unknown
7%, time left: 0 sec
13%, time left: 0 sec
19%, time left: 0 sec
25%, time left: 0 sec
32%, time left: 0 sec
38%, time left: 0 sec
44%, time left: 0 sec
50%, time left: 0 sec
57%, time left: 0 sec
63%, time left: 0 sec
69%, time left: 0 sec
75%, time left: 0 sec
82%, time left: 0 sec
88%, time left: 0 sec
94%, time left: 0 sec
0%, time left: unknown
7%, time left: 0 sec
13%, time left: 0 sec
19%, time left: 0 sec
25%, time left: 0 sec
32%, time left: 0 sec
38%, time left: 0 sec
44%, time left: 0 sec
50%, time left: 0 sec
57%, time left: 0 sec
63%, time left: 0 sec
69%, time left: 0 sec
75%, time left: 0 sec
82%, time left: 0 sec
88%, time left: 0 sec
94%, time left: 0 sec


In [20]:
# Plot ROC Curves
roc = factory.GetROCCurve(dataloader)
roc.SaveAs('ROC_ClassificationPyTorch.png')

Info in <TCanvas::Print>: png file ROC_ClassificationPyTorch.png has been created
