<img src="http://oproject.org/img/ROOT.png" height="30%" width="30%">
<img src="http://oproject.org/img/tmvalogo.png" height="30%" width="30%">

<hr style="border-top-width: 4px; border-top-color: #34609b;">

In [None]:
from ROOT import TMVA, TFile, TTree, TCut
from subprocess import call
from os.path import isfile

from keras.models import Sequential
from keras.layers import Dense

In [None]:
TMVA.Tools.Instance()
TMVA.PyMethodBase.PyInitialize()

output = TFile.Open('TMVA.root', 'RECREATE')
factory = TMVA.Factory('TMVAClassification', output,
                       '!V:!Silent:Color:!DrawProgressBar:Transformations=D,G:AnalysisType=Classification')

In [None]:
if not isfile('tmva_class_example.root'):
    call(['curl', '-O', 'http://root.cern.ch/files/tmva_class_example.root'])

data = TFile.Open('tmva_class_example.root')
signal = data.Get('TreeS')
background = data.Get('TreeB')

dataloader = TMVA.DataLoader('dataset')
for branch in signal.GetListOfBranches():
    dataloader.AddVariable(branch.GetName())

dataloader.AddSignalTree(signal, 1.0)
dataloader.AddBackgroundTree(background, 1.0)
dataloader.PrepareTrainingAndTestTree(TCut(''),
                                      'nTrain_Signal=4000:nTrain_Background=4000:SplitMode=Random:NormMode=NumEvents:!V')

In [None]:
model = Sequential()
model.add(Dense(64, activation='relu', input_dim=4))
model.add(Dense(2, activation='softmax'))

model.summary()

In [None]:
model.compile(loss='categorical_crossentropy',
              optimizer="adam", metrics=['accuracy', ])

model.save("model.h5")

In [None]:
factory.BookMethod(dataloader, TMVA.Types.kFisher, 'Fisher',
                   '!H:!V:Fisher:VarTransform=D,G')
factory.BookMethod(dataloader, TMVA.Types.kBDT, 'BDT',
                   '!H:!V:NTrees=500:MaxDepth=3:BoostType=AdaBoost:nCuts=20')
factory.BookMethod(dataloader, TMVA.Types.kPyKeras, 'PyKeras',
                   'H:!V:VarTransform=N:FilenameModel=model.h5:NumEpochs=10:BatchSize=32')

In [None]:
factory.TrainAllMethods()
factory.TestAllMethods()
factory.EvaluateAllMethods()

In [None]:
%jsroot on
canvas = factory.GetROCCurve(dataloader)
canvas.Draw()