<img src="http://oproject.org/img/ROOT.png" height="30%" width="30%">
<img src="http://oproject.org/img/tmvalogo.png" height="30%" width="30%">

<hr style="border-top-width: 4px; border-top-color: #34609b;">

# DNN Example

## Declare Factory

In [None]:
TFile* inputFile = TFile::Open("inputdata.root");
TFile* outputFile = TFile::Open("TMVAOutputDNN.root", "RECREATE");

TMVA::Factory factory("TMVAClassification", outputFile,
                      "!V:ROC:!Correlations:!Silent:Color:!DrawProgressBar:AnalysisType=Classification" ); 

## Declare DataLoader

In [None]:
TMVA::DataLoader loader("dataset_dnn");

loader.AddVariable("var1");
loader.AddVariable("var2");
loader.AddVariable("var3");
loader.AddVariable("var4");
loader.AddVariable("var5 := var1-var3");
loader.AddVariable("var6 := var1+var2");

## Setup Dataset(s)

In [None]:
TTree *tsignal = (TTree*) inputFile->Get("Sig");
TTree *tbackground = (TTree*) inputFile->Get("Bkg");

loader.AddSignalTree(tsignal);
loader.AddBackgroundTree(tbackground);
loader.PrepareTrainingAndTestTree("",
        "nTrain_Signal=1000:nTrain_Background=1000:SplitMode=Random:NormMode=NumEvents:!V"); 

# Configure network layout 

In [None]:
// General layout
TString layoutString ("Layout=TANH|128,TANH|128,TANH|128,LINEAR");

// Training strategies
TString training0("LearningRate=1e-1,Momentum=0.9,Repetitions=1,"
                        "ConvergenceSteps=2,BatchSize=256,TestRepetitions=10,"
                        "WeightDecay=1e-4,Regularization=L2,"
                        "DropConfig=0.0+0.5+0.5+0.5, Multithreading=True");
TString training1("LearningRate=1e-2,Momentum=0.9,Repetitions=1,"
                        "ConvergenceSteps=2,BatchSize=256,TestRepetitions=10,"
                        "WeightDecay=1e-4,Regularization=L2,"
                        "DropConfig=0.0+0.0+0.0+0.0, Multithreading=True");
TString training2("LearningRate=1e-3,Momentum=0.0,Repetitions=1,"
                        "ConvergenceSteps=2,BatchSize=256,TestRepetitions=10,"
                        "WeightDecay=1e-4,Regularization=L2,"
                        "DropConfig=0.0+0.0+0.0+0.0, Multithreading=True");
TString trainingStrategyString ("TrainingStrategy=");
trainingStrategyString += training0 + "|" + training1 + "|" + training2;

// General Options
TString dnnOptions ("!H:!V:ErrorStrategy=CROSSENTROPY:VarTransform=N:"
                          "WeightInitialization=XAVIERUNIFORM");
dnnOptions.Append (":"); dnnOptions.Append (layoutString);
dnnOptions.Append (":"); dnnOptions.Append (trainingStrategyString);

# Booking Methods

In [None]:
// Standard implementation, no dependencies.
TString stdOptions = dnnOptions + ":Architecture=CPU";
factory.BookMethod(&loader, TMVA::Types::kDNN, "DNN", stdOptions);

// CPU implementation, using BLAS
//TString cpuOptions = dnnOptions + ":Architecture=CPU";
//factory->BookMethod(dataloader, TMVA::Types::kDNN, "DNN CPU", cpuOptions);

// Multi-Layer Perceptron (Neural Network)
factory.BookMethod(&loader, TMVA::Types::kMLP, "MLP",
        "!H:!V:NeuronType=tanh:VarTransform=N:NCycles=100:HiddenLayers=N+5:TestRate=5:!UseRegulator");

## Train Methods

In [None]:
factory.TrainAllMethods();

## Test and Evaluate Methods

In [None]:
factory.TestAllMethods();
factory.EvaluateAllMethods();    

## Plot ROC Curve
We enable JavaScript visualisation for the plots

In [None]:
%jsroot on

In [None]:
auto c = factory.GetROCCurve(&loader);
c->Draw();